diff --git a/.rat-excludes b/.rat-excludes index c24667c18dbd..9165872b9fb2 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -86,4 +86,12 @@ local-1430917381535_2 DESCRIPTION NAMESPACE test_support/* +.*Rd +help/* +html/* +INDEX .lintr +gen-java.* +.*avpr +org.apache.spark.sql.sources.DataSourceRegister +.*parquet diff --git a/LICENSE b/LICENSE index 42010d9f5f0e..f9e412cade34 100644 --- a/LICENSE +++ b/LICENSE @@ -948,6 +948,6 @@ The following components are provided under the MIT License. See project link fo (MIT License) SLF4J LOG4J-12 Binding (org.slf4j:slf4j-log4j12:1.7.5 - http://www.slf4j.org) (MIT License) pyrolite (org.spark-project:pyrolite:2.0.1 - http://pythonhosted.org/Pyro4/) (MIT License) scopt (com.github.scopt:scopt_2.10:3.2.0 - https://github.com/scopt/scopt) - (The MIT License) Mockito (org.mockito:mockito-all:1.8.5 - http://www.mockito.org) + (The MIT License) Mockito (org.mockito:mockito-core:1.9.5 - http://www.mockito.org) (MIT License) jquery (https://jquery.org/license/) (MIT License) AnchorJS (https://github.com/bryanbraun/anchorjs) diff --git a/R/README.md b/R/README.md index d7d65b4f0eca..005f56da1670 100644 --- a/R/README.md +++ b/R/README.md @@ -6,7 +6,7 @@ SparkR is an R package that provides a light-weight frontend to use Spark from R #### Build Spark -Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-PsparkR` profile to build the R package. For example to use the default Hadoop versions you can run +Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-Psparkr` profile to build the R package. For example to use the default Hadoop versions you can run ``` build/mvn -DskipTests -Psparkr package ``` diff --git a/R/create-docs.sh b/R/create-docs.sh index 6a4687b06ecb..d2ae160b5002 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -39,7 +39,7 @@ pushd $FWDIR mkdir -p pkg/html pushd pkg/html -Rscript -e 'library(SparkR, lib.loc="../../lib"); library(knitr); knit_rd("SparkR")' +Rscript -e 'libDir <- "../../lib"; library(SparkR, lib.loc=libDir); library(knitr); knit_rd("SparkR", links = tools::findHTMLlinks(paste(libDir, "SparkR", sep="/")))' popd diff --git a/R/install-dev.sh b/R/install-dev.sh index 1edd551f8d24..59d98c9c7a64 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -34,7 +34,7 @@ LIB_DIR="$FWDIR/lib" mkdir -p $LIB_DIR -pushd $FWDIR +pushd $FWDIR > /dev/null # Generate Rd files if devtools is installed Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' @@ -42,4 +42,4 @@ Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtoo # Install SparkR to $LIB_DIR R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ -popd +popd > /dev/null diff --git a/R/pkg/.lintr b/R/pkg/.lintr index b10ebd35c4ca..038236fc149e 100644 --- a/R/pkg/.lintr +++ b/R/pkg/.lintr @@ -1,2 +1,2 @@ -linters: with_defaults(line_length_linter(100), camel_case_linter = NULL) +linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R") diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index efc85bbc4b31..a3a16c42a621 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,7 +1,7 @@ Package: SparkR Type: Package Title: R frontend for Spark -Version: 1.4.0 +Version: 1.6.0 Date: 2013-09-09 Author: The Apache Software Foundation Maintainer: Shivaram Venkataraman @@ -29,7 +29,8 @@ Collate: 'client.R' 'context.R' 'deserialize.R' + 'functions.R' + 'mllib.R' 'serialize.R' 'sparkR.R' 'utils.R' - 'zzz.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 7f857222452d..9d3963070643 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -10,6 +10,11 @@ export("sparkR.init") export("sparkR.stop") export("print.jobj") +# MLlib integration +exportMethods("glm", + "predict", + "summary") + # Job group lifecycle management methods export("setJobGroup", "clearJobGroup", @@ -22,7 +27,9 @@ exportMethods("arrange", "collect", "columns", "count", + "crosstab", "describe", + "dim", "distinct", "dropna", "dtypes", @@ -39,11 +46,16 @@ exportMethods("arrange", "isLocal", "join", "limit", - "orderBy", + "merge", "mutate", + "na.omit", "names", + "ncol", + "nrow", + "orderBy", "persist", "printSchema", + "rbind", "registerTempTable", "rename", "repartition", @@ -57,9 +69,13 @@ exportMethods("arrange", "selectExpr", "show", "showDF", + "subset", "summarize", + "summary", "take", + "transform", "unionAll", + "unique", "unpersist", "where", "withColumn", @@ -68,58 +84,139 @@ exportMethods("arrange", exportClasses("Column") -exportMethods("abs", +exportMethods("%in%", + "abs", "acos", + "add_months", "alias", "approxCountDistinct", "asc", + "ascii", "asin", "atan", "atan2", "avg", + "base64", + "between", + "bin", + "bitwiseNOT", "cast", "cbrt", + "ceil", "ceiling", + "concat", + "concat_ws", "contains", + "conv", "cos", "cosh", + "count", "countDistinct", + "crc32", + "date_add", + "date_format", + "date_sub", + "datediff", + "dayofmonth", + "dayofyear", "desc", "endsWith", "exp", + "explode", "expm1", + "expr", + "factorial", + "first", "floor", + "format_number", + "format_string", + "from_unixtime", + "from_utc_timestamp", "getField", "getItem", + "greatest", + "hex", + "hour", "hypot", + "ifelse", + "initcap", + "instr", + "isNaN", "isNotNull", "isNull", "last", + "last_day", + "least", + "length", + "levenshtein", "like", + "lit", + "locate", "log", "log10", "log1p", + "log2", "lower", + "lpad", + "ltrim", "max", + "md5", "mean", "min", + "minute", + "month", + "months_between", "n", "n_distinct", + "nanvl", + "negate", + "next_day", + "otherwise", + "pmod", + "quarter", + "rand", + "randn", + "regexp_extract", + "regexp_replace", + "reverse", "rint", "rlike", + "round", + "rpad", + "rtrim", + "second", + "sha1", + "sha2", + "shiftLeft", + "shiftRight", + "shiftRightUnsigned", "sign", + "signum", "sin", "sinh", + "size", + "soundex", "sqrt", "startsWith", "substr", + "substring_index", "sum", "sumDistinct", "tan", "tanh", "toDegrees", "toRadians", - "upper") + "to_date", + "to_utc_timestamp", + "translate", + "trim", + "unbase64", + "unhex", + "unix_timestamp", + "upper", + "weekofyear", + "when", + "year") exportClasses("GroupedData") exportMethods("agg") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 0af5cb8881e3..c3c189348733 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -27,9 +27,10 @@ setOldClass("jobj") #' \code{jsonFile}, \code{table} etc. #' @rdname DataFrame #' @seealso jsonFile, table +#' @docType class #' -#' @param env An R environment that stores bookkeeping states of the DataFrame -#' @param sdf A Java object reference to the backing Scala DataFrame +#' @slot env An R environment that stores bookkeeping states of the DataFrame +#' @slot sdf A Java object reference to the backing Scala DataFrame #' @export setClass("DataFrame", slots = list(env = "environment", @@ -38,7 +39,7 @@ setClass("DataFrame", setMethod("initialize", "DataFrame", function(.Object, sdf, isCached) { .Object@env <- new.env() .Object@env$isCached <- isCached - + .Object@sdf <- sdf .Object }) @@ -55,12 +56,13 @@ dataFrame <- function(sdf, isCached = FALSE) { ############################ DataFrame Methods ############################################## #' Print Schema of a DataFrame -#' +#' #' Prints out the schema in tree format -#' +#' #' @param x A SparkSQL DataFrame -#' +#' #' @rdname printSchema +#' @name printSchema #' @export #' @examples #'\dontrun{ @@ -78,12 +80,13 @@ setMethod("printSchema", }) #' Get schema object -#' +#' #' Returns the schema of this DataFrame as a structType object. -#' +#' #' @param x A SparkSQL DataFrame -#' +#' #' @rdname schema +#' @name schema #' @export #' @examples #'\dontrun{ @@ -100,12 +103,13 @@ setMethod("schema", }) #' Explain -#' +#' #' Print the logical and physical Catalyst plans to the console for debugging. -#' +#' #' @param x A SparkSQL DataFrame #' @param extended Logical. If extended is False, explain() only prints the physical plan. #' @rdname explain +#' @name explain #' @export #' @examples #'\dontrun{ @@ -135,6 +139,7 @@ setMethod("explain", #' @param x A SparkSQL DataFrame #' #' @rdname isLocal +#' @name isLocal #' @export #' @examples #'\dontrun{ @@ -158,6 +163,7 @@ setMethod("isLocal", #' @param numRows The number of rows to print. Defaults to 20. #' #' @rdname showDF +#' @name showDF #' @export #' @examples #'\dontrun{ @@ -169,8 +175,8 @@ setMethod("isLocal", #'} setMethod("showDF", signature(x = "DataFrame"), - function(x, numRows = 20) { - s <- callJMethod(x@sdf, "showString", numToInt(numRows)) + function(x, numRows = 20, truncate = TRUE) { + s <- callJMethod(x@sdf, "showString", numToInt(numRows), truncate) cat(s) }) @@ -181,6 +187,7 @@ setMethod("showDF", #' @param x A SparkSQL DataFrame #' #' @rdname show +#' @name show #' @export #' @examples #'\dontrun{ @@ -200,12 +207,13 @@ setMethod("show", "DataFrame", }) #' DataTypes -#' +#' #' Return all column names and their data types as a list -#' +#' #' @param x A SparkSQL DataFrame -#' +#' #' @rdname dtypes +#' @name dtypes #' @export #' @examples #'\dontrun{ @@ -224,12 +232,14 @@ setMethod("dtypes", }) #' Column names -#' +#' #' Return all column names as a list -#' +#' #' @param x A SparkSQL DataFrame -#' +#' #' @rdname columns +#' @name columns +#' @aliases names #' @export #' @examples #'\dontrun{ @@ -248,21 +258,33 @@ setMethod("columns", }) #' @rdname columns -#' @aliases names,DataFrame,function-method +#' @name names setMethod("names", signature(x = "DataFrame"), function(x) { columns(x) }) +#' @rdname columns +#' @name names<- +setMethod("names<-", + signature(x = "DataFrame"), + function(x, value) { + if (!is.null(value)) { + sdf <- callJMethod(x@sdf, "toDF", as.list(value)) + dataFrame(sdf) + } + }) + #' Register Temporary Table -#' +#' #' Registers a DataFrame as a Temporary Table in the SQLContext -#' +#' #' @param x A SparkSQL DataFrame #' @param tableName A character vector containing the name of the table -#' +#' #' @rdname registerTempTable +#' @name registerTempTable #' @export #' @examples #'\dontrun{ @@ -289,6 +311,7 @@ setMethod("registerTempTable", #' the existing rows in the table. #' #' @rdname insertInto +#' @name insertInto #' @export #' @examples #'\dontrun{ @@ -306,12 +329,13 @@ setMethod("insertInto", }) #' Cache -#' +#' #' Persist with the default storage level (MEMORY_ONLY). -#' +#' #' @param x A SparkSQL DataFrame -#' -#' @rdname cache-methods +#' +#' @rdname cache +#' @name cache #' @export #' @examples #'\dontrun{ @@ -337,6 +361,7 @@ setMethod("cache", #' #' @param x The DataFrame to persist #' @rdname persist +#' @name persist #' @export #' @examples #'\dontrun{ @@ -362,6 +387,7 @@ setMethod("persist", #' @param x The DataFrame to unpersist #' @param blocking Whether to block until all blocks are deleted #' @rdname unpersist-methods +#' @name unpersist #' @export #' @examples #'\dontrun{ @@ -387,6 +413,7 @@ setMethod("unpersist", #' @param x A SparkSQL DataFrame #' @param numPartitions The number of partitions to use. #' @rdname repartition +#' @name repartition #' @export #' @examples #'\dontrun{ @@ -400,7 +427,7 @@ setMethod("repartition", signature(x = "DataFrame", numPartitions = "numeric"), function(x, numPartitions) { sdf <- callJMethod(x@sdf, "repartition", numToInt(numPartitions)) - dataFrame(sdf) + dataFrame(sdf) }) # toJSON @@ -436,6 +463,7 @@ setMethod("toJSON", #' @param x A SparkSQL DataFrame #' @param path The directory where the file is saved #' @rdname saveAsParquetFile +#' @name saveAsParquetFile #' @export #' @examples #'\dontrun{ @@ -457,6 +485,7 @@ setMethod("saveAsParquetFile", #' #' @param x A SparkSQL DataFrame #' @rdname distinct +#' @name distinct #' @export #' @examples #'\dontrun{ @@ -473,6 +502,19 @@ setMethod("distinct", dataFrame(sdf) }) +#' @title Distinct rows in a DataFrame +# +#' @description Returns a new DataFrame containing distinct rows in this DataFrame +#' +#' @rdname unique +#' @name unique +#' @aliases distinct +setMethod("unique", + signature(x = "DataFrame"), + function(x) { + distinct(x) + }) + #' Sample #' #' Return a sampled subset of this DataFrame using a random seed. @@ -489,7 +531,7 @@ setMethod("distinct", #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" #' df <- jsonFile(sqlContext, path) -#' collect(sample(df, FALSE, 0.5)) +#' collect(sample(df, FALSE, 0.5)) #' collect(sample(df, TRUE, 0.5)) #'} setMethod("sample", @@ -504,7 +546,7 @@ setMethod("sample", }) #' @rdname sample -#' @aliases sample +#' @name sample_frac setMethod("sample_frac", signature(x = "DataFrame", withReplacement = "logical", fraction = "numeric"), @@ -513,12 +555,14 @@ setMethod("sample_frac", }) #' Count -#' +#' #' Returns the number of rows in a DataFrame -#' +#' #' @param x A SparkSQL DataFrame -#' +#' #' @rdname count +#' @name count +#' @aliases nrow #' @export #' @examples #'\dontrun{ @@ -534,13 +578,67 @@ setMethod("count", callJMethod(x@sdf, "count") }) +#' @title Number of rows for a DataFrame +#' @description Returns number of rows in a DataFrames +#' +#' @name nrow +#' +#' @rdname nrow +#' @aliases count +setMethod("nrow", + signature(x = "DataFrame"), + function(x) { + count(x) + }) + +#' Returns the number of columns in a DataFrame +#' +#' @param x a SparkSQL DataFrame +#' +#' @rdname ncol +#' @name ncol +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' ncol(df) +#' } +setMethod("ncol", + signature(x = "DataFrame"), + function(x) { + length(columns(x)) + }) + +#' Returns the dimentions (number of rows and columns) of a DataFrame +#' @param x a SparkSQL DataFrame +#' +#' @rdname dim +#' @name dim +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' dim(df) +#' } +setMethod("dim", + signature(x = "DataFrame"), + function(x) { + c(count(x), ncol(x)) + }) + #' Collects all the elements of a Spark DataFrame and coerces them into an R data.frame. #' #' @param x A SparkSQL DataFrame #' @param stringsAsFactors (Optional) A logical indicating whether or not string columns #' should be converted to factors. FALSE by default. - -#' @rdname collect-methods +#' @rdname collect +#' @name collect #' @export #' @examples #'\dontrun{ @@ -554,28 +652,60 @@ setMethod("count", setMethod("collect", signature(x = "DataFrame"), function(x, stringsAsFactors = FALSE) { - # listCols is a list of raw vectors, one per column - listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf) - cols <- lapply(listCols, function(col) { - objRaw <- rawConnection(col) - numRows <- readInt(objRaw) - col <- readCol(objRaw, numRows) - close(objRaw) - col - }) - names(cols) <- columns(x) - do.call(cbind.data.frame, list(cols, stringsAsFactors = stringsAsFactors)) - }) + names <- columns(x) + ncol <- length(names) + if (ncol <= 0) { + # empty data.frame with 0 columns and 0 rows + data.frame() + } else { + # listCols is a list of columns + listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf) + stopifnot(length(listCols) == ncol) + + # An empty data.frame with 0 columns and number of rows as collected + nrow <- length(listCols[[1]]) + if (nrow <= 0) { + df <- data.frame() + } else { + df <- data.frame(row.names = 1 : nrow) + } + + # Append columns one by one + for (colIndex in 1 : ncol) { + # Note: appending a column of list type into a data.frame so that + # data of complex type can be held. But getting a cell from a column + # of list type returns a list instead of a vector. So for columns of + # non-complex type, append them as vector. + col <- listCols[[colIndex]] + if (length(col) <= 0) { + df[[names[colIndex]]] <- col + } else { + # TODO: more robust check on column of primitive types + vec <- do.call(c, col) + if (class(vec) != "list") { + df[[names[colIndex]]] <- vec + } else { + # For columns of complex type, be careful to access them. + # Get a column of complex type returns a list. + # Get a cell from a column of complex type returns a list instead of a vector. + df[[names[colIndex]]] <- col + } + } + } + df + } + }) #' Limit -#' +#' #' Limit the resulting DataFrame to the number of rows specified. -#' +#' #' @param x A SparkSQL DataFrame #' @param num The number of rows to return #' @return A new DataFrame containing the number of rows specified. -#' +#' #' @rdname limit +#' @name limit #' @export #' @examples #' \dontrun{ @@ -593,8 +723,9 @@ setMethod("limit", }) #' Take the first NUM rows of a DataFrame and return a the results as a data.frame -#' +#' #' @rdname take +#' @name take #' @export #' @examples #'\dontrun{ @@ -613,8 +744,8 @@ setMethod("take", #' Head #' -#' Return the first NUM rows of a DataFrame as a data.frame. If NUM is NULL, -#' then head() returns the first 6 rows in keeping with the current data.frame +#' Return the first NUM rows of a DataFrame as a data.frame. If NUM is NULL, +#' then head() returns the first 6 rows in keeping with the current data.frame #' convention in R. #' #' @param x A SparkSQL DataFrame @@ -622,6 +753,7 @@ setMethod("take", #' @return A data.frame #' #' @rdname head +#' @name head #' @export #' @examples #'\dontrun{ @@ -643,6 +775,7 @@ setMethod("head", #' @param x A SparkSQL DataFrame #' #' @rdname first +#' @name first #' @export #' @examples #'\dontrun{ @@ -658,12 +791,12 @@ setMethod("first", take(x, 1) }) -# toRDD() -# +# toRDD +# # Converts a Spark DataFrame to an RDD while preserving column names. -# +# # @param x A Spark DataFrame -# +# # @rdname DataFrame # @export # @examples @@ -695,6 +828,7 @@ setMethod("toRDD", #' @seealso GroupedData #' @aliases group_by #' @rdname groupBy +#' @name groupBy #' @export #' @examples #' \dontrun{ @@ -709,16 +843,16 @@ setMethod("groupBy", function(x, ...) { cols <- list(...) if (length(cols) >= 1 && class(cols[[1]]) == "character") { - sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], listToSeq(cols[-1])) + sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], cols[-1]) } else { jcol <- lapply(cols, function(c) { c@jc }) - sgd <- callJMethod(x@sdf, "groupBy", listToSeq(jcol)) + sgd <- callJMethod(x@sdf, "groupBy", jcol) } groupedData(sgd) }) #' @rdname groupBy -#' @aliases group_by +#' @name group_by setMethod("group_by", signature(x = "DataFrame"), function(x, ...) { @@ -730,7 +864,8 @@ setMethod("group_by", #' Compute aggregates by specifying a list of columns #' #' @param x a DataFrame -#' @rdname DataFrame +#' @rdname agg +#' @name agg #' @aliases summarize #' @export setMethod("agg", @@ -739,8 +874,8 @@ setMethod("agg", agg(groupBy(x), ...) }) -#' @rdname DataFrame -#' @aliases agg +#' @rdname agg +#' @name summarize setMethod("summarize", signature(x = "DataFrame"), function(x, ...) { @@ -816,12 +951,14 @@ getColumn <- function(x, c) { } #' @rdname select +#' @name $ setMethod("$", signature(x = "DataFrame"), function(x, name) { getColumn(x, name) }) #' @rdname select +#' @name $<- setMethod("$<-", signature(x = "DataFrame"), function(x, name, value) { stopifnot(class(value) == "Column" || is.null(value)) @@ -848,8 +985,11 @@ setMethod("$<-", signature(x = "DataFrame"), x }) -#' @rdname select -setMethod("[[", signature(x = "DataFrame"), +setClassUnion("numericOrcharacter", c("numeric", "character")) + +#' @rdname subset +#' @name [[ +setMethod("[[", signature(x = "DataFrame", i = "numericOrcharacter"), function(x, i) { if (is.numeric(i)) { cols <- columns(x) @@ -858,7 +998,8 @@ setMethod("[[", signature(x = "DataFrame"), getColumn(x, i) }) -#' @rdname select +#' @rdname subset +#' @name [ setMethod("[", signature(x = "DataFrame", i = "missing"), function(x, i, j, ...) { if (is.numeric(j)) { @@ -871,6 +1012,51 @@ setMethod("[", signature(x = "DataFrame", i = "missing"), select(x, j) }) +#' @rdname subset +#' @name [ +setMethod("[", signature(x = "DataFrame", i = "Column"), + function(x, i, j, ...) { + # It could handle i as "character" but it seems confusing and not required + # https://stat.ethz.ch/R-manual/R-devel/library/base/html/Extract.data.frame.html + filtered <- filter(x, i) + if (!missing(j)) { + filtered[, j, ...] + } else { + filtered + } + }) + +#' Subset +#' +#' Return subsets of DataFrame according to given conditions +#' @param x A DataFrame +#' @param subset A logical expression to filter on rows +#' @param select expression for the single Column or a list of columns to select from the DataFrame +#' @return A new DataFrame containing only the rows that meet the condition with selected columns +#' @export +#' @rdname subset +#' @name subset +#' @aliases [ +#' @family subsetting functions +#' @examples +#' \dontrun{ +#' # Columns can be selected using `[[` and `[` +#' df[[2]] == df[["age"]] +#' df[,2] == df[,"age"] +#' df[,c("name", "age")] +#' # Or to filter rows +#' df[df$age > 20,] +#' # DataFrame can be subset on both rows and Columns +#' df[df$name == "Smith", c(1,2)] +#' df[df$age %in% c(19, 30), 1:2] +#' subset(df, df$age %in% c(19, 30), 1:2) +#' subset(df, df$age %in% c(19), select = c(1,2)) +#' } +setMethod("subset", signature(x = "DataFrame"), + function(x, subset, select, ...) { + x[subset, select, ...] + }) + #' Select #' #' Selects a set of columns with names or Column expressions. @@ -879,6 +1065,8 @@ setMethod("[", signature(x = "DataFrame", i = "missing"), #' @return A new DataFrame with selected columns #' @export #' @rdname select +#' @name select +#' @family subsetting functions #' @examples #' \dontrun{ #' select(df, "*") @@ -886,15 +1074,12 @@ setMethod("[", signature(x = "DataFrame", i = "missing"), #' select(df, df$name, df$age + 1) #' select(df, c("col1", "col2")) #' select(df, list(df$name, df$age + 1)) -#' # Columns can also be selected using `[[` and `[` -#' df[[2]] == df[["age"]] -#' df[,2] == df[,"age"] #' # Similar to R data frames columns can also be selected using `$` #' df$age #' } setMethod("select", signature(x = "DataFrame", col = "character"), function(x, col, ...) { - sdf <- callJMethod(x@sdf, "select", col, toSeq(...)) + sdf <- callJMethod(x@sdf, "select", col, list(...)) dataFrame(sdf) }) @@ -905,7 +1090,7 @@ setMethod("select", signature(x = "DataFrame", col = "Column"), jcols <- lapply(list(col, ...), function(c) { c@jc }) - sdf <- callJMethod(x@sdf, "select", listToSeq(jcols)) + sdf <- callJMethod(x@sdf, "select", jcols) dataFrame(sdf) }) @@ -921,7 +1106,7 @@ setMethod("select", col(c)@jc } }) - sdf <- callJMethod(x@sdf, "select", listToSeq(cols)) + sdf <- callJMethod(x@sdf, "select", cols) dataFrame(sdf) }) @@ -934,6 +1119,7 @@ setMethod("select", #' @param ... Additional expressions #' @return A DataFrame #' @rdname selectExpr +#' @name selectExpr #' @export #' @examples #'\dontrun{ @@ -947,7 +1133,7 @@ setMethod("selectExpr", signature(x = "DataFrame", expr = "character"), function(x, expr, ...) { exprList <- list(expr, ...) - sdf <- callJMethod(x@sdf, "selectExpr", listToSeq(exprList)) + sdf <- callJMethod(x@sdf, "selectExpr", exprList) dataFrame(sdf) }) @@ -960,6 +1146,8 @@ setMethod("selectExpr", #' @param col A Column expression. #' @return A DataFrame with the new column added. #' @rdname withColumn +#' @name withColumn +#' @aliases mutate transform #' @export #' @examples #'\dontrun{ @@ -979,11 +1167,12 @@ setMethod("withColumn", #' #' Return a new DataFrame with the specified columns added. #' -#' @param x A DataFrame +#' @param .data A DataFrame #' @param col a named argument of the form name = col #' @return A new DataFrame with the new columns added. #' @rdname withColumn -#' @aliases withColumn +#' @name mutate +#' @aliases withColumn transform #' @export #' @examples #'\dontrun{ @@ -993,10 +1182,12 @@ setMethod("withColumn", #' df <- jsonFile(sqlContext, path) #' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2) #' names(newDF) # Will contain newCol, newCol2 +#' newDF2 <- transform(df, newCol = df$col1 / 5, newCol2 = df$col1 * 2) #' } setMethod("mutate", - signature(x = "DataFrame"), - function(x, ...) { + signature(.data = "DataFrame"), + function(.data, ...) { + x <- .data cols <- list(...) stopifnot(length(cols) > 0) stopifnot(class(cols[[1]]) == "Column") @@ -1011,6 +1202,16 @@ setMethod("mutate", do.call(select, c(x, x$"*", cols)) }) +#' @export +#' @rdname withColumn +#' @name transform +#' @aliases withColumn mutate +setMethod("transform", + signature(`_data` = "DataFrame"), + function(`_data`, ...) { + mutate(`_data`, ...) + }) + #' WithColumnRenamed #' #' Rename an existing column in a DataFrame. @@ -1020,6 +1221,7 @@ setMethod("mutate", #' @param newCol The new column name. #' @return A DataFrame with the column name changed. #' @rdname withColumnRenamed +#' @name withColumnRenamed #' @export #' @examples #'\dontrun{ @@ -1050,6 +1252,7 @@ setMethod("withColumnRenamed", #' @param newCol A named pair of the form new_column_name = existing_column #' @return A DataFrame with the column name changed. #' @rdname withColumnRenamed +#' @name rename #' @aliases withColumnRenamed #' @export #' @examples @@ -1091,6 +1294,8 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' @param ... Additional sorting fields #' @return A DataFrame where all elements are sorted. #' @rdname arrange +#' @name arrange +#' @aliases orderby #' @export #' @examples #'\dontrun{ @@ -1106,18 +1311,18 @@ setMethod("arrange", signature(x = "DataFrame", col = "characterOrColumn"), function(x, col, ...) { if (class(col) == "character") { - sdf <- callJMethod(x@sdf, "sort", col, toSeq(...)) + sdf <- callJMethod(x@sdf, "sort", col, list(...)) } else if (class(col) == "Column") { jcols <- lapply(list(col, ...), function(c) { c@jc }) - sdf <- callJMethod(x@sdf, "sort", listToSeq(jcols)) + sdf <- callJMethod(x@sdf, "sort", jcols) } dataFrame(sdf) }) #' @rdname arrange -#' @aliases orderBy,DataFrame,function-method +#' @name orderby setMethod("orderBy", signature(x = "DataFrame", col = "characterOrColumn"), function(x, col) { @@ -1133,6 +1338,8 @@ setMethod("orderBy", #' or a string containing a SQL statement #' @return A DataFrame containing only the rows that meet the condition. #' @rdname filter +#' @name filter +#' @family subsetting functions #' @export #' @examples #'\dontrun{ @@ -1154,7 +1361,7 @@ setMethod("filter", }) #' @rdname filter -#' @aliases where,DataFrame,function-method +#' @name where setMethod("where", signature(x = "DataFrame", condition = "characterOrColumn"), function(x, condition) { @@ -1167,12 +1374,13 @@ setMethod("where", #' #' @param x A Spark DataFrame #' @param y A Spark DataFrame -#' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a +#' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a #' Column expression. If joinExpr is omitted, join() wil perform a Cartesian join #' @param joinType The type of join to perform. The following join types are available: #' 'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'. The default joinType is "inner". #' @return A DataFrame containing the result of the join operation. #' @rdname join +#' @name join #' @export #' @examples #'\dontrun{ @@ -1205,6 +1413,16 @@ setMethod("join", dataFrame(sdf) }) +#' @rdname merge +#' @name merge +#' @aliases join +setMethod("merge", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y, joinExpr = NULL, joinType = NULL, ...) { + join(x, y, joinExpr, joinType) + }) + + #' UnionAll #' #' Return a new DataFrame containing the union of rows in this DataFrame @@ -1215,6 +1433,7 @@ setMethod("join", #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the union. #' @rdname unionAll +#' @name unionAll #' @export #' @examples #'\dontrun{ @@ -1231,6 +1450,23 @@ setMethod("unionAll", dataFrame(unioned) }) +#' @title Union two or more DataFrames +# +#' @description Returns a new DataFrame containing rows of all parameters. +# +#' @rdname rbind +#' @name rbind +#' @aliases unionAll +setMethod("rbind", + signature(... = "DataFrame"), + function(x, ..., deparse.level = 1) { + if (nargs() == 3) { + unionAll(x, ...) + } else { + unionAll(x, Recall(..., deparse.level = 1)) + } + }) + #' Intersect #' #' Return a new DataFrame containing rows only in both this DataFrame @@ -1240,6 +1476,7 @@ setMethod("unionAll", #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the intersect. #' @rdname intersect +#' @name intersect #' @export #' @examples #'\dontrun{ @@ -1265,6 +1502,7 @@ setMethod("intersect", #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the except operation. #' @rdname except +#' @name except #' @export #' @examples #'\dontrun{ @@ -1303,7 +1541,9 @@ setMethod("except", #' @param source A name for external data source #' @param mode One of 'append', 'overwrite', 'error', 'ignore' #' -#' @rdname write.df +#' @rdname write.df +#' @name write.df +#' @aliases saveDF #' @export #' @examples #'\dontrun{ @@ -1314,7 +1554,7 @@ setMethod("except", #' write.df(df, "myfile", "parquet", "overwrite") #' } setMethod("write.df", - signature(df = "DataFrame", path = 'character'), + signature(df = "DataFrame", path = "character"), function(df, path, source = NULL, mode = "append", ...){ if (is.null(source)) { sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) @@ -1322,22 +1562,24 @@ setMethod("write.df", "org.apache.spark.sql.parquet") } allModes <- c("append", "overwrite", "error", "ignore") + # nolint start if (!(mode %in% allModes)) { stop('mode should be one of "append", "overwrite", "error", "ignore"') } + # nolint end jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) options <- varargsToEnv(...) if (!is.null(path)) { - options[['path']] = path + options[["path"]] <- path } callJMethod(df@sdf, "save", source, jmode, options) }) #' @rdname write.df -#' @aliases saveDF +#' @name saveDF #' @export setMethod("saveDF", - signature(df = "DataFrame", path = 'character'), + signature(df = "DataFrame", path = "character"), function(df, path, source = NULL, mode = "append", ...){ write.df(df, path, source, mode, ...) }) @@ -1365,6 +1607,7 @@ setMethod("saveDF", #' @param mode One of 'append', 'overwrite', 'error', 'ignore' #' #' @rdname saveAsTable +#' @name saveAsTable #' @export #' @examples #'\dontrun{ @@ -1375,8 +1618,8 @@ setMethod("saveDF", #' saveAsTable(df, "myfile") #' } setMethod("saveAsTable", - signature(df = "DataFrame", tableName = 'character', source = 'character', - mode = 'character'), + signature(df = "DataFrame", tableName = "character", source = "character", + mode = "character"), function(df, tableName, source = NULL, mode="append", ...){ if (is.null(source)) { sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) @@ -1384,9 +1627,11 @@ setMethod("saveAsTable", "org.apache.spark.sql.parquet") } allModes <- c("append", "overwrite", "error", "ignore") + # nolint start if (!(mode %in% allModes)) { stop('mode should be one of "append", "overwrite", "error", "ignore"') } + # nolint end jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) options <- varargsToEnv(...) callJMethod(df@sdf, "saveAsTable", tableName, source, jmode, options) @@ -1401,7 +1646,9 @@ setMethod("saveAsTable", #' @param col A string of name #' @param ... Additional expressions #' @return A DataFrame -#' @rdname describe +#' @rdname describe +#' @name describe +#' @aliases summary #' @export #' @examples #'\dontrun{ @@ -1417,19 +1664,33 @@ setMethod("describe", signature(x = "DataFrame", col = "character"), function(x, col, ...) { colList <- list(col, ...) - sdf <- callJMethod(x@sdf, "describe", listToSeq(colList)) + sdf <- callJMethod(x@sdf, "describe", colList) dataFrame(sdf) }) #' @rdname describe +#' @name describe setMethod("describe", signature(x = "DataFrame"), function(x) { colList <- as.list(c(columns(x))) - sdf <- callJMethod(x@sdf, "describe", listToSeq(colList)) + sdf <- callJMethod(x@sdf, "describe", colList) dataFrame(sdf) }) +#' @title Summary +#' +#' @description Computes statistics for numeric columns of the DataFrame +#' +#' @rdname summary +#' @name summary +setMethod("summary", + signature(x = "DataFrame"), + function(x) { + describe(x) + }) + + #' dropna #' #' Returns a new DataFrame omitting rows with null values. @@ -1444,8 +1705,10 @@ setMethod("describe", #' This overwrites the how parameter. #' @param cols Optional list of column names to consider. #' @return A DataFrame -#' +#' #' @rdname nafunctions +#' @name dropna +#' @aliases na.omit #' @export #' @examples #'\dontrun{ @@ -1465,19 +1728,20 @@ setMethod("dropna", if (is.null(minNonNulls)) { minNonNulls <- if (how == "any") { length(cols) } else { 1 } } - + naFunctions <- callJMethod(x@sdf, "na") sdf <- callJMethod(naFunctions, "drop", - as.integer(minNonNulls), listToSeq(as.list(cols))) + as.integer(minNonNulls), as.list(cols)) dataFrame(sdf) }) -#' @aliases dropna +#' @rdname nafunctions +#' @name na.omit #' @export setMethod("na.omit", - signature(x = "DataFrame"), - function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { - dropna(x, how, minNonNulls, cols) + signature(object = "DataFrame"), + function(object, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + dropna(object, how, minNonNulls, cols) }) #' fillna @@ -1488,17 +1752,18 @@ setMethod("na.omit", #' @param value Value to replace null values with. #' Should be an integer, numeric, character or named list. #' If the value is a named list, then cols is ignored and -#' value must be a mapping from column name (character) to +#' value must be a mapping from column name (character) to #' replacement value. The replacement value must be an #' integer, numeric or character. #' @param cols optional list of column names to consider. #' Columns specified in cols that do not have matching data -#' type are ignored. For example, if value is a character, and +#' type are ignored. For example, if value is a character, and #' subset contains a non-character column, then the non-character #' column is simply ignored. #' @return A DataFrame -#' +#' #' @rdname nafunctions +#' @name fillna #' @export #' @examples #'\dontrun{ @@ -1515,14 +1780,14 @@ setMethod("fillna", if (!(class(value) %in% c("integer", "numeric", "character", "list"))) { stop("value should be an integer, numeric, charactor or named list.") } - + if (class(value) == "list") { # Check column names in the named list colNames <- names(value) if (length(colNames) == 0 || !all(colNames != "")) { stop("value should be an a named list with each name being a column name.") } - + # Convert to the named list to an environment to be passed to JVM valueMap <- new.env() for (col in colNames) { @@ -1533,24 +1798,53 @@ setMethod("fillna", } valueMap[[col]] <- v } - + # When value is a named list, caller is expected not to pass in cols if (!is.null(cols)) { warning("When value is a named list, cols is ignored!") cols <- NULL } - + value <- valueMap } else if (is.integer(value)) { # Cast an integer to a numeric value <- as.numeric(value) } - + naFunctions <- callJMethod(x@sdf, "na") sdf <- if (length(cols) == 0) { callJMethod(naFunctions, "fill", value) } else { - callJMethod(naFunctions, "fill", value, listToSeq(as.list(cols))) + callJMethod(naFunctions, "fill", value, as.list(cols)) } dataFrame(sdf) }) + +#' crosstab +#' +#' Computes a pair-wise frequency table of the given columns. Also known as a contingency +#' table. The number of distinct values for each column should be less than 1e4. At most 1e6 +#' non-zero pair frequencies will be returned. +#' +#' @param col1 name of the first column. Distinct items will make the first item of each row. +#' @param col2 name of the second column. Distinct items will make the column names of the output. +#' @return a local R data.frame representing the contingency table. The first column of each row +#' will be the distinct values of `col1` and the column names will be the distinct values +#' of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no +#' occurrences will have zero as their counts. +#' +#' @rdname statfunctions +#' @name crosstab +#' @export +#' @examples +#' \dontrun{ +#' df <- jsonFile(sqlCtx, "/path/to/file.json") +#' ct = crosstab(df, "title", "gender") +#' } +setMethod("crosstab", + signature(x = "DataFrame", col1 = "character", col2 = "character"), + function(x, col1, col2) { + statFunctions <- callJMethod(x@sdf, "stat") + sct <- callJMethod(statFunctions, "crosstab", col1, col2) + collect(dataFrame(sct)) + }) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 051329951564..051e441d4e06 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -48,7 +48,7 @@ setMethod("initialize", "RDD", function(.Object, jrdd, serializedMode, # byte: The RDD stores data serialized in R. # string: The RDD stores data as strings. # row: The RDD stores the serialized rows of a DataFrame. - + # We use an environment to store mutable states inside an RDD object. # Note that R's call-by-value semantics makes modifying slots inside an # object (passed as an argument into a function, such as cache()) difficult: @@ -85,7 +85,9 @@ setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) isPipelinable <- function(rdd) { e <- rdd@env + # nolint start !(e$isCached || e$isCheckpointed) + # nolint end } if (!inherits(prev, "PipelinedRDD") || !isPipelinable(prev)) { @@ -97,7 +99,8 @@ setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) # prev_serializedMode is used during the delayed computation of JRDD in getJRDD } else { pipelinedFunc <- function(partIndex, part) { - func(partIndex, prev@func(partIndex, part)) + f <- prev@func + func(partIndex, f(partIndex, part)) } .Object@func <- cleanClosure(pipelinedFunc) .Object@prev_jrdd <- prev@prev_jrdd # maintain the pipeline @@ -165,7 +168,6 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"), serializedFuncArr, rdd@env$prev_serializedMode, packageNamesArr, - as.character(.sparkREnv[["libname"]]), broadcastArr, callJMethod(prev_jrdd, "classTag")) } else { @@ -175,7 +177,6 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"), rdd@env$prev_serializedMode, serializedMode, packageNamesArr, - as.character(.sparkREnv[["libname"]]), broadcastArr, callJMethod(prev_jrdd, "classTag")) } @@ -363,7 +364,7 @@ setMethod("collectPartition", # @description # \code{collectAsMap} returns a named list as a map that contains all of the elements -# in a key-value pair RDD. +# in a key-value pair RDD. # @examples #\dontrun{ # sc <- sparkR.init() @@ -666,7 +667,7 @@ setMethod("minimum", # rdd <- parallelize(sc, 1:10) # sumRDD(rdd) # 55 #} -# @rdname sumRDD +# @rdname sumRDD # @aliases sumRDD,RDD setMethod("sumRDD", signature(x = "RDD"), @@ -843,7 +844,7 @@ setMethod("sampleRDD", if (withReplacement) { count <- rpois(1, fraction) if (count > 0) { - res[(len + 1):(len + count)] <- rep(list(elem), count) + res[ (len + 1) : (len + count) ] <- rep(list(elem), count) len <- len + count } } else { @@ -1090,11 +1091,11 @@ setMethod("sortBy", # Return: # A list of the first N elements from the RDD in the specified order. # -takeOrderedElem <- function(x, num, ascending = TRUE) { +takeOrderedElem <- function(x, num, ascending = TRUE) { if (num <= 0L) { return(list()) } - + partitionFunc <- function(part) { if (num < length(part)) { # R limitation: order works only on primitive types! @@ -1152,7 +1153,7 @@ takeOrderedElem <- function(x, num, ascending = TRUE) { # @aliases takeOrdered,RDD,RDD-method setMethod("takeOrdered", signature(x = "RDD", num = "integer"), - function(x, num) { + function(x, num) { takeOrderedElem(x, num) }) @@ -1173,7 +1174,7 @@ setMethod("takeOrdered", # @aliases top,RDD,RDD-method setMethod("top", signature(x = "RDD", num = "integer"), - function(x, num) { + function(x, num) { takeOrderedElem(x, num, FALSE) }) @@ -1181,7 +1182,7 @@ setMethod("top", # # Aggregate the elements of each partition, and then the results for all the # partitions, using a given associative function and a neutral "zero value". -# +# # @param x An RDD. # @param zeroValue A neutral "zero value". # @param op An associative function for the folding operation. @@ -1207,7 +1208,7 @@ setMethod("fold", # # Aggregate the elements of each partition, and then the results for all the # partitions, using given combine functions and a neutral "zero value". -# +# # @param x An RDD. # @param zeroValue A neutral "zero value". # @param seqOp A function to aggregate the RDD elements. It may return a different @@ -1230,11 +1231,11 @@ setMethod("fold", # @aliases aggregateRDD,RDD,RDD-method setMethod("aggregateRDD", signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", combOp = "ANY"), - function(x, zeroValue, seqOp, combOp) { + function(x, zeroValue, seqOp, combOp) { partitionFunc <- function(part) { Reduce(seqOp, part, zeroValue) } - + partitionList <- collect(lapplyPartition(x, partitionFunc), flatten = FALSE) Reduce(combOp, partitionList, zeroValue) @@ -1263,12 +1264,12 @@ setMethod("pipeRDD", signature(x = "RDD", command = "character"), function(x, command, env = list()) { func <- function(part) { - trim.trailing.func <- function(x) { + trim_trailing_func <- function(x) { sub("[\r\n]*$", "", toString(x)) } - input <- unlist(lapply(part, trim.trailing.func)) + input <- unlist(lapply(part, trim_trailing_func)) res <- system2(command, stdout = TRUE, input = input, env = env) - lapply(res, trim.trailing.func) + lapply(res, trim_trailing_func) } lapplyPartition(x, func) }) @@ -1330,7 +1331,7 @@ setMethod("setName", #\dontrun{ # sc <- sparkR.init() # rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) -# collect(zipWithUniqueId(rdd)) +# collect(zipWithUniqueId(rdd)) # # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) #} # @rdname zipWithUniqueId @@ -1426,7 +1427,7 @@ setMethod("glom", partitionFunc <- function(part) { list(part) } - + lapplyPartition(x, partitionFunc) }) @@ -1498,16 +1499,16 @@ setMethod("zipRDD", # The jrdd's elements are of scala Tuple2 type. The serialized # flag here is used for the elements inside the tuples. rdd <- RDD(jrdd, getSerializedMode(rdds[[1]])) - + mergePartitions(rdd, TRUE) }) # Cartesian product of this RDD and another one. # -# Return the Cartesian product of this RDD and another one, -# that is, the RDD of all pairs of elements (a, b) where a +# Return the Cartesian product of this RDD and another one, +# that is, the RDD of all pairs of elements (a, b) where a # is in this and b is in other. -# +# # @param x An RDD. # @param other An RDD. # @return A new RDD which is the Cartesian product of these two RDDs. @@ -1515,7 +1516,7 @@ setMethod("zipRDD", #\dontrun{ # sc <- sparkR.init() # rdd <- parallelize(sc, 1:2) -# sortByKey(cartesian(rdd, rdd)) +# sortByKey(cartesian(rdd, rdd)) # # list(list(1, 1), list(1, 2), list(2, 1), list(2, 2)) #} # @rdname cartesian @@ -1528,7 +1529,7 @@ setMethod("cartesian", # The jrdd's elements are of scala Tuple2 type. The serialized # flag here is used for the elements inside the tuples. rdd <- RDD(jrdd, getSerializedMode(rdds[[1]])) - + mergePartitions(rdd, FALSE) }) @@ -1598,11 +1599,11 @@ setMethod("intersection", # Zips an RDD's partitions with one (or more) RDD(s). # Same as zipPartitions in Spark. -# +# # @param ... RDDs to be zipped. # @param func A function to transform zipped partitions. -# @return A new RDD by applying a function to the zipped partitions. -# Assumes that all the RDDs have the *same number of partitions*, but +# @return A new RDD by applying a function to the zipped partitions. +# Assumes that all the RDDs have the *same number of partitions*, but # does *not* require them to have the same number of elements in each partition. # @examples #\dontrun{ @@ -1610,7 +1611,7 @@ setMethod("intersection", # rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 # rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 # rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 -# collect(zipPartitions(rdd1, rdd2, rdd3, +# collect(zipPartitions(rdd1, rdd2, rdd3, # func = function(x, y, z) { list(list(x, y, z))} )) # # list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6))) #} @@ -1627,7 +1628,7 @@ setMethod("zipPartitions", if (length(unique(nPart)) != 1) { stop("Can only zipPartitions RDDs which have the same number of partitions.") } - + rrdds <- lapply(rrdds, function(rdd) { mapPartitionsWithIndex(rdd, function(partIndex, part) { print(length(part)) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 22a4b5bf86eb..1c58fd96d750 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -41,15 +41,12 @@ infer_type <- function(x) { if (type == "map") { stopifnot(length(x) > 0) key <- ls(x)[[1]] - list(type = "map", - keyType = "string", - valueType = infer_type(get(key, x)), - valueContainsNull = TRUE) + paste0("map") } else if (type == "array") { stopifnot(length(x) > 0) names <- names(x) if (is.null(names)) { - list(type = "array", elementType = infer_type(x[[1]]), containsNull = TRUE) + paste0("array<", infer_type(x[[1]]), ">") } else { # StructType types <- lapply(x, infer_type) @@ -59,7 +56,7 @@ infer_type <- function(x) { do.call(structType, fields) } } else if (length(x) > 1) { - list(type = "array", elementType = type, containsNull = TRUE) + paste0("array<", infer_type(x[[1]]), ">") } else { type } @@ -86,7 +83,9 @@ infer_type <- function(x) { createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0) { if (is.data.frame(data)) { # get the names of columns, they will be put into RDD - schema <- names(data) + if (is.null(schema)) { + schema <- names(data) + } n <- nrow(data) m <- ncol(data) # get rid of factor type @@ -182,7 +181,7 @@ setMethod("toDF", signature(x = "RDD"), #' Create a DataFrame from a JSON file. #' -#' Loads a JSON file (one object per line), returning the result as a DataFrame +#' Loads a JSON file (one object per line), returning the result as a DataFrame #' It goes through the entire dataset once to determine the schema. #' #' @param sqlContext SQLContext to use @@ -199,7 +198,7 @@ setMethod("toDF", signature(x = "RDD"), jsonFile <- function(sqlContext, path) { # Allow the user to have a more flexible definiton of the text file path - path <- normalizePath(path) + path <- suppressWarnings(normalizePath(path)) # Convert a string vector of paths to a string containing comma separated paths path <- paste(path, collapse = ",") sdf <- callJMethod(sqlContext, "jsonFile", path) @@ -238,7 +237,7 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { #' Create a DataFrame from a Parquet file. -#' +#' #' Loads a Parquet file, returning the result as a DataFrame. #' #' @param sqlContext SQLContext to use @@ -249,7 +248,7 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { # TODO: Implement saveasParquetFile and write examples for both parquetFile <- function(sqlContext, ...) { # Allow the user to have a more flexible definiton of the text file path - paths <- lapply(list(...), normalizePath) + paths <- lapply(list(...), function(x) suppressWarnings(normalizePath(x))) sdf <- callJMethod(sqlContext, "parquetFile", paths) dataFrame(sdf) } @@ -278,7 +277,7 @@ sql <- function(sqlContext, sqlQuery) { } #' Create a DataFrame from a SparkSQL Table -#' +#' #' Returns the specified Table as a DataFrame. The Table must have already been registered #' in the SQLContext. #' @@ -298,7 +297,7 @@ sql <- function(sqlContext, sqlQuery) { table <- function(sqlContext, tableName) { sdf <- callJMethod(sqlContext, "table", tableName) - dataFrame(sdf) + dataFrame(sdf) } @@ -352,7 +351,7 @@ tableNames <- function(sqlContext, databaseName = NULL) { #' Cache Table -#' +#' #' Caches the specified table in-memory. #' #' @param sqlContext SQLContext to use @@ -370,11 +369,11 @@ tableNames <- function(sqlContext, databaseName = NULL) { #' } cacheTable <- function(sqlContext, tableName) { - callJMethod(sqlContext, "cacheTable", tableName) + callJMethod(sqlContext, "cacheTable", tableName) } #' Uncache Table -#' +#' #' Removes the specified table from the in-memory cache. #' #' @param sqlContext SQLContext to use @@ -455,7 +454,7 @@ dropTempTable <- function(sqlContext, tableName) { read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { options <- varargsToEnv(...) if (!is.null(path)) { - options[['path']] <- path + options[["path"]] <- path } if (is.null(source)) { sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) @@ -504,7 +503,7 @@ loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { createExternalTable <- function(sqlContext, tableName, path = NULL, source = NULL, ...) { options <- varargsToEnv(...) if (!is.null(path)) { - options[['path']] <- path + options[["path"]] <- path } sdf <- callJMethod(sqlContext, "createExternalTable", tableName, source, options) dataFrame(sdf) diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R index 2fb6fae55f28..49162838b8d1 100644 --- a/R/pkg/R/backend.R +++ b/R/pkg/R/backend.R @@ -110,6 +110,8 @@ invokeJava <- function(isStatic, objId, methodName, ...) { # TODO: check the status code to output error information returnStatus <- readInt(conn) - stopifnot(returnStatus == 0) + if (returnStatus != 0) { + stop(readString(conn)) + } readObject(conn) } diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R index 23dc38780716..2403925b267c 100644 --- a/R/pkg/R/broadcast.R +++ b/R/pkg/R/broadcast.R @@ -27,9 +27,9 @@ # @description Broadcast variables can be created using the broadcast # function from a \code{SparkContext}. # @rdname broadcast-class -# @seealso broadcast +# @seealso broadcast # -# @param id Id of the backing Spark broadcast variable +# @param id Id of the backing Spark broadcast variable # @export setClass("Broadcast", slots = list(id = "character")) @@ -68,7 +68,7 @@ setMethod("value", # variable on workers. Not intended for use outside the package. # # @rdname broadcast-internal -# @seealso broadcast, value +# @seealso broadcast, value # @param bcastId The id of broadcast variable to set # @param value The value to be set diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 1281c41213e3..c811d1dac3bd 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -34,24 +34,36 @@ connectBackend <- function(hostname, port, timeout = 6000) { con } -launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts) { +determineSparkSubmitBin <- function() { if (.Platform$OS.type == "unix") { - sparkSubmitBinName = "spark-submit" + sparkSubmitBinName <- "spark-submit" } else { - sparkSubmitBinName = "spark-submit.cmd" + sparkSubmitBinName <- "spark-submit.cmd" + } + sparkSubmitBinName +} + +generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { + if (jars != "") { + jars <- paste("--jars", jars) + } + + if (!identical(packages, "")) { + packages <- paste("--packages", packages) } + combinedArgs <- paste(jars, packages, sparkSubmitOpts, args, sep = " ") + combinedArgs +} + +launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { + sparkSubmitBinName <- determineSparkSubmitBin() if (sparkHome != "") { sparkSubmitBin <- file.path(sparkHome, "bin", sparkSubmitBinName) } else { sparkSubmitBin <- sparkSubmitBinName } - - if (jars != "") { - jars <- paste("--jars", jars) - } - - combinedArgs <- paste(jars, sparkSubmitOpts, args, sep = " ") + combinedArgs <- generateSparkSubmitArgs(args, sparkHome, jars, sparkSubmitOpts, packages) cat("Launching java with spark-submit command", sparkSubmitBin, combinedArgs, "\n") invisible(system2(sparkSubmitBin, combinedArgs, wait = F)) } diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 80e92d3105a3..42e9d12179db 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -24,10 +24,9 @@ setOldClass("jobj") #' @title S4 class that represents a DataFrame column #' @description The column class supports unary, binary operations on DataFrame columns - #' @rdname column #' -#' @param jc reference to JVM DataFrame column +#' @slot jc reference to JVM DataFrame column #' @export setClass("Column", slots = list(jc = "jobj")) @@ -46,6 +45,7 @@ col <- function(x) { } #' @rdname show +#' @name show setMethod("show", "Column", function(object) { cat("Column", callJMethod(object@jc, "toString"), "\n") @@ -60,12 +60,6 @@ operators <- list( ) column_functions1 <- c("asc", "desc", "isNull", "isNotNull") column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains") -functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt", - "first", "last", "lower", "upper", "sumDistinct", - "acos", "asin", "atan", "cbrt", "ceiling", "cos", "cosh", "exp", - "expm1", "floor", "log", "log10", "log1p", "rint", "sign", - "sin", "sinh", "tan", "tanh", "toDegrees", "toRadians") -binary_mathfunctions<- c("atan2", "hypot") createOperator <- function(op) { setMethod(op, @@ -111,33 +105,6 @@ createColumnFunction2 <- function(name) { }) } -createStaticFunction <- function(name) { - setMethod(name, - signature(x = "Column"), - function(x) { - if (name == "ceiling") { - name <- "ceil" - } - if (name == "sign") { - name <- "signum" - } - jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc) - column(jc) - }) -} - -createBinaryMathfunctions <- function(name) { - setMethod(name, - signature(y = "Column"), - function(y, x) { - if (class(x) == "Column") { - x <- x@jc - } - jc <- callJStatic("org.apache.spark.sql.functions", name, y@jc, x) - column(jc) - }) -} - createMethods <- function() { for (op in names(operators)) { createOperator(op) @@ -148,12 +115,6 @@ createMethods <- function() { for (name in column_functions2) { createColumnFunction2(name) } - for (x in functions) { - createStaticFunction(x) - } - for (name in binary_mathfunctions) { - createBinaryMathfunctions(name) - } } createMethods() @@ -161,8 +122,11 @@ createMethods() #' alias #' #' Set a new name for a column - -#' @rdname column +#' +#' @rdname alias +#' @name alias +#' @family colum_func +#' @export setMethod("alias", signature(object = "Column"), function(object, data) { @@ -177,7 +141,9 @@ setMethod("alias", #' #' An expression that returns a substring. #' -#' @rdname column +#' @rdname substr +#' @name substr +#' @family colum_func #' #' @param start starting position #' @param stop ending position @@ -187,12 +153,32 @@ setMethod("substr", signature(x = "Column"), column(jc) }) +#' between +#' +#' Test if the column is between the lower bound and upper bound, inclusive. +#' +#' @rdname between +#' @name between +#' @family colum_func +#' +#' @param bounds lower and upper bounds +setMethod("between", signature(x = "Column"), + function(x, bounds) { + if (is.vector(bounds) && length(bounds) == 2) { + jc <- callJMethod(x@jc, "between", bounds[1], bounds[2]) + column(jc) + } else { + stop("bounds should be a vector of lower and upper bounds") + } + }) + #' Casts the column to a different data type. #' -#' @rdname column +#' @rdname cast +#' @name cast +#' @family colum_func #' -#' @examples -#' \dontrun{ +#' @examples \dontrun{ #' cast(df$age, "string") #' cast(df$name, list(type="array", elementType="byte", containsNull = TRUE)) #' } @@ -210,44 +196,38 @@ setMethod("cast", } }) -#' Approx Count Distinct +#' Match a column with given values. #' -#' @rdname column -#' @return the approximate number of distinct items in a group. -setMethod("approxCountDistinct", +#' @rdname match +#' @name %in% +#' @aliases %in% +#' @return a matched values as a result of comparing with given values. +#' @export +#' @examples +#' \dontrun{ +#' filter(df, "age in (10, 30)") +#' where(df, df$age %in% c(10, 30)) +#' } +setMethod("%in%", signature(x = "Column"), - function(x, rsd = 0.95) { - jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) - column(jc) + function(x, table) { + jc <- callJMethod(x@jc, "in", as.list(table)) + return(column(jc)) }) -#' Count Distinct +#' otherwise #' -#' @rdname column -#' @return the number of distinct items in a group. -setMethod("countDistinct", - signature(x = "Column"), - function(x, ...) { - jcol <- lapply(list(...), function (x) { - x@jc - }) - jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, - listToSeq(jcol)) +#' If values in the specified column are null, returns the value. +#' Can be used in conjunction with `when` to specify a default value for expressions. +#' +#' @rdname otherwise +#' @name otherwise +#' @family colum_func +#' @export +setMethod("otherwise", + signature(x = "Column", value = "ANY"), + function(x, value) { + value <- ifelse(class(value) == "Column", value@jc, value) + jc <- callJMethod(x@jc, "otherwise", value) column(jc) }) - -#' @rdname column -#' @aliases countDistinct -setMethod("n_distinct", - signature(x = "Column"), - function(x, ...) { - countDistinct(x, ...) - }) - -#' @rdname column -#' @aliases count -setMethod("n", - signature(x = "Column"), - function(x) { - count(x) - }) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 43be9c904fdf..720990e1c608 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -121,7 +121,7 @@ parallelize <- function(sc, coll, numSlices = 1) { numSlices <- length(coll) sliceLen <- ceiling(length(coll) / numSlices) - slices <- split(coll, rep(1:(numSlices + 1), each = sliceLen)[1:length(coll)]) + slices <- split(coll, rep(1: (numSlices + 1), each = sliceLen)[1:length(coll)]) # Serialize each slice: obtain a list of raws, or a list of lists (slices) of # 2-tuples of raws diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index 257b435607ce..ce88d0b071b7 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -18,11 +18,12 @@ # Utility functions to deserialize objects from Java. # Type mapping from Java to R -# +# # void -> NULL # Int -> integer # String -> character # Boolean -> logical +# Float -> double # Double -> double # Long -> double # Array[Byte] -> raw @@ -47,7 +48,9 @@ readTypedObject <- function(con, type) { "r" = readRaw(con), "D" = readDate(con), "t" = readTime(con), + "a" = readArray(con), "l" = readList(con), + "e" = readEnv(con), "n" = NULL, "j" = getJobj(readString(con)), stop(paste("Unsupported type for deserialization", type))) @@ -55,8 +58,10 @@ readTypedObject <- function(con, type) { readString <- function(con) { stringLen <- readInt(con) - string <- readBin(con, raw(), stringLen, endian = "big") - rawToChar(string) + raw <- readBin(con, raw(), stringLen, endian = "big") + string <- rawToChar(raw) + Encoding(string) <- "UTF-8" + string } readInt <- function(con) { @@ -84,8 +89,7 @@ readTime <- function(con) { as.POSIXct(t, origin = "1970-01-01") } -# We only support lists where all elements are of same type -readList <- function(con) { +readArray <- function(con) { type <- readType(con) len <- readInt(con) if (len > 0) { @@ -99,13 +103,45 @@ readList <- function(con) { } } +# Read a list. Types of each element may be different. +# Null objects are read as NA. +readList <- function(con) { + len <- readInt(con) + if (len > 0) { + l <- vector("list", len) + for (i in 1:len) { + elem <- readObject(con) + if (is.null(elem)) { + elem <- NA + } + l[[i]] <- elem + } + l + } else { + list() + } +} + +readEnv <- function(con) { + env <- new.env() + len <- readInt(con) + if (len > 0) { + for (i in 1:len) { + key <- readString(con) + value <- readObject(con) + env[[key]] <- value + } + } + env +} + readRaw <- function(con) { dataLen <- readInt(con) - data <- readBin(con, raw(), as.integer(dataLen), endian = "big") + readBin(con, raw(), as.integer(dataLen), endian = "big") } readRawLen <- function(con, dataLen) { - data <- readBin(con, raw(), as.integer(dataLen), endian = "big") + readBin(con, raw(), as.integer(dataLen), endian = "big") } readDeserialize <- function(con) { @@ -131,18 +167,19 @@ readDeserialize <- function(con) { } } -readDeserializeRows <- function(inputCon) { - # readDeserializeRows will deserialize a DataOutputStream composed of - # a list of lists. Since the DOS is one continuous stream and - # the number of rows varies, we put the readRow function in a while loop - # that termintates when the next row is empty. +readMultipleObjects <- function(inputCon) { + # readMultipleObjects will read multiple continuous objects from + # a DataOutputStream. There is no preceding field telling the count + # of the objects, so the number of objects varies, we try to read + # all objects in a loop until the end of the stream. data <- list() while(TRUE) { - row <- readRow(inputCon) - if (length(row) == 0) { + # If reaching the end of the stream, type returned should be "". + type <- readType(inputCon) + if (type == "") { break } - data[[length(data) + 1L]] <- row + data[[length(data) + 1L]] <- readTypedObject(inputCon, type) } data # this is a list of named lists now } @@ -154,31 +191,5 @@ readRowList <- function(obj) { # deserialize the row. rawObj <- rawConnection(obj, "r+") on.exit(close(rawObj)) - readRow(rawObj) -} - -readRow <- function(inputCon) { - numCols <- readInt(inputCon) - if (length(numCols) > 0 && numCols > 0) { - lapply(1:numCols, function(x) { - obj <- readObject(inputCon) - if (is.null(obj)) { - NA - } else { - obj - } - }) # each row is a list now - } else { - list() - } -} - -# Take a single column as Array[Byte] and deserialize it into an atomic vector -readCol <- function(inputCon, numRows) { - # sapply can not work with POSIXlt - do.call(c, lapply(1:numRows, function(x) { - value <- readObject(inputCon) - # Replace NULL with NA so we can coerce to vectors - if (is.null(value)) NA else value - })) + readObject(rawObj) } diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R new file mode 100644 index 000000000000..94687edb0544 --- /dev/null +++ b/R/pkg/R/functions.R @@ -0,0 +1,1988 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +#' @include generics.R column.R +NULL + +#' Creates a \code{Column} of literal value. +#' +#' The passed in object is returned directly if it is already a \linkS4class{Column}. +#' If the object is a Scala Symbol, it is converted into a \linkS4class{Column} also. +#' Otherwise, a new \linkS4class{Column} is created to represent the literal value. +#' +#' @family normal_funcs +#' @rdname lit +#' @name lit +#' @export +setMethod("lit", signature("ANY"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "lit", + ifelse(class(x) == "Column", x@jc, x)) + column(jc) + }) + +#' abs +#' +#' Computes the absolute value. +#' +#' @rdname abs +#' @name abs +#' @family normal_funcs +#' @export +#' @examples \dontrun{abs(df$c)} +setMethod("abs", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "abs", x@jc) + column(jc) + }) + +#' acos +#' +#' Computes the cosine inverse of the given value; the returned angle is in the range +#' 0.0 through pi. +#' +#' @rdname acos +#' @name acos +#' @family math_funcs +#' @export +#' @examples \dontrun{acos(df$c)} +setMethod("acos", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "acos", x@jc) + column(jc) + }) + +#' approxCountDistinct +#' +#' Aggregate function: returns the approximate number of distinct items in a group. +#' +#' @rdname approxCountDistinct +#' @name approxCountDistinct +#' @family agg_funcs +#' @export +#' @examples \dontrun{approxCountDistinct(df$c)} +setMethod("approxCountDistinct", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc) + column(jc) + }) + +#' ascii +#' +#' Computes the numeric value of the first character of the string column, and returns the +#' result as a int column. +#' +#' @rdname ascii +#' @name ascii +#' @family string_funcs +#' @export +#' @examples \dontrun{\dontrun{ascii(df$c)}} +setMethod("ascii", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "ascii", x@jc) + column(jc) + }) + +#' asin +#' +#' Computes the sine inverse of the given value; the returned angle is in the range +#' -pi/2 through pi/2. +#' +#' @rdname asin +#' @name asin +#' @family math_funcs +#' @export +#' @examples \dontrun{asin(df$c)} +setMethod("asin", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "asin", x@jc) + column(jc) + }) + +#' atan +#' +#' Computes the tangent inverse of the given value. +#' +#' @rdname atan +#' @name atan +#' @family math_funcs +#' @export +#' @examples \dontrun{atan(df$c)} +setMethod("atan", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "atan", x@jc) + column(jc) + }) + +#' avg +#' +#' Aggregate function: returns the average of the values in a group. +#' +#' @rdname avg +#' @name avg +#' @family agg_funcs +#' @export +#' @examples \dontrun{avg(df$c)} +setMethod("avg", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "avg", x@jc) + column(jc) + }) + +#' base64 +#' +#' Computes the BASE64 encoding of a binary column and returns it as a string column. +#' This is the reverse of unbase64. +#' +#' @rdname base64 +#' @name base64 +#' @family string_funcs +#' @export +#' @examples \dontrun{base64(df$c)} +setMethod("base64", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "base64", x@jc) + column(jc) + }) + +#' bin +#' +#' An expression that returns the string representation of the binary value of the given long +#' column. For example, bin("12") returns "1100". +#' +#' @rdname bin +#' @name bin +#' @family math_funcs +#' @export +#' @examples \dontrun{bin(df$c)} +setMethod("bin", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "bin", x@jc) + column(jc) + }) + +#' bitwiseNOT +#' +#' Computes bitwise NOT. +#' +#' @rdname bitwiseNOT +#' @name bitwiseNOT +#' @family normal_funcs +#' @export +#' @examples \dontrun{bitwiseNOT(df$c)} +setMethod("bitwiseNOT", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "bitwiseNOT", x@jc) + column(jc) + }) + +#' cbrt +#' +#' Computes the cube-root of the given value. +#' +#' @rdname cbrt +#' @name cbrt +#' @family math_funcs +#' @export +#' @examples \dontrun{cbrt(df$c)} +setMethod("cbrt", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "cbrt", x@jc) + column(jc) + }) + +#' ceil +#' +#' Computes the ceiling of the given value. +#' +#' @rdname ceil +#' @name ceil +#' @family math_funcs +#' @export +#' @examples \dontrun{ceil(df$c)} +setMethod("ceil", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "ceil", x@jc) + column(jc) + }) + +#' cos +#' +#' Computes the cosine of the given value. +#' +#' @rdname cos +#' @name cos +#' @family math_funcs +#' @export +#' @examples \dontrun{cos(df$c)} +setMethod("cos", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "cos", x@jc) + column(jc) + }) + +#' cosh +#' +#' Computes the hyperbolic cosine of the given value. +#' +#' @rdname cosh +#' @name cosh +#' @family math_funcs +#' @export +#' @examples \dontrun{cosh(df$c)} +setMethod("cosh", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "cosh", x@jc) + column(jc) + }) + +#' count +#' +#' Aggregate function: returns the number of items in a group. +#' +#' @rdname count +#' @name count +#' @family agg_funcs +#' @export +#' @examples \dontrun{count(df$c)} +setMethod("count", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "count", x@jc) + column(jc) + }) + +#' crc32 +#' +#' Calculates the cyclic redundancy check value (CRC32) of a binary column and +#' returns the value as a bigint. +#' +#' @rdname crc32 +#' @name crc32 +#' @family misc_funcs +#' @export +#' @examples \dontrun{crc32(df$c)} +setMethod("crc32", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "crc32", x@jc) + column(jc) + }) + +#' dayofmonth +#' +#' Extracts the day of the month as an integer from a given date/timestamp/string. +#' +#' @rdname dayofmonth +#' @name dayofmonth +#' @family datetime_funcs +#' @export +#' @examples \dontrun{dayofmonth(df$c)} +setMethod("dayofmonth", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "dayofmonth", x@jc) + column(jc) + }) + +#' dayofyear +#' +#' Extracts the day of the year as an integer from a given date/timestamp/string. +#' +#' @rdname dayofyear +#' @name dayofyear +#' @family datetime_funcs +#' @export +#' @examples \dontrun{dayofyear(df$c)} +setMethod("dayofyear", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "dayofyear", x@jc) + column(jc) + }) + +#' exp +#' +#' Computes the exponential of the given value. +#' +#' @rdname exp +#' @name exp +#' @family math_funcs +#' @export +#' @examples \dontrun{exp(df$c)} +setMethod("exp", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "exp", x@jc) + column(jc) + }) + +#' explode +#' +#' Creates a new row for each element in the given array or map column. +#' +#' @rdname explode +#' @name explode +#' @family collection_funcs +#' @export +#' @examples \dontrun{explode(df$c)} +setMethod("explode", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "explode", x@jc) + column(jc) + }) + +#' expm1 +#' +#' Computes the exponential of the given value minus one. +#' +#' @rdname expm1 +#' @name expm1 +#' @family math_funcs +#' @export +#' @examples \dontrun{expm1(df$c)} +setMethod("expm1", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "expm1", x@jc) + column(jc) + }) + +#' factorial +#' +#' Computes the factorial of the given value. +#' +#' @rdname factorial +#' @name factorial +#' @family math_funcs +#' @export +#' @examples \dontrun{factorial(df$c)} +setMethod("factorial", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "factorial", x@jc) + column(jc) + }) + +#' first +#' +#' Aggregate function: returns the first value in a group. +#' +#' @rdname first +#' @name first +#' @family agg_funcs +#' @export +#' @examples \dontrun{first(df$c)} +setMethod("first", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "first", x@jc) + column(jc) + }) + +#' floor +#' +#' Computes the floor of the given value. +#' +#' @rdname floor +#' @name floor +#' @family math_funcs +#' @export +#' @examples \dontrun{floor(df$c)} +setMethod("floor", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "floor", x@jc) + column(jc) + }) + +#' hex +#' +#' Computes hex value of the given column. +#' +#' @rdname hex +#' @name hex +#' @family math_funcs +#' @export +#' @examples \dontrun{hex(df$c)} +setMethod("hex", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "hex", x@jc) + column(jc) + }) + +#' hour +#' +#' Extracts the hours as an integer from a given date/timestamp/string. +#' +#' @rdname hour +#' @name hour +#' @family datetime_funcs +#' @export +#' @examples \dontrun{hour(df$c)} +setMethod("hour", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "hour", x@jc) + column(jc) + }) + +#' initcap +#' +#' Returns a new string column by converting the first letter of each word to uppercase. +#' Words are delimited by whitespace. +#' +#' For example, "hello world" will become "Hello World". +#' +#' @rdname initcap +#' @name initcap +#' @family string_funcs +#' @export +#' @examples \dontrun{initcap(df$c)} +setMethod("initcap", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "initcap", x@jc) + column(jc) + }) + +#' isNaN +#' +#' Return true iff the column is NaN. +#' +#' @rdname isNaN +#' @name isNaN +#' @family normal_funcs +#' @export +#' @examples \dontrun{isNaN(df$c)} +setMethod("isNaN", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "isNaN", x@jc) + column(jc) + }) + +#' last +#' +#' Aggregate function: returns the last value in a group. +#' +#' @rdname last +#' @name last +#' @family agg_funcs +#' @export +#' @examples \dontrun{last(df$c)} +setMethod("last", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "last", x@jc) + column(jc) + }) + +#' last_day +#' +#' Given a date column, returns the last day of the month which the given date belongs to. +#' For example, input "2015-07-27" returns "2015-07-31" since July 31 is the last day of the +#' month in July 2015. +#' +#' @rdname last_day +#' @name last_day +#' @family datetime_funcs +#' @export +#' @examples \dontrun{last_day(df$c)} +setMethod("last_day", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "last_day", x@jc) + column(jc) + }) + +#' length +#' +#' Computes the length of a given string or binary column. +#' +#' @rdname length +#' @name length +#' @family string_funcs +#' @export +#' @examples \dontrun{length(df$c)} +setMethod("length", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "length", x@jc) + column(jc) + }) + +#' log +#' +#' Computes the natural logarithm of the given value. +#' +#' @rdname log +#' @name log +#' @family math_funcs +#' @export +#' @examples \dontrun{log(df$c)} +setMethod("log", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "log", x@jc) + column(jc) + }) + +#' log10 +#' +#' Computes the logarithm of the given value in base 10. +#' +#' @rdname log10 +#' @name log10 +#' @family math_funcs +#' @export +#' @examples \dontrun{log10(df$c)} +setMethod("log10", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "log10", x@jc) + column(jc) + }) + +#' log1p +#' +#' Computes the natural logarithm of the given value plus one. +#' +#' @rdname log1p +#' @name log1p +#' @family math_funcs +#' @export +#' @examples \dontrun{log1p(df$c)} +setMethod("log1p", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "log1p", x@jc) + column(jc) + }) + +#' log2 +#' +#' Computes the logarithm of the given column in base 2. +#' +#' @rdname log2 +#' @name log2 +#' @family math_funcs +#' @export +#' @examples \dontrun{log2(df$c)} +setMethod("log2", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "log2", x@jc) + column(jc) + }) + +#' lower +#' +#' Converts a string column to lower case. +#' +#' @rdname lower +#' @name lower +#' @family string_funcs +#' @export +#' @examples \dontrun{lower(df$c)} +setMethod("lower", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "lower", x@jc) + column(jc) + }) + +#' ltrim +#' +#' Trim the spaces from left end for the specified string value. +#' +#' @rdname ltrim +#' @name ltrim +#' @family string_funcs +#' @export +#' @examples \dontrun{ltrim(df$c)} +setMethod("ltrim", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "ltrim", x@jc) + column(jc) + }) + +#' max +#' +#' Aggregate function: returns the maximum value of the expression in a group. +#' +#' @rdname max +#' @name max +#' @family agg_funcs +#' @export +#' @examples \dontrun{max(df$c)} +setMethod("max", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "max", x@jc) + column(jc) + }) + +#' md5 +#' +#' Calculates the MD5 digest of a binary column and returns the value +#' as a 32 character hex string. +#' +#' @rdname md5 +#' @name md5 +#' @family misc_funcs +#' @export +#' @examples \dontrun{md5(df$c)} +setMethod("md5", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "md5", x@jc) + column(jc) + }) + +#' mean +#' +#' Aggregate function: returns the average of the values in a group. +#' Alias for avg. +#' +#' @rdname mean +#' @name mean +#' @family agg_funcs +#' @export +#' @examples \dontrun{mean(df$c)} +setMethod("mean", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "mean", x@jc) + column(jc) + }) + +#' min +#' +#' Aggregate function: returns the minimum value of the expression in a group. +#' +#' @rdname min +#' @name min +#' @family agg_funcs +#' @export +#' @examples \dontrun{min(df$c)} +setMethod("min", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "min", x@jc) + column(jc) + }) + +#' minute +#' +#' Extracts the minutes as an integer from a given date/timestamp/string. +#' +#' @rdname minute +#' @name minute +#' @family datetime_funcs +#' @export +#' @examples \dontrun{minute(df$c)} +setMethod("minute", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "minute", x@jc) + column(jc) + }) + +#' month +#' +#' Extracts the month as an integer from a given date/timestamp/string. +#' +#' @rdname month +#' @name month +#' @family datetime_funcs +#' @export +#' @examples \dontrun{month(df$c)} +setMethod("month", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "month", x@jc) + column(jc) + }) + +#' negate +#' +#' Unary minus, i.e. negate the expression. +#' +#' @rdname negate +#' @name negate +#' @family normal_funcs +#' @export +#' @examples \dontrun{negate(df$c)} +setMethod("negate", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "negate", x@jc) + column(jc) + }) + +#' quarter +#' +#' Extracts the quarter as an integer from a given date/timestamp/string. +#' +#' @rdname quarter +#' @name quarter +#' @family datetime_funcs +#' @export +#' @examples \dontrun{quarter(df$c)} +setMethod("quarter", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "quarter", x@jc) + column(jc) + }) + +#' reverse +#' +#' Reverses the string column and returns it as a new string column. +#' +#' @rdname reverse +#' @name reverse +#' @family string_funcs +#' @export +#' @examples \dontrun{reverse(df$c)} +setMethod("reverse", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "reverse", x@jc) + column(jc) + }) + +#' rint +#' +#' Returns the double value that is closest in value to the argument and +#' is equal to a mathematical integer. +#' +#' @rdname rint +#' @name rint +#' @family math_funcs +#' @export +#' @examples \dontrun{rint(df$c)} +setMethod("rint", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "rint", x@jc) + column(jc) + }) + +#' round +#' +#' Returns the value of the column `e` rounded to 0 decimal places. +#' +#' @rdname round +#' @name round +#' @family math_funcs +#' @export +#' @examples \dontrun{round(df$c)} +setMethod("round", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "round", x@jc) + column(jc) + }) + +#' rtrim +#' +#' Trim the spaces from right end for the specified string value. +#' +#' @rdname rtrim +#' @name rtrim +#' @family string_funcs +#' @export +#' @examples \dontrun{rtrim(df$c)} +setMethod("rtrim", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "rtrim", x@jc) + column(jc) + }) + +#' second +#' +#' Extracts the seconds as an integer from a given date/timestamp/string. +#' +#' @rdname second +#' @name second +#' @family datetime_funcs +#' @export +#' @examples \dontrun{second(df$c)} +setMethod("second", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "second", x@jc) + column(jc) + }) + +#' sha1 +#' +#' Calculates the SHA-1 digest of a binary column and returns the value +#' as a 40 character hex string. +#' +#' @rdname sha1 +#' @name sha1 +#' @family misc_funcs +#' @export +#' @examples \dontrun{sha1(df$c)} +setMethod("sha1", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sha1", x@jc) + column(jc) + }) + +#' signum +#' +#' Computes the signum of the given value. +#' +#' @rdname signum +#' @name signum +#' @family math_funcs +#' @export +#' @examples \dontrun{signum(df$c)} +setMethod("signum", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "signum", x@jc) + column(jc) + }) + +#' sin +#' +#' Computes the sine of the given value. +#' +#' @rdname sin +#' @name sin +#' @family math_funcs +#' @export +#' @examples \dontrun{sin(df$c)} +setMethod("sin", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sin", x@jc) + column(jc) + }) + +#' sinh +#' +#' Computes the hyperbolic sine of the given value. +#' +#' @rdname sinh +#' @name sinh +#' @family math_funcs +#' @export +#' @examples \dontrun{sinh(df$c)} +setMethod("sinh", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sinh", x@jc) + column(jc) + }) + +#' size +#' +#' Returns length of array or map. +#' +#' @rdname size +#' @name size +#' @family collection_funcs +#' @export +#' @examples \dontrun{size(df$c)} +setMethod("size", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "size", x@jc) + column(jc) + }) + +#' soundex +#' +#' Return the soundex code for the specified expression. +#' +#' @rdname soundex +#' @name soundex +#' @family string_funcs +#' @export +#' @examples \dontrun{soundex(df$c)} +setMethod("soundex", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "soundex", x@jc) + column(jc) + }) + +#' sqrt +#' +#' Computes the square root of the specified float value. +#' +#' @rdname sqrt +#' @name sqrt +#' @family math_funcs +#' @export +#' @examples \dontrun{sqrt(df$c)} +setMethod("sqrt", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sqrt", x@jc) + column(jc) + }) + +#' sum +#' +#' Aggregate function: returns the sum of all values in the expression. +#' +#' @rdname sum +#' @name sum +#' @family agg_funcs +#' @export +#' @examples \dontrun{sum(df$c)} +setMethod("sum", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sum", x@jc) + column(jc) + }) + +#' sumDistinct +#' +#' Aggregate function: returns the sum of distinct values in the expression. +#' +#' @rdname sumDistinct +#' @name sumDistinct +#' @family agg_funcs +#' @export +#' @examples \dontrun{sumDistinct(df$c)} +setMethod("sumDistinct", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sumDistinct", x@jc) + column(jc) + }) + +#' tan +#' +#' Computes the tangent of the given value. +#' +#' @rdname tan +#' @name tan +#' @family math_funcs +#' @export +#' @examples \dontrun{tan(df$c)} +setMethod("tan", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "tan", x@jc) + column(jc) + }) + +#' tanh +#' +#' Computes the hyperbolic tangent of the given value. +#' +#' @rdname tanh +#' @name tanh +#' @family math_funcs +#' @export +#' @examples \dontrun{tanh(df$c)} +setMethod("tanh", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "tanh", x@jc) + column(jc) + }) + +#' toDegrees +#' +#' Converts an angle measured in radians to an approximately equivalent angle measured in degrees. +#' +#' @rdname toDegrees +#' @name toDegrees +#' @family math_funcs +#' @export +#' @examples \dontrun{toDegrees(df$c)} +setMethod("toDegrees", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "toDegrees", x@jc) + column(jc) + }) + +#' toRadians +#' +#' Converts an angle measured in degrees to an approximately equivalent angle measured in radians. +#' +#' @rdname toRadians +#' @name toRadians +#' @family math_funcs +#' @export +#' @examples \dontrun{toRadians(df$c)} +setMethod("toRadians", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "toRadians", x@jc) + column(jc) + }) + +#' to_date +#' +#' Converts the column into DateType. +#' +#' @rdname to_date +#' @name to_date +#' @family datetime_funcs +#' @export +#' @examples \dontrun{to_date(df$c)} +setMethod("to_date", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "to_date", x@jc) + column(jc) + }) + +#' trim +#' +#' Trim the spaces from both ends for the specified string column. +#' +#' @rdname trim +#' @name trim +#' @family string_funcs +#' @export +#' @examples \dontrun{trim(df$c)} +setMethod("trim", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "trim", x@jc) + column(jc) + }) + +#' unbase64 +#' +#' Decodes a BASE64 encoded string column and returns it as a binary column. +#' This is the reverse of base64. +#' +#' @rdname unbase64 +#' @name unbase64 +#' @family string_funcs +#' @export +#' @examples \dontrun{unbase64(df$c)} +setMethod("unbase64", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "unbase64", x@jc) + column(jc) + }) + +#' unhex +#' +#' Inverse of hex. Interprets each pair of characters as a hexadecimal number +#' and converts to the byte representation of number. +#' +#' @rdname unhex +#' @name unhex +#' @family math_funcs +#' @export +#' @examples \dontrun{unhex(df$c)} +setMethod("unhex", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "unhex", x@jc) + column(jc) + }) + +#' upper +#' +#' Converts a string column to upper case. +#' +#' @rdname upper +#' @name upper +#' @family string_funcs +#' @export +#' @examples \dontrun{upper(df$c)} +setMethod("upper", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "upper", x@jc) + column(jc) + }) + +#' weekofyear +#' +#' Extracts the week number as an integer from a given date/timestamp/string. +#' +#' @rdname weekofyear +#' @name weekofyear +#' @family datetime_funcs +#' @export +#' @examples \dontrun{weekofyear(df$c)} +setMethod("weekofyear", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "weekofyear", x@jc) + column(jc) + }) + +#' year +#' +#' Extracts the year as an integer from a given date/timestamp/string. +#' +#' @rdname year +#' @name year +#' @family datetime_funcs +#' @export +#' @examples \dontrun{year(df$c)} +setMethod("year", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "year", x@jc) + column(jc) + }) + +#' atan2 +#' +#' Returns the angle theta from the conversion of rectangular coordinates (x, y) to +#' polar coordinates (r, theta). +#' +#' @rdname atan2 +#' @name atan2 +#' @family math_funcs +#' @export +#' @examples \dontrun{atan2(df$c, x)} +setMethod("atan2", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "atan2", y@jc, x) + column(jc) + }) + +#' datediff +#' +#' Returns the number of days from `start` to `end`. +#' +#' @rdname datediff +#' @name datediff +#' @family datetime_funcs +#' @export +#' @examples \dontrun{datediff(df$c, x)} +setMethod("datediff", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "datediff", y@jc, x) + column(jc) + }) + +#' hypot +#' +#' Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. +#' +#' @rdname hypot +#' @name hypot +#' @family math_funcs +#' @export +#' @examples \dontrun{hypot(df$c, x)} +setMethod("hypot", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "hypot", y@jc, x) + column(jc) + }) + +#' levenshtein +#' +#' Computes the Levenshtein distance of the two given string columns. +#' +#' @rdname levenshtein +#' @name levenshtein +#' @family string_funcs +#' @export +#' @examples \dontrun{levenshtein(df$c, x)} +setMethod("levenshtein", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "levenshtein", y@jc, x) + column(jc) + }) + +#' months_between +#' +#' Returns number of months between dates `date1` and `date2`. +#' +#' @rdname months_between +#' @name months_between +#' @family datetime_funcs +#' @export +#' @examples \dontrun{months_between(df$c, x)} +setMethod("months_between", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "months_between", y@jc, x) + column(jc) + }) + +#' nanvl +#' +#' Returns col1 if it is not NaN, or col2 if col1 is NaN. +#' hhBoth inputs should be floating point columns (DoubleType or FloatType). +#' +#' @rdname nanvl +#' @name nanvl +#' @family normal_funcs +#' @export +#' @examples \dontrun{nanvl(df$c, x)} +setMethod("nanvl", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "nanvl", y@jc, x) + column(jc) + }) + +#' pmod +#' +#' Returns the positive value of dividend mod divisor. +#' +#' @rdname pmod +#' @name pmod +#' @docType methods +#' @family math_funcs +#' @export +#' @examples \dontrun{pmod(df$c, x)} +setMethod("pmod", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "pmod", y@jc, x) + column(jc) + }) + + +#' Approx Count Distinct +#' +#' @family agg_funcs +#' @rdname approxCountDistinct +#' @name approxCountDistinct +#' @return the approximate number of distinct items in a group. +#' @export +setMethod("approxCountDistinct", + signature(x = "Column"), + function(x, rsd = 0.95) { + jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) + column(jc) + }) + +#' Count Distinct +#' +#' @family agg_funcs +#' @rdname countDistinct +#' @name countDistinct +#' @return the number of distinct items in a group. +#' @export +setMethod("countDistinct", + signature(x = "Column"), + function(x, ...) { + jcol <- lapply(list(...), function (x) { + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, + jcol) + column(jc) + }) + + +#' concat +#' +#' Concatenates multiple input string columns together into a single string column. +#' +#' @family string_funcs +#' @rdname concat +#' @name concat +#' @export +setMethod("concat", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function(x) { x@jc }) + jc <- callJStatic("org.apache.spark.sql.functions", "concat", jcols) + column(jc) + }) + +#' greatest +#' +#' Returns the greatest value of the list of column names, skipping null values. +#' This function takes at least 2 parameters. It will return null if all parameters are null. +#' +#' @family normal_funcs +#' @rdname greatest +#' @name greatest +#' @export +setMethod("greatest", + signature(x = "Column"), + function(x, ...) { + stopifnot(length(list(...)) > 0) + jcols <- lapply(list(x, ...), function(x) { x@jc }) + jc <- callJStatic("org.apache.spark.sql.functions", "greatest", jcols) + column(jc) + }) + +#' least +#' +#' Returns the least value of the list of column names, skipping null values. +#' This function takes at least 2 parameters. It will return null iff all parameters are null. +#' +#' @family normal_funcs +#' @rdname least +#' @name least +#' @export +setMethod("least", + signature(x = "Column"), + function(x, ...) { + stopifnot(length(list(...)) > 0) + jcols <- lapply(list(x, ...), function(x) { x@jc }) + jc <- callJStatic("org.apache.spark.sql.functions", "least", jcols) + column(jc) + }) + +#' ceiling +#' +#' Computes the ceiling of the given value. +#' +#' @family math_funcs +#' @rdname ceil +#' @name ceil +#' @aliases ceil +#' @export +setMethod("ceiling", + signature(x = "Column"), + function(x) { + ceil(x) + }) + +#' sign +#' +#' Computes the signum of the given value. +#' +#' @family math_funcs +#' @rdname signum +#' @name signum +#' @aliases signum +#' @export +setMethod("sign", signature(x = "Column"), + function(x) { + signum(x) + }) + +#' n_distinct +#' +#' Aggregate function: returns the number of distinct items in a group. +#' +#' @family agg_funcs +#' @rdname countDistinct +#' @name countDistinct +#' @aliases countDistinct +#' @export +setMethod("n_distinct", signature(x = "Column"), + function(x, ...) { + countDistinct(x, ...) + }) + +#' n +#' +#' Aggregate function: returns the number of items in a group. +#' +#' @family agg_funcs +#' @rdname count +#' @name count +#' @aliases count +#' @export +setMethod("n", signature(x = "Column"), + function(x) { + count(x) + }) + +#' date_format +#' +#' Converts a date/timestamp/string to a value of string in the format specified by the date +#' format given by the second argument. +#' +#' A pattern could be for instance \preformatted{dd.MM.yyyy} and could return a string like '18.03.1993'. All +#' pattern letters of \code{java.text.SimpleDateFormat} can be used. +#' +#' NOTE: Use when ever possible specialized functions like \code{year}. These benefit from a +#' specialized implementation. +#' +#' @family datetime_funcs +#' @rdname date_format +#' @name date_format +#' @export +setMethod("date_format", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "date_format", y@jc, x) + column(jc) + }) + +#' from_utc_timestamp +#' +#' Assumes given timestamp is UTC and converts to given timezone. +#' +#' @family datetime_funcs +#' @rdname from_utc_timestamp +#' @name from_utc_timestamp +#' @export +setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "from_utc_timestamp", y@jc, x) + column(jc) + }) + +#' instr +#' +#' Locate the position of the first occurrence of substr column in the given string. +#' Returns null if either of the arguments are null. +#' +#' NOTE: The position is not zero based, but 1 based index, returns 0 if substr +#' could not be found in str. +#' +#' @family string_funcs +#' @rdname instr +#' @name instr +#' @export +setMethod("instr", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "instr", y@jc, x) + column(jc) + }) + +#' next_day +#' +#' Given a date column, returns the first date which is later than the value of the date column +#' that is on the specified day of the week. +#' +#' For example, \code{next_day('2015-07-27', "Sunday")} returns 2015-08-02 because that is the first +#' Sunday after 2015-07-27. +#' +#' Day of the week parameter is case insensitive, and accepts: +#' "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". +#' +#' @family datetime_funcs +#' @rdname next_day +#' @name next_day +#' @export +setMethod("next_day", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "next_day", y@jc, x) + column(jc) + }) + +#' to_utc_timestamp +#' +#' Assumes given timestamp is in given timezone and converts to UTC. +#' +#' @family datetime_funcs +#' @rdname to_utc_timestamp +#' @name to_utc_timestamp +#' @export +setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "to_utc_timestamp", y@jc, x) + column(jc) + }) + +#' add_months +#' +#' Returns the date that is numMonths after startDate. +#' +#' @name add_months +#' @family datetime_funcs +#' @rdname add_months +#' @name add_months +#' @export +setMethod("add_months", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "add_months", y@jc, as.integer(x)) + column(jc) + }) + +#' date_add +#' +#' Returns the date that is `days` days after `start` +#' +#' @family datetime_funcs +#' @rdname date_add +#' @name date_add +#' @export +setMethod("date_add", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "date_add", y@jc, as.integer(x)) + column(jc) + }) + +#' date_sub +#' +#' Returns the date that is `days` days before `start` +#' +#' @family datetime_funcs +#' @rdname date_sub +#' @name date_sub +#' @export +setMethod("date_sub", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "date_sub", y@jc, as.integer(x)) + column(jc) + }) + +#' format_number +#' +#' Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places, +#' and returns the result as a string column. +#' +#' If d is 0, the result has no decimal point or fractional part. +#' If d < 0, the result will be null.' +#' +#' @family string_funcs +#' @rdname format_number +#' @name format_number +#' @export +setMethod("format_number", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "format_number", + y@jc, as.integer(x)) + column(jc) + }) + +#' sha2 +#' +#' Calculates the SHA-2 family of hash functions of a binary column and +#' returns the value as a hex string. +#' +#' @param y column to compute SHA-2 on. +#' @param x one of 224, 256, 384, or 512. +#' @family misc_funcs +#' @rdname sha2 +#' @name sha2 +#' @export +setMethod("sha2", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sha2", y@jc, as.integer(x)) + column(jc) + }) + +#' shiftLeft +#' +#' Shift the the given value numBits left. If the given value is a long value, this function +#' will return a long value else it will return an integer value. +#' +#' @family math_funcs +#' @rdname shiftLeft +#' @name shiftLeft +#' @export +setMethod("shiftLeft", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "shiftLeft", + y@jc, as.integer(x)) + column(jc) + }) + +#' shiftRight +#' +#' Shift the the given value numBits right. If the given value is a long value, it will return +#' a long value else it will return an integer value. +#' +#' @family math_funcs +#' @rdname shiftRight +#' @name shiftRight +#' @export +setMethod("shiftRight", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "shiftRight", + y@jc, as.integer(x)) + column(jc) + }) + +#' shiftRightUnsigned +#' +#' Unsigned shift the the given value numBits right. If the given value is a long value, +#' it will return a long value else it will return an integer value. +#' +#' @family math_funcs +#' @rdname shiftRightUnsigned +#' @name shiftRightUnsigned +#' @export +setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "shiftRightUnsigned", + y@jc, as.integer(x)) + column(jc) + }) + +#' concat_ws +#' +#' Concatenates multiple input string columns together into a single string column, +#' using the given separator. +#' +#' @family string_funcs +#' @rdname concat_ws +#' @name concat_ws +#' @export +setMethod("concat_ws", signature(sep = "character", x = "Column"), + function(sep, x, ...) { + jcols <- lapply(list(x, ...), function(x) { x@jc }) + jc <- callJStatic("org.apache.spark.sql.functions", "concat_ws", sep, jcols) + column(jc) + }) + +#' conv +#' +#' Convert a number in a string column from one base to another. +#' +#' @family math_funcs +#' @rdname conv +#' @name conv +#' @export +setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeric"), + function(x, fromBase, toBase) { + fromBase <- as.integer(fromBase) + toBase <- as.integer(toBase) + jc <- callJStatic("org.apache.spark.sql.functions", + "conv", + x@jc, fromBase, toBase) + column(jc) + }) + +#' expr +#' +#' Parses the expression string into the column that it represents, similar to +#' DataFrame.selectExpr +#' +#' @family normal_funcs +#' @rdname expr +#' @name expr +#' @export +setMethod("expr", signature(x = "character"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "expr", x) + column(jc) + }) + +#' format_string +#' +#' Formats the arguments in printf-style and returns the result as a string column. +#' +#' @family string_funcs +#' @rdname format_string +#' @name format_string +#' @export +setMethod("format_string", signature(format = "character", x = "Column"), + function(format, x, ...) { + jcols <- lapply(list(x, ...), function(arg) { arg@jc }) + jc <- callJStatic("org.apache.spark.sql.functions", + "format_string", + format, jcols) + column(jc) + }) + +#' from_unixtime +#' +#' Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string +#' representing the timestamp of that moment in the current system time zone in the given +#' format. +#' +#' @family datetime_funcs +#' @rdname from_unixtime +#' @name from_unixtime +#' @export +setMethod("from_unixtime", signature(x = "Column"), + function(x, format = "yyyy-MM-dd HH:mm:ss") { + jc <- callJStatic("org.apache.spark.sql.functions", + "from_unixtime", + x@jc, format) + column(jc) + }) + +#' locate +#' +#' Locate the position of the first occurrence of substr. +#' NOTE: The position is not zero based, but 1 based index, returns 0 if substr +#' could not be found in str. +#' +#' @family string_funcs +#' @rdname locate +#' @name locate +#' @export +setMethod("locate", signature(substr = "character", str = "Column"), + function(substr, str, pos = 0) { + jc <- callJStatic("org.apache.spark.sql.functions", + "locate", + substr, str@jc, as.integer(pos)) + column(jc) + }) + +#' lpad +#' +#' Left-pad the string column with +#' +#' @family string_funcs +#' @rdname lpad +#' @name lpad +#' @export +setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), + function(x, len, pad) { + jc <- callJStatic("org.apache.spark.sql.functions", + "lpad", + x@jc, as.integer(len), pad) + column(jc) + }) + +#' rand +#' +#' Generate a random column with i.i.d. samples from U[0.0, 1.0]. +#' +#' @family normal_funcs +#' @rdname rand +#' @name rand +#' @export +setMethod("rand", signature(seed = "missing"), + function(seed) { + jc <- callJStatic("org.apache.spark.sql.functions", "rand") + column(jc) + }) +#' @family normal_funcs +#' @rdname rand +#' @name rand +#' @export +setMethod("rand", signature(seed = "numeric"), + function(seed) { + jc <- callJStatic("org.apache.spark.sql.functions", "rand", as.integer(seed)) + column(jc) + }) + +#' randn +#' +#' Generate a column with i.i.d. samples from the standard normal distribution. +#' +#' @family normal_funcs +#' @rdname randn +#' @name randn +#' @export +setMethod("randn", signature(seed = "missing"), + function(seed) { + jc <- callJStatic("org.apache.spark.sql.functions", "randn") + column(jc) + }) +#' @family normal_funcs +#' @rdname randn +#' @name randn +#' @export +setMethod("randn", signature(seed = "numeric"), + function(seed) { + jc <- callJStatic("org.apache.spark.sql.functions", "randn", as.integer(seed)) + column(jc) + }) + +#' regexp_extract +#' +#' Extract a specific(idx) group identified by a java regex, from the specified string column. +#' +#' @family string_funcs +#' @rdname regexp_extract +#' @name regexp_extract +#' @export +setMethod("regexp_extract", + signature(x = "Column", pattern = "character", idx = "numeric"), + function(x, pattern, idx) { + jc <- callJStatic("org.apache.spark.sql.functions", + "regexp_extract", + x@jc, pattern, as.integer(idx)) + column(jc) + }) + +#' regexp_replace +#' +#' Replace all substrings of the specified string value that match regexp with rep. +#' +#' @family string_funcs +#' @rdname regexp_replace +#' @name regexp_replace +#' @export +setMethod("regexp_replace", + signature(x = "Column", pattern = "character", replacement = "character"), + function(x, pattern, replacement) { + jc <- callJStatic("org.apache.spark.sql.functions", + "regexp_replace", + x@jc, pattern, replacement) + column(jc) + }) + +#' rpad +#' +#' Right-padded with pad to a length of len. +#' +#' @family string_funcs +#' @rdname rpad +#' @name rpad +#' @export +setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), + function(x, len, pad) { + jc <- callJStatic("org.apache.spark.sql.functions", + "rpad", + x@jc, as.integer(len), pad) + column(jc) + }) + +#' substring_index +#' +#' Returns the substring from string str before count occurrences of the delimiter delim. +#' If count is positive, everything the left of the final delimiter (counting from left) is +#' returned. If count is negative, every to the right of the final delimiter (counting from the +#' right) is returned. substring <- index performs a case-sensitive match when searching for delim. +#' +#' @family string_funcs +#' @rdname substring_index +#' @name substring_index +#' @export +setMethod("substring_index", + signature(x = "Column", delim = "character", count = "numeric"), + function(x, delim, count) { + jc <- callJStatic("org.apache.spark.sql.functions", + "substring_index", + x@jc, delim, as.integer(count)) + column(jc) + }) + +#' translate +#' +#' Translate any character in the src by a character in replaceString. +#' The characters in replaceString is corresponding to the characters in matchingString. +#' The translate will happen when any character in the string matching with the character +#' in the matchingString. +#' +#' @family string_funcs +#' @rdname translate +#' @name translate +#' @export +setMethod("translate", + signature(x = "Column", matchingString = "character", replaceString = "character"), + function(x, matchingString, replaceString) { + jc <- callJStatic("org.apache.spark.sql.functions", + "translate", x@jc, matchingString, replaceString) + column(jc) + }) + +#' unix_timestamp +#' +#' Gets current Unix timestamp in seconds. +#' +#' @family datetime_funcs +#' @rdname unix_timestamp +#' @name unix_timestamp +#' @export +setMethod("unix_timestamp", signature(x = "missing", format = "missing"), + function(x, format) { + jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp") + column(jc) + }) +#' @family datetime_funcs +#' @rdname unix_timestamp +#' @name unix_timestamp +#' @export +setMethod("unix_timestamp", signature(x = "Column", format = "missing"), + function(x, format) { + jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp", x@jc) + column(jc) + }) +#' @family datetime_funcs +#' @rdname unix_timestamp +#' @name unix_timestamp +#' @export +setMethod("unix_timestamp", signature(x = "Column", format = "character"), + function(x, format = "yyyy-MM-dd HH:mm:ss") { + jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp", x@jc, format) + column(jc) + }) +#' when +#' +#' Evaluates a list of conditions and returns one of multiple possible result expressions. +#' For unmatched expressions null is returned. +#' +#' @family normal_funcs +#' @rdname when +#' @name when +#' @export +setMethod("when", signature(condition = "Column", value = "ANY"), + function(condition, value) { + condition <- condition@jc + value <- ifelse(class(value) == "Column", value@jc, value) + jc <- callJStatic("org.apache.spark.sql.functions", "when", condition, value) + column(jc) + }) + +#' ifelse +#' +#' Evaluates a list of conditions and returns \code{yes} if the conditions are satisfied. +#' Otherwise \code{no} is returned for unmatched conditions. +#' +#' @family normal_funcs +#' @rdname ifelse +#' @name ifelse +#' @export +setMethod("ifelse", + signature(test = "Column", yes = "ANY", no = "ANY"), + function(test, yes, no) { + test <- test@jc + yes <- ifelse(class(yes) == "Column", yes@jc, yes) + no <- ifelse(class(no) == "Column", no@jc, no) + jc <- callJMethod(callJStatic("org.apache.spark.sql.functions", + "when", + test, yes), + "otherwise", no) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 12e09176c9f9..43dd8d283ab6 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -20,7 +20,8 @@ # @rdname aggregateRDD # @seealso reduce # @export -setGeneric("aggregateRDD", function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) +setGeneric("aggregateRDD", + function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) # @rdname cache-methods # @export @@ -58,6 +59,10 @@ setGeneric("count", function(x) { standardGeneric("count") }) # @export setGeneric("countByValue", function(x) { standardGeneric("countByValue") }) +# @rdname statfunctions +# @export +setGeneric("crosstab", function(x, col1, col2) { standardGeneric("crosstab") }) + # @rdname distinct # @export setGeneric("distinct", function(x, numPartitions = 1) { standardGeneric("distinct") }) @@ -130,7 +135,7 @@ setGeneric("maximum", function(x) { standardGeneric("maximum") }) # @export setGeneric("minimum", function(x) { standardGeneric("minimum") }) -# @rdname sumRDD +# @rdname sumRDD # @export setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") }) @@ -219,7 +224,7 @@ setGeneric("zipRDD", function(x, other) { standardGeneric("zipRDD") }) # @rdname zipRDD # @export -setGeneric("zipPartitions", function(..., func) { standardGeneric("zipPartitions") }, +setGeneric("zipPartitions", function(..., func) { standardGeneric("zipPartitions") }, signature = "...") # @rdname zipWithIndex @@ -249,8 +254,10 @@ setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") # @rdname intersection # @export -setGeneric("intersection", function(x, other, numPartitions = 1) { - standardGeneric("intersection") }) +setGeneric("intersection", + function(x, other, numPartitions = 1) { + standardGeneric("intersection") + }) # @rdname keys # @export @@ -364,7 +371,7 @@ setGeneric("subtract", # @rdname subtractByKey # @export -setGeneric("subtractByKey", +setGeneric("subtractByKey", function(x, other, numPartitions = 1) { standardGeneric("subtractByKey") }) @@ -399,15 +406,15 @@ setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) #' @rdname nafunctions #' @export setGeneric("dropna", - function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { - standardGeneric("dropna") + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + standardGeneric("dropna") }) #' @rdname nafunctions #' @export setGeneric("na.omit", - function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { - standardGeneric("na.omit") + function(object, ...) { + standardGeneric("na.omit") }) #' @rdname schema @@ -434,7 +441,7 @@ setGeneric("filter", function(x, condition) { standardGeneric("filter") }) #' @export setGeneric("group_by", function(x, ...) { standardGeneric("group_by") }) -#' @rdname DataFrame +#' @rdname groupBy #' @export setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) @@ -454,9 +461,13 @@ setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) #' @export setGeneric("limit", function(x, num) {standardGeneric("limit") }) +#' rdname merge +#' @export +setGeneric("merge") + #' @rdname withColumn #' @export -setGeneric("mutate", function(x, ...) {standardGeneric("mutate") }) +setGeneric("mutate", function(.data, ...) {standardGeneric("mutate") }) #' @rdname arrange #' @export @@ -484,9 +495,7 @@ setGeneric("sample", #' @rdname sample #' @export setGeneric("sample_frac", - function(x, withReplacement, fraction, seed) { - standardGeneric("sample_frac") - }) + function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") }) #' @rdname saveAsParquetFile #' @export @@ -498,6 +507,10 @@ setGeneric("saveAsTable", function(df, tableName, source, mode, ...) { standardGeneric("saveAsTable") }) +#' @rdname withColumn +#' @export +setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") }) + #' @rdname write.df #' @export setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) @@ -522,10 +535,18 @@ setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") #' @export setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) +# @rdname subset +# @export +setGeneric("subset", function(x, subset, select, ...) { standardGeneric("subset") }) + #' @rdname agg #' @export setGeneric("summarize", function(x,...) { standardGeneric("summarize") }) +#' @rdname summary +#' @export +setGeneric("summary", function(x, ...) { standardGeneric("summary") }) + # @rdname tojson # @export setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) @@ -548,38 +569,27 @@ setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn #' @rdname withColumnRenamed #' @export -setGeneric("withColumnRenamed", function(x, existingCol, newCol) { - standardGeneric("withColumnRenamed") }) +setGeneric("withColumnRenamed", + function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) ###################### Column Methods ########################## -#' @rdname column -#' @export -setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) - #' @rdname column #' @export setGeneric("asc", function(x) { standardGeneric("asc") }) #' @rdname column #' @export -setGeneric("avg", function(x, ...) { standardGeneric("avg") }) +setGeneric("between", function(x, bounds) { standardGeneric("between") }) #' @rdname column #' @export setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) -#' @rdname column -#' @export -setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) - #' @rdname column #' @export setGeneric("contains", function(x, ...) { standardGeneric("contains") }) -#' @rdname column -#' @export -setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) #' @rdname column #' @export @@ -599,61 +609,377 @@ setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) #' @rdname column #' @export -setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) +setGeneric("isNull", function(x) { standardGeneric("isNull") }) #' @rdname column #' @export -setGeneric("isNull", function(x) { standardGeneric("isNull") }) +setGeneric("isNotNull", function(x) { standardGeneric("isNotNull") }) #' @rdname column #' @export -setGeneric("isNotNull", function(x) { standardGeneric("isNotNull") }) +setGeneric("like", function(x, ...) { standardGeneric("like") }) #' @rdname column #' @export -setGeneric("last", function(x) { standardGeneric("last") }) +setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) #' @rdname column #' @export -setGeneric("like", function(x, ...) { standardGeneric("like") }) +setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") }) #' @rdname column #' @export -setGeneric("lower", function(x) { standardGeneric("lower") }) +setGeneric("when", function(condition, value) { standardGeneric("when") }) #' @rdname column #' @export +setGeneric("otherwise", function(x, value) { standardGeneric("otherwise") }) + + +###################### Expression Function Methods ########################## + +#' @rdname add_months +#' @export +setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) + +#' @rdname approxCountDistinct +#' @export +setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) + +#' @rdname ascii +#' @export +setGeneric("ascii", function(x) { standardGeneric("ascii") }) + +#' @rdname avg +#' @export +setGeneric("avg", function(x, ...) { standardGeneric("avg") }) + +#' @rdname base64 +#' @export +setGeneric("base64", function(x) { standardGeneric("base64") }) + +#' @rdname bin +#' @export +setGeneric("bin", function(x) { standardGeneric("bin") }) + +#' @rdname bitwiseNOT +#' @export +setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") }) + +#' @rdname cbrt +#' @export +setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) + +#' @rdname ceil +#' @export +setGeneric("ceil", function(x) { standardGeneric("ceil") }) + +#' @rdname concat +#' @export +setGeneric("concat", function(x, ...) { standardGeneric("concat") }) + +#' @rdname concat_ws +#' @export +setGeneric("concat_ws", function(sep, x, ...) { standardGeneric("concat_ws") }) + +#' @rdname conv +#' @export +setGeneric("conv", function(x, fromBase, toBase) { standardGeneric("conv") }) + +#' @rdname countDistinct +#' @export +setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) + +#' @rdname crc32 +#' @export +setGeneric("crc32", function(x) { standardGeneric("crc32") }) + +#' @rdname datediff +#' @export +setGeneric("datediff", function(y, x) { standardGeneric("datediff") }) + +#' @rdname date_add +#' @export +setGeneric("date_add", function(y, x) { standardGeneric("date_add") }) + +#' @rdname date_format +#' @export +setGeneric("date_format", function(y, x) { standardGeneric("date_format") }) + +#' @rdname date_sub +#' @export +setGeneric("date_sub", function(y, x) { standardGeneric("date_sub") }) + +#' @rdname dayofmonth +#' @export +setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) + +#' @rdname dayofyear +#' @export +setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) + +#' @rdname explode +#' @export +setGeneric("explode", function(x) { standardGeneric("explode") }) + +#' @rdname expr +#' @export +setGeneric("expr", function(x) { standardGeneric("expr") }) + +#' @rdname from_utc_timestamp +#' @export +setGeneric("from_utc_timestamp", function(y, x) { standardGeneric("from_utc_timestamp") }) + +#' @rdname format_number +#' @export +setGeneric("format_number", function(y, x) { standardGeneric("format_number") }) + +#' @rdname format_string +#' @export +setGeneric("format_string", function(format, x, ...) { standardGeneric("format_string") }) + +#' @rdname from_unixtime +#' @export +setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") }) + +#' @rdname greatest +#' @export +setGeneric("greatest", function(x, ...) { standardGeneric("greatest") }) + +#' @rdname hex +#' @export +setGeneric("hex", function(x) { standardGeneric("hex") }) + +#' @rdname hour +#' @export +setGeneric("hour", function(x) { standardGeneric("hour") }) + +#' @rdname hypot +#' @export +setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) + +#' @rdname initcap +#' @export +setGeneric("initcap", function(x) { standardGeneric("initcap") }) + +#' @rdname instr +#' @export +setGeneric("instr", function(y, x) { standardGeneric("instr") }) + +#' @rdname isNaN +#' @export +setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) + +#' @rdname last +#' @export +setGeneric("last", function(x) { standardGeneric("last") }) + +#' @rdname last_day +#' @export +setGeneric("last_day", function(x) { standardGeneric("last_day") }) + +#' @rdname least +#' @export +setGeneric("least", function(x, ...) { standardGeneric("least") }) + +#' @rdname levenshtein +#' @export +setGeneric("levenshtein", function(y, x) { standardGeneric("levenshtein") }) + +#' @rdname lit +#' @export +setGeneric("lit", function(x) { standardGeneric("lit") }) + +#' @rdname locate +#' @export +setGeneric("locate", function(substr, str, ...) { standardGeneric("locate") }) + +#' @rdname lower +#' @export +setGeneric("lower", function(x) { standardGeneric("lower") }) + +#' @rdname lpad +#' @export +setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) + +#' @rdname ltrim +#' @export +setGeneric("ltrim", function(x) { standardGeneric("ltrim") }) + +#' @rdname md5 +#' @export +setGeneric("md5", function(x) { standardGeneric("md5") }) + +#' @rdname minute +#' @export +setGeneric("minute", function(x) { standardGeneric("minute") }) + +#' @rdname month +#' @export +setGeneric("month", function(x) { standardGeneric("month") }) + +#' @rdname months_between +#' @export +setGeneric("months_between", function(y, x) { standardGeneric("months_between") }) + +#' @rdname count +#' @export setGeneric("n", function(x) { standardGeneric("n") }) -#' @rdname column +#' @rdname nanvl +#' @export +setGeneric("nanvl", function(y, x) { standardGeneric("nanvl") }) + +#' @rdname negate +#' @export +setGeneric("negate", function(x) { standardGeneric("negate") }) + +#' @rdname next_day +#' @export +setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) + +#' @rdname countDistinct #' @export setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) -#' @rdname column +#' @rdname pmod +#' @export +setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) + +#' @rdname quarter +#' @export +setGeneric("quarter", function(x) { standardGeneric("quarter") }) + +#' @rdname rand +#' @export +setGeneric("rand", function(seed) { standardGeneric("rand") }) + +#' @rdname randn +#' @export +setGeneric("randn", function(seed) { standardGeneric("randn") }) + +#' @rdname regexp_extract +#' @export +setGeneric("regexp_extract", function(x, pattern, idx) { standardGeneric("regexp_extract") }) + +#' @rdname regexp_replace +#' @export +setGeneric("regexp_replace", + function(x, pattern, replacement) { standardGeneric("regexp_replace") }) + +#' @rdname reverse +#' @export +setGeneric("reverse", function(x) { standardGeneric("reverse") }) + +#' @rdname rint #' @export setGeneric("rint", function(x, ...) { standardGeneric("rint") }) -#' @rdname column +#' @rdname rpad #' @export -setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) +setGeneric("rpad", function(x, len, pad) { standardGeneric("rpad") }) -#' @rdname column +#' @rdname rtrim #' @export -setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") }) +setGeneric("rtrim", function(x) { standardGeneric("rtrim") }) -#' @rdname column +#' @rdname second +#' @export +setGeneric("second", function(x) { standardGeneric("second") }) + +#' @rdname sha1 +#' @export +setGeneric("sha1", function(x) { standardGeneric("sha1") }) + +#' @rdname sha2 +#' @export +setGeneric("sha2", function(y, x) { standardGeneric("sha2") }) + +#' @rdname shiftLeft +#' @export +setGeneric("shiftLeft", function(y, x) { standardGeneric("shiftLeft") }) + +#' @rdname shiftRight +#' @export +setGeneric("shiftRight", function(y, x) { standardGeneric("shiftRight") }) + +#' @rdname shiftRightUnsigned +#' @export +setGeneric("shiftRightUnsigned", function(y, x) { standardGeneric("shiftRightUnsigned") }) + +#' @rdname signum +#' @export +setGeneric("signum", function(x) { standardGeneric("signum") }) + +#' @rdname size +#' @export +setGeneric("size", function(x) { standardGeneric("size") }) + +#' @rdname soundex +#' @export +setGeneric("soundex", function(x) { standardGeneric("soundex") }) + +#' @rdname substring_index +#' @export +setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") }) + +#' @rdname sumDistinct #' @export setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) -#' @rdname column +#' @rdname toDegrees #' @export setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") }) -#' @rdname column +#' @rdname toRadians #' @export setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) -#' @rdname column +#' @rdname to_date +#' @export +setGeneric("to_date", function(x) { standardGeneric("to_date") }) + +#' @rdname to_utc_timestamp +#' @export +setGeneric("to_utc_timestamp", function(y, x) { standardGeneric("to_utc_timestamp") }) + +#' @rdname translate +#' @export +setGeneric("translate", function(x, matchingString, replaceString) { standardGeneric("translate") }) + +#' @rdname trim +#' @export +setGeneric("trim", function(x) { standardGeneric("trim") }) + +#' @rdname unbase64 +#' @export +setGeneric("unbase64", function(x) { standardGeneric("unbase64") }) + +#' @rdname unhex +#' @export +setGeneric("unhex", function(x) { standardGeneric("unhex") }) + +#' @rdname unix_timestamp +#' @export +setGeneric("unix_timestamp", function(x, format) { standardGeneric("unix_timestamp") }) + +#' @rdname upper #' @export setGeneric("upper", function(x) { standardGeneric("upper") }) +#' @rdname weekofyear +#' @export +setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) + +#' @rdname year +#' @export +setGeneric("year", function(x) { standardGeneric("year") }) + + +#' @rdname glm +#' @export +setGeneric("glm") + +#' @rdname rbind +#' @export +setGeneric("rbind", signature = "...") diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index b75848199757..4cab1a69f601 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -87,7 +87,7 @@ setMethod("count", setMethod("agg", signature(x = "GroupedData"), function(x, ...) { - cols = list(...) + cols <- list(...) stopifnot(length(cols) > 0) if (is.character(cols[[1]])) { cols <- varargsToEnv(...) @@ -97,12 +97,12 @@ setMethod("agg", if (!is.null(ns)) { for (n in ns) { if (n != "") { - cols[[n]] = alias(cols[[n]], n) + cols[[n]] <- alias(cols[[n]], n) } } } jcols <- lapply(cols, function(c) { c@jc }) - sdf <- callJMethod(x@sgd, "agg", jcols[[1]], listToSeq(jcols[-1])) + sdf <- callJMethod(x@sgd, "agg", jcols[[1]], jcols[-1]) } else { stop("agg can only support Column or character") } @@ -124,7 +124,7 @@ createMethod <- function(name) { setMethod(name, signature(x = "GroupedData"), function(x, ...) { - sdf <- callJMethod(x@sgd, name, toSeq(...)) + sdf <- callJMethod(x@sgd, name, list(...)) dataFrame(sdf) }) } @@ -136,4 +136,3 @@ createMethods <- function() { } createMethods() - diff --git a/R/pkg/R/jobj.R b/R/pkg/R/jobj.R index a8a25230b636..0838a7bb35e0 100644 --- a/R/pkg/R/jobj.R +++ b/R/pkg/R/jobj.R @@ -16,7 +16,7 @@ # # References to objects that exist on the JVM backend -# are maintained using the jobj. +# are maintained using the jobj. #' @include generics.R NULL diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R new file mode 100644 index 000000000000..cea3d760d05f --- /dev/null +++ b/R/pkg/R/mllib.R @@ -0,0 +1,99 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib.R: Provides methods for MLlib integration + +#' @title S4 class that represents a PipelineModel +#' @param model A Java object reference to the backing Scala PipelineModel +#' @export +setClass("PipelineModel", representation(model = "jobj")) + +#' Fits a generalized linear model +#' +#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package. +#' +#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '+', '-', and '.'. +#' @param data DataFrame for training +#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. +#' @param lambda Regularization parameter +#' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details) +#' @return a fitted MLlib model +#' @rdname glm +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' data(iris) +#' df <- createDataFrame(sqlContext, iris) +#' model <- glm(Sepal_Length ~ Sepal_Width, df) +#'} +setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"), + function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0) { + family <- match.arg(family) + model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "fitRModelFormula", deparse(formula), data@sdf, family, lambda, + alpha) + return(new("PipelineModel", model = model)) + }) + +#' Make predictions from a model +#' +#' Makes predictions from a model produced by glm(), similarly to R's predict(). +#' +#' @param object A fitted MLlib model +#' @param newData DataFrame for testing +#' @return DataFrame containing predicted values +#' @rdname predict +#' @export +#' @examples +#'\dontrun{ +#' model <- glm(y ~ x, trainingData) +#' predicted <- predict(model, testData) +#' showDF(predicted) +#'} +setMethod("predict", signature(object = "PipelineModel"), + function(object, newData) { + return(dataFrame(callJMethod(object@model, "transform", newData@sdf))) + }) + +#' Get the summary of a model +#' +#' Returns the summary of a model produced by glm(), similarly to R's summary(). +#' +#' @param x A fitted MLlib model +#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See +#' summary.glm for more information. +#' @rdname summary +#' @export +#' @examples +#'\dontrun{ +#' model <- glm(y ~ x, trainingData) +#' summary(model) +#'} +setMethod("summary", signature(x = "PipelineModel"), + function(x, ...) { + features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelFeatures", x@model) + weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelWeights", x@model) + coefficients <- as.matrix(unlist(weights)) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + return(list(coefficients = coefficients)) + }) diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 1e24286dbcae..199c3fd6ab1b 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -202,8 +202,8 @@ setMethod("partitionBy", packageNamesArr <- serialize(.sparkREnv$.packages, connection = NULL) - broadcastArr <- lapply(ls(.broadcastNames), function(name) { - get(name, .broadcastNames) }) + broadcastArr <- lapply(ls(.broadcastNames), + function(name) { get(name, .broadcastNames) }) jrdd <- getJRDD(x) # We create a PairwiseRRDD that extends RDD[(Int, Array[Byte])], @@ -215,7 +215,6 @@ setMethod("partitionBy", serializedHashFuncBytes, getSerializedMode(x), packageNamesArr, - as.character(.sparkREnv$libname), broadcastArr, callJMethod(jrdd, "classTag")) @@ -560,8 +559,8 @@ setMethod("join", # Left outer join two RDDs # # @description -# \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of the form list(K, V). -# The key types of the two RDDs should be the same. +# \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of +# the form list(K, V). The key types of the two RDDs should be the same. # # @param x An RDD to be joined. Should be an RDD where each element is # list(K, V). @@ -597,8 +596,8 @@ setMethod("leftOuterJoin", # Right outer join two RDDs # # @description -# \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of the form list(K, V). -# The key types of the two RDDs should be the same. +# \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of +# the form list(K, V). The key types of the two RDDs should be the same. # # @param x An RDD to be joined. Should be an RDD where each element is # list(K, V). @@ -634,8 +633,8 @@ setMethod("rightOuterJoin", # Full outer join two RDDs # # @description -# \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V). -# The key types of the two RDDs should be the same. +# \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of +# the form list(K, V). The key types of the two RDDs should be the same. # # @param x An RDD to be joined. Should be an RDD where each element is # list(K, V). @@ -784,7 +783,7 @@ setMethod("sortByKey", newRDD <- partitionBy(x, numPartitions, rangePartitionFunc) lapplyPartition(newRDD, partitionFunc) }) - + # Subtract a pair RDD with another pair RDD. # # Return an RDD with the pairs from x whose keys are not in other. @@ -820,7 +819,7 @@ setMethod("subtractByKey", }) # Return a subset of this RDD sampled by key. -# +# # @description # \code{sampleByKey} Create a sample of this RDD using variable sampling rates # for different keys as specified by fractions, a key to sampling rate map. @@ -880,7 +879,7 @@ setMethod("sampleByKey", if (withReplacement) { count <- rpois(1, frac) if (count > 0) { - res[(len + 1):(len + count)] <- rep(list(elem), count) + res[ (len + 1) : (len + count) ] <- rep(list(elem), count) len <- len + count } } else { diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index e442119086b1..8df1563f8ebc 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -20,7 +20,7 @@ #' structType #' -#' Create a structType object that contains the metadata for a DataFrame. Intended for +#' Create a structType object that contains the metadata for a DataFrame. Intended for #' use with createDataFrame and toDF. #' #' @param x a structField object (created with the field() function) @@ -56,7 +56,7 @@ structType.structField <- function(x, ...) { }) stObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createStructType", - listToSeq(sfObjList)) + sfObjList) structType(stObj) } @@ -69,11 +69,14 @@ structType.structField <- function(x, ...) { #' @param ... further arguments passed to or from other methods print.structType <- function(x, ...) { cat("StructType\n", - sapply(x$fields(), function(field) { paste("|-", "name = \"", field$name(), - "\", type = \"", field$dataType.toString(), - "\", nullable = ", field$nullable(), "\n", - sep = "") }) - , sep = "") + sapply(x$fields(), + function(field) { + paste("|-", "name = \"", field$name(), + "\", type = \"", field$dataType.toString(), + "\", nullable = ", field$nullable(), "\n", + sep = "") + }), + sep = "") } #' structField @@ -111,6 +114,55 @@ structField.jobj <- function(x) { obj } +checkType <- function(type) { + primtiveTypes <- c("byte", + "integer", + "float", + "double", + "numeric", + "character", + "string", + "binary", + "raw", + "logical", + "boolean", + "timestamp", + "date") + if (type %in% primtiveTypes) { + return() + } else { + # Check complex types + firstChar <- substr(type, 1, 1) + switch (firstChar, + a = { + # Array type + m <- regexec("^array<(.*)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 2) { + elemType <- matchedStrings[[1]][2] + checkType(elemType) + return() + } + }, + m = { + # Map type + m <- regexec("^map<(.*),(.*)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 3) { + keyType <- matchedStrings[[1]][2] + if (keyType != "string" && keyType != "character") { + stop("Key type in a map must be string or character") + } + valueType <- matchedStrings[[1]][3] + checkType(valueType) + return() + } + }) + } + + stop(paste("Unsupported type for Dataframe:", type)) +} + structField.character <- function(x, type, nullable = TRUE) { if (class(x) != "character") { stop("Field name must be a string.") @@ -121,27 +173,13 @@ structField.character <- function(x, type, nullable = TRUE) { if (class(nullable) != "logical") { stop("nullable must be either TRUE or FALSE") } - options <- c("byte", - "integer", - "double", - "numeric", - "character", - "string", - "binary", - "raw", - "logical", - "boolean", - "timestamp", - "date") - dataType <- if (type %in% options) { - type - } else { - stop(paste("Unsupported type for Dataframe:", type)) - } + + checkType(type) + sfObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createStructField", x, - dataType, + type, nullable) structField(sfObj) } diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 3169d7968f8f..91e6b3e5609b 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -79,7 +79,7 @@ writeJobj <- function(con, value) { writeString <- function(con, value) { utfVal <- enc2utf8(value) writeInt(con, as.integer(nchar(utfVal, type = "bytes") + 1)) - writeBin(utfVal, con, endian = "big") + writeBin(utfVal, con, endian = "big", useBytes=TRUE) } writeInt <- function(con, value) { @@ -110,18 +110,10 @@ writeRowSerialize <- function(outputCon, rows) { serializeRow <- function(row) { rawObj <- rawConnection(raw(0), "wb") on.exit(close(rawObj)) - writeRow(rawObj, row) + writeGenericList(rawObj, row) rawConnectionValue(rawObj) } -writeRow <- function(con, row) { - numCols <- length(row) - writeInt(con, numCols) - for (i in 1:numCols) { - writeObject(con, row[[i]]) - } -} - writeRaw <- function(con, batch) { writeInt(con, length(batch)) writeBin(batch, con, endian = "big") @@ -140,8 +132,8 @@ writeType <- function(con, class) { jobj = "j", environment = "e", Date = "D", - POSIXlt = 't', - POSIXct = 't', + POSIXlt = "t", + POSIXct = "t", stop(paste("Unsupported type for serialization", class))) writeBin(charToRaw(type), con) } @@ -175,7 +167,7 @@ writeGenericList <- function(con, list) { writeObject(con, elem) } } - + # Used to pass in hash maps required on Java side. writeEnv <- function(con, env) { len <- length(env) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 2efd4f0742e7..3c57a44db257 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -17,16 +17,13 @@ .sparkREnv <- new.env() -sparkR.onLoad <- function(libname, pkgname) { - .sparkREnv$libname <- libname -} - # Utility function that returns TRUE if we have an active connection to the # backend and FALSE otherwise connExists <- function(env) { tryCatch({ exists(".sparkRCon", envir = env) && isOpen(env[[".sparkRCon"]]) - }, error = function(err) { + }, + error = function(err) { return(FALSE) }) } @@ -43,7 +40,7 @@ sparkR.stop <- function() { callJMethod(sc, "stop") rm(".sparkRjsc", envir = env) } - + if (exists(".backendLaunched", envir = env)) { callJStatic("SparkRHandler", "stopBackend") } @@ -80,7 +77,7 @@ sparkR.stop <- function() { #' @param sparkEnvir Named list of environment variables to set on worker nodes. #' @param sparkExecutorEnv Named list of environment variables to be used when launching executors. #' @param sparkJars Character string vector of jar files to pass to the worker nodes. -#' @param sparkRLibDir The path where R is installed on the worker nodes. +#' @param sparkPackages Character string vector of packages from spark-packages.org #' @export #' @examples #'\dontrun{ @@ -100,23 +97,21 @@ sparkR.init <- function( sparkEnvir = list(), sparkExecutorEnv = list(), sparkJars = "", - sparkRLibDir = "") { + sparkPackages = "") { if (exists(".sparkRjsc", envir = .sparkREnv)) { - cat("Re-using existing Spark Context. Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n") + cat(paste("Re-using existing Spark Context.", + "Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n")) return(get(".sparkRjsc", envir = .sparkREnv)) } - sparkMem <- Sys.getenv("SPARK_MEM", "512m") jars <- suppressWarnings(normalizePath(as.character(sparkJars))) # Classpath separator is ";" on Windows # URI needs four /// as from http://stackoverflow.com/a/18522792 if (.Platform$OS.type == "unix") { - collapseChar <- ":" uriSep <- "//" } else { - collapseChar <- ";" uriSep <- "////" } @@ -129,7 +124,8 @@ sparkR.init <- function( args = path, sparkHome = sparkHome, jars = jars, - sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell")) + sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"), + packages = sparkPackages) # wait atmost 100 seconds for JVM to launch wait <- 0.1 for (i in 1:25) { @@ -142,7 +138,7 @@ sparkR.init <- function( if (!file.exists(path)) { stop("JVM is not ready after 10 seconds") } - f <- file(path, open='rb') + f <- file(path, open="rb") backendPort <- readInt(f) monitorPort <- readInt(f) close(f) @@ -158,33 +154,32 @@ sparkR.init <- function( .sparkREnv$backendPort <- backendPort tryCatch({ connectBackend("localhost", backendPort) - }, error = function(err) { + }, + error = function(err) { stop("Failed to connect JVM\n") }) if (nchar(sparkHome) != 0) { - sparkHome <- normalizePath(sparkHome) - } - - if (nchar(sparkRLibDir) != 0) { - .sparkREnv$libname <- sparkRLibDir + sparkHome <- suppressWarnings(normalizePath(sparkHome)) } sparkEnvirMap <- new.env() for (varname in names(sparkEnvir)) { sparkEnvirMap[[varname]] <- sparkEnvir[[varname]] } - + sparkExecutorEnvMap <- new.env() if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) { - sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) + sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- + paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) } for (varname in names(sparkExecutorEnv)) { sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]] } nonEmptyJars <- Filter(function(x) { x != "" }, jars) - localJarPaths <- sapply(nonEmptyJars, function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) + localJarPaths <- sapply(nonEmptyJars, + function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) # Set the start time to identify jobjs # Seconds resolution is good enough for this purpose, so use ints @@ -214,7 +209,7 @@ sparkR.init <- function( #' Initialize a new SQLContext. #' -#' This function creates a SparkContext from an existing JavaSparkContext and +#' This function creates a SparkContext from an existing JavaSparkContext and #' then uses it to initialize a new SQLContext #' #' @param jsc The existing JavaSparkContext created with SparkR.init() @@ -271,7 +266,8 @@ sparkRHive.init <- function(jsc = NULL) { ssc <- callJMethod(sc, "sc") hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.HiveContext", ssc) - }, error = function(err) { + }, + error = function(err) { stop("Spark SQL is not built with Hive support") }) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 69b2700191c9..69a2bc728f84 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -32,7 +32,7 @@ convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL, } results <- if (arrSize > 0) { - lapply(0:(arrSize - 1), + lapply(0 : (arrSize - 1), function(index) { obj <- callJMethod(jList, "get", as.integer(index)) @@ -41,8 +41,8 @@ convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL, if (isInstanceOf(obj, "scala.Tuple2")) { # JavaPairRDD[Array[Byte], Array[Byte]]. - keyBytes = callJMethod(obj, "_1") - valBytes = callJMethod(obj, "_2") + keyBytes <- callJMethod(obj, "_1") + valBytes <- callJMethod(obj, "_2") res <- list(unserialize(keyBytes), unserialize(valBytes)) } else { @@ -314,7 +314,8 @@ convertEnvsToList <- function(keys, vals) { # Utility function to capture the varargs into environment object varargsToEnv <- function(...) { - pairs <- as.list(substitute(list(...)))[-1L] + # Based on http://stackoverflow.com/a/3057419/4577954 + pairs <- list(...) env <- new.env() for (name in names(pairs)) { env[[name]] <- pairs[[name]] @@ -334,18 +335,21 @@ getStorageLevel <- function(newLevel = c("DISK_ONLY", "MEMORY_ONLY_SER_2", "OFF_HEAP")) { match.arg(newLevel) + storageLevelClass <- "org.apache.spark.storage.StorageLevel" storageLevel <- switch(newLevel, - "DISK_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY"), - "DISK_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY_2"), - "MEMORY_AND_DISK" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK"), - "MEMORY_AND_DISK_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_2"), - "MEMORY_AND_DISK_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER"), - "MEMORY_AND_DISK_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER_2"), - "MEMORY_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY"), - "MEMORY_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_2"), - "MEMORY_ONLY_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER"), - "MEMORY_ONLY_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER_2"), - "OFF_HEAP" = callJStatic("org.apache.spark.storage.StorageLevel", "OFF_HEAP")) + "DISK_ONLY" = callJStatic(storageLevelClass, "DISK_ONLY"), + "DISK_ONLY_2" = callJStatic(storageLevelClass, "DISK_ONLY_2"), + "MEMORY_AND_DISK" = callJStatic(storageLevelClass, "MEMORY_AND_DISK"), + "MEMORY_AND_DISK_2" = callJStatic(storageLevelClass, "MEMORY_AND_DISK_2"), + "MEMORY_AND_DISK_SER" = callJStatic(storageLevelClass, + "MEMORY_AND_DISK_SER"), + "MEMORY_AND_DISK_SER_2" = callJStatic(storageLevelClass, + "MEMORY_AND_DISK_SER_2"), + "MEMORY_ONLY" = callJStatic(storageLevelClass, "MEMORY_ONLY"), + "MEMORY_ONLY_2" = callJStatic(storageLevelClass, "MEMORY_ONLY_2"), + "MEMORY_ONLY_SER" = callJStatic(storageLevelClass, "MEMORY_ONLY_SER"), + "MEMORY_ONLY_SER_2" = callJStatic(storageLevelClass, "MEMORY_ONLY_SER_2"), + "OFF_HEAP" = callJStatic(storageLevelClass, "OFF_HEAP")) } # Utility function for functions where an argument needs to be integer but we want to allow @@ -357,44 +361,37 @@ numToInt <- function(num) { as.integer(num) } -# create a Seq in JVM -toSeq <- function(...) { - callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", list(...)) -} - -# create a Seq in JVM from a list -listToSeq <- function(l) { - callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", l) -} - # Utility function to recursively traverse the Abstract Syntax Tree (AST) of a -# user defined function (UDF), and to examine variables in the UDF to decide +# user defined function (UDF), and to examine variables in the UDF to decide # if their values should be included in the new function environment. # param # node The current AST node in the traversal. # oldEnv The original function environment. # defVars An Accumulator of variables names defined in the function's calling environment, # including function argument and local variable names. -# checkedFunc An environment of function objects examined during cleanClosure. It can +# checkedFunc An environment of function objects examined during cleanClosure. It can # be considered as a "name"-to-"list of functions" mapping. # newEnv A new function environment to store necessary function dependencies, an output argument. processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { nodeLen <- length(node) - + if (nodeLen > 1 && typeof(node) == "language") { - # Recursive case: current AST node is an internal node, check for its children. + # Recursive case: current AST node is an internal node, check for its children. if (length(node[[1]]) > 1) { for (i in 1:nodeLen) { processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) } - } else { # if node[[1]] is length of 1, check for some R special functions. + } else { + # if node[[1]] is length of 1, check for some R special functions. nodeChar <- as.character(node[[1]]) - if (nodeChar == "{" || nodeChar == "(") { # Skip start symbol. + if (nodeChar == "{" || nodeChar == "(") { + # Skip start symbol. for (i in 2:nodeLen) { processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) } - } else if (nodeChar == "<-" || nodeChar == "=" || - nodeChar == "<<-") { # Assignment Ops. + } else if (nodeChar == "<-" || nodeChar == "=" || + nodeChar == "<<-") { + # Assignment Ops. defVar <- node[[2]] if (length(defVar) == 1 && typeof(defVar) == "symbol") { # Add the defined variable name into defVars. @@ -405,14 +402,16 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { for (i in 3:nodeLen) { processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) } - } else if (nodeChar == "function") { # Function definition. + } else if (nodeChar == "function") { + # Function definition. # Add parameter names. newArgs <- names(node[[2]]) lapply(newArgs, function(arg) { addItemToAccumulator(defVars, arg) }) for (i in 3:nodeLen) { processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) } - } else if (nodeChar == "$") { # Skip the field. + } else if (nodeChar == "$") { + # Skip the field. processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv) } else if (nodeChar == "::" || nodeChar == ":::") { processClosure(node[[3]], oldEnv, defVars, checkedFuncs, newEnv) @@ -422,38 +421,43 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { } } } - } else if (nodeLen == 1 && + } else if (nodeLen == 1 && (typeof(node) == "symbol" || typeof(node) == "language")) { # Base case: current AST node is a leaf node and a symbol or a function call. nodeChar <- as.character(node) - if (!nodeChar %in% defVars$data) { # Not a function parameter or local variable. + if (!nodeChar %in% defVars$data) { + # Not a function parameter or local variable. func.env <- oldEnv topEnv <- parent.env(.GlobalEnv) - # Search in function environment, and function's enclosing environments + # Search in function environment, and function's enclosing environments # up to global environment. There is no need to look into package environments - # above the global or namespace environment that is not SparkR below the global, + # above the global or namespace environment that is not SparkR below the global, # as they are assumed to be loaded on workers. while (!identical(func.env, topEnv)) { # Namespaces other than "SparkR" will not be searched. - if (!isNamespace(func.env) || - (getNamespaceName(func.env) == "SparkR" && - !(nodeChar %in% getNamespaceExports("SparkR")))) { # Only include SparkR internals. + if (!isNamespace(func.env) || + (getNamespaceName(func.env) == "SparkR" && + !(nodeChar %in% getNamespaceExports("SparkR")))) { + # Only include SparkR internals. + # Set parameter 'inherits' to FALSE since we do not need to search in # attached package environments. if (tryCatch(exists(nodeChar, envir = func.env, inherits = FALSE), error = function(e) { FALSE })) { obj <- get(nodeChar, envir = func.env, inherits = FALSE) - if (is.function(obj)) { # If the node is a function call. - funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F, + if (is.function(obj)) { + # If the node is a function call. + funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F, ifnotfound = list(list(NULL)))[[1]] found <- sapply(funcList, function(func) { ifelse(identical(func, obj), TRUE, FALSE) }) - if (sum(found) > 0) { # If function has been examined, ignore. + if (sum(found) > 0) { + # If function has been examined, ignore. break } # Function has not been examined, record it and recursively clean its closure. - assign(nodeChar, + assign(nodeChar, if (is.null(funcList[[1]])) { list(obj) } else { @@ -466,7 +470,7 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { break } } - + # Continue to search in enclosure. func.env <- parent.env(func.env) } @@ -474,8 +478,8 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { } } -# Utility function to get user defined function (UDF) dependencies (closure). -# More specifically, this function captures the values of free variables defined +# Utility function to get user defined function (UDF) dependencies (closure). +# More specifically, this function captures the values of free variables defined # outside a UDF, and stores them in the function's environment. # param # func A function whose closure needs to be captured. @@ -488,11 +492,12 @@ cleanClosure <- function(func, checkedFuncs = new.env()) { newEnv <- new.env(parent = .GlobalEnv) func.body <- body(func) oldEnv <- environment(func) - # defVars is an Accumulator of variables names defined in the function's calling + # defVars is an Accumulator of variables names defined in the function's calling # environment. First, function's arguments are added to defVars. defVars <- initAccumulator() argNames <- names(as.list(args(func))) - for (i in 1:(length(argNames) - 1)) { # Remove the ending NULL in pairlist. + for (i in 1:(length(argNames) - 1)) { + # Remove the ending NULL in pairlist. addItemToAccumulator(defVars, argNames[i]) } # Recursively examine variables in the function body. @@ -509,15 +514,15 @@ cleanClosure <- function(func, checkedFuncs = new.env()) { # return value # A list of two result RDDs. appendPartitionLengths <- function(x, other) { - if (getSerializedMode(x) != getSerializedMode(other) || + if (getSerializedMode(x) != getSerializedMode(other) || getSerializedMode(x) == "byte") { # Append the number of elements in each partition to that partition so that we can later # know the boundary of elements from x and other. # - # Note that this appending also serves the purpose of reserialization, because even if + # Note that this appending also serves the purpose of reserialization, because even if # any RDD is serialized, we need to reserialize it to make sure its partitions are encoded # as a single byte array. For example, partitions of an RDD generated from partitionBy() - # may be encoded as multiple byte arrays. + # may be encoded as multiple byte arrays. appendLength <- function(part) { len <- length(part) part[[len + 1]] <- len + 1 @@ -544,23 +549,25 @@ mergePartitions <- function(rdd, zip) { lengthOfValues <- part[[len]] lengthOfKeys <- part[[len - lengthOfValues]] stopifnot(len == lengthOfKeys + lengthOfValues) - - # For zip operation, check if corresponding partitions of both RDDs have the same number of elements. + + # For zip operation, check if corresponding partitions + # of both RDDs have the same number of elements. if (zip && lengthOfKeys != lengthOfValues) { - stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.") + stop(paste("Can only zip RDDs with same number of elements", + "in each pair of corresponding partitions.")) } - + if (lengthOfKeys > 1) { keys <- part[1 : (lengthOfKeys - 1)] } else { keys <- list() } if (lengthOfValues > 1) { - values <- part[(lengthOfKeys + 1) : (len - 1)] + values <- part[ (lengthOfKeys + 1) : (len - 1) ] } else { values <- list() } - + if (!zip) { return(mergeCompactLists(keys, values)) } @@ -578,6 +585,6 @@ mergePartitions <- function(rdd, zip) { part } } - + PipelinedRDD(rdd, partitionFunc) } diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R index 8fe711b62208..2a8a8213d084 100644 --- a/R/pkg/inst/profile/general.R +++ b/R/pkg/inst/profile/general.R @@ -16,7 +16,7 @@ # .First <- function() { - home <- Sys.getenv("SPARK_HOME") - .libPaths(c(file.path(home, "R", "lib"), .libPaths())) + packageDir <- Sys.getenv("SPARKR_PACKAGE_DIR") + .libPaths(c(packageDir, .libPaths())) Sys.setenv(NOAWT=1) } diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R index 773b6ecf582d..7189f1a26093 100644 --- a/R/pkg/inst/profile/shell.R +++ b/R/pkg/inst/profile/shell.R @@ -27,7 +27,21 @@ sc <- SparkR::sparkR.init() assign("sc", sc, envir=.GlobalEnv) sqlContext <- SparkR::sparkRSQL.init(sc) + sparkVer <- SparkR:::callJMethod(sc, "version") assign("sqlContext", sqlContext, envir=.GlobalEnv) - cat("\n Welcome to SparkR!") + cat("\n Welcome to") + cat("\n") + cat(" ____ __", "\n") + cat(" / __/__ ___ _____/ /__", "\n") + cat(" _\\ \\/ _ \\/ _ `/ __/ '_/", "\n") + cat(" /___/ .__/\\_,_/_/ /_/\\_\\") + if (nchar(sparkVer) == 0) { + cat("\n") + } else { + cat(" version ", sparkVer, "\n") + } + cat(" /_/", "\n") + cat("\n") + cat("\n Spark context is available as sc, SQL context is available as sqlContext\n") } diff --git a/R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar b/R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar new file mode 100644 index 000000000000..1d5c2af631aa Binary files /dev/null and b/R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar differ diff --git a/R/pkg/inst/tests/jarTest.R b/R/pkg/inst/tests/jarTest.R new file mode 100644 index 000000000000..d68bb20950b0 --- /dev/null +++ b/R/pkg/inst/tests/jarTest.R @@ -0,0 +1,32 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +library(SparkR) + +sc <- sparkR.init() + +helloTest <- SparkR:::callJStatic("sparkR.test.hello", + "helloWorld", + "Dave") + +basicFunction <- SparkR:::callJStatic("sparkR.test.basicFunction", + "addStuff", + 2L, + 2L) + +sparkR.stop() +output <- c(helloTest, basicFunction) +writeLines(output) diff --git a/R/pkg/inst/tests/packageInAJarTest.R b/R/pkg/inst/tests/packageInAJarTest.R new file mode 100644 index 000000000000..207a37a0cb47 --- /dev/null +++ b/R/pkg/inst/tests/packageInAJarTest.R @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +library(SparkR) +library(sparkPackageTest) + +sc <- sparkR.init() + +run1 <- myfunc(5L) + +run2 <- myfunc(-4L) + +sparkR.stop() + +if(run1 != 6) quit(save = "no", status = 1) + +if(run2 != -3) quit(save = "no", status = 1) diff --git a/R/pkg/inst/tests/test_Serde.R b/R/pkg/inst/tests/test_Serde.R new file mode 100644 index 000000000000..dddce54d7044 --- /dev/null +++ b/R/pkg/inst/tests/test_Serde.R @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("SerDe functionality") + +sc <- sparkR.init() + +test_that("SerDe of primitive types", { + x <- callJStatic("SparkRHandler", "echo", 1L) + expect_equal(x, 1L) + expect_equal(class(x), "integer") + + x <- callJStatic("SparkRHandler", "echo", 1) + expect_equal(x, 1) + expect_equal(class(x), "numeric") + + x <- callJStatic("SparkRHandler", "echo", TRUE) + expect_true(x) + expect_equal(class(x), "logical") + + x <- callJStatic("SparkRHandler", "echo", "abc") + expect_equal(x, "abc") + expect_equal(class(x), "character") +}) + +test_that("SerDe of list of primitive types", { + x <- list(1L, 2L, 3L) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "integer") + + x <- list(1, 2, 3) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "numeric") + + x <- list(TRUE, FALSE) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "logical") + + x <- list("a", "b", "c") + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "character") + + # Empty list + x <- list() + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) +}) + +test_that("SerDe of list of lists", { + x <- list(list(1L, 2L, 3L), list(1, 2, 3), + list(TRUE, FALSE), list("a", "b", "c")) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + + # List of empty lists + x <- list(list(), list()) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) +}) diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R index ca4218f3819f..f2452ed97d2e 100644 --- a/R/pkg/inst/tests/test_binaryFile.R +++ b/R/pkg/inst/tests/test_binaryFile.R @@ -20,7 +20,7 @@ context("functions on binary files") # JavaSparkContext handle sc <- sparkR.init() -mockFile = c("Spark is pretty.", "Spark is awesome.") +mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("saveAsObjectFile()/objectFile() following textFile() works", { fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") @@ -59,15 +59,15 @@ test_that("saveAsObjectFile()/objectFile() following RDD transformations works", wordCount <- lapply(words, function(word) { list(word, 1L) }) counts <- reduceByKey(wordCount, "+", 2L) - + saveAsObjectFile(counts, fileName2) counts <- objectFile(sc, fileName2) - + output <- collect(counts) expected <- list(list("awesome.", 1), list("Spark", 2), list("pretty.", 1), list("is", 2)) expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) - + unlink(fileName1) unlink(fileName2, recursive = TRUE) }) @@ -82,9 +82,8 @@ test_that("saveAsObjectFile()/objectFile() works with multiple paths", { saveAsObjectFile(rdd2, fileName2) rdd <- objectFile(sc, c(fileName1, fileName2)) - expect_true(count(rdd) == 2) + expect_equal(count(rdd), 2) unlink(fileName1, recursive = TRUE) unlink(fileName2, recursive = TRUE) }) - diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R index 6785a7bdae8c..f054ac9a87d6 100644 --- a/R/pkg/inst/tests/test_binary_function.R +++ b/R/pkg/inst/tests/test_binary_function.R @@ -30,7 +30,7 @@ mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("union on two RDDs", { actual <- collect(unionRDD(rdd, rdd)) expect_equal(actual, as.list(rep(nums, 2))) - + fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) @@ -38,13 +38,13 @@ test_that("union on two RDDs", { union.rdd <- unionRDD(rdd, text.rdd) actual <- collect(union.rdd) expect_equal(actual, c(as.list(nums), mockFile)) - expect_true(getSerializedMode(union.rdd) == "byte") + expect_equal(getSerializedMode(union.rdd), "byte") - rdd<- map(text.rdd, function(x) {x}) + rdd <- map(text.rdd, function(x) {x}) union.rdd <- unionRDD(rdd, text.rdd) actual <- collect(union.rdd) expect_equal(actual, as.list(c(mockFile, mockFile))) - expect_true(getSerializedMode(union.rdd) == "byte") + expect_equal(getSerializedMode(union.rdd), "byte") unlink(fileName) }) @@ -52,14 +52,14 @@ test_that("union on two RDDs", { test_that("cogroup on two RDDs", { rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) - cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) + cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) actual <- collect(cogroup.rdd) - expect_equal(actual, + expect_equal(actual, list(list(1, list(list(1), list(2, 3))), list(2, list(list(4), list())))) - + rdd1 <- parallelize(sc, list(list("a", 1), list("a", 4))) rdd2 <- parallelize(sc, list(list("b", 2), list("a", 3))) - cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) + cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) actual <- collect(cogroup.rdd) expected <- list(list("b", list(list(), list(2))), list("a", list(list(1, 4), list(3)))) @@ -71,31 +71,31 @@ test_that("zipPartitions() on RDDs", { rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 - actual <- collect(zipPartitions(rdd1, rdd2, rdd3, + actual <- collect(zipPartitions(rdd1, rdd2, rdd3, func = function(x, y, z) { list(list(x, y, z))} )) expect_equal(actual, list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6)))) - - mockFile = c("Spark is pretty.", "Spark is awesome.") + + mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) - + rdd <- textFile(sc, fileName, 1) - actual <- collect(zipPartitions(rdd, rdd, + actual <- collect(zipPartitions(rdd, rdd, func = function(x, y) { list(paste(x, y, sep = "\n")) })) expected <- list(paste(mockFile, mockFile, sep = "\n")) expect_equal(actual, expected) - + rdd1 <- parallelize(sc, 0:1, 1) - actual <- collect(zipPartitions(rdd1, rdd, + actual <- collect(zipPartitions(rdd1, rdd, func = function(x, y) { list(x + nchar(y)) })) expected <- list(0:1 + nchar(mockFile)) expect_equal(actual, expected) - + rdd <- map(rdd, function(x) { x }) - actual <- collect(zipPartitions(rdd, rdd1, + actual <- collect(zipPartitions(rdd, rdd1, func = function(x, y) { list(y + nchar(x)) })) expect_equal(actual, expected) - + unlink(fileName) }) diff --git a/R/pkg/inst/tests/test_client.R b/R/pkg/inst/tests/test_client.R new file mode 100644 index 000000000000..8a20991f89af --- /dev/null +++ b/R/pkg/inst/tests/test_client.R @@ -0,0 +1,36 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions in client.R") + +test_that("adding spark-testing-base as a package works", { + args <- generateSparkSubmitArgs("", "", "", "", + "holdenk:spark-testing-base:1.3.0_0.0.5") + expect_equal(gsub("[[:space:]]", "", args), + gsub("[[:space:]]", "", + "--packages holdenk:spark-testing-base:1.3.0_0.0.5")) +}) + +test_that("no package specified doesn't add packages flag", { + args <- generateSparkSubmitArgs("", "", "", "", "") + expect_equal(gsub("[[:space:]]", "", args), + "") +}) + +test_that("multiple packages don't produce a warning", { + expect_that(generateSparkSubmitArgs("", "", "", "", c("A", "B")), not(gives_warning())) +}) diff --git a/R/pkg/inst/tests/test_includeJAR.R b/R/pkg/inst/tests/test_includeJAR.R new file mode 100644 index 000000000000..cc1faeabffe3 --- /dev/null +++ b/R/pkg/inst/tests/test_includeJAR.R @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +context("include an external JAR in SparkContext") + +runScript <- function() { + sparkHome <- Sys.getenv("SPARK_HOME") + sparkTestJarPath <- "R/lib/SparkR/test_support/sparktestjar_2.10-1.0.jar" + jarPath <- paste("--jars", shQuote(file.path(sparkHome, sparkTestJarPath))) + scriptPath <- file.path(sparkHome, "R/lib/SparkR/tests/jarTest.R") + submitPath <- file.path(sparkHome, "bin/spark-submit") + res <- system2(command = submitPath, + args = c(jarPath, scriptPath), + stdout = TRUE) + tail(res, 2) +} + +test_that("sparkJars tag in SparkContext", { + testOutput <- runScript() + helloTest <- testOutput[1] + expect_equal(helloTest, "Hello, Dave") + basicFunction <- testOutput[2] + expect_equal(basicFunction, "4") +}) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R new file mode 100644 index 000000000000..f272de78ad4a --- /dev/null +++ b/R/pkg/inst/tests/test_mllib.R @@ -0,0 +1,61 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib functions") + +# Tests for MLlib functions in SparkR + +sc <- sparkR.init() + +sqlContext <- sparkRSQL.init(sc) + +test_that("glm and predict", { + training <- createDataFrame(sqlContext, iris) + test <- select(training, "Sepal_Length") + model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian") + prediction <- predict(model, test) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") +}) + +test_that("predictions match with native glm", { + training <- createDataFrame(sqlContext, iris) + model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + +test_that("dot minus and intercept vs native glm", { + training <- createDataFrame(sqlContext, iris) + model <- glm(Sepal_Width ~ . - Species + 0, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + +test_that("summary coefficients match with native glm", { + training <- createDataFrame(sqlContext, iris) + stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) + coefs <- as.vector(stats$coefficients) + rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) + expect_true(all(abs(rCoefs - coefs) < 1e-6)) + expect_true(all( + as.character(stats$features) == + c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica"))) +}) diff --git a/R/pkg/inst/tests/test_parallelize_collect.R b/R/pkg/inst/tests/test_parallelize_collect.R index fff028657db3..2552127cc547 100644 --- a/R/pkg/inst/tests/test_parallelize_collect.R +++ b/R/pkg/inst/tests/test_parallelize_collect.R @@ -57,7 +57,7 @@ test_that("parallelize() on simple vectors and lists returns an RDD", { strListRDD2) for (rdd in rdds) { - expect_true(inherits(rdd, "RDD")) + expect_is(rdd, "RDD") expect_true(.hasSlot(rdd, "jrdd") && inherits(rdd@jrdd, "jobj") && isInstanceOf(rdd@jrdd, "org.apache.spark.api.java.JavaRDD")) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index 03207353c31c..71aed2bb9d6a 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -33,9 +33,9 @@ test_that("get number of partitions in RDD", { }) test_that("first on RDD", { - expect_true(first(rdd) == 1) + expect_equal(first(rdd), 1) newrdd <- lapply(rdd, function(x) x + 1) - expect_true(first(newrdd) == 2) + expect_equal(first(newrdd), 2) }) test_that("count and length on RDD", { @@ -250,7 +250,7 @@ test_that("flatMapValues() on pairwise RDDs", { expect_equal(actual, list(list(1,1), list(1,2), list(2,3), list(2,4))) # Generate x to x+1 for every value - actual <- collect(flatMapValues(intRdd, function(x) { x:(x + 1) })) + actual <- collect(flatMapValues(intRdd, function(x) { x: (x + 1) })) expect_equal(actual, list(list(1L, -1), list(1L, 0), list(2L, 100), list(2L, 101), list(2L, 1), list(2L, 2), list(1L, 200), list(1L, 201))) @@ -293,7 +293,7 @@ test_that("sumRDD() on RDDs", { }) test_that("keyBy on RDDs", { - func <- function(x) { x*x } + func <- function(x) { x * x } keys <- keyBy(rdd, func) actual <- collect(keys) expect_equal(actual, lapply(nums, function(x) { list(func(x), x) })) @@ -311,7 +311,7 @@ test_that("repartition/coalesce on RDDs", { r2 <- repartition(rdd, 6) expect_equal(numPartitions(r2), 6L) count <- length(collectPartition(r2, 0L)) - expect_true(count >=0 && count <= 4) + expect_true(count >= 0 && count <= 4) # coalesce r3 <- coalesce(rdd, 1) @@ -447,7 +447,7 @@ test_that("zipRDD() on RDDs", { expect_equal(actual, list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004))) - mockFile = c("Spark is pretty.", "Spark is awesome.") + mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) @@ -477,16 +477,16 @@ test_that("cartesian() on RDDs", { list(1, 1), list(1, 2), list(1, 3), list(2, 1), list(2, 2), list(2, 3), list(3, 1), list(3, 2), list(3, 3))) - + # test case where one RDD is empty emptyRdd <- parallelize(sc, list()) actual <- collect(cartesian(rdd, emptyRdd)) expect_equal(actual, list()) - mockFile = c("Spark is pretty.", "Spark is awesome.") + mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) - + rdd <- textFile(sc, fileName) actual <- collect(cartesian(rdd, rdd)) expected <- list( @@ -495,7 +495,7 @@ test_that("cartesian() on RDDs", { list("Spark is pretty.", "Spark is pretty."), list("Spark is pretty.", "Spark is awesome.")) expect_equal(sortKeyValueList(actual), expected) - + rdd1 <- parallelize(sc, 0:1) actual <- collect(cartesian(rdd1, rdd)) expect_equal(sortKeyValueList(actual), @@ -504,11 +504,11 @@ test_that("cartesian() on RDDs", { list(0, "Spark is awesome."), list(1, "Spark is pretty."), list(1, "Spark is awesome."))) - + rdd1 <- map(rdd, function(x) { x }) actual <- collect(cartesian(rdd, rdd1)) expect_equal(sortKeyValueList(actual), expected) - + unlink(fileName) }) @@ -669,13 +669,15 @@ test_that("fullOuterJoin() on pairwise RDDs", { rdd1 <- parallelize(sc, list(list(1,2), list(1,3), list(3,3))) rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) - expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4)), list(3, list(3, NULL))) + expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), + list(2, list(NULL, 4)), list(3, list(3, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) rdd1 <- parallelize(sc, list(list("a",2), list("a",3), list("c", 1))) rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) - expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1)), list("c", list(1, NULL))) + expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), + list("a", list(3, 1)), list("c", list(1, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -683,13 +685,15 @@ test_that("fullOuterJoin() on pairwise RDDs", { rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), - sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), list(3, list(NULL, 3)), list(4, list(NULL, 4))))) + sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), + list(3, list(NULL, 3)), list(4, list(NULL, 4))))) rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), - sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), list("d", list(NULL, 4)), list("c", list(NULL, 3))))) + sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), + list("d", list(NULL, 4)), list("c", list(NULL, 3))))) }) test_that("sortByKey() on pairwise RDDs", { @@ -760,7 +764,7 @@ test_that("collectAsMap() on a pairwise RDD", { }) test_that("show()", { - rdd <- parallelize(sc, list(1:10)) + rdd <- parallelize(sc, list(1:10)) expect_output(show(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+") }) diff --git a/R/pkg/inst/tests/test_shuffle.R b/R/pkg/inst/tests/test_shuffle.R index d7dedda553c5..adf0b91d25fe 100644 --- a/R/pkg/inst/tests/test_shuffle.R +++ b/R/pkg/inst/tests/test_shuffle.R @@ -106,39 +106,39 @@ test_that("aggregateByKey", { zeroValue <- list(0, 0) seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } - aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) - + aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) + actual <- collect(aggregatedRDD) - + expected <- list(list(1, list(3, 2)), list(2, list(7, 2))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) # test aggregateByKey for string keys rdd <- parallelize(sc, list(list("a", 1), list("a", 2), list("b", 3), list("b", 4))) - + zeroValue <- list(0, 0) seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } - aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) + aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) actual <- collect(aggregatedRDD) - + expected <- list(list("a", list(3, 2)), list("b", list(7, 2))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) }) -test_that("foldByKey", { +test_that("foldByKey", { # test foldByKey for int keys folded <- foldByKey(intRdd, 0, "+", 2L) - + actual <- collect(folded) - + expected <- list(list(2L, 101), list(1L, 199)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) # test foldByKey for double keys folded <- foldByKey(doubleRdd, 0, "+", 2L) - + actual <- collect(folded) expected <- list(list(1.5, 199), list(2.5, 101)) @@ -146,15 +146,15 @@ test_that("foldByKey", { # test foldByKey for string keys stringKeyPairs <- list(list("a", -1), list("b", 100), list("b", 1), list("a", 200)) - + stringKeyRDD <- parallelize(sc, stringKeyPairs) folded <- foldByKey(stringKeyRDD, 0, "+", 2L) - + actual <- collect(folded) - + expected <- list(list("b", 101), list("a", 199)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) - + # test foldByKey for empty pair RDD rdd <- parallelize(sc, list()) folded <- foldByKey(rdd, 0, "+", 2L) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 8946348ef801..e159a6958427 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -19,6 +19,14 @@ library(testthat) context("SparkSQL functions") +# Utility function for easily checking the values of a StructField +checkStructField <- function(actual, expectedName, expectedType, expectedNullable) { + expect_equal(class(actual), "structField") + expect_equal(actual$name(), expectedName) + expect_equal(actual$dataType.toString(), expectedType) + expect_equal(actual$nullable(), expectedNullable) +} + # Tests for SparkSQL functions in SparkR sc <- sparkR.init() @@ -41,64 +49,103 @@ mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}", jsonPathNa <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLinesNa, jsonPathNa) -test_that("infer types", { +# For test complex types in DataFrame +mockLinesComplexType <- + c("{\"c1\":[1, 2, 3], \"c2\":[\"a\", \"b\", \"c\"], \"c3\":[1.0, 2.0, 3.0]}", + "{\"c1\":[4, 5, 6], \"c2\":[\"d\", \"e\", \"f\"], \"c3\":[4.0, 5.0, 6.0]}", + "{\"c1\":[7, 8, 9], \"c2\":[\"g\", \"h\", \"i\"], \"c3\":[7.0, 8.0, 9.0]}") +complexTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +writeLines(mockLinesComplexType, complexTypeJsonPath) + +test_that("infer types and check types", { expect_equal(infer_type(1L), "integer") expect_equal(infer_type(1.0), "double") expect_equal(infer_type("abc"), "string") expect_equal(infer_type(TRUE), "boolean") expect_equal(infer_type(as.Date("2015-03-11")), "date") expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp") - expect_equal(infer_type(c(1L, 2L)), - list(type = 'array', elementType = "integer", containsNull = TRUE)) - expect_equal(infer_type(list(1L, 2L)), - list(type = 'array', elementType = "integer", containsNull = TRUE)) - expect_equal(infer_type(list(a = 1L, b = "2")), - structType(structField(x = "a", type = "integer", nullable = TRUE), - structField(x = "b", type = "string", nullable = TRUE))) + expect_equal(infer_type(c(1L, 2L)), "array") + expect_equal(infer_type(list(1L, 2L)), "array") + testStruct <- infer_type(list(a = 1L, b = "2")) + expect_equal(class(testStruct), "structType") + checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE) + checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE) e <- new.env() assign("a", 1L, envir = e) - expect_equal(infer_type(e), - list(type = "map", keyType = "string", valueType = "integer", - valueContainsNull = TRUE)) + expect_equal(infer_type(e), "map") + + expect_error(checkType("map"), "Key type in a map must be string or character") }) test_that("structType and structField", { testField <- structField("a", "string") - expect_true(inherits(testField, "structField")) - expect_true(testField$name() == "a") + expect_is(testField, "structField") + expect_equal(testField$name(), "a") expect_true(testField$nullable()) - + testSchema <- structType(testField, structField("b", "integer")) - expect_true(inherits(testSchema, "structType")) - expect_true(inherits(testSchema$fields()[[2]], "structField")) - expect_true(testSchema$fields()[[1]]$dataType.toString() == "StringType") + expect_is(testSchema, "structType") + expect_is(testSchema$fields()[[2]], "structField") + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") }) test_that("create DataFrame from RDD", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- createDataFrame(sqlContext, rdd, list("a", "b")) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) + expect_equal(nrow(df), 10) + expect_equal(ncol(df), 2) + expect_equal(dim(df), c(10, 2)) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) df <- createDataFrame(sqlContext, rdd) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("_1", "_2")) schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), structField(x = "b", type = "string", nullable = TRUE)) df <- createDataFrame(sqlContext, rdd, schema) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) df <- createDataFrame(sqlContext, rdd) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- jsonFile(sqlContext, jsonPathNa) + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, + error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)") + insertInto(df, "people") + expect_equal(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"), c(16)) + expect_equal(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"), c(176.5)) + + schema <- structType(structField("name", "string"), structField("age", "integer"), + structField("height", "float")) + df2 <- createDataFrame(sqlContext, df.toRDD, schema) + expect_equal(columns(df2), c("name", "age", "height")) + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float"))) + expect_equal(collect(where(df2, df2$name == "Bob")), c("Bob", 16, 176.5)) + + localDF <- data.frame(name=c("John", "Smith", "Sarah"), + age=c(19, 23, 18), + height=c(164.10, 181.4, 173.7)) + df <- createDataFrame(sqlContext, localDF, schema) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + expect_equal(columns(df), c("name", "age", "height")) + expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float"))) + expect_equal(collect(where(df, df$name == "John")), c("John", 19, 164.10)) }) test_that("convert NAs to null type in DataFrames", { @@ -141,26 +188,26 @@ test_that("convert NAs to null type in DataFrames", { test_that("toDF", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- toDF(rdd, list("a", "b")) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) df <- toDF(rdd) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("_1", "_2")) schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), structField(x = "b", type = "string", nullable = TRUE)) df <- toDF(rdd, schema) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) df <- toDF(rdd) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) }) @@ -195,8 +242,7 @@ test_that("create DataFrame with different data types", { expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) }) -# TODO: enable this test after fix serialization for nested object -#test_that("create DataFrame with nested array and struct", { +test_that("create DataFrame with nested array and map", { # e <- new.env() # assign("n", 3L, envir = e) # l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) @@ -206,25 +252,82 @@ test_that("create DataFrame with different data types", { # expect_equal(count(df), 1) # ldf <- collect(df) # expect_equal(ldf[1,], l[[1]]) -#}) + + # ArrayType and MapType + e <- new.env() + assign("n", 3L, envir = e) + + l <- list(as.list(1:10), list("a", "b"), e) + df <- createDataFrame(sqlContext, list(l), c("a", "b", "c")) + expect_equal(dtypes(df), list(c("a", "array"), + c("b", "array"), + c("c", "map"))) + expect_equal(count(df), 1) + ldf <- collect(df) + expect_equal(names(ldf), c("a", "b", "c")) + expect_equal(ldf[1, 1][[1]], l[[1]]) + expect_equal(ldf[1, 2][[1]], l[[2]]) + e <- ldf$c[[1]] + expect_equal(class(e), "environment") + expect_equal(ls(e), "n") + expect_equal(e$n, 3L) +}) + +# For test map type in DataFrame +mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}", + "{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}", + "{\"name\":\"David\",\"info\":{\"age\":60,\"height\":180}}") +mapTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +writeLines(mockLinesMapType, mapTypeJsonPath) + +test_that("Collect DataFrame with complex types", { + # ArrayType + df <- jsonFile(sqlContext, complexTypeJsonPath) + + ldf <- collect(df) + expect_equal(nrow(ldf), 3) + expect_equal(ncol(ldf), 3) + expect_equal(names(ldf), c("c1", "c2", "c3")) + expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list (7, 8, 9))) + expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list ("g", "h", "i"))) + expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list (7.0, 8.0, 9.0))) + + # MapType + schema <- structType(structField("name", "string"), + structField("info", "map")) + df <- read.df(sqlContext, mapTypeJsonPath, "json", schema) + expect_equal(dtypes(df), list(c("name", "string"), + c("info", "map"))) + ldf <- collect(df) + expect_equal(nrow(ldf), 3) + expect_equal(ncol(ldf), 2) + expect_equal(names(ldf), c("name", "info")) + expect_equal(ldf$name, c("Bob", "Alice", "David")) + bob <- ldf$info[[1]] + expect_equal(class(bob), "environment") + expect_equal(bob$age, 16) + expect_equal(bob$height, 176.5) + + # TODO: tests for StructType after it is supported +}) test_that("jsonFile() on a local file returns a DataFrame", { df <- jsonFile(sqlContext, jsonPath) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) }) test_that("jsonRDD() on a RDD with json string", { rdd <- parallelize(sc, mockLines) - expect_true(count(rdd) == 3) + expect_equal(count(rdd), 3) df <- jsonRDD(sqlContext, rdd) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) rdd2 <- flatMap(rdd, function(x) c(x, x)) df <- jsonRDD(sqlContext, rdd2) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 6) + expect_is(df, "DataFrame") + expect_equal(count(df), 6) }) test_that("test cache, uncache and clearCache", { @@ -239,9 +342,9 @@ test_that("test cache, uncache and clearCache", { test_that("test tableNames and tables", { df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") - expect_true(length(tableNames(sqlContext)) == 1) + expect_equal(length(tableNames(sqlContext)), 1) df <- tables(sqlContext) - expect_true(count(df) == 1) + expect_equal(count(df), 1) dropTempTable(sqlContext, "table1") }) @@ -249,8 +352,8 @@ test_that("registerTempTable() results in a queryable table and sql() results in df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") newdf <- sql(sqlContext, "SELECT * FROM table1 where name = 'Michael'") - expect_true(inherits(newdf, "DataFrame")) - expect_true(count(newdf) == 1) + expect_is(newdf, "DataFrame") + expect_equal(count(newdf), 1) dropTempTable(sqlContext, "table1") }) @@ -270,14 +373,14 @@ test_that("insertInto() on a registered table", { registerTempTable(dfParquet, "table1") insertInto(dfParquet2, "table1") - expect_true(count(sql(sqlContext, "select * from table1")) == 5) - expect_true(first(sql(sqlContext, "select * from table1 order by age"))$name == "Michael") + expect_equal(count(sql(sqlContext, "select * from table1")), 5) + expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Michael") dropTempTable(sqlContext, "table1") registerTempTable(dfParquet, "table1") insertInto(dfParquet2, "table1", overwrite = TRUE) - expect_true(count(sql(sqlContext, "select * from table1")) == 2) - expect_true(first(sql(sqlContext, "select * from table1 order by age"))$name == "Bob") + expect_equal(count(sql(sqlContext, "select * from table1")), 2) + expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Bob") dropTempTable(sqlContext, "table1") }) @@ -285,16 +388,16 @@ test_that("table() returns a new DataFrame", { df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") tabledf <- table(sqlContext, "table1") - expect_true(inherits(tabledf, "DataFrame")) - expect_true(count(tabledf) == 3) + expect_is(tabledf, "DataFrame") + expect_equal(count(tabledf), 3) dropTempTable(sqlContext, "table1") }) test_that("toRDD() returns an RRDD", { df <- jsonFile(sqlContext, jsonPath) testRDD <- toRDD(df) - expect_true(inherits(testRDD, "RDD")) - expect_true(count(testRDD) == 3) + expect_is(testRDD, "RDD") + expect_equal(count(testRDD), 3) }) test_that("union on two RDDs created from DataFrames returns an RRDD", { @@ -302,9 +405,9 @@ test_that("union on two RDDs created from DataFrames returns an RRDD", { RDD1 <- toRDD(df) RDD2 <- toRDD(df) unioned <- unionRDD(RDD1, RDD2) - expect_true(inherits(unioned, "RDD")) - expect_true(SparkR:::getSerializedMode(unioned) == "byte") - expect_true(collect(unioned)[[2]]$name == "Andy") + expect_is(unioned, "RDD") + expect_equal(SparkR:::getSerializedMode(unioned), "byte") + expect_equal(collect(unioned)[[2]]$name, "Andy") }) test_that("union on mixed serialization types correctly returns a byte RRDD", { @@ -324,16 +427,16 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { dfRDD <- toRDD(df) unionByte <- unionRDD(rdd, dfRDD) - expect_true(inherits(unionByte, "RDD")) - expect_true(SparkR:::getSerializedMode(unionByte) == "byte") - expect_true(collect(unionByte)[[1]] == 1) - expect_true(collect(unionByte)[[12]]$name == "Andy") + expect_is(unionByte, "RDD") + expect_equal(SparkR:::getSerializedMode(unionByte), "byte") + expect_equal(collect(unionByte)[[1]], 1) + expect_equal(collect(unionByte)[[12]]$name, "Andy") unionString <- unionRDD(textRDD, dfRDD) - expect_true(inherits(unionString, "RDD")) - expect_true(SparkR:::getSerializedMode(unionString) == "byte") - expect_true(collect(unionString)[[1]] == "Michael") - expect_true(collect(unionString)[[5]]$name == "Andy") + expect_is(unionString, "RDD") + expect_equal(SparkR:::getSerializedMode(unionString), "byte") + expect_equal(collect(unionString)[[1]], "Michael") + expect_equal(collect(unionString)[[5]]$name, "Andy") }) test_that("objectFile() works with row serialization", { @@ -343,7 +446,7 @@ test_that("objectFile() works with row serialization", { saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) objectIn <- objectFile(sc, objectPath) - expect_true(inherits(objectIn, "RDD")) + expect_is(objectIn, "RDD") expect_equal(SparkR:::getSerializedMode(objectIn), "byte") expect_equal(collect(objectIn)[[2]]$age, 30) }) @@ -354,35 +457,69 @@ test_that("lapply() on a DataFrame returns an RDD with the correct columns", { row$newCol <- row$age + 5 row }) - expect_true(inherits(testRDD, "RDD")) + expect_is(testRDD, "RDD") collected <- collect(testRDD) - expect_true(collected[[1]]$name == "Michael") - expect_true(collected[[2]]$newCol == "35") + expect_equal(collected[[1]]$name, "Michael") + expect_equal(collected[[2]]$newCol, 35) }) test_that("collect() returns a data.frame", { df <- jsonFile(sqlContext, jsonPath) rdf <- collect(df) expect_true(is.data.frame(rdf)) - expect_true(names(rdf)[1] == "age") - expect_true(nrow(rdf) == 3) - expect_true(ncol(rdf) == 2) + expect_equal(names(rdf)[1], "age") + expect_equal(nrow(rdf), 3) + expect_equal(ncol(rdf), 2) + + # collect() returns data correctly from a DataFrame with 0 row + df0 <- limit(df, 0) + rdf <- collect(df0) + expect_true(is.data.frame(rdf)) + expect_equal(names(rdf)[1], "age") + expect_equal(nrow(rdf), 0) + expect_equal(ncol(rdf), 2) }) test_that("limit() returns DataFrame with the correct number of rows", { df <- jsonFile(sqlContext, jsonPath) dfLimited <- limit(df, 2) - expect_true(inherits(dfLimited, "DataFrame")) - expect_true(count(dfLimited) == 2) + expect_is(dfLimited, "DataFrame") + expect_equal(count(dfLimited), 2) }) test_that("collect() and take() on a DataFrame return the same number of rows and columns", { df <- jsonFile(sqlContext, jsonPath) - expect_true(nrow(collect(df)) == nrow(take(df, 10))) - expect_true(ncol(collect(df)) == ncol(take(df, 10))) + expect_equal(nrow(collect(df)), nrow(take(df, 10))) + expect_equal(ncol(collect(df)), ncol(take(df, 10))) }) -test_that("multiple pipeline transformations starting with a DataFrame result in an RDD with the correct values", { +test_that("collect() support Unicode characters", { + markUtf8 <- function(s) { + Encoding(s) <- "UTF-8" + s + } + + lines <- c("{\"name\":\"안녕하세요\"}", + "{\"name\":\"您好\", \"age\":30}", + "{\"name\":\"こんにちは\", \"age\":19}", + "{\"name\":\"Xin chào\"}") + + jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPath) + + df <- read.df(sqlContext, jsonPath, "json") + rdf <- collect(df) + expect_true(is.data.frame(rdf)) + expect_equal(rdf$name[1], markUtf8("안녕하세요")) + expect_equal(rdf$name[2], markUtf8("您好")) + expect_equal(rdf$name[3], markUtf8("こんにちは")) + expect_equal(rdf$name[4], markUtf8("Xin chào")) + + df1 <- createDataFrame(sqlContext, rdf) + expect_equal(collect(where(df1, df1$name == markUtf8("您好")))$name, markUtf8("您好")) +}) + +test_that("multiple pipeline transformations result in an RDD with the correct values", { df <- jsonFile(sqlContext, jsonPath) first <- lapply(df, function(row) { row$age <- row$age + 5 @@ -392,9 +529,9 @@ test_that("multiple pipeline transformations starting with a DataFrame result in row$testCol <- if (row$age == 35 && !is.na(row$age)) TRUE else FALSE row }) - expect_true(inherits(second, "RDD")) - expect_true(count(second) == 3) - expect_true(collect(second)[[2]]$age == 35) + expect_is(second, "RDD") + expect_equal(count(second), 3) + expect_equal(collect(second)[[2]]$age, 35) expect_true(collect(second)[[2]]$testCol) expect_false(collect(second)[[3]]$testCol) }) @@ -421,39 +558,51 @@ test_that("cache(), persist(), and unpersist() on a DataFrame", { test_that("schema(), dtypes(), columns(), names() return the correct values/format", { df <- jsonFile(sqlContext, jsonPath) testSchema <- schema(df) - expect_true(length(testSchema$fields()) == 2) - expect_true(testSchema$fields()[[1]]$dataType.toString() == "LongType") - expect_true(testSchema$fields()[[2]]$dataType.simpleString() == "string") - expect_true(testSchema$fields()[[1]]$name() == "age") + expect_equal(length(testSchema$fields()), 2) + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "LongType") + expect_equal(testSchema$fields()[[2]]$dataType.simpleString(), "string") + expect_equal(testSchema$fields()[[1]]$name(), "age") testTypes <- dtypes(df) - expect_true(length(testTypes[[1]]) == 2) - expect_true(testTypes[[1]][1] == "age") + expect_equal(length(testTypes[[1]]), 2) + expect_equal(testTypes[[1]][1], "age") testCols <- columns(df) - expect_true(length(testCols) == 2) - expect_true(testCols[2] == "name") + expect_equal(length(testCols), 2) + expect_equal(testCols[2], "name") testNames <- names(df) - expect_true(length(testNames) == 2) - expect_true(testNames[2] == "name") + expect_equal(length(testNames), 2) + expect_equal(testNames[2], "name") }) test_that("head() and first() return the correct data", { df <- jsonFile(sqlContext, jsonPath) testHead <- head(df) - expect_true(nrow(testHead) == 3) - expect_true(ncol(testHead) == 2) + expect_equal(nrow(testHead), 3) + expect_equal(ncol(testHead), 2) testHead2 <- head(df, 2) - expect_true(nrow(testHead2) == 2) - expect_true(ncol(testHead2) == 2) + expect_equal(nrow(testHead2), 2) + expect_equal(ncol(testHead2), 2) testFirst <- first(df) - expect_true(nrow(testFirst) == 1) + expect_equal(nrow(testFirst), 1) + + # head() and first() return the correct data on + # a DataFrame with 0 row + df0 <- limit(df, 0) + + testHead <- head(df0) + expect_equal(nrow(testHead), 0) + expect_equal(ncol(testHead), 2) + + testFirst <- first(df0) + expect_equal(nrow(testFirst), 0) + expect_equal(ncol(testFirst), 2) }) -test_that("distinct() on DataFrames", { +test_that("distinct() and unique on DataFrames", { lines <- c("{\"name\":\"Michael\"}", "{\"name\":\"Andy\", \"age\":30}", "{\"name\":\"Justin\", \"age\":19}", @@ -463,15 +612,19 @@ test_that("distinct() on DataFrames", { df <- jsonFile(sqlContext, jsonPathWithDup) uniques <- distinct(df) - expect_true(inherits(uniques, "DataFrame")) - expect_true(count(uniques) == 3) + expect_is(uniques, "DataFrame") + expect_equal(count(uniques), 3) + + uniques2 <- unique(df) + expect_is(uniques2, "DataFrame") + expect_equal(count(uniques2), 3) }) test_that("sample on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) sampled <- sample(df, FALSE, 1.0) expect_equal(nrow(collect(sampled)), count(df)) - expect_true(inherits(sampled, "DataFrame")) + expect_is(sampled, "DataFrame") sampled2 <- sample(df, FALSE, 0.1) expect_true(count(sampled2) < 3) @@ -482,15 +635,15 @@ test_that("sample on a DataFrame", { test_that("select operators", { df <- select(jsonFile(sqlContext, jsonPath), "name", "age") - expect_true(inherits(df$name, "Column")) - expect_true(inherits(df[[2]], "Column")) - expect_true(inherits(df[["age"]], "Column")) + expect_is(df$name, "Column") + expect_is(df[[2]], "Column") + expect_is(df[["age"]], "Column") - expect_true(inherits(df[,1], "DataFrame")) + expect_is(df[,1], "DataFrame") expect_equal(columns(df[,1]), c("name")) expect_equal(columns(df[,"age"]), c("age")) df2 <- df[,c("age", "name")] - expect_true(inherits(df2, "DataFrame")) + expect_is(df2, "DataFrame") expect_equal(columns(df2), c("age", "name")) df$age2 <- df$age @@ -509,50 +662,91 @@ test_that("select operators", { test_that("select with column", { df <- jsonFile(sqlContext, jsonPath) df1 <- select(df, "name") - expect_true(columns(df1) == c("name")) - expect_true(count(df1) == 3) + expect_equal(columns(df1), c("name")) + expect_equal(count(df1), 3) df2 <- select(df, df$age) - expect_true(columns(df2) == c("age")) - expect_true(count(df2) == 3) + expect_equal(columns(df2), c("age")) + expect_equal(count(df2), 3) + + df3 <- select(df, lit("x")) + expect_equal(columns(df3), c("x")) + expect_equal(count(df3), 3) + expect_equal(collect(select(df3, "x"))[[1, 1]], "x") +}) + +test_that("subsetting", { + # jsonFile returns columns in random order + df <- select(jsonFile(sqlContext, jsonPath), "name", "age") + filtered <- df[df$age > 20,] + expect_equal(count(filtered), 1) + expect_equal(columns(filtered), c("name", "age")) + expect_equal(collect(filtered)$name, "Andy") + + df2 <- df[df$age == 19, 1] + expect_is(df2, "DataFrame") + expect_equal(count(df2), 1) + expect_equal(columns(df2), c("name")) + expect_equal(collect(df2)$name, "Justin") + + df3 <- df[df$age > 20, 2] + expect_equal(count(df3), 1) + expect_equal(columns(df3), c("age")) + + df4 <- df[df$age %in% c(19, 30), 1:2] + expect_equal(count(df4), 2) + expect_equal(columns(df4), c("name", "age")) + + df5 <- df[df$age %in% c(19), c(1,2)] + expect_equal(count(df5), 1) + expect_equal(columns(df5), c("name", "age")) + + df6 <- subset(df, df$age %in% c(30), c(1,2)) + expect_equal(count(df6), 1) + expect_equal(columns(df6), c("name", "age")) }) test_that("selectExpr() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) selected <- selectExpr(df, "age * 2") - expect_true(names(selected) == "(age * 2)") + expect_equal(names(selected), "(age * 2)") expect_equal(collect(selected), collect(select(df, df$age * 2L))) selected2 <- selectExpr(df, "name as newName", "abs(age) as age") expect_equal(names(selected2), c("newName", "age")) - expect_true(count(selected2) == 3) + expect_equal(count(selected2), 3) +}) + +test_that("expr() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPath) + expect_equal(collect(select(df, expr("abs(-123)")))[1, 1], 123) }) test_that("column calculation", { df <- jsonFile(sqlContext, jsonPath) d <- collect(select(df, alias(df$age + 1, "age2"))) - expect_true(names(d) == c("age2")) + expect_equal(names(d), c("age2")) df2 <- select(df, lower(df$name), abs(df$age)) - expect_true(inherits(df2, "DataFrame")) - expect_true(count(df2) == 3) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) }) test_that("read.df() from json file", { df <- read.df(sqlContext, jsonPath, "json") - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) # Check if we can apply a user defined schema schema <- structType(structField("name", type = "string"), structField("age", type = "double")) df1 <- read.df(sqlContext, jsonPath, "json", schema) - expect_true(inherits(df1, "DataFrame")) + expect_is(df1, "DataFrame") expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) # Run the same with loadDF df2 <- loadDF(sqlContext, jsonPath, "json", schema) - expect_true(inherits(df2, "DataFrame")) + expect_is(df2, "DataFrame") expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) }) @@ -560,28 +754,29 @@ test_that("write.df() as parquet file", { df <- read.df(sqlContext, jsonPath, "json") write.df(df, parquetPath, "parquet", mode="overwrite") df2 <- read.df(sqlContext, parquetPath, "parquet") - expect_true(inherits(df2, "DataFrame")) - expect_true(count(df2) == 3) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) }) test_that("test HiveContext", { hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, error = function(err) { + }, + error = function(err) { skip("Hive is not build with SparkSQL, skipped") }) df <- createExternalTable(hiveCtx, "json", jsonPath, "json") - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) df2 <- sql(hiveCtx, "select * from json") - expect_true(inherits(df2, "DataFrame")) - expect_true(count(df2) == 3) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") saveAsTable(df, "json", "json", "append", path = jsonPath2) df3 <- sql(hiveCtx, "select * from json") - expect_true(inherits(df3, "DataFrame")) - expect_true(count(df3) == 6) + expect_is(df3, "DataFrame") + expect_equal(count(df3), 6) }) test_that("column operators", { @@ -594,17 +789,34 @@ test_that("column operators", { test_that("column functions", { c <- SparkR:::col("a") - c2 <- min(c) + max(c) + sum(c) + avg(c) + count(c) + abs(c) + sqrt(c) - c3 <- lower(c) + upper(c) + first(c) + last(c) - c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string") - c5 <- n(c) + n_distinct(c) - c5 <- acos(c) + asin(c) + atan(c) + cbrt(c) - c6 <- ceiling(c) + cos(c) + cosh(c) + exp(c) + expm1(c) - c7 <- floor(c) + log(c) + log10(c) + log1p(c) + rint(c) - c8 <- sign(c) + sin(c) + sinh(c) + tan(c) + tanh(c) - c9 <- toDegrees(c) + toRadians(c) -}) + c1 <- abs(c) + acos(c) + approxCountDistinct(c) + ascii(c) + asin(c) + atan(c) + c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c) + c3 <- cosh(c) + count(c) + crc32(c) + exp(c) + c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c) + c5 <- hour(c) + initcap(c) + isNaN(c) + last(c) + last_day(c) + length(c) + c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c) + c7 <- mean(c) + min(c) + month(c) + negate(c) + quarter(c) + c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + sha1(c) + c9 <- signum(c) + sin(c) + sinh(c) + size(c) + soundex(c) + sqrt(c) + sum(c) + c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c) + c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) + df <- jsonFile(sqlContext, jsonPath) + df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20))) + expect_equal(collect(df2)[[2, 1]], TRUE) + expect_equal(collect(df2)[[2, 2]], FALSE) + expect_equal(collect(df2)[[3, 1]], FALSE) + expect_equal(collect(df2)[[3, 2]], TRUE) + + df3 <- select(df, between(df$name, c("Apache", "Spark"))) + expect_equal(collect(df3)[[1, 1]], TRUE) + expect_equal(collect(df3)[[2, 1]], FALSE) + expect_equal(collect(df3)[[3, 1]], TRUE) + + df4 <- createDataFrame(sqlContext, list(list(a = "010101"))) + expect_equal(collect(select(df4, conv(df4$a, 2, 16)))[1, 1], "15") +}) +# test_that("column binary mathfunctions", { lines <- c("{\"a\":1, \"b\":5}", "{\"a\":2, \"b\":6}", @@ -617,10 +829,19 @@ test_that("column binary mathfunctions", { expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6)) expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7)) expect_equal(collect(select(df, atan2(df$a, df$b)))[4, "ATAN2(a, b)"], atan2(4, 8)) + ## nolint start expect_equal(collect(select(df, hypot(df$a, df$b)))[1, "HYPOT(a, b)"], sqrt(1^2 + 5^2)) expect_equal(collect(select(df, hypot(df$a, df$b)))[2, "HYPOT(a, b)"], sqrt(2^2 + 6^2)) expect_equal(collect(select(df, hypot(df$a, df$b)))[3, "HYPOT(a, b)"], sqrt(3^2 + 7^2)) expect_equal(collect(select(df, hypot(df$a, df$b)))[4, "HYPOT(a, b)"], sqrt(4^2 + 8^2)) + ## nolint end + expect_equal(collect(select(df, shiftLeft(df$b, 1)))[4, 1], 16) + expect_equal(collect(select(df, shiftRight(df$b, 1)))[4, 1], 4) + expect_equal(collect(select(df, shiftRightUnsigned(df$b, 1)))[4, 1], 4) + expect_equal(class(collect(select(df, rand()))[2, 1]), "numeric") + expect_equal(collect(select(df, rand(1)))[1, 1], 0.45, tolerance = 0.01) + expect_equal(class(collect(select(df, randn()))[2, 1]), "numeric") + expect_equal(collect(select(df, randn(1)))[1, 1], -0.0111, tolerance = 0.01) }) test_that("string operators", { @@ -629,73 +850,171 @@ test_that("string operators", { expect_equal(count(where(df, startsWith(df$name, "A"))), 1) expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") expect_equal(collect(select(df, cast(df$age, "string")))[[2, 1]], "30") + expect_equal(collect(select(df, concat(df$name, lit(":"), df$age)))[[2, 1]], "Andy:30") + expect_equal(collect(select(df, concat_ws(":", df$name)))[[2, 1]], "Andy") + expect_equal(collect(select(df, concat_ws(":", df$name, df$age)))[[2, 1]], "Andy:30") + expect_equal(collect(select(df, instr(df$name, "i")))[, 1], c(2, 0, 5)) + expect_equal(collect(select(df, format_number(df$age, 2)))[2, 1], "30.00") + expect_equal(collect(select(df, sha1(df$name)))[2, 1], + "ab5a000e88b5d9d0fa2575f5c6263eb93452405d") + expect_equal(collect(select(df, sha2(df$name, 256)))[2, 1], + "80f2aed3c618c423ddf05a2891229fba44942d907173152442cf6591441ed6dc") + expect_equal(collect(select(df, format_string("Name:%s", df$name)))[2, 1], "Name:Andy") + expect_equal(collect(select(df, format_string("%s, %d", df$name, df$age)))[2, 1], "Andy, 30") + expect_equal(collect(select(df, regexp_extract(df$name, "(n.y)", 1)))[2, 1], "ndy") + expect_equal(collect(select(df, regexp_replace(df$name, "(n.y)", "ydn")))[2, 1], "Aydn") + + l2 <- list(list(a = "aaads")) + df2 <- createDataFrame(sqlContext, l2) + expect_equal(collect(select(df2, locate("aa", df2$a)))[1, 1], 1) + expect_equal(collect(select(df2, locate("aa", df2$a, 1)))[1, 1], 2) + expect_equal(collect(select(df2, lpad(df2$a, 8, "#")))[1, 1], "###aaads") + expect_equal(collect(select(df2, rpad(df2$a, 8, "#")))[1, 1], "aaads###") + + l3 <- list(list(a = "a.b.c.d")) + df3 <- createDataFrame(sqlContext, l3) + expect_equal(collect(select(df3, substring_index(df3$a, ".", 2)))[1, 1], "a.b") + expect_equal(collect(select(df3, substring_index(df3$a, ".", -3)))[1, 1], "b.c.d") + expect_equal(collect(select(df3, translate(df3$a, "bc", "12")))[1, 1], "a.1.2.d") +}) + +test_that("date functions on a DataFrame", { + .originalTimeZone <- Sys.getenv("TZ") + Sys.setenv(TZ = "UTC") + l <- list(list(a = 1L, b = as.Date("2012-12-13")), + list(a = 2L, b = as.Date("2013-12-14")), + list(a = 3L, b = as.Date("2014-12-15"))) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(select(df, dayofmonth(df$b)))[, 1], c(13, 14, 15)) + expect_equal(collect(select(df, dayofyear(df$b)))[, 1], c(348, 348, 349)) + expect_equal(collect(select(df, weekofyear(df$b)))[, 1], c(50, 50, 51)) + expect_equal(collect(select(df, year(df$b)))[, 1], c(2012, 2013, 2014)) + expect_equal(collect(select(df, month(df$b)))[, 1], c(12, 12, 12)) + expect_equal(collect(select(df, last_day(df$b)))[, 1], + c(as.Date("2012-12-31"), as.Date("2013-12-31"), as.Date("2014-12-31"))) + expect_equal(collect(select(df, next_day(df$b, "MONDAY")))[, 1], + c(as.Date("2012-12-17"), as.Date("2013-12-16"), as.Date("2014-12-22"))) + expect_equal(collect(select(df, date_format(df$b, "y")))[, 1], c("2012", "2013", "2014")) + expect_equal(collect(select(df, add_months(df$b, 3)))[, 1], + c(as.Date("2013-03-13"), as.Date("2014-03-14"), as.Date("2015-03-15"))) + expect_equal(collect(select(df, date_add(df$b, 1)))[, 1], + c(as.Date("2012-12-14"), as.Date("2013-12-15"), as.Date("2014-12-16"))) + expect_equal(collect(select(df, date_sub(df$b, 1)))[, 1], + c(as.Date("2012-12-12"), as.Date("2013-12-13"), as.Date("2014-12-14"))) + + l2 <- list(list(a = 1L, b = as.POSIXlt("2012-12-13 12:34:00", tz = "UTC")), + list(a = 2L, b = as.POSIXlt("2014-12-15 01:24:34", tz = "UTC"))) + df2 <- createDataFrame(sqlContext, l2) + expect_equal(collect(select(df2, minute(df2$b)))[, 1], c(34, 24)) + expect_equal(collect(select(df2, second(df2$b)))[, 1], c(0, 34)) + expect_equal(collect(select(df2, from_utc_timestamp(df2$b, "JST")))[, 1], + c(as.POSIXlt("2012-12-13 21:34:00 UTC"), as.POSIXlt("2014-12-15 10:24:34 UTC"))) + expect_equal(collect(select(df2, to_utc_timestamp(df2$b, "JST")))[, 1], + c(as.POSIXlt("2012-12-13 03:34:00 UTC"), as.POSIXlt("2014-12-14 16:24:34 UTC"))) + expect_more_than(collect(select(df2, unix_timestamp()))[1, 1], 0) + expect_more_than(collect(select(df2, unix_timestamp(df2$b)))[1, 1], 0) + expect_more_than(collect(select(df2, unix_timestamp(lit("2015-01-01"), "yyyy-MM-dd")))[1, 1], 0) + + l3 <- list(list(a = 1000), list(a = -1000)) + df3 <- createDataFrame(sqlContext, l3) + result31 <- collect(select(df3, from_unixtime(df3$a))) + expect_equal(grep("\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}", result31[, 1], perl = TRUE), + c(1, 2)) + result32 <- collect(select(df3, from_unixtime(df3$a, "yyyy"))) + expect_equal(grep("\\d{4}", result32[, 1]), c(1, 2)) + Sys.setenv(TZ = .originalTimeZone) +}) + +test_that("greatest() and least() on a DataFrame", { + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(select(df, greatest(df$a, df$b)))[, 1], c(2, 4)) + expect_equal(collect(select(df, least(df$a, df$b)))[, 1], c(1, 3)) +}) + +test_that("when(), otherwise() and ifelse() on a DataFrame", { + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(select(df, when(df$a > 1 & df$b > 2, 1)))[, 1], c(NA, 1)) + expect_equal(collect(select(df, otherwise(when(df$a > 1, 1), 0)))[, 1], c(0, 1)) + expect_equal(collect(select(df, ifelse(df$a > 1 & df$b > 2, 0, 1)))[, 1], c(1, 0)) }) test_that("group by", { df <- jsonFile(sqlContext, jsonPath) df1 <- agg(df, name = "max", age = "sum") - expect_true(1 == count(df1)) + expect_equal(1, count(df1)) df1 <- agg(df, age2 = max(df$age)) - expect_true(1 == count(df1)) + expect_equal(1, count(df1)) expect_equal(columns(df1), c("age2")) gd <- groupBy(df, "name") - expect_true(inherits(gd, "GroupedData")) + expect_is(gd, "GroupedData") df2 <- count(gd) - expect_true(inherits(df2, "DataFrame")) - expect_true(3 == count(df2)) + expect_is(df2, "DataFrame") + expect_equal(3, count(df2)) # Also test group_by, summarize, mean gd1 <- group_by(df, "name") - expect_true(inherits(gd1, "GroupedData")) + expect_is(gd1, "GroupedData") df_summarized <- summarize(gd, mean_age = mean(df$age)) - expect_true(inherits(df_summarized, "DataFrame")) - expect_true(3 == count(df_summarized)) + expect_is(df_summarized, "DataFrame") + expect_equal(3, count(df_summarized)) df3 <- agg(gd, age = "sum") - expect_true(inherits(df3, "DataFrame")) - expect_true(3 == count(df3)) + expect_is(df3, "DataFrame") + expect_equal(3, count(df3)) df3 <- agg(gd, age = sum(df$age)) - expect_true(inherits(df3, "DataFrame")) - expect_true(3 == count(df3)) + expect_is(df3, "DataFrame") + expect_equal(3, count(df3)) expect_equal(columns(df3), c("name", "age")) df4 <- sum(gd, "age") - expect_true(inherits(df4, "DataFrame")) - expect_true(3 == count(df4)) - expect_true(3 == count(mean(gd, "age"))) - expect_true(3 == count(max(gd, "age"))) + expect_is(df4, "DataFrame") + expect_equal(3, count(df4)) + expect_equal(3, count(mean(gd, "age"))) + expect_equal(3, count(max(gd, "age"))) }) test_that("arrange() and orderBy() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) sorted <- arrange(df, df$age) - expect_true(collect(sorted)[1,2] == "Michael") + expect_equal(collect(sorted)[1,2], "Michael") sorted2 <- arrange(df, "name") - expect_true(collect(sorted2)[2,"age"] == 19) + expect_equal(collect(sorted2)[2,"age"], 19) sorted3 <- orderBy(df, asc(df$age)) expect_true(is.na(first(sorted3)$age)) - expect_true(collect(sorted3)[2, "age"] == 19) + expect_equal(collect(sorted3)[2, "age"], 19) sorted4 <- orderBy(df, desc(df$name)) - expect_true(first(sorted4)$name == "Michael") - expect_true(collect(sorted4)[3,"name"] == "Andy") + expect_equal(first(sorted4)$name, "Michael") + expect_equal(collect(sorted4)[3,"name"], "Andy") }) test_that("filter() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) filtered <- filter(df, "age > 20") - expect_true(count(filtered) == 1) - expect_true(collect(filtered)$name == "Andy") + expect_equal(count(filtered), 1) + expect_equal(collect(filtered)$name, "Andy") filtered2 <- where(df, df$name != "Michael") - expect_true(count(filtered2) == 2) - expect_true(collect(filtered2)$age[2] == 19) + expect_equal(count(filtered2), 2) + expect_equal(collect(filtered2)$age[2], 19) + + # test suites for %in% + filtered3 <- filter(df, "age in (19)") + expect_equal(count(filtered3), 1) + filtered4 <- filter(df, "age in (19, 30)") + expect_equal(count(filtered4), 2) + filtered5 <- where(df, df$age %in% c(19)) + expect_equal(count(filtered5), 1) + filtered6 <- where(df, df$age %in% c(19, 30)) + expect_equal(count(filtered6), 2) }) -test_that("join() on a DataFrame", { +test_that("join() and merge() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", @@ -708,36 +1027,49 @@ test_that("join() on a DataFrame", { joined <- join(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) - expect_true(count(joined) == 12) + expect_equal(count(joined), 12) joined2 <- join(df, df2, df$name == df2$name) expect_equal(names(joined2), c("age", "name", "name", "test")) - expect_true(count(joined2) == 3) + expect_equal(count(joined2), 3) joined3 <- join(df, df2, df$name == df2$name, "right_outer") expect_equal(names(joined3), c("age", "name", "name", "test")) - expect_true(count(joined3) == 4) + expect_equal(count(joined3), 4) expect_true(is.na(collect(orderBy(joined3, joined3$age))$age[2])) joined4 <- select(join(df, df2, df$name == df2$name, "outer"), alias(df$age + 5, "newAge"), df$name, df2$test) expect_equal(names(joined4), c("newAge", "name", "test")) - expect_true(count(joined4) == 4) + expect_equal(count(joined4), 4) expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24) + + merged <- select(merge(df, df2, df$name == df2$name, "outer"), + alias(df$age + 5, "newAge"), df$name, df2$test) + expect_equal(names(merged), c("newAge", "name", "test")) + expect_equal(count(merged), 4) + expect_equal(collect(orderBy(merged, joined4$name))$newAge[3], 24) }) test_that("toJSON() returns an RDD of the correct values", { df <- jsonFile(sqlContext, jsonPath) testRDD <- toJSON(df) - expect_true(inherits(testRDD, "RDD")) - expect_true(SparkR:::getSerializedMode(testRDD) == "string") + expect_is(testRDD, "RDD") + expect_equal(SparkR:::getSerializedMode(testRDD), "string") expect_equal(collect(testRDD)[[1]], mockLines[1]) }) test_that("showDF()", { df <- jsonFile(sqlContext, jsonPath) s <- capture.output(showDF(df)) - expect_output(s , "+----+-------+\n| age| name|\n+----+-------+\n|null|Michael|\n| 30| Andy|\n| 19| Justin|\n+----+-------+\n") + expected <- paste("+----+-------+\n", + "| age| name|\n", + "+----+-------+\n", + "|null|Michael|\n", + "| 30| Andy|\n", + "| 19| Justin|\n", + "+----+-------+\n", sep="") + expect_output(s , expected) }) test_that("isLocal()", { @@ -745,7 +1077,7 @@ test_that("isLocal()", { expect_false(isLocal(df)) }) -test_that("unionAll(), except(), and intersect() on a DataFrame", { +test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) lines <- c("{\"name\":\"Bob\", \"age\":24}", @@ -756,50 +1088,73 @@ test_that("unionAll(), except(), and intersect() on a DataFrame", { df2 <- read.df(sqlContext, jsonPath2, "json") unioned <- arrange(unionAll(df, df2), df$age) - expect_true(inherits(unioned, "DataFrame")) - expect_true(count(unioned) == 6) - expect_true(first(unioned)$name == "Michael") + expect_is(unioned, "DataFrame") + expect_equal(count(unioned), 6) + expect_equal(first(unioned)$name, "Michael") + + unioned2 <- arrange(rbind(unioned, df, df2), df$age) + expect_is(unioned2, "DataFrame") + expect_equal(count(unioned2), 12) + expect_equal(first(unioned2)$name, "Michael") excepted <- arrange(except(df, df2), desc(df$age)) - expect_true(inherits(unioned, "DataFrame")) - expect_true(count(excepted) == 2) - expect_true(first(excepted)$name == "Justin") + expect_is(unioned, "DataFrame") + expect_equal(count(excepted), 2) + expect_equal(first(excepted)$name, "Justin") intersected <- arrange(intersect(df, df2), df$age) - expect_true(inherits(unioned, "DataFrame")) - expect_true(count(intersected) == 1) - expect_true(first(intersected)$name == "Andy") + expect_is(unioned, "DataFrame") + expect_equal(count(intersected), 1) + expect_equal(first(intersected)$name, "Andy") }) test_that("withColumn() and withColumnRenamed()", { df <- jsonFile(sqlContext, jsonPath) newDF <- withColumn(df, "newAge", df$age + 2) - expect_true(length(columns(newDF)) == 3) - expect_true(columns(newDF)[3] == "newAge") - expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) newDF2 <- withColumnRenamed(df, "age", "newerAge") - expect_true(length(columns(newDF2)) == 2) - expect_true(columns(newDF2)[1] == "newerAge") + expect_equal(length(columns(newDF2)), 2) + expect_equal(columns(newDF2)[1], "newerAge") }) -test_that("mutate() and rename()", { +test_that("mutate(), transform(), rename() and names()", { df <- jsonFile(sqlContext, jsonPath) newDF <- mutate(df, newAge = df$age + 2) - expect_true(length(columns(newDF)) == 3) - expect_true(columns(newDF)[3] == "newAge") - expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) newDF2 <- rename(df, newerAge = df$age) - expect_true(length(columns(newDF2)) == 2) - expect_true(columns(newDF2)[1] == "newerAge") + expect_equal(length(columns(newDF2)), 2) + expect_equal(columns(newDF2)[1], "newerAge") + + names(newDF2) <- c("newerName", "evenNewerAge") + expect_equal(length(names(newDF2)), 2) + expect_equal(names(newDF2)[1], "newerName") + + transformedDF <- transform(df, newAge = -df$age, newAge2 = df$age / 2) + expect_equal(length(columns(transformedDF)), 4) + expect_equal(columns(transformedDF)[3], "newAge") + expect_equal(columns(transformedDF)[4], "newAge2") + expect_equal(first(filter(transformedDF, transformedDF$name == "Andy"))$newAge, -30) + + # test if transform on local data frames works + # ensure the proper signature is used - otherwise this will fail to run + attach(airquality) + result <- transform(Ozone, logOzone = log(Ozone)) + expect_equal(nrow(result), 153) + expect_equal(ncol(result), 2) + detach(airquality) }) test_that("write.df() on DataFrame and works with parquetFile", { df <- jsonFile(sqlContext, jsonPath) write.df(df, parquetPath, "parquet", mode="overwrite") parquetDF <- parquetFile(sqlContext, parquetPath) - expect_true(inherits(parquetDF, "DataFrame")) + expect_is(parquetDF, "DataFrame") expect_equal(count(df), count(parquetDF)) }) @@ -809,110 +1164,141 @@ test_that("parquetFile works with multiple input paths", { parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") write.df(df, parquetPath2, "parquet", mode="overwrite") parquetDF <- parquetFile(sqlContext, parquetPath, parquetPath2) - expect_true(inherits(parquetDF, "DataFrame")) - expect_true(count(parquetDF) == count(df)*2) + expect_is(parquetDF, "DataFrame") + expect_equal(count(parquetDF), count(df) * 2) + + # Test if varargs works with variables + saveMode <- "overwrite" + mergeSchema <- "true" + parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") + write.df(df, parquetPath2, "parquet", mode = saveMode, mergeSchema = mergeSchema) }) -test_that("describe() on a DataFrame", { +test_that("describe() and summarize() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) stats <- describe(df, "age") expect_equal(collect(stats)[1, "summary"], "count") expect_equal(collect(stats)[2, "age"], "24.5") - expect_equal(collect(stats)[3, "age"], "5.5") + expect_equal(collect(stats)[3, "age"], "7.7781745930520225") stats <- describe(df) expect_equal(collect(stats)[4, "name"], "Andy") expect_equal(collect(stats)[5, "age"], "30") + + stats2 <- summary(df) + expect_equal(collect(stats2)[4, "name"], "Andy") + expect_equal(collect(stats2)[5, "age"], "30") }) -test_that("dropna() on a DataFrame", { +test_that("dropna() and na.omit() on a DataFrame", { df <- jsonFile(sqlContext, jsonPathNa) rows <- collect(df) # drop with columns - + expected <- rows[!is.na(rows$name),] actual <- collect(dropna(df, cols = "name")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) + actual <- collect(na.omit(df, cols = "name")) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age),] actual <- collect(dropna(df, cols = "age")) row.names(expected) <- row.names(actual) # identical on two dataframes does not work here. Don't know why. # use identical on all columns as a workaround. - expect_true(identical(expected$age, actual$age)) - expect_true(identical(expected$height, actual$height)) - expect_true(identical(expected$name, actual$name)) - + expect_identical(expected$age, actual$age) + expect_identical(expected$height, actual$height) + expect_identical(expected$name, actual$name) + actual <- collect(na.omit(df, cols = "age")) + expected <- rows[!is.na(rows$age) & !is.na(rows$height),] actual <- collect(dropna(df, cols = c("age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) + actual <- collect(na.omit(df, cols = c("age", "height"))) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df)) - expect_true(identical(expected, actual)) - + expect_identical(expected, actual) + actual <- collect(na.omit(df)) + expect_identical(expected, actual) + # drop with how expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df)) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) + actual <- collect(na.omit(df)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),] actual <- collect(dropna(df, "all")) - expect_true(identical(expected, actual)) - + expect_identical(expected, actual) + actual <- collect(na.omit(df, "all")) + expect_identical(expected, actual) + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df, "any")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) + actual <- collect(na.omit(df, "any")) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) & !is.na(rows$height),] actual <- collect(dropna(df, "any", cols = c("age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) + actual <- collect(na.omit(df, "any", cols = c("age", "height"))) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) | !is.na(rows$height),] actual <- collect(dropna(df, "all", cols = c("age", "height"))) - expect_true(identical(expected, actual)) - + expect_identical(expected, actual) + actual <- collect(na.omit(df, "all", cols = c("age", "height"))) + expect_identical(expected, actual) + # drop with threshold - + expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,] actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) + actual <- collect(na.omit(df, minNonNulls = 2, cols = c("age", "height"))) + expect_identical(expected, actual) - expected <- rows[as.integer(!is.na(rows$age)) + + expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) + as.integer(!is.na(rows$name)) >= 3,] actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) + actual <- collect(na.omit(df, minNonNulls = 3, cols = c("name", "age", "height"))) + expect_identical(expected, actual) }) test_that("fillna() on a DataFrame", { df <- jsonFile(sqlContext, jsonPathNa) rows <- collect(df) - + # fill with value - + expected <- rows expected$age[is.na(expected$age)] <- 50 expected$height[is.na(expected$height)] <- 50.6 actual <- collect(fillna(df, 50.6)) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows expected$name[is.na(expected$name)] <- "unknown" actual <- collect(fillna(df, "unknown")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows expected$age[is.na(expected$age)] <- 50 actual <- collect(fillna(df, 50.6, "age")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows expected$name[is.na(expected$name)] <- "unknown" actual <- collect(fillna(df, "unknown", c("age", "name"))) - expect_true(identical(expected, actual)) - + expect_identical(expected, actual) + # fill with named list expected <- rows @@ -920,7 +1306,25 @@ test_that("fillna() on a DataFrame", { expected$height[is.na(expected$height)] <- 50.6 expected$name[is.na(expected$name)] <- "unknown" actual <- collect(fillna(df, list("age" = 50, "height" = 50.6, "name" = "unknown"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) +}) + +test_that("crosstab() on a DataFrame", { + rdd <- lapply(parallelize(sc, 0:3), function(x) { + list(paste0("a", x %% 3), paste0("b", x %% 2)) + }) + df <- toDF(rdd, list("a", "b")) + ct <- crosstab(df, "a", "b") + ordered <- ct[order(ct$a_b),] + row.names(ordered) <- NULL + expected <- data.frame("a_b" = c("a0", "a1", "a2"), "b0" = c(1, 0, 1), "b1" = c(1, 1, 0), + stringsAsFactors = FALSE, row.names = NULL) + expect_identical(expected, ordered) +}) + +test_that("SQL error message is returned from JVM", { + retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) + expect_equal(grepl("Table Not Found: blah", retError), TRUE) }) unlink(parquetPath) diff --git a/R/pkg/inst/tests/test_take.R b/R/pkg/inst/tests/test_take.R index 7f4c7c315d78..c2c724cdc762 100644 --- a/R/pkg/inst/tests/test_take.R +++ b/R/pkg/inst/tests/test_take.R @@ -59,9 +59,8 @@ test_that("take() gives back the original elements in correct count and order", expect_equal(take(strListRDD, 3), as.list(head(strList, n = 3))) expect_equal(take(strListRDD2, 1), as.list(head(strList, n = 1))) - expect_true(length(take(strListRDD, 0)) == 0) - expect_true(length(take(strVectorRDD, 0)) == 0) - expect_true(length(take(numListRDD, 0)) == 0) - expect_true(length(take(numVectorRDD, 0)) == 0) + expect_equal(length(take(strListRDD, 0)), 0) + expect_equal(length(take(strVectorRDD, 0)), 0) + expect_equal(length(take(numListRDD, 0)), 0) + expect_equal(length(take(numVectorRDD, 0)), 0) }) - diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R index 6b87b4b3e0b0..a9cf83dbdbdb 100644 --- a/R/pkg/inst/tests/test_textFile.R +++ b/R/pkg/inst/tests/test_textFile.R @@ -20,16 +20,16 @@ context("the textFile() function") # JavaSparkContext handle sc <- sparkR.init() -mockFile = c("Spark is pretty.", "Spark is awesome.") +mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("textFile() on a local file returns an RDD", { fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) rdd <- textFile(sc, fileName) - expect_true(inherits(rdd, "RDD")) + expect_is(rdd, "RDD") expect_true(count(rdd) > 0) - expect_true(count(rdd) == 2) + expect_equal(count(rdd), 2) unlink(fileName) }) @@ -58,7 +58,7 @@ test_that("textFile() word count works as expected", { expected <- list(list("pretty.", 1), list("is", 2), list("awesome.", 1), list("Spark", 2)) expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) - + unlink(fileName) }) @@ -115,13 +115,13 @@ test_that("textFile() and saveAsTextFile() word count works as expected", { saveAsTextFile(counts, fileName2) rdd <- textFile(sc, fileName2) - + output <- collect(rdd) expected <- list(list("awesome.", 1), list("Spark", 2), list("pretty.", 1), list("is", 2)) expectedStr <- lapply(expected, function(x) { toString(x) }) expect_equal(sortKeyValueList(output), sortKeyValueList(expectedStr)) - + unlink(fileName1) unlink(fileName2) }) @@ -133,7 +133,7 @@ test_that("textFile() on multiple paths", { writeLines("Spark is awesome.", fileName2) rdd <- textFile(sc, c(fileName1, fileName2)) - expect_true(count(rdd) == 2) + expect_equal(count(rdd), 2) unlink(fileName1) unlink(fileName2) @@ -159,4 +159,3 @@ test_that("Pipelined operations on RDDs created using textFile", { unlink(fileName) }) - diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R index 539e3a3c19df..12df4cf4f65b 100644 --- a/R/pkg/inst/tests/test_utils.R +++ b/R/pkg/inst/tests/test_utils.R @@ -43,13 +43,13 @@ test_that("serializeToBytes on RDD", { mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) - + text.rdd <- textFile(sc, fileName) - expect_true(getSerializedMode(text.rdd) == "string") + expect_equal(getSerializedMode(text.rdd), "string") ser.rdd <- serializeToBytes(text.rdd) expect_equal(collect(ser.rdd), as.list(mockFile)) - expect_true(getSerializedMode(ser.rdd) == "byte") - + expect_equal(getSerializedMode(ser.rdd), "byte") + unlink(fileName) }) @@ -64,7 +64,7 @@ test_that("cleanClosure on R functions", { expect_equal(actual, y) actual <- get("g", envir = env, inherits = FALSE) expect_equal(actual, g) - + # Test for nested enclosures and package variables. env2 <- new.env() funcEnv <- new.env(parent = env2) @@ -106,7 +106,7 @@ test_that("cleanClosure on R functions", { expect_equal(length(ls(env)), 1) actual <- get("y", envir = env, inherits = FALSE) expect_equal(actual, y) - + # Test for function (and variable) definitions. f <- function(x) { g <- function(y) { y * 2 } @@ -115,11 +115,11 @@ test_that("cleanClosure on R functions", { newF <- cleanClosure(f) env <- environment(newF) expect_equal(length(ls(env)), 0) # "y" and "g" should not be included. - + # Test for overriding variables in base namespace (Issue: SparkR-196). nums <- as.list(1:10) rdd <- parallelize(sc, nums, 2L) - t = 4 # Override base::t in .GlobalEnv. + t <- 4 # Override base::t in .GlobalEnv. f <- function(x) { x > t } newF <- cleanClosure(f) env <- environment(newF) @@ -128,7 +128,7 @@ test_that("cleanClosure on R functions", { actual <- collect(lapply(rdd, f)) expected <- as.list(c(rep(FALSE, 4), rep(TRUE, 6))) expect_equal(actual, expected) - + # Test for broadcast variables. a <- matrix(nrow=10, ncol=10, data=rnorm(100)) aBroadcast <- broadcast(sc, a) diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index 7e3b5fc403b2..0c3b0d1f4be2 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -94,7 +94,7 @@ if (isEmpty != 0) { } else if (deserializer == "string") { data <- as.list(readLines(inputCon)) } else if (deserializer == "row") { - data <- SparkR:::readDeserializeRows(inputCon) + data <- SparkR:::readMultipleObjects(inputCon) } # Timing reading input data for execution inputElap <- elapsedSecs() @@ -120,7 +120,7 @@ if (isEmpty != 0) { } else if (deserializer == "string") { data <- readLines(inputCon) } else if (deserializer == "row") { - data <- SparkR:::readDeserializeRows(inputCon) + data <- SparkR:::readMultipleObjects(inputCon) } # Timing reading input data for execution inputElap <- elapsedSecs() diff --git a/README.md b/README.md index 380422ca00db..76e29b423566 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Apache Spark Spark is a fast and general cluster computing system for Big Data. It provides -high-level APIs in Scala, Java, and Python, and an optimized engine that +high-level APIs in Scala, Java, Python, and R, and an optimized engine that supports general computation graphs for data analysis. It also supports a rich set of higher-level tools including Spark SQL for SQL and DataFrames, MLlib for machine learning, GraphX for graph processing, @@ -94,5 +94,5 @@ distribution. ## Configuration -Please refer to the [Configuration guide](http://spark.apache.org/docs/latest/configuration.html) +Please refer to the [Configuration Guide](http://spark.apache.org/docs/latest/configuration.html) in the online documentation for an overview on how to configure Spark. diff --git a/assembly/pom.xml b/assembly/pom.xml index e9c6d26ccddc..4b60ee00ffbe 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/bagel/pom.xml b/bagel/pom.xml index ed5c37e595a9..3baf8d47b4dc 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala index ef0bb2ac13f0..8399033ac61e 100644 --- a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala +++ b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala @@ -22,6 +22,7 @@ import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") object Bagel extends Logging { val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK @@ -78,7 +79,7 @@ object Bagel extends Logging { val startTime = System.currentTimeMillis val aggregated = agg(verts, aggregator) - val combinedMsgs = msgs.combineByKey( + val combinedMsgs = msgs.combineByKeyWithClassTag( combiner.createCombiner _, combiner.mergeMsg _, combiner.mergeCombiners _, partitioner) val grouped = combinedMsgs.groupWith(verts) val superstep_ = superstep // Create a read-only copy of superstep for capture in closure @@ -270,18 +271,21 @@ object Bagel extends Logging { } } +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Combiner[M, C] { def createCombiner(msg: M): C def mergeMsg(combiner: C, msg: M): C def mergeCombiners(a: C, b: C): C } +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Aggregator[V, A] { def createAggregator(vert: V): A def mergeAggregators(a: A, b: A): A } /** Default combiner that simply appends messages together (i.e. performs no aggregation) */ +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializable { def createCombiner(msg: M): Array[M] = Array(msg) @@ -297,6 +301,7 @@ class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializab * Subclasses may store state along with each vertex and must * inherit from java.io.Serializable or scala.Serializable. */ +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Vertex { def active: Boolean } @@ -307,6 +312,7 @@ trait Vertex { * Subclasses may contain a payload to deliver to the target vertex * and must inherit from java.io.Serializable or scala.Serializable. */ +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Message[K] { def targetId: K } diff --git a/bin/pyspark b/bin/pyspark index f9dbddfa5356..8f2a3b5a7717 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -82,4 +82,4 @@ fi export PYSPARK_DRIVER_PYTHON export PYSPARK_DRIVER_PYTHON_OPTS -exec "$SPARK_HOME"/bin/spark-submit pyspark-shell-main "$@" +exec "$SPARK_HOME"/bin/spark-submit pyspark-shell-main --name "PySparkShell" "$@" diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 45e9e3def512..3c6169983e76 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -35,4 +35,4 @@ set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.8.2.1-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py -call %SPARK_HOME%\bin\spark-submit2.cmd pyspark-shell-main %* +call %SPARK_HOME%\bin\spark-submit2.cmd pyspark-shell-main --name "PySparkShell" %* diff --git a/bin/spark-class b/bin/spark-class index 2b59e5df5736..e38e08dec40e 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -43,17 +43,19 @@ else fi num_jars="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" | wc -l)" -if [ "$num_jars" -eq "0" -a -z "$SPARK_ASSEMBLY_JAR" ]; then +if [ "$num_jars" -eq "0" -a -z "$SPARK_ASSEMBLY_JAR" -a "$SPARK_PREPEND_CLASSES" != "1" ]; then echo "Failed to find Spark assembly in $ASSEMBLY_DIR." 1>&2 echo "You need to build Spark before running this program." 1>&2 exit 1 fi -ASSEMBLY_JARS="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" || true)" -if [ "$num_jars" -gt "1" ]; then - echo "Found multiple Spark assembly jars in $ASSEMBLY_DIR:" 1>&2 - echo "$ASSEMBLY_JARS" 1>&2 - echo "Please remove all but one jar." 1>&2 - exit 1 +if [ -d "$ASSEMBLY_DIR" ]; then + ASSEMBLY_JARS="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" || true)" + if [ "$num_jars" -gt "1" ]; then + echo "Found multiple Spark assembly jars in $ASSEMBLY_DIR:" 1>&2 + echo "$ASSEMBLY_JARS" 1>&2 + echo "Please remove all but one jar." 1>&2 + exit 1 + fi fi SPARK_ASSEMBLY_JAR="${ASSEMBLY_DIR}/${ASSEMBLY_JARS}" diff --git a/bin/spark-shell b/bin/spark-shell index a6dc863d83fc..00ab7afd118b 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -47,11 +47,11 @@ function main() { # (see https://github.com/sbt/sbt/issues/562). stty -icanon min 1 -echo > /dev/null 2>&1 export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix" - "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "$@" + "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@" stty icanon echo > /dev/null 2>&1 else export SPARK_SUBMIT_OPTS - "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "$@" + "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@" fi } diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd index 251309d67f86..b9b0f510d7f5 100644 --- a/bin/spark-shell2.cmd +++ b/bin/spark-shell2.cmd @@ -32,4 +32,4 @@ if "x%SPARK_SUBMIT_OPTS%"=="x" ( set SPARK_SUBMIT_OPTS="%SPARK_SUBMIT_OPTS% -Dscala.usejavacp=true" :run_shell -%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %* +%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main --name "Spark shell" %* diff --git a/build/mvn b/build/mvn index e8364181e823..ec0380afad31 100755 --- a/build/mvn +++ b/build/mvn @@ -51,11 +51,11 @@ install_app() { # check if we have curl installed # download application [ ! -f "${local_tarball}" ] && [ $(command -v curl) ] && \ - echo "exec: curl ${curl_opts} ${remote_tarball}" && \ + echo "exec: curl ${curl_opts} ${remote_tarball}" 1>&2 && \ curl ${curl_opts} "${remote_tarball}" > "${local_tarball}" # if the file still doesn't exist, lets try `wget` and cross our fingers [ ! -f "${local_tarball}" ] && [ $(command -v wget) ] && \ - echo "exec: wget ${wget_opts} ${remote_tarball}" && \ + echo "exec: wget ${wget_opts} ${remote_tarball}" 1>&2 && \ wget ${wget_opts} -O "${local_tarball}" "${remote_tarball}" # if both were unsuccessful, exit [ ! -f "${local_tarball}" ] && \ @@ -82,7 +82,7 @@ install_mvn() { # Install zinc under the build/ folder install_zinc() { local zinc_path="zinc-0.3.5.3/bin/zinc" - [ ! -f "${zinc_path}" ] && ZINC_INSTALL_FLAG=1 + [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 install_app \ "http://downloads.typesafe.com/zinc/0.3.5.3" \ "zinc-0.3.5.3.tgz" \ @@ -112,10 +112,17 @@ install_scala() { # the environment ZINC_PORT=${ZINC_PORT:-"3030"} +# Check for the `--force` flag dictating that `mvn` should be downloaded +# regardless of whether the system already has a `mvn` install +if [ "$1" == "--force" ]; then + FORCE_MVN=1 + shift +fi + # Install Maven if necessary MVN_BIN="$(command -v mvn)" -if [ ! "$MVN_BIN" ]; then +if [ ! "$MVN_BIN" -o -n "$FORCE_MVN" ]; then install_mvn fi @@ -128,9 +135,9 @@ cd "${_CALLING_DIR}" # Now that zinc is ensured to be installed, check its status and, if its # not running or just installed, start it -if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`${ZINC_BIN} -status`" ]; then +if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`${ZINC_BIN} -status -port ${ZINC_PORT}`" ]; then export ZINC_OPTS=${ZINC_OPTS:-"$_COMPILE_JVM_OPTS"} - ${ZINC_BIN} -shutdown + ${ZINC_BIN} -shutdown -port ${ZINC_PORT} ${ZINC_BIN} -start -port ${ZINC_PORT} \ -scala-compiler "${SCALA_COMPILER}" \ -scala-library "${SCALA_LIBRARY}" &>/dev/null @@ -139,5 +146,7 @@ fi # Set any `mvn` options if not already present export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"} +echo "Using \`mvn\` from path: $MVN_BIN" 1>&2 + # Last, call the `mvn` command as usual -${MVN_BIN} "$@" +${MVN_BIN} -DzincPort=${ZINC_PORT} "$@" diff --git a/build/sbt-launch-lib.bash b/build/sbt-launch-lib.bash index 504be48b358f..615f84839465 100755 --- a/build/sbt-launch-lib.bash +++ b/build/sbt-launch-lib.bash @@ -38,8 +38,7 @@ dlog () { acquire_sbt_jar () { SBT_VERSION=`awk -F "=" '/sbt\.version/ {print $2}' ./project/build.properties` - URL1=http://typesafe.artifactoryonline.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar - URL2=http://repo.typesafe.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar + URL1=https://dl.bintray.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar JAR=build/sbt-launch-${SBT_VERSION}.jar sbt_jar=$JAR @@ -51,9 +50,11 @@ acquire_sbt_jar () { printf "Attempting to fetch sbt\n" JAR_DL="${JAR}.part" if [ $(command -v curl) ]; then - (curl --silent ${URL1} > "${JAR_DL}" || curl --silent ${URL2} > "${JAR_DL}") && mv "${JAR_DL}" "${JAR}" + curl --fail --location --silent ${URL1} > "${JAR_DL}" &&\ + mv "${JAR_DL}" "${JAR}" elif [ $(command -v wget) ]; then - (wget --quiet ${URL1} -O "${JAR_DL}" || wget --quiet ${URL2} -O "${JAR_DL}") && mv "${JAR_DL}" "${JAR}" + wget --quiet ${URL1} -O "${JAR_DL}" &&\ + mv "${JAR_DL}" "${JAR}" else printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n" exit -1 diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index 3a2a88219818..74c5cea94403 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -10,3 +10,9 @@ log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO +log4j.logger.org.apache.parquet=ERROR +log4j.logger.parquet=ERROR + +# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support +log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL +log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 43c4288912b1..c05fe381a36a 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -22,7 +22,7 @@ # - SPARK_EXECUTOR_INSTANCES, Number of workers to start (Default: 2) # - SPARK_EXECUTOR_CORES, Number of cores for the workers (Default: 1). # - SPARK_EXECUTOR_MEMORY, Memory per Worker (e.g. 1000M, 2G) (Default: 1G) -# - SPARK_DRIVER_MEMORY, Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb) +# - SPARK_DRIVER_MEMORY, Memory for Master (e.g. 1000M, 2G) (Default: 1G) # - SPARK_YARN_APP_NAME, The name of your application (Default: Spark) # - SPARK_YARN_QUEUE, The hadoop queue to use for allocation requests (Default: ‘default’) # - SPARK_YARN_DIST_FILES, Comma separated list of files to be distributed with the job. @@ -38,6 +38,7 @@ # - SPARK_WORKER_INSTANCES, to set the number of worker processes per node # - SPARK_WORKER_DIR, to set the working directory of worker processes # - SPARK_WORKER_OPTS, to set config properties only for the worker (e.g. "-Dx=y") +# - SPARK_DAEMON_MEMORY, to allocate to the master, worker and history server themselves (default: 1g). # - SPARK_HISTORY_OPTS, to set config properties only for the history server (e.g. "-Dx=y") # - SPARK_SHUFFLE_OPTS, to set config properties only for the external shuffle service (e.g. "-Dx=y") # - SPARK_DAEMON_JAVA_OPTS, to set config properties for all daemons (e.g. "-Dx=y") diff --git a/core/pom.xml b/core/pom.xml index 40a64beccdc2..e31d90f60889 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml @@ -34,6 +34,11 @@ Spark Project Core http://spark.apache.org/ + + org.apache.avro + avro-mapred + ${avro.mapred.classifier} + com.google.guava guava @@ -41,44 +46,14 @@ com.twitter chill_${scala.binary.version} - - - org.ow2.asm - asm - - - org.ow2.asm - asm-commons - - com.twitter chill-java - - - org.ow2.asm - asm - - - org.ow2.asm - asm-commons - - org.apache.hadoop hadoop-client - - - javax.servlet - servlet-api - - - org.codehaus.jackson - jackson-mapper-asl - - org.apache.spark @@ -271,7 +246,7 @@ com.fasterxml.jackson.module - jackson-module-scala_2.10 + jackson-module-scala_${scala.binary.version} org.apache.derby @@ -291,7 +266,7 @@ org.tachyonproject tachyon-client - 0.6.4 + 0.7.1 org.apache.hadoop @@ -299,39 +274,23 @@ org.apache.curator - curator-recipes + curator-client - org.eclipse.jetty - jetty-jsp - - - org.eclipse.jetty - jetty-webapp - - - org.eclipse.jetty - jetty-server - - - org.eclipse.jetty - jetty-servlet - - - junit - junit + org.apache.curator + curator-framework - org.powermock - powermock-module-junit4 + org.apache.curator + curator-recipes - org.powermock - powermock-api-mockito + org.tachyonproject + tachyon-underfs-glusterfs - org.apache.curator - curator-test + org.tachyonproject + tachyon-underfs-s3 @@ -353,28 +312,28 @@ test - org.mockito - mockito-all + org.hamcrest + hamcrest-core test - org.scalacheck - scalacheck_${scala.binary.version} + org.hamcrest + hamcrest-library test - junit - junit + org.mockito + mockito-core test - org.hamcrest - hamcrest-core + org.scalacheck + scalacheck_${scala.binary.version} test - org.hamcrest - hamcrest-library + junit + junit test @@ -382,6 +341,11 @@ junit-interface test + + org.apache.curator + curator-test + test + net.razorvine pyrolite diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java index 646496f31350..fa9acf0a15b8 100644 --- a/core/src/main/java/org/apache/spark/JavaSparkListener.java +++ b/core/src/main/java/org/apache/spark/JavaSparkListener.java @@ -17,23 +17,7 @@ package org.apache.spark; -import org.apache.spark.scheduler.SparkListener; -import org.apache.spark.scheduler.SparkListenerApplicationEnd; -import org.apache.spark.scheduler.SparkListenerApplicationStart; -import org.apache.spark.scheduler.SparkListenerBlockManagerAdded; -import org.apache.spark.scheduler.SparkListenerBlockManagerRemoved; -import org.apache.spark.scheduler.SparkListenerEnvironmentUpdate; -import org.apache.spark.scheduler.SparkListenerExecutorAdded; -import org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate; -import org.apache.spark.scheduler.SparkListenerExecutorRemoved; -import org.apache.spark.scheduler.SparkListenerJobEnd; -import org.apache.spark.scheduler.SparkListenerJobStart; -import org.apache.spark.scheduler.SparkListenerStageCompleted; -import org.apache.spark.scheduler.SparkListenerStageSubmitted; -import org.apache.spark.scheduler.SparkListenerTaskEnd; -import org.apache.spark.scheduler.SparkListenerTaskGettingResult; -import org.apache.spark.scheduler.SparkListenerTaskStart; -import org.apache.spark.scheduler.SparkListenerUnpersistRDD; +import org.apache.spark.scheduler.*; /** * Java clients should extend this class instead of implementing @@ -94,4 +78,8 @@ public void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { } @Override public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { } + + @Override + public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { } + } diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index fbc566695905..1214d05ba606 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -112,4 +112,10 @@ public final void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { public final void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { onEvent(executorRemoved); } + + @Override + public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { + onEvent(blockUpdated); + } + } diff --git a/core/src/main/scala/org/apache/spark/annotation/AlphaComponent.java b/core/src/main/java/org/apache/spark/annotation/AlphaComponent.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/AlphaComponent.java rename to core/src/main/java/org/apache/spark/annotation/AlphaComponent.java diff --git a/core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java b/core/src/main/java/org/apache/spark/annotation/DeveloperApi.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java rename to core/src/main/java/org/apache/spark/annotation/DeveloperApi.java diff --git a/core/src/main/scala/org/apache/spark/annotation/Experimental.java b/core/src/main/java/org/apache/spark/annotation/Experimental.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/Experimental.java rename to core/src/main/java/org/apache/spark/annotation/Experimental.java diff --git a/core/src/main/scala/org/apache/spark/annotation/Private.java b/core/src/main/java/org/apache/spark/annotation/Private.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/Private.java rename to core/src/main/java/org/apache/spark/annotation/Private.java diff --git a/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java b/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java index 2090efd3b999..d4c42b38ac22 100644 --- a/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java +++ b/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java @@ -23,11 +23,13 @@ // See // http://scala-programming-language.1934581.n4.nabble.com/Workaround-for-implementing-java-varargs-in-2-7-2-final-tp1944767p1944772.html abstract class JavaSparkContextVarargsWorkaround { - public JavaRDD union(JavaRDD... rdds) { + + @SafeVarargs + public final JavaRDD union(JavaRDD... rdds) { if (rdds.length == 0) { throw new IllegalArgumentException("Union called on empty list"); } - ArrayList> rest = new ArrayList>(rdds.length - 1); + List> rest = new ArrayList<>(rdds.length - 1); for (int i = 1; i < rdds.length; i++) { rest.add(rdds[i]); } @@ -38,18 +40,19 @@ public JavaDoubleRDD union(JavaDoubleRDD... rdds) { if (rdds.length == 0) { throw new IllegalArgumentException("Union called on empty list"); } - ArrayList rest = new ArrayList(rdds.length - 1); + List rest = new ArrayList<>(rdds.length - 1); for (int i = 1; i < rdds.length; i++) { rest.add(rdds[i]); } return union(rdds[0], rest); } - public JavaPairRDD union(JavaPairRDD... rdds) { + @SafeVarargs + public final JavaPairRDD union(JavaPairRDD... rdds) { if (rdds.length == 0) { throw new IllegalArgumentException("Union called on empty list"); } - ArrayList> rest = new ArrayList>(rdds.length - 1); + List> rest = new ArrayList<>(rdds.length - 1); for (int i = 1; i < rdds.length; i++) { rest.add(rdds[i]); } @@ -57,7 +60,7 @@ public JavaPairRDD union(JavaPairRDD... rdds) { } // These methods take separate "first" and "rest" elements to avoid having the same type erasure - abstract public JavaRDD union(JavaRDD first, List> rest); - abstract public JavaDoubleRDD union(JavaDoubleRDD first, List rest); - abstract public JavaPairRDD union(JavaPairRDD first, List> rest); + public abstract JavaRDD union(JavaRDD first, List> rest); + public abstract JavaDoubleRDD union(JavaDoubleRDD first, List rest); + public abstract JavaPairRDD union(JavaPairRDD first, List> rest); } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java similarity index 86% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java rename to core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java index 3f746b886bc9..0e58bb4f7101 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java +++ b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.serializer; import java.io.IOException; import java.io.InputStream; @@ -24,10 +24,8 @@ import scala.reflect.ClassTag; -import org.apache.spark.serializer.DeserializationStream; -import org.apache.spark.serializer.SerializationStream; -import org.apache.spark.serializer.SerializerInstance; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.annotation.Private; +import org.apache.spark.unsafe.Platform; /** * Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. @@ -35,7 +33,8 @@ * `write() OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work * around this, we pass a dummy no-op serializer. */ -final class DummySerializerInstance extends SerializerInstance { +@Private +public final class DummySerializerInstance extends SerializerInstance { public static final DummySerializerInstance INSTANCE = new DummySerializerInstance(); @@ -50,7 +49,7 @@ public void flush() { try { s.flush(); } catch (IOException e) { - PlatformDependent.throwException(e); + Platform.throwException(e); } } @@ -65,7 +64,7 @@ public void close() { try { s.close(); } catch (IOException e) { - PlatformDependent.throwException(e); + Platform.throwException(e); } } }; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index d3d6280284be..0b8b604e1849 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -75,7 +75,7 @@ final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter< private final Serializer serializer; /** Array of file writers, one for each partition */ - private BlockObjectWriter[] partitionWriters; + private DiskBlockObjectWriter[] partitionWriters; public BypassMergeSortShuffleWriter( SparkConf conf, @@ -101,7 +101,7 @@ public void insertAll(Iterator> records) throws IOException { } final SerializerInstance serInstance = serializer.newInstance(); final long openStartTime = System.nanoTime(); - partitionWriters = new BlockObjectWriter[numPartitions]; + partitionWriters = new DiskBlockObjectWriter[numPartitions]; for (int i = 0; i < numPartitions; i++) { final Tuple2 tempShuffleBlockIdPlusFile = blockManager.diskBlockManager().createTempShuffleBlock(); @@ -121,7 +121,7 @@ public void insertAll(Iterator> records) throws IOException { partitionWriters[partitioner.getPartition(key)].write(key, record._2()); } - for (BlockObjectWriter writer : partitionWriters) { + for (DiskBlockObjectWriter writer : partitionWriters) { writer.commitAndClose(); } } @@ -169,7 +169,7 @@ public void stop() throws IOException { if (partitionWriters != null) { try { final DiskBlockManager diskBlockManager = blockManager.diskBlockManager(); - for (BlockObjectWriter writer : partitionWriters) { + for (DiskBlockObjectWriter writer : partitionWriters) { // This method explicitly does _not_ throw exceptions: writer.revertPartialWritesAndClose(); if (!diskBlockManager.getFile(writer.blockId()).delete()) { diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 9e9ed94b7890..e73ba3946882 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -17,6 +17,7 @@ package org.apache.spark.shuffle.unsafe; +import javax.annotation.Nullable; import java.io.File; import java.io.IOException; import java.util.LinkedList; @@ -30,10 +31,14 @@ import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleMemoryManager; -import org.apache.spark.storage.*; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.DiskBlockObjectWriter; +import org.apache.spark.storage.TempShuffleBlockId; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; import org.apache.spark.util.Utils; @@ -58,15 +63,15 @@ final class UnsafeShuffleExternalSorter { private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class); - private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES; @VisibleForTesting static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; - @VisibleForTesting - static final int MAX_RECORD_SIZE = PAGE_SIZE - 4; private final int initialSize; private final int numPartitions; - private final TaskMemoryManager memoryManager; + private final int pageSizeBytes; + @VisibleForTesting + final int maxRecordSizeBytes; + private final TaskMemoryManager taskMemoryManager; private final ShuffleMemoryManager shuffleMemoryManager; private final BlockManager blockManager; private final TaskContext taskContext; @@ -85,9 +90,12 @@ final class UnsafeShuffleExternalSorter { private final LinkedList spills = new LinkedList(); + /** Peak memory used by this sorter so far, in bytes. **/ + private long peakMemoryUsedBytes; + // These variables are reset after spilling: - private UnsafeShuffleInMemorySorter sorter; - private MemoryBlock currentPage = null; + @Nullable private UnsafeShuffleInMemorySorter inMemSorter; + @Nullable private MemoryBlock currentPage = null; private long currentPagePosition = -1; private long freeSpaceInCurrentPage = 0; @@ -100,17 +108,24 @@ public UnsafeShuffleExternalSorter( int numPartitions, SparkConf conf, ShuffleWriteMetrics writeMetrics) throws IOException { - this.memoryManager = memoryManager; + this.taskMemoryManager = memoryManager; this.shuffleMemoryManager = shuffleMemoryManager; this.blockManager = blockManager; this.taskContext = taskContext; this.initialSize = initialSize; + this.peakMemoryUsedBytes = initialSize; this.numPartitions = numPartitions; // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; - + this.pageSizeBytes = (int) Math.min( + PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, shuffleMemoryManager.pageSizeBytes()); + this.maxRecordSizeBytes = pageSizeBytes - 4; this.writeMetrics = writeMetrics; initializeForWriting(); + + // preserve first page to ensure that we have at least one page to work with. Otherwise, + // other operators in the same task may starve this sorter (SPARK-9709). + acquireNewPageIfNecessary(pageSizeBytes); } /** @@ -125,7 +140,7 @@ private void initializeForWriting() throws IOException { throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); } - this.sorter = new UnsafeShuffleInMemorySorter(initialSize); + this.inMemSorter = new UnsafeShuffleInMemorySorter(initialSize); } /** @@ -152,11 +167,11 @@ private void writeSortedFile(boolean isLastFile) throws IOException { // This call performs the actual sort. final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords = - sorter.getSortedIterator(); + inMemSorter.getSortedIterator(); // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this // after SPARK-5581 is fixed. - BlockObjectWriter writer; + DiskBlockObjectWriter writer; // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to // be an API to directly transfer bytes from managed memory to the disk writer, we buffer @@ -198,18 +213,14 @@ private void writeSortedFile(boolean isLastFile) throws IOException { } final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); - final Object recordPage = memoryManager.getPage(recordPointer); - final long recordOffsetInPage = memoryManager.getOffsetInPage(recordPointer); - int dataRemaining = PlatformDependent.UNSAFE.getInt(recordPage, recordOffsetInPage); + final Object recordPage = taskMemoryManager.getPage(recordPointer); + final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer); + int dataRemaining = Platform.getInt(recordPage, recordOffsetInPage); long recordReadPosition = recordOffsetInPage + 4; // skip over record length while (dataRemaining > 0) { final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining); - PlatformDependent.copyMemory( - recordPage, - recordReadPosition, - writeBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET, - toTransfer); + Platform.copyMemory( + recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer); writer.write(writeBuffer, 0, toTransfer); recordReadPosition += toTransfer; dataRemaining -= toTransfer; @@ -261,9 +272,9 @@ void spill() throws IOException { spills.size() > 1 ? " times" : " time"); writeSortedFile(false); - final long sorterMemoryUsage = sorter.getMemoryUsage(); - sorter = null; - shuffleMemoryManager.release(sorterMemoryUsage); + final long inMemSorterMemoryUsage = inMemSorter.getMemoryUsage(); + inMemSorter = null; + shuffleMemoryManager.release(inMemSorterMemoryUsage); final long spillSize = freeMemory(); taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); @@ -271,13 +282,33 @@ void spill() throws IOException { } private long getMemoryUsage() { - return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE); + long totalPageSize = 0; + for (MemoryBlock page : allocatedPages) { + totalPageSize += page.size(); + } + return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize; + } + + private void updatePeakMemoryUsed() { + long mem = getMemoryUsage(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; } private long freeMemory() { + updatePeakMemoryUsed(); long memoryFreed = 0; for (MemoryBlock block : allocatedPages) { - memoryManager.freePage(block); + taskMemoryManager.freePage(block); shuffleMemoryManager.release(block.size()); memoryFreed += block.size(); } @@ -291,77 +322,76 @@ private long freeMemory() { /** * Force all memory and spill files to be deleted; called by shuffle error-handling code. */ - public void cleanupAfterError() { + public void cleanupResources() { freeMemory(); for (SpillInfo spill : spills) { if (spill.file.exists() && !spill.file.delete()) { logger.error("Unable to delete spill file {}", spill.file.getPath()); } } - if (sorter != null) { - shuffleMemoryManager.release(sorter.getMemoryUsage()); - sorter = null; + if (inMemSorter != null) { + shuffleMemoryManager.release(inMemSorter.getMemoryUsage()); + inMemSorter = null; } } /** - * Checks whether there is enough space to insert a new record into the sorter. - * - * @param requiredSpace the required space in the data page, in bytes, including space for storing - * the record size. - - * @return true if the record can be inserted without requiring more allocations, false otherwise. + * Checks whether there is enough space to insert an additional record in to the sort pointer + * array and grows the array if additional space is required. If the required space cannot be + * obtained, then the in-memory data will be spilled to disk. */ - private boolean haveSpaceForRecord(int requiredSpace) { - assert (requiredSpace > 0); - return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage)); - } - - /** - * Allocates more memory in order to insert an additional record. This will request additional - * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be - * obtained. - * - * @param requiredSpace the required space in the data page, in bytes, including space for storing - * the record size. - */ - private void allocateSpaceForRecord(int requiredSpace) throws IOException { - if (!sorter.hasSpaceForAnotherRecord()) { + private void growPointerArrayIfNecessary() throws IOException { + assert(inMemSorter != null); + if (!inMemSorter.hasSpaceForAnotherRecord()) { logger.debug("Attempting to expand sort pointer array"); - final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage(); + final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage(); final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2; final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray); if (memoryAcquired < memoryToGrowPointerArray) { shuffleMemoryManager.release(memoryAcquired); spill(); } else { - sorter.expandPointerArray(); + inMemSorter.expandPointerArray(); shuffleMemoryManager.release(oldPointerArrayMemoryUsage); } } + } + + /** + * Allocates more memory in order to insert an additional record. This will request additional + * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be + * obtained. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. This must be less than or equal to the page size (records + * that exceed the page size are handled via a different code path which uses + * special overflow pages). + */ + private void acquireNewPageIfNecessary(int requiredSpace) throws IOException { + growPointerArrayIfNecessary(); if (requiredSpace > freeSpaceInCurrentPage) { logger.trace("Required space {} is less than free space in current page ({})", requiredSpace, freeSpaceInCurrentPage); // TODO: we should track metrics on the amount of space wasted when we roll over to a new page // without using the free space at the end of the current page. We should also do this for // BytesToBytesMap. - if (requiredSpace > PAGE_SIZE) { + if (requiredSpace > pageSizeBytes) { throw new IOException("Required space " + requiredSpace + " is greater than page size (" + - PAGE_SIZE + ")"); + pageSizeBytes + ")"); } else { - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); - if (memoryAcquired < PAGE_SIZE) { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquired < pageSizeBytes) { shuffleMemoryManager.release(memoryAcquired); spill(); - final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); - if (memoryAcquiredAfterSpilling != PAGE_SIZE) { + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquiredAfterSpilling != pageSizeBytes) { shuffleMemoryManager.release(memoryAcquiredAfterSpilling); - throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory"); + throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); } } - currentPage = memoryManager.allocatePage(PAGE_SIZE); + currentPage = taskMemoryManager.allocatePage(pageSizeBytes); currentPagePosition = currentPage.getBaseOffset(); - freeSpaceInCurrentPage = PAGE_SIZE; + freeSpaceInCurrentPage = pageSizeBytes; allocatedPages.add(currentPage); } } @@ -375,27 +405,54 @@ public void insertRecord( long recordBaseOffset, int lengthInBytes, int partitionId) throws IOException { + + growPointerArrayIfNecessary(); // Need 4 bytes to store the record length. final int totalSpaceRequired = lengthInBytes + 4; - if (!haveSpaceForRecord(totalSpaceRequired)) { - allocateSpaceForRecord(totalSpaceRequired); + + // --- Figure out where to insert the new record ---------------------------------------------- + + final MemoryBlock dataPage; + long dataPagePosition; + boolean useOverflowPage = totalSpaceRequired > pageSizeBytes; + if (useOverflowPage) { + long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired); + // The record is larger than the page size, so allocate a special overflow page just to hold + // that record. + final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize); + if (memoryGranted != overflowPageSize) { + shuffleMemoryManager.release(memoryGranted); + spill(); + final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize); + if (memoryGrantedAfterSpill != overflowPageSize) { + shuffleMemoryManager.release(memoryGrantedAfterSpill); + throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory"); + } + } + MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); + allocatedPages.add(overflowPage); + dataPage = overflowPage; + dataPagePosition = overflowPage.getBaseOffset(); + } else { + // The record is small enough to fit in a regular data page, but the current page might not + // have enough space to hold it (or no pages have been allocated yet). + acquireNewPageIfNecessary(totalSpaceRequired); + dataPage = currentPage; + dataPagePosition = currentPagePosition; + // Update bookkeeping information + freeSpaceInCurrentPage -= totalSpaceRequired; + currentPagePosition += totalSpaceRequired; } + final Object dataPageBaseObject = dataPage.getBaseObject(); final long recordAddress = - memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); - final Object dataPageBaseObject = currentPage.getBaseObject(); - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes); - currentPagePosition += 4; - freeSpaceInCurrentPage -= 4; - PlatformDependent.copyMemory( - recordBaseObject, - recordBaseOffset, - dataPageBaseObject, - currentPagePosition, - lengthInBytes); - currentPagePosition += lengthInBytes; - freeSpaceInCurrentPage -= lengthInBytes; - sorter.insertRecord(recordAddress, partitionId); + taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); + Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); + dataPagePosition += 4; + Platform.copyMemory( + recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes); + assert(inMemSorter != null); + inMemSorter.insertRecord(recordAddress, partitionId); } /** @@ -407,14 +464,14 @@ public void insertRecord( */ public SpillInfo[] closeAndGetSpills() throws IOException { try { - if (sorter != null) { + if (inMemSorter != null) { // Do not count the final file towards the spill count. writeSortedFile(true); freeMemory(); } return spills.toArray(new SpillInfo[spills.size()]); } catch (IOException e) { - cleanupAfterError(); + cleanupResources(); throw e; } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index ad7eb04afcd8..fdb309e365f6 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -17,14 +17,15 @@ package org.apache.spark.shuffle.unsafe; +import javax.annotation.Nullable; import java.io.*; import java.nio.channels.FileChannel; import java.util.Iterator; -import javax.annotation.Nullable; import scala.Option; import scala.Product2; -import scala.collection.JavaConversions; +import scala.collection.JavaConverters; +import scala.collection.immutable.Map; import scala.reflect.ClassTag; import scala.reflect.ClassTag$; @@ -37,10 +38,10 @@ import org.apache.spark.*; import org.apache.spark.annotation.Private; +import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; import org.apache.spark.io.LZFCompressionCodec; -import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; @@ -52,7 +53,7 @@ import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.TimeTrackingOutputStream; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.TaskMemoryManager; @Private @@ -78,8 +79,9 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final SparkConf sparkConf; private final boolean transferToEnabled; - private MapStatus mapStatus = null; - private UnsafeShuffleExternalSorter sorter = null; + @Nullable private MapStatus mapStatus; + @Nullable private UnsafeShuffleExternalSorter sorter; + private long peakMemoryUsedBytes = 0; /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream { @@ -129,16 +131,43 @@ public UnsafeShuffleWriter( open(); } + @VisibleForTesting + public int maxRecordSizeBytes() { + assert(sorter != null); + return sorter.maxRecordSizeBytes; + } + + private void updatePeakMemoryUsed() { + // sorter can be null if this writer is closed + if (sorter != null) { + long mem = sorter.getPeakMemoryUsedBytes(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; + } + /** * This convenience method should only be called in test code. */ @VisibleForTesting public void write(Iterator> records) throws IOException { - write(JavaConversions.asScalaIterator(records)); + write(JavaConverters.asScalaIteratorConverter(records).asScala()); } @Override public void write(scala.collection.Iterator> records) throws IOException { + // Keep track of success so we know if we encountered an exception + // We do this rather than a standard try/catch/re-throw to handle + // generic throwables. boolean success = false; try { while (records.hasNext()) { @@ -147,8 +176,19 @@ public void write(scala.collection.Iterator> records) throws IOEx closeAndWriteOutput(); success = true; } finally { - if (!success) { - sorter.cleanupAfterError(); + if (sorter != null) { + try { + sorter.cleanupResources(); + } catch (Exception e) { + // Only throw this error if we won't be masking another + // error. + if (success) { + throw e; + } else { + logger.error("In addition to a failure during writing, we failed during " + + "cleanup.", e); + } + } } } } @@ -170,6 +210,8 @@ private void open() throws IOException { @VisibleForTesting void closeAndWriteOutput() throws IOException { + assert(sorter != null); + updatePeakMemoryUsed(); serBuffer = null; serOutputStream = null; final SpillInfo[] spills = sorter.closeAndGetSpills(); @@ -190,6 +232,7 @@ void closeAndWriteOutput() throws IOException { @VisibleForTesting void insertRecordIntoSorter(Product2 record) throws IOException { + assert(sorter != null); final K key = record._1(); final int partitionId = partitioner.getPartition(key); serBuffer.reset(); @@ -201,7 +244,7 @@ void insertRecordIntoSorter(Product2 record) throws IOException { assert (serializedRecordSize > 0); sorter.insertRecord( - serBuffer.getBuf(), PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); + serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); } @VisibleForTesting @@ -412,6 +455,14 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th @Override public Option stop(boolean success) { try { + // Update task metrics from accumulators (null in UnsafeShuffleWriterSuite) + Map> internalAccumulators = + taskContext.internalMetricsToAccumulators(); + if (internalAccumulators != null) { + internalAccumulators.apply(InternalAccumulator.PEAK_EXECUTION_MEMORY()) + .add(getPeakMemoryUsedBytes()); + } + if (stopping) { return Option.apply(null); } else { @@ -431,7 +482,7 @@ public Option stop(boolean success) { if (sorter != null) { // If sorter is non-null, then this implies that we called stop() in response to an error, // so we need to clean up memory and spill files created by the sorter - sorter.cleanupAfterError(); + sorter.cleanupResources(); } } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java similarity index 57% rename from unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java rename to core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 0b4d8d286f5f..b24eed3952fd 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -17,34 +17,48 @@ package org.apache.spark.unsafe.map; -import java.lang.Override; -import java.lang.UnsupportedOperationException; +import javax.annotation.Nullable; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -import org.apache.spark.unsafe.*; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.bitset.BitSet; import org.apache.spark.unsafe.hash.Murmur3_x86_32; -import org.apache.spark.unsafe.memory.*; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.MemoryLocation; +import org.apache.spark.unsafe.memory.TaskMemoryManager; /** * An append-only hash map where keys and values are contiguous regions of bytes. - *

+ * * This is backed by a power-of-2-sized hash table, using quadratic probing with triangular numbers, * which is guaranteed to exhaust the space. - *

+ * * The map can support up to 2^29 keys. If the key cardinality is higher than this, you should * probably be using sorting instead of hashing for better cache locality. - *

- * This class is not thread safe. + * + * The key and values under the hood are stored together, in the following format: + * Bytes 0 to 4: len(k) (key length in bytes) + len(v) (value length in bytes) + 4 + * Bytes 4 to 8: len(k) + * Bytes 8 to 8 + len(k): key data + * Bytes 8 + len(k) to 8 + len(k) + len(v): value data + * + * This means that the first four bytes store the entire record (key + value) length. This format + * is consistent with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter}, + * so we can pass records from this map directly into the sorter to sort records in place. */ public final class BytesToBytesMap { + private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class); + private static final Murmur3_x86_32 HASHER = new Murmur3_x86_32(0); private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING; @@ -54,7 +68,9 @@ public final class BytesToBytesMap { */ private static final int END_OF_PAGE_MARKER = -1; - private final TaskMemoryManager memoryManager; + private final TaskMemoryManager taskMemoryManager; + + private final ShuffleMemoryManager shuffleMemoryManager; /** * A linked list for tracking all allocated data pages so that we can free all of our memory. @@ -74,17 +90,11 @@ public final class BytesToBytesMap { */ private long pageCursor = 0; - /** - * The size of the data pages that hold key and value data. Map entries cannot span multiple - * pages, so this limits the maximum entry size. - */ - private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes - /** * The maximum number of keys that BytesToBytesMap supports. The hash table has to be - * power-of-2-sized and its backing Java array can contain at most (1 << 30) elements, since - * that's the largest power-of-2 that's less than Integer.MAX_VALUE. We need two long array - * entries per key, giving us a maximum capacity of (1 << 29). + * power-of-2-sized and its backing Java array can contain at most (1 << 30) elements, + * since that's the largest power-of-2 that's less than Integer.MAX_VALUE. We need two long array + * entries per key, giving us a maximum capacity of (1 << 29). */ @VisibleForTesting static final int MAX_CAPACITY = (1 << 29); @@ -98,7 +108,7 @@ public final class BytesToBytesMap { * Position {@code 2 * i} in the array is used to track a pointer to the key at index {@code i}, * while position {@code 2 * i + 1} in the array holds key's full 32-bit hashcode. */ - private LongArray longArray; + @Nullable private LongArray longArray; // TODO: we're wasting 32 bits of space here; we can probably store fewer bits of the hashcode // and exploit word-alignment to use fewer bits to hold the address. This might let us store // only one long per map entry, increasing the chance that this array will fit in cache at the @@ -113,14 +123,20 @@ public final class BytesToBytesMap { * A {@link BitSet} used to track location of the map where the key is set. * Size of the bitset should be half of the size of the long array. */ - private BitSet bitset; + @Nullable private BitSet bitset; private final double loadFactor; + /** + * The size of the data pages that hold key and value data. Map entries cannot span multiple + * pages, so this limits the maximum entry size. + */ + private final long pageSizeBytes; + /** * Number of keys defined in the map. */ - private int size; + private int numElements; /** * The map will be expanded once the number of keys exceeds this threshold. @@ -149,14 +165,20 @@ public final class BytesToBytesMap { private long numHashCollisions = 0; + private long peakMemoryUsedBytes = 0L; + public BytesToBytesMap( - TaskMemoryManager memoryManager, + TaskMemoryManager taskMemoryManager, + ShuffleMemoryManager shuffleMemoryManager, int initialCapacity, double loadFactor, + long pageSizeBytes, boolean enablePerfMetrics) { - this.memoryManager = memoryManager; + this.taskMemoryManager = taskMemoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; this.loadFactor = loadFactor; this.loc = new Location(); + this.pageSizeBytes = pageSizeBytes; this.enablePerfMetrics = enablePerfMetrics; if (initialCapacity <= 0) { throw new IllegalArgumentException("Initial capacity must be greater than 0"); @@ -165,46 +187,82 @@ public BytesToBytesMap( throw new IllegalArgumentException( "Initial capacity " + initialCapacity + " exceeds maximum capacity of " + MAX_CAPACITY); } + if (pageSizeBytes > TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES) { + throw new IllegalArgumentException("Page size " + pageSizeBytes + " cannot exceed " + + TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES); + } allocate(initialCapacity); + + // Acquire a new page as soon as we construct the map to ensure that we have at least + // one page to work with. Otherwise, other operators in the same task may starve this + // map (SPARK-9747). + acquireNewPage(); } - public BytesToBytesMap(TaskMemoryManager memoryManager, int initialCapacity) { - this(memoryManager, initialCapacity, 0.70, false); + public BytesToBytesMap( + TaskMemoryManager taskMemoryManager, + ShuffleMemoryManager shuffleMemoryManager, + int initialCapacity, + long pageSizeBytes) { + this(taskMemoryManager, shuffleMemoryManager, initialCapacity, 0.70, pageSizeBytes, false); } public BytesToBytesMap( - TaskMemoryManager memoryManager, + TaskMemoryManager taskMemoryManager, + ShuffleMemoryManager shuffleMemoryManager, int initialCapacity, + long pageSizeBytes, boolean enablePerfMetrics) { - this(memoryManager, initialCapacity, 0.70, enablePerfMetrics); + this( + taskMemoryManager, + shuffleMemoryManager, + initialCapacity, + 0.70, + pageSizeBytes, + enablePerfMetrics); } /** * Returns the number of keys defined in the map. */ - public int size() { return size; } + public int numElements() { return numElements; } - private static final class BytesToBytesMapIterator implements Iterator { + public static final class BytesToBytesMapIterator implements Iterator { private final int numRecords; private final Iterator dataPagesIterator; private final Location loc; + private MemoryBlock currentPage = null; private int currentRecordNumber = 0; private Object pageBaseObject; private long offsetInPage; - BytesToBytesMapIterator(int numRecords, Iterator dataPagesIterator, Location loc) { + // If this iterator destructive or not. When it is true, it frees each page as it moves onto + // next one. + private boolean destructive = false; + private BytesToBytesMap bmap; + + private BytesToBytesMapIterator( + int numRecords, Iterator dataPagesIterator, Location loc, + boolean destructive, BytesToBytesMap bmap) { this.numRecords = numRecords; this.dataPagesIterator = dataPagesIterator; this.loc = loc; + this.destructive = destructive; + this.bmap = bmap; if (dataPagesIterator.hasNext()) { advanceToNextPage(); } } private void advanceToNextPage() { - final MemoryBlock currentPage = dataPagesIterator.next(); + if (destructive && currentPage != null) { + dataPagesIterator.remove(); + this.bmap.taskMemoryManager.freePage(currentPage); + this.bmap.shuffleMemoryManager.release(currentPage.size()); + } + currentPage = dataPagesIterator.next(); pageBaseObject = currentPage.getBaseObject(); offsetInPage = currentPage.getBaseOffset(); } @@ -216,13 +274,13 @@ public boolean hasNext() { @Override public Location next() { - int keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage); - if (keyLength == END_OF_PAGE_MARKER) { + int totalLength = Platform.getInt(pageBaseObject, offsetInPage); + if (totalLength == END_OF_PAGE_MARKER) { advanceToNextPage(); - keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage); + totalLength = Platform.getInt(pageBaseObject, offsetInPage); } - loc.with(pageBaseObject, offsetInPage); - offsetInPage += 8 + 8 + keyLength + loc.getValueLength(); + loc.with(currentPage, offsetInPage); + offsetInPage += 4 + totalLength; currentRecordNumber++; return loc; } @@ -241,8 +299,22 @@ public void remove() { * If any other lookups or operations are performed on this map while iterating over it, including * `lookup()`, the behavior of the returned iterator is undefined. */ - public Iterator iterator() { - return new BytesToBytesMapIterator(size, dataPages.iterator(), loc); + public BytesToBytesMapIterator iterator() { + return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc, false, this); + } + + /** + * Returns a destructive iterator for iterating over the entries of this map. It frees each page + * as it moves onto next one. Notice: it is illegal to call any method on the map after + * `destructiveIterator()` has been called. + * + * For efficiency, all calls to `next()` will return the same {@link Location} object. + * + * If any other lookups or operations are performed on this map while iterating over it, including + * `lookup()`, the behavior of the returned iterator is undefined. + */ + public BytesToBytesMapIterator destructiveIterator() { + return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc, true, this); } /** @@ -255,6 +327,23 @@ public Location lookup( Object keyBaseObject, long keyBaseOffset, int keyRowLengthBytes) { + safeLookup(keyBaseObject, keyBaseOffset, keyRowLengthBytes, loc); + return loc; + } + + /** + * Looks up a key, and saves the result in provided `loc`. + * + * This is a thread-safe version of `lookup`, could be used by multiple threads. + */ + public void safeLookup( + Object keyBaseObject, + long keyBaseOffset, + int keyRowLengthBytes, + Location loc) { + assert(bitset != null); + assert(longArray != null); + if (enablePerfMetrics) { numKeyLookups++; } @@ -267,7 +356,8 @@ public Location lookup( } if (!bitset.isSet(pos)) { // This is a new key. - return loc.with(pos, hashcode, false); + loc.with(pos, hashcode, false); + return; } else { long stored = longArray.get(pos * 2 + 1); if ((int) (stored) == hashcode) { @@ -277,7 +367,7 @@ public Location lookup( final MemoryLocation keyAddress = loc.getKeyAddress(); final Object storedKeyBaseObject = keyAddress.getBaseObject(); final long storedKeyBaseOffset = keyAddress.getBaseOffset(); - final boolean areEqual = ByteArrayMethods.wordAlignedArrayEquals( + final boolean areEqual = ByteArrayMethods.arrayEquals( keyBaseObject, keyBaseOffset, storedKeyBaseObject, @@ -285,7 +375,7 @@ public Location lookup( keyRowLengthBytes ); if (areEqual) { - return loc; + return; } else { if (enablePerfMetrics) { numHashCollisions++; @@ -318,23 +408,33 @@ public final class Location { private int keyLength; private int valueLength; + /** + * Memory page containing the record. Only set if created by {@link BytesToBytesMap#iterator()}. + */ + @Nullable private MemoryBlock memoryPage; + private void updateAddressesAndSizes(long fullKeyAddress) { updateAddressesAndSizes( - memoryManager.getPage(fullKeyAddress), memoryManager.getOffsetInPage(fullKeyAddress)); + taskMemoryManager.getPage(fullKeyAddress), + taskMemoryManager.getOffsetInPage(fullKeyAddress)); } - private void updateAddressesAndSizes(Object page, long keyOffsetInPage) { - long position = keyOffsetInPage; - keyLength = (int) PlatformDependent.UNSAFE.getLong(page, position); - position += 8; // word used to store the key size - keyMemoryLocation.setObjAndOffset(page, position); - position += keyLength; - valueLength = (int) PlatformDependent.UNSAFE.getLong(page, position); - position += 8; // word used to store the key size - valueMemoryLocation.setObjAndOffset(page, position); + private void updateAddressesAndSizes(final Object page, final long offsetInPage) { + long position = offsetInPage; + final int totalLength = Platform.getInt(page, position); + position += 4; + keyLength = Platform.getInt(page, position); + position += 4; + valueLength = totalLength - keyLength - 4; + + keyMemoryLocation.setObjAndOffset(page, position); + + position += keyLength; + valueMemoryLocation.setObjAndOffset(page, position); } - Location with(int pos, int keyHashcode, boolean isDefined) { + private Location with(int pos, int keyHashcode, boolean isDefined) { + assert(longArray != null); this.pos = pos; this.isDefined = isDefined; this.keyHashcode = keyHashcode; @@ -345,12 +445,21 @@ Location with(int pos, int keyHashcode, boolean isDefined) { return this; } - Location with(Object page, long keyOffsetInPage) { + private Location with(MemoryBlock page, long offsetInPage) { this.isDefined = true; - updateAddressesAndSizes(page, keyOffsetInPage); + this.memoryPage = page; + updateAddressesAndSizes(page.getBaseObject(), offsetInPage); return this; } + /** + * Returns the memory page that contains the current record. + * This is only valid if this is returned by {@link BytesToBytesMap#iterator()}. + */ + public MemoryBlock getMemoryPage() { + return this.memoryPage; + } + /** * Returns true if the key is defined at this position, and false otherwise. */ @@ -401,27 +510,37 @@ public int getValueLength() { /** * Store a new key and value. This method may only be called once for a given key; if you want * to update the value associated with a key, then you can directly manipulate the bytes stored - * at the value address. + * at the value address. The return value indicates whether the put succeeded or whether it + * failed because additional memory could not be acquired. *

* It is only valid to call this method immediately after calling `lookup()` using the same key. + *

*

* The key and value must be word-aligned (that is, their sizes must multiples of 8). + *

*

* After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length` * will return information on the data stored by this `putNewKey` call. + *

*

* As an example usage, here's the proper way to store a new key: - *

+ *

*
      *   Location loc = map.lookup(keyBaseObject, keyBaseOffset, keyLengthInBytes);
      *   if (!loc.isDefined()) {
-     *     loc.putNewKey(keyBaseObject, keyBaseOffset, keyLengthInBytes, ...)
+     *     if (!loc.putNewKey(keyBaseObject, keyBaseOffset, keyLengthInBytes, ...)) {
+     *       // handle failure to grow map (by spilling, for example)
+     *     }
      *   }
      * 
*

* Unspecified behavior if the key is not defined. + *

+ * + * @return true if the put() was successful and false if the put() failed because memory could + * not be acquired. */ - public void putNewKey( + public boolean putNewKey( Object keyBaseObject, long keyBaseOffset, int keyLengthBytes, @@ -431,64 +550,126 @@ public void putNewKey( assert (!isDefined) : "Can only set value once for a key"; assert (keyLengthBytes % 8 == 0); assert (valueLengthBytes % 8 == 0); - if (size == MAX_CAPACITY) { + assert(bitset != null); + assert(longArray != null); + + if (numElements == MAX_CAPACITY) { throw new IllegalStateException("BytesToBytesMap has reached maximum capacity"); } + // Here, we'll copy the data into our data pages. Because we only store a relative offset from // the key address instead of storing the absolute address of the value, the key and value // must be stored in the same memory page. - // (8 byte key length) (key) (8 byte value length) (value) - final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes; - assert (requiredSize <= PAGE_SIZE_BYTES - 8); // Reserve 8 bytes for the end-of-page marker. - size++; - bitset.set(pos); - - // If there's not enough space in the current page, allocate a new page (8 bytes are reserved - // for the end-of-page marker). - if (currentDataPage == null || PAGE_SIZE_BYTES - 8 - pageCursor < requiredSize) { + // (8 byte key length) (key) (value) + final long requiredSize = 8 + keyLengthBytes + valueLengthBytes; + + // --- Figure out where to insert the new record --------------------------------------------- + + final MemoryBlock dataPage; + final Object dataPageBaseObject; + final long dataPageInsertOffset; + boolean useOverflowPage = requiredSize > pageSizeBytes - 8; + if (useOverflowPage) { + // The record is larger than the page size, so allocate a special overflow page just to hold + // that record. + final long memoryRequested = requiredSize + 8; + final long memoryGranted = shuffleMemoryManager.tryToAcquire(memoryRequested); + if (memoryGranted != memoryRequested) { + shuffleMemoryManager.release(memoryGranted); + logger.debug("Failed to acquire {} bytes of memory", memoryRequested); + return false; + } + MemoryBlock overflowPage = taskMemoryManager.allocatePage(memoryRequested); + dataPages.add(overflowPage); + dataPage = overflowPage; + dataPageBaseObject = overflowPage.getBaseObject(); + dataPageInsertOffset = overflowPage.getBaseOffset(); + } else if (currentDataPage == null || pageSizeBytes - 8 - pageCursor < requiredSize) { + // The record can fit in a data page, but either we have not allocated any pages yet or + // the current page does not have enough space. if (currentDataPage != null) { // There wasn't enough space in the current page, so write an end-of-page marker: final Object pageBaseObject = currentDataPage.getBaseObject(); final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor; - PlatformDependent.UNSAFE.putLong(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); + Platform.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); } - MemoryBlock newPage = memoryManager.allocatePage(PAGE_SIZE_BYTES); - dataPages.add(newPage); - pageCursor = 0; - currentDataPage = newPage; + if (!acquireNewPage()) { + return false; + } + dataPage = currentDataPage; + dataPageBaseObject = currentDataPage.getBaseObject(); + dataPageInsertOffset = currentDataPage.getBaseOffset(); + } else { + // There is enough space in the current data page. + dataPage = currentDataPage; + dataPageBaseObject = currentDataPage.getBaseObject(); + dataPageInsertOffset = currentDataPage.getBaseOffset() + pageCursor; } - // Compute all of our offsets up-front: - final Object pageBaseObject = currentDataPage.getBaseObject(); - final long pageBaseOffset = currentDataPage.getBaseOffset(); - final long keySizeOffsetInPage = pageBaseOffset + pageCursor; - pageCursor += 8; // word used to store the key size - final long keyDataOffsetInPage = pageBaseOffset + pageCursor; - pageCursor += keyLengthBytes; - final long valueSizeOffsetInPage = pageBaseOffset + pageCursor; - pageCursor += 8; // word used to store the value size - final long valueDataOffsetInPage = pageBaseOffset + pageCursor; - pageCursor += valueLengthBytes; + // --- Append the key and value data to the current data page -------------------------------- + long insertCursor = dataPageInsertOffset; + + // Compute all of our offsets up-front: + final long recordOffset = insertCursor; + insertCursor += 4; + final long keyLengthOffset = insertCursor; + insertCursor += 4; + final long keyDataOffsetInPage = insertCursor; + insertCursor += keyLengthBytes; + final long valueDataOffsetInPage = insertCursor; + insertCursor += valueLengthBytes; // word used to store the value size + + Platform.putInt(dataPageBaseObject, recordOffset, + keyLengthBytes + valueLengthBytes + 4); + Platform.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes); // Copy the key - PlatformDependent.UNSAFE.putLong(pageBaseObject, keySizeOffsetInPage, keyLengthBytes); - PlatformDependent.copyMemory( - keyBaseObject, keyBaseOffset, pageBaseObject, keyDataOffsetInPage, keyLengthBytes); + Platform.copyMemory( + keyBaseObject, keyBaseOffset, dataPageBaseObject, keyDataOffsetInPage, keyLengthBytes); // Copy the value - PlatformDependent.UNSAFE.putLong(pageBaseObject, valueSizeOffsetInPage, valueLengthBytes); - PlatformDependent.copyMemory( - valueBaseObject, valueBaseOffset, pageBaseObject, valueDataOffsetInPage, valueLengthBytes); + Platform.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject, + valueDataOffsetInPage, valueLengthBytes); + + // --- Update bookeeping data structures ----------------------------------------------------- + + if (useOverflowPage) { + // Store the end-of-page marker at the end of the data page + Platform.putInt(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER); + } else { + pageCursor += requiredSize; + } - final long storedKeyAddress = memoryManager.encodePageNumberAndOffset( - currentDataPage, keySizeOffsetInPage); + numElements++; + bitset.set(pos); + final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset( + dataPage, recordOffset); longArray.set(pos * 2, storedKeyAddress); longArray.set(pos * 2 + 1, keyHashcode); updateAddressesAndSizes(storedKeyAddress); isDefined = true; - if (size > growthThreshold && longArray.size() < MAX_CAPACITY) { + if (numElements > growthThreshold && longArray.size() < MAX_CAPACITY) { growAndRehash(); } + return true; + } + } + + /** + * Acquire a new page from the {@link ShuffleMemoryManager}. + * @return whether there is enough space to allocate the new page. + */ + private boolean acquireNewPage() { + final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryGranted != pageSizeBytes) { + shuffleMemoryManager.release(memoryGranted); + logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes); + return false; } + MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes); + dataPages.add(newPage); + pageCursor = 0; + currentDataPage = newPage; + return true; } /** @@ -500,9 +681,9 @@ public void putNewKey( private void allocate(int capacity) { assert (capacity >= 0); // The capacity needs to be divisible by 64 so that our bit set can be sized properly - capacity = Math.max((int) Math.min(MAX_CAPACITY, nextPowerOf2(capacity)), 64); + capacity = Math.max((int) Math.min(MAX_CAPACITY, ByteArrayMethods.nextPowerOf2(capacity)), 64); assert (capacity <= MAX_CAPACITY); - longArray = new LongArray(memoryManager.allocate(capacity * 8L * 2)); + longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2])); bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64])); this.growthThreshold = (int) (capacity * loadFactor); @@ -513,31 +694,60 @@ private void allocate(int capacity) { * Free all allocated memory associated with this map, including the storage for keys and values * as well as the hash map array itself. * - * This method is idempotent. + * This method is idempotent and can be called multiple times. */ public void free() { - if (longArray != null) { - memoryManager.free(longArray.memoryBlock()); - longArray = null; - } - if (bitset != null) { - // The bitset's heap memory isn't managed by a memory manager, so no need to free it here. - bitset = null; - } + updatePeakMemoryUsed(); + longArray = null; + bitset = null; Iterator dataPagesIterator = dataPages.iterator(); while (dataPagesIterator.hasNext()) { - memoryManager.freePage(dataPagesIterator.next()); + MemoryBlock dataPage = dataPagesIterator.next(); dataPagesIterator.remove(); + taskMemoryManager.freePage(dataPage); + shuffleMemoryManager.release(dataPage.size()); } assert(dataPages.isEmpty()); } - /** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */ + public TaskMemoryManager getTaskMemoryManager() { + return taskMemoryManager; + } + + public ShuffleMemoryManager getShuffleMemoryManager() { + return shuffleMemoryManager; + } + + public long getPageSizeBytes() { + return pageSizeBytes; + } + + /** + * Returns the total amount of memory, in bytes, consumed by this map's managed structures. + */ public long getTotalMemoryConsumption() { - return ( - dataPages.size() * PAGE_SIZE_BYTES + - bitset.memoryBlock().size() + - longArray.memoryBlock().size()); + long totalDataPagesSize = 0L; + for (MemoryBlock dataPage : dataPages) { + totalDataPagesSize += dataPage.size(); + } + return totalDataPagesSize + + ((bitset != null) ? bitset.memoryBlock().size() : 0L) + + ((longArray != null) ? longArray.memoryBlock().size() : 0L); + } + + private void updatePeakMemoryUsed() { + long mem = getTotalMemoryConsumption(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; } /** @@ -550,7 +760,6 @@ public long getTimeSpentResizingNs() { return timeSpentResizingNs; } - /** * Returns the average number of probes per key lookup. */ @@ -569,7 +778,7 @@ public long getNumHashCollisions() { } @VisibleForTesting - int getNumDataPages() { + public int getNumDataPages() { return dataPages.size(); } @@ -578,6 +787,9 @@ int getNumDataPages() { */ @VisibleForTesting void growAndRehash() { + assert(bitset != null); + assert(longArray != null); + long resizeStartTime = -1; if (enablePerfMetrics) { resizeStartTime = System.nanoTime(); @@ -613,16 +825,8 @@ void growAndRehash() { } } - // Deallocate the old data structures. - memoryManager.free(oldLongArray.memoryBlock()); if (enablePerfMetrics) { timeSpentResizingNs += System.nanoTime() - resizeStartTime; } } - - /** Returns the next number greater or equal num that is power of 2. */ - private static long nextPowerOf2(long num) { - final long highBit = Long.highestOneBit(num); - return (highBit == num) ? num : highBit << 1; - } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java rename to core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java new file mode 100644 index 000000000000..45b78829e4cf --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import org.apache.spark.annotation.Private; + +/** + * Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific + * comparisons, such as lexicographic comparison for strings. + */ +@Private +public abstract class PrefixComparator { + public abstract int compare(long prefix1, long prefix2); +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java new file mode 100644 index 000000000000..71b76d5ddfaa --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import com.google.common.primitives.UnsignedLongs; + +import org.apache.spark.annotation.Private; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.util.Utils; + +@Private +public class PrefixComparators { + private PrefixComparators() {} + + public static final StringPrefixComparator STRING = new StringPrefixComparator(); + public static final StringPrefixComparatorDesc STRING_DESC = new StringPrefixComparatorDesc(); + public static final BinaryPrefixComparator BINARY = new BinaryPrefixComparator(); + public static final BinaryPrefixComparatorDesc BINARY_DESC = new BinaryPrefixComparatorDesc(); + public static final LongPrefixComparator LONG = new LongPrefixComparator(); + public static final LongPrefixComparatorDesc LONG_DESC = new LongPrefixComparatorDesc(); + public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator(); + public static final DoublePrefixComparatorDesc DOUBLE_DESC = new DoublePrefixComparatorDesc(); + + public static final class StringPrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + return UnsignedLongs.compare(aPrefix, bPrefix); + } + + public static long computePrefix(UTF8String value) { + return value == null ? 0L : value.getPrefix(); + } + } + + public static final class StringPrefixComparatorDesc extends PrefixComparator { + @Override + public int compare(long bPrefix, long aPrefix) { + return UnsignedLongs.compare(aPrefix, bPrefix); + } + } + + public static final class BinaryPrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + return UnsignedLongs.compare(aPrefix, bPrefix); + } + + public static long computePrefix(byte[] bytes) { + if (bytes == null) { + return 0L; + } else { + /** + * TODO: If a wrapper for BinaryType is created (SPARK-8786), + * these codes below will be in the wrapper class. + */ + final int minLen = Math.min(bytes.length, 8); + long p = 0; + for (int i = 0; i < minLen; ++i) { + p |= (128L + Platform.getByte(bytes, Platform.BYTE_ARRAY_OFFSET + i)) + << (56 - 8 * i); + } + return p; + } + } + } + + public static final class BinaryPrefixComparatorDesc extends PrefixComparator { + @Override + public int compare(long bPrefix, long aPrefix) { + return UnsignedLongs.compare(aPrefix, bPrefix); + } + } + + public static final class LongPrefixComparator extends PrefixComparator { + @Override + public int compare(long a, long b) { + return (a < b) ? -1 : (a > b) ? 1 : 0; + } + } + + public static final class LongPrefixComparatorDesc extends PrefixComparator { + @Override + public int compare(long b, long a) { + return (a < b) ? -1 : (a > b) ? 1 : 0; + } + } + + public static final class DoublePrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + double a = Double.longBitsToDouble(aPrefix); + double b = Double.longBitsToDouble(bPrefix); + return Utils.nanSafeCompareDoubles(a, b); + } + + public static long computePrefix(double value) { + return Double.doubleToLongBits(value); + } + } + + public static final class DoublePrefixComparatorDesc extends PrefixComparator { + @Override + public int compare(long bPrefix, long aPrefix) { + double a = Double.longBitsToDouble(aPrefix); + double b = Double.longBitsToDouble(bPrefix); + return Utils.nanSafeCompareDoubles(a, b); + } + + public static long computePrefix(double value) { + return Double.doubleToLongBits(value); + } + } +} diff --git a/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java similarity index 58% rename from core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala rename to core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java index a4568e849fa1..09e425879220 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java @@ -15,27 +15,23 @@ * limitations under the License. */ -package org.apache.spark.network.nio +package org.apache.spark.util.collection.unsafe.sort; -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer - -private[nio] -class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { - - val size: Int = if (buffer == null) 0 else buffer.remaining - - lazy val buffers: ArrayBuffer[ByteBuffer] = { - val ab = new ArrayBuffer[ByteBuffer]() - ab += header.buffer - if (buffer != null) { - ab += buffer - } - ab - } +/** + * Compares records for ordering. In cases where the entire sorting key can fit in the 8-byte + * prefix, this may simply return 0. + */ +public abstract class RecordComparator { - override def toString: String = { - "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" - } + /** + * Compare two records for order. + * + * @return a negative integer, zero, or a positive integer as the first record is less than, + * equal to, or greater than the second. + */ + public abstract int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java new file mode 100644 index 000000000000..0c4ebde407cf --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +final class RecordPointerAndKeyPrefix { + /** + * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a + * description of how these addresses are encoded. + */ + public long recordPointer; + + /** + * A key prefix, for use in comparisons. + */ + public long keyPrefix; +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java new file mode 100644 index 000000000000..fc364e0a895b --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -0,0 +1,528 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.File; +import java.io.IOException; +import java.util.LinkedList; + +import javax.annotation.Nullable; + +import scala.runtime.AbstractFunction0; +import scala.runtime.BoxedUnit; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; + +/** + * External sorter based on {@link UnsafeInMemorySorter}. + */ +public final class UnsafeExternalSorter { + + private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class); + + private final long pageSizeBytes; + private final PrefixComparator prefixComparator; + private final RecordComparator recordComparator; + private final int initialSize; + private final TaskMemoryManager taskMemoryManager; + private final ShuffleMemoryManager shuffleMemoryManager; + private final BlockManager blockManager; + private final TaskContext taskContext; + private ShuffleWriteMetrics writeMetrics; + + /** The buffer size to use when writing spills using DiskBlockObjectWriter */ + private final int fileBufferSizeBytes; + + /** + * Memory pages that hold the records being sorted. The pages in this list are freed when + * spilling, although in principle we could recycle these pages across spills (on the other hand, + * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager + * itself). + */ + private final LinkedList allocatedPages = new LinkedList<>(); + + private final LinkedList spillWriters = new LinkedList<>(); + + // These variables are reset after spilling: + @Nullable private UnsafeInMemorySorter inMemSorter; + // Whether the in-mem sorter is created internally, or passed in from outside. + // If it is passed in from outside, we shouldn't release the in-mem sorter's memory. + private boolean isInMemSorterExternal = false; + private MemoryBlock currentPage = null; + private long currentPagePosition = -1; + private long freeSpaceInCurrentPage = 0; + private long peakMemoryUsedBytes = 0; + + public static UnsafeExternalSorter createWithExistingInMemorySorter( + TaskMemoryManager taskMemoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + RecordComparator recordComparator, + PrefixComparator prefixComparator, + int initialSize, + long pageSizeBytes, + UnsafeInMemorySorter inMemorySorter) throws IOException { + return new UnsafeExternalSorter(taskMemoryManager, shuffleMemoryManager, blockManager, + taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter); + } + + public static UnsafeExternalSorter create( + TaskMemoryManager taskMemoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + RecordComparator recordComparator, + PrefixComparator prefixComparator, + int initialSize, + long pageSizeBytes) throws IOException { + return new UnsafeExternalSorter(taskMemoryManager, shuffleMemoryManager, blockManager, + taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null); + } + + private UnsafeExternalSorter( + TaskMemoryManager taskMemoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + RecordComparator recordComparator, + PrefixComparator prefixComparator, + int initialSize, + long pageSizeBytes, + @Nullable UnsafeInMemorySorter existingInMemorySorter) throws IOException { + this.taskMemoryManager = taskMemoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; + this.blockManager = blockManager; + this.taskContext = taskContext; + this.recordComparator = recordComparator; + this.prefixComparator = prefixComparator; + this.initialSize = initialSize; + // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units + // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.fileBufferSizeBytes = 32 * 1024; + this.pageSizeBytes = pageSizeBytes; + this.writeMetrics = new ShuffleWriteMetrics(); + + if (existingInMemorySorter == null) { + initializeForWriting(); + // Acquire a new page as soon as we construct the sorter to ensure that we have at + // least one page to work with. Otherwise, other operators in the same task may starve + // this sorter (SPARK-9709). We don't need to do this if we already have an existing sorter. + acquireNewPage(); + } else { + this.isInMemSorterExternal = true; + this.inMemSorter = existingInMemorySorter; + } + + // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at + // the end of the task. This is necessary to avoid memory leaks in when the downstream operator + // does not fully consume the sorter's output (e.g. sort followed by limit). + taskContext.addOnCompleteCallback(new AbstractFunction0() { + @Override + public BoxedUnit apply() { + cleanupResources(); + return null; + } + }); + } + + // TODO: metrics tracking + integration with shuffle write metrics + // need to connect the write metrics to task metrics so we count the spill IO somewhere. + + /** + * Allocates new sort data structures. Called when creating the sorter and after each spill. + */ + private void initializeForWriting() throws IOException { + this.writeMetrics = new ShuffleWriteMetrics(); + final long pointerArrayMemory = + UnsafeInMemorySorter.getMemoryRequirementsForPointerArray(initialSize); + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pointerArrayMemory); + if (memoryAcquired != pointerArrayMemory) { + shuffleMemoryManager.release(memoryAcquired); + throw new IOException("Could not acquire " + pointerArrayMemory + " bytes of memory"); + } + + this.inMemSorter = + new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize); + this.isInMemSorterExternal = false; + } + + /** + * Marks the current page as no-more-space-available, and as a result, either allocate a + * new page or spill when we see the next record. + */ + @VisibleForTesting + public void closeCurrentPage() { + freeSpaceInCurrentPage = 0; + } + + /** + * Sort and spill the current records in response to memory pressure. + */ + public void spill() throws IOException { + assert(inMemSorter != null); + logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", + Thread.currentThread().getId(), + Utils.bytesToString(getMemoryUsage()), + spillWriters.size(), + spillWriters.size() > 1 ? " times" : " time"); + + // We only write out contents of the inMemSorter if it is not empty. + if (inMemSorter.numRecords() > 0) { + final UnsafeSorterSpillWriter spillWriter = + new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, + inMemSorter.numRecords()); + spillWriters.add(spillWriter); + final UnsafeSorterIterator sortedRecords = inMemSorter.getSortedIterator(); + while (sortedRecords.hasNext()) { + sortedRecords.loadNext(); + final Object baseObject = sortedRecords.getBaseObject(); + final long baseOffset = sortedRecords.getBaseOffset(); + final int recordLength = sortedRecords.getRecordLength(); + spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix()); + } + spillWriter.close(); + } + + final long spillSize = freeMemory(); + // Note that this is more-or-less going to be a multiple of the page size, so wasted space in + // pages will currently be counted as memory spilled even though that space isn't actually + // written to disk. This also counts the space needed to store the sorter's pointer array. + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + + initializeForWriting(); + } + + /** + * Return the total memory usage of this sorter, including the data pages and the sorter's pointer + * array. + */ + private long getMemoryUsage() { + long totalPageSize = 0; + for (MemoryBlock page : allocatedPages) { + totalPageSize += page.size(); + } + return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize; + } + + private void updatePeakMemoryUsed() { + long mem = getMemoryUsage(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; + } + + @VisibleForTesting + public int getNumberOfAllocatedPages() { + return allocatedPages.size(); + } + + /** + * Free this sorter's in-memory data structures, including its data pages and pointer array. + * + * @return the number of bytes freed. + */ + private long freeMemory() { + updatePeakMemoryUsed(); + long memoryFreed = 0; + for (MemoryBlock block : allocatedPages) { + taskMemoryManager.freePage(block); + shuffleMemoryManager.release(block.size()); + memoryFreed += block.size(); + } + if (inMemSorter != null) { + if (!isInMemSorterExternal) { + long sorterMemoryUsage = inMemSorter.getMemoryUsage(); + memoryFreed += sorterMemoryUsage; + shuffleMemoryManager.release(sorterMemoryUsage); + } + inMemSorter = null; + } + allocatedPages.clear(); + currentPage = null; + currentPagePosition = -1; + freeSpaceInCurrentPage = 0; + return memoryFreed; + } + + /** + * Deletes any spill files created by this sorter. + */ + private void deleteSpillFiles() { + for (UnsafeSorterSpillWriter spill : spillWriters) { + File file = spill.getFile(); + if (file != null && file.exists()) { + if (!file.delete()) { + logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); + } + } + } + } + + /** + * Frees this sorter's in-memory data structures and cleans up its spill files. + */ + public void cleanupResources() { + deleteSpillFiles(); + freeMemory(); + } + + /** + * Checks whether there is enough space to insert an additional record in to the sort pointer + * array and grows the array if additional space is required. If the required space cannot be + * obtained, then the in-memory data will be spilled to disk. + */ + private void growPointerArrayIfNecessary() throws IOException { + assert(inMemSorter != null); + if (!inMemSorter.hasSpaceForAnotherRecord()) { + logger.debug("Attempting to expand sort pointer array"); + final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage(); + final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray); + if (memoryAcquired < memoryToGrowPointerArray) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + } else { + inMemSorter.expandPointerArray(); + shuffleMemoryManager.release(oldPointerArrayMemoryUsage); + } + } + } + + /** + * Allocates more memory in order to insert an additional record. This will request additional + * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be + * obtained. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. This must be less than or equal to the page size (records + * that exceed the page size are handled via a different code path which uses + * special overflow pages). + */ + private void acquireNewPageIfNecessary(int requiredSpace) throws IOException { + assert (requiredSpace <= pageSizeBytes); + if (requiredSpace > freeSpaceInCurrentPage) { + logger.trace("Required space {} is less than free space in current page ({})", requiredSpace, + freeSpaceInCurrentPage); + // TODO: we should track metrics on the amount of space wasted when we roll over to a new page + // without using the free space at the end of the current page. We should also do this for + // BytesToBytesMap. + if (requiredSpace > pageSizeBytes) { + throw new IOException("Required space " + requiredSpace + " is greater than page size (" + + pageSizeBytes + ")"); + } else { + acquireNewPage(); + } + } + } + + /** + * Acquire a new page from the {@link ShuffleMemoryManager}. + * + * If there is not enough space to allocate the new page, spill all existing ones + * and try again. If there is still not enough space, report error to the caller. + */ + private void acquireNewPage() throws IOException { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquired < pageSizeBytes) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquiredAfterSpilling != pageSizeBytes) { + shuffleMemoryManager.release(memoryAcquiredAfterSpilling); + throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); + } + } + currentPage = taskMemoryManager.allocatePage(pageSizeBytes); + currentPagePosition = currentPage.getBaseOffset(); + freeSpaceInCurrentPage = pageSizeBytes; + allocatedPages.add(currentPage); + } + + /** + * Write a record to the sorter. + */ + public void insertRecord( + Object recordBaseObject, + long recordBaseOffset, + int lengthInBytes, + long prefix) throws IOException { + + growPointerArrayIfNecessary(); + // Need 4 bytes to store the record length. + final int totalSpaceRequired = lengthInBytes + 4; + + // --- Figure out where to insert the new record ---------------------------------------------- + + final MemoryBlock dataPage; + long dataPagePosition; + boolean useOverflowPage = totalSpaceRequired > pageSizeBytes; + if (useOverflowPage) { + long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired); + // The record is larger than the page size, so allocate a special overflow page just to hold + // that record. + final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize); + if (memoryGranted != overflowPageSize) { + shuffleMemoryManager.release(memoryGranted); + spill(); + final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize); + if (memoryGrantedAfterSpill != overflowPageSize) { + shuffleMemoryManager.release(memoryGrantedAfterSpill); + throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory"); + } + } + MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); + allocatedPages.add(overflowPage); + dataPage = overflowPage; + dataPagePosition = overflowPage.getBaseOffset(); + } else { + // The record is small enough to fit in a regular data page, but the current page might not + // have enough space to hold it (or no pages have been allocated yet). + acquireNewPageIfNecessary(totalSpaceRequired); + dataPage = currentPage; + dataPagePosition = currentPagePosition; + // Update bookkeeping information + freeSpaceInCurrentPage -= totalSpaceRequired; + currentPagePosition += totalSpaceRequired; + } + final Object dataPageBaseObject = dataPage.getBaseObject(); + + // --- Insert the record ---------------------------------------------------------------------- + + final long recordAddress = + taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); + Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); + dataPagePosition += 4; + Platform.copyMemory( + recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes); + assert(inMemSorter != null); + inMemSorter.insertRecord(recordAddress, prefix); + } + + /** + * Write a key-value record to the sorter. The key and value will be put together in-memory, + * using the following format: + * + * record length (4 bytes), key length (4 bytes), key data, value data + * + * record length = key length + value length + 4 + */ + public void insertKVRecord( + Object keyBaseObj, long keyOffset, int keyLen, + Object valueBaseObj, long valueOffset, int valueLen, long prefix) throws IOException { + + growPointerArrayIfNecessary(); + final int totalSpaceRequired = keyLen + valueLen + 4 + 4; + + // --- Figure out where to insert the new record ---------------------------------------------- + + final MemoryBlock dataPage; + long dataPagePosition; + boolean useOverflowPage = totalSpaceRequired > pageSizeBytes; + if (useOverflowPage) { + long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired); + // The record is larger than the page size, so allocate a special overflow page just to hold + // that record. + final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize); + if (memoryGranted != overflowPageSize) { + shuffleMemoryManager.release(memoryGranted); + spill(); + final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize); + if (memoryGrantedAfterSpill != overflowPageSize) { + shuffleMemoryManager.release(memoryGrantedAfterSpill); + throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory"); + } + } + MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); + allocatedPages.add(overflowPage); + dataPage = overflowPage; + dataPagePosition = overflowPage.getBaseOffset(); + } else { + // The record is small enough to fit in a regular data page, but the current page might not + // have enough space to hold it (or no pages have been allocated yet). + acquireNewPageIfNecessary(totalSpaceRequired); + dataPage = currentPage; + dataPagePosition = currentPagePosition; + // Update bookkeeping information + freeSpaceInCurrentPage -= totalSpaceRequired; + currentPagePosition += totalSpaceRequired; + } + final Object dataPageBaseObject = dataPage.getBaseObject(); + + // --- Insert the record ---------------------------------------------------------------------- + + final long recordAddress = + taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); + Platform.putInt(dataPageBaseObject, dataPagePosition, keyLen + valueLen + 4); + dataPagePosition += 4; + + Platform.putInt(dataPageBaseObject, dataPagePosition, keyLen); + dataPagePosition += 4; + + Platform.copyMemory(keyBaseObj, keyOffset, dataPageBaseObject, dataPagePosition, keyLen); + dataPagePosition += keyLen; + + Platform.copyMemory(valueBaseObj, valueOffset, dataPageBaseObject, dataPagePosition, valueLen); + + assert(inMemSorter != null); + inMemSorter.insertRecord(recordAddress, prefix); + } + + /** + * Returns a sorted iterator. It is the caller's responsibility to call `cleanupResources()` + * after consuming this iterator. + */ + public UnsafeSorterIterator getSortedIterator() throws IOException { + assert(inMemSorter != null); + final UnsafeInMemorySorter.SortedIterator inMemoryIterator = inMemSorter.getSortedIterator(); + int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0); + if (spillWriters.isEmpty()) { + return inMemoryIterator; + } else { + final UnsafeSorterSpillMerger spillMerger = + new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge); + for (UnsafeSorterSpillWriter spillWriter : spillWriters) { + spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager)); + } + spillWriters.clear(); + spillMerger.addSpillIfNotEmpty(inMemoryIterator); + + return spillMerger.getSortedIterator(); + } + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java new file mode 100644 index 000000000000..f7787e1019c2 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.util.Comparator; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.util.collection.Sorter; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +/** + * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records + * alongside a user-defined prefix of the record's sorting key. When the underlying sort algorithm + * compares records, it will first compare the stored key prefixes; if the prefixes are not equal, + * then we do not need to traverse the record pointers to compare the actual records. Avoiding these + * random memory accesses improves cache hit rates. + */ +public final class UnsafeInMemorySorter { + + private static final class SortComparator implements Comparator { + + private final RecordComparator recordComparator; + private final PrefixComparator prefixComparator; + private final TaskMemoryManager memoryManager; + + SortComparator( + RecordComparator recordComparator, + PrefixComparator prefixComparator, + TaskMemoryManager memoryManager) { + this.recordComparator = recordComparator; + this.prefixComparator = prefixComparator; + this.memoryManager = memoryManager; + } + + @Override + public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { + final int prefixComparisonResult = prefixComparator.compare(r1.keyPrefix, r2.keyPrefix); + if (prefixComparisonResult == 0) { + final Object baseObject1 = memoryManager.getPage(r1.recordPointer); + final long baseOffset1 = memoryManager.getOffsetInPage(r1.recordPointer) + 4; // skip length + final Object baseObject2 = memoryManager.getPage(r2.recordPointer); + final long baseOffset2 = memoryManager.getOffsetInPage(r2.recordPointer) + 4; // skip length + return recordComparator.compare(baseObject1, baseOffset1, baseObject2, baseOffset2); + } else { + return prefixComparisonResult; + } + } + } + + private final TaskMemoryManager memoryManager; + private final Sorter sorter; + private final Comparator sortComparator; + + /** + * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at + * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. + */ + private long[] pointerArray; + + /** + * The position in the sort buffer where new records can be inserted. + */ + private int pointerArrayInsertPosition = 0; + + public UnsafeInMemorySorter( + final TaskMemoryManager memoryManager, + final RecordComparator recordComparator, + final PrefixComparator prefixComparator, + int initialSize) { + assert (initialSize > 0); + this.pointerArray = new long[initialSize * 2]; + this.memoryManager = memoryManager; + this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); + this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); + } + + /** + * @return the number of records that have been inserted into this sorter. + */ + public int numRecords() { + return pointerArrayInsertPosition / 2; + } + + public long getMemoryUsage() { + return pointerArray.length * 8L; + } + + static long getMemoryRequirementsForPointerArray(long numEntries) { + return numEntries * 2L * 8L; + } + + public boolean hasSpaceForAnotherRecord() { + return pointerArrayInsertPosition + 2 < pointerArray.length; + } + + public void expandPointerArray() { + final long[] oldArray = pointerArray; + // Guard against overflow: + final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE; + pointerArray = new long[newLength]; + System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length); + } + + /** + * Inserts a record to be sorted. Assumes that the record pointer points to a record length + * stored as a 4-byte integer, followed by the record's bytes. + * + * @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}. + * @param keyPrefix a user-defined key prefix + */ + public void insertRecord(long recordPointer, long keyPrefix) { + if (!hasSpaceForAnotherRecord()) { + expandPointerArray(); + } + pointerArray[pointerArrayInsertPosition] = recordPointer; + pointerArrayInsertPosition++; + pointerArray[pointerArrayInsertPosition] = keyPrefix; + pointerArrayInsertPosition++; + } + + public static final class SortedIterator extends UnsafeSorterIterator { + + private final TaskMemoryManager memoryManager; + private final int sortBufferInsertPosition; + private final long[] sortBuffer; + private int position = 0; + private Object baseObject; + private long baseOffset; + private long keyPrefix; + private int recordLength; + + private SortedIterator( + TaskMemoryManager memoryManager, + int sortBufferInsertPosition, + long[] sortBuffer) { + this.memoryManager = memoryManager; + this.sortBufferInsertPosition = sortBufferInsertPosition; + this.sortBuffer = sortBuffer; + } + + @Override + public boolean hasNext() { + return position < sortBufferInsertPosition; + } + + @Override + public void loadNext() { + // This pointer points to a 4-byte record length, followed by the record's bytes + final long recordPointer = sortBuffer[position]; + baseObject = memoryManager.getPage(recordPointer); + baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length + recordLength = Platform.getInt(baseObject, baseOffset - 4); + keyPrefix = sortBuffer[position + 1]; + position += 2; + } + + @Override + public Object getBaseObject() { return baseObject; } + + @Override + public long getBaseOffset() { return baseOffset; } + + @Override + public int getRecordLength() { return recordLength; } + + @Override + public long getKeyPrefix() { return keyPrefix; } + } + + /** + * Return an iterator over record pointers in sorted order. For efficiency, all calls to + * {@code next()} will return the same mutable object. + */ + public SortedIterator getSortedIterator() { + sorter.sort(pointerArray, 0, pointerArrayInsertPosition / 2, sortComparator); + return new SortedIterator(memoryManager, pointerArrayInsertPosition, pointerArray); + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java new file mode 100644 index 000000000000..d09c728a7a63 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import org.apache.spark.util.collection.SortDataFormat; + +/** + * Supports sorting an array of (record pointer, key prefix) pairs. + * Used in {@link UnsafeInMemorySorter}. + *

+ * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at + * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. + */ +final class UnsafeSortDataFormat extends SortDataFormat { + + public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat(); + + private UnsafeSortDataFormat() { } + + @Override + public RecordPointerAndKeyPrefix getKey(long[] data, int pos) { + // Since we re-use keys, this method shouldn't be called. + throw new UnsupportedOperationException(); + } + + @Override + public RecordPointerAndKeyPrefix newKey() { + return new RecordPointerAndKeyPrefix(); + } + + @Override + public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix reuse) { + reuse.recordPointer = data[pos * 2]; + reuse.keyPrefix = data[pos * 2 + 1]; + return reuse; + } + + @Override + public void swap(long[] data, int pos0, int pos1) { + long tempPointer = data[pos0 * 2]; + long tempKeyPrefix = data[pos0 * 2 + 1]; + data[pos0 * 2] = data[pos1 * 2]; + data[pos0 * 2 + 1] = data[pos1 * 2 + 1]; + data[pos1 * 2] = tempPointer; + data[pos1 * 2 + 1] = tempKeyPrefix; + } + + @Override + public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { + dst[dstPos * 2] = src[srcPos * 2]; + dst[dstPos * 2 + 1] = src[srcPos * 2 + 1]; + } + + @Override + public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { + System.arraycopy(src, srcPos * 2, dst, dstPos * 2, length * 2); + } + + @Override + public long[] allocate(int length) { + assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large"; + return new long[length * 2]; + } + +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java new file mode 100644 index 000000000000..16ac2e8d821b --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.IOException; + +public abstract class UnsafeSorterIterator { + + public abstract boolean hasNext(); + + public abstract void loadNext() throws IOException; + + public abstract Object getBaseObject(); + + public abstract long getBaseOffset(); + + public abstract int getRecordLength(); + + public abstract long getKeyPrefix(); +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java new file mode 100644 index 000000000000..3874a9f9cbdb --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.IOException; +import java.util.Comparator; +import java.util.PriorityQueue; + +final class UnsafeSorterSpillMerger { + + private final PriorityQueue priorityQueue; + + public UnsafeSorterSpillMerger( + final RecordComparator recordComparator, + final PrefixComparator prefixComparator, + final int numSpills) { + final Comparator comparator = new Comparator() { + + @Override + public int compare(UnsafeSorterIterator left, UnsafeSorterIterator right) { + final int prefixComparisonResult = + prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix()); + if (prefixComparisonResult == 0) { + return recordComparator.compare( + left.getBaseObject(), left.getBaseOffset(), + right.getBaseObject(), right.getBaseOffset()); + } else { + return prefixComparisonResult; + } + } + }; + priorityQueue = new PriorityQueue(numSpills, comparator); + } + + /** + * Add an UnsafeSorterIterator to this merger + */ + public void addSpillIfNotEmpty(UnsafeSorterIterator spillReader) throws IOException { + if (spillReader.hasNext()) { + // We only add the spillReader to the priorityQueue if it is not empty. We do this to + // make sure the hasNext method of UnsafeSorterIterator returned by getSortedIterator + // does not return wrong result because hasNext will returns true + // at least priorityQueue.size() times. If we allow n spillReaders in the + // priorityQueue, we will have n extra empty records in the result of the UnsafeSorterIterator. + spillReader.loadNext(); + priorityQueue.add(spillReader); + } + } + + public UnsafeSorterIterator getSortedIterator() throws IOException { + return new UnsafeSorterIterator() { + + private UnsafeSorterIterator spillReader; + + @Override + public boolean hasNext() { + return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext()); + } + + @Override + public void loadNext() throws IOException { + if (spillReader != null) { + if (spillReader.hasNext()) { + spillReader.loadNext(); + priorityQueue.add(spillReader); + } + } + spillReader = priorityQueue.remove(); + } + + @Override + public Object getBaseObject() { return spillReader.getBaseObject(); } + + @Override + public long getBaseOffset() { return spillReader.getBaseOffset(); } + + @Override + public int getRecordLength() { return spillReader.getRecordLength(); } + + @Override + public long getKeyPrefix() { return spillReader.getKeyPrefix(); } + }; + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java new file mode 100644 index 000000000000..4989b05d63e2 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.*; + +import com.google.common.io.ByteStreams; + +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.unsafe.Platform; + +/** + * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description + * of the file format). + */ +final class UnsafeSorterSpillReader extends UnsafeSorterIterator { + + private final File file; + private InputStream in; + private DataInputStream din; + + // Variables that change with every record read: + private int recordLength; + private long keyPrefix; + private int numRecordsRemaining; + + private byte[] arr = new byte[1024 * 1024]; + private Object baseObject = arr; + private final long baseOffset = Platform.BYTE_ARRAY_OFFSET; + + public UnsafeSorterSpillReader( + BlockManager blockManager, + File file, + BlockId blockId) throws IOException { + assert (file.length() > 0); + this.file = file; + final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file)); + this.in = blockManager.wrapForCompression(blockId, bs); + this.din = new DataInputStream(this.in); + numRecordsRemaining = din.readInt(); + } + + @Override + public boolean hasNext() { + return (numRecordsRemaining > 0); + } + + @Override + public void loadNext() throws IOException { + recordLength = din.readInt(); + keyPrefix = din.readLong(); + if (recordLength > arr.length) { + arr = new byte[recordLength]; + baseObject = arr; + } + ByteStreams.readFully(in, arr, 0, recordLength); + numRecordsRemaining--; + if (numRecordsRemaining == 0) { + in.close(); + file.delete(); + in = null; + din = null; + } + } + + @Override + public Object getBaseObject() { + return baseObject; + } + + @Override + public long getBaseOffset() { + return baseOffset; + } + + @Override + public int getRecordLength() { + return recordLength; + } + + @Override + public long getKeyPrefix() { + return keyPrefix; + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java new file mode 100644 index 000000000000..e59a84ff8d11 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.File; +import java.io.IOException; + +import scala.Tuple2; + +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.DummySerializerInstance; +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.DiskBlockObjectWriter; +import org.apache.spark.storage.TempLocalBlockId; +import org.apache.spark.unsafe.Platform; + +/** + * Spills a list of sorted records to disk. Spill files have the following format: + * + * [# of records (int)] [[len (int)][prefix (long)][data (bytes)]...] + */ +final class UnsafeSorterSpillWriter { + + static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; + + // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to + // be an API to directly transfer bytes from managed memory to the disk writer, we buffer + // data through a byte array. + private byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE]; + + private final File file; + private final BlockId blockId; + private final int numRecordsToWrite; + private DiskBlockObjectWriter writer; + private int numRecordsSpilled = 0; + + public UnsafeSorterSpillWriter( + BlockManager blockManager, + int fileBufferSize, + ShuffleWriteMetrics writeMetrics, + int numRecordsToWrite) throws IOException { + final Tuple2 spilledFileInfo = + blockManager.diskBlockManager().createTempLocalBlock(); + this.file = spilledFileInfo._2(); + this.blockId = spilledFileInfo._1(); + this.numRecordsToWrite = numRecordsToWrite; + // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. + // Our write path doesn't actually use this serializer (since we end up calling the `write()` + // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work + // around this, we pass a dummy no-op serializer. + writer = blockManager.getDiskWriter( + blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics); + // Write the number of records + writeIntToBuffer(numRecordsToWrite, 0); + writer.write(writeBuffer, 0, 4); + } + + // Based on DataOutputStream.writeLong. + private void writeLongToBuffer(long v, int offset) throws IOException { + writeBuffer[offset + 0] = (byte)(v >>> 56); + writeBuffer[offset + 1] = (byte)(v >>> 48); + writeBuffer[offset + 2] = (byte)(v >>> 40); + writeBuffer[offset + 3] = (byte)(v >>> 32); + writeBuffer[offset + 4] = (byte)(v >>> 24); + writeBuffer[offset + 5] = (byte)(v >>> 16); + writeBuffer[offset + 6] = (byte)(v >>> 8); + writeBuffer[offset + 7] = (byte)(v >>> 0); + } + + // Based on DataOutputStream.writeInt. + private void writeIntToBuffer(int v, int offset) throws IOException { + writeBuffer[offset + 0] = (byte)(v >>> 24); + writeBuffer[offset + 1] = (byte)(v >>> 16); + writeBuffer[offset + 2] = (byte)(v >>> 8); + writeBuffer[offset + 3] = (byte)(v >>> 0); + } + + /** + * Write a record to a spill file. + * + * @param baseObject the base object / memory page containing the record + * @param baseOffset the base offset which points directly to the record data. + * @param recordLength the length of the record. + * @param keyPrefix a sort key prefix + */ + public void write( + Object baseObject, + long baseOffset, + int recordLength, + long keyPrefix) throws IOException { + if (numRecordsSpilled == numRecordsToWrite) { + throw new IllegalStateException( + "Number of records written exceeded numRecordsToWrite = " + numRecordsToWrite); + } else { + numRecordsSpilled++; + } + writeIntToBuffer(recordLength, 0); + writeLongToBuffer(keyPrefix, 4); + int dataRemaining = recordLength; + int freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE - 4 - 8; // space used by prefix + len + long recordReadPosition = baseOffset; + while (dataRemaining > 0) { + final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining); + Platform.copyMemory( + baseObject, + recordReadPosition, + writeBuffer, + Platform.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer), + toTransfer); + writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer); + recordReadPosition += toTransfer; + dataRemaining -= toTransfer; + freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE; + } + if (freeSpaceInWriteBuffer < DISK_WRITE_BUFFER_SIZE) { + writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer)); + } + writer.recordWritten(); + } + + public void close() throws IOException { + writer.commitAndClose(); + writer = null; + writeBuffer = null; + } + + public File getFile() { + return file; + } + + public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException { + return new UnsafeSorterSpillReader(blockManager, file, blockId); + } +} diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties index b146f8a78412..689afea64f8d 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties @@ -10,3 +10,7 @@ log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO + +# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support +log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL +log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index 3a2a88219818..27006e45e932 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -10,3 +10,7 @@ log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO + +# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support +log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL +log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js index 0b450dc76bc3..3c8ddddf07b1 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js +++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js @@ -19,6 +19,9 @@ * to be registered after the page loads. */ $(function() { $("span.expand-additional-metrics").click(function(){ + var status = window.localStorage.getItem("expand-additional-metrics") == "true"; + status = !status; + // Expand the list of additional metrics. var additionalMetricsDiv = $(this).parent().find('.additional-metrics'); $(additionalMetricsDiv).toggleClass('collapsed'); @@ -26,17 +29,31 @@ $(function() { // Switch the class of the arrow from open to closed. $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-open'); $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-closed'); + + window.localStorage.setItem("expand-additional-metrics", "" + status); }); + if (window.localStorage.getItem("expand-additional-metrics") == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem("expand-additional-metrics", "false"); + $("span.expand-additional-metrics").trigger("click"); + } + stripeSummaryTable(); $('input[type="checkbox"]').click(function() { - var column = "table ." + $(this).attr("name"); + var name = $(this).attr("name") + var column = "table ." + name; + var status = window.localStorage.getItem(name) == "true"; + status = !status; $(column).toggle(); stripeSummaryTable(); + window.localStorage.setItem(name, "" + status); }); $("#select-all-metrics").click(function() { + var status = window.localStorage.getItem("select-all-metrics") == "true"; + status = !status; if (this.checked) { // Toggle all un-checked options. $('input[type="checkbox"]:not(:checked)').trigger('click'); @@ -44,6 +61,21 @@ $(function() { // Toggle all checked options. $('input[type="checkbox"]:checked').trigger('click'); } + window.localStorage.setItem("select-all-metrics", "" + status); + }); + + if (window.localStorage.getItem("select-all-metrics") == "true") { + $("#select-all-metrics").attr('checked', status); + } + + $("span.additional-metric-title").parent().find('input[type="checkbox"]').each(function() { + var name = $(this).attr("name") + // If name is undefined, then skip it because it's the "select-all-metrics" checkbox + if (name && window.localStorage.getItem(name) == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem(name, "false"); + $(this).trigger("click") + } }); // Trigger a click on the checkbox if a user clicks the label next to it. diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index 9fa53baaf421..83dbea40b63f 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -72,6 +72,14 @@ var StagePageVizConstants = { rankSep: 40 }; +/* + * Return "expand-dag-viz-arrow-job" if forJob is true. + * Otherwise, return "expand-dag-viz-arrow-stage". + */ +function expandDagVizArrowKey(forJob) { + return forJob ? "expand-dag-viz-arrow-job" : "expand-dag-viz-arrow-stage"; +} + /* * Show or hide the RDD DAG visualization. * @@ -79,6 +87,9 @@ var StagePageVizConstants = { * This is the narrow interface called from the Scala UI code. */ function toggleDagViz(forJob) { + var status = window.localStorage.getItem(expandDagVizArrowKey(forJob)) == "true"; + status = !status; + var arrowSelector = ".expand-dag-viz-arrow"; $(arrowSelector).toggleClass('arrow-closed'); $(arrowSelector).toggleClass('arrow-open'); @@ -93,8 +104,24 @@ function toggleDagViz(forJob) { // Save the graph for later so we don't have to render it again graphContainer().style("display", "none"); } + + window.localStorage.setItem(expandDagVizArrowKey(forJob), "" + status); } +$(function (){ + if ($("#stage-dag-viz").length && + window.localStorage.getItem(expandDagVizArrowKey(false)) == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem(expandDagVizArrowKey(false), "false"); + toggleDagViz(false); + } else if ($("#job-dag-viz").length && + window.localStorage.getItem(expandDagVizArrowKey(true)) == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem(expandDagVizArrowKey(true), "false"); + toggleDagViz(true); + } +}); + /* * Render the RDD DAG visualization. * diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js index ca74ef9d7e94..f4453c71df1e 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js +++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js @@ -66,14 +66,27 @@ function drawApplicationTimeline(groupArray, eventObjArray, startTime) { setupJobEventAction(); $("span.expand-application-timeline").click(function() { + var status = window.localStorage.getItem("expand-application-timeline") == "true"; + status = !status; + $("#application-timeline").toggleClass('collapsed'); // Switch the class of the arrow from open to closed. $(this).find('.expand-application-timeline-arrow').toggleClass('arrow-open'); $(this).find('.expand-application-timeline-arrow').toggleClass('arrow-closed'); + + window.localStorage.setItem("expand-application-timeline", "" + status); }); } +$(function (){ + if (window.localStorage.getItem("expand-application-timeline") == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem("expand-application-timeline", "false"); + $("span.expand-application-timeline").trigger('click'); + } +}); + function drawJobTimeline(groupArray, eventObjArray, startTime) { var groups = new vis.DataSet(groupArray); var items = new vis.DataSet(eventObjArray); @@ -125,14 +138,27 @@ function drawJobTimeline(groupArray, eventObjArray, startTime) { setupStageEventAction(); $("span.expand-job-timeline").click(function() { + var status = window.localStorage.getItem("expand-job-timeline") == "true"; + status = !status; + $("#job-timeline").toggleClass('collapsed'); // Switch the class of the arrow from open to closed. $(this).find('.expand-job-timeline-arrow').toggleClass('arrow-open'); $(this).find('.expand-job-timeline-arrow').toggleClass('arrow-closed'); + + window.localStorage.setItem("expand-job-timeline", "" + status); }); } +$(function (){ + if (window.localStorage.getItem("expand-job-timeline") == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem("expand-job-timeline", "false"); + $("span.expand-job-timeline").trigger('click'); + } +}); + function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, maxFinishTime) { var groups = new vis.DataSet(groupArray); var items = new vis.DataSet(eventObjArray); @@ -176,14 +202,27 @@ function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, ma setupZoomable("#task-assignment-timeline-zoom-lock", taskTimeline); $("span.expand-task-assignment-timeline").click(function() { + var status = window.localStorage.getItem("expand-task-assignment-timeline") == "true"; + status = !status; + $("#task-assignment-timeline").toggleClass("collapsed"); // Switch the class of the arrow from open to closed. $(this).find(".expand-task-assignment-timeline-arrow").toggleClass("arrow-open"); $(this).find(".expand-task-assignment-timeline-arrow").toggleClass("arrow-closed"); + + window.localStorage.setItem("expand-task-assignment-timeline", "" + status); }); } +$(function (){ + if (window.localStorage.getItem("expand-task-assignment-timeline") == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem("expand-task-assignment-timeline", "false"); + $("span.expand-task-assignment-timeline").trigger('click'); + } +}); + function setupExecutorEventAction() { $(".item.box.executor").each(function () { $(this).hover( diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index b1cef4704224..04f3070d25b4 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -207,7 +207,7 @@ span.additional-metric-title { /* Hide all additional metrics by default. This is done here rather than using JavaScript to * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */ .scheduler_delay, .deserialization_time, .fetch_wait_time, .shuffle_read_remote, -.serialization_time, .getting_result_time { +.serialization_time, .getting_result_time, .peak_execution_memory { display: none; } @@ -224,3 +224,11 @@ span.additional-metric-title { a.expandbutton { cursor: pointer; } + +.executor-thread { + background: #E6E6E6; +} + +.non-executor-thread { + background: #FAFAFA; +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 5a8d17bd9993..5592b75afb75 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -20,7 +20,8 @@ package org.apache.spark import java.io.{ObjectInputStream, Serializable} import scala.collection.generic.Growable -import scala.collection.mutable.Map +import scala.collection.Map +import scala.collection.mutable import scala.ref.WeakReference import scala.reflect.ClassTag @@ -39,25 +40,44 @@ import org.apache.spark.util.Utils * @param initialValue initial value of accumulator * @param param helper object defining how to add elements of type `R` and `T` * @param name human-readable name for use in Spark's web UI + * @param internal if this [[Accumulable]] is internal. Internal [[Accumulable]]s will be reported + * to the driver via heartbeats. For internal [[Accumulable]]s, `R` must be + * thread safe so that they can be reported correctly. * @tparam R the full accumulated data (result type) * @tparam T partial data that can be added in */ -class Accumulable[R, T] ( - @transient initialValue: R, +class Accumulable[R, T] private[spark] ( + initialValue: R, param: AccumulableParam[R, T], - val name: Option[String]) + val name: Option[String], + internal: Boolean) extends Serializable { + private[spark] def this( + @transient initialValue: R, param: AccumulableParam[R, T], internal: Boolean) = { + this(initialValue, param, None, internal) + } + + def this(@transient initialValue: R, param: AccumulableParam[R, T], name: Option[String]) = + this(initialValue, param, name, false) + def this(@transient initialValue: R, param: AccumulableParam[R, T]) = this(initialValue, param, None) val id: Long = Accumulators.newId - @transient private var value_ = initialValue // Current value on master + @volatile @transient private var value_ : R = initialValue // Current value on master val zero = param.zero(initialValue) // Zero value to be passed to workers private var deserialized = false - Accumulators.register(this, true) + Accumulators.register(this) + + /** + * If this [[Accumulable]] is internal. Internal [[Accumulable]]s will be reported to the driver + * via heartbeats. For internal [[Accumulable]]s, `R` must be thread safe so that they can be + * reported correctly. + */ + private[spark] def isInternal: Boolean = internal /** * Add more data to this accumulator / accumulable @@ -132,7 +152,15 @@ class Accumulable[R, T] ( in.defaultReadObject() value_ = zero deserialized = true - Accumulators.register(this, false) + // Automatically register the accumulator when it is deserialized with the task closure. + // + // Note internal accumulators sent with task are deserialized before the TaskContext is created + // and are registered in the TaskContext constructor. Other internal accumulators, such SQL + // metrics, still need to register here. + val taskContext = TaskContext.get() + if (taskContext != null) { + taskContext.registerAccumulator(this) + } } override def toString: String = if (value_ == null) "null" else value_.toString @@ -227,10 +255,20 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa * @param param helper object defining how to add elements of type `T` * @tparam T result type */ -class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String]) - extends Accumulable[T, T](initialValue, param, name) { +class Accumulator[T] private[spark] ( + @transient private[spark] val initialValue: T, + param: AccumulatorParam[T], + name: Option[String], + internal: Boolean) + extends Accumulable[T, T](initialValue, param, name, internal) { + + def this(initialValue: T, param: AccumulatorParam[T], name: Option[String]) = { + this(initialValue, param, name, false) + } - def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None) + def this(initialValue: T, param: AccumulatorParam[T]) = { + this(initialValue, param, None, false) + } } /** @@ -284,16 +322,7 @@ private[spark] object Accumulators extends Logging { * It keeps weak references to these objects so that accumulators can be garbage-collected * once the RDDs and user-code that reference them are cleaned up. */ - val originals = Map[Long, WeakReference[Accumulable[_, _]]]() - - /** - * This thread-local map holds per-task copies of accumulators; it is used to collect the set - * of accumulator updates to send back to the driver when tasks complete. After tasks complete, - * this map is cleared by `Accumulators.clear()` (see Executor.scala). - */ - private val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() { - override protected def initialValue() = Map[Long, Accumulable[_, _]]() - } + val originals = mutable.Map[Long, WeakReference[Accumulable[_, _]]]() private var lastId: Long = 0 @@ -302,19 +331,8 @@ private[spark] object Accumulators extends Logging { lastId } - def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized { - if (original) { - originals(a.id) = new WeakReference[Accumulable[_, _]](a) - } else { - localAccums.get()(a.id) = a - } - } - - // Clear the local (non-original) accumulators for the current thread - def clear() { - synchronized { - localAccums.get.clear() - } + def register(a: Accumulable[_, _]): Unit = synchronized { + originals(a.id) = new WeakReference[Accumulable[_, _]](a) } def remove(accId: Long) { @@ -323,15 +341,6 @@ private[spark] object Accumulators extends Logging { } } - // Get the values of the local accumulators for the current thread (by ID) - def values: Map[Long, Any] = synchronized { - val ret = Map[Long, Any]() - for ((id, accum) <- localAccums.get) { - ret(id) = accum.localValue - } - return ret - } - // Add values to the original accumulators with some given IDs def add(values: Map[Long, Any]): Unit = synchronized { for ((id, value) <- values) { @@ -349,7 +358,42 @@ private[spark] object Accumulators extends Logging { } } - def stringifyPartialValue(partialValue: Any): String = "%s".format(partialValue) +} + +private[spark] object InternalAccumulator { + val PEAK_EXECUTION_MEMORY = "peakExecutionMemory" + val TEST_ACCUMULATOR = "testAccumulator" + + // For testing only. + // This needs to be a def since we don't want to reuse the same accumulator across stages. + private def maybeTestAccumulator: Option[Accumulator[Long]] = { + if (sys.props.contains("spark.testing")) { + Some(new Accumulator( + 0L, AccumulatorParam.LongAccumulatorParam, Some(TEST_ACCUMULATOR), internal = true)) + } else { + None + } + } - def stringifyValue(value: Any): String = "%s".format(value) + /** + * Accumulators for tracking internal metrics. + * + * These accumulators are created with the stage such that all tasks in the stage will + * add to the same set of accumulators. We do this to report the distribution of accumulator + * values across all tasks within each stage. + */ + def create(sc: SparkContext): Seq[Accumulator[Long]] = { + val internalAccumulators = Seq( + // Execution memory refers to the memory used by internal data structures created + // during shuffles, aggregations and joins. The value of this accumulator should be + // approximately the sum of the peak sizes across all such data structures created + // in this task. For SQL jobs, this only tracks all unsafe operators and ExternalSort. + new Accumulator( + 0L, AccumulatorParam.LongAccumulatorParam, Some(PEAK_EXECUTION_MEMORY), internal = true) + ) ++ maybeTestAccumulator.toSeq + internalAccumulators.foreach { accumulator => + sc.cleaner.foreach(_.registerAccumulatorForCleanup(accumulator)) + } + internalAccumulators + } } diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index ceeb58075d34..289aab9bd9e5 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -58,12 +58,7 @@ case class Aggregator[K, V, C] ( } else { val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) combiners.insertAll(iter) - // Update task metrics if context is not null - // TODO: Make context non optional in a future release - Option(context).foreach { c => - c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled) - c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled) - } + updateMetrics(context, combiners) combiners.iterator } } @@ -89,13 +84,18 @@ case class Aggregator[K, V, C] ( } else { val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) combiners.insertAll(iter) - // Update task metrics if context is not null - // TODO: Make context non-optional in a future release - Option(context).foreach { c => - c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled) - c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled) - } + updateMetrics(context, combiners) combiners.iterator } } + + /** Update task metrics after populating the external map. */ + private def updateMetrics(context: TaskContext, map: ExternalAppendOnlyMap[_, _, _]): Unit = { + Option(context).foreach { c => + c.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) + c.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) + c.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) + } + } } diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 37198d887b07..d23c1533db75 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -22,7 +22,7 @@ import java.lang.ref.{ReferenceQueue, WeakReference} import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.{RDDCheckpointData, RDD} +import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData} import org.apache.spark.util.Utils /** @@ -231,11 +231,14 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } } - /** Perform checkpoint cleanup. */ + /** + * Clean up checkpoint files written to a reliable storage. + * Locally checkpointed files are cleaned up separately through RDD cleanups. + */ def doCleanCheckpoint(rddId: Int): Unit = { try { logDebug("Cleaning rdd checkpoint data " + rddId) - RDDCheckpointData.clearRDDCheckpointData(sc, rddId) + ReliableRDDCheckpointData.cleanCheckpoint(sc, rddId) listeners.foreach(_.checkpointCleaned(rddId)) logInfo("Cleaned rdd checkpoint data " + rddId) } diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index fc8cdde9348e..9aafc9eb1cde 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -17,6 +17,8 @@ package org.apache.spark +import scala.reflect.ClassTag + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer @@ -65,8 +67,8 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] { * @param mapSideCombine whether to perform partial aggregation (also known as map-side combine) */ @DeveloperApi -class ShuffleDependency[K, V, C]( - @transient _rdd: RDD[_ <: Product2[K, V]], +class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( + @transient private val _rdd: RDD[_ <: Product2[K, V]], val partitioner: Partitioner, val serializer: Option[Serializer] = None, val keyOrdering: Option[Ordering[K]] = None, @@ -76,6 +78,13 @@ class ShuffleDependency[K, V, C]( override def rdd: RDD[Product2[K, V]] = _rdd.asInstanceOf[RDD[Product2[K, V]]] + private[spark] val keyClassName: String = reflect.classTag[K].runtimeClass.getName + private[spark] val valueClassName: String = reflect.classTag[V].runtimeClass.getName + // Note: It's possible that the combiner class tag is null, if the combineByKey + // methods in PairRDDFunctions are used instead of combineByKeyWithClassTag. + private[spark] val combinerClassName: Option[String] = + Option(reflect.classTag[C]).map(_.runtimeClass.getName) + val shuffleId: Int = _rdd.context.newShuffleId() val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle( diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala index 443830f8d03b..842bfdbadc94 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala @@ -24,11 +24,23 @@ package org.apache.spark private[spark] trait ExecutorAllocationClient { /** - * Express a preference to the cluster manager for a given total number of executors. - * This can result in canceling pending requests or filing additional requests. + * Update the cluster manager on our scheduling needs. Three bits of information are included + * to help it make decisions. + * @param numExecutors The total number of executors we'd like to have. The cluster manager + * shouldn't kill any running executor to reach this number, but, + * if all existing executors were to die, this is the number of executors + * we'd want to be allocated. + * @param localityAwareTasks The number of tasks in all active stages that have a locality + * preferences. This includes running, pending, and completed tasks. + * @param hostToLocalTaskCount A map of hosts to the number of tasks from all active stages + * that would like to like to run on that host. + * This includes running, pending, and completed tasks. * @return whether the request is acknowledged by the cluster manager. */ - private[spark] def requestTotalExecutors(numExecutors: Int): Boolean + private[spark] def requestTotalExecutors( + numExecutors: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int]): Boolean /** * Request an additional number of executors from the cluster manager. diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 49329423dca7..b93536e6536e 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -20,6 +20,7 @@ package org.apache.spark import java.util.concurrent.TimeUnit import scala.collection.mutable +import scala.util.control.ControlThrowable import com.codahale.metrics.{Gauge, MetricRegistry} @@ -102,7 +103,7 @@ private[spark] class ExecutorAllocationManager( "spark.dynamicAllocation.executorIdleTimeout", "60s") private val cachedExecutorIdleTimeoutS = conf.getTimeAsSeconds( - "spark.dynamicAllocation.cachedExecutorIdleTimeout", s"${2 * executorIdleTimeoutS}s") + "spark.dynamicAllocation.cachedExecutorIdleTimeout", s"${Integer.MAX_VALUE}s") // During testing, the methods to actually kill and add executors are mocked out private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) @@ -160,6 +161,12 @@ private[spark] class ExecutorAllocationManager( // (2) an executor idle timeout has elapsed. @volatile private var initializing: Boolean = true + // Number of locality aware tasks, used for executor placement. + private var localityAwareTasks = 0 + + // Host to possible task running on it, used for executor placement. + private var hostToLocalTaskCount: Map[String, Int] = Map.empty + /** * Verify that the settings specified through the config are valid. * If not, throw an appropriate exception. @@ -211,7 +218,16 @@ private[spark] class ExecutorAllocationManager( listenerBus.addListener(listener) val scheduleTask = new Runnable() { - override def run(): Unit = Utils.logUncaughtExceptions(schedule()) + override def run(): Unit = { + try { + schedule() + } catch { + case ct: ControlThrowable => + throw ct + case t: Throwable => + logWarning(s"Uncaught exception in thread ${Thread.currentThread().getName}", t) + } + } } executor.scheduleAtFixedRate(scheduleTask, 0, intervalMillis, TimeUnit.MILLISECONDS) } @@ -285,7 +301,7 @@ private[spark] class ExecutorAllocationManager( // If the new target has not changed, avoid sending a message to the cluster manager if (numExecutorsTarget < oldNumExecutorsTarget) { - client.requestTotalExecutors(numExecutorsTarget) + client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) logDebug(s"Lowering target number of executors to $numExecutorsTarget (previously " + s"$oldNumExecutorsTarget) because not all requested executors are actually needed") } @@ -339,7 +355,8 @@ private[spark] class ExecutorAllocationManager( return 0 } - val addRequestAcknowledged = testing || client.requestTotalExecutors(numExecutorsTarget) + val addRequestAcknowledged = testing || + client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) if (addRequestAcknowledged) { val executorsString = "executor" + { if (delta > 1) "s" else "" } logInfo(s"Requesting $delta new $executorsString because tasks are backlogged" + @@ -509,6 +526,12 @@ private[spark] class ExecutorAllocationManager( // Number of tasks currently running on the cluster. Should be 0 when no stages are active. private var numRunningTasks: Int = _ + // stageId to tuple (the number of task with locality preferences, a map where each pair is a + // node and the number of tasks that would like to be scheduled on that node) map, + // maintain the executor placement hints for each stage Id used by resource framework to better + // place the executors. + private val stageIdToExecutorPlacementHints = new mutable.HashMap[Int, (Int, Map[String, Int])] + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { initializing = false val stageId = stageSubmitted.stageInfo.stageId @@ -516,6 +539,24 @@ private[spark] class ExecutorAllocationManager( allocationManager.synchronized { stageIdToNumTasks(stageId) = numTasks allocationManager.onSchedulerBacklogged() + + // Compute the number of tasks requested by the stage on each host + var numTasksPending = 0 + val hostToLocalTaskCountPerStage = new mutable.HashMap[String, Int]() + stageSubmitted.stageInfo.taskLocalityPreferences.foreach { locality => + if (!locality.isEmpty) { + numTasksPending += 1 + locality.foreach { location => + val count = hostToLocalTaskCountPerStage.getOrElse(location.host, 0) + 1 + hostToLocalTaskCountPerStage(location.host) = count + } + } + } + stageIdToExecutorPlacementHints.put(stageId, + (numTasksPending, hostToLocalTaskCountPerStage.toMap)) + + // Update the executor placement hints + updateExecutorPlacementHints() } } @@ -524,6 +565,10 @@ private[spark] class ExecutorAllocationManager( allocationManager.synchronized { stageIdToNumTasks -= stageId stageIdToTaskIndices -= stageId + stageIdToExecutorPlacementHints -= stageId + + // Update the executor placement hints + updateExecutorPlacementHints() // If this is the last stage with pending tasks, mark the scheduler queue as empty // This is needed in case the stage is aborted for any reason @@ -554,14 +599,8 @@ private[spark] class ExecutorAllocationManager( // If this is the last pending task, mark the scheduler queue as empty stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex - val numTasksScheduled = stageIdToTaskIndices(stageId).size - val numTasksTotal = stageIdToNumTasks.getOrElse(stageId, -1) - if (numTasksScheduled == numTasksTotal) { - // No more pending tasks for this stage - stageIdToNumTasks -= stageId - if (stageIdToNumTasks.isEmpty) { - allocationManager.onSchedulerQueueEmpty() - } + if (totalPendingTasks() == 0) { + allocationManager.onSchedulerQueueEmpty() } // Mark the executor on which this task is scheduled as busy @@ -573,6 +612,8 @@ private[spark] class ExecutorAllocationManager( override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { val executorId = taskEnd.taskInfo.executorId val taskId = taskEnd.taskInfo.taskId + val taskIndex = taskEnd.taskInfo.index + val stageId = taskEnd.stageId allocationManager.synchronized { numRunningTasks -= 1 // If the executor is no longer running any scheduled tasks, mark it as idle @@ -583,6 +624,16 @@ private[spark] class ExecutorAllocationManager( allocationManager.onExecutorIdle(executorId) } } + + // If the task failed, we expect it to be resubmitted later. To ensure we have + // enough resources to run the resubmitted task, we need to mark the scheduler + // as backlogged again if it's not already marked as such (SPARK-8366) + if (taskEnd.reason != Success) { + if (totalPendingTasks() == 0) { + allocationManager.onSchedulerBacklogged() + } + stageIdToTaskIndices.get(stageId).foreach { _.remove(taskIndex) } + } } } @@ -627,6 +678,29 @@ private[spark] class ExecutorAllocationManager( def isExecutorIdle(executorId: String): Boolean = { !executorIdToTaskIds.contains(executorId) } + + /** + * Update the Executor placement hints (the number of tasks with locality preferences, + * a map where each pair is a node and the number of tasks that would like to be scheduled + * on that node). + * + * These hints are updated when stages arrive and complete, so are not up-to-date at task + * granularity within stages. + */ + def updateExecutorPlacementHints(): Unit = { + var localityAwareTasks = 0 + val localityToCount = new mutable.HashMap[String, Int]() + stageIdToExecutorPlacementHints.values.foreach { case (numTasksPending, localities) => + localityAwareTasks += numTasksPending + localities.foreach { case (hostname, count) => + val updatedCount = localityToCount.getOrElse(hostname, 0) + count + localityToCount(hostname) = updatedCount + } + } + + allocationManager.localityAwareTasks = localityAwareTasks + allocationManager.hostToLocalTaskCount = localityToCount.toMap + } } /** diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 6909015ff66e..ee60d697d879 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -24,8 +24,8 @@ import scala.collection.mutable import org.apache.spark.executor.TaskMetrics import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext} import org.apache.spark.storage.BlockManagerId -import org.apache.spark.scheduler.{SlaveLost, TaskScheduler} -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.scheduler._ +import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} /** * A heartbeat from executors to the driver. This is a shared message used by several internal @@ -45,13 +45,23 @@ private[spark] case object TaskSchedulerIsSet private[spark] case object ExpireDeadHosts +private case class ExecutorRegistered(executorId: String) + +private case class ExecutorRemoved(executorId: String) + private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) /** * Lives in the driver to receive heartbeats from executors.. */ -private[spark] class HeartbeatReceiver(sc: SparkContext) - extends ThreadSafeRpcEndpoint with Logging { +private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) + extends ThreadSafeRpcEndpoint with SparkListener with Logging { + + def this(sc: SparkContext) { + this(sc, new SystemClock) + } + + sc.addSparkListener(this) override val rpcEnv: RpcEnv = sc.env.rpcEnv @@ -86,30 +96,48 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) override def onStart(): Unit = { timeoutCheckingTask = eventLoopThread.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { - Option(self).foreach(_.send(ExpireDeadHosts)) + Option(self).foreach(_.ask[Boolean](ExpireDeadHosts)) } }, 0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS) } - override def receive: PartialFunction[Any, Unit] = { - case ExpireDeadHosts => - expireDeadHosts() + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + + // Messages sent and received locally + case ExecutorRegistered(executorId) => + executorLastSeen(executorId) = clock.getTimeMillis() + context.reply(true) + case ExecutorRemoved(executorId) => + executorLastSeen.remove(executorId) + context.reply(true) case TaskSchedulerIsSet => scheduler = sc.taskScheduler - } + context.reply(true) + case ExpireDeadHosts => + expireDeadHosts() + context.reply(true) - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + // Messages received from executors case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) => if (scheduler != null) { - executorLastSeen(executorId) = System.currentTimeMillis() - eventLoopThread.submit(new Runnable { - override def run(): Unit = Utils.tryLogNonFatalError { - val unknownExecutor = !scheduler.executorHeartbeatReceived( - executorId, taskMetrics, blockManagerId) - val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) - context.reply(response) - } - }) + if (executorLastSeen.contains(executorId)) { + executorLastSeen(executorId) = clock.getTimeMillis() + eventLoopThread.submit(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + val unknownExecutor = !scheduler.executorHeartbeatReceived( + executorId, taskMetrics, blockManagerId) + val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) + context.reply(response) + } + }) + } else { + // This may happen if we get an executor's in-flight heartbeat immediately + // after we just removed it. It's not really an error condition so we should + // not log warning here. Otherwise there may be a lot of noise especially if + // we explicitly remove executors (SPARK-4134). + logDebug(s"Received heartbeat from unknown executor $executorId") + context.reply(HeartbeatResponse(reregisterBlockManager = true)) + } } else { // Because Executor will sleep several seconds before sending the first "Heartbeat", this // case rarely happens. However, if it really happens, log it and ask the executor to @@ -119,23 +147,44 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) } } + /** + * If the heartbeat receiver is not stopped, notify it of executor registrations. + */ + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = { + Option(self).foreach(_.ask[Boolean](ExecutorRegistered(executorAdded.executorId))) + } + + /** + * If the heartbeat receiver is not stopped, notify it of executor removals so it doesn't + * log superfluous errors. + * + * Note that we must do this after the executor is actually removed to guard against the + * following race condition: if we remove an executor's metadata from our data structure + * prematurely, we may get an in-flight heartbeat from the executor before the executor is + * actually removed, in which case we will still mark the executor as a dead host later + * and expire it with loud error messages. + */ + override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { + Option(self).foreach(_.ask[Boolean](ExecutorRemoved(executorRemoved.executorId))) + } + private def expireDeadHosts(): Unit = { logTrace("Checking for hosts with no recent heartbeats in HeartbeatReceiver.") - val now = System.currentTimeMillis() + val now = clock.getTimeMillis() for ((executorId, lastSeenMs) <- executorLastSeen) { if (now - lastSeenMs > executorTimeoutMs) { logWarning(s"Removing executor $executorId with no recent heartbeats: " + s"${now - lastSeenMs} ms exceeds timeout $executorTimeoutMs ms") scheduler.executorLost(executorId, SlaveLost("Executor heartbeat " + s"timed out after ${now - lastSeenMs} ms")) - if (sc.supportDynamicAllocation) { // Asynchronously kill the executor to avoid blocking the current thread - killExecutorThread.submit(new Runnable { - override def run(): Unit = Utils.tryLogNonFatalError { - sc.killExecutor(executorId) - } - }) - } + killExecutorThread.submit(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + // Note: we want to get an executor back after expiring this one, + // so do not simply call `sc.killExecutor` here (SPARK-8119) + sc.killAndReplaceExecutor(executorId) + } + }) executorLastSeen.remove(executorId) } } diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 7fcb7830e7b0..f0598816d6c0 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -121,6 +121,7 @@ trait Logging { if (usingLog4j12) { val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements if (!log4j12Initialized) { + // scalastyle:off println if (Utils.isInInterpreter) { val replDefaultLogProps = "org/apache/spark/log4j-defaults-repl.properties" Option(Utils.getSparkClassLoader.getResource(replDefaultLogProps)) match { @@ -141,6 +142,7 @@ trait Logging { System.err.println(s"Spark was unable to load $defaultLogProps") } } + // scalastyle:on println } } Logging.initialized = true @@ -157,7 +159,7 @@ private object Logging { try { // We use reflection here to handle the case where users remove the // slf4j-to-jul bridge order to route their logs to JUL. - val bridgeClass = Class.forName("org.slf4j.bridge.SLF4JBridgeHandler") + val bridgeClass = Utils.classForName("org.slf4j.bridge.SLF4JBridgeHandler") bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null) val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean] if (!installed) { diff --git a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala new file mode 100644 index 000000000000..f8a6f1d0d8cb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * Holds statistics about the output sizes in a map stage. May become a DeveloperApi in the future. + * + * @param shuffleId ID of the shuffle + * @param bytesByPartitionId approximate number of output bytes for each map output partition + * (may be inexact due to use of compressed map statuses) + */ +private[spark] class MapOutputStatistics(val shuffleId: Int, val bytesByPartitionId: Array[Long]) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 862ffe868f58..94eb8daa85c5 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -18,17 +18,18 @@ package org.apache.spark import java.io._ +import java.util.Arrays import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.{HashMap, HashSet, Map} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.reflect.ClassTag import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, RpcEndpoint} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.MetadataFetchFailedException -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util._ private[spark] sealed trait MapOutputTrackerMessage @@ -124,13 +125,51 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } /** - * Called from executors to get the server URIs and output sizes of the map outputs of - * a given shuffle. + * Called from executors to get the server URIs and output sizes for each shuffle block that + * needs to be read from a given reduce task. + * + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block id, shuffle block size) tuples + * describing the shuffle blocks that are stored at that block manager. + */ + def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, reduce $reduceId") + val statuses = getStatuses(shuffleId) + // Synchronize on the returned array because, on the driver, it gets mutated in place + statuses.synchronized { + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) + } + } + + /** + * Return statistics about all of the outputs for a given shuffle. + */ + def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { + val statuses = getStatuses(dep.shuffleId) + // Synchronize on the returned array because, on the driver, it gets mutated in place + statuses.synchronized { + val totalSizes = new Array[Long](dep.partitioner.numPartitions) + for (s <- statuses) { + for (i <- 0 until totalSizes.length) { + totalSizes(i) += s.getSizeForBlock(i) + } + } + new MapOutputStatistics(dep.shuffleId, totalSizes) + } + } + + /** + * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize + * on this array when reading it, because on the driver, we may be changing it in place. + * + * (It would be nice to remove this restriction in the future.) */ - def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { + private def getStatuses(shuffleId: Int): Array[MapStatus] = { val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") + val startTime = System.currentTimeMillis var fetchedStatuses: Array[MapStatus] = null fetching.synchronized { // Someone else is fetching it; wait for them to be done @@ -152,7 +191,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } if (fetchedStatuses == null) { - // We won the race to fetch the output locs; do so + // We won the race to fetch the statuses; do so logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) // This try-finally prevents hangs due to timeouts: try { @@ -167,19 +206,18 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } } + logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + + s"${System.currentTimeMillis - startTime} ms") + if (fetchedStatuses != null) { - fetchedStatuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) - } + return fetchedStatuses } else { logError("Missing all output locations for shuffle " + shuffleId) throw new MetadataFetchFailedException( - shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId) + shuffleId, -1, "Missing all output locations for shuffle " + shuffleId) } } else { - statuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) - } + return statuses } } @@ -387,7 +425,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) */ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { protected val mapStatuses: Map[Int, Array[MapStatus]] = - new ConcurrentHashMap[Int, Array[MapStatus]] + new ConcurrentHashMap[Int, Array[MapStatus]]().asScala } private[spark] object MapOutputTracker extends Logging { @@ -421,23 +459,38 @@ private[spark] object MapOutputTracker extends Logging { } } - // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If - // any of the statuses is null (indicating a missing location due to a failed mapper), - // throw a FetchFailedException. + /** + * Converts an array of MapStatuses for a given reduce ID to a sequence that, for each block + * manager ID, lists the shuffle block ids and corresponding shuffle block sizes stored at that + * block manager. + * + * If any of the statuses is null (indicating a missing location due to a failed mapper), + * throws a FetchFailedException. + * + * @param shuffleId Identifier for the shuffle + * @param reduceId Identifier for the reduce task + * @param statuses List of map statuses, indexed by map ID. + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block id, shuffle block size) tuples + * describing the shuffle blocks that are stored at that block manager. + */ private def convertMapStatuses( shuffleId: Int, reduceId: Int, - statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = { + statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { assert (statuses != null) - statuses.map { - status => - if (status == null) { - logError("Missing an output location for shuffle " + shuffleId) - throw new MetadataFetchFailedException( - shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId) - } else { - (status.location, status.getSizeForBlock(reduceId)) - } + val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]] + for ((status, mapId) <- statuses.zipWithIndex) { + if (status == null) { + val errorMessage = s"Missing an output location for shuffle $shuffleId" + logError(errorMessage) + throw new MetadataFetchFailedException(shuffleId, reduceId, errorMessage) + } else { + splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) += + ((ShuffleBlockId(shuffleId, mapId, reduceId), status.getSizeForBlock(reduceId))) + } } + + splitsByAddress.toSeq } } diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 82889bcd3098..e4df7af81a6d 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -56,7 +56,7 @@ object Partitioner { */ def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = { val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.size).reverse - for (r <- bySize if r.partitioner.isDefined) { + for (r <- bySize if r.partitioner.isDefined && r.partitioner.get.numPartitions > 0) { return r.partitioner.get } if (rdd.context.conf.contains("spark.default.parallelism")) { @@ -76,6 +76,8 @@ object Partitioner { * produce an unexpected or incorrect result. */ class HashPartitioner(partitions: Int) extends Partitioner { + require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.") + def numPartitions: Int = partitions def getPartition(key: Any): Int = key match { @@ -102,8 +104,8 @@ class HashPartitioner(partitions: Int) extends Partitioner { * the value of `partitions`. */ class RangePartitioner[K : Ordering : ClassTag, V]( - @transient partitions: Int, - @transient rdd: RDD[_ <: Product2[K, V]], + partitions: Int, + rdd: RDD[_ <: Product2[K, V]], private var ascending: Boolean = true) extends Partitioner { @@ -289,7 +291,7 @@ private[spark] object RangePartitioner { while ((i < numCandidates) && (j < partitions - 1)) { val (key, weight) = ordered(i) cumWeight += weight - if (cumWeight > target) { + if (cumWeight >= target) { // Skip duplicate values. if (previousBound.isEmpty || ordering.gt(key, previousBound.get)) { bounds += key diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 2cdc167f85af..3b9c885bf97a 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -18,6 +18,10 @@ package org.apache.spark import java.io.File +import java.security.NoSuchAlgorithmException +import javax.net.ssl.SSLContext + +import scala.collection.JavaConverters._ import com.typesafe.config.{Config, ConfigFactory, ConfigValueFactory} import org.eclipse.jetty.util.ssl.SslContextFactory @@ -38,7 +42,7 @@ import org.eclipse.jetty.util.ssl.SslContextFactory * @param trustStore a path to the trust-store file * @param trustStorePassword a password to access the trust-store file * @param protocol SSL protocol (remember that SSLv3 was compromised) supported by Java - * @param enabledAlgorithms a set of encryption algorithms to use + * @param enabledAlgorithms a set of encryption algorithms that may be used */ private[spark] case class SSLOptions( enabled: Boolean = false, @@ -48,7 +52,8 @@ private[spark] case class SSLOptions( trustStore: Option[File] = None, trustStorePassword: Option[String] = None, protocol: Option[String] = None, - enabledAlgorithms: Set[String] = Set.empty) { + enabledAlgorithms: Set[String] = Set.empty) + extends Logging { /** * Creates a Jetty SSL context factory according to the SSL settings represented by this object. @@ -63,7 +68,7 @@ private[spark] case class SSLOptions( trustStorePassword.foreach(sslContextFactory.setTrustStorePassword) keyPassword.foreach(sslContextFactory.setKeyManagerPassword) protocol.foreach(sslContextFactory.setProtocol) - sslContextFactory.setIncludeCipherSuites(enabledAlgorithms.toSeq: _*) + sslContextFactory.setIncludeCipherSuites(supportedAlgorithms.toSeq: _*) Some(sslContextFactory) } else { @@ -76,7 +81,6 @@ private[spark] case class SSLOptions( * object. It can be used then to compose the ultimate Akka configuration. */ def createAkkaConfig: Option[Config] = { - import scala.collection.JavaConversions._ if (enabled) { Some(ConfigFactory.empty() .withValue("akka.remote.netty.tcp.security.key-store", @@ -94,7 +98,7 @@ private[spark] case class SSLOptions( .withValue("akka.remote.netty.tcp.security.protocol", ConfigValueFactory.fromAnyRef(protocol.getOrElse(""))) .withValue("akka.remote.netty.tcp.security.enabled-algorithms", - ConfigValueFactory.fromIterable(enabledAlgorithms.toSeq)) + ConfigValueFactory.fromIterable(supportedAlgorithms.asJava)) .withValue("akka.remote.netty.tcp.enable-ssl", ConfigValueFactory.fromAnyRef(true))) } else { @@ -102,6 +106,36 @@ private[spark] case class SSLOptions( } } + /* + * The supportedAlgorithms set is a subset of the enabledAlgorithms that + * are supported by the current Java security provider for this protocol. + */ + private val supportedAlgorithms: Set[String] = { + var context: SSLContext = null + try { + context = SSLContext.getInstance(protocol.orNull) + /* The set of supported algorithms does not depend upon the keys, trust, or + rng, although they will influence which algorithms are eventually used. */ + context.init(null, null, null) + } catch { + case npe: NullPointerException => + logDebug("No SSL protocol specified") + context = SSLContext.getDefault + case nsa: NoSuchAlgorithmException => + logDebug(s"No support for requested SSL protocol ${protocol.get}") + context = SSLContext.getDefault + } + + val providerAlgorithms = context.getServerSocketFactory.getSupportedCipherSuites.toSet + + // Log which algorithms we are discarding + (enabledAlgorithms &~ providerAlgorithms).foreach { cipher => + logDebug(s"Discarding unsupported cipher $cipher") + } + + enabledAlgorithms & providerAlgorithms + } + /** Returns a string representation of this SSLOptions with all the passwords masked. */ override def toString: String = s"SSLOptions{enabled=$enabled, " + s"keyStore=$keyStore, keyStorePassword=${keyStorePassword.map(_ => "xxx")}, " + diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 673ef49e7c1c..746d2081d439 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -310,7 +310,16 @@ private[spark] class SecurityManager(sparkConf: SparkConf) setViewAcls(Set[String](defaultUser), allowedUsers) } - def getViewAcls: String = viewAcls.mkString(",") + /** + * Checking the existence of "*" is necessary as YARN can't recognize the "*" in "defaultuser,*" + */ + def getViewAcls: String = { + if (viewAcls.contains("*")) { + "*" + } else { + viewAcls.mkString(",") + } + } /** * Admin acls should be set before the view or modify acls. If you modify the admin @@ -321,7 +330,16 @@ private[spark] class SecurityManager(sparkConf: SparkConf) logInfo("Changing modify acls to: " + modifyAcls.mkString(",")) } - def getModifyAcls: String = modifyAcls.mkString(",") + /** + * Checking the existence of "*" is necessary as YARN can't recognize the "*" in "defaultuser,*" + */ + def getModifyAcls: String = { + if (modifyAcls.contains("*")) { + "*" + } else { + modifyAcls.mkString(",") + } + } /** * Admin acls should be set before the view or modify acls. If you modify the admin @@ -394,7 +412,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) def checkUIViewPermissions(user: String): Boolean = { logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " viewAcls=" + viewAcls.mkString(",")) - !aclsEnabled || user == null || viewAcls.contains(user) + !aclsEnabled || user == null || viewAcls.contains(user) || viewAcls.contains("*") } /** @@ -409,7 +427,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) def checkModifyPermissions(user: String): Boolean = { logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " modifyAcls=" + modifyAcls.mkString(",")) - !aclsEnabled || user == null || modifyAcls.contains(user) + !aclsEnabled || user == null || modifyAcls.contains(user) || modifyAcls.contains("*") } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 6cf36fbbd625..b344b5e173d6 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -18,11 +18,12 @@ package org.apache.spark import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ import scala.collection.mutable.LinkedHashSet +import org.apache.avro.{SchemaNormalization, Schema} + import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.Utils @@ -161,6 +162,26 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { this } + private final val avroNamespace = "avro.schema." + + /** + * Use Kryo serialization and register the given set of Avro schemas so that the generic + * record serializer can decrease network IO + */ + def registerAvroSchemas(schemas: Schema*): SparkConf = { + for (schema <- schemas) { + set(avroNamespace + SchemaNormalization.parsingFingerprint64(schema), schema.toString) + } + this + } + + /** Gets all the avro schemas in the configuration used in the generic Avro record serializer */ + def getAvroSchema: Map[Long, String] = { + getAll.filter { case (k, v) => k.startsWith(avroNamespace) } + .map { case (k, v) => (k.substring(avroNamespace.length).toLong, v) } + .toMap + } + /** Remove a parameter from the configuration */ def remove(key: String): SparkConf = { settings.remove(key) @@ -228,6 +249,13 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { Utils.byteStringAsBytes(get(key, defaultValue)) } + /** + * Get a size parameter as bytes, falling back to a default if not set. + */ + def getSizeAsBytes(key: String, defaultValue: Long): Long = { + Utils.byteStringAsBytes(get(key, defaultValue + "B")) + } + /** * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Kibibytes are assumed. @@ -361,6 +389,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { val driverOptsKey = "spark.driver.extraJavaOptions" val driverClassPathKey = "spark.driver.extraClassPath" val driverLibraryPathKey = "spark.driver.extraLibraryPath" + val sparkExecutorInstances = "spark.executor.instances" // Used by Yarn in 1.1 and before sys.props.get("spark.driver.libraryPath").foreach { value => @@ -448,6 +477,24 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } } } + + if (!contains(sparkExecutorInstances)) { + sys.env.get("SPARK_WORKER_INSTANCES").foreach { value => + val warning = + s""" + |SPARK_WORKER_INSTANCES was detected (set to '$value'). + |This is deprecated in Spark 1.0+. + | + |Please instead use: + | - ./spark-submit with --num-executors to specify the number of executors + | - Or set SPARK_EXECUTOR_INSTANCES + | - spark.executor.instances to configure the number of instances in the spark config. + """.stripMargin + logWarning(warning) + + set("spark.executor.instances", value) + } + } } /** @@ -527,7 +574,9 @@ private[spark] object SparkConf extends Logging { "spark.rpc.askTimeout" -> Seq( AlternateConfig("spark.akka.askTimeout", "1.4")), "spark.rpc.lookupTimeout" -> Seq( - AlternateConfig("spark.akka.lookupTimeout", "1.4")) + AlternateConfig("spark.akka.lookupTimeout", "1.4")), + "spark.streaming.fileStream.minRememberDuration" -> Seq( + AlternateConfig("spark.streaming.minRememberDuration", "1.5")) ) /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 141276ac901f..a2f34eafa2c3 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -26,13 +26,14 @@ import java.util.{Arrays, Properties, UUID} import java.util.concurrent.atomic.{AtomicReference, AtomicBoolean, AtomicInteger} import java.util.UUID.randomUUID +import scala.collection.JavaConverters._ import scala.collection.{Map, Set} -import scala.collection.JavaConversions._ import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} import scala.util.control.NonFatal +import org.apache.commons.lang.SerializationUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, @@ -114,13 +115,16 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * :: DeveloperApi :: * Alternative constructor for setting preferred locations where Spark will create executors. * + * @param config a [[org.apache.spark.SparkConf]] object specifying other Spark parameters * @param preferredNodeLocationData used in YARN mode to select nodes to launch containers on. * Can be generated using [[org.apache.spark.scheduler.InputFormatInfo.computePreferredLocations]] * from a list of input files or InputFormats for the application. */ + @deprecated("Passing in preferred locations has no effect at all, see SPARK-8949", "1.5.0") @DeveloperApi def this(config: SparkConf, preferredNodeLocationData: Map[String, Set[SplitInfo]]) = { this(config) + logWarning("Passing in preferred locations has no effect at all, see SPARK-8949") this.preferredNodeLocationData = preferredNodeLocationData } @@ -143,6 +147,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @param jars Collection of JARs to send to the cluster. These can be paths on the local file * system or HDFS, HTTP, HTTPS, or FTP URLs. * @param environment Environment variables to set on worker nodes. + * @param preferredNodeLocationData used in YARN mode to select nodes to launch containers on. + * Can be generated using [[org.apache.spark.scheduler.InputFormatInfo.computePreferredLocations]] + * from a list of input files or InputFormats for the application. */ def this( master: String, @@ -153,6 +160,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map()) = { this(SparkContext.updatedConf(new SparkConf(), master, appName, sparkHome, jars, environment)) + if (preferredNodeLocationData.nonEmpty) { + logWarning("Passing in preferred locations has no effect at all, see SPARK-8949") + } this.preferredNodeLocationData = preferredNodeLocationData } @@ -315,6 +325,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _dagScheduler = ds } + /** + * A unique identifier for the Spark application. + * Its format depends on the scheduler implementation. + * (i.e. + * in case of local spark app something like 'local-1433865536131' + * in case of YARN something like 'application_1433865536131_34483' + * ) + */ def applicationId: String = _applicationId def applicationAttemptId: Option[String] = _applicationAttemptId @@ -330,8 +348,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private[spark] var checkpointDir: Option[String] = None // Thread Local variable that can be used by users to pass information down the stack - private val localProperties = new InheritableThreadLocal[Properties] { - override protected def childValue(parent: Properties): Properties = new Properties(parent) + protected[spark] val localProperties = new InheritableThreadLocal[Properties] { + override protected def childValue(parent: Properties): Properties = { + // Note: make a clone such that changes in the parent properties aren't reflected in + // the those of the children threads, which has confusing semantics (SPARK-10563). + SerializationUtils.clone(parent).asInstanceOf[Properties] + } override protected def initialValue(): Properties = new Properties() } @@ -463,7 +485,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli .orElse(Option(System.getenv("SPARK_MEM")) .map(warnSparkMem)) .map(Utils.memoryStringToMb) - .getOrElse(512) + .getOrElse(1024) // Convert java options to env vars as a work around // since we can't set env vars directly in sbt. @@ -490,7 +512,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _schedulerBackend = sched _taskScheduler = ts _dagScheduler = new DAGScheduler(this) - _heartbeatReceiver.send(TaskSchedulerIsSet) + _heartbeatReceiver.ask[Boolean](TaskSchedulerIsSet) // start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's // constructor @@ -520,11 +542,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } // Optionally scale number of executors dynamically based on workload. Exposed for testing. - val dynamicAllocationEnabled = _conf.getBoolean("spark.dynamicAllocation.enabled", false) + val dynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(_conf) + if (!dynamicAllocationEnabled && _conf.getBoolean("spark.dynamicAllocation.enabled", false)) { + logInfo("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") + } + _executorAllocationManager = if (dynamicAllocationEnabled) { - assert(supportDynamicAllocation, - "Dynamic allocation of executors is currently only supported in YARN mode") Some(new ExecutorAllocationManager(this, listenerBus, _conf)) } else { None @@ -545,7 +569,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Post init _taskScheduler.postStartHook() - _env.metricsSystem.registerSource(new DAGSchedulerSource(dagScheduler)) _env.metricsSystem.registerSource(new BlockManagerSource(_env.blockManager)) _executorAllocationManager.foreach { e => _env.metricsSystem.registerSource(e.executorAllocationManagerSource) @@ -554,7 +577,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Make sure the context is stopped if the user forgets about it. This avoids leaving // unfinished event logs around after the JVM exits cleanly. It doesn't help if the JVM // is killed, though. - _shutdownHookRef = Utils.addShutdownHook(Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY) { () => + _shutdownHookRef = ShutdownHookManager.addShutdownHook( + ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY) { () => logInfo("Invoking stop() from shutdown hook") stop() } @@ -624,7 +648,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * [[org.apache.spark.SparkContext.setLocalProperty]]. */ def getLocalProperty(key: String): String = - Option(localProperties.get).map(_.getProperty(key)).getOrElse(null) + Option(localProperties.get).map(_.getProperty(key)).orNull /** Set a human readable description of the current job. */ def setJobDescription(value: String) { @@ -824,7 +848,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * }}} * * @note Small files are preferred, large file is also allowable, but may cause bad performance. + * @note On some filesystems, `.../path/*` can be a more efficient way to read all files + * in a directory rather than `.../path/` or `.../path` * + * @param path Directory to the input data files, the path can be comma separated paths as the + * list of inputs. * @param minPartitions A suggestion value of the minimal splitting number for input data. */ def wholeTextFiles( @@ -835,7 +863,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Use setInputPaths so that wholeTextFiles aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updateConf = job.getConfiguration + val updateConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) new WholeTextFileRDD( this, classOf[WholeTextFileInputFormat], @@ -845,7 +873,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli minPartitions).setName(path) } - /** * :: Experimental :: * @@ -861,7 +888,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * }}} * * Do - * `val rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * `val rdd = sparkContext.binaryFiles("hdfs://a-hdfs-path")`, * * then `rdd` contains * {{{ @@ -871,9 +898,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * (a-hdfs-path/part-nnnnn, its content) * }}} * - * @param minPartitions A suggestion value of the minimal splitting number for input data. - * * @note Small files are preferred; very large files may cause bad performance. + * @note On some filesystems, `.../path/*` can be a more efficient way to read all files + * in a directory rather than `.../path/` or `.../path` + * + * @param path Directory to the input data files, the path can be comma separated paths as the + * list of inputs. + * @param minPartitions A suggestion value of the minimal splitting number for input data. */ @Experimental def binaryFiles( @@ -884,7 +915,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Use setInputPaths so that binaryFiles aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updateConf = job.getConfiguration + val updateConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) new BinaryFileRDD( this, classOf[StreamInputFormat], @@ -902,8 +933,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * '''Note:''' We ensure that the byte array for each record in the resulting RDD * has the provided record length. * - * @param path Directory to the input data files + * @param path Directory to the input data files, the path can be comma separated paths as the + * list of inputs. * @param recordLength The length at which to split the records + * @param conf Configuration for setting up the dataset. + * * @return An RDD of data with values, represented as byte arrays */ @Experimental @@ -1063,7 +1097,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Use setInputPaths so that newAPIHadoopFile aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updatedConf = job.getConfiguration + val updatedConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf).setName(path) } @@ -1186,7 +1220,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } protected[spark] def checkpointFile[T: ClassTag](path: String): RDD[T] = withScope { - new CheckpointRDD[T](this, path) + new ReliableCheckpointRDD[T](this, path) } /** Build the union of a list of RDDs. */ @@ -1353,13 +1387,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli postEnvironmentUpdate() } - /** - * Return whether dynamically adjusting the amount of resources allocated to - * this application is supported. This is currently only available for YARN. - */ - private[spark] def supportDynamicAllocation = - master.contains("yarn") || _conf.getBoolean("spark.dynamicAllocation.testing", false) - /** * :: DeveloperApi :: * Register a listener to receive up-calls from events that happen during execution. @@ -1370,16 +1397,27 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } /** - * Express a preference to the cluster manager for a given total number of executors. - * This can result in canceling pending requests or filing additional requests. - * This is currently only supported in YARN mode. Return whether the request is received. - */ - private[spark] override def requestTotalExecutors(numExecutors: Int): Boolean = { - assert(supportDynamicAllocation, - "Requesting executors is currently only supported in YARN mode") + * Update the cluster manager on our scheduling needs. Three bits of information are included + * to help it make decisions. + * @param numExecutors The total number of executors we'd like to have. The cluster manager + * shouldn't kill any running executor to reach this number, but, + * if all existing executors were to die, this is the number of executors + * we'd want to be allocated. + * @param localityAwareTasks The number of tasks in all active stages that have a locality + * preferences. This includes running, pending, and completed tasks. + * @param hostToLocalTaskCount A map of hosts to the number of tasks from all active stages + * that would like to like to run on that host. + * This includes running, pending, and completed tasks. + * @return whether the request is acknowledged by the cluster manager. + */ + private[spark] override def requestTotalExecutors( + numExecutors: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: scala.collection.immutable.Map[String, Int] + ): Boolean = { schedulerBackend match { case b: CoarseGrainedSchedulerBackend => - b.requestTotalExecutors(numExecutors) + b.requestTotalExecutors(numExecutors, localityAwareTasks, hostToLocalTaskCount) case _ => logWarning("Requesting executors is only supported in coarse-grained mode") false @@ -1389,12 +1427,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * :: DeveloperApi :: * Request an additional number of executors from the cluster manager. - * This is currently only supported in YARN mode. Return whether the request is received. + * @return whether the request is received. */ @DeveloperApi override def requestExecutors(numAdditionalExecutors: Int): Boolean = { - assert(supportDynamicAllocation, - "Requesting executors is currently only supported in YARN mode") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => b.requestExecutors(numAdditionalExecutors) @@ -1407,12 +1443,16 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * :: DeveloperApi :: * Request that the cluster manager kill the specified executors. - * This is currently only supported in YARN mode. Return whether the request is received. + * + * Note: This is an indication to the cluster manager that the application wishes to adjust + * its resource usage downwards. If the application wishes to replace the executors it kills + * through this method with new ones, it should follow up explicitly with a call to + * {{SparkContext#requestExecutors}}. + * + * @return whether the request is received. */ @DeveloperApi override def killExecutors(executorIds: Seq[String]): Boolean = { - assert(supportDynamicAllocation, - "Killing executors is currently only supported in YARN mode") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => b.killExecutors(executorIds) @@ -1424,12 +1464,42 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * :: DeveloperApi :: - * Request that cluster manager the kill the specified executor. - * This is currently only supported in Yarn mode. Return whether the request is received. + * Request that the cluster manager kill the specified executor. + * + * Note: This is an indication to the cluster manager that the application wishes to adjust + * its resource usage downwards. If the application wishes to replace the executor it kills + * through this method with a new one, it should follow up explicitly with a call to + * {{SparkContext#requestExecutors}}. + * + * @return whether the request is received. */ @DeveloperApi override def killExecutor(executorId: String): Boolean = super.killExecutor(executorId) + /** + * Request that the cluster manager kill the specified executor without adjusting the + * application resource requirements. + * + * The effect is that a new executor will be launched in place of the one killed by + * this request. This assumes the cluster manager will automatically and eventually + * fulfill all missing application resource requests. + * + * Note: The replace is by no means guaranteed; another application on the same cluster + * can steal the window of opportunity and acquire this application's resources in the + * mean time. + * + * @return whether the request is received. + */ + private[spark] def killAndReplaceExecutor(executorId: String): Boolean = { + schedulerBackend match { + case b: CoarseGrainedSchedulerBackend => + b.killExecutors(Seq(executorId), replace = true) + case _ => + logWarning("Killing executors is only supported in coarse-grained mode") + false + } + } + /** The version of Spark on which this application is running. */ def version: String = SPARK_VERSION @@ -1451,8 +1521,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getRDDStorageInfo: Array[RDDInfo] = { + getRDDStorageInfo(_ => true) + } + + private[spark] def getRDDStorageInfo(filter: RDD[_] => Boolean): Array[RDDInfo] = { assertNotStopped() - val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray + val rddInfos = persistentRdds.values.filter(filter).map(RDDInfo.fromRdd).toArray StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) rddInfos.filter(_.isCached) } @@ -1481,7 +1555,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def getAllPools: Seq[Schedulable] = { assertNotStopped() // TODO(xiajunluan): We should take nested pools into account - taskScheduler.rootPool.schedulableQueue.toSeq + taskScheduler.rootPool.schedulableQueue.asScala.toSeq } /** @@ -1525,11 +1599,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Register an RDD to be persisted in memory and/or disk storage */ private[spark] def persistRDD(rdd: RDD[_]) { - _executorAllocationManager.foreach { _ => - logWarning( - s"Dynamic allocation currently does not support cached RDDs. Cached data for RDD " + - s"${rdd.id} will be lost when executors are removed.") - } persistentRdds(rdd.id) = rdd } @@ -1625,36 +1694,60 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli return } if (_shutdownHookRef != null) { - Utils.removeShutdownHook(_shutdownHookRef) + ShutdownHookManager.removeShutdownHook(_shutdownHookRef) } - postApplicationEnd() - _ui.foreach(_.stop()) + Utils.tryLogNonFatalError { + postApplicationEnd() + } + Utils.tryLogNonFatalError { + _ui.foreach(_.stop()) + } if (env != null) { - env.metricsSystem.report() + Utils.tryLogNonFatalError { + env.metricsSystem.report() + } } if (metadataCleaner != null) { - metadataCleaner.cancel() + Utils.tryLogNonFatalError { + metadataCleaner.cancel() + } + } + Utils.tryLogNonFatalError { + _cleaner.foreach(_.stop()) + } + Utils.tryLogNonFatalError { + _executorAllocationManager.foreach(_.stop()) } - _cleaner.foreach(_.stop()) - _executorAllocationManager.foreach(_.stop()) if (_dagScheduler != null) { - _dagScheduler.stop() + Utils.tryLogNonFatalError { + _dagScheduler.stop() + } _dagScheduler = null } if (_listenerBusStarted) { - listenerBus.stop() - _listenerBusStarted = false + Utils.tryLogNonFatalError { + listenerBus.stop() + _listenerBusStarted = false + } + } + Utils.tryLogNonFatalError { + _eventLogger.foreach(_.stop()) } - _eventLogger.foreach(_.stop()) if (env != null && _heartbeatReceiver != null) { - env.rpcEnv.stop(_heartbeatReceiver) + Utils.tryLogNonFatalError { + env.rpcEnv.stop(_heartbeatReceiver) + } + } + Utils.tryLogNonFatalError { + _progressBar.foreach(_.stop()) } - _progressBar.foreach(_.stop()) _taskScheduler = null // TODO: Cache.stop()? if (_env != null) { - _env.stop() + Utils.tryLogNonFatalError { + _env.stop() + } SparkEnv.set(null) } SparkContext.clearActiveContext() @@ -1710,16 +1803,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * Run a function on a given set of partitions in an RDD and pass the results to the given - * handler function. This is the main entry point for all actions in Spark. The allowLocal - * flag specifies whether the scheduler can run the computation on the driver rather than - * shipping it out to the cluster, for short actions like first(). + * handler function. This is the main entry point for all actions in Spark. */ def runJob[T, U: ClassTag]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], - allowLocal: Boolean, - resultHandler: (Int, U) => Unit) { + resultHandler: (Int, U) => Unit): Unit = { if (stopped.get()) { throw new IllegalStateException("SparkContext has been shutdown") } @@ -1729,54 +1819,104 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli if (conf.getBoolean("spark.logLineage", false)) { logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString) } - dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, - resultHandler, localProperties.get) + dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get) progressBar.foreach(_.finishAll()) rdd.doCheckpoint() } /** - * Run a function on a given set of partitions in an RDD and return the results as an array. The - * allowLocal flag specifies whether the scheduler can run the computation on the driver rather - * than shipping it out to the cluster, for short actions like first(). + * Run a function on a given set of partitions in an RDD and return the results as an array. + */ + def runJob[T, U: ClassTag]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int]): Array[U] = { + val results = new Array[U](partitions.size) + runJob[T, U](rdd, func, partitions, (index, res) => results(index) = res) + results + } + + /** + * Run a job on a given set of partitions of an RDD, but take a function of type + * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`. + */ + def runJob[T, U: ClassTag]( + rdd: RDD[T], + func: Iterator[T] => U, + partitions: Seq[Int]): Array[U] = { + val cleanedFunc = clean(func) + runJob(rdd, (ctx: TaskContext, it: Iterator[T]) => cleanedFunc(it), partitions) + } + + + /** + * Run a function on a given set of partitions in an RDD and pass the results to the given + * handler function. This is the main entry point for all actions in Spark. + * + * The allowLocal flag is deprecated as of Spark 1.5.0+. */ + @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0") + def runJob[T, U: ClassTag]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int], + allowLocal: Boolean, + resultHandler: (Int, U) => Unit): Unit = { + if (allowLocal) { + logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+") + } + runJob(rdd, func, partitions, resultHandler) + } + + /** + * Run a function on a given set of partitions in an RDD and return the results as an array. + * + * The allowLocal flag is deprecated as of Spark 1.5.0+. + */ + @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0") def runJob[T, U: ClassTag]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], allowLocal: Boolean ): Array[U] = { - val results = new Array[U](partitions.size) - runJob[T, U](rdd, func, partitions, allowLocal, (index, res) => results(index) = res) - results + if (allowLocal) { + logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+") + } + runJob(rdd, func, partitions) } /** * Run a job on a given set of partitions of an RDD, but take a function of type * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`. + * + * The allowLocal argument is deprecated as of Spark 1.5.0+. */ + @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0") def runJob[T, U: ClassTag]( rdd: RDD[T], func: Iterator[T] => U, partitions: Seq[Int], allowLocal: Boolean ): Array[U] = { - val cleanedFunc = clean(func) - runJob(rdd, (ctx: TaskContext, it: Iterator[T]) => cleanedFunc(it), partitions, allowLocal) + if (allowLocal) { + logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+") + } + runJob(rdd, func, partitions) } /** * Run a job on all partitions in an RDD and return the results in an array. */ def runJob[T, U: ClassTag](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = { - runJob(rdd, func, 0 until rdd.partitions.size, false) + runJob(rdd, func, 0 until rdd.partitions.length) } /** * Run a job on all partitions in an RDD and return the results in an array. */ def runJob[T, U: ClassTag](rdd: RDD[T], func: Iterator[T] => U): Array[U] = { - runJob(rdd, func, 0 until rdd.partitions.size, false) + runJob(rdd, func, 0 until rdd.partitions.length) } /** @@ -1787,7 +1927,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli processPartition: (TaskContext, Iterator[T]) => U, resultHandler: (Int, U) => Unit) { - runJob[T, U](rdd, processPartition, 0 until rdd.partitions.size, false, resultHandler) + runJob[T, U](rdd, processPartition, 0 until rdd.partitions.length, resultHandler) } /** @@ -1799,7 +1939,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli resultHandler: (Int, U) => Unit) { val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter) - runJob[T, U](rdd, processFunc, 0 until rdd.partitions.size, false, resultHandler) + runJob[T, U](rdd, processFunc, 0 until rdd.partitions.length, resultHandler) } /** @@ -1844,12 +1984,28 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli (context: TaskContext, iter: Iterator[T]) => cleanF(iter), partitions, callSite, - allowLocal = false, resultHandler, localProperties.get) new SimpleFutureAction(waiter, resultFunc) } + /** + * Submit a map stage for execution. This is currently an internal API only, but might be + * promoted to DeveloperApi in the future. + */ + private[spark] def submitMapStage[K, V, C](dependency: ShuffleDependency[K, V, C]) + : SimpleFutureAction[MapOutputStatistics] = { + assertNotStopped() + val callSite = getCallSite() + var result: MapOutputStatistics = null + val waiter = dagScheduler.submitMapStage( + dependency, + (r: MapOutputStatistics) => { result = r }, + callSite, + localProperties.get) + new SimpleFutureAction[MapOutputStatistics](waiter, result) + } + /** * Cancel active jobs for the specified group. See [[org.apache.spark.SparkContext.setJobGroup]] * for more information. @@ -1897,6 +2053,16 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * be a HDFS path if running on a cluster. */ def setCheckpointDir(directory: String) { + + // If we are running on a cluster, log a warning if the directory is local. + // Otherwise, the driver may attempt to reconstruct the checkpointed RDD from + // its own local file system, which is incorrect because the checkpoint files + // are actually on the executor machines. + if (!isLocal && Utils.nonLocalPaths(directory).isEmpty) { + logWarning("Checkpoint directory must be non-local " + + "if Spark is running on a cluster: " + directory) + } + checkpointDir = Option(directory).map { dir => val path = new Path(dir, UUID.randomUUID().toString) val fs = path.getFileSystem(hadoopConfiguration) @@ -1946,7 +2112,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli for (className <- listenerClassNames) { // Use reflection to find the right constructor val constructors = { - val listenerClass = Class.forName(className) + val listenerClass = Utils.classForName(className) listenerClass.getConstructors.asInstanceOf[Array[Constructor[_ <: SparkListener]]] } val constructorTakingSparkConf = constructors.find { c => @@ -2481,7 +2647,7 @@ object SparkContext extends Logging { "\"yarn-standalone\" is deprecated as of Spark 1.0. Use \"yarn-cluster\" instead.") } val scheduler = try { - val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") + val clazz = Utils.classForName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] } catch { @@ -2493,7 +2659,7 @@ object SparkContext extends Logging { } val backend = try { val clazz = - Class.forName("org.apache.spark.scheduler.cluster.YarnClusterSchedulerBackend") + Utils.classForName("org.apache.spark.scheduler.cluster.YarnClusterSchedulerBackend") val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext]) cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] } catch { @@ -2506,8 +2672,7 @@ object SparkContext extends Logging { case "yarn-client" => val scheduler = try { - val clazz = - Class.forName("org.apache.spark.scheduler.cluster.YarnScheduler") + val clazz = Utils.classForName("org.apache.spark.scheduler.cluster.YarnScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] @@ -2519,7 +2684,7 @@ object SparkContext extends Logging { val backend = try { val clazz = - Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend") + Utils.classForName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend") val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext]) cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] } catch { @@ -2537,7 +2702,7 @@ object SparkContext extends Logging { val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", false) val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs val backend = if (coarseGrained) { - new CoarseMesosSchedulerBackend(scheduler, sc, url) + new CoarseMesosSchedulerBackend(scheduler, sc, url, sc.env.securityManager) } else { new MesosSchedulerBackend(scheduler, sc, url) } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index b0665570e268..c6fef7f91f00 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -22,7 +22,6 @@ import java.net.Socket import akka.actor.ActorSystem -import scala.collection.JavaConversions._ import scala.collection.mutable import scala.util.Properties @@ -34,7 +33,6 @@ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService -import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv} import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus} @@ -77,7 +75,7 @@ class SparkEnv ( val conf: SparkConf) extends Logging { // TODO Remove actorSystem - @deprecated("Actor system is no longer supported as of 1.4") + @deprecated("Actor system is no longer supported as of 1.4.0", "1.4.0") val actorSystem: ActorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem private[spark] var isStopped = false @@ -90,39 +88,42 @@ class SparkEnv ( private var driverTmpDirToDelete: Option[String] = None private[spark] def stop() { - isStopped = true - pythonWorkers.foreach { case(key, worker) => worker.stop() } - Option(httpFileServer).foreach(_.stop()) - mapOutputTracker.stop() - shuffleManager.stop() - broadcastManager.stop() - blockManager.stop() - blockManager.master.stop() - metricsSystem.stop() - outputCommitCoordinator.stop() - rpcEnv.shutdown() - - // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut - // down, but let's call it anyway in case it gets fixed in a later release - // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. - // actorSystem.awaitTermination() - - // Note that blockTransferService is stopped by BlockManager since it is started by it. - - // If we only stop sc, but the driver process still run as a services then we need to delete - // the tmp dir, if not, it will create too many tmp dirs. - // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the - // current working dir in executor which we do not need to delete. - driverTmpDirToDelete match { - case Some(path) => { - try { - Utils.deleteRecursively(new File(path)) - } catch { - case e: Exception => - logWarning(s"Exception while deleting Spark temp dir: $path", e) + + if (!isStopped) { + isStopped = true + pythonWorkers.values.foreach(_.stop()) + Option(httpFileServer).foreach(_.stop()) + mapOutputTracker.stop() + shuffleManager.stop() + broadcastManager.stop() + blockManager.stop() + blockManager.master.stop() + metricsSystem.stop() + outputCommitCoordinator.stop() + rpcEnv.shutdown() + + // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut + // down, but let's call it anyway in case it gets fixed in a later release + // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. + // actorSystem.awaitTermination() + + // Note that blockTransferService is stopped by BlockManager since it is started by it. + + // If we only stop sc, but the driver process still run as a services then we need to delete + // the tmp dir, if not, it will create too many tmp dirs. + // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the + // current working dir in executor which we do not need to delete. + driverTmpDirToDelete match { + case Some(path) => { + try { + Utils.deleteRecursively(new File(path)) + } catch { + case e: Exception => + logWarning(s"Exception while deleting Spark temp dir: $path", e) + } } + case None => // We just need to delete tmp dir created by driver, so do nothing on executor } - case None => // We just need to delete tmp dir created by driver, so do nothing on executor } } @@ -171,7 +172,7 @@ object SparkEnv extends Logging { /** * Returns the ThreadLocal SparkEnv. */ - @deprecated("Use SparkEnv.get instead", "1.2") + @deprecated("Use SparkEnv.get instead", "1.2.0") def getThreadLocal: SparkEnv = { env } @@ -259,7 +260,7 @@ object SparkEnv extends Logging { // Create an instance of the class with the given name, possibly initializing it with our conf def instantiateClass[T](className: String): T = { - val cls = Class.forName(className, true, Utils.getContextOrSparkClassLoader) + val cls = Utils.classForName(className) // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just // SparkConf, then one taking no arguments try { @@ -322,15 +323,9 @@ object SparkEnv extends Logging { val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) - val shuffleMemoryManager = new ShuffleMemoryManager(conf) + val shuffleMemoryManager = ShuffleMemoryManager.create(conf, numUsableCores) - val blockTransferService = - conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match { - case "netty" => - new NettyBlockTransferService(conf, securityManager, numUsableCores) - case "nio" => - new NioBlockTransferService(conf, securityManager) - } + val blockTransferService = new NettyBlockTransferService(conf, securityManager, numUsableCores) val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint( BlockManagerMaster.DRIVER_ENDPOINT_NAME, diff --git a/core/src/main/scala/org/apache/spark/SparkException.scala b/core/src/main/scala/org/apache/spark/SparkException.scala index 2ebd7a7151a5..977a27bdfe1b 100644 --- a/core/src/main/scala/org/apache/spark/SparkException.scala +++ b/core/src/main/scala/org/apache/spark/SparkException.scala @@ -30,3 +30,10 @@ class SparkException(message: String, cause: Throwable) */ private[spark] class SparkDriverExecutionException(cause: Throwable) extends SparkException("Execution error", cause) + +/** + * Exception thrown when the main user code is run as a child process (e.g. pyspark) and we want + * the parent SparkSubmit process to exit with the same exit code. + */ +private[spark] case class SparkUserAppException(exitCode: Int) + extends SparkException(s"User application exited with $exitCode") diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index f5dd36cbcfe6..ac6eaab20d8d 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -37,7 +37,7 @@ import org.apache.spark.util.SerializableJobConf * a filename to write to, etc, exactly like in a Hadoop MapReduce job. */ private[spark] -class SparkHadoopWriter(@transient jobConf: JobConf) +class SparkHadoopWriter(jobConf: JobConf) extends Logging with SparkHadoopMapRedUtil with Serializable { @@ -104,8 +104,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) } def commit() { - SparkHadoopMapRedUtil.commitTask( - getOutputCommitter(), getTaskContext(), jobID, splitID, attemptID) + SparkHadoopMapRedUtil.commitTask(getOutputCommitter(), getTaskContext(), jobID, splitID) } def commitJob() { diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index d09e17dea091..63cca80b2d73 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -21,6 +21,7 @@ import java.io.Serializable import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +import org.apache.spark.metrics.source.Source import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.TaskCompletionListener @@ -32,7 +33,20 @@ object TaskContext { */ def get(): TaskContext = taskContext.get - private val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext] + /** + * Returns the partition id of currently active TaskContext. It will return 0 + * if there is no active TaskContext for cases like local execution. + */ + def getPartitionId(): Int = { + val tc = taskContext.get() + if (tc eq null) { + 0 + } else { + tc.partitionId() + } + } + + private[this] val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext] // Note: protected[spark] instead of private[spark] to prevent the following two from // showing up in JavaDoc. @@ -45,6 +59,14 @@ object TaskContext { * Unset the thread local TaskContext. Internal to Spark. */ protected[spark] def unset(): Unit = taskContext.remove() + + /** + * An empty task context that does not represent an actual task. + */ + private[spark] def empty(): TaskContextImpl = { + new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty) + } + } @@ -135,8 +157,39 @@ abstract class TaskContext extends Serializable { @DeveloperApi def taskMetrics(): TaskMetrics + /** + * ::DeveloperApi:: + * Returns all metrics sources with the given name which are associated with the instance + * which runs the task. For more information see [[org.apache.spark.metrics.MetricsSystem!]]. + */ + @DeveloperApi + def getMetricsSources(sourceName: String): Seq[Source] + /** * Returns the manager for this task's managed memory. */ private[spark] def taskMemoryManager(): TaskMemoryManager + + /** + * Register an accumulator that belongs to this task. Accumulators must call this method when + * deserializing in executors. + */ + private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit + + /** + * Return the local values of internal accumulators that belong to this task. The key of the Map + * is the accumulator id and the value of the Map is the latest accumulator local value. + */ + private[spark] def collectInternalAccumulators(): Map[Long, Any] + + /** + * Return the local values of accumulators that belong to this task. The key of the Map is the + * accumulator id and the value of the Map is the latest accumulator local value. + */ + private[spark] def collectAccumulators(): Map[Long, Any] + + /** + * Accumulators for tracking internal metrics indexed by the name. + */ + private[spark] val internalMetricsToAccumulators: Map[String, Accumulator[Long]] } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index b4d572cb5231..5df94c6d3a10 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -17,18 +17,22 @@ package org.apache.spark +import scala.collection.mutable.{ArrayBuffer, HashMap} + import org.apache.spark.executor.TaskMetrics +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.metrics.source.Source import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} -import scala.collection.mutable.ArrayBuffer - private[spark] class TaskContextImpl( val stageId: Int, val partitionId: Int, override val taskAttemptId: Long, override val attemptNumber: Int, override val taskMemoryManager: TaskMemoryManager, + @transient private val metricsSystem: MetricsSystem, + internalAccumulators: Seq[Accumulator[Long]], val runningLocally: Boolean = false, val taskMetrics: TaskMetrics = TaskMetrics.empty) extends TaskContext @@ -94,5 +98,28 @@ private[spark] class TaskContextImpl( override def isRunningLocally(): Boolean = runningLocally override def isInterrupted(): Boolean = interrupted -} + override def getMetricsSources(sourceName: String): Seq[Source] = + metricsSystem.getSourcesByName(sourceName) + + @transient private val accumulators = new HashMap[Long, Accumulable[_, _]] + + private[spark] override def registerAccumulator(a: Accumulable[_, _]): Unit = synchronized { + accumulators(a.id) = a + } + + private[spark] override def collectInternalAccumulators(): Map[Long, Any] = synchronized { + accumulators.filter(_._2.isInternal).mapValues(_.localValue).toMap + } + + private[spark] override def collectAccumulators(): Map[Long, Any] = synchronized { + accumulators.mapValues(_.localValue).toMap + } + + private[spark] override val internalMetricsToAccumulators: Map[String, Accumulator[Long]] = { + // Explicitly register internal accumulators here because these are + // not captured in the task closure and are already deserialized + internalAccumulators.foreach(registerAccumulator) + internalAccumulators.map { a => (a.name.get, a) }.toMap + } +} diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 48fd3e7e23d5..7137246bc34f 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -17,6 +17,8 @@ package org.apache.spark +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId @@ -46,6 +48,8 @@ case object Success extends TaskEndReason sealed trait TaskFailedReason extends TaskEndReason { /** Error message displayed in the web UI. */ def toErrorString: String + + def shouldEventuallyFailJob: Boolean = true } /** @@ -90,6 +94,10 @@ case class FetchFailed( * * `fullStackTrace` is a better representation of the stack trace because it contains the whole * stack trace including the exception and its causes + * + * `exception` is the actual exception that caused the task to fail. It may be `None` in + * the case that the exception is not in fact serializable. If a task fails more than + * once (due to retries), `exception` is that one that caused the last failure. */ @DeveloperApi case class ExceptionFailure( @@ -97,11 +105,26 @@ case class ExceptionFailure( description: String, stackTrace: Array[StackTraceElement], fullStackTrace: String, - metrics: Option[TaskMetrics]) + metrics: Option[TaskMetrics], + private val exceptionWrapper: Option[ThrowableSerializationWrapper]) extends TaskFailedReason { + /** + * `preserveCause` is used to keep the exception itself so it is available to the + * driver. This may be set to `false` in the event that the exception is not in fact + * serializable. + */ + private[spark] def this(e: Throwable, metrics: Option[TaskMetrics], preserveCause: Boolean) { + this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics, + if (preserveCause) Some(new ThrowableSerializationWrapper(e)) else None) + } + private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) { - this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics) + this(e, metrics, preserveCause = true) + } + + def exception: Option[Throwable] = exceptionWrapper.flatMap { + (w: ThrowableSerializationWrapper) => Option(w.exception) } override def toErrorString: String = @@ -127,6 +150,25 @@ case class ExceptionFailure( } } +/** + * A class for recovering from exceptions when deserializing a Throwable that was + * thrown in user task code. If the Throwable cannot be deserialized it will be null, + * but the stacktrace and message will be preserved correctly in SparkException. + */ +private[spark] class ThrowableSerializationWrapper(var exception: Throwable) extends + Serializable with Logging { + private def writeObject(out: ObjectOutputStream): Unit = { + out.writeObject(exception) + } + private def readObject(in: ObjectInputStream): Unit = { + try { + exception = in.readObject().asInstanceOf[Throwable] + } catch { + case e : Exception => log.warn("Task exception could not be deserialized", e) + } + } +} + /** * :: DeveloperApi :: * The task finished successfully, but the result was lost from the executor's block manager before @@ -151,9 +193,18 @@ case object TaskKilled extends TaskFailedReason { * Task requested the driver to commit, but was denied. */ @DeveloperApi -case class TaskCommitDenied(jobID: Int, partitionID: Int, attemptID: Int) extends TaskFailedReason { +case class TaskCommitDenied( + jobID: Int, + partitionID: Int, + attemptNumber: Int) extends TaskFailedReason { override def toErrorString: String = s"TaskCommitDenied (Driver denied task commit)" + - s" for job: $jobID, partition: $partitionID, attempt: $attemptID" + s" for job: $jobID, partition: $partitionID, attemptNumber: $attemptNumber" + /** + * If a task failed because its attempt to commit was denied, do not count this failure + * towards failing the stage. This is intended to prevent spurious stage failures in cases + * where many speculative tasks are launched and denied to commit. + */ + override def shouldEventuallyFailJob: Boolean = false } /** @@ -162,8 +213,14 @@ case class TaskCommitDenied(jobID: Int, partitionID: Int, attemptID: Int) extend * the task crashed the JVM. */ @DeveloperApi -case class ExecutorLostFailure(execId: String) extends TaskFailedReason { - override def toErrorString: String = s"ExecutorLostFailure (executor ${execId} lost)" +case class ExecutorLostFailure(execId: String, isNormalExit: Boolean = false) + extends TaskFailedReason { + override def toErrorString: String = { + val exitBehavior = if (isNormalExit) "normally" else "abnormally" + s"ExecutorLostFailure (executor ${execId} exited ${exitBehavior})" + } + + override def shouldEventuallyFailJob: Boolean = !isNormalExit } /** diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index a1ebbecf93b7..888763a3e8eb 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -19,11 +19,12 @@ package org.apache.spark import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} import java.net.{URI, URL} +import java.nio.charset.StandardCharsets +import java.util.Arrays import java.util.jar.{JarEntry, JarOutputStream} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ -import com.google.common.base.Charsets.UTF_8 import com.google.common.io.{ByteStreams, Files} import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} @@ -71,7 +72,7 @@ private[spark] object TestUtils { files.foreach { case (k, v) => val entry = new JarEntry(k) jarStream.putNextEntry(entry) - ByteStreams.copy(new ByteArrayInputStream(v.getBytes(UTF_8)), jarStream) + ByteStreams.copy(new ByteArrayInputStream(v.getBytes(StandardCharsets.UTF_8)), jarStream) } jarStream.close() jarFile.toURI.toURL @@ -125,7 +126,7 @@ private[spark] object TestUtils { } else { Seq() } - compiler.getTask(null, null, null, options, null, Seq(sourceFile)).call() + compiler.getTask(null, null, null, options.asJava, null, Arrays.asList(sourceFile)).call() val fileName = className + ".class" val result = new File(fileName) diff --git a/core/src/main/scala/org/apache/spark/annotation/Since.scala b/core/src/main/scala/org/apache/spark/annotation/Since.scala new file mode 100644 index 000000000000..af483e361e33 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/annotation/Since.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.annotation + +import scala.annotation.StaticAnnotation +import scala.annotation.meta._ + +/** + * A Scala annotation that specifies the Spark version when a definition was added. + * Different from the `@since` tag in JavaDoc, this annotation does not require explicit JavaDoc and + * hence works for overridden methods that inherit API documentation directly from parents. + * The limitation is that it does not show up in the generated Java API documentation. + */ +@param @field @getter @setter @beanGetter @beanSetter +private[spark] class Since(version: String) extends StaticAnnotation diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala index 0ae0b4ec042e..891bcddeac28 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.java -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.hadoop.mapred.InputSplit @@ -37,7 +37,7 @@ class JavaHadoopRDD[K, V](rdd: HadoopRDD[K, V]) def mapPartitionsWithInputSplit[R]( f: JFunction2[InputSplit, java.util.Iterator[(K, V)], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = { - new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, asJavaIterator(b)), + new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, b.asJava).asScala, preservesPartitioning)(fakeClassTag))(fakeClassTag) } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala index ec4f3964d75e..0f49279f3e64 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.java -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.hadoop.mapreduce.InputSplit @@ -37,7 +37,7 @@ class JavaNewHadoopRDD[K, V](rdd: NewHadoopRDD[K, V]) def mapPartitionsWithInputSplit[R]( f: JFunction2[InputSplit, java.util.Iterator[(K, V)], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = { - new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, asJavaIterator(b)), + new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, b.asJava).asScala, preservesPartitioning)(fakeClassTag))(fakeClassTag) } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 8441bb3a3047..8344f6368ac4 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -20,7 +20,7 @@ package org.apache.spark.api.java import java.util.{Comparator, List => JList, Map => JMap} import java.lang.{Iterable => JIterable} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -142,7 +142,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double], seed: Long): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, seed)) + new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions.asScala, seed)) /** * Return a subset of this RDD sampled by key (via stratified sampling). @@ -173,7 +173,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) def sampleByKeyExact(withReplacement: Boolean, fractions: JMap[K, Double], seed: Long): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd.sampleByKeyExact(withReplacement, fractions, seed)) + new JavaPairRDD[K, V](rdd.sampleByKeyExact(withReplacement, fractions.asScala, seed)) /** * ::Experimental:: @@ -239,7 +239,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) mapSideCombine: Boolean, serializer: Serializer): JavaPairRDD[K, C] = { implicit val ctag: ClassTag[C] = fakeClassTag - fromRDD(rdd.combineByKey( + fromRDD(rdd.combineByKeyWithClassTag( createCombiner, mergeValue, mergeCombiners, @@ -768,7 +768,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Return the list of values in the RDD for key `key`. This operation is done efficiently if the * RDD has a known partitioner by only searching the partition that the key maps to. */ - def lookup(key: K): JList[V] = seqAsJavaList(rdd.lookup(key)) + def lookup(key: K): JList[V] = rdd.lookup(key).asJava /** Output the RDD to any Hadoop-supported file system. */ def saveAsHadoopFile[F <: OutputFormat[_, _]]( @@ -987,30 +987,27 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) object JavaPairRDD { private[spark] def groupByResultToJava[K: ClassTag, T](rdd: RDD[(K, Iterable[T])]): RDD[(K, JIterable[T])] = { - rddToPairRDDFunctions(rdd).mapValues(asJavaIterable) + rddToPairRDDFunctions(rdd).mapValues(_.asJava) } private[spark] def cogroupResultToJava[K: ClassTag, V, W]( rdd: RDD[(K, (Iterable[V], Iterable[W]))]): RDD[(K, (JIterable[V], JIterable[W]))] = { - rddToPairRDDFunctions(rdd).mapValues(x => (asJavaIterable(x._1), asJavaIterable(x._2))) + rddToPairRDDFunctions(rdd).mapValues(x => (x._1.asJava, x._2.asJava)) } private[spark] def cogroupResult2ToJava[K: ClassTag, V, W1, W2]( rdd: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))]) : RDD[(K, (JIterable[V], JIterable[W1], JIterable[W2]))] = { - rddToPairRDDFunctions(rdd) - .mapValues(x => (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3))) + rddToPairRDDFunctions(rdd).mapValues(x => (x._1.asJava, x._2.asJava, x._3.asJava)) } private[spark] def cogroupResult3ToJava[K: ClassTag, V, W1, W2, W3]( rdd: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))]) : RDD[(K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3]))] = { - rddToPairRDDFunctions(rdd) - .mapValues(x => - (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3), asJavaIterable(x._4))) + rddToPairRDDFunctions(rdd).mapValues(x => (x._1.asJava, x._2.asJava, x._3.asJava, x._4.asJava)) } def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index c95615a5a930..fc817cdd6a3f 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -21,7 +21,6 @@ import java.{lang => jl} import java.lang.{Iterable => JIterable, Long => JLong} import java.util.{Comparator, List => JList, Iterator => JIterator} -import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -59,10 +58,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def rdd: RDD[T] @deprecated("Use partitions() instead.", "1.1.0") - def splits: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) + def splits: JList[Partition] = rdd.partitions.toSeq.asJava /** Set of partitions in this RDD. */ - def partitions: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) + def partitions: JList[Partition] = rdd.partitions.toSeq.asJava /** The partitioner of this RDD. */ def partitioner: Optional[Partitioner] = JavaUtils.optionToOptional(rdd.partitioner) @@ -82,7 +81,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * subclasses of RDD. */ def iterator(split: Partition, taskContext: TaskContext): java.util.Iterator[T] = - asJavaIterator(rdd.iterator(split, taskContext)) + rdd.iterator(split, taskContext).asJava // Transformations (return a new RDD) @@ -99,7 +98,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsWithIndex[R]( f: JFunction2[jl.Integer, java.util.Iterator[T], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = - new JavaRDD(rdd.mapPartitionsWithIndex(((a, b) => f(a, asJavaIterator(b))), + new JavaRDD(rdd.mapPartitionsWithIndex((a, b) => f.call(a, b.asJava).asScala, preservesPartitioning)(fakeClassTag))(fakeClassTag) /** @@ -153,7 +152,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = { def fn: (Iterator[T]) => Iterator[U] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } JavaRDD.fromRDD(rdd.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -164,7 +163,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U], preservesPartitioning: Boolean): JavaRDD[U] = { def fn: (Iterator[T]) => Iterator[U] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } JavaRDD.fromRDD( rdd.mapPartitions(fn, preservesPartitioning)(fakeClassTag[U]))(fakeClassTag[U]) @@ -175,7 +174,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = { def fn: (Iterator[T]) => Iterator[jl.Double] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: jl.Double) => x.doubleValue())) } @@ -186,7 +185,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]): JavaPairRDD[K2, V2] = { def fn: (Iterator[T]) => Iterator[(K2, V2)] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -197,7 +196,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]], preservesPartitioning: Boolean): JavaDoubleRDD = { def fn: (Iterator[T]) => Iterator[jl.Double] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } new JavaDoubleRDD(rdd.mapPartitions(fn, preservesPartitioning) .map(x => x.doubleValue())) @@ -209,7 +208,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2], preservesPartitioning: Boolean): JavaPairRDD[K2, V2] = { def fn: (Iterator[T]) => Iterator[(K2, V2)] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } JavaPairRDD.fromRDD( rdd.mapPartitions(fn, preservesPartitioning))(fakeClassTag[K2], fakeClassTag[V2]) @@ -219,14 +218,14 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Applies a function f to each partition of this RDD. */ def foreachPartition(f: VoidFunction[java.util.Iterator[T]]) { - rdd.foreachPartition((x => f.call(asJavaIterator(x)))) + rdd.foreachPartition((x => f.call(x.asJava))) } /** * Return an RDD created by coalescing all elements within each partition into an array. */ def glom(): JavaRDD[JList[T]] = - new JavaRDD(rdd.glom().map(x => new java.util.ArrayList[T](x.toSeq))) + new JavaRDD(rdd.glom().map(_.toSeq.asJava)) /** * Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of @@ -266,13 +265,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return an RDD created by piping elements to a forked external process. */ def pipe(command: JList[String]): JavaRDD[String] = - rdd.pipe(asScalaBuffer(command)) + rdd.pipe(command.asScala) /** * Return an RDD created by piping elements to a forked external process. */ def pipe(command: JList[String], env: java.util.Map[String, String]): JavaRDD[String] = - rdd.pipe(asScalaBuffer(command), mapAsScalaMap(env)) + rdd.pipe(command.asScala, env.asScala) /** * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, @@ -294,8 +293,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { other: JavaRDDLike[U, _], f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V]): JavaRDD[V] = { def fn: (Iterator[T], Iterator[U]) => Iterator[V] = { - (x: Iterator[T], y: Iterator[U]) => asScalaIterator( - f.call(asJavaIterator(x), asJavaIterator(y)).iterator()) + (x: Iterator[T], y: Iterator[U]) => f.call(x.asJava, y.asJava).iterator().asScala } JavaRDD.fromRDD( rdd.zipPartitions(other.rdd)(fn)(other.classTag, fakeClassTag[V]))(fakeClassTag[V]) @@ -333,28 +331,22 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Return an array that contains all of the elements in this RDD. */ - def collect(): JList[T] = { - import scala.collection.JavaConversions._ - val arr: java.util.Collection[T] = rdd.collect().toSeq - new java.util.ArrayList(arr) - } + def collect(): JList[T] = + rdd.collect().toSeq.asJava /** * Return an iterator that contains all of the elements in this RDD. * * The iterator will consume as much memory as the largest partition in this RDD. */ - def toLocalIterator(): JIterator[T] = { - import scala.collection.JavaConversions._ - rdd.toLocalIterator - } - + def toLocalIterator(): JIterator[T] = + asJavaIteratorConverter(rdd.toLocalIterator).asJava /** * Return an array that contains all of the elements in this RDD. * @deprecated As of Spark 1.0.0, toArray() is deprecated, use {@link #collect()} instead */ - @Deprecated + @deprecated("use collect()", "1.0.0") def toArray(): JList[T] = collect() /** @@ -363,9 +355,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def collectPartitions(partitionIds: Array[Int]): Array[JList[T]] = { // This is useful for implementing `take` from other language frontends // like Python where the data is serialized. - import scala.collection.JavaConversions._ - val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds, true) - res.map(x => new java.util.ArrayList(x.toSeq)).toArray + val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds) + res.map(_.toSeq.asJava) } /** @@ -489,20 +480,14 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * it will be slow if a lot of partitions are required. In that case, use collect() to get the * whole RDD instead. */ - def take(num: Int): JList[T] = { - import scala.collection.JavaConversions._ - val arr: java.util.Collection[T] = rdd.take(num).toSeq - new java.util.ArrayList(arr) - } + def take(num: Int): JList[T] = + rdd.take(num).toSeq.asJava def takeSample(withReplacement: Boolean, num: Int): JList[T] = takeSample(withReplacement, num, Utils.random.nextLong) - def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = { - import scala.collection.JavaConversions._ - val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq - new java.util.ArrayList(arr) - } + def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = + rdd.takeSample(withReplacement, num, seed).toSeq.asJava /** * Return the first element in this RDD. @@ -582,10 +567,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * @return an array of top elements */ def top(num: Int, comp: Comparator[T]): JList[T] = { - import scala.collection.JavaConversions._ - val topElems = rdd.top(num)(Ordering.comparatorToOrdering(comp)) - val arr: java.util.Collection[T] = topElems.toSeq - new java.util.ArrayList(arr) + rdd.top(num)(Ordering.comparatorToOrdering(comp)).toSeq.asJava } /** @@ -607,10 +589,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * @return an array of top elements */ def takeOrdered(num: Int, comp: Comparator[T]): JList[T] = { - import scala.collection.JavaConversions._ - val topElems = rdd.takeOrdered(num)(Ordering.comparatorToOrdering(comp)) - val arr: java.util.Collection[T] = topElems.toSeq - new java.util.ArrayList(arr) + rdd.takeOrdered(num)(Ordering.comparatorToOrdering(comp)).toSeq.asJava } /** @@ -696,7 +675,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * applies a function f to each partition of this RDD. */ def foreachPartitionAsync(f: VoidFunction[java.util.Iterator[T]]): JavaFutureAction[Void] = { - new JavaFutureActionWrapper[Unit, Void](rdd.foreachPartitionAsync(x => f.call(x)), + new JavaFutureActionWrapper[Unit, Void](rdd.foreachPartitionAsync(x => f.call(x.asJava)), { x => null.asInstanceOf[Void] }) } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 02e49a853c5f..609496ccdfef 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -21,8 +21,7 @@ import java.io.Closeable import java.util import java.util.{Map => JMap} -import scala.collection.JavaConversions -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -104,7 +103,7 @@ class JavaSparkContext(val sc: SparkContext) */ def this(master: String, appName: String, sparkHome: String, jars: Array[String], environment: JMap[String, String]) = - this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment, Map())) + this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment.asScala, Map())) private[spark] val env = sc.env @@ -118,7 +117,7 @@ class JavaSparkContext(val sc: SparkContext) def appName: String = sc.appName - def jars: util.List[String] = sc.jars + def jars: util.List[String] = sc.jars.asJava def startTime: java.lang.Long = sc.startTime @@ -142,7 +141,7 @@ class JavaSparkContext(val sc: SparkContext) /** Distribute a local Scala collection to form an RDD. */ def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = { implicit val ctag: ClassTag[T] = fakeClassTag - sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices) + sc.parallelize(list.asScala, numSlices) } /** Get an RDD that has no partitions or elements. */ @@ -161,7 +160,7 @@ class JavaSparkContext(val sc: SparkContext) : JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = fakeClassTag implicit val ctagV: ClassTag[V] = fakeClassTag - JavaPairRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices)) + JavaPairRDD.fromRDD(sc.parallelize(list.asScala, numSlices)) } /** Distribute a local Scala collection to form an RDD. */ @@ -170,8 +169,7 @@ class JavaSparkContext(val sc: SparkContext) /** Distribute a local Scala collection to form an RDD. */ def parallelizeDoubles(list: java.util.List[java.lang.Double], numSlices: Int): JavaDoubleRDD = - JavaDoubleRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list).map(_.doubleValue()), - numSlices)) + JavaDoubleRDD.fromRDD(sc.parallelize(list.asScala.map(_.doubleValue()), numSlices)) /** Distribute a local Scala collection to form an RDD. */ def parallelizeDoubles(list: java.util.List[java.lang.Double]): JavaDoubleRDD = @@ -519,7 +517,7 @@ class JavaSparkContext(val sc: SparkContext) /** Build the union of two or more RDDs. */ override def union[T](first: JavaRDD[T], rest: java.util.List[JavaRDD[T]]): JavaRDD[T] = { - val rdds: Seq[RDD[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd) + val rdds: Seq[RDD[T]] = (Seq(first) ++ rest.asScala).map(_.rdd) implicit val ctag: ClassTag[T] = first.classTag sc.union(rdds) } @@ -527,7 +525,7 @@ class JavaSparkContext(val sc: SparkContext) /** Build the union of two or more RDDs. */ override def union[K, V](first: JavaPairRDD[K, V], rest: java.util.List[JavaPairRDD[K, V]]) : JavaPairRDD[K, V] = { - val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd) + val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ rest.asScala).map(_.rdd) implicit val ctag: ClassTag[(K, V)] = first.classTag implicit val ctagK: ClassTag[K] = first.kClassTag implicit val ctagV: ClassTag[V] = first.vClassTag @@ -536,7 +534,7 @@ class JavaSparkContext(val sc: SparkContext) /** Build the union of two or more RDDs. */ override def union(first: JavaDoubleRDD, rest: java.util.List[JavaDoubleRDD]): JavaDoubleRDD = { - val rdds: Seq[RDD[Double]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.srdd) + val rdds: Seq[RDD[Double]] = (Seq(first) ++ rest.asScala).map(_.srdd) new JavaDoubleRDD(sc.union(rdds)) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index b959b683d167..a7dfa1d257cf 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -17,15 +17,17 @@ package org.apache.spark.api.python -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD -import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.{Logging, SparkException} +import scala.collection.JavaConverters._ +import scala.util.{Failure, Success, Try} + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io._ -import scala.util.{Failure, Success, Try} -import org.apache.spark.annotation.Experimental +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.annotation.Experimental +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.util.{SerializableConfiguration, Utils} /** * :: Experimental :: @@ -68,7 +70,6 @@ private[python] class WritableToJavaConverter( * object representation */ private def convertWritable(writable: Writable): Any = { - import collection.JavaConversions._ writable match { case iw: IntWritable => iw.get() case dw: DoubleWritable => dw.get() @@ -89,9 +90,7 @@ private[python] class WritableToJavaConverter( aw.get().map(convertWritable(_)) case mw: MapWritable => val map = new java.util.HashMap[Any, Any]() - mw.foreach { case (k, v) => - map.put(convertWritable(k), convertWritable(v)) - } + mw.asScala.foreach { case (k, v) => map.put(convertWritable(k), convertWritable(v)) } map case w: Writable => WritableUtils.clone(w, conf.value.value) case other => other @@ -122,7 +121,6 @@ private[python] class JavaToWritableConverter extends Converter[Any, Writable] { * supported out-of-the-box. */ private def convertToWritable(obj: Any): Writable = { - import collection.JavaConversions._ obj match { case i: java.lang.Integer => new IntWritable(i) case d: java.lang.Double => new DoubleWritable(d) @@ -134,7 +132,7 @@ private[python] class JavaToWritableConverter extends Converter[Any, Writable] { case null => NullWritable.get() case map: java.util.Map[_, _] => val mapWritable = new MapWritable() - map.foreach { case (k, v) => + map.asScala.foreach { case (k, v) => mapWritable.put(convertToWritable(k), convertToWritable(v)) } mapWritable @@ -161,9 +159,8 @@ private[python] object PythonHadoopUtil { * Convert a [[java.util.Map]] of properties to a [[org.apache.hadoop.conf.Configuration]] */ def mapToConf(map: java.util.Map[String, String]): Configuration = { - import collection.JavaConversions._ val conf = new Configuration() - map.foreach{ case (k, v) => conf.set(k, v) } + map.asScala.foreach { case (k, v) => conf.set(k, v) } conf } @@ -172,9 +169,8 @@ private[python] object PythonHadoopUtil { * any matching keys in left */ def mergeConfs(left: Configuration, right: Configuration): Configuration = { - import collection.JavaConversions._ val copy = new Configuration(left) - right.iterator().foreach(entry => copy.set(entry.getKey, entry.getValue)) + right.asScala.foreach(entry => copy.set(entry.getKey, entry.getValue)) copy } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index dc9f62f39e6d..69da180593bb 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -21,7 +21,7 @@ import java.io._ import java.net._ import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.existentials @@ -41,7 +41,7 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} import scala.util.control.NonFatal private[spark] class PythonRDD( - @transient parent: RDD[_], + parent: RDD[_], command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], @@ -66,11 +66,11 @@ private[spark] class PythonRDD( val env = SparkEnv.get val localdir = env.blockManager.diskBlockManager.localDirs.map( f => f.getPath()).mkString(",") - envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread + envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread if (reuse_worker) { - envVars += ("SPARK_REUSE_WORKER" -> "1") + envVars.put("SPARK_REUSE_WORKER", "1") } - val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) + val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) // Whether is the worker released into idle pool @volatile var released = false @@ -150,7 +150,7 @@ private[spark] class PythonRDD( // Check whether the worker is ready to be re-used. if (stream.readInt() == SpecialLengths.END_OF_STREAM) { if (reuse_worker) { - env.releasePythonWorker(pythonExec, envVars.toMap, worker) + env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) released = true } } @@ -207,6 +207,7 @@ private[spark] class PythonRDD( override def run(): Unit = Utils.logUncaughtExceptions { try { + TaskContext.setTaskContext(context) val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) // Partition index @@ -216,13 +217,13 @@ private[spark] class PythonRDD( // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.length) - for (include <- pythonIncludes) { + dataOut.writeInt(pythonIncludes.size()) + for (include <- pythonIncludes.asScala) { PythonRDD.writeUTF(include, dataOut) } // Broadcast variables val oldBids = PythonRDD.getWorkerBroadcasts(worker) - val newBids = broadcastVars.map(_.id).toSet + val newBids = broadcastVars.asScala.map(_.id).toSet // number of different broadcasts val toRemove = oldBids.diff(newBids) val cnt = toRemove.size + newBids.diff(oldBids).size @@ -232,7 +233,7 @@ private[spark] class PythonRDD( dataOut.writeLong(- bid - 1) // bid >= 0 oldBids.remove(bid) } - for (broadcast <- broadcastVars) { + for (broadcast <- broadcastVars.asScala) { if (!oldBids.contains(broadcast.id)) { // send new broadcast dataOut.writeLong(broadcast.id) @@ -263,11 +264,6 @@ private[spark] class PythonRDD( if (!worker.isClosed) { Utils.tryLog(worker.shutdownOutput()) } - } finally { - // Release memory used by this thread for shuffles - env.shuffleMemoryManager.releaseMemoryForThisThread() - // Release memory used by this thread for unrolling blocks - env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() } } } @@ -291,7 +287,7 @@ private[spark] class PythonRDD( if (!context.isCompleted) { try { logWarning("Incomplete task interrupted: Attempting to kill Python Worker") - env.destroyPythonWorker(pythonExec, envVars.toMap, worker) + env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) } catch { case e: Exception => logError("Exception when trying to kill worker", e) @@ -358,15 +354,14 @@ private[spark] object PythonRDD extends Logging { def runJob( sc: SparkContext, rdd: JavaRDD[Array[Byte]], - partitions: JArrayList[Int], - allowLocal: Boolean): Int = { + partitions: JArrayList[Int]): Int = { type ByteArray = Array[Byte] type UnrolledPartition = Array[ByteArray] val allPartitions: Array[UnrolledPartition] = - sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal) + sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions.asScala) val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*) serveIterator(flattenedPartition.iterator, - s"serve RDD ${rdd.id} with partitions ${partitions.mkString(",")}") + s"serve RDD ${rdd.id} with partitions ${partitions.asScala.mkString(",")}") } /** @@ -790,7 +785,7 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it * collects a list of pickled strings that we pass to Python through a socket. */ -private class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) +private class PythonAccumulatorParam(@transient private val serverHost: String, serverPort: Int) extends AccumulatorParam[JList[Array[Byte]]] { Utils.checkHost(serverHost, "Expected hostname") @@ -799,7 +794,7 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: /** * We try to reuse a single Socket to transfer accumulator updates, as they are all added - * by the DAGScheduler's single-threaded actor anyway. + * by the DAGScheduler's single-threaded RpcEndpoint anyway. */ @transient var socket: Socket = _ @@ -824,7 +819,7 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: val in = socket.getInputStream val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize)) out.writeInt(val2.size) - for (array <- val2) { + for (array <- val2.asScala) { out.writeInt(array.length) out.write(array) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 90dacaeb9342..31e534f160ee 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -17,10 +17,10 @@ package org.apache.spark.api.python -import java.io.{File} +import java.io.File import java.util.{List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkContext @@ -51,7 +51,14 @@ private[spark] object PythonUtils { * Convert list of T into seq of T (for calling API with varargs) */ def toSeq[T](vs: JList[T]): Seq[T] = { - vs.toList.toSeq + vs.asScala + } + + /** + * Convert list of T into a (Scala) List of T + */ + def toList[T](vs: JList[T]): List[T] = { + vs.asScala.toList } /** @@ -65,6 +72,6 @@ private[spark] object PythonUtils { * Convert java map of K, V into Map of K, V (for calling API with varargs) */ def toScalaMap[K, V](jm: java.util.Map[K, V]): Map[K, V] = { - jm.toMap + jm.asScala.toMap } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index e314408c067e..7039b734d2e4 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -19,9 +19,10 @@ package org.apache.spark.api.python import java.io.{DataOutputStream, DataInputStream, InputStream, OutputStreamWriter} import java.net.{InetAddress, ServerSocket, Socket, SocketException} +import java.util.Arrays import scala.collection.mutable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark._ import org.apache.spark.util.{RedirectThread, Utils} @@ -108,9 +109,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) // Create and start the worker - val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.worker")) + val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", "pyspark.worker")) val workerEnv = pb.environment() - workerEnv.putAll(envVars) + workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") @@ -151,9 +152,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String try { // Create and start the daemon - val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.daemon")) + val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", "pyspark.daemon")) val workerEnv = pb.environment() - workerEnv.putAll(envVars) + workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index 1f1debcf84ad..fd27276e70bf 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -22,7 +22,6 @@ import java.util.{ArrayList => JArrayList} import org.apache.spark.api.java.JavaRDD -import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Failure @@ -214,7 +213,7 @@ private[spark] object SerDeUtil extends Logging { new AutoBatchedPickler(cleaned) } else { val pickle = new Pickler - cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched))) + cleaned.grouped(batchSize).map(batched => pickle.dumps(batched.asJava)) } } } diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index 8f30ff9202c8..ee1fb056f0d9 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -20,6 +20,8 @@ package org.apache.spark.api.python import java.io.{DataOutput, DataInput} import java.{util => ju} +import scala.collection.JavaConverters._ + import com.google.common.base.Charsets.UTF_8 import org.apache.hadoop.io._ @@ -62,10 +64,9 @@ private[python] class TestInputKeyConverter extends Converter[Any, Any] { } private[python] class TestInputValueConverter extends Converter[Any, Any] { - import collection.JavaConversions._ override def convert(obj: Any): ju.List[Double] = { val m = obj.asInstanceOf[MapWritable] - seqAsJavaList(m.keySet.map(w => w.asInstanceOf[DoubleWritable].get()).toSeq) + m.keySet.asScala.map(_.asInstanceOf[DoubleWritable].get()).toSeq.asJava } } @@ -76,9 +77,8 @@ private[python] class TestOutputKeyConverter extends Converter[Any, Any] { } private[python] class TestOutputValueConverter extends Converter[Any, Any] { - import collection.JavaConversions._ override def convert(obj: Any): DoubleWritable = { - new DoubleWritable(obj.asInstanceOf[java.util.Map[Double, _]].keySet().head) + new DoubleWritable(obj.asInstanceOf[java.util.Map[Double, _]].keySet().iterator().next()) } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 1a5f2bca26c2..b7e72d4d0ed0 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -95,7 +95,9 @@ private[spark] class RBackend { private[spark] object RBackend extends Logging { def main(args: Array[String]): Unit = { if (args.length < 1) { + // scalastyle:off println System.err.println("Usage: RBackend ") + // scalastyle:on println System.exit(-1) } val sparkRBackend = new RBackend() diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 2e86984c66b3..2a792d81994f 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -20,12 +20,14 @@ package org.apache.spark.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import scala.collection.mutable.HashMap +import scala.language.existentials import io.netty.channel.ChannelHandler.Sharable import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import org.apache.spark.Logging import org.apache.spark.api.r.SerDe._ +import org.apache.spark.util.Utils /** * Handler for RBackend @@ -51,6 +53,13 @@ private[r] class RBackendHandler(server: RBackend) if (objId == "SparkRHandler") { methodName match { + // This function is for test-purpose only + case "echo" => + val args = readArgs(numArgs, dis) + assert(numArgs == 1) + + writeInt(dos, 0) + writeObject(dos, args(0)) case "stopBackend" => writeInt(dos, 0) writeType(dos, "void") @@ -67,8 +76,11 @@ private[r] class RBackendHandler(server: RBackend) case e: Exception => logError(s"Removing $objId failed", e) writeInt(dos, -1) + writeString(dos, s"Removing $objId failed: ${e.getMessage}") } - case _ => dos.writeInt(-1) + case _ => + dos.writeInt(-1) + writeString(dos, s"Error: unknown method $methodName") } } else { handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) @@ -98,7 +110,7 @@ private[r] class RBackendHandler(server: RBackend) var obj: Object = null try { val cls = if (isStatic) { - Class.forName(objId) + Utils.classForName(objId) } else { JVMObjectTracker.get(objId) match { case None => throw new IllegalArgumentException("Object not found " + objId) @@ -113,10 +125,11 @@ private[r] class RBackendHandler(server: RBackend) val methods = cls.getMethods val selectedMethods = methods.filter(m => m.getName == methodName) if (selectedMethods.length > 0) { - val methods = selectedMethods.filter { x => - matchMethod(numArgs, args, x.getParameterTypes) - } - if (methods.isEmpty) { + val index = findMatchedSignature( + selectedMethods.map(_.getParameterTypes), + args) + + if (index.isEmpty) { logWarning(s"cannot find matching method ${cls}.$methodName. " + s"Candidates are:") selectedMethods.foreach { method => @@ -124,18 +137,29 @@ private[r] class RBackendHandler(server: RBackend) } throw new Exception(s"No matched method found for $cls.$methodName") } - val ret = methods.head.invoke(obj, args : _*) + + val ret = selectedMethods(index.get).invoke(obj, args : _*) // Write status bit writeInt(dos, 0) writeObject(dos, ret.asInstanceOf[AnyRef]) } else if (methodName == "") { // methodName should be "" for constructor - val ctor = cls.getConstructors.filter { x => - matchMethod(numArgs, args, x.getParameterTypes) - }.head + val ctors = cls.getConstructors + val index = findMatchedSignature( + ctors.map(_.getParameterTypes), + args) + + if (index.isEmpty) { + logWarning(s"cannot find matching constructor for ${cls}. " + + s"Candidates are:") + ctors.foreach { ctor => + logWarning(s"$cls(${ctor.getParameterTypes.mkString(",")})") + } + throw new Exception(s"No matched constructor found for $cls") + } - val obj = ctor.newInstance(args : _*) + val obj = ctors(index.get).newInstance(args : _*) writeInt(dos, 0) writeObject(dos, obj.asInstanceOf[AnyRef]) @@ -144,46 +168,89 @@ private[r] class RBackendHandler(server: RBackend) } } catch { case e: Exception => - logError(s"$methodName on $objId failed", e) + logError(s"$methodName on $objId failed") writeInt(dos, -1) + // Writing the error message of the cause for the exception. This will be returned + // to user in the R process. + writeString(dos, Utils.exceptionString(e.getCause)) } } // Read a number of arguments from the data input stream def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { - (0 until numArgs).map { arg => + (0 until numArgs).map { _ => readObject(dis) }.toArray } - // Checks if the arguments passed in args matches the parameter types. - // NOTE: Currently we do exact match. We may add type conversions later. - def matchMethod( - numArgs: Int, - args: Array[java.lang.Object], - parameterTypes: Array[Class[_]]): Boolean = { - if (parameterTypes.length != numArgs) { - return false - } + // Find a matching method signature in an array of signatures of constructors + // or methods of the same name according to the passed arguments. Arguments + // may be converted in order to match a signature. + // + // Note that in Java reflection, constructors and normal methods are of different + // classes, and share no parent class that provides methods for reflection uses. + // There is no unified way to handle them in this function. So an array of signatures + // is passed in instead of an array of candidate constructors or methods. + // + // Returns an Option[Int] which is the index of the matched signature in the array. + def findMatchedSignature( + parameterTypesOfMethods: Array[Array[Class[_]]], + args: Array[Object]): Option[Int] = { + val numArgs = args.length + + for (index <- 0 until parameterTypesOfMethods.length) { + val parameterTypes = parameterTypesOfMethods(index) - for (i <- 0 to numArgs - 1) { - val parameterType = parameterTypes(i) - var parameterWrapperType = parameterType - - // Convert native parameters to Object types as args is Array[Object] here - if (parameterType.isPrimitive) { - parameterWrapperType = parameterType match { - case java.lang.Integer.TYPE => classOf[java.lang.Integer] - case java.lang.Double.TYPE => classOf[java.lang.Double] - case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] - case _ => parameterType + if (parameterTypes.length == numArgs) { + var argMatched = true + var i = 0 + while (i < numArgs && argMatched) { + val parameterType = parameterTypes(i) + + if (parameterType == classOf[Seq[Any]] && args(i).getClass.isArray) { + // The case that the parameter type is a Scala Seq and the argument + // is a Java array is considered matching. The array will be converted + // to a Seq later if this method is matched. + } else { + var parameterWrapperType = parameterType + + // Convert native parameters to Object types as args is Array[Object] here + if (parameterType.isPrimitive) { + parameterWrapperType = parameterType match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Long.TYPE => classOf[java.lang.Integer] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => parameterType + } + } + if (!parameterWrapperType.isInstance(args(i))) { + argMatched = false + } + } + + i = i + 1 + } + + if (argMatched) { + // For now, we return the first matching method. + // TODO: find best method in matching methods. + + // Convert args if needed + val parameterTypes = parameterTypesOfMethods(index) + + (0 until numArgs).map { i => + if (parameterTypes(i) == classOf[Seq[Any]] && args(i).getClass.isArray) { + // Convert a Java array to scala Seq + args(i) = args(i).asInstanceOf[Array[_]].toSeq + } + } + + return Some(index) } - } - if (!parameterWrapperType.isInstance(args(i))) { - return false } } - true + None } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 4dfa7325934f..9d5bbb5d609f 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -19,9 +19,10 @@ package org.apache.spark.api.r import java.io._ import java.net.{InetAddress, ServerSocket} +import java.util.Arrays import java.util.{Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.io.Source import scala.reflect.ClassTag import scala.util.Try @@ -39,7 +40,6 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( deserializer: String, serializer: String, packageNames: Array[Byte], - rLibDir: String, broadcastVars: Array[Broadcast[Object]]) extends RDD[U](parent) with Logging { protected var dataStream: DataInputStream = _ @@ -60,7 +60,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( // The stdout/stderr is shared by multiple tasks, because we use one daemon // to launch child process as worker. - val errThread = RRDD.createRWorker(rLibDir, listenPort) + val errThread = RRDD.createRWorker(listenPort) // We use two sockets to separate input and output, then it's easy to manage // the lifecycle of them to avoid deadlock. @@ -113,6 +113,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( partition: Int): Unit = { val env = SparkEnv.get + val taskContext = TaskContext.get() val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt val stream = new BufferedOutputStream(output, bufferSize) @@ -120,6 +121,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( override def run(): Unit = { try { SparkEnv.set(env) + TaskContext.setTaskContext(taskContext) val dataOut = new DataOutputStream(stream) dataOut.writeInt(partition) @@ -161,7 +163,9 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( dataOut.write(elem.asInstanceOf[Array[Byte]]) } else if (deserializer == SerializationFormats.STRING) { // write string(for StringRRDD) + // scalastyle:off println printOut.println(elem) + // scalastyle:on println } } @@ -233,11 +237,10 @@ private class PairwiseRRDD[T: ClassTag]( hashFunc: Array[Byte], deserializer: String, packageNames: Array[Byte], - rLibDir: String, broadcastVars: Array[Object]) extends BaseRRDD[T, (Int, Array[Byte])]( parent, numPartitions, hashFunc, deserializer, - SerializationFormats.BYTE, packageNames, rLibDir, + SerializationFormats.BYTE, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { override protected def readData(length: Int): (Int, Array[Byte]) = { @@ -264,10 +267,9 @@ private class RRDD[T: ClassTag]( deserializer: String, serializer: String, packageNames: Array[Byte], - rLibDir: String, broadcastVars: Array[Object]) extends BaseRRDD[T, Array[Byte]]( - parent, -1, func, deserializer, serializer, packageNames, rLibDir, + parent, -1, func, deserializer, serializer, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { override protected def readData(length: Int): Array[Byte] = { @@ -291,10 +293,9 @@ private class StringRRDD[T: ClassTag]( func: Array[Byte], deserializer: String, packageNames: Array[Byte], - rLibDir: String, broadcastVars: Array[Object]) extends BaseRRDD[T, String]( - parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, rLibDir, + parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { override protected def readData(length: Int): String = { @@ -365,11 +366,11 @@ private[r] object RRDD { sparkConf.setIfMissing("spark.master", "local") } - for ((name, value) <- sparkEnvirMap) { - sparkConf.set(name.asInstanceOf[String], value.asInstanceOf[String]) + for ((name, value) <- sparkEnvirMap.asScala) { + sparkConf.set(name.toString, value.toString) } - for ((name, value) <- sparkExecutorEnvMap) { - sparkConf.setExecutorEnv(name.asInstanceOf[String], value.asInstanceOf[String]) + for ((name, value) <- sparkExecutorEnvMap.asScala) { + sparkConf.setExecutorEnv(name.toString, value.toString) } val jsc = new JavaSparkContext(sparkConf) @@ -390,11 +391,12 @@ private[r] object RRDD { thread } - private def createRProcess(rLibDir: String, port: Int, script: String): BufferedStreamThread = { - val rCommand = "Rscript" + private def createRProcess(port: Int, script: String): BufferedStreamThread = { + val rCommand = SparkEnv.get.conf.get("spark.sparkr.r.command", "Rscript") val rOptions = "--vanilla" + val rLibDir = RUtils.sparkRPackagePath(isDriver = false) val rExecScript = rLibDir + "/SparkR/worker/" + script - val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript)) + val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript)) // Unset the R_TESTS environment variable for workers. // This is set by R CMD check as startup.Rs // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R) @@ -411,7 +413,7 @@ private[r] object RRDD { /** * ProcessBuilder used to launch worker R processes. */ - def createRWorker(rLibDir: String, port: Int): BufferedStreamThread = { + def createRWorker(port: Int): BufferedStreamThread = { val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true) if (!Utils.isWindows && useDaemon) { synchronized { @@ -419,7 +421,7 @@ private[r] object RRDD { // we expect one connections val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) val daemonPort = serverSocket.getLocalPort - errThread = createRProcess(rLibDir, daemonPort, "daemon.R") + errThread = createRProcess(daemonPort, "daemon.R") // the socket used to send out the input of task serverSocket.setSoTimeout(10000) val sock = serverSocket.accept() @@ -441,7 +443,7 @@ private[r] object RRDD { errThread } } else { - createRProcess(rLibDir, port, "worker.R") + createRProcess(port, "worker.R") } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala new file mode 100644 index 000000000000..9e807cc52f18 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.File +import java.util.Arrays + +import org.apache.spark.{SparkEnv, SparkException} + +private[spark] object RUtils { + /** + * Get the SparkR package path in the local spark distribution. + */ + def localSparkRPackagePath: Option[String] = { + val sparkHome = sys.env.get("SPARK_HOME").orElse(sys.props.get("spark.test.home")) + sparkHome.map( + Seq(_, "R", "lib").mkString(File.separator) + ) + } + + /** + * Get the SparkR package path in various deployment modes. + * This assumes that Spark properties `spark.master` and `spark.submit.deployMode` + * and environment variable `SPARK_HOME` are set. + */ + def sparkRPackagePath(isDriver: Boolean): String = { + val (master, deployMode) = + if (isDriver) { + (sys.props("spark.master"), sys.props("spark.submit.deployMode")) + } else { + val sparkConf = SparkEnv.get.conf + (sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode")) + } + + val isYarnCluster = master != null && master.contains("yarn") && deployMode == "cluster" + val isYarnClient = master != null && master.contains("yarn") && deployMode == "client" + + // In YARN mode, the SparkR package is distributed as an archive symbolically + // linked to the "sparkr" file in the current directory. Note that this does not apply + // to the driver in client mode because it is run outside of the cluster. + if (isYarnCluster || (isYarnClient && !isDriver)) { + new File("sparkr").getAbsolutePath + } else { + // Otherwise, assume the package is local + // TODO: support this for Mesos + localSparkRPackagePath.getOrElse { + throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.") + } + } + } + + /** Check if R is installed before running tests that use R commands. */ + def isRInstalled: Boolean = { + try { + val builder = new ProcessBuilder(Arrays.asList("R", "--version")) + builder.start().waitFor() == 0 + } catch { + case e: Exception => false + } + } +} diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 56adc857d4ce..0c78613e406e 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -20,7 +20,8 @@ package org.apache.spark.api.r import java.io.{DataInputStream, DataOutputStream} import java.sql.{Timestamp, Date, Time} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.collection.mutable.WrappedArray /** * Utility functions to serialize, deserialize objects to / from R @@ -149,6 +150,10 @@ private[spark] object SerDe { case 'b' => readBooleanArr(dis) case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) case 'r' => readBytesArr(dis) + case 'l' => { + val len = readInt(dis) + (0 until len).map(_ => readList(dis)).toArray + } case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") } } @@ -165,7 +170,7 @@ private[spark] object SerDe { val valueType = readObjectType(in) readTypedObject(in, valueType) }) - mapAsJavaMap(keys.zip(values).toMap) + keys.zip(values).toMap.asJava } else { new java.util.HashMap[Object, Object]() } @@ -179,7 +184,9 @@ private[spark] object SerDe { // Int -> integer // String -> character // Boolean -> logical + // Float -> double // Double -> double + // Decimal -> double // Long -> double // Array[Byte] -> raw // Date -> Date @@ -198,78 +205,141 @@ private[spark] object SerDe { case "date" => dos.writeByte('D') case "time" => dos.writeByte('t') case "raw" => dos.writeByte('r') + // Array of primitive types + case "array" => dos.writeByte('a') + // Array of objects case "list" => dos.writeByte('l') + case "map" => dos.writeByte('e') case "jobj" => dos.writeByte('j') case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") } } - def writeObject(dos: DataOutputStream, value: Object): Unit = { - if (value == null) { + private def writeKeyValue(dos: DataOutputStream, key: Object, value: Object): Unit = { + if (key == null) { + throw new IllegalArgumentException("Key in map can't be null.") + } else if (!key.isInstanceOf[String]) { + throw new IllegalArgumentException(s"Invalid map key type: ${key.getClass.getName}") + } + + writeString(dos, key.asInstanceOf[String]) + writeObject(dos, value) + } + + def writeObject(dos: DataOutputStream, obj: Object): Unit = { + if (obj == null) { writeType(dos, "void") } else { - value.getClass.getName match { - case "java.lang.String" => + // Convert ArrayType collected from DataFrame to Java array + // Collected data of ArrayType from a DataFrame is observed to be of + // type "scala.collection.mutable.WrappedArray" + val value = + if (obj.isInstanceOf[WrappedArray[_]]) { + obj.asInstanceOf[WrappedArray[_]].toArray + } else { + obj + } + + value match { + case v: java.lang.Character => + writeType(dos, "character") + writeString(dos, v.toString) + case v: java.lang.String => writeType(dos, "character") - writeString(dos, value.asInstanceOf[String]) - case "long" | "java.lang.Long" => + writeString(dos, v) + case v: java.lang.Long => writeType(dos, "double") - writeDouble(dos, value.asInstanceOf[Long].toDouble) - case "double" | "java.lang.Double" => + writeDouble(dos, v.toDouble) + case v: java.lang.Float => writeType(dos, "double") - writeDouble(dos, value.asInstanceOf[Double]) - case "int" | "java.lang.Integer" => + writeDouble(dos, v.toDouble) + case v: java.math.BigDecimal => + writeType(dos, "double") + writeDouble(dos, scala.math.BigDecimal(v).toDouble) + case v: java.lang.Double => + writeType(dos, "double") + writeDouble(dos, v) + case v: java.lang.Byte => + writeType(dos, "integer") + writeInt(dos, v.toInt) + case v: java.lang.Short => writeType(dos, "integer") - writeInt(dos, value.asInstanceOf[Int]) - case "boolean" | "java.lang.Boolean" => + writeInt(dos, v.toInt) + case v: java.lang.Integer => + writeType(dos, "integer") + writeInt(dos, v) + case v: java.lang.Boolean => writeType(dos, "logical") - writeBoolean(dos, value.asInstanceOf[Boolean]) - case "java.sql.Date" => + writeBoolean(dos, v) + case v: java.sql.Date => writeType(dos, "date") - writeDate(dos, value.asInstanceOf[Date]) - case "java.sql.Time" => + writeDate(dos, v) + case v: java.sql.Time => writeType(dos, "time") - writeTime(dos, value.asInstanceOf[Time]) - case "java.sql.Timestamp" => + writeTime(dos, v) + case v: java.sql.Timestamp => writeType(dos, "time") - writeTime(dos, value.asInstanceOf[Timestamp]) - case "[B" => - writeType(dos, "raw") - writeBytes(dos, value.asInstanceOf[Array[Byte]]) - // TODO: Types not handled right now include - // byte, char, short, float + writeTime(dos, v) // Handle arrays - case "[Ljava.lang.String;" => - writeType(dos, "list") - writeStringArr(dos, value.asInstanceOf[Array[String]]) - case "[I" => - writeType(dos, "list") - writeIntArr(dos, value.asInstanceOf[Array[Int]]) - case "[J" => - writeType(dos, "list") - writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble)) - case "[D" => - writeType(dos, "list") - writeDoubleArr(dos, value.asInstanceOf[Array[Double]]) - case "[Z" => - writeType(dos, "list") - writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]]) - case "[[B" => + + // Array of primitive types + + // Special handling for byte array + case v: Array[Byte] => + writeType(dos, "raw") + writeBytes(dos, v) + + case v: Array[Char] => + writeType(dos, "array") + writeStringArr(dos, v.map(_.toString)) + case v: Array[Short] => + writeType(dos, "array") + writeIntArr(dos, v.map(_.toInt)) + case v: Array[Int] => + writeType(dos, "array") + writeIntArr(dos, v) + case v: Array[Long] => + writeType(dos, "array") + writeDoubleArr(dos, v.map(_.toDouble)) + case v: Array[Float] => + writeType(dos, "array") + writeDoubleArr(dos, v.map(_.toDouble)) + case v: Array[Double] => + writeType(dos, "array") + writeDoubleArr(dos, v) + case v: Array[Boolean] => + writeType(dos, "array") + writeBooleanArr(dos, v) + + // Array of objects, null objects use "void" type + case v: Array[Object] => writeType(dos, "list") - writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]]) - case otherName => - // Handle array of objects - if (otherName.startsWith("[L")) { - val objArr = value.asInstanceOf[Array[Object]] - writeType(dos, "list") - writeType(dos, "jobj") - dos.writeInt(objArr.length) - objArr.foreach(o => writeJObj(dos, o)) - } else { - writeType(dos, "jobj") - writeJObj(dos, value) + writeInt(dos, v.length) + v.foreach(elem => writeObject(dos, elem)) + + // Handle map + case v: java.util.Map[_, _] => + writeType(dos, "map") + writeInt(dos, v.size) + val iter = v.entrySet.iterator + while(iter.hasNext) { + val entry = iter.next + val key = entry.getKey + val value = entry.getValue + + writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) + } + case v: scala.collection.Map[_, _] => + writeType(dos, "map") + writeInt(dos, v.size) + v.foreach { case (key, value) => + writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) } + + case _ => + writeType(dos, "jobj") + writeJObj(dos, value) } } } @@ -299,12 +369,11 @@ private[spark] object SerDe { out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9) } - // NOTE: Only works for ASCII right now def writeString(out: DataOutputStream, value: String): Unit = { - val len = value.length - out.writeInt(len + 1) // For the \0 - out.writeBytes(value) - out.writeByte(0) + val utf8 = value.getBytes("UTF-8") + val len = utf8.length + out.writeInt(len) + out.write(utf8, 0, len) } def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { @@ -341,11 +410,6 @@ private[spark] object SerDe { value.foreach(v => writeString(out, v)) } - def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { - writeType(out, "raw") - out.writeInt(value.length) - value.foreach(v => writeBytes(out, v)) - } } private[r] object SerializationFormats { diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index 685313ac009b..fac6666bb341 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicLong import scala.reflect.ClassTag import org.apache.spark._ +import org.apache.spark.util.Utils private[spark] class BroadcastManager( val isDriver: Boolean, @@ -42,7 +43,7 @@ private[spark] class BroadcastManager( conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") broadcastFactory = - Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] + Utils.classForName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] // Initialize appropriate BroadcastFactory and BroadcastObject broadcastFactory.initialize(isDriver, conf, securityManager) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index a0c9b5e63c74..7e3764d802fe 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -20,7 +20,7 @@ package org.apache.spark.broadcast import java.io._ import java.nio.ByteBuffer -import scala.collection.JavaConversions.asJavaEnumeration +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.util.Random @@ -210,7 +210,7 @@ private object TorrentBroadcast extends Logging { compressionCodec: Option[CompressionCodec]): T = { require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks") val is = new SequenceInputStream( - asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block)))) + blocks.iterator.map(new ByteBufferInputStream(_)).asJavaEnumeration) val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is) val ser = serializer.newInstance() val serIn = ser.deserializeStream(in) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 848b62f9de71..f03875a3e8c8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -18,17 +18,17 @@ package org.apache.spark.deploy import scala.collection.mutable.HashSet -import scala.concurrent._ +import scala.concurrent.ExecutionContext +import scala.reflect.ClassTag +import scala.util.{Failure, Success} -import akka.actor._ -import akka.pattern.ask -import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} import org.apache.log4j.{Level, Logger} +import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, Utils} +import org.apache.spark.util.{ThreadUtils, SparkExitCode, Utils} /** * Proxy that relays messages to the driver. @@ -36,20 +36,30 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, Utils} * We currently don't support retry if submission fails. In HA mode, client will submit request to * all masters and see which one could handle it. */ -private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) - extends Actor with ActorLogReceive with Logging { - - private val masterActors = driverArgs.masters.map { m => - context.actorSelection(Master.toAkkaUrl(m, AkkaUtils.protocol(context.system))) - } - private val lostMasters = new HashSet[Address] - private var activeMasterActor: ActorSelection = null - - val timeout = RpcUtils.askTimeout(conf) - - override def preStart(): Unit = { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - +private class ClientEndpoint( + override val rpcEnv: RpcEnv, + driverArgs: ClientArguments, + masterEndpoints: Seq[RpcEndpointRef], + conf: SparkConf) + extends ThreadSafeRpcEndpoint with Logging { + + // A scheduled executor used to send messages at the specified time. + private val forwardMessageThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("client-forward-message") + // Used to provide the implicit parameter of `Future` methods. + private val forwardMessageExecutionContext = + ExecutionContext.fromExecutor(forwardMessageThread, + t => t match { + case ie: InterruptedException => // Exit normally + case e: Throwable => + logError(e.getMessage, e) + System.exit(SparkExitCode.UNCAUGHT_EXCEPTION) + }) + + private val lostMasters = new HashSet[RpcAddress] + private var activeMasterEndpoint: RpcEndpointRef = null + + override def onStart(): Unit = { driverArgs.cmd match { case "launch" => // TODO: We could add an env variable here and intercept it in `sc.addJar` that would @@ -82,44 +92,52 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) driverArgs.cores, driverArgs.supervise, command) - - // This assumes only one Master is active at a time - for (masterActor <- masterActors) { - masterActor ! RequestSubmitDriver(driverDescription) - } + ayncSendToMasterAndForwardReply[SubmitDriverResponse]( + RequestSubmitDriver(driverDescription)) case "kill" => val driverId = driverArgs.driverId - // This assumes only one Master is active at a time - for (masterActor <- masterActors) { - masterActor ! RequestKillDriver(driverId) - } + ayncSendToMasterAndForwardReply[KillDriverResponse](RequestKillDriver(driverId)) + } + } + + /** + * Send the message to master and forward the reply to self asynchronously. + */ + private def ayncSendToMasterAndForwardReply[T: ClassTag](message: Any): Unit = { + for (masterEndpoint <- masterEndpoints) { + masterEndpoint.ask[T](message).onComplete { + case Success(v) => self.send(v) + case Failure(e) => + logWarning(s"Error sending messages to master $masterEndpoint", e) + }(forwardMessageExecutionContext) } } /* Find out driver status then exit the JVM */ def pollAndReportStatus(driverId: String) { - println("... waiting before polling master for driver state") + // Since ClientEndpoint is the only RpcEndpoint in the process, blocking the event loop thread + // is fine. + logInfo("... waiting before polling master for driver state") Thread.sleep(5000) - println("... polling master for driver state") - val statusFuture = (activeMasterActor ? RequestDriverStatus(driverId))(timeout) - .mapTo[DriverStatusResponse] - val statusResponse = Await.result(statusFuture, timeout) + logInfo("... polling master for driver state") + val statusResponse = + activeMasterEndpoint.askWithRetry[DriverStatusResponse](RequestDriverStatus(driverId)) statusResponse.found match { case false => - println(s"ERROR: Cluster master did not recognize $driverId") + logError(s"ERROR: Cluster master did not recognize $driverId") System.exit(-1) case true => - println(s"State of $driverId is ${statusResponse.state.get}") + logInfo(s"State of $driverId is ${statusResponse.state.get}") // Worker node, if present (statusResponse.workerId, statusResponse.workerHostPort, statusResponse.state) match { case (Some(id), Some(hostPort), Some(DriverState.RUNNING)) => - println(s"Driver running on $hostPort ($id)") + logInfo(s"Driver running on $hostPort ($id)") case _ => } // Exception, if present statusResponse.exception.map { e => - println(s"Exception from cluster was: $e") + logError(s"Exception from cluster was: $e") e.printStackTrace() System.exit(-1) } @@ -127,50 +145,62 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) } } - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receive: PartialFunction[Any, Unit] = { - case SubmitDriverResponse(success, driverId, message) => - println(message) + case SubmitDriverResponse(master, success, driverId, message) => + logInfo(message) if (success) { - activeMasterActor = context.actorSelection(sender.path) + activeMasterEndpoint = master pollAndReportStatus(driverId.get) } else if (!Utils.responseFromBackup(message)) { System.exit(-1) } - case KillDriverResponse(driverId, success, message) => - println(message) + case KillDriverResponse(master, driverId, success, message) => + logInfo(message) if (success) { - activeMasterActor = context.actorSelection(sender.path) + activeMasterEndpoint = master pollAndReportStatus(driverId) } else if (!Utils.responseFromBackup(message)) { System.exit(-1) } + } - case DisassociatedEvent(_, remoteAddress, _) => - if (!lostMasters.contains(remoteAddress)) { - println(s"Error connecting to master $remoteAddress.") - lostMasters += remoteAddress - // Note that this heuristic does not account for the fact that a Master can recover within - // the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This - // is not currently a concern, however, because this client does not retry submissions. - if (lostMasters.size >= masterActors.size) { - println("No master is available, exiting.") - System.exit(-1) - } + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (!lostMasters.contains(remoteAddress)) { + logError(s"Error connecting to master $remoteAddress.") + lostMasters += remoteAddress + // Note that this heuristic does not account for the fact that a Master can recover within + // the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This + // is not currently a concern, however, because this client does not retry submissions. + if (lostMasters.size >= masterEndpoints.size) { + logError("No master is available, exiting.") + System.exit(-1) } + } + } - case AssociationErrorEvent(cause, _, remoteAddress, _, _) => - if (!lostMasters.contains(remoteAddress)) { - println(s"Error connecting to master ($remoteAddress).") - println(s"Cause was: $cause") - lostMasters += remoteAddress - if (lostMasters.size >= masterActors.size) { - println("No master is available, exiting.") - System.exit(-1) - } + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + if (!lostMasters.contains(remoteAddress)) { + logError(s"Error connecting to master ($remoteAddress).") + logError(s"Cause was: $cause") + lostMasters += remoteAddress + if (lostMasters.size >= masterEndpoints.size) { + logError("No master is available, exiting.") + System.exit(-1) } + } + } + + override def onError(cause: Throwable): Unit = { + logError(s"Error processing messages, exiting.") + cause.printStackTrace() + System.exit(-1) + } + + override def onStop(): Unit = { + forwardMessageThread.shutdownNow() } } @@ -179,10 +209,12 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) */ object Client { def main(args: Array[String]) { + // scalastyle:off println if (!sys.props.contains("SPARK_SUBMIT")) { println("WARNING: This client is deprecated and will be removed in a future version of Spark") println("Use ./bin/spark-submit with \"--master spark://host:port\"") } + // scalastyle:on println val conf = new SparkConf() val driverArgs = new ClientArguments(args) @@ -194,15 +226,13 @@ object Client { conf.set("akka.loglevel", driverArgs.logLevel.toString.replace("WARN", "WARNING")) Logger.getRootLogger.setLevel(driverArgs.logLevel) - val (actorSystem, _) = AkkaUtils.createActorSystem( - "driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) + val rpcEnv = + RpcEnv.create("driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) - // Verify driverArgs.master is a valid url so that we can use it in ClientActor safely - for (m <- driverArgs.masters) { - Master.toAkkaUrl(m, AkkaUtils.protocol(actorSystem)) - } - actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf)) + val masterEndpoints = driverArgs.masters.map(RpcAddress.fromSparkURL). + map(rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, _, Master.ENDPOINT_NAME)) + rpcEnv.setupEndpoint("client", new ClientEndpoint(rpcEnv, driverArgs, masterEndpoints, conf)) - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 316e2d59f01b..72cc330a398d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -72,9 +72,11 @@ private[deploy] class ClientArguments(args: Array[String]) { cmd = "launch" if (!ClientArguments.isValidJarUrl(_jarUrl)) { + // scalastyle:off println println(s"Jar url '${_jarUrl}' is not in valid format.") println(s"Must be a jar file path in URL format " + "(e.g. hdfs://host:port/XX.jar, file:///XX.jar)") + // scalastyle:on println printUsageAndExit(-1) } @@ -110,14 +112,16 @@ private[deploy] class ClientArguments(args: Array[String]) { | (default: $DEFAULT_SUPERVISE) | -v, --verbose Print more debugging output """.stripMargin + // scalastyle:off println System.err.println(usage) + // scalastyle:on println System.exit(exitCode) } } private[deploy] object ClientArguments { val DEFAULT_CORES = 1 - val DEFAULT_MEMORY = 512 // MB + val DEFAULT_MEMORY = Utils.DEFAULT_DRIVER_MEM_MB // MB val DEFAULT_SUPERVISE = false def isValidJarUrl(s: String): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 9db6fd1ac4db..d8084a57658a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -24,11 +24,12 @@ import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.RecoveryState.MasterState import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[deploy] sealed trait DeployMessage extends Serializable -/** Contains messages sent between Scheduler actor nodes. */ +/** Contains messages sent between Scheduler endpoint nodes. */ private[deploy] object DeployMessages { // Worker to Master @@ -37,6 +38,7 @@ private[deploy] object DeployMessages { id: String, host: String, port: Int, + worker: RpcEndpointRef, cores: Int, memory: Int, webUiPort: Int, @@ -63,11 +65,11 @@ private[deploy] object DeployMessages { case class WorkerSchedulerStateResponse(id: String, executors: List[ExecutorDescription], driverIds: Seq[String]) - case class Heartbeat(workerId: String) extends DeployMessage + case class Heartbeat(workerId: String, worker: RpcEndpointRef) extends DeployMessage // Master to Worker - case class RegisteredWorker(masterUrl: String, masterWebUiUrl: String) extends DeployMessage + case class RegisteredWorker(master: RpcEndpointRef, masterWebUiUrl: String) extends DeployMessage case class RegisterWorkerFailed(message: String) extends DeployMessage @@ -92,22 +94,26 @@ private[deploy] object DeployMessages { // Worker internal - case object WorkDirCleanup // Sent to Worker actor periodically for cleaning up app folders + case object WorkDirCleanup // Sent to Worker endpoint periodically for cleaning up app folders case object ReregisterWithMaster // used when a worker attempts to reconnect to a master // AppClient to Master - case class RegisterApplication(appDescription: ApplicationDescription) + case class RegisterApplication(appDescription: ApplicationDescription, driver: RpcEndpointRef) extends DeployMessage case class UnregisterApplication(appId: String) case class MasterChangeAcknowledged(appId: String) + case class RequestExecutors(appId: String, requestedTotal: Int) + + case class KillExecutors(appId: String, executorIds: Seq[String]) + // Master to AppClient - case class RegisteredApplication(appId: String, masterUrl: String) extends DeployMessage + case class RegisteredApplication(appId: String, master: RpcEndpointRef) extends DeployMessage // TODO(matei): replace hostPort with host case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) { @@ -123,12 +129,14 @@ private[deploy] object DeployMessages { case class RequestSubmitDriver(driverDescription: DriverDescription) extends DeployMessage - case class SubmitDriverResponse(success: Boolean, driverId: Option[String], message: String) + case class SubmitDriverResponse( + master: RpcEndpointRef, success: Boolean, driverId: Option[String], message: String) extends DeployMessage case class RequestKillDriver(driverId: String) extends DeployMessage - case class KillDriverResponse(driverId: String, success: Boolean, message: String) + case class KillDriverResponse( + master: RpcEndpointRef, driverId: String, success: Boolean, message: String) extends DeployMessage case class RequestDriverStatus(driverId: String) extends DeployMessage @@ -142,7 +150,7 @@ private[deploy] object DeployMessages { // Master to Worker & AppClient - case class MasterChanged(masterUrl: String, masterWebUiUrl: String) + case class MasterChanged(master: RpcEndpointRef, masterWebUiUrl: String) // MasterWebUI To Master diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index 09973a0a2c99..6840a3ae831f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -19,14 +19,15 @@ package org.apache.spark.deploy import java.util.concurrent.CountDownLatch -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.{Logging, SparkConf, SecurityManager} import org.apache.spark.network.TransportContext import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.sasl.SaslServerBootstrap -import org.apache.spark.network.server.TransportServer +import org.apache.spark.network.server.{TransportServerBootstrap, TransportServer} import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler +import org.apache.spark.network.util.TransportConf import org.apache.spark.util.Utils /** @@ -45,11 +46,16 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana private val useSasl: Boolean = securityManager.isAuthenticationEnabled() private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0) - private val blockHandler = new ExternalShuffleBlockHandler(transportConf) + private val blockHandler = newShuffleBlockHandler(transportConf) private val transportContext: TransportContext = new TransportContext(transportConf, blockHandler) private var server: TransportServer = _ + /** Create a new shuffle block handler. Factored out for subclasses to override. */ + protected def newShuffleBlockHandler(conf: TransportConf): ExternalShuffleBlockHandler = { + new ExternalShuffleBlockHandler(conf, null) + } + /** Starts the external shuffle service if the user has configured us to. */ def startIfEnabled() { if (enabled) { @@ -61,13 +67,18 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana def start() { require(server == null, "Shuffle server already started") logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") - val bootstraps = + val bootstraps: Seq[TransportServerBootstrap] = if (useSasl) { Seq(new SaslServerBootstrap(transportConf, securityManager)) } else { Nil } - server = transportContext.createServer(port, bootstraps) + server = transportContext.createServer(port, bootstraps.asJava) + } + + /** Clean up all shuffle files associated with an application that has exited. */ + def applicationRemoved(appId: String): Unit = { + blockHandler.applicationRemoved(appId, true /* cleanupLocalDirs */) } def stop() { @@ -88,6 +99,13 @@ object ExternalShuffleService extends Logging { private val barrier = new CountDownLatch(1) def main(args: Array[String]): Unit = { + main(args, (conf: SparkConf, sm: SecurityManager) => new ExternalShuffleService(conf, sm)) + } + + /** A helper main method that allows the caller to call this with a custom shuffle service. */ + private[spark] def main( + args: Array[String], + newShuffleService: (SparkConf, SecurityManager) => ExternalShuffleService): Unit = { val sparkConf = new SparkConf Utils.loadDefaultSparkProperties(sparkConf) val securityManager = new SecurityManager(sparkConf) @@ -95,7 +113,7 @@ object ExternalShuffleService extends Logging { // we override this value since this service is started from the command line // and we assume the user really wants it to be running sparkConf.set("spark.shuffle.service.enabled", "true") - server = new ExternalShuffleService(sparkConf, securityManager) + server = newShuffleService(sparkConf, securityManager) server.start() installShutdownHook() diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index 2954f932b4f4..ccffb3665298 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -76,12 +76,13 @@ private[deploy] object JsonProtocol { } def writeMasterState(obj: MasterStateResponse): JObject = { + val aliveWorkers = obj.workers.filter(_.isAlive()) ("url" -> obj.uri) ~ ("workers" -> obj.workers.toList.map(writeWorkerInfo)) ~ - ("cores" -> obj.workers.map(_.cores).sum) ~ - ("coresused" -> obj.workers.map(_.coresUsed).sum) ~ - ("memory" -> obj.workers.map(_.memory).sum) ~ - ("memoryused" -> obj.workers.map(_.memoryUsed).sum) ~ + ("cores" -> aliveWorkers.map(_.cores).sum) ~ + ("coresused" -> aliveWorkers.map(_.coresUsed).sum) ~ + ("memory" -> aliveWorkers.map(_.memory).sum) ~ + ("memoryused" -> aliveWorkers.map(_.memoryUsed).sum) ~ ("activeapps" -> obj.activeApps.toList.map(writeApplicationInfo)) ~ ("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo)) ~ ("activedrivers" -> obj.activeDrivers.toList.map(writeDriverInfo)) ~ diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 0550f00a172a..83ccaadfe744 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -19,8 +19,7 @@ package org.apache.spark.deploy import scala.collection.mutable.ArrayBuffer -import akka.actor.ActorSystem - +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.worker.Worker import org.apache.spark.deploy.master.Master @@ -41,8 +40,8 @@ class LocalSparkCluster( extends Logging { private val localHostname = Utils.localHostName() - private val masterActorSystems = ArrayBuffer[ActorSystem]() - private val workerActorSystems = ArrayBuffer[ActorSystem]() + private val masterRpcEnvs = ArrayBuffer[RpcEnv]() + private val workerRpcEnvs = ArrayBuffer[RpcEnv]() // exposed for testing var masterWebUIPort = -1 @@ -55,18 +54,17 @@ class LocalSparkCluster( .set("spark.shuffle.service.enabled", "false") /* Start the Master */ - val (masterSystem, masterPort, webUiPort, _) = - Master.startSystemAndActor(localHostname, 0, 0, _conf) + val (rpcEnv, webUiPort, _) = Master.startRpcEnvAndEndpoint(localHostname, 0, 0, _conf) masterWebUIPort = webUiPort - masterActorSystems += masterSystem - val masterUrl = "spark://" + Utils.localHostNameForURI() + ":" + masterPort + masterRpcEnvs += rpcEnv + val masterUrl = "spark://" + Utils.localHostNameForURI() + ":" + rpcEnv.address.port val masters = Array(masterUrl) /* Start the Workers */ for (workerNum <- 1 to numWorkers) { - val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker, + val workerEnv = Worker.startRpcEnvAndEndpoint(localHostname, 0, 0, coresPerWorker, memoryPerWorker, masters, null, Some(workerNum), _conf) - workerActorSystems += workerSystem + workerRpcEnvs += workerEnv } masters @@ -75,13 +73,9 @@ class LocalSparkCluster( def stop() { logInfo("Shutting down local Spark cluster.") // Stop the workers before the master so they don't get upset that it disconnected - // TODO: In Akka 2.1.x, ActorSystem.awaitTermination hangs when you have remote actors! - // This is unfortunate, but for now we just comment it out. - workerActorSystems.foreach(_.shutdown()) - // workerActorSystems.foreach(_.awaitTermination()) - masterActorSystems.foreach(_.shutdown()) - // masterActorSystems.foreach(_.awaitTermination()) - masterActorSystems.clear() - workerActorSystems.clear() + workerRpcEnvs.foreach(_.shutdown()) + masterRpcEnvs.foreach(_.shutdown()) + masterRpcEnvs.clear() + workerRpcEnvs.clear() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index c2ed43a5397d..d85327603f64 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -21,9 +21,10 @@ import java.net.URI import java.io.File import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.Try +import org.apache.spark.SparkUserAppException import org.apache.spark.api.python.PythonUtils import org.apache.spark.util.{RedirectThread, Utils} @@ -46,7 +47,20 @@ object PythonRunner { // Launch a Py4J gateway server for the process to connect to; this will let it see our // Java system properties and such val gatewayServer = new py4j.GatewayServer(null, 0) - gatewayServer.start() + val thread = new Thread(new Runnable() { + override def run(): Unit = Utils.logUncaughtExceptions { + gatewayServer.start() + } + }) + thread.setName("py4j-gateway-init") + thread.setDaemon(true) + thread.start() + + // Wait until the gateway server has started, so that we know which port is it bound to. + // `gatewayServer.start()` will start a new thread and run the server code there, after + // initializing the socket, so the thread started above will end as soon as the server is + // ready to serve connections. + thread.join() // Build up a PYTHONPATH that includes the Spark assembly JAR (where this class is), the // python directories in SPARK_HOME (if set), and any files in the pyFiles argument @@ -57,18 +71,25 @@ object PythonRunner { val pythonPath = PythonUtils.mergePythonPaths(pathElements: _*) // Launch Python process - val builder = new ProcessBuilder(Seq(pythonExec, formattedPythonFile) ++ otherArgs) + val builder = new ProcessBuilder((Seq(pythonExec, formattedPythonFile) ++ otherArgs).asJava) val env = builder.environment() env.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize - val process = builder.start() + try { + val process = builder.start() - new RedirectThread(process.getInputStream, System.out, "redirect output").start() + new RedirectThread(process.getInputStream, System.out, "redirect output").start() - System.exit(process.waitFor()) + val exitCode = process.waitFor() + if (exitCode != 0) { + throw new SparkUserAppException(exitCode) + } + } finally { + gatewayServer.shutdown() + } } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala new file mode 100644 index 000000000000..4b28866dcaa7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io._ +import java.util.jar.JarFile +import java.util.logging.Level +import java.util.zip.{ZipEntry, ZipOutputStream} + +import scala.collection.JavaConverters._ + +import com.google.common.io.{ByteStreams, Files} + +import org.apache.spark.{SparkException, Logging} +import org.apache.spark.api.r.RUtils +import org.apache.spark.util.{RedirectThread, Utils} + +private[deploy] object RPackageUtils extends Logging { + + /** The key in the MANIFEST.mf that we look for, in case a jar contains R code. */ + private final val hasRPackage = "Spark-HasRPackage" + + /** Base of the shell command used in order to install R packages. */ + private final val baseInstallCmd = Seq("R", "CMD", "INSTALL", "-l") + + /** R source code should exist under R/pkg in a jar. */ + private final val RJarEntries = "R/pkg" + + /** Documentation on how the R source file layout should be in the jar. */ + private[deploy] final val RJarDoc = + s"""In order for Spark to build R packages that are parts of Spark Packages, there are a few + |requirements. The R source code must be shipped in a jar, with additional Java/Scala + |classes. The jar must be in the following format: + | 1- The Manifest (META-INF/MANIFEST.mf) must contain the key-value: $hasRPackage: true + | 2- The standard R package layout must be preserved under R/pkg/ inside the jar. More + | information on the standard R package layout can be found in: + | http://cran.r-project.org/doc/contrib/Leisch-CreatingPackages.pdf + | An example layout is given below. After running `jar tf $$JAR_FILE | sort`: + | + |META-INF/MANIFEST.MF + |R/ + |R/pkg/ + |R/pkg/DESCRIPTION + |R/pkg/NAMESPACE + |R/pkg/R/ + |R/pkg/R/myRcode.R + |org/ + |org/apache/ + |... + """.stripMargin.trim + + /** Internal method for logging. We log to a printStream in tests, for debugging purposes. */ + private def print( + msg: String, + printStream: PrintStream, + level: Level = Level.FINE, + e: Throwable = null): Unit = { + if (printStream != null) { + // scalastyle:off println + printStream.println(msg) + // scalastyle:on println + if (e != null) { + e.printStackTrace(printStream) + } + } else { + level match { + case Level.INFO => logInfo(msg) + case Level.WARNING => logWarning(msg) + case Level.SEVERE => logError(msg, e) + case _ => logDebug(msg) + } + } + } + + /** + * Checks the manifest of the Jar whether there is any R source code bundled with it. + * Exposed for testing. + */ + private[deploy] def checkManifestForR(jar: JarFile): Boolean = { + val manifest = jar.getManifest.getMainAttributes + manifest.getValue(hasRPackage) != null && manifest.getValue(hasRPackage).trim == "true" + } + + /** + * Runs the standard R package installation code to build the R package from source. + * Multiple runs don't cause problems. + */ + private def rPackageBuilder(dir: File, printStream: PrintStream, verbose: Boolean): Boolean = { + // this code should be always running on the driver. + val pathToSparkR = RUtils.localSparkRPackagePath.getOrElse( + throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.")) + val pathToPkg = Seq(dir, "R", "pkg").mkString(File.separator) + val installCmd = baseInstallCmd ++ Seq(pathToSparkR, pathToPkg) + if (verbose) { + print(s"Building R package with the command: $installCmd", printStream) + } + try { + val builder = new ProcessBuilder(installCmd.asJava) + builder.redirectErrorStream(true) + val env = builder.environment() + env.clear() + val process = builder.start() + new RedirectThread(process.getInputStream, printStream, "redirect R packaging").start() + process.waitFor() == 0 + } catch { + case e: Throwable => + print("Failed to build R package.", printStream, Level.SEVERE, e) + false + } + } + + /** + * Extracts the files under /R in the jar to a temporary directory for building. + */ + private def extractRFolder(jar: JarFile, printStream: PrintStream, verbose: Boolean): File = { + val tempDir = Utils.createTempDir(null) + val jarEntries = jar.entries() + while (jarEntries.hasMoreElements) { + val entry = jarEntries.nextElement() + val entryRIndex = entry.getName.indexOf(RJarEntries) + if (entryRIndex > -1) { + val entryPath = entry.getName.substring(entryRIndex) + if (entry.isDirectory) { + val dir = new File(tempDir, entryPath) + if (verbose) { + print(s"Creating directory: $dir", printStream) + } + dir.mkdirs + } else { + val inStream = jar.getInputStream(entry) + val outPath = new File(tempDir, entryPath) + Files.createParentDirs(outPath) + val outStream = new FileOutputStream(outPath) + if (verbose) { + print(s"Extracting $entry to $outPath", printStream) + } + Utils.copyStream(inStream, outStream, closeStreams = true) + } + } + } + tempDir + } + + /** + * Extracts the files under /R in the jar to a temporary directory for building. + */ + private[deploy] def checkAndBuildRPackage( + jars: String, + printStream: PrintStream = null, + verbose: Boolean = false): Unit = { + jars.split(",").foreach { jarPath => + val file = new File(Utils.resolveURI(jarPath)) + if (file.exists()) { + val jar = new JarFile(file) + if (checkManifestForR(jar)) { + print(s"$file contains R source code. Now installing package.", printStream, Level.INFO) + val rSource = extractRFolder(jar, printStream, verbose) + try { + if (!rPackageBuilder(rSource, printStream, verbose)) { + print(s"ERROR: Failed to build R package in $file.", printStream) + print(RJarDoc, printStream) + } + } finally { + rSource.delete() // clean up + } + } else { + if (verbose) { + print(s"$file doesn't contain R source code, skipping...", printStream) + } + } + } else { + print(s"WARN: $file resolved as dependency, but not found.", printStream, Level.WARNING) + } + } + } + + private def listFilesRecursively(dir: File, excludePatterns: Seq[String]): Set[File] = { + if (!dir.exists()) { + Set.empty[File] + } else { + if (dir.isDirectory) { + val subDir = dir.listFiles(new FilenameFilter { + override def accept(dir: File, name: String): Boolean = { + !excludePatterns.map(name.contains).reduce(_ || _) // exclude files with given pattern + } + }) + subDir.flatMap(listFilesRecursively(_, excludePatterns)).toSet + } else { + Set(dir) + } + } + } + + /** Zips all the libraries found with SparkR in the R/lib directory for distribution with Yarn. */ + private[deploy] def zipRLibraries(dir: File, name: String): File = { + val filesToBundle = listFilesRecursively(dir, Seq(".zip")) + // create a zip file from scratch, do not append to existing file. + val zipFile = new File(dir, name) + zipFile.delete() + val zipOutputStream = new ZipOutputStream(new FileOutputStream(zipFile, false)) + try { + filesToBundle.foreach { file => + // get the relative paths for proper naming in the zip file + val relPath = file.getAbsolutePath.replaceFirst(dir.getAbsolutePath, "") + val fis = new FileInputStream(file) + val zipEntry = new ZipEntry(relPath) + zipOutputStream.putNextEntry(zipEntry) + ByteStreams.copy(fis, zipOutputStream) + zipOutputStream.closeEntry() + fis.close() + } + } finally { + zipOutputStream.close() + } + zipFile + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index e99779f29978..05b954ce3699 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -20,11 +20,11 @@ package org.apache.spark.deploy import java.io._ import java.util.concurrent.{Semaphore, TimeUnit} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path -import org.apache.spark.api.r.RBackend +import org.apache.spark.api.r.{RBackend, RUtils} import org.apache.spark.util.RedirectThread /** @@ -68,12 +68,13 @@ object RRunner { if (initialized.tryAcquire(backendTimeout, TimeUnit.SECONDS)) { // Launch R val returnCode = try { - val builder = new ProcessBuilder(Seq(rCommand, rFileNormalized) ++ otherArgs) + val builder = new ProcessBuilder((Seq(rCommand, rFileNormalized) ++ otherArgs).asJava) val env = builder.environment() env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString) - val sparkHome = System.getenv("SPARK_HOME") + val rPackageDir = RUtils.sparkRPackagePath(isDriver = true) + env.put("SPARKR_PACKAGE_DIR", rPackageDir) env.put("R_PROFILE_USER", - Seq(sparkHome, "R", "lib", "SparkR", "profile", "general.R").mkString(File.separator)) + Seq(rPackageDir, "SparkR", "profile", "general.R").mkString(File.separator)) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize val process = builder.start() @@ -85,7 +86,9 @@ object RRunner { } System.exit(returnCode) } else { + // scalastyle:off println System.err.println("SparkR backend did not initialize in " + backendTimeout + " seconds") + // scalastyle:on println System.exit(-1) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala index b8d399354022..8d5e716e6aea 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.curator.framework.{CuratorFramework, CuratorFrameworkFactory} import org.apache.curator.retry.ExponentialBackoffRetry @@ -57,7 +57,7 @@ private[spark] object SparkCuratorUtil extends Logging { def deleteRecursive(zk: CuratorFramework, path: String) { if (zk.checkExists().forPath(path) != null) { - for (child <- zk.getChildren.forPath(path)) { + for (child <- zk.getChildren.forPath(path).asScala) { zk.delete().forPath(path + "/" + child) } zk.delete().forPath(path) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 7fa75ac8c2b5..a0b7365df900 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -22,9 +22,10 @@ import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import java.util.{Arrays, Comparator} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps +import scala.util.control.NonFatal import com.google.common.primitives.Longs import org.apache.hadoop.conf.Configuration @@ -33,6 +34,8 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce.JobContext +import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext} +import org.apache.hadoop.mapreduce.{TaskAttemptID => MapReduceTaskAttemptID} import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.spark.annotation.DeveloperApi @@ -68,12 +71,12 @@ class SparkHadoopUtil extends Logging { } def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) { - for (token <- source.getTokens()) { + for (token <- source.getTokens.asScala) { dest.addToken(token) } } - @Deprecated + @deprecated("use newConfiguration with SparkConf argument", "1.2.0") def newConfiguration(): Configuration = newConfiguration(null) /** @@ -172,13 +175,13 @@ class SparkHadoopUtil extends Logging { } private def getFileSystemThreadStatistics(): Seq[AnyRef] = { - val stats = FileSystem.getAllStatistics() - stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) + FileSystem.getAllStatistics.asScala.map( + Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) } private def getFileSystemThreadStatisticsMethod(methodName: String): Method = { val statisticsDataClass = - Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData") + Utils.classForName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData") statisticsDataClass.getDeclaredMethod(methodName) } @@ -189,10 +192,26 @@ class SparkHadoopUtil extends Logging { * while it's interface in Hadoop 2.+. */ def getConfigurationFromJobContext(context: JobContext): Configuration = { + // scalastyle:off jobconfig val method = context.getClass.getMethod("getConfiguration") + // scalastyle:on jobconfig method.invoke(context).asInstanceOf[Configuration] } + /** + * Using reflection to call `getTaskAttemptID` from TaskAttemptContext. If we directly + * call `TaskAttemptContext.getTaskAttemptID`, it will generate different byte codes + * for Hadoop 1.+ and Hadoop 2.+ because TaskAttemptContext is class in Hadoop 1.+ + * while it's interface in Hadoop 2.+. + */ + def getTaskAttemptIDFromTaskAttemptContext( + context: MapReduceTaskAttemptContext): MapReduceTaskAttemptID = { + // scalastyle:off jobconfig + val method = context.getClass.getMethod("getTaskAttemptID") + // scalastyle:on jobconfig + method.invoke(context).asInstanceOf[MapReduceTaskAttemptID] + } + /** * Get [[FileStatus]] objects for all leaf children (files) under the given base path. If the * given path points to a file, return a single-element collection containing [[FileStatus]] of @@ -238,6 +257,14 @@ class SparkHadoopUtil extends Logging { }.getOrElse(Seq.empty[Path]) } + def globPathIfNecessary(pattern: Path): Seq[Path] = { + if (pattern.toString.exists("{}[]*?\\".toSet.contains)) { + globPath(pattern) + } else { + Seq(pattern) + } + } + /** * Lists all the files in a directory with the specified prefix, and does not end with the * given suffix. The returned {{FileStatus}} instances are sorted by the modification times of @@ -248,19 +275,25 @@ class SparkHadoopUtil extends Logging { dir: Path, prefix: String, exclusionSuffix: String): Array[FileStatus] = { - val fileStatuses = remoteFs.listStatus(dir, - new PathFilter { - override def accept(path: Path): Boolean = { - val name = path.getName - name.startsWith(prefix) && !name.endsWith(exclusionSuffix) + try { + val fileStatuses = remoteFs.listStatus(dir, + new PathFilter { + override def accept(path: Path): Boolean = { + val name = path.getName + name.startsWith(prefix) && !name.endsWith(exclusionSuffix) + } + }) + Arrays.sort(fileStatuses, new Comparator[FileStatus] { + override def compare(o1: FileStatus, o2: FileStatus): Int = { + Longs.compare(o1.getModificationTime, o2.getModificationTime) } }) - Arrays.sort(fileStatuses, new Comparator[FileStatus] { - override def compare(o1: FileStatus, o2: FileStatus): Int = { - Longs.compare(o1.getModificationTime, o2.getModificationTime) - } - }) - fileStatuses + fileStatuses + } catch { + case NonFatal(e) => + logWarning("Error while attempting to list files from application staging dir", e) + Array.empty + } } /** @@ -277,12 +310,13 @@ class SparkHadoopUtil extends Logging { val renewalInterval = sparkConf.getLong("spark.yarn.token.renewal.interval", (24 hours).toMillis) - credentials.getAllTokens.filter(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND) + credentials.getAllTokens.asScala + .filter(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND) .map { t => - val identifier = new DelegationTokenIdentifier() - identifier.readFields(new DataInputStream(new ByteArrayInputStream(t.getIdentifier))) - (identifier.getIssueDate + fraction * renewalInterval).toLong - now - }.foldLeft(0L)(math.max) + val identifier = new DelegationTokenIdentifier() + identifier.readFields(new DataInputStream(new ByteArrayInputStream(t.getIdentifier))) + (identifier.getIssueDate + fraction * renewalInterval).toLong - now + }.foldLeft(0L)(math.max) } @@ -334,6 +368,19 @@ class SparkHadoopUtil extends Logging { * Stop the thread that does the delegation token updates. */ private[spark] def stopExecutorDelegationTokenRenewer() {} + + /** + * Return a fresh Hadoop configuration, bypassing the HDFS cache mechanism. + * This is to prevent the DFSClient from using an old cached token to connect to the NameNode. + */ + private[spark] def getConfBypassingFSCache( + hadoopConf: Configuration, + scheme: String): Configuration = { + val newConf = new Configuration(hadoopConf) + val confKey = s"fs.${scheme}.impl.disable.cache" + newConf.setBoolean(confKey, true) + newConf + } } object SparkHadoopUtil { @@ -343,7 +390,7 @@ object SparkHadoopUtil { System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) if (yarnMode) { try { - Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") + Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") .newInstance() .asInstanceOf[SparkHadoopUtil] } catch { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index abf222757a95..ad92f5635af3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -24,6 +24,7 @@ import java.security.PrivilegedExceptionAction import scala.collection.mutable.{ArrayBuffer, HashMap, Map} +import org.apache.commons.lang3.StringUtils import org.apache.hadoop.fs.Path import org.apache.hadoop.security.UserGroupInformation import org.apache.ivy.Ivy @@ -37,7 +38,9 @@ import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.matcher.GlobPatternMatcher import org.apache.ivy.plugins.repository.file.FileRepository import org.apache.ivy.plugins.resolver.{FileSystemResolver, ChainResolver, IBiblioResolver} -import org.apache.spark.SPARK_VERSION + +import org.apache.spark.{SparkUserAppException, SPARK_VERSION} +import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.rest._ import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} @@ -79,9 +82,11 @@ object SparkSubmit { private val SPARK_SHELL = "spark-shell" private val PYSPARK_SHELL = "pyspark-shell" private val SPARKR_SHELL = "sparkr-shell" + private val SPARKR_PACKAGE_ARCHIVE = "sparkr.zip" private val CLASS_NOT_FOUND_EXIT_STATUS = 101 + // scalastyle:off println // Exposed for testing private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode) private[spark] var printStream: PrintStream = System.err @@ -102,11 +107,14 @@ object SparkSubmit { printStream.println("Type --help for more information.") exitFn(0) } + // scalastyle:on println def main(args: Array[String]): Unit = { val appArgs = new SparkSubmitArguments(args) if (appArgs.verbose) { + // scalastyle:off println printStream.println(appArgs) + // scalastyle:on println } appArgs.action match { case SparkSubmitAction.SUBMIT => submit(appArgs) @@ -160,7 +168,9 @@ object SparkSubmit { // makes the message printed to the output by the JVM not very helpful. Instead, // detect exceptions with empty stack traces here, and treat them differently. if (e.getStackTrace().length == 0) { + // scalastyle:off println printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") + // scalastyle:on println exitFn(1) } else { throw e @@ -178,7 +188,9 @@ object SparkSubmit { // to use the legacy gateway if the master endpoint turns out to be not a REST server. if (args.isStandaloneCluster && args.useRest) { try { + // scalastyle:off println printStream.println("Running Spark using the REST application submission protocol.") + // scalastyle:on println doRunMain() } catch { // Fail over to use the legacy submission gateway @@ -254,29 +266,38 @@ object SparkSubmit { } } + // Update args.deployMode if it is null. It will be passed down as a Spark property later. + (args.deployMode, deployMode) match { + case (null, CLIENT) => args.deployMode = "client" + case (null, CLUSTER) => args.deployMode = "cluster" + case _ => + } val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files // too for packages that include Python code - val resolvedMavenCoordinates = - SparkSubmitUtils.resolveMavenCoordinates( - args.packages, Option(args.repositories), Option(args.ivyRepoPath)) - if (!resolvedMavenCoordinates.trim.isEmpty) { - if (args.jars == null || args.jars.trim.isEmpty) { - args.jars = resolvedMavenCoordinates + val exclusions: Seq[String] = + if (!StringUtils.isBlank(args.packagesExclusions)) { + args.packagesExclusions.split(",") } else { - args.jars += s",$resolvedMavenCoordinates" + Nil } + val resolvedMavenCoordinates = SparkSubmitUtils.resolveMavenCoordinates(args.packages, + Option(args.repositories), Option(args.ivyRepoPath), exclusions = exclusions) + if (!StringUtils.isBlank(resolvedMavenCoordinates)) { + args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates) if (args.isPython) { - if (args.pyFiles == null || args.pyFiles.trim.isEmpty) { - args.pyFiles = resolvedMavenCoordinates - } else { - args.pyFiles += s",$resolvedMavenCoordinates" - } + args.pyFiles = mergeFileLists(args.pyFiles, resolvedMavenCoordinates) } } + // install any R packages that may have been passed through --jars or --packages. + // Spark Packages may contain R source code inside the jar. + if (args.isR && !StringUtils.isBlank(args.jars)) { + RPackageUtils.checkAndBuildRPackage(args.jars, printStream, args.verbose) + } + // Require all python files to be local, so we can add them to the PYTHONPATH // In YARN cluster mode, python files are distributed as regular files, which can be non-local if (args.isPython && !isYarnCluster) { @@ -298,8 +319,8 @@ object SparkSubmit { // The following modes are not supported or applicable (clusterManager, deployMode) match { - case (MESOS, CLUSTER) if args.isPython => - printErrorAndExit("Cluster deploy mode is currently not supported for python " + + case (MESOS, CLUSTER) if args.isR => + printErrorAndExit("Cluster deploy mode is currently not supported for R " + "applications on Mesos clusters.") case (STANDALONE, CLUSTER) if args.isPython => printErrorAndExit("Cluster deploy mode is currently not supported for python " + @@ -339,6 +360,24 @@ object SparkSubmit { } } + // In YARN mode for an R app, add the SparkR package archive to archives + // that can be distributed with the job + if (args.isR && clusterManager == YARN) { + val rPackagePath = RUtils.localSparkRPackagePath + if (rPackagePath.isEmpty) { + printErrorAndExit("SPARK_HOME does not exist for R application in YARN mode.") + } + val rPackageFile = + RPackageUtils.zipRLibraries(new File(rPackagePath.get), SPARKR_PACKAGE_ARCHIVE) + if (!rPackageFile.exists()) { + printErrorAndExit(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.") + } + val localURI = Utils.resolveURI(rPackageFile.getAbsolutePath) + + // Assigns a symbol link name "sparkr" to the shipped package. + args.archives = mergeFileLists(args.archives, localURI.toString + "#sparkr") + } + // If we're running a R app, set the main class to our specific R runner if (args.isR && deployMode == CLIENT) { if (args.primaryResource == SPARKR_SHELL) { @@ -367,6 +406,8 @@ object SparkSubmit { // All cluster managers OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"), + OptionAssigner(args.deployMode, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, + sysProp = "spark.submit.deployMode"), OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"), OptionAssigner(args.jars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars"), OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars.ivy"), @@ -381,7 +422,8 @@ object SparkSubmit { // Yarn client only OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), - OptionAssigner(args.numExecutors, YARN, CLIENT, sysProp = "spark.executor.instances"), + OptionAssigner(args.numExecutors, YARN, ALL_DEPLOY_MODES, + sysProp = "spark.executor.instances"), OptionAssigner(args.files, YARN, CLIENT, sysProp = "spark.yarn.dist.files"), OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"), OptionAssigner(args.principal, YARN, CLIENT, sysProp = "spark.yarn.principal"), @@ -392,7 +434,6 @@ object SparkSubmit { OptionAssigner(args.driverMemory, YARN, CLUSTER, clOption = "--driver-memory"), OptionAssigner(args.driverCores, YARN, CLUSTER, clOption = "--driver-cores"), OptionAssigner(args.queue, YARN, CLUSTER, clOption = "--queue"), - OptionAssigner(args.numExecutors, YARN, CLUSTER, clOption = "--num-executors"), OptionAssigner(args.executorMemory, YARN, CLUSTER, clOption = "--executor-memory"), OptionAssigner(args.executorCores, YARN, CLUSTER, clOption = "--executor-cores"), OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"), @@ -473,8 +514,14 @@ object SparkSubmit { } // Let YARN know it's a pyspark app, so it distributes needed libraries. - if (clusterManager == YARN && args.isPython) { - sysProps.put("spark.yarn.isPython", "true") + if (clusterManager == YARN) { + if (args.isPython) { + sysProps.put("spark.yarn.isPython", "true") + } + if (args.principal != null) { + require(args.keytab != null, "Keytab must be specified when the keytab is specified") + UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) + } } // In yarn-cluster mode, use yarn.Client as a wrapper around the user class @@ -504,7 +551,15 @@ object SparkSubmit { if (isMesosCluster) { assert(args.useRest, "Mesos cluster mode is only supported through the REST submission API") childMainClass = "org.apache.spark.deploy.rest.RestSubmissionClient" - childArgs += (args.primaryResource, args.mainClass) + if (args.isPython) { + // Second argument is main class + childArgs += (args.primaryResource, "") + if (args.pyFiles != null) { + sysProps("spark.submit.pyFiles") = args.pyFiles + } + } else { + childArgs += (args.primaryResource, args.mainClass) + } if (args.childArgs != null) { childArgs ++= args.childArgs } @@ -558,6 +613,7 @@ object SparkSubmit { sysProps: Map[String, String], childMainClass: String, verbose: Boolean): Unit = { + // scalastyle:off println if (verbose) { printStream.println(s"Main class:\n$childMainClass") printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") @@ -565,6 +621,7 @@ object SparkSubmit { printStream.println(s"Classpath elements:\n${childClasspath.mkString("\n")}") printStream.println("\n") } + // scalastyle:on println val loader = if (sysProps.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) { @@ -587,13 +644,15 @@ object SparkSubmit { var mainClass: Class[_] = null try { - mainClass = Class.forName(childMainClass, true, loader) + mainClass = Utils.classForName(childMainClass) } catch { case e: ClassNotFoundException => e.printStackTrace(printStream) if (childMainClass.contains("thriftserver")) { + // scalastyle:off println printStream.println(s"Failed to load main class $childMainClass.") printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.") + // scalastyle:on println } System.exit(CLASS_NOT_FOUND_EXIT_STATUS) } @@ -621,7 +680,13 @@ object SparkSubmit { mainMethod.invoke(null, childArgs.toArray) } catch { case t: Throwable => - throw findCause(t) + findCause(t) match { + case SparkUserAppException(exitCode) => + System.exit(exitCode) + + case t: Throwable => + throw t + } } } @@ -691,7 +756,7 @@ object SparkSubmit { * no files, into a single comma-separated string. */ private def mergeFileLists(lists: String*): String = { - val merged = lists.filter(_ != null) + val merged = lists.filterNot(StringUtils.isBlank) .flatMap(_.split(",")) .mkString(",") if (merged == "") null else merged @@ -756,6 +821,22 @@ private[spark] object SparkSubmitUtils { val cr = new ChainResolver cr.setName("list") + val repositoryList = remoteRepos.getOrElse("") + // add any other remote repositories other than maven central + if (repositoryList.trim.nonEmpty) { + repositoryList.split(",").zipWithIndex.foreach { case (repo, i) => + val brr: IBiblioResolver = new IBiblioResolver + brr.setM2compatible(true) + brr.setUsepoms(true) + brr.setRoot(repo) + brr.setName(s"repo-${i + 1}") + cr.add(brr) + // scalastyle:off println + printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") + // scalastyle:on println + } + } + val localM2 = new IBiblioResolver localM2.setM2compatible(true) localM2.setRoot(m2Path.toURI.toString) @@ -786,20 +867,6 @@ private[spark] object SparkSubmitUtils { sp.setRoot("http://dl.bintray.com/spark-packages/maven") sp.setName("spark-packages") cr.add(sp) - - val repositoryList = remoteRepos.getOrElse("") - // add any other remote repositories other than maven central - if (repositoryList.trim.nonEmpty) { - repositoryList.split(",").zipWithIndex.foreach { case (repo, i) => - val brr: IBiblioResolver = new IBiblioResolver - brr.setM2compatible(true) - brr.setUsepoms(true) - brr.setRoot(repo) - brr.setName(s"repo-${i + 1}") - cr.add(brr) - printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") - } - } cr } @@ -829,7 +896,9 @@ private[spark] object SparkSubmitUtils { val ri = ModuleRevisionId.newInstance(mvn.groupId, mvn.artifactId, mvn.version) val dd = new DefaultDependencyDescriptor(ri, false, false) dd.addDependencyConfiguration(ivyConfName, ivyConfName) + // scalastyle:off println printStream.println(s"${dd.getDependencyId} added as a dependency") + // scalastyle:on println md.addDependency(dd) } } @@ -889,16 +958,18 @@ private[spark] object SparkSubmitUtils { // are supplied to spark-submit val alternateIvyCache = ivyPath.getOrElse("") val packagesDirectory: File = - if (alternateIvyCache.trim.isEmpty) { + if (alternateIvyCache == null || alternateIvyCache.trim.isEmpty) { new File(ivySettings.getDefaultIvyUserDir, "jars") } else { ivySettings.setDefaultIvyUserDir(new File(alternateIvyCache)) ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) new File(alternateIvyCache, "jars") } + // scalastyle:off println printStream.println( s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") printStream.println(s"The jars for the packages stored in: $packagesDirectory") + // scalastyle:on println // create a pattern matcher ivySettings.addMatcher(new GlobPatternMatcher) // create the dependency resolvers @@ -922,17 +993,24 @@ private[spark] object SparkSubmitUtils { // A Module descriptor must be specified. Entries are dummy strings val md = getModuleDescriptor + // clear ivy resolution from previous launches. The resolution file is usually at + // ~/.ivy2/org.apache.spark-spark-submit-parent-default.xml. In between runs, this file + // leads to confusion with Ivy when the files can no longer be found at the repository + // declared in that file/ + val mdId = md.getModuleRevisionId + val previousResolution = new File(ivySettings.getDefaultCache, + s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml") + if (previousResolution.exists) previousResolution.delete + md.setDefaultConf(ivyConfName) // Add exclusion rules for Spark and Scala Library addExclusionRules(ivySettings, ivyConfName, md) // add all supplied maven artifacts as dependencies addDependenciesToIvy(md, artifacts, ivyConfName) - exclusions.foreach { e => md.addExcludeRule(createExclusion(e + ":*", ivySettings, ivyConfName)) } - // resolve dependencies val rr: ResolveReport = ivy.resolve(md, resolveOptions) if (rr.hasError) { @@ -950,7 +1028,7 @@ private[spark] object SparkSubmitUtils { } } - private def createExclusion( + private[deploy] def createExclusion( coords: String, ivySettings: IvySettings, ivyConfName: String): ExcludeRule = { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index b7429a901e16..18a1c52ae53f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -23,7 +23,7 @@ import java.net.URI import java.util.{List => JList} import java.util.jar.JarFile -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.io.Source @@ -59,6 +59,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var packages: String = null var repositories: String = null var ivyRepoPath: String = null + var packagesExclusions: String = null var verbose: Boolean = false var isPython: Boolean = false var pyFiles: String = null @@ -79,6 +80,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S /** Default properties present in the currently defined defaults file. */ lazy val defaultSparkProperties: HashMap[String, String] = { val defaultProperties = new HashMap[String, String]() + // scalastyle:off println if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile") Option(propertiesFile).foreach { filename => Utils.getPropertiesFromFile(filename).foreach { case (k, v) => @@ -86,12 +88,13 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v") } } + // scalastyle:on println defaultProperties } // Set parameters from command line arguments try { - parse(args.toList) + parse(args.asJava) } catch { case e: IllegalArgumentException => SparkSubmit.printErrorAndExit(e.getMessage()) @@ -162,6 +165,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S .orNull executorCores = Option(executorCores) .orElse(sparkProperties.get("spark.executor.cores")) + .orElse(env.get("SPARK_EXECUTOR_CORES")) .orNull totalExecutorCores = Option(totalExecutorCores) .orElse(sparkProperties.get("spark.cores.max")) @@ -169,6 +173,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull ivyRepoPath = sparkProperties.get("spark.jars.ivy").orNull + packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull + packagesExclusions = Option(packagesExclusions) + .orElse(sparkProperties.get("spark.jars.excludes")).orNull deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull numExecutors = Option(numExecutors) .getOrElse(sparkProperties.get("spark.executor.instances").orNull) @@ -296,6 +303,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | childArgs [${childArgs.mkString(" ")}] | jars $jars | packages $packages + | packagesExclusions $packagesExclusions | repositories $repositories | verbose $verbose | @@ -388,6 +396,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case PACKAGES => packages = value + case PACKAGES_EXCLUDE => + packagesExclusions = value + case REPOSITORIES => repositories = value @@ -447,10 +458,11 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } override protected def handleExtraArgs(extra: JList[String]): Unit = { - childArgs ++= extra + childArgs ++= extra.asScala } private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = { + // scalastyle:off println val outStream = SparkSubmit.printStream if (unknownParam != null) { outStream.println("Unknown/unsupported param " + unknownParam) @@ -461,8 +473,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S |Usage: spark-submit --status [submission ID] --master [spark://...]""".stripMargin) outStream.println(command) + val mem_mb = Utils.DEFAULT_DRIVER_MEM_MB outStream.println( - """ + s""" |Options: | --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local. | --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or @@ -477,6 +490,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | maven repo, then maven central and any additional remote | repositories given by --repositories. The format for the | coordinates should be groupId:artifactId:version. + | --exclude-packages Comma-separated list of groupId:artifactId, to exclude while + | resolving the dependencies provided in --packages to avoid + | dependency conflicts. | --repositories Comma-separated list of additional remote repositories to | search for the maven coordinates given with --packages. | --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to place @@ -488,7 +504,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | --properties-file FILE Path to a file from which to load extra properties. If not | specified, this will look for conf/spark-defaults.conf. | - | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: 512M). + | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: ${mem_mb}M). | --driver-java-options Extra Java options to pass to the driver. | --driver-library-path Extra library path entries to pass to the driver. | --driver-class-path Extra class path entries to pass to the driver. Note that @@ -539,6 +555,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S outStream.println("CLI options:") outStream.println(getSqlShellOptions()) } + // scalastyle:on println SparkSubmit.exitFn(exitCode) } @@ -570,7 +587,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S System.setSecurityManager(sm) try { - Class.forName(mainClass).getMethod("main", classOf[Array[String]]) + Utils.classForName(mainClass).getMethod("main", classOf[Array[String]]) .invoke(null, Array(HELP)) } catch { case e: InvocationTargetException => @@ -594,5 +611,4 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S System.setErr(currentErr) } } - } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 43c8a934c311..25ea6925434a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -17,20 +17,17 @@ package org.apache.spark.deploy.client -import java.util.concurrent.TimeoutException +import java.util.concurrent._ +import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} -import scala.concurrent.Await -import scala.concurrent.duration._ - -import akka.actor._ -import akka.pattern.ask -import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} +import scala.util.control.NonFatal import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master -import org.apache.spark.util.{ActorLogReceive, RpcUtils, Utils, AkkaUtils} +import org.apache.spark.rpc._ +import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils} /** * Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL, @@ -40,98 +37,143 @@ import org.apache.spark.util.{ActorLogReceive, RpcUtils, Utils, AkkaUtils} * @param masterUrls Each url should look like spark://host:port. */ private[spark] class AppClient( - actorSystem: ActorSystem, + rpcEnv: RpcEnv, masterUrls: Array[String], appDescription: ApplicationDescription, listener: AppClientListener, conf: SparkConf) extends Logging { - private val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem))) + private val masterRpcAddresses = masterUrls.map(RpcAddress.fromSparkURL(_)) - private val REGISTRATION_TIMEOUT = 20.seconds + private val REGISTRATION_TIMEOUT_SECONDS = 20 private val REGISTRATION_RETRIES = 3 - private var masterAddress: Address = null - private var actor: ActorRef = null + private var endpoint: RpcEndpointRef = null private var appId: String = null - private var registered = false - private var activeMasterUrl: String = null + @volatile private var registered = false + + private class ClientEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint + with Logging { + + private var master: Option[RpcEndpointRef] = None + // To avoid calling listener.disconnected() multiple times + private var alreadyDisconnected = false + @volatile private var alreadyDead = false // To avoid calling listener.dead() multiple times + @volatile private var registerMasterFutures: Array[JFuture[_]] = null + @volatile private var registrationRetryTimer: JScheduledFuture[_] = null - private class ClientActor extends Actor with ActorLogReceive with Logging { - var master: ActorSelection = null - var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times - var alreadyDead = false // To avoid calling listener.dead() multiple times - var registrationRetryTimer: Option[Cancellable] = None + // A thread pool for registering with masters. Because registering with a master is a blocking + // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same + // time so that we can register with all masters. + private val registerMasterThreadPool = new ThreadPoolExecutor( + 0, + masterRpcAddresses.size, // Make sure we can register with all masters at the same time + 60L, TimeUnit.SECONDS, + new SynchronousQueue[Runnable](), + ThreadUtils.namedThreadFactory("appclient-register-master-threadpool")) - override def preStart() { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + // A scheduled executor for scheduling the registration actions + private val registrationRetryThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("appclient-registration-retry-thread") + + override def onStart(): Unit = { try { - registerWithMaster() + registerWithMaster(1) } catch { case e: Exception => logWarning("Failed to connect to master", e) markDisconnected() - context.stop(self) + stop() } } - def tryRegisterAllMasters() { - for (masterAkkaUrl <- masterAkkaUrls) { - logInfo("Connecting to master " + masterAkkaUrl + "...") - val actor = context.actorSelection(masterAkkaUrl) - actor ! RegisterApplication(appDescription) + /** + * Register with all masters asynchronously and returns an array `Future`s for cancellation. + */ + private def tryRegisterAllMasters(): Array[JFuture[_]] = { + for (masterAddress <- masterRpcAddresses) yield { + registerMasterThreadPool.submit(new Runnable { + override def run(): Unit = try { + if (registered) { + return + } + logInfo("Connecting to master " + masterAddress.toSparkURL + "...") + val masterRef = + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + masterRef.send(RegisterApplication(appDescription, self)) + } catch { + case ie: InterruptedException => // Cancelled + case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) + } + }) } } - def registerWithMaster() { - tryRegisterAllMasters() - import context.dispatcher - var retries = 0 - registrationRetryTimer = Some { - context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) { + /** + * Register with all masters asynchronously. It will call `registerWithMaster` every + * REGISTRATION_TIMEOUT_SECONDS seconds until exceeding REGISTRATION_RETRIES times. + * Once we connect to a master successfully, all scheduling work and Futures will be cancelled. + * + * nthRetry means this is the nth attempt to register with master. + */ + private def registerWithMaster(nthRetry: Int) { + registerMasterFutures = tryRegisterAllMasters() + registrationRetryTimer = registrationRetryThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = { Utils.tryOrExit { - retries += 1 if (registered) { - registrationRetryTimer.foreach(_.cancel()) - } else if (retries >= REGISTRATION_RETRIES) { + registerMasterFutures.foreach(_.cancel(true)) + registerMasterThreadPool.shutdownNow() + } else if (nthRetry >= REGISTRATION_RETRIES) { markDead("All masters are unresponsive! Giving up.") } else { - tryRegisterAllMasters() + registerMasterFutures.foreach(_.cancel(true)) + registerWithMaster(nthRetry + 1) } } } - } + }, REGISTRATION_TIMEOUT_SECONDS, REGISTRATION_TIMEOUT_SECONDS, TimeUnit.SECONDS) } - def changeMaster(url: String) { - // activeMasterUrl is a valid Spark url since we receive it from master. - activeMasterUrl = url - master = context.actorSelection( - Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(actorSystem))) - masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(actorSystem)) + /** + * Send a message to the current master. If we have not yet registered successfully with any + * master, the message will be dropped. + */ + private def sendToMaster(message: Any): Unit = { + master match { + case Some(masterRef) => masterRef.send(message) + case None => logWarning(s"Drop $message because has not yet connected to master") + } } - private def isPossibleMaster(remoteUrl: Address) = { - masterAkkaUrls.map(AddressFromURIString(_).hostPort).contains(remoteUrl.hostPort) + private def isPossibleMaster(remoteAddress: RpcAddress): Boolean = { + masterRpcAddresses.contains(remoteAddress) } - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case RegisteredApplication(appId_, masterUrl) => + override def receive: PartialFunction[Any, Unit] = { + case RegisteredApplication(appId_, masterRef) => + // FIXME How to handle the following cases? + // 1. A master receives multiple registrations and sends back multiple + // RegisteredApplications due to an unstable network. + // 2. Receive multiple RegisteredApplication from different masters because the master is + // changing. appId = appId_ registered = true - changeMaster(masterUrl) + master = Some(masterRef) listener.connected(appId) case ApplicationRemoved(message) => markDead("Master removed our application: %s".format(message)) - context.stop(self) + stop() case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => val fullId = appId + "/" + id logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores)) - master ! ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None) + // FIXME if changing master and `ExecutorAdded` happen at the same time (the order is not + // guaranteed), `ExecutorStateChanged` may be sent to a dead master. + sendToMaster(ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None)) listener.executorAdded(fullId, workerId, hostPort, cores, memory) case ExecutorUpdated(id, state, message, exitStatus) => @@ -142,24 +184,48 @@ private[spark] class AppClient( listener.executorRemoved(fullId, message.getOrElse(""), exitStatus) } - case MasterChanged(masterUrl, masterWebUiUrl) => - logInfo("Master has changed, new master is at " + masterUrl) - changeMaster(masterUrl) + case MasterChanged(masterRef, masterWebUiUrl) => + logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) + master = Some(masterRef) alreadyDisconnected = false - sender ! MasterChangeAcknowledged(appId) + masterRef.send(MasterChangeAcknowledged(appId)) + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case StopAppClient => + markDead("Application has been stopped.") + sendToMaster(UnregisterApplication(appId)) + context.reply(true) + stop() - case DisassociatedEvent(_, address, _) if address == masterAddress => + case r: RequestExecutors => + master match { + case Some(m) => context.reply(m.askWithRetry[Boolean](r)) + case None => + logWarning("Attempted to request executors before registering with Master.") + context.reply(false) + } + + case k: KillExecutors => + master match { + case Some(m) => context.reply(m.askWithRetry[Boolean](k)) + case None => + logWarning("Attempted to kill executors before registering with Master.") + context.reply(false) + } + } + + override def onDisconnected(address: RpcAddress): Unit = { + if (master.exists(_.address == address)) { logWarning(s"Connection to $address failed; waiting for master to reconnect...") markDisconnected() + } + } - case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) => + override def onNetworkError(cause: Throwable, address: RpcAddress): Unit = { + if (isPossibleMaster(address)) { logWarning(s"Could not connect to $address: $cause") - - case StopAppClient => - markDead("Application has been stopped.") - master ! UnregisterApplication(appId) - sender ! true - context.stop(self) + } } /** @@ -179,28 +245,61 @@ private[spark] class AppClient( } } - override def postStop() { - registrationRetryTimer.foreach(_.cancel()) + override def onStop(): Unit = { + if (registrationRetryTimer != null) { + registrationRetryTimer.cancel(true) + } + registrationRetryThread.shutdownNow() + registerMasterFutures.foreach(_.cancel(true)) + registerMasterThreadPool.shutdownNow() } } def start() { - // Just launch an actor; it will call back into the listener. - actor = actorSystem.actorOf(Props(new ClientActor)) + // Just launch an rpcEndpoint; it will call back into the listener. + endpoint = rpcEnv.setupEndpoint("AppClient", new ClientEndpoint(rpcEnv)) } def stop() { - if (actor != null) { + if (endpoint != null) { try { - val timeout = RpcUtils.askTimeout(conf) - val future = actor.ask(StopAppClient)(timeout) - Await.result(future, timeout) + val timeout = RpcUtils.askRpcTimeout(conf) + timeout.awaitResult(endpoint.ask[Boolean](StopAppClient)) } catch { case e: TimeoutException => logInfo("Stop request to Master timed out; it may already be shut down.") } - actor = null + endpoint = null + } + } + + /** + * Request executors from the Master by specifying the total number desired, + * including existing pending and running executors. + * + * @return whether the request is acknowledged. + */ + def requestTotalExecutors(requestedTotal: Int): Boolean = { + if (endpoint != null && appId != null) { + endpoint.askWithRetry[Boolean](RequestExecutors(appId, requestedTotal)) + } else { + logWarning("Attempted to request executors before driver fully initialized.") + false } } + + /** + * Kill the given list of executors through the Master. + * @return whether the kill request is acknowledged. + */ + def killExecutors(executorIds: Seq[String]): Boolean = { + if (endpoint != null && appId != null) { + endpoint.askWithRetry[Boolean](KillExecutors(appId, executorIds)) + } else { + logWarning("Attempted to kill executors before driver fully initialized.") + false + } + } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index 40835b955058..1c79089303e3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -17,9 +17,10 @@ package org.apache.spark.deploy.client +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, Command} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.Utils private[spark] object TestClient { @@ -46,13 +47,12 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) val conf = new SparkConf - val (actorSystem, _) = AkkaUtils.createActorSystem("spark", Utils.localHostName(), 0, - conf = conf, securityManager = new SecurityManager(conf)) + val rpcEnv = RpcEnv.create("spark", Utils.localHostName(), 0, conf, new SecurityManager(conf)) val desc = new ApplicationDescription("TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), Seq(), Seq()), "ignored") val listener = new TestListener - val client = new AppClient(actorSystem, Array(url), desc, listener, new SparkConf) + val client = new AppClient(rpcEnv, Array(url), desc, listener, new SparkConf) client.start() - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala index c5ac45c6730d..a98b1fa8f83a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala @@ -19,7 +19,9 @@ package org.apache.spark.deploy.client private[spark] object TestExecutor { def main(args: Array[String]) { + // scalastyle:off println println("Hello world!") + // scalastyle:on println while (true) { Thread.sleep(1000) } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index db383b9823d3..a5755eac3639 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.history import java.io.{BufferedInputStream, FileNotFoundException, InputStream, IOException, OutputStream} +import java.util.UUID import java.util.concurrent.{ExecutorService, Executors, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} @@ -73,7 +74,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // The modification time of the newest log detected during the last scan. This is used // to ignore logs that are older during subsequent scans, to avoid processing data that // is already known. - private var lastModifiedTime = -1L + private var lastScanTime = -1L // Mapping of application IDs to their metadata, in descending end time order. Apps are inserted // into the map in order, so the LinkedHashMap maintains the correct ordering. @@ -83,12 +84,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // List of application logs to be deleted by event log cleaner. private var attemptsToClean = new mutable.ListBuffer[FsApplicationAttemptInfo] - // Constants used to parse Spark 1.0.0 log directories. - private[history] val LOG_PREFIX = "EVENT_LOG_" - private[history] val SPARK_VERSION_PREFIX = EventLoggingListener.SPARK_VERSION_KEY + "_" - private[history] val COMPRESSION_CODEC_PREFIX = EventLoggingListener.COMPRESSION_CODEC_KEY + "_" - private[history] val APPLICATION_COMPLETE = "APPLICATION_COMPLETE" - /** * Return a runnable that performs the given operation on the event logs. * This operation is expected to be executed periodically. @@ -132,11 +127,11 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // Disable the background thread during tests. if (!conf.contains("spark.testing")) { // A task that periodically checks for event log updates on disk. - pool.scheduleAtFixedRate(getRunner(checkForLogs), 0, UPDATE_INTERVAL_S, TimeUnit.SECONDS) + pool.scheduleWithFixedDelay(getRunner(checkForLogs), 0, UPDATE_INTERVAL_S, TimeUnit.SECONDS) if (conf.getBoolean("spark.history.fs.cleaner.enabled", false)) { // A task that periodically cleans event logs on disk. - pool.scheduleAtFixedRate(getRunner(cleanLogs), 0, CLEAN_INTERVAL_S, TimeUnit.SECONDS) + pool.scheduleWithFixedDelay(getRunner(cleanLogs), 0, CLEAN_INTERVAL_S, TimeUnit.SECONDS) } } } @@ -146,7 +141,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) override def getAppUI(appId: String, attemptId: Option[String]): Option[SparkUI] = { try { applications.get(appId).flatMap { appInfo => - appInfo.attempts.find(_.attemptId == attemptId).map { attempt => + appInfo.attempts.find(_.attemptId == attemptId).flatMap { attempt => val replayBus = new ReplayListenerBus() val ui = { val conf = this.conf.clone() @@ -155,20 +150,20 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) HistoryServer.getAttemptURI(appId, attempt.attemptId), attempt.startTime) // Do not call ui.bind() to avoid creating a new server for each application } - val appListener = new ApplicationEventListener() replayBus.addListener(appListener) val appInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), replayBus) - - appInfo.foreach { app => ui.setAppName(s"${app.name} ($appId)") } - - val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) - ui.getSecurityManager.setAcls(uiAclsEnabled) - // make sure to set admin acls before view acls so they are properly picked up - ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse("")) - ui.getSecurityManager.setViewAcls(attempt.sparkUser, - appListener.viewAcls.getOrElse("")) - ui + appInfo.map { info => + ui.setAppName(s"${info.name} ($appId)") + + val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) + ui.getSecurityManager.setAcls(uiAclsEnabled) + // make sure to set admin acls before view acls so they are properly picked up + ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse("")) + ui.getSecurityManager.setViewAcls(attempt.sparkUser, + appListener.viewAcls.getOrElse("")) + ui + } } } } catch { @@ -185,15 +180,14 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) */ private[history] def checkForLogs(): Unit = { try { + val newLastScanTime = getNewLastScanTime() val statusList = Option(fs.listStatus(new Path(logDir))).map(_.toSeq) .getOrElse(Seq[FileStatus]()) - var newLastModifiedTime = lastModifiedTime val logInfos: Seq[FileStatus] = statusList .filter { entry => try { getModificationTime(entry).map { time => - newLastModifiedTime = math.max(newLastModifiedTime, time) - time >= lastModifiedTime + time >= lastScanTime }.getOrElse(false) } catch { case e: AccessControlException => @@ -210,18 +204,49 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) mod1 >= mod2 } - logInfos.sliding(20, 20).foreach { batch => - replayExecutor.submit(new Runnable { - override def run(): Unit = mergeApplicationListing(batch) - }) - } + logInfos.grouped(20) + .map { batch => + replayExecutor.submit(new Runnable { + override def run(): Unit = mergeApplicationListing(batch) + }) + } + .foreach { task => + try { + // Wait for all tasks to finish. This makes sure that checkForLogs + // is not scheduled again while some tasks are already running in + // the replayExecutor. + task.get() + } catch { + case e: InterruptedException => + throw e + case e: Exception => + logError("Exception while merging application listings", e) + } + } - lastModifiedTime = newLastModifiedTime + lastScanTime = newLastScanTime } catch { case e: Exception => logError("Exception in checking for event log updates", e) } } + private def getNewLastScanTime(): Long = { + val fileName = "." + UUID.randomUUID().toString + val path = new Path(logDir, fileName) + val fos = fs.create(path) + + try { + fos.close() + fs.getFileStatus(path).getModificationTime + } catch { + case e: Exception => + logError("Exception encountered when attempting to update last scan time", e) + lastScanTime + } finally { + fs.delete(path) + } + } + override def writeEventLogs( appId: String, attemptId: Option[String], @@ -278,9 +303,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) * Replay the log files in the list and merge the list of old applications with new ones */ private def mergeApplicationListing(logs: Seq[FileStatus]): Unit = { - val bus = new ReplayListenerBus() val newAttempts = logs.flatMap { fileStatus => try { + val bus = new ReplayListenerBus() val res = replay(fileStatus, bus) res match { case Some(r) => logDebug(s"Application log ${r.logPath} loaded successfully.") @@ -413,8 +438,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Comparison function that defines the sort order for application attempts within the same - * application. Order is: running attempts before complete attempts, running attempts sorted - * by start time, completed attempts sorted by end time. + * application. Order is: attempts are sorted by descending start time. + * Most recent attempt state matches with current state of the app. * * Normally applications should have a single running attempt; but failure to call sc.stop() * may cause multiple running attempts to show up. @@ -424,11 +449,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private def compareAttemptInfo( a1: FsApplicationAttemptInfo, a2: FsApplicationAttemptInfo): Boolean = { - if (a1.completed == a2.completed) { - if (a1.completed) a1.endTime >= a2.endTime else a1.startTime >= a2.startTime - } else { - !a1.completed - } + a1.startTime >= a2.startTime } /** @@ -451,17 +472,23 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val appCompleted = isApplicationCompleted(eventLog) bus.addListener(appListener) bus.replay(logInput, logPath.toString, !appCompleted) - appListener.appId.map { appId => - new FsApplicationAttemptInfo( + + // Without an app ID, new logs will render incorrectly in the listing page, so do not list or + // try to show their UI. Some old versions of Spark generate logs without an app ID, so let + // logs generated by those versions go through. + if (appListener.appId.isDefined || !sparkVersionHasAppId(eventLog)) { + Some(new FsApplicationAttemptInfo( logPath.getName(), appListener.appName.getOrElse(NOT_STARTED), - appId, + appListener.appId.getOrElse(logPath.getName()), appListener.appAttemptId, appListener.startTime.getOrElse(-1L), appListener.endTime.getOrElse(-1L), getModificationTime(eventLog).get, appListener.sparkUser.getOrElse(NOT_STARTED), - appCompleted) + appCompleted)) + } else { + None } } finally { logInput.close() @@ -537,10 +564,34 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } + /** + * Returns whether the version of Spark that generated logs records app IDs. App IDs were added + * in Spark 1.1. + */ + private def sparkVersionHasAppId(entry: FileStatus): Boolean = { + if (isLegacyLogDirectory(entry)) { + fs.listStatus(entry.getPath()) + .find { status => status.getPath().getName().startsWith(SPARK_VERSION_PREFIX) } + .map { status => + val version = status.getPath().getName().substring(SPARK_VERSION_PREFIX.length()) + version != "1.0" && version != "1.1" + } + .getOrElse(true) + } else { + true + } + } + } -private object FsHistoryProvider { +private[history] object FsHistoryProvider { val DEFAULT_LOG_DIR = "file:/tmp/spark-events" + + // Constants used to parse Spark 1.0.0 log directories. + val LOG_PREFIX = "EVENT_LOG_" + val SPARK_VERSION_PREFIX = EventLoggingListener.SPARK_VERSION_KEY + "_" + val COMPRESSION_CODEC_PREFIX = EventLoggingListener.COMPRESSION_CODEC_KEY + "_" + val APPLICATION_COMPLETE = "APPLICATION_COMPLETE" } private class FsApplicationAttemptInfo( diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 10638afb7490..d4f327cc588f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -30,7 +30,7 @@ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, Applica UIRoot} import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.{SignalLogger, Utils} +import org.apache.spark.util.{ShutdownHookManager, SignalLogger, Utils} /** * A web server that renders SparkUIs of completed applications. @@ -228,7 +228,7 @@ object HistoryServer extends Logging { val providerName = conf.getOption("spark.history.provider") .getOrElse(classOf[FsHistoryProvider].getName()) - val provider = Class.forName(providerName) + val provider = Utils.classForName(providerName) .getConstructor(classOf[SparkConf]) .newInstance(conf) .asInstanceOf[ApplicationHistoryProvider] @@ -238,7 +238,7 @@ object HistoryServer extends Logging { val server = new HistoryServer(conf, provider, securityManager, port) server.bind() - Utils.addShutdownHook { () => server.stop() } + ShutdownHookManager.addShutdownHook { () => server.stop() } // Wait until the end of the world... or if the HistoryServer process is manually stopped while(true) { Thread.sleep(Int.MaxValue) } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index 4692d22651c9..18265df9faa2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -56,6 +56,7 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin Utils.loadDefaultSparkProperties(conf, propertiesFile) private def printUsageAndExit(exitCode: Int) { + // scalastyle:off println System.err.println( """ |Usage: HistoryServer [options] @@ -84,6 +85,7 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin | spark.history.fs.updateInterval How often to reload log data from storage | (in seconds, default: 10) |""".stripMargin) + // scalastyle:on println System.exit(exitCode) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index 1620e95bea21..b40d20f9f786 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -22,10 +22,8 @@ import java.util.Date import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import akka.actor.ActorRef - -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.ApplicationDescription +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] class ApplicationInfo( @@ -33,7 +31,7 @@ private[spark] class ApplicationInfo( val id: String, val desc: ApplicationDescription, val submitDate: Date, - val driver: ActorRef, + val driver: RpcEndpointRef, defaultCores: Int) extends Serializable { @@ -44,6 +42,11 @@ private[spark] class ApplicationInfo( @transient var endTime: Long = _ @transient var appSource: ApplicationSource = _ + // A cap on the number of executors this application can have at any given time. + // By default, this is infinite. Only after the first allocation request is issued by the + // application will this be set to a finite value. This is used for dynamic allocation. + @transient private[master] var executorLimit: Int = _ + @transient private var nextExecutorId: Int = _ init() @@ -61,6 +64,7 @@ private[spark] class ApplicationInfo( appSource = new ApplicationSource(this) nextExecutorId = 0 removedExecutors = new ArrayBuffer[ExecutorDesc] + executorLimit = Integer.MAX_VALUE } private def newExecutorId(useID: Option[Int] = None): Int = { @@ -117,6 +121,12 @@ private[spark] class ApplicationInfo( state != ApplicationState.WAITING && state != ApplicationState.RUNNING } + /** + * Return the limit on the number of executors this application can have. + * For testing only. + */ + private[deploy] def getExecutorLimit: Int = executorLimit + def duration: Long = { if (endTime != -1) { endTime - startTime diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index f459ed5b3a1a..aa379d4cd61e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -21,9 +21,8 @@ import java.io._ import scala.reflect.ClassTag -import akka.serialization.Serialization - import org.apache.spark.Logging +import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer} import org.apache.spark.util.Utils @@ -32,11 +31,11 @@ import org.apache.spark.util.Utils * Files are deleted when applications and workers are removed. * * @param dir Directory to store files. Created if non-existent (but not recursively). - * @param serialization Used to serialize our objects. + * @param serializer Used to serialize our objects. */ private[master] class FileSystemPersistenceEngine( val dir: String, - val serialization: Serialization) + val serializer: Serializer) extends PersistenceEngine with Logging { new File(dir).mkdir() @@ -57,27 +56,31 @@ private[master] class FileSystemPersistenceEngine( private def serializeIntoFile(file: File, value: AnyRef) { val created = file.createNewFile() if (!created) { throw new IllegalStateException("Could not create file: " + file) } - val serializer = serialization.findSerializerFor(value) - val serialized = serializer.toBinary(value) - val out = new FileOutputStream(file) + val fileOut = new FileOutputStream(file) + var out: SerializationStream = null Utils.tryWithSafeFinally { - out.write(serialized) + out = serializer.newInstance().serializeStream(fileOut) + out.writeObject(value) } { - out.close() + fileOut.close() + if (out != null) { + out.close() + } } } private def deserializeFromFile[T](file: File)(implicit m: ClassTag[T]): T = { - val fileData = new Array[Byte](file.length().asInstanceOf[Int]) - val dis = new DataInputStream(new FileInputStream(file)) + val fileIn = new FileInputStream(file) + var in: DeserializationStream = null try { - dis.readFully(fileData) + in = serializer.newInstance().deserializeStream(fileIn) + in.readObject[T]() } finally { - dis.close() + fileIn.close() + if (in != null) { + in.close() + } } - val clazz = m.runtimeClass.asInstanceOf[Class[T]] - val serializer = serialization.serializerFor(clazz) - serializer.fromBinary(fileData).asInstanceOf[T] } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala index cf77c86d760c..70f21fbe0de8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala @@ -26,7 +26,7 @@ import org.apache.spark.annotation.DeveloperApi */ @DeveloperApi trait LeaderElectionAgent { - val masterActor: LeaderElectable + val masterInstance: LeaderElectable def stop() {} // to avoid noops in implementations. } @@ -37,7 +37,7 @@ trait LeaderElectable { } /** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */ -private[spark] class MonarchyLeaderAgent(val masterActor: LeaderElectable) +private[spark] class MonarchyLeaderAgent(val masterInstance: LeaderElectable) extends LeaderElectionAgent { - masterActor.electedLeader() + masterInstance.electedLeader() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index fccceb3ea528..26904d39a9be 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -21,20 +21,15 @@ import java.io.FileNotFoundException import java.net.URLEncoder import java.text.SimpleDateFormat import java.util.Date +import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.Await -import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random -import akka.actor._ -import akka.pattern.ask -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} -import akka.serialization.Serialization -import akka.serialization.SerializationExtension import org.apache.hadoop.fs.Path +import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState, SparkHadoopUtil} @@ -46,24 +41,26 @@ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.deploy.rest.StandaloneRestServer import org.apache.spark.metrics.MetricsSystem import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} +import org.apache.spark.serializer.{JavaSerializer, Serializer} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, SignalLogger, Utils} +import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} -private[master] class Master( - host: String, - port: Int, +private[deploy] class Master( + override val rpcEnv: RpcEnv, + address: RpcAddress, webUiPort: Int, val securityMgr: SecurityManager, val conf: SparkConf) - extends Actor with ActorLogReceive with Logging with LeaderElectable { + extends ThreadSafeRpcEndpoint with Logging with LeaderElectable { - import context.dispatcher // to use Akka's scheduler.schedule() + private val forwardMessageThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread") private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs - private val WORKER_TIMEOUT = conf.getLong("spark.worker.timeout", 60) * 1000 + private val WORKER_TIMEOUT_MS = conf.getLong("spark.worker.timeout", 60) * 1000 private val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) private val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200) private val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15) @@ -75,10 +72,10 @@ private[master] class Master( val apps = new HashSet[ApplicationInfo] private val idToWorker = new HashMap[String, WorkerInfo] - private val addressToWorker = new HashMap[Address, WorkerInfo] + private val addressToWorker = new HashMap[RpcAddress, WorkerInfo] - private val actorToApp = new HashMap[ActorRef, ApplicationInfo] - private val addressToApp = new HashMap[Address, ApplicationInfo] + private val endpointToApp = new HashMap[RpcEndpointRef, ApplicationInfo] + private val addressToApp = new HashMap[RpcAddress, ApplicationInfo] private val completedApps = new ArrayBuffer[ApplicationInfo] private var nextAppNumber = 0 private val appIdToUI = new HashMap[String, SparkUI] @@ -89,21 +86,22 @@ private[master] class Master( private val waitingDrivers = new ArrayBuffer[DriverInfo] private var nextDriverNumber = 0 - Utils.checkHost(host, "Expected hostname") + Utils.checkHost(address.host, "Expected hostname") private val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr) private val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf, securityMgr) private val masterSource = new MasterSource(this) - private val webUi = new MasterWebUI(this, webUiPort) + // After onStart, webUi will be set + private var webUi: MasterWebUI = null private val masterPublicAddress = { val envVar = conf.getenv("SPARK_PUBLIC_DNS") - if (envVar != null) envVar else host + if (envVar != null) envVar else address.host } - private val masterUrl = "spark://" + host + ":" + port + private val masterUrl = address.toSparkURL private var masterWebUiUrl: String = _ private var state = RecoveryState.STANDBY @@ -112,7 +110,9 @@ private[master] class Master( private var leaderElectionAgent: LeaderElectionAgent = _ - private var recoveryCompletionTask: Cancellable = _ + private var recoveryCompletionTask: ScheduledFuture[_] = _ + + private var checkForWorkerTimeOutTask: ScheduledFuture[_] = _ // As a temporary workaround before better ways of configuring memory, we allow users to set // a flag that will perform round-robin scheduling across the nodes (spreading out each app @@ -127,23 +127,26 @@ private[master] class Master( // Alternative application submission gateway that is stable across Spark versions private val restServerEnabled = conf.getBoolean("spark.master.rest.enabled", true) - private val restServer = - if (restServerEnabled) { - val port = conf.getInt("spark.master.rest.port", 6066) - Some(new StandaloneRestServer(host, port, conf, self, masterUrl)) - } else { - None - } - private val restServerBoundPort = restServer.map(_.start()) + private var restServer: Option[StandaloneRestServer] = None + private var restServerBoundPort: Option[Int] = None - override def preStart() { + override def onStart(): Unit = { logInfo("Starting Spark master at " + masterUrl) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") - // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + webUi = new MasterWebUI(this, webUiPort) webUi.bind() masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort - context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut) + checkForWorkerTimeOutTask = forwardMessageThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(CheckForWorkerTimeOut) + } + }, 0, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS) + + if (restServerEnabled) { + val port = conf.getInt("spark.master.rest.port", 6066) + restServer = Some(new StandaloneRestServer(address.host, port, conf, self, masterUrl)) + } + restServerBoundPort = restServer.map(_.start()) masterMetricsSystem.registerSource(masterSource) masterMetricsSystem.start() @@ -153,20 +156,21 @@ private[master] class Master( masterMetricsSystem.getServletHandlers.foreach(webUi.attachHandler) applicationMetricsSystem.getServletHandlers.foreach(webUi.attachHandler) + val serializer = new JavaSerializer(conf) val (persistenceEngine_, leaderElectionAgent_) = RECOVERY_MODE match { case "ZOOKEEPER" => logInfo("Persisting recovery state to ZooKeeper") val zkFactory = - new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(context.system)) + new ZooKeeperRecoveryModeFactory(conf, serializer) (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this)) case "FILESYSTEM" => val fsFactory = - new FileSystemRecoveryModeFactory(conf, SerializationExtension(context.system)) + new FileSystemRecoveryModeFactory(conf, serializer) (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this)) case "CUSTOM" => - val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory")) - val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serialization]) - .newInstance(conf, SerializationExtension(context.system)) + val clazz = Utils.classForName(conf.get("spark.deploy.recoveryMode.factory")) + val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serializer]) + .newInstance(conf, serializer) .asInstanceOf[StandaloneRecoveryModeFactory] (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this)) case _ => @@ -176,18 +180,17 @@ private[master] class Master( leaderElectionAgent = leaderElectionAgent_ } - override def preRestart(reason: Throwable, message: Option[Any]) { - super.preRestart(reason, message) // calls postStop()! - logError("Master actor restarted due to exception", reason) - } - - override def postStop() { + override def onStop() { masterMetricsSystem.report() applicationMetricsSystem.report() // prevent the CompleteRecovery message sending to restarted master if (recoveryCompletionTask != null) { - recoveryCompletionTask.cancel() + recoveryCompletionTask.cancel(true) + } + if (checkForWorkerTimeOutTask != null) { + checkForWorkerTimeOutTask.cancel(true) } + forwardMessageThread.shutdownNow() webUi.stop() restServer.foreach(_.stop()) masterMetricsSystem.stop() @@ -197,16 +200,16 @@ private[master] class Master( } override def electedLeader() { - self ! ElectedLeader + self.send(ElectedLeader) } override def revokedLeadership() { - self ! RevokedLeadership + self.send(RevokedLeadership) } - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receive: PartialFunction[Any, Unit] = { case ElectedLeader => { - val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData() + val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv) state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { RecoveryState.ALIVE } else { @@ -215,8 +218,11 @@ private[master] class Master( logInfo("I have been elected leader! New state: " + state) if (state == RecoveryState.RECOVERING) { beginRecovery(storedApps, storedDrivers, storedWorkers) - recoveryCompletionTask = context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis, self, - CompleteRecovery) + recoveryCompletionTask = forwardMessageThread.schedule(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(CompleteRecovery) + } + }, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS) } } @@ -227,111 +233,42 @@ private[master] class Master( System.exit(0) } - case RegisterWorker(id, workerHost, workerPort, cores, memory, workerUiPort, publicAddress) => - { + case RegisterWorker( + id, workerHost, workerPort, workerRef, cores, memory, workerUiPort, publicAddress) => { logInfo("Registering worker %s:%d with %d cores, %s RAM".format( workerHost, workerPort, cores, Utils.megabytesToString(memory))) if (state == RecoveryState.STANDBY) { // ignore, don't send response } else if (idToWorker.contains(id)) { - sender ! RegisterWorkerFailed("Duplicate worker ID") + workerRef.send(RegisterWorkerFailed("Duplicate worker ID")) } else { val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory, - sender, workerUiPort, publicAddress) + workerRef, workerUiPort, publicAddress) if (registerWorker(worker)) { persistenceEngine.addWorker(worker) - sender ! RegisteredWorker(masterUrl, masterWebUiUrl) + workerRef.send(RegisteredWorker(self, masterWebUiUrl)) schedule() } else { - val workerAddress = worker.actor.path.address + val workerAddress = worker.endpoint.address logWarning("Worker registration failed. Attempted to re-register worker at same " + "address: " + workerAddress) - sender ! RegisterWorkerFailed("Attempted to re-register worker at same address: " - + workerAddress) - } - } - } - - case RequestSubmitDriver(description) => { - if (state != RecoveryState.ALIVE) { - val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + - "Can only accept driver submissions in ALIVE state." - sender ! SubmitDriverResponse(false, None, msg) - } else { - logInfo("Driver submitted " + description.command.mainClass) - val driver = createDriver(description) - persistenceEngine.addDriver(driver) - waitingDrivers += driver - drivers.add(driver) - schedule() - - // TODO: It might be good to instead have the submission client poll the master to determine - // the current status of the driver. For now it's simply "fire and forget". - - sender ! SubmitDriverResponse(true, Some(driver.id), - s"Driver successfully submitted as ${driver.id}") - } - } - - case RequestKillDriver(driverId) => { - if (state != RecoveryState.ALIVE) { - val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + - s"Can only kill drivers in ALIVE state." - sender ! KillDriverResponse(driverId, success = false, msg) - } else { - logInfo("Asked to kill driver " + driverId) - val driver = drivers.find(_.id == driverId) - driver match { - case Some(d) => - if (waitingDrivers.contains(d)) { - waitingDrivers -= d - self ! DriverStateChanged(driverId, DriverState.KILLED, None) - } else { - // We just notify the worker to kill the driver here. The final bookkeeping occurs - // on the return path when the worker submits a state change back to the master - // to notify it that the driver was successfully killed. - d.worker.foreach { w => - w.actor ! KillDriver(driverId) - } - } - // TODO: It would be nice for this to be a synchronous response - val msg = s"Kill request for $driverId submitted" - logInfo(msg) - sender ! KillDriverResponse(driverId, success = true, msg) - case None => - val msg = s"Driver $driverId has already finished or does not exist" - logWarning(msg) - sender ! KillDriverResponse(driverId, success = false, msg) - } - } - } - - case RequestDriverStatus(driverId) => { - if (state != RecoveryState.ALIVE) { - val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + - "Can only request driver status in ALIVE state." - sender ! DriverStatusResponse(found = false, None, None, None, Some(new Exception(msg))) - } else { - (drivers ++ completedDrivers).find(_.id == driverId) match { - case Some(driver) => - sender ! DriverStatusResponse(found = true, Some(driver.state), - driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception) - case None => - sender ! DriverStatusResponse(found = false, None, None, None, None) + workerRef.send(RegisterWorkerFailed("Attempted to re-register worker at same address: " + + workerAddress)) } } } - case RegisterApplication(description) => { + case RegisterApplication(description, driver) => { + // TODO Prevent repeated registrations from some driver if (state == RecoveryState.STANDBY) { // ignore, don't send response } else { logInfo("Registering app " + description.name) - val app = createApplication(description, sender) + val app = createApplication(description, driver) registerApplication(app) logInfo("Registered app " + description.name + " with ID " + app.id) persistenceEngine.addApplication(app) - sender ! RegisteredApplication(app.id, masterUrl) + driver.send(RegisteredApplication(app.id, self)) schedule() } } @@ -343,7 +280,7 @@ private[master] class Master( val appInfo = idToApp(appId) exec.state = state if (state == ExecutorState.RUNNING) { appInfo.resetRetryCount() } - exec.application.driver ! ExecutorUpdated(execId, state, message, exitStatus) + exec.application.driver.send(ExecutorUpdated(execId, state, message, exitStatus)) if (ExecutorState.isFinished(state)) { // Remove this executor from the worker and app logInfo(s"Removing executor ${exec.fullId} because it is $state") @@ -384,7 +321,7 @@ private[master] class Master( } } - case Heartbeat(workerId) => { + case Heartbeat(workerId, worker) => { idToWorker.get(workerId) match { case Some(workerInfo) => workerInfo.lastHeartbeat = System.currentTimeMillis() @@ -392,7 +329,7 @@ private[master] class Master( if (workers.map(_.id).contains(workerId)) { logWarning(s"Got heartbeat from unregistered worker $workerId." + " Asking it to re-register.") - sender ! ReconnectWorker(masterUrl) + worker.send(ReconnectWorker(masterUrl)) } else { logWarning(s"Got heartbeat from unregistered worker $workerId." + " This worker was never registered, so ignoring the heartbeat.") @@ -444,28 +381,108 @@ private[master] class Master( logInfo(s"Received unregister request from application $applicationId") idToApp.get(applicationId).foreach(finishApplication) - case DisassociatedEvent(_, address, _) => { - // The disconnected client could've been either a worker or an app; remove whichever it was - logInfo(s"$address got disassociated, removing it.") - addressToWorker.get(address).foreach(removeWorker) - addressToApp.get(address).foreach(finishApplication) - if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } + case CheckForWorkerTimeOut => { + timeOutDeadWorkers() } + } - case RequestMasterState => { - sender ! MasterStateResponse( - host, port, restServerBoundPort, - workers.toArray, apps.toArray, completedApps.toArray, - drivers.toArray, completedDrivers.toArray, state) + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RequestSubmitDriver(description) => { + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + "Can only accept driver submissions in ALIVE state." + context.reply(SubmitDriverResponse(self, false, None, msg)) + } else { + logInfo("Driver submitted " + description.command.mainClass) + val driver = createDriver(description) + persistenceEngine.addDriver(driver) + waitingDrivers += driver + drivers.add(driver) + schedule() + + // TODO: It might be good to instead have the submission client poll the master to determine + // the current status of the driver. For now it's simply "fire and forget". + + context.reply(SubmitDriverResponse(self, true, Some(driver.id), + s"Driver successfully submitted as ${driver.id}")) + } } - case CheckForWorkerTimeOut => { - timeOutDeadWorkers() + case RequestKillDriver(driverId) => { + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + s"Can only kill drivers in ALIVE state." + context.reply(KillDriverResponse(self, driverId, success = false, msg)) + } else { + logInfo("Asked to kill driver " + driverId) + val driver = drivers.find(_.id == driverId) + driver match { + case Some(d) => + if (waitingDrivers.contains(d)) { + waitingDrivers -= d + self.send(DriverStateChanged(driverId, DriverState.KILLED, None)) + } else { + // We just notify the worker to kill the driver here. The final bookkeeping occurs + // on the return path when the worker submits a state change back to the master + // to notify it that the driver was successfully killed. + d.worker.foreach { w => + w.endpoint.send(KillDriver(driverId)) + } + } + // TODO: It would be nice for this to be a synchronous response + val msg = s"Kill request for $driverId submitted" + logInfo(msg) + context.reply(KillDriverResponse(self, driverId, success = true, msg)) + case None => + val msg = s"Driver $driverId has already finished or does not exist" + logWarning(msg) + context.reply(KillDriverResponse(self, driverId, success = false, msg)) + } + } + } + + case RequestDriverStatus(driverId) => { + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + "Can only request driver status in ALIVE state." + context.reply( + DriverStatusResponse(found = false, None, None, None, Some(new Exception(msg)))) + } else { + (drivers ++ completedDrivers).find(_.id == driverId) match { + case Some(driver) => + context.reply(DriverStatusResponse(found = true, Some(driver.state), + driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception)) + case None => + context.reply(DriverStatusResponse(found = false, None, None, None, None)) + } + } + } + + case RequestMasterState => { + context.reply(MasterStateResponse( + address.host, address.port, restServerBoundPort, + workers.toArray, apps.toArray, completedApps.toArray, + drivers.toArray, completedDrivers.toArray, state)) } case BoundPortsRequest => { - sender ! BoundPortsResponse(port, webUi.boundPort, restServerBoundPort) + context.reply(BoundPortsResponse(address.port, webUi.boundPort, restServerBoundPort)) } + + case RequestExecutors(appId, requestedTotal) => + context.reply(handleRequestExecutors(appId, requestedTotal)) + + case KillExecutors(appId, executorIds) => + val formattedExecutorIds = formatExecutorIds(executorIds) + context.reply(handleKillExecutors(appId, formattedExecutorIds)) + } + + override def onDisconnected(address: RpcAddress): Unit = { + // The disconnected client could've been either a worker or an app; remove whichever it was + logInfo(s"$address got disassociated, removing it.") + addressToWorker.get(address).foreach(removeWorker) + addressToApp.get(address).foreach(finishApplication) + if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } } private def canCompleteRecovery = @@ -479,7 +496,7 @@ private[master] class Master( try { registerApplication(app) app.state = ApplicationState.UNKNOWN - app.driver ! MasterChanged(masterUrl, masterWebUiUrl) + app.driver.send(MasterChanged(self, masterWebUiUrl)) } catch { case e: Exception => logInfo("App " + app.id + " had exception on reconnect") } @@ -496,7 +513,7 @@ private[master] class Master( try { registerWorker(worker) worker.state = WorkerState.UNKNOWN - worker.actor ! MasterChanged(masterUrl, masterWebUiUrl) + worker.endpoint.send(MasterChanged(self, masterWebUiUrl)) } catch { case e: Exception => logInfo("Worker " + worker.id + " had exception on reconnect") } @@ -505,10 +522,8 @@ private[master] class Master( private def completeRecovery() { // Ensure "only-once" recovery semantics using a short synchronization period. - synchronized { - if (state != RecoveryState.RECOVERING) { return } - state = RecoveryState.COMPLETING_RECOVERY - } + if (state != RecoveryState.RECOVERING) { return } + state = RecoveryState.COMPLETING_RECOVERY // Kill off any workers and apps that didn't respond to us. workers.filter(_.state == WorkerState.UNKNOWN).foreach(removeWorker) @@ -533,6 +548,7 @@ private[master] class Master( /** * Schedule executors to be launched on the workers. + * Returns an array containing number of cores assigned to each worker. * * There are two modes of launching executors. The first attempts to spread out an application's * executors on as many workers as possible, while the second does the opposite (i.e. launch them @@ -543,39 +559,97 @@ private[master] class Master( * multiple executors from the same application may be launched on the same worker if the worker * has enough cores and memory. Otherwise, each executor grabs all the cores available on the * worker by default, in which case only one executor may be launched on each worker. + * + * It is important to allocate coresPerExecutor on each worker at a time (instead of 1 core + * at a time). Consider the following example: cluster has 4 workers with 16 cores each. + * User requests 3 executors (spark.cores.max = 48, spark.executor.cores = 16). If 1 core is + * allocated at a time, 12 cores from each worker would be assigned to each executor. + * Since 12 < 16, no executors would launch [SPARK-8881]. */ - private def startExecutorsOnWorkers(): Unit = { - // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app - // in the queue, then the second app, etc. - if (spreadOutApps) { - // Try to spread out each app among all the workers, until it has all its cores - for (app <- waitingApps if app.coresLeft > 0) { - val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) - .filter(worker => worker.memoryFree >= app.desc.memoryPerExecutorMB && - worker.coresFree >= app.desc.coresPerExecutor.getOrElse(1)) - .sortBy(_.coresFree).reverse - val numUsable = usableWorkers.length - val assigned = new Array[Int](numUsable) // Number of cores to give on each node - var toAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum) - var pos = 0 - while (toAssign > 0) { - if (usableWorkers(pos).coresFree - assigned(pos) > 0) { - toAssign -= 1 - assigned(pos) += 1 + private def scheduleExecutorsOnWorkers( + app: ApplicationInfo, + usableWorkers: Array[WorkerInfo], + spreadOutApps: Boolean): Array[Int] = { + val coresPerExecutor = app.desc.coresPerExecutor + val minCoresPerExecutor = coresPerExecutor.getOrElse(1) + val oneExecutorPerWorker = coresPerExecutor.isEmpty + val memoryPerExecutor = app.desc.memoryPerExecutorMB + val numUsable = usableWorkers.length + val assignedCores = new Array[Int](numUsable) // Number of cores to give to each worker + val assignedExecutors = new Array[Int](numUsable) // Number of new executors on each worker + var coresToAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum) + + /** Return whether the specified worker can launch an executor for this app. */ + def canLaunchExecutor(pos: Int): Boolean = { + val keepScheduling = coresToAssign >= minCoresPerExecutor + val enoughCores = usableWorkers(pos).coresFree - assignedCores(pos) >= minCoresPerExecutor + + // If we allow multiple executors per worker, then we can always launch new executors. + // Otherwise, if there is already an executor on this worker, just give it more cores. + val launchingNewExecutor = !oneExecutorPerWorker || assignedExecutors(pos) == 0 + if (launchingNewExecutor) { + val assignedMemory = assignedExecutors(pos) * memoryPerExecutor + val enoughMemory = usableWorkers(pos).memoryFree - assignedMemory >= memoryPerExecutor + val underLimit = assignedExecutors.sum + app.executors.size < app.executorLimit + keepScheduling && enoughCores && enoughMemory && underLimit + } else { + // We're adding cores to an existing executor, so no need + // to check memory and executor limits + keepScheduling && enoughCores + } + } + + // Keep launching executors until no more workers can accommodate any + // more executors, or if we have reached this application's limits + var freeWorkers = (0 until numUsable).filter(canLaunchExecutor) + while (freeWorkers.nonEmpty) { + freeWorkers.foreach { pos => + var keepScheduling = true + while (keepScheduling && canLaunchExecutor(pos)) { + coresToAssign -= minCoresPerExecutor + assignedCores(pos) += minCoresPerExecutor + + // If we are launching one executor per worker, then every iteration assigns 1 core + // to the executor. Otherwise, every iteration assigns cores to a new executor. + if (oneExecutorPerWorker) { + assignedExecutors(pos) = 1 + } else { + assignedExecutors(pos) += 1 + } + + // Spreading out an application means spreading out its executors across as + // many workers as possible. If we are not spreading out, then we should keep + // scheduling executors on this worker until we use all of its resources. + // Otherwise, just move on to the next worker. + if (spreadOutApps) { + keepScheduling = false } - pos = (pos + 1) % numUsable - } - // Now that we've decided how many cores to give on each node, let's actually give them - for (pos <- 0 until numUsable if assigned(pos) > 0) { - allocateWorkerResourceToExecutors(app, assigned(pos), usableWorkers(pos)) } } - } else { - // Pack each app into as few workers as possible until we've assigned all its cores - for (worker <- workers if worker.coresFree > 0 && worker.state == WorkerState.ALIVE) { - for (app <- waitingApps if app.coresLeft > 0) { - allocateWorkerResourceToExecutors(app, app.coresLeft, worker) - } + freeWorkers = freeWorkers.filter(canLaunchExecutor) + } + assignedCores + } + + /** + * Schedule and launch executors on workers + */ + private def startExecutorsOnWorkers(): Unit = { + // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app + // in the queue, then the second app, etc. + for (app <- waitingApps if app.coresLeft > 0) { + val coresPerExecutor: Option[Int] = app.desc.coresPerExecutor + // Filter out workers that don't have enough resources to launch an executor + val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) + .filter(worker => worker.memoryFree >= app.desc.memoryPerExecutorMB && + worker.coresFree >= coresPerExecutor.getOrElse(1)) + .sortBy(_.coresFree).reverse + val assignedCores = scheduleExecutorsOnWorkers(app, usableWorkers, spreadOutApps) + + // Now that we've decided how many cores to allocate on each worker, let's allocate them + for (pos <- 0 until usableWorkers.length if assignedCores(pos) > 0) { + allocateWorkerResourceToExecutors( + app, assignedCores(pos), coresPerExecutor, usableWorkers(pos)) } } } @@ -583,19 +657,22 @@ private[master] class Master( /** * Allocate a worker's resources to one or more executors. * @param app the info of the application which the executors belong to - * @param coresToAllocate cores on this worker to be allocated to this application + * @param assignedCores number of cores on this worker for this application + * @param coresPerExecutor number of cores per executor * @param worker the worker info */ private def allocateWorkerResourceToExecutors( app: ApplicationInfo, - coresToAllocate: Int, + assignedCores: Int, + coresPerExecutor: Option[Int], worker: WorkerInfo): Unit = { - val memoryPerExecutor = app.desc.memoryPerExecutorMB - val coresPerExecutor = app.desc.coresPerExecutor.getOrElse(coresToAllocate) - var coresLeft = coresToAllocate - while (coresLeft >= coresPerExecutor && worker.memoryFree >= memoryPerExecutor) { - val exec = app.addExecutor(worker, coresPerExecutor) - coresLeft -= coresPerExecutor + // If the number of cores per executor is specified, we divide the cores assigned + // to this worker evenly among the executors with no remainder. + // Otherwise, we launch a single executor that grabs all the assignedCores on this worker. + val numExecutors = coresPerExecutor.map { assignedCores / _ }.getOrElse(1) + val coresToAssign = coresPerExecutor.getOrElse(assignedCores) + for (i <- 1 to numExecutors) { + val exec = app.addExecutor(worker, coresToAssign) launchExecutor(worker, exec) app.state = ApplicationState.RUNNING } @@ -623,10 +700,10 @@ private[master] class Master( private def launchExecutor(worker: WorkerInfo, exec: ExecutorDesc): Unit = { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) - worker.actor ! LaunchExecutor(masterUrl, - exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory) - exec.application.driver ! ExecutorAdded( - exec.id, worker.id, worker.hostPort, exec.cores, exec.memory) + worker.endpoint.send(LaunchExecutor(masterUrl, + exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory)) + exec.application.driver.send(ExecutorAdded( + exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)) } private def registerWorker(worker: WorkerInfo): Boolean = { @@ -638,7 +715,7 @@ private[master] class Master( workers -= w } - val workerAddress = worker.actor.path.address + val workerAddress = worker.endpoint.address if (addressToWorker.contains(workerAddress)) { val oldWorker = addressToWorker(workerAddress) if (oldWorker.state == WorkerState.UNKNOWN) { @@ -661,11 +738,11 @@ private[master] class Master( logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port) worker.setState(WorkerState.DEAD) idToWorker -= worker.id - addressToWorker -= worker.actor.path.address + addressToWorker -= worker.endpoint.address for (exec <- worker.executors.values) { logInfo("Telling app of lost executor: " + exec.id) - exec.application.driver ! ExecutorUpdated( - exec.id, ExecutorState.LOST, Some("worker lost"), None) + exec.application.driver.send(ExecutorUpdated( + exec.id, ExecutorState.LOST, Some("worker lost"), None)) exec.application.removeExecutor(exec) } for (driver <- worker.drivers.values) { @@ -687,14 +764,15 @@ private[master] class Master( schedule() } - private def createApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = { + private def createApplication(desc: ApplicationDescription, driver: RpcEndpointRef): + ApplicationInfo = { val now = System.currentTimeMillis() val date = new Date(now) new ApplicationInfo(now, newApplicationId(date), desc, date, driver, defaultCores) } private def registerApplication(app: ApplicationInfo): Unit = { - val appAddress = app.driver.path.address + val appAddress = app.driver.address if (addressToApp.contains(appAddress)) { logInfo("Attempted to re-register application at same address: " + appAddress) return @@ -703,7 +781,7 @@ private[master] class Master( applicationMetricsSystem.registerSource(app.appSource) apps += app idToApp(app.id) = app - actorToApp(app.driver) = app + endpointToApp(app.driver) = app addressToApp(appAddress) = app waitingApps += app } @@ -717,8 +795,8 @@ private[master] class Master( logInfo("Removing app " + app.id) apps -= app idToApp -= app.id - actorToApp -= app.driver - addressToApp -= app.driver.path.address + endpointToApp -= app.driver + addressToApp -= app.driver.address if (completedApps.size >= RETAINED_APPLICATIONS) { val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1) completedApps.take(toRemove).foreach( a => { @@ -734,24 +812,103 @@ private[master] class Master( rebuildSparkUI(app) for (exec <- app.executors.values) { - exec.worker.removeExecutor(exec) - exec.worker.actor ! KillExecutor(masterUrl, exec.application.id, exec.id) - exec.state = ExecutorState.KILLED + killExecutor(exec) } app.markFinished(state) if (state != ApplicationState.FINISHED) { - app.driver ! ApplicationRemoved(state.toString) + app.driver.send(ApplicationRemoved(state.toString)) } persistenceEngine.removeApplication(app) schedule() // Tell all workers that the application has finished, so they can clean up any app state. workers.foreach { w => - w.actor ! ApplicationFinished(app.id) + w.endpoint.send(ApplicationFinished(app.id)) + } + } + } + + /** + * Handle a request to set the target number of executors for this application. + * + * If the executor limit is adjusted upwards, new executors will be launched provided + * that there are workers with sufficient resources. If it is adjusted downwards, however, + * we do not kill existing executors until we explicitly receive a kill request. + * + * @return whether the application has previously registered with this Master. + */ + private def handleRequestExecutors(appId: String, requestedTotal: Int): Boolean = { + idToApp.get(appId) match { + case Some(appInfo) => + logInfo(s"Application $appId requested to set total executors to $requestedTotal.") + appInfo.executorLimit = requestedTotal + schedule() + true + case None => + logWarning(s"Unknown application $appId requested $requestedTotal total executors.") + false + } + } + + /** + * Handle a kill request from the given application. + * + * This method assumes the executor limit has already been adjusted downwards through + * a separate [[RequestExecutors]] message, such that we do not launch new executors + * immediately after the old ones are removed. + * + * @return whether the application has previously registered with this Master. + */ + private def handleKillExecutors(appId: String, executorIds: Seq[Int]): Boolean = { + idToApp.get(appId) match { + case Some(appInfo) => + logInfo(s"Application $appId requests to kill executors: " + executorIds.mkString(", ")) + val (known, unknown) = executorIds.partition(appInfo.executors.contains) + known.foreach { executorId => + val desc = appInfo.executors(executorId) + appInfo.removeExecutor(desc) + killExecutor(desc) + } + if (unknown.nonEmpty) { + logWarning(s"Application $appId attempted to kill non-existent executors: " + + unknown.mkString(", ")) + } + schedule() + true + case None => + logWarning(s"Unregistered application $appId requested us to kill executors!") + false + } + } + + /** + * Cast the given executor IDs to integers and filter out the ones that fail. + * + * All executors IDs should be integers since we launched these executors. However, + * the kill interface on the driver side accepts arbitrary strings, so we need to + * handle non-integer executor IDs just to be safe. + */ + private def formatExecutorIds(executorIds: Seq[String]): Seq[Int] = { + executorIds.flatMap { executorId => + try { + Some(executorId.toInt) + } catch { + case e: NumberFormatException => + logError(s"Encountered executor with a non-integer ID: $executorId. Ignoring") + None } } } + /** + * Ask the worker on which the specified executor is launched to kill the executor. + */ + private def killExecutor(exec: ExecutorDesc): Unit = { + exec.worker.removeExecutor(exec) + exec.worker.endpoint.send(KillExecutor(masterUrl, exec.application.id, exec.id)) + exec.state = ExecutorState.KILLED + } + /** * Rebuild a new SparkUI from the given application's event logs. * Return the UI if successful, else None @@ -768,7 +925,7 @@ private[master] class Master( } val eventLogFilePrefix = EventLoggingListener.getLogPath( - eventLogDir, app.id, None, app.desc.eventLogCodec) + eventLogDir, app.id, app.desc.eventLogCodec) val fs = Utils.getHadoopFileSystem(eventLogDir, hadoopConf) val inProgressExists = fs.exists(new Path(eventLogFilePrefix + EventLoggingListener.IN_PROGRESS)) @@ -832,14 +989,14 @@ private[master] class Master( private def timeOutDeadWorkers() { // Copy the workers into an array so we don't modify the hashset while iterating through it val currentTime = System.currentTimeMillis() - val toRemove = workers.filter(_.lastHeartbeat < currentTime - WORKER_TIMEOUT).toArray + val toRemove = workers.filter(_.lastHeartbeat < currentTime - WORKER_TIMEOUT_MS).toArray for (worker <- toRemove) { if (worker.state != WorkerState.DEAD) { logWarning("Removing %s because we got no heartbeat in %d seconds".format( - worker.id, WORKER_TIMEOUT/1000)) + worker.id, WORKER_TIMEOUT_MS / 1000)) removeWorker(worker) } else { - if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT)) { + if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT_MS)) { workers -= worker // we've seen this DEAD worker in the UI, etc. for long enough; cull it } } @@ -862,7 +1019,7 @@ private[master] class Master( logInfo("Launching driver " + driver.id + " on worker " + worker.id) worker.addDriver(driver) driver.worker = Some(worker) - worker.actor ! LaunchDriver(driver.id, driver.desc) + worker.endpoint.send(LaunchDriver(driver.id, driver.desc)) driver.state = DriverState.RUNNING } @@ -891,57 +1048,33 @@ private[master] class Master( } private[deploy] object Master extends Logging { - val systemName = "sparkMaster" - private val actorName = "Master" + val SYSTEM_NAME = "sparkMaster" + val ENDPOINT_NAME = "Master" def main(argStrings: Array[String]) { SignalLogger.register(log) val conf = new SparkConf val args = new MasterArguments(argStrings, conf) - val (actorSystem, _, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf) - actorSystem.awaitTermination() - } - - /** - * Returns an `akka.tcp://...` URL for the Master actor given a sparkUrl `spark://host:port`. - * - * @throws SparkException if the url is invalid - */ - def toAkkaUrl(sparkUrl: String, protocol: String): String = { - val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) - AkkaUtils.address(protocol, systemName, host, port, actorName) - } - - /** - * Returns an akka `Address` for the Master actor given a sparkUrl `spark://host:port`. - * - * @throws SparkException if the url is invalid - */ - def toAkkaAddress(sparkUrl: String, protocol: String): Address = { - val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) - Address(protocol, systemName, host, port) + val (rpcEnv, _, _) = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, conf) + rpcEnv.awaitTermination() } /** - * Start the Master and return a four tuple of: - * (1) The Master actor system - * (2) The bound port - * (3) The web UI bound port - * (4) The REST server bound port, if any + * Start the Master and return a three tuple of: + * (1) The Master RpcEnv + * (2) The web UI bound port + * (3) The REST server bound port, if any */ - def startSystemAndActor( + def startRpcEnvAndEndpoint( host: String, port: Int, webUiPort: Int, - conf: SparkConf): (ActorSystem, Int, Int, Option[Int]) = { + conf: SparkConf): (RpcEnv, Int, Option[Int]) = { val securityMgr = new SecurityManager(conf) - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf, - securityManager = securityMgr) - val actor = actorSystem.actorOf( - Props(classOf[Master], host, boundPort, webUiPort, securityMgr, conf), actorName) - val timeout = RpcUtils.askTimeout(conf) - val portsRequest = actor.ask(BoundPortsRequest)(timeout) - val portsResponse = Await.result(portsRequest, timeout).asInstanceOf[BoundPortsResponse] - (actorSystem, boundPort, portsResponse.webUIPort, portsResponse.restPort) + val rpcEnv = RpcEnv.create(SYSTEM_NAME, host, port, conf, securityMgr) + val masterEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, + new Master(rpcEnv, rpcEnv.address, webUiPort, securityMgr, conf)) + val portsResponse = masterEndpoint.askWithRetry[BoundPortsResponse](BoundPortsRequest) + (rpcEnv, portsResponse.webUIPort, portsResponse.restPort) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index 435b9b12f83b..44cefbc77f08 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -85,6 +85,7 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) { * Print usage and exit JVM with the given exit code. */ private def printUsageAndExit(exitCode: Int) { + // scalastyle:off println System.err.println( "Usage: Master [options]\n" + "\n" + @@ -95,6 +96,7 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) { " --webui-port PORT Port for web UI (default: 8080)\n" + " --properties-file FILE Path to a custom Spark properties file.\n" + " Default is conf/spark-defaults.conf.") + // scalastyle:on println System.exit(exitCode) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala index 15c6296888f7..a952cee36eb4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala @@ -28,7 +28,7 @@ private[master] object MasterMessages { case object RevokedLeadership - // Actor System to Master + // Master to itself case object CheckForWorkerTimeOut @@ -38,5 +38,5 @@ private[master] object MasterMessages { case object BoundPortsRequest - case class BoundPortsResponse(actorPort: Int, webUIPort: Int, restPort: Option[Int]) + case class BoundPortsResponse(rpcEndpointPort: Int, webUIPort: Int, restPort: Option[Int]) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala index a03d460509e0..58a00bceee6a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.master import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.RpcEnv import scala.reflect.ClassTag @@ -80,8 +81,11 @@ abstract class PersistenceEngine { * Returns the persisted data sorted by their respective ids (which implies that they're * sorted by time of creation). */ - final def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { - (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_")) + final def readPersistedData( + rpcEnv: RpcEnv): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { + rpcEnv.deserialize { () => + (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_")) + } } def close() {} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala index 351db8fab204..c4c3283fb73f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala @@ -17,10 +17,9 @@ package org.apache.spark.deploy.master -import akka.serialization.Serialization - import org.apache.spark.{Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.serializer.Serializer /** * ::DeveloperApi:: @@ -30,7 +29,7 @@ import org.apache.spark.annotation.DeveloperApi * */ @DeveloperApi -abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serialization) { +abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serializer) { /** * PersistenceEngine defines how the persistent data(Information about worker, driver etc..) @@ -49,7 +48,7 @@ abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serial * LeaderAgent in this case is a no-op. Since leader is forever leader as the actual * recovery is made by restoring from filesystem. */ -private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serialization) +private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serializer) extends StandaloneRecoveryModeFactory(conf, serializer) with Logging { val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "") @@ -64,7 +63,7 @@ private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: } } -private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serialization) +private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serializer) extends StandaloneRecoveryModeFactory(conf, serializer) { def createPersistenceEngine(): PersistenceEngine = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index 9b3d48c6edc8..f75196660520 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -19,9 +19,7 @@ package org.apache.spark.deploy.master import scala.collection.mutable -import akka.actor.ActorRef - -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] class WorkerInfo( @@ -30,7 +28,7 @@ private[spark] class WorkerInfo( val port: Int, val cores: Int, val memory: Int, - val actor: ActorRef, + val endpoint: RpcEndpointRef, val webUiPort: Int, val publicAddress: String) extends Serializable { @@ -107,4 +105,6 @@ private[spark] class WorkerInfo( def setState(state: WorkerState.Value): Unit = { this.state = state } + + def isAlive(): Boolean = this.state == WorkerState.ALIVE } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala index 52758d6a7c4b..d317206a614f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -17,15 +17,12 @@ package org.apache.spark.deploy.master -import akka.actor.ActorRef - import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.deploy.master.MasterMessages._ import org.apache.curator.framework.CuratorFramework import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch} import org.apache.spark.deploy.SparkCuratorUtil -private[master] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElectable, +private[master] class ZooKeeperLeaderElectionAgent(val masterInstance: LeaderElectable, conf: SparkConf) extends LeaderLatchListener with LeaderElectionAgent with Logging { val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/leader_election" @@ -76,10 +73,10 @@ private[master] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElecta private def updateLeadershipStatus(isLeader: Boolean) { if (isLeader && status == LeadershipStatus.NOT_LEADER) { status = LeadershipStatus.LEADER - masterActor.electedLeader() + masterInstance.electedLeader() } else if (!isLeader && status == LeadershipStatus.LEADER) { status = LeadershipStatus.NOT_LEADER - masterActor.revokedLeadership() + masterInstance.revokedLeadership() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 328d95a7a0c6..540e802420ce 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -17,9 +17,9 @@ package org.apache.spark.deploy.master -import akka.serialization.Serialization +import java.nio.ByteBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.curator.framework.CuratorFramework @@ -27,9 +27,10 @@ import org.apache.zookeeper.CreateMode import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.SparkCuratorUtil +import org.apache.spark.serializer.Serializer -private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization) +private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer: Serializer) extends PersistenceEngine with Logging { @@ -48,8 +49,8 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializat } override def read[T: ClassTag](prefix: String): Seq[T] = { - val file = zk.getChildren.forPath(WORKING_DIR).filter(_.startsWith(prefix)) - file.map(deserializeFromFile[T]).flatten + zk.getChildren.forPath(WORKING_DIR).asScala + .filter(_.startsWith(prefix)).map(deserializeFromFile[T]).flatten } override def close() { @@ -57,17 +58,16 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializat } private def serializeIntoFile(path: String, value: AnyRef) { - val serializer = serialization.findSerializerFor(value) - val serialized = serializer.toBinary(value) - zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized) + val serialized = serializer.newInstance().serialize(value) + val bytes = new Array[Byte](serialized.remaining()) + serialized.get(bytes) + zk.create().withMode(CreateMode.PERSISTENT).forPath(path, bytes) } private def deserializeFromFile[T](filename: String)(implicit m: ClassTag[T]): Option[T] = { val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename) - val clazz = m.runtimeClass.asInstanceOf[Class[T]] - val serializer = serialization.serializerFor(clazz) try { - Some(serializer.fromBinary(fileData).asInstanceOf[T]) + Some(serializer.newInstance().deserialize[T](ByteBuffer.wrap(fileData))) } catch { case e: Exception => { logWarning("Exception while reading persisted file, deleting", e) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 06e265f99e23..e28e7e379ac9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -19,11 +19,8 @@ package org.apache.spark.deploy.master.ui import javax.servlet.http.HttpServletRequest -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask - import org.apache.spark.deploy.ExecutorState import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master.ExecutorDesc @@ -32,14 +29,12 @@ import org.apache.spark.util.Utils private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") { - private val master = parent.masterActorRef - private val timeout = parent.timeout + private val master = parent.masterEndpointRef /** Executor details for a particular application */ def render(request: HttpServletRequest): Seq[Node] = { val appId = request.getParameter("appId") - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, timeout) + val state = master.askWithRetry[MasterStateResponse](RequestMasterState) val app = state.activeApps.find(_.id == appId).getOrElse({ state.completedApps.find(_.id == appId).getOrElse(null) }) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 6a7c74020bac..c3e20ebf8d6e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -19,25 +19,21 @@ package org.apache.spark.deploy.master.ui import javax.servlet.http.HttpServletRequest -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask import org.json4s.JValue import org.apache.spark.deploy.JsonProtocol -import org.apache.spark.deploy.DeployMessages.{RequestKillDriver, MasterStateResponse, RequestMasterState} +import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver, MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master._ import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { - private val master = parent.masterActorRef - private val timeout = parent.timeout + private val master = parent.masterEndpointRef def getMasterState: MasterStateResponse = { - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - Await.result(stateFuture, timeout) + master.askWithRetry[MasterStateResponse](RequestMasterState) } override def renderJson(request: HttpServletRequest): JValue = { @@ -53,7 +49,9 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } def handleDriverKillRequest(request: HttpServletRequest): Unit = { - handleKillRequest(request, id => { master ! RequestKillDriver(id) }) + handleKillRequest(request, id => { + master.ask[KillDriverResponse](RequestKillDriver(id)) + }) } private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 2111a8581f2e..6174fc11f83d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -23,7 +23,6 @@ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationsListResource UIRoot} import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.RpcUtils /** * Web UI server for the standalone master. @@ -33,8 +32,7 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging with UIRoot { - val masterActorRef = master.self - val timeout = RpcUtils.askTimeout(master.conf) + val masterEndpointRef = master.self val killEnabled = master.conf.getBoolean("spark.ui.killEnabled", true) val masterPage = new MasterPage(this) diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala index 894cb78d8591..5accaf78d0a5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -54,7 +54,9 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: case ("--master" | "-m") :: value :: tail => if (!value.startsWith("mesos://")) { + // scalastyle:off println System.err.println("Cluster dispatcher only supports mesos (uri begins with mesos://)") + // scalastyle:on println System.exit(1) } masterUrl = value.stripPrefix("mesos://") @@ -73,7 +75,9 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: case Nil => { if (masterUrl == null) { + // scalastyle:off println System.err.println("--master is required") + // scalastyle:on println printUsageAndExit(1) } } @@ -83,6 +87,7 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: } private def printUsageAndExit(exitCode: Int): Unit = { + // scalastyle:off println System.err.println( "Usage: MesosClusterDispatcher [options]\n" + "\n" + @@ -96,6 +101,7 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: " Zookeeper for persistence\n" + " --properties-file FILE Path to a custom Spark properties file.\n" + " Default is conf/spark-defaults.conf.") + // scalastyle:on println System.exit(exitCode) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala new file mode 100644 index 000000000000..12337a940a41 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import java.net.SocketAddress + +import scala.collection.mutable + +import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.deploy.ExternalShuffleService +import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage +import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver +import org.apache.spark.network.util.TransportConf + +/** + * An RPC endpoint that receives registration requests from Spark drivers running on Mesos. + * It detects driver termination and calls the cleanup callback to [[ExternalShuffleService]]. + */ +private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportConf) + extends ExternalShuffleBlockHandler(transportConf, null) with Logging { + + // Stores a map of driver socket addresses to app ids + private val connectedApps = new mutable.HashMap[SocketAddress, String] + + protected override def handleMessage( + message: BlockTransferMessage, + client: TransportClient, + callback: RpcResponseCallback): Unit = { + message match { + case RegisterDriverParam(appId) => + val address = client.getSocketAddress + logDebug(s"Received registration request from app $appId (remote address $address).") + if (connectedApps.contains(address)) { + val existingAppId = connectedApps(address) + if (!existingAppId.equals(appId)) { + logError(s"A new app '$appId' has connected to existing address $address, " + + s"removing previously registered app '$existingAppId'.") + applicationRemoved(existingAppId, true) + } + } + connectedApps(address) = appId + callback.onSuccess(new Array[Byte](0)) + case _ => super.handleMessage(message, client, callback) + } + } + + /** + * On connection termination, clean up shuffle files written by the associated application. + */ + override def connectionTerminated(client: TransportClient): Unit = { + val address = client.getSocketAddress + if (connectedApps.contains(address)) { + val appId = connectedApps(address) + logInfo(s"Application $appId disconnected (address was $address).") + applicationRemoved(appId, true /* cleanupLocalDirs */) + connectedApps.remove(address) + } else { + logWarning(s"Unknown $address disconnected.") + } + } + + /** An extractor object for matching [[RegisterDriver]] message. */ + private object RegisterDriverParam { + def unapply(r: RegisterDriver): Option[String] = Some(r.getAppId) + } +} + +/** + * A wrapper of [[ExternalShuffleService]] that provides an additional endpoint for drivers + * to associate with. This allows the shuffle service to detect when a driver is terminated + * and can clean up the associated shuffle files. + */ +private[mesos] class MesosExternalShuffleService(conf: SparkConf, securityManager: SecurityManager) + extends ExternalShuffleService(conf, securityManager) { + + protected override def newShuffleBlockHandler( + conf: TransportConf): ExternalShuffleBlockHandler = { + new MesosExternalShuffleBlockHandler(conf) + } +} + +private[spark] object MesosExternalShuffleService extends Logging { + + def main(args: Array[String]): Unit = { + ExternalShuffleService.main(args, + (conf: SparkConf, sm: SecurityManager) => new MesosExternalShuffleService(conf, sm)) + } +} + + diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 1fe956320a1b..957a928bc402 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -392,15 +392,14 @@ private[spark] object RestSubmissionClient { mainClass: String, appArgs: Array[String], conf: SparkConf, - env: Map[String, String] = sys.env): SubmitRestProtocolResponse = { + env: Map[String, String] = Map()): SubmitRestProtocolResponse = { val master = conf.getOption("spark.master").getOrElse { throw new IllegalArgumentException("'spark.master' must be set.") } val sparkProperties = conf.getAll.toMap - val environmentVariables = env.filter { case (k, _) => k.startsWith("SPARK_") } val client = new RestSubmissionClient(master) val submitRequest = client.constructSubmitRequest( - appResource, mainClass, appArgs, sparkProperties, environmentVariables) + appResource, mainClass, appArgs, sparkProperties, env) client.createSubmission(submitRequest) } @@ -413,6 +412,16 @@ private[spark] object RestSubmissionClient { val mainClass = args(1) val appArgs = args.slice(2, args.size) val conf = new SparkConf - run(appResource, mainClass, appArgs, conf) + val env = filterSystemEnvironment(sys.env) + run(appResource, mainClass, appArgs, conf, env) + } + + /** + * Filter non-spark environment variables from any environment. + */ + private[rest] def filterSystemEnvironment(env: Map[String, String]): Map[String, String] = { + env.filter { case (k, _) => + (k.startsWith("SPARK_") && k != "SPARK_ENV_LOADED") || k.startsWith("MESOS_") + } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 502b9bb701cc..d5b9bcab1423 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -20,10 +20,10 @@ package org.apache.spark.deploy.rest import java.io.File import javax.servlet.http.HttpServletResponse -import akka.actor.ActorRef import org.apache.spark.deploy.ClientArguments._ import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} -import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.util.Utils import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} /** @@ -45,35 +45,34 @@ import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} * @param host the address this server should bind to * @param requestedPort the port this server will attempt to bind to * @param masterConf the conf used by the Master - * @param masterActor reference to the Master actor to which requests can be sent + * @param masterEndpoint reference to the Master endpoint to which requests can be sent * @param masterUrl the URL of the Master new drivers will attempt to connect to */ private[deploy] class StandaloneRestServer( host: String, requestedPort: Int, masterConf: SparkConf, - masterActor: ActorRef, + masterEndpoint: RpcEndpointRef, masterUrl: String) extends RestSubmissionServer(host, requestedPort, masterConf) { protected override val submitRequestServlet = - new StandaloneSubmitRequestServlet(masterActor, masterUrl, masterConf) + new StandaloneSubmitRequestServlet(masterEndpoint, masterUrl, masterConf) protected override val killRequestServlet = - new StandaloneKillRequestServlet(masterActor, masterConf) + new StandaloneKillRequestServlet(masterEndpoint, masterConf) protected override val statusRequestServlet = - new StandaloneStatusRequestServlet(masterActor, masterConf) + new StandaloneStatusRequestServlet(masterEndpoint, masterConf) } /** * A servlet for handling kill requests passed to the [[StandaloneRestServer]]. */ -private[rest] class StandaloneKillRequestServlet(masterActor: ActorRef, conf: SparkConf) +private[rest] class StandaloneKillRequestServlet(masterEndpoint: RpcEndpointRef, conf: SparkConf) extends KillRequestServlet { protected def handleKill(submissionId: String): KillSubmissionResponse = { - val askTimeout = RpcUtils.askTimeout(conf) - val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse]( - DeployMessages.RequestKillDriver(submissionId), masterActor, askTimeout) + val response = masterEndpoint.askWithRetry[DeployMessages.KillDriverResponse]( + DeployMessages.RequestKillDriver(submissionId)) val k = new KillSubmissionResponse k.serverSparkVersion = sparkVersion k.message = response.message @@ -86,13 +85,12 @@ private[rest] class StandaloneKillRequestServlet(masterActor: ActorRef, conf: Sp /** * A servlet for handling status requests passed to the [[StandaloneRestServer]]. */ -private[rest] class StandaloneStatusRequestServlet(masterActor: ActorRef, conf: SparkConf) +private[rest] class StandaloneStatusRequestServlet(masterEndpoint: RpcEndpointRef, conf: SparkConf) extends StatusRequestServlet { protected def handleStatus(submissionId: String): SubmissionStatusResponse = { - val askTimeout = RpcUtils.askTimeout(conf) - val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse]( - DeployMessages.RequestDriverStatus(submissionId), masterActor, askTimeout) + val response = masterEndpoint.askWithRetry[DeployMessages.DriverStatusResponse]( + DeployMessages.RequestDriverStatus(submissionId)) val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) } val d = new SubmissionStatusResponse d.serverSparkVersion = sparkVersion @@ -110,7 +108,7 @@ private[rest] class StandaloneStatusRequestServlet(masterActor: ActorRef, conf: * A servlet for handling submit requests passed to the [[StandaloneRestServer]]. */ private[rest] class StandaloneSubmitRequestServlet( - masterActor: ActorRef, + masterEndpoint: RpcEndpointRef, masterUrl: String, conf: SparkConf) extends SubmitRequestServlet { @@ -175,10 +173,9 @@ private[rest] class StandaloneSubmitRequestServlet( responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { requestMessage match { case submitRequest: CreateSubmissionRequest => - val askTimeout = RpcUtils.askTimeout(conf) val driverDescription = buildDriverDescription(submitRequest) - val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( - DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) + val response = masterEndpoint.askWithRetry[DeployMessages.SubmitDriverResponse]( + DeployMessages.RequestSubmitDriver(driverDescription)) val submitResponse = new CreateSubmissionResponse submitResponse.serverSparkVersion = sparkVersion submitResponse.message = response.message diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index e6615a3174ce..ef5a7e35ad56 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -128,7 +128,7 @@ private[spark] object SubmitRestProtocolMessage { */ def fromJson(json: String): SubmitRestProtocolMessage = { val className = parseAction(json) - val clazz = Class.forName(packagePrefix + "." + className) + val clazz = Utils.classForName(packagePrefix + "." + className) .asSubclass[SubmitRestProtocolMessage](classOf[SubmitRestProtocolMessage]) fromJson(json, clazz) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index 8198296eeb34..868cc35d06ef 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -59,7 +59,7 @@ private[mesos] class MesosSubmitRequestServlet( extends SubmitRequestServlet { private val DEFAULT_SUPERVISE = false - private val DEFAULT_MEMORY = 512 // mb + private val DEFAULT_MEMORY = Utils.DEFAULT_DRIVER_MEM_MB // mb private val DEFAULT_CORES = 1.0 private val nextDriverNumber = new AtomicLong(0) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index 45a3f4304543..ce02ee203a4b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -18,9 +18,8 @@ package org.apache.spark.deploy.worker import java.io.{File, FileOutputStream, InputStream, IOException} -import java.lang.System._ -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.Map import org.apache.spark.Logging @@ -62,7 +61,7 @@ object CommandUtils extends Logging { // SPARK-698: do not call the run.cmd script, as process.destroy() // fails to kill a process tree on Windows val cmd = new WorkerCommandBuilder(sparkHome, memory, command).buildCommand() - cmd.toSeq ++ Seq(command.mainClass) ++ command.arguments + cmd.asScala ++ Seq(command.mainClass) ++ command.arguments } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 1386055eb8c4..89159ff5e2b3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -19,9 +19,8 @@ package org.apache.spark.deploy.worker import java.io._ -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ -import akka.actor.ActorRef import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.hadoop.fs.Path @@ -31,6 +30,7 @@ import org.apache.spark.deploy.{DriverDescription, SparkHadoopUtil} import org.apache.spark.deploy.DeployMessages.DriverStateChanged import org.apache.spark.deploy.master.DriverState import org.apache.spark.deploy.master.DriverState.DriverState +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.{Utils, Clock, SystemClock} /** @@ -43,7 +43,7 @@ private[deploy] class DriverRunner( val workDir: File, val sparkHome: File, val driverDesc: DriverDescription, - val worker: ActorRef, + val worker: RpcEndpointRef, val workerUrl: String, val securityManager: SecurityManager) extends Logging { @@ -107,7 +107,7 @@ private[deploy] class DriverRunner( finalState = Some(state) - worker ! DriverStateChanged(driverId, state, finalException) + worker.send(DriverStateChanged(driverId, state, finalException)) } }.start() } @@ -172,8 +172,8 @@ private[deploy] class DriverRunner( CommandUtils.redirectStream(process.getInputStream, stdout) val stderr = new File(baseDir, "stderr") - val header = "Launch Command: %s\n%s\n\n".format( - builder.command.mkString("\"", "\" \"", "\""), "=" * 40) + val formattedCommand = builder.command.asScala.mkString("\"", "\" \"", "\"") + val header = "Launch Command: %s\n%s\n\n".format(formattedCommand, "=" * 40) Files.append(header, stderr, UTF_8) CommandUtils.redirectStream(process.getErrorStream, stderr) } @@ -229,6 +229,6 @@ private[deploy] trait ProcessBuilderLike { private[deploy] object ProcessBuilderLike { def apply(processBuilder: ProcessBuilder): ProcessBuilderLike = new ProcessBuilderLike { override def start(): Process = processBuilder.start() - override def command: Seq[String] = processBuilder.command() + override def command: Seq[String] = processBuilder.command().asScala } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index d1a12b01e78f..6799f78ec0c1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -53,14 +53,16 @@ object DriverWrapper { Thread.currentThread.setContextClassLoader(loader) // Delegate to supplied main class - val clazz = Class.forName(mainClass, true, loader) + val clazz = Utils.classForName(mainClass) val mainMethod = clazz.getMethod("main", classOf[Array[String]]) mainMethod.invoke(null, extraArgs.toArray[String]) rpcEnv.shutdown() case _ => + // scalastyle:off println System.err.println("Usage: DriverWrapper [options]") + // scalastyle:on println System.exit(-1) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index fff17e109504..3aef0515cbf6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -19,16 +19,16 @@ package org.apache.spark.deploy.worker import java.io._ -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ -import akka.actor.ActorRef import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.util.logging.FileAppender /** @@ -41,7 +41,7 @@ private[deploy] class ExecutorRunner( val appDesc: ApplicationDescription, val cores: Int, val memory: Int, - val worker: ActorRef, + val worker: RpcEndpointRef, val workerId: String, val host: String, val webUiPort: Int, @@ -70,7 +70,8 @@ private[deploy] class ExecutorRunner( } workerThread.start() // Shutdown hook that kills actors on shutdown. - shutdownHook = Utils.addShutdownHook { () => killProcess(Some("Worker shutting down")) } + shutdownHook = ShutdownHookManager.addShutdownHook { () => + killProcess(Some("Worker shutting down")) } } /** @@ -91,7 +92,7 @@ private[deploy] class ExecutorRunner( process.destroy() exitCode = Some(process.waitFor()) } - worker ! ExecutorStateChanged(appId, execId, state, message, exitCode) + worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode)) } /** Stop this executor runner, including killing the process it launched */ @@ -102,7 +103,7 @@ private[deploy] class ExecutorRunner( workerThread = null state = ExecutorState.KILLED try { - Utils.removeShutdownHook(shutdownHook) + ShutdownHookManager.removeShutdownHook(shutdownHook) } catch { case e: IllegalStateException => None } @@ -128,7 +129,8 @@ private[deploy] class ExecutorRunner( val builder = CommandUtils.buildProcessBuilder(appDesc.command, new SecurityManager(conf), memory, sparkHome.getAbsolutePath, substituteVariables) val command = builder.command() - logInfo("Launch command: " + command.mkString("\"", "\" \"", "\"")) + val formattedCommand = command.asScala.mkString("\"", "\" \"", "\"") + logInfo(s"Launch command: $formattedCommand") builder.directory(executorDir) builder.environment.put("SPARK_EXECUTOR_DIRS", appLocalDirs.mkString(File.pathSeparator)) @@ -144,7 +146,7 @@ private[deploy] class ExecutorRunner( process = builder.start() val header = "Spark Executor Command: %s\n%s\n\n".format( - command.mkString("\"", "\" \"", "\""), "=" * 40) + formattedCommand, "=" * 40) // Redirect its stdout and stderr to files val stdout = new File(executorDir, "stdout") @@ -159,7 +161,7 @@ private[deploy] class ExecutorRunner( val exitCode = process.waitFor() state = ExecutorState.EXITED val message = "Command exited with code " + exitCode - worker ! ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode)) + worker.send(ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode))) } catch { case interrupted: InterruptedException => { logInfo("Runner thread for executor " + fullId + " interrupted") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index ebc6cd76c6af..770927c80f7a 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -21,15 +21,13 @@ import java.io.File import java.io.IOException import java.text.SimpleDateFormat import java.util.{UUID, Date} +import java.util.concurrent._ +import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} -import scala.collection.JavaConversions._ -import scala.collection.mutable.{HashMap, HashSet} -import scala.concurrent.duration._ -import scala.language.postfixOps +import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap} +import scala.concurrent.ExecutionContext import scala.util.Random - -import akka.actor._ -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import scala.util.control.NonFatal import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.{Command, ExecutorDescription, ExecutorState} @@ -38,32 +36,39 @@ import org.apache.spark.deploy.ExternalShuffleService import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} - -/** - * @param masterAkkaUrls Each url should be a valid akka url. - */ -private[worker] class Worker( - host: String, - port: Int, +import org.apache.spark.rpc._ +import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} + +private[deploy] class Worker( + override val rpcEnv: RpcEnv, webUiPort: Int, cores: Int, memory: Int, - masterAkkaUrls: Array[String], - actorSystemName: String, - actorName: String, + masterRpcAddresses: Array[RpcAddress], + systemName: String, + endpointName: String, workDirPath: String = null, val conf: SparkConf, val securityMgr: SecurityManager) - extends Actor with ActorLogReceive with Logging { - import context.dispatcher + extends ThreadSafeRpcEndpoint with Logging { + + private val host = rpcEnv.address.host + private val port = rpcEnv.address.port Utils.checkHost(host, "Expected hostname") assert (port > 0) + // A scheduled executor used to send messages at the specified time. + private val forwordMessageScheduler = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("worker-forward-message-scheduler") + + // A separated thread to clean up the workDir. Used to provide the implicit parameter of `Future` + // methods. + private val cleanupThreadExecutor = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonSingleThreadExecutor("worker-cleanup-thread")) + // For worker and executor IDs private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") - // Send a heartbeat every (heartbeat timeout) / 4 milliseconds private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 @@ -79,32 +84,26 @@ private[worker] class Worker( val randomNumberGenerator = new Random(UUID.randomUUID.getMostSignificantBits) randomNumberGenerator.nextDouble + FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND } - private val INITIAL_REGISTRATION_RETRY_INTERVAL = (math.round(10 * - REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds - private val PROLONGED_REGISTRATION_RETRY_INTERVAL = (math.round(60 - * REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds + private val INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS = (math.round(10 * + REGISTRATION_RETRY_FUZZ_MULTIPLIER)) + private val PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS = (math.round(60 + * REGISTRATION_RETRY_FUZZ_MULTIPLIER)) private val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", false) // How often worker will clean up old app folders private val CLEANUP_INTERVAL_MILLIS = conf.getLong("spark.worker.cleanup.interval", 60 * 30) * 1000 // TTL for app folders/data; after TTL expires it will be cleaned up - private val APP_DATA_RETENTION_SECS = + private val APP_DATA_RETENTION_SECONDS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) private val testing: Boolean = sys.props.contains("spark.testing") - private var master: ActorSelection = null - private var masterAddress: Address = null + private var master: Option[RpcEndpointRef] = None private var activeMasterUrl: String = "" private[worker] var activeMasterWebUiUrl : String = "" - private val akkaUrl = AkkaUtils.address( - AkkaUtils.protocol(context.system), - actorSystemName, - host, - port, - actorName) - @volatile private var registered = false - @volatile private var connected = false + private val workerUri = rpcEnv.uriOf(systemName, rpcEnv.address, endpointName) + private var registered = false + private var connected = false private val workerId = generateWorkerId() private val sparkHome = if (testing) { @@ -115,13 +114,18 @@ private[worker] class Worker( } var workDir: File = null - val finishedExecutors = new HashMap[String, ExecutorRunner] + val finishedExecutors = new LinkedHashMap[String, ExecutorRunner] val drivers = new HashMap[String, DriverRunner] val executors = new HashMap[String, ExecutorRunner] - val finishedDrivers = new HashMap[String, DriverRunner] + val finishedDrivers = new LinkedHashMap[String, DriverRunner] val appDirectories = new HashMap[String, Seq[String]] val finishedApps = new HashSet[String] + val retainedExecutors = conf.getInt("spark.worker.ui.retainedExecutors", + WorkerWebUI.DEFAULT_RETAINED_EXECUTORS) + val retainedDrivers = conf.getInt("spark.worker.ui.retainedDrivers", + WorkerWebUI.DEFAULT_RETAINED_DRIVERS) + // The shuffle service is not actually started unless configured. private val shuffleService = new ExternalShuffleService(conf, securityMgr) @@ -136,7 +140,18 @@ private[worker] class Worker( private val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) private val workerSource = new WorkerSource(this) - private var registrationRetryTimer: Option[Cancellable] = None + private var registerMasterFutures: Array[JFuture[_]] = null + private var registrationRetryTimer: Option[JScheduledFuture[_]] = None + + // A thread pool for registering with masters. Because registering with a master is a blocking + // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same + // time so that we can register with all masters. + private val registerMasterThreadPool = new ThreadPoolExecutor( + 0, + masterRpcAddresses.size, // Make sure we can register with all masters at the same time + 60L, TimeUnit.SECONDS, + new SynchronousQueue[Runnable](), + ThreadUtils.namedThreadFactory("worker-register-master-threadpool")) var coresUsed = 0 var memoryUsed = 0 @@ -162,14 +177,13 @@ private[worker] class Worker( } } - override def preStart() { + override def onStart() { assert(!registered) logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format( host, port, cores, Utils.megabytesToString(memory))) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") logInfo("Spark home: " + sparkHome) createWorkDir() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) shuffleService.startIfEnabled() webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() @@ -181,38 +195,45 @@ private[worker] class Worker( metricsSystem.getServletHandlers.foreach(webUi.attachHandler) } - private def changeMaster(url: String, uiUrl: String) { + private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String) { // activeMasterUrl it's a valid Spark url since we receive it from master. - activeMasterUrl = url + activeMasterUrl = masterRef.address.toSparkURL activeMasterWebUiUrl = uiUrl - master = context.actorSelection( - Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(context.system))) - masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(context.system)) + master = Some(masterRef) connected = true // Cancel any outstanding re-registration attempts because we found a new master - registrationRetryTimer.foreach(_.cancel()) - registrationRetryTimer = None + cancelLastRegistrationRetry() } - private def tryRegisterAllMasters() { - for (masterAkkaUrl <- masterAkkaUrls) { - logInfo("Connecting to master " + masterAkkaUrl + "...") - val actor = context.actorSelection(masterAkkaUrl) - actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort, publicAddress) + private def tryRegisterAllMasters(): Array[JFuture[_]] = { + masterRpcAddresses.map { masterAddress => + registerMasterThreadPool.submit(new Runnable { + override def run(): Unit = { + try { + logInfo("Connecting to master " + masterAddress + "...") + val masterEndpoint = + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + masterEndpoint.send(RegisterWorker( + workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress)) + } catch { + case ie: InterruptedException => // Cancelled + case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) + } + } + }) } } /** * Re-register with the master because a network failure or a master failure has occurred. * If the re-registration attempt threshold is exceeded, the worker exits with error. - * Note that for thread-safety this should only be called from the actor. + * Note that for thread-safety this should only be called from the rpcEndpoint. */ private def reregisterWithMaster(): Unit = { Utils.tryOrExit { connectionAttemptCount += 1 if (registered) { - registrationRetryTimer.foreach(_.cancel()) - registrationRetryTimer = None + cancelLastRegistrationRetry() } else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) { logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)") /** @@ -235,21 +256,48 @@ private[worker] class Worker( * still not safe if the old master recovers within this interval, but this is a much * less likely scenario. */ - if (master != null) { - master ! RegisterWorker( - workerId, host, port, cores, memory, webUi.boundPort, publicAddress) - } else { - // We are retrying the initial registration - tryRegisterAllMasters() + master match { + case Some(masterRef) => + // registered == false && master != None means we lost the connection to master, so + // masterRef cannot be used and we need to recreate it again. Note: we must not set + // master to None due to the above comments. + if (registerMasterFutures != null) { + registerMasterFutures.foreach(_.cancel(true)) + } + val masterAddress = masterRef.address + registerMasterFutures = Array(registerMasterThreadPool.submit(new Runnable { + override def run(): Unit = { + try { + logInfo("Connecting to master " + masterAddress + "...") + val masterEndpoint = + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + masterEndpoint.send(RegisterWorker( + workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress)) + } catch { + case ie: InterruptedException => // Cancelled + case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) + } + } + })) + case None => + if (registerMasterFutures != null) { + registerMasterFutures.foreach(_.cancel(true)) + } + // We are retrying the initial registration + registerMasterFutures = tryRegisterAllMasters() } // We have exceeded the initial registration retry threshold // All retries from now on should use a higher interval if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) { - registrationRetryTimer.foreach(_.cancel()) - registrationRetryTimer = Some { - context.system.scheduler.schedule(PROLONGED_REGISTRATION_RETRY_INTERVAL, - PROLONGED_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster) - } + registrationRetryTimer.foreach(_.cancel(true)) + registrationRetryTimer = Some( + forwordMessageScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(ReregisterWithMaster) + } + }, PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS, + PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS, + TimeUnit.SECONDS)) } } else { logError("All masters are unresponsive! Giving up.") @@ -258,41 +306,68 @@ private[worker] class Worker( } } + /** + * Cancel last registeration retry, or do nothing if no retry + */ + private def cancelLastRegistrationRetry(): Unit = { + if (registerMasterFutures != null) { + registerMasterFutures.foreach(_.cancel(true)) + registerMasterFutures = null + } + registrationRetryTimer.foreach(_.cancel(true)) + registrationRetryTimer = None + } + private def registerWithMaster() { - // DisassociatedEvent may be triggered multiple times, so don't attempt registration + // onDisconnected may be triggered multiple times, so don't attempt registration // if there are outstanding registration attempts scheduled. registrationRetryTimer match { case None => registered = false - tryRegisterAllMasters() + registerMasterFutures = tryRegisterAllMasters() connectionAttemptCount = 0 - registrationRetryTimer = Some { - context.system.scheduler.schedule(INITIAL_REGISTRATION_RETRY_INTERVAL, - INITIAL_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster) - } + registrationRetryTimer = Some(forwordMessageScheduler.scheduleAtFixedRate( + new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(ReregisterWithMaster) + } + }, + INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS, + INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS, + TimeUnit.SECONDS)) case Some(_) => logInfo("Not spawning another attempt to register with the master, since there is an" + " attempt scheduled already.") } } - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case RegisteredWorker(masterUrl, masterWebUiUrl) => - logInfo("Successfully registered with master " + masterUrl) + override def receive: PartialFunction[Any, Unit] = { + case RegisteredWorker(masterRef, masterWebUiUrl) => + logInfo("Successfully registered with master " + masterRef.address.toSparkURL) registered = true - changeMaster(masterUrl, masterWebUiUrl) - context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis, self, SendHeartbeat) + changeMaster(masterRef, masterWebUiUrl) + forwordMessageScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(SendHeartbeat) + } + }, 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS) if (CLEANUP_ENABLED) { logInfo(s"Worker cleanup enabled; old application directories will be deleted in: $workDir") - context.system.scheduler.schedule(CLEANUP_INTERVAL_MILLIS millis, - CLEANUP_INTERVAL_MILLIS millis, self, WorkDirCleanup) + forwordMessageScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(WorkDirCleanup) + } + }, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS) } case SendHeartbeat => - if (connected) { master ! Heartbeat(workerId) } + if (connected) { sendToMaster(Heartbeat(workerId, self)) } case WorkDirCleanup => - // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker actor + // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker + // rpcEndpoint. + // Copy ids so that it can be used in the cleanup thread. + val appIds = executors.values.map(_.appId).toSet val cleanupFuture = concurrent.future { val appDirs = workDir.listFiles() if (appDirs == null) { @@ -302,27 +377,27 @@ private[worker] class Worker( // the directory is used by an application - check that the application is not running // when cleaning up val appIdFromDir = dir.getName - val isAppStillRunning = executors.values.map(_.appId).contains(appIdFromDir) + val isAppStillRunning = appIds.contains(appIdFromDir) dir.isDirectory && !isAppStillRunning && - !Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECS) + !Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECONDS) }.foreach { dir => logInfo(s"Removing directory: ${dir.getPath}") Utils.deleteRecursively(dir) } - } + }(cleanupThreadExecutor) - cleanupFuture onFailure { + cleanupFuture.onFailure { case e: Throwable => logError("App dir cleanup failed: " + e.getMessage, e) - } + }(cleanupThreadExecutor) - case MasterChanged(masterUrl, masterWebUiUrl) => - logInfo("Master has changed, new master is at " + masterUrl) - changeMaster(masterUrl, masterWebUiUrl) + case MasterChanged(masterRef, masterWebUiUrl) => + logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) + changeMaster(masterRef, masterWebUiUrl) val execs = executors.values. map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state)) - sender ! WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq) + masterRef.send(WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq)) case RegisterWorkerFailed(message) => if (!registered) { @@ -352,7 +427,9 @@ private[worker] class Worker( // application finishes. val appLocalDirs = appDirectories.get(appId).getOrElse { Utils.getOrCreateLocalRootDirs(conf).map { dir => - Utils.createDirectory(dir, namePrefix = "executor").getAbsolutePath() + val appDir = Utils.createDirectory(dir, namePrefix = "executor") + Utils.chmod700(appDir) + appDir.getAbsolutePath() }.toSeq } appDirectories(appId) = appLocalDirs @@ -369,14 +446,14 @@ private[worker] class Worker( publicAddress, sparkHome, executorDir, - akkaUrl, + workerUri, conf, appLocalDirs, ExecutorState.LOADING) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ memoryUsed += memory_ - master ! ExecutorStateChanged(appId, execId, manager.state, None, None) + sendToMaster(ExecutorStateChanged(appId, execId, manager.state, None, None)) } catch { case e: Exception => { logError(s"Failed to launch executor $appId/$execId for ${appDesc.name}.", e) @@ -384,32 +461,14 @@ private[worker] class Worker( executors(appId + "/" + execId).kill() executors -= appId + "/" + execId } - master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, - Some(e.toString), None) + sendToMaster(ExecutorStateChanged(appId, execId, ExecutorState.FAILED, + Some(e.toString), None)) } } } - case ExecutorStateChanged(appId, execId, state, message, exitStatus) => - master ! ExecutorStateChanged(appId, execId, state, message, exitStatus) - val fullId = appId + "/" + execId - if (ExecutorState.isFinished(state)) { - executors.get(fullId) match { - case Some(executor) => - logInfo("Executor " + fullId + " finished with state " + state + - message.map(" message " + _).getOrElse("") + - exitStatus.map(" exitStatus " + _).getOrElse("")) - executors -= fullId - finishedExecutors(fullId) = executor - coresUsed -= executor.cores - memoryUsed -= executor.memory - case None => - logInfo("Unknown Executor " + fullId + " finished with state " + state + - message.map(" message " + _).getOrElse("") + - exitStatus.map(" exitStatus " + _).getOrElse("")) - } - maybeCleanupApplication(appId) - } + case executorStateChanged @ ExecutorStateChanged(appId, execId, state, message, exitStatus) => + handleExecutorStateChanged(executorStateChanged) case KillExecutor(masterUrl, appId, execId) => if (masterUrl != activeMasterUrl) { @@ -434,7 +493,7 @@ private[worker] class Worker( sparkHome, driverDesc.copy(command = Worker.maybeUpdateSSLSettings(driverDesc.command, conf)), self, - akkaUrl, + workerUri, securityMgr) drivers(driverId) = driver driver.start() @@ -453,36 +512,10 @@ private[worker] class Worker( } } - case DriverStateChanged(driverId, state, exception) => { - state match { - case DriverState.ERROR => - logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}") - case DriverState.FAILED => - logWarning(s"Driver $driverId exited with failure") - case DriverState.FINISHED => - logInfo(s"Driver $driverId exited successfully") - case DriverState.KILLED => - logInfo(s"Driver $driverId was killed by user") - case _ => - logDebug(s"Driver $driverId changed state to $state") - } - master ! DriverStateChanged(driverId, state, exception) - val driver = drivers.remove(driverId).get - finishedDrivers(driverId) = driver - memoryUsed -= driver.driverDesc.mem - coresUsed -= driver.driverDesc.cores + case driverStateChanged @ DriverStateChanged(driverId, state, exception) => { + handleDriverStateChanged(driverStateChanged) } - case x: DisassociatedEvent if x.remoteAddress == masterAddress => - logInfo(s"$x Disassociated !") - masterDisconnected() - - case RequestWorkerState => - sender ! WorkerStateResponse(host, port, workerId, executors.values.toList, - finishedExecutors.values.toList, drivers.values.toList, - finishedDrivers.values.toList, activeMasterUrl, cores, memory, - coresUsed, memoryUsed, activeMasterWebUiUrl) - case ReregisterWithMaster => reregisterWithMaster() @@ -491,6 +524,21 @@ private[worker] class Worker( maybeCleanupApplication(id) } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RequestWorkerState => + context.reply(WorkerStateResponse(host, port, workerId, executors.values.toList, + finishedExecutors.values.toList, drivers.values.toList, + finishedDrivers.values.toList, activeMasterUrl, cores, memory, + coresUsed, memoryUsed, activeMasterWebUiUrl)) + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (master.exists(_.address == remoteAddress)) { + logInfo(s"$remoteAddress Disassociated !") + masterDisconnected() + } + } + private def masterDisconnected() { logError("Connection to master failed! Waiting for master to reconnect...") connected = false @@ -507,6 +555,20 @@ private[worker] class Worker( Utils.deleteRecursively(new File(dir)) } } + shuffleService.applicationRemoved(id) + } + } + + /** + * Send a message to the current master. If we have not yet registered successfully with any + * master, the message will be dropped. + */ + private def sendToMaster(message: Any): Unit = { + master match { + case Some(masterRef) => masterRef.send(message) + case None => + logWarning( + s"Dropping $message because the connection to master has not yet been established") } } @@ -514,28 +576,106 @@ private[worker] class Worker( "worker-%s-%s-%d".format(createDateFormat.format(new Date), host, port) } - override def postStop() { + override def onStop() { + cleanupThreadExecutor.shutdownNow() metricsSystem.report() - registrationRetryTimer.foreach(_.cancel()) + cancelLastRegistrationRetry() + forwordMessageScheduler.shutdownNow() + registerMasterThreadPool.shutdownNow() executors.values.foreach(_.kill()) drivers.values.foreach(_.kill()) shuffleService.stop() webUi.stop() metricsSystem.stop() } + + private def trimFinishedExecutorsIfNecessary(): Unit = { + // do not need to protect with locks since both WorkerPage and Restful server get data through + // thread-safe RpcEndPoint + if (finishedExecutors.size > retainedExecutors) { + finishedExecutors.take(math.max(finishedExecutors.size / 10, 1)).foreach { + case (executorId, _) => finishedExecutors.remove(executorId) + } + } + } + + private def trimFinishedDriversIfNecessary(): Unit = { + // do not need to protect with locks since both WorkerPage and Restful server get data through + // thread-safe RpcEndPoint + if (finishedDrivers.size > retainedDrivers) { + finishedDrivers.take(math.max(finishedDrivers.size / 10, 1)).foreach { + case (driverId, _) => finishedDrivers.remove(driverId) + } + } + } + + private[worker] def handleDriverStateChanged(driverStateChanged: DriverStateChanged): Unit = { + val driverId = driverStateChanged.driverId + val exception = driverStateChanged.exception + val state = driverStateChanged.state + state match { + case DriverState.ERROR => + logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}") + case DriverState.FAILED => + logWarning(s"Driver $driverId exited with failure") + case DriverState.FINISHED => + logInfo(s"Driver $driverId exited successfully") + case DriverState.KILLED => + logInfo(s"Driver $driverId was killed by user") + case _ => + logDebug(s"Driver $driverId changed state to $state") + } + sendToMaster(driverStateChanged) + val driver = drivers.remove(driverId).get + finishedDrivers(driverId) = driver + trimFinishedDriversIfNecessary() + memoryUsed -= driver.driverDesc.mem + coresUsed -= driver.driverDesc.cores + } + + private[worker] def handleExecutorStateChanged(executorStateChanged: ExecutorStateChanged): + Unit = { + sendToMaster(executorStateChanged) + val state = executorStateChanged.state + if (ExecutorState.isFinished(state)) { + val appId = executorStateChanged.appId + val fullId = appId + "/" + executorStateChanged.execId + val message = executorStateChanged.message + val exitStatus = executorStateChanged.exitStatus + executors.get(fullId) match { + case Some(executor) => + logInfo("Executor " + fullId + " finished with state " + state + + message.map(" message " + _).getOrElse("") + + exitStatus.map(" exitStatus " + _).getOrElse("")) + executors -= fullId + finishedExecutors(fullId) = executor + trimFinishedExecutorsIfNecessary() + coresUsed -= executor.cores + memoryUsed -= executor.memory + case None => + logInfo("Unknown Executor " + fullId + " finished with state " + state + + message.map(" message " + _).getOrElse("") + + exitStatus.map(" exitStatus " + _).getOrElse("")) + } + maybeCleanupApplication(appId) + } + } } private[deploy] object Worker extends Logging { + val SYSTEM_NAME = "sparkWorker" + val ENDPOINT_NAME = "Worker" + def main(argStrings: Array[String]) { SignalLogger.register(log) val conf = new SparkConf val args = new WorkerArguments(argStrings, conf) - val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, + val rpcEnv = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, args.cores, args.memory, args.masters, args.workDir) - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } - def startSystemAndActor( + def startRpcEnvAndEndpoint( host: String, port: Int, webUiPort: Int, @@ -544,18 +684,16 @@ private[deploy] object Worker extends Logging { masterUrls: Array[String], workDir: String, workerNumber: Option[Int] = None, - conf: SparkConf = new SparkConf): (ActorSystem, Int) = { + conf: SparkConf = new SparkConf): RpcEnv = { - // The LocalSparkCluster runs multiple local sparkWorkerX actor systems - val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") - val actorName = "Worker" + // The LocalSparkCluster runs multiple local sparkWorkerX RPC Environments + val systemName = SYSTEM_NAME + workerNumber.map(_.toString).getOrElse("") val securityMgr = new SecurityManager(conf) - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, - conf = conf, securityManager = securityMgr) - val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem))) - actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, - masterAkkaUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName) - (actorSystem, boundPort) + val rpcEnv = RpcEnv.create(systemName, host, port, conf, securityMgr) + val masterAddresses = masterUrls.map(RpcAddress.fromSparkURL(_)) + rpcEnv.setupEndpoint(ENDPOINT_NAME, new Worker(rpcEnv, webUiPort, cores, memory, + masterAddresses, systemName, ENDPOINT_NAME, workDir, conf, securityMgr)) + rpcEnv } def isUseLocalNodeSSLConfig(cmd: Command): Boolean = { @@ -579,5 +717,4 @@ private[deploy] object Worker extends Logging { cmd } } - } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 9678631da9f6..5181142c5f80 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -121,6 +121,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { * Print usage and exit JVM with the given exit code. */ def printUsageAndExit(exitCode: Int) { + // scalastyle:off println System.err.println( "Usage: Worker [options] \n" + "\n" + @@ -136,6 +137,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { " --webui-port PORT Port for web UI (default: 8081)\n" + " --properties-file FILE Path to a custom Spark properties file.\n" + " Default is conf/spark-defaults.conf.") + // scalastyle:on println System.exit(exitCode) } @@ -147,6 +149,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { val ibmVendor = System.getProperty("java.vendor").contains("IBM") var totalMb = 0 try { + // scalastyle:off classforname val bean = ManagementFactory.getOperatingSystemMXBean() if (ibmVendor) { val beanClass = Class.forName("com.ibm.lang.management.OperatingSystemMXBean") @@ -157,14 +160,17 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { val method = beanClass.getDeclaredMethod("getTotalPhysicalMemorySize") totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt } + // scalastyle:on classforname } catch { case e: Exception => { totalMb = 2*1024 + // scalastyle:off println System.out.println("Failed to get total physical memory. Using " + totalMb + " MB") + // scalastyle:on println } } // Leave out 1 GB for the operating system, but don't return a negative memory size - math.max(totalMb - 1024, 512) + math.max(totalMb - 1024, Utils.DEFAULT_DRIVER_MEM_MB) } def checkWorkerMemory(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 83fb991891a4..735c4f092715 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -18,7 +18,6 @@ package org.apache.spark.deploy.worker import org.apache.spark.Logging -import org.apache.spark.deploy.DeployMessages.SendHeartbeat import org.apache.spark.rpc._ /** @@ -44,7 +43,7 @@ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: Strin private[deploy] def setTesting(testing: Boolean) = isTesting = testing private var isTesting = false - // Lets us filter events only from the worker's actor system + // Lets filter events only from the worker's rpc system private val expectedAddress = RpcAddress.fromURIString(workerUrl) private def isWorker(address: RpcAddress) = expectedAddress == address @@ -63,7 +62,7 @@ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: Strin override def onDisconnected(remoteAddress: RpcAddress): Unit = { if (isWorker(remoteAddress)) { // This log message will never be seen - logError(s"Lost connection to worker actor $workerUrl. Exiting.") + logError(s"Lost connection to worker rpc endpoint $workerUrl. Exiting.") exitNonZero() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 9f9f27d71e1a..fd905feb97e9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -17,10 +17,8 @@ package org.apache.spark.deploy.worker.ui -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask import javax.servlet.http.HttpServletRequest import org.json4s.JValue @@ -32,18 +30,15 @@ import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { - private val workerActor = parent.worker.self - private val timeout = parent.timeout + private val workerEndpoint = parent.worker.self override def renderJson(request: HttpServletRequest): JValue = { - val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] - val workerState = Await.result(stateFuture, timeout) + val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState) JsonProtocol.writeWorkerState(workerState) } def render(request: HttpServletRequest): Seq[Node] = { - val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] - val workerState = Await.result(stateFuture, timeout) + val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState) val executorHeaders = Seq("ExecutorID", "Cores", "State", "Memory", "Job Details", "Logs") val runningExecutors = workerState.executors diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index b3bb5f911dbd..1a0598e50dcf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -20,9 +20,8 @@ package org.apache.spark.deploy.worker.ui import java.io.File import javax.servlet.http.HttpServletRequest -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.Logging import org.apache.spark.deploy.worker.Worker -import org.apache.spark.deploy.worker.ui.WorkerWebUI._ import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.RpcUtils @@ -38,7 +37,7 @@ class WorkerWebUI( extends WebUI(worker.securityMgr, requestedPort, worker.conf, name = "WorkerUI") with Logging { - private[ui] val timeout = RpcUtils.askTimeout(worker.conf) + private[ui] val timeout = RpcUtils.askRpcTimeout(worker.conf) initialize() @@ -49,10 +48,14 @@ class WorkerWebUI( attachPage(new WorkerPage(this)) attachHandler(createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static")) attachHandler(createServletHandler("/log", - (request: HttpServletRequest) => logPage.renderLog(request), worker.securityMgr)) + (request: HttpServletRequest) => logPage.renderLog(request), + worker.securityMgr, + worker.conf)) } } -private[ui] object WorkerWebUI { +private[worker] object WorkerWebUI { val STATIC_RESOURCE_BASE = SparkUI.STATIC_RESOURCE_DIR + val DEFAULT_RETAINED_DRIVERS = 1000 + val DEFAULT_RETAINED_EXECUTORS = 1000 } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index f3a26f54a81f..fcd76ec52742 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -66,7 +66,10 @@ private[spark] class CoarseGrainedExecutorBackend( case Success(msg) => Utils.tryLogNonFatalError { Option(self).foreach(_.send(msg)) // msg must be RegisteredExecutor } - case Failure(e) => logError(s"Cannot register with driver: $driverUrl", e) + case Failure(e) => { + logError(s"Cannot register with driver: $driverUrl", e) + System.exit(1) + } }(ThreadUtils.sameThread) } @@ -232,7 +235,9 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { argv = tail case Nil => case tail => + // scalastyle:off println System.err.println(s"Unrecognized options: ${tail.mkString(" ")}") + // scalastyle:on println printUsageAndExit() } } @@ -246,6 +251,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { } private def printUsageAndExit() = { + // scalastyle:off println System.err.println( """ |"Usage: CoarseGrainedExecutorBackend [options] @@ -259,6 +265,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { | --worker-url | --user-class-path |""".stripMargin) + // scalastyle:on println System.exit(1) } diff --git a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala index f47d7ef511da..7d84889a2def 100644 --- a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala +++ b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala @@ -26,8 +26,8 @@ private[spark] class CommitDeniedException( msg: String, jobID: Int, splitID: Int, - attemptID: Int) + attemptNumber: Int) extends Exception(msg) { - def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, attemptID) + def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, attemptNumber) } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 8f916e0502ec..c3491bb8b1cf 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -17,13 +17,13 @@ package org.apache.spark.executor -import java.io.File +import java.io.{File, NotSerializableException} import java.lang.management.ManagementFactory import java.net.URL import java.nio.ByteBuffer import java.util.concurrent.{ConcurrentHashMap, TimeUnit} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal @@ -147,7 +147,7 @@ private[spark] class Executor( /** Returns the total amount of time this JVM process has spent in garbage collection. */ private def computeTotalGcTime(): Long = { - ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum + ManagementFactory.getGarbageCollectorMXBeans.asScala.map(_.getCollectionTime).sum } class TaskRunner( @@ -209,15 +209,19 @@ private[spark] class Executor( // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() - val value = try { - task.run(taskAttemptId = taskId, attemptNumber = attemptNumber) + var threwException = true + val (value, accumUpdates) = try { + val res = task.run( + taskAttemptId = taskId, + attemptNumber = attemptNumber, + metricsSystem = env.metricsSystem) + threwException = false + res } finally { - // Note: this memory freeing logic is duplicated in DAGScheduler.runLocallyWithinThread; - // when changing this, make sure to update both copies. val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() if (freedMemory > 0) { val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId" - if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) { + if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false) && !threwException) { throw new SparkException(errMsg) } else { logError(errMsg) @@ -245,9 +249,9 @@ private[spark] class Executor( m.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime) m.setJvmGCTime(computeTotalGcTime() - startGCTime) m.setResultSerializationTime(afterSerialization - beforeSerialization) + m.updateAccumulators() } - val accumUpdates = Accumulators.values val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull) val serializedDirectResult = ser.serialize(directResult) val resultSize = serializedDirectResult.limit @@ -297,11 +301,20 @@ private[spark] class Executor( task.metrics.map { m => m.setExecutorRunTime(System.currentTimeMillis() - taskStart) m.setJvmGCTime(computeTotalGcTime() - startGCTime) + m.updateAccumulators() m } } - val taskEndReason = new ExceptionFailure(t, metrics) - execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(taskEndReason)) + val serializedTaskEndReason = { + try { + ser.serialize(new ExceptionFailure(t, metrics)) + } catch { + case _: NotSerializableException => + // t is not serializable so just send the stacktrace + ser.serialize(new ExceptionFailure(t, metrics, false)) + } + } + execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason) // Don't forcibly exit unless the exception was inherently fatal, to avoid // stopping other tasks unnecessarily. @@ -310,12 +323,6 @@ private[spark] class Executor( } } finally { - // Release memory used by this thread for shuffles - env.shuffleMemoryManager.releaseMemoryForThisThread() - // Release memory used by this thread for unrolling blocks - env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() - // Release memory used by this thread for accumulators - Accumulators.clear() runningTasks.remove(taskId) } } @@ -356,7 +363,7 @@ private[spark] class Executor( logInfo("Using REPL class URI: " + classUri) try { val _userClassPathFirst: java.lang.Boolean = userClassPathFirst - val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader") + val klass = Utils.classForName("org.apache.spark.repl.ExecutorClassLoader") .asInstanceOf[Class[_ <: ClassLoader]] val constructor = klass.getConstructor(classOf[SparkConf], classOf[String], classOf[ClassLoader], classOf[Boolean]) @@ -418,12 +425,13 @@ private[spark] class Executor( val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]() val curGCTime = computeTotalGcTime() - for (taskRunner <- runningTasks.values()) { + for (taskRunner <- runningTasks.values().asScala) { if (taskRunner.task != null) { taskRunner.task.metrics.foreach { metrics => metrics.updateShuffleReadMetrics() metrics.updateInputMetrics() metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime) + metrics.updateAccumulators() if (isLocal) { // JobProgressListener will hold an reference of it during @@ -443,7 +451,7 @@ private[spark] class Executor( try { val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](message) if (response.reregisterBlockManager) { - logWarning("Told to re-register on heartbeat") + logInfo("Told to re-register on heartbeat") env.blockManager.reregister() } } catch { diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala index 293c512f8b70..d16f4a1fc4e3 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala @@ -19,7 +19,7 @@ package org.apache.spark.executor import java.util.concurrent.ThreadPoolExecutor -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.hadoop.fs.FileSystem @@ -30,7 +30,7 @@ private[spark] class ExecutorSource(threadPool: ThreadPoolExecutor, executorId: String) extends Source { private def fileStats(scheme: String) : Option[FileSystem.Statistics] = - FileSystem.getAllStatistics().find(s => s.getScheme.equals(scheme)) + FileSystem.getAllStatistics.asScala.find(s => s.getScheme.equals(scheme)) private def registerFileSystemStat[T]( scheme: String, name: String, f: FileSystem.Statistics => T, defaultValue: T) = { diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index cfd672e1d8a9..0474fd2ccc12 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -19,7 +19,7 @@ package org.apache.spark.executor import java.nio.ByteBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.mesos.protobuf.ByteString import org.apache.mesos.{Executor => MesosExecutor, ExecutorDriver, MesosExecutorDriver} @@ -28,7 +28,7 @@ import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _} import org.apache.spark.{Logging, TaskState, SparkConf, SparkEnv} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.scheduler.cluster.mesos.{MesosTaskLaunchData} +import org.apache.spark.scheduler.cluster.mesos.MesosTaskLaunchData import org.apache.spark.util.{SignalLogger, Utils} private[spark] class MesosExecutorBackend @@ -55,7 +55,7 @@ private[spark] class MesosExecutorBackend slaveInfo: SlaveInfo) { // Get num cores for this task from ExecutorInfo, created in MesosSchedulerBackend. - val cpusPerTask = executorInfo.getResourcesList + val cpusPerTask = executorInfo.getResourcesList.asScala .find(_.getName == "cpus") .map(_.getScalar.getValue.toInt) .getOrElse(0) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index a3b4561b07e7..42207a955359 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,11 +17,15 @@ package org.apache.spark.executor +import java.io.{IOException, ObjectInputStream} +import java.util.concurrent.ConcurrentHashMap + import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.DataReadMethod.DataReadMethod import org.apache.spark.storage.{BlockId, BlockStatus} +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -210,10 +214,42 @@ class TaskMetrics extends Serializable { private[spark] def updateInputMetrics(): Unit = synchronized { inputMetrics.foreach(_.updateBytesRead()) } + + @throws(classOf[IOException]) + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { + in.defaultReadObject() + // Get the hostname from cached data, since hostname is the order of number of nodes in + // cluster, so using cached hostname will decrease the object number and alleviate the GC + // overhead. + _hostname = TaskMetrics.getCachedHostName(_hostname) + } + + private var _accumulatorUpdates: Map[Long, Any] = Map.empty + @transient private var _accumulatorsUpdater: () => Map[Long, Any] = null + + private[spark] def updateAccumulators(): Unit = synchronized { + _accumulatorUpdates = _accumulatorsUpdater() + } + + /** + * Return the latest updates of accumulators in this task. + */ + def accumulatorUpdates(): Map[Long, Any] = _accumulatorUpdates + + private[spark] def setAccumulatorsUpdater(accumulatorsUpdater: () => Map[Long, Any]): Unit = { + _accumulatorsUpdater = accumulatorsUpdater + } } private[spark] object TaskMetrics { + private val hostNameCache = new ConcurrentHashMap[String, String]() + def empty: TaskMetrics = new TaskMetrics + + def getCachedHostName(host: String): String = { + val canonicalHost = hostNameCache.putIfAbsent(host, host) + if (canonicalHost != null) canonicalHost else host + } } /** diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala index c219d21fbefa..532850dd5771 100644 --- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala @@ -21,6 +21,8 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{BytesWritable, LongWritable} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} + +import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil /** @@ -39,7 +41,8 @@ private[spark] object FixedLengthBinaryInputFormat { } private[spark] class FixedLengthBinaryInputFormat - extends FileInputFormat[LongWritable, BytesWritable] { + extends FileInputFormat[LongWritable, BytesWritable] + with Logging { private var recordLength = -1 @@ -51,7 +54,7 @@ private[spark] class FixedLengthBinaryInputFormat recordLength = FixedLengthBinaryInputFormat.getRecordLength(context) } if (recordLength <= 0) { - println("record length is less than 0, file cannot be split") + logDebug("record length is less than 0, file cannot be split") false } else { true diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index 6cda7772f77b..e2ffc3b64e5d 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -19,7 +19,7 @@ package org.apache.spark.input import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import com.google.common.io.ByteStreams import org.apache.hadoop.conf.Configuration @@ -44,12 +44,9 @@ private[spark] abstract class StreamFileInputFormat[T] * which is set through setMaxSplitSize */ def setMinPartitions(context: JobContext, minPartitions: Int) { - val files = listStatus(context) - val totalLen = files.map { file => - if (file.isDir) 0L else file.getLen - }.sum - - val maxSplitSize = Math.ceil(totalLen * 1.0 / files.length).toLong + val files = listStatus(context).asScala + val totalLen = files.map(file => if (file.isDir) 0L else file.getLen).sum + val maxSplitSize = Math.ceil(totalLen * 1.0 / files.size).toLong super.setMaxSplitSize(maxSplitSize) } @@ -134,8 +131,8 @@ private[spark] class StreamInputFormat extends StreamFileInputFormat[PortableDat */ @Experimental class PortableDataStream( - @transient isplit: CombineFileSplit, - @transient context: TaskAttemptContext, + isplit: CombineFileSplit, + context: TaskAttemptContext, index: Integer) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala index aaef7c74eea3..1ba34a11414a 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -17,7 +17,7 @@ package org.apache.spark.input -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.InputSplit @@ -52,10 +52,8 @@ private[spark] class WholeTextFileInputFormat * which is set through setMaxSplitSize */ def setMinPartitions(context: JobContext, minPartitions: Int) { - val files = listStatus(context) - val totalLen = files.map { file => - if (file.isDir) 0L else file.getLen - }.sum + val files = listStatus(context).asScala + val totalLen = files.map(file => if (file.isDir) 0L else file.getLen).sum val maxSplitSize = Math.ceil(totalLen * 1.0 / (if (minPartitions == 0) 1 else minPartitions)).toLong super.setMaxSplitSize(maxSplitSize) diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 0d8ac1f80a9f..9dc36704a676 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -63,8 +63,7 @@ private[spark] object CompressionCodec { def createCodec(conf: SparkConf, codecName: String): CompressionCodec = { val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName) val codec = try { - val ctor = Class.forName(codecClass, true, Utils.getContextOrSparkClassLoader) - .getConstructor(classOf[SparkConf]) + val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf]) Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec]) } catch { case e: ClassNotFoundException => None @@ -149,7 +148,7 @@ class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { try { Snappy.getNativeLibraryVersion } catch { - case e: Error => throw new IllegalArgumentException + case e: Error => throw new IllegalArgumentException(e) } override def compressedOutputStream(s: OutputStream): OutputStream = { diff --git a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala index 9be98723aed1..0c096656f923 100644 --- a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala +++ b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala @@ -20,7 +20,7 @@ package org.apache.spark.launcher import java.io.File import java.util.{HashMap => JHashMap, List => JList, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.deploy.Command @@ -32,7 +32,7 @@ import org.apache.spark.deploy.Command private[spark] class WorkerCommandBuilder(sparkHome: String, memoryMb: Int, command: Command) extends AbstractCommandBuilder { - childEnv.putAll(command.environment) + childEnv.putAll(command.environment.asJava) childEnv.put(CommandBuilderUtils.ENV_SPARK_HOME, sparkHome) override def buildCommand(env: JMap[String, String]): JList[String] = { diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 818f7a4c8d42..f7298e8d5c62 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -24,8 +24,10 @@ import org.apache.hadoop.mapred._ import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext} import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.CommitDeniedException import org.apache.spark.{Logging, SparkEnv, TaskContext} +import org.apache.spark.util.{Utils => SparkUtils} private[spark] trait SparkHadoopMapRedUtil { @@ -64,10 +66,10 @@ trait SparkHadoopMapRedUtil { private def firstAvailableClass(first: String, second: String): Class[_] = { try { - Class.forName(first) + SparkUtils.classForName(first) } catch { case e: ClassNotFoundException => - Class.forName(second) + SparkUtils.classForName(second) } } } @@ -89,10 +91,9 @@ object SparkHadoopMapRedUtil extends Logging { committer: MapReduceOutputCommitter, mrTaskContext: MapReduceTaskAttemptContext, jobId: Int, - splitId: Int, - attemptId: Int): Unit = { + splitId: Int): Unit = { - val mrTaskAttemptID = mrTaskContext.getTaskAttemptID + val mrTaskAttemptID = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(mrTaskContext) // Called after we have decided to commit def performCommit(): Unit = { @@ -120,7 +121,8 @@ object SparkHadoopMapRedUtil extends Logging { if (shouldCoordinateWithDriver) { val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator - val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, attemptId) + val taskAttemptNumber = TaskContext.get().attemptNumber() + val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, taskAttemptNumber) if (canCommit) { performCommit() @@ -130,7 +132,7 @@ object SparkHadoopMapRedUtil extends Logging { logInfo(message) // We need to abort the task so that the driver can reschedule new attempts, if necessary committer.abortTask(mrTaskContext) - throw new CommitDeniedException(message, jobId, splitId, attemptId) + throw new CommitDeniedException(message, jobId, splitId, taskAttemptNumber) } } else { // Speculation is disabled or a user has chosen to manually bypass the commit coordination @@ -141,16 +143,4 @@ object SparkHadoopMapRedUtil extends Logging { logInfo(s"No need to commit output of task because needsTaskCommit=false: $mrTaskAttemptID") } } - - def commitTask( - committer: MapReduceOutputCommitter, - mrTaskContext: MapReduceTaskAttemptContext, - sparkTaskContext: TaskContext): Unit = { - commitTask( - committer, - mrTaskContext, - sparkTaskContext.stageId(), - sparkTaskContext.partitionId(), - sparkTaskContext.attemptNumber()) - } } diff --git a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala index 390d148bc97f..943ebcb7bd0a 100644 --- a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala @@ -21,6 +21,7 @@ import java.lang.{Boolean => JBoolean, Integer => JInteger} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapreduce.{JobContext, JobID, TaskAttemptContext, TaskAttemptID} +import org.apache.spark.util.Utils private[spark] trait SparkHadoopMapReduceUtil { @@ -46,7 +47,7 @@ trait SparkHadoopMapReduceUtil { isMap: Boolean, taskId: Int, attemptId: Int): TaskAttemptID = { - val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID") + val klass = Utils.classForName("org.apache.hadoop.mapreduce.TaskAttemptID") try { // First, attempt to use the old-style constructor that takes a boolean isMap // (not available in YARN) @@ -57,7 +58,7 @@ trait SparkHadoopMapReduceUtil { } catch { case exc: NoSuchMethodException => { // If that failed, look for the new constructor that takes a TaskType (not available in 1.x) - val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType") + val taskTypeClass = Utils.classForName("org.apache.hadoop.mapreduce.TaskType") .asInstanceOf[Class[Enum[_]]] val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke( taskTypeClass, if (isMap) "MAP" else "REDUCE") @@ -71,10 +72,10 @@ trait SparkHadoopMapReduceUtil { private def firstAvailableClass(first: String, second: String): Class[_] = { try { - Class.forName(first) + Utils.classForName(first) } catch { case e: ClassNotFoundException => - Class.forName(second) + Utils.classForName(second) } } } diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala index d7495551ad23..dd2d325d8703 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala @@ -20,6 +20,7 @@ package org.apache.spark.metrics import java.io.{FileInputStream, InputStream} import java.util.Properties +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.matching.Regex @@ -58,25 +59,20 @@ private[spark] class MetricsConfig(conf: SparkConf) extends Logging { propertyCategories = subProperties(properties, INSTANCE_REGEX) if (propertyCategories.contains(DEFAULT_PREFIX)) { - import scala.collection.JavaConversions._ - - val defaultProperty = propertyCategories(DEFAULT_PREFIX) - for { (inst, prop) <- propertyCategories - if (inst != DEFAULT_PREFIX) - (k, v) <- defaultProperty - if (prop.getProperty(k) == null) } { - prop.setProperty(k, v) + val defaultProperty = propertyCategories(DEFAULT_PREFIX).asScala + for((inst, prop) <- propertyCategories if (inst != DEFAULT_PREFIX); + (k, v) <- defaultProperty if (prop.get(k) == null)) { + prop.put(k, v) } } } def subProperties(prop: Properties, regex: Regex): mutable.HashMap[String, Properties] = { val subProperties = new mutable.HashMap[String, Properties] - import scala.collection.JavaConversions._ - prop.foreach { kv => - if (regex.findPrefixOf(kv._1).isDefined) { - val regex(prefix, suffix) = kv._1 - subProperties.getOrElseUpdate(prefix, new Properties).setProperty(suffix, kv._2) + prop.asScala.foreach { kv => + if (regex.findPrefixOf(kv._1.toString).isDefined) { + val regex(prefix, suffix) = kv._1.toString + subProperties.getOrElseUpdate(prefix, new Properties).setProperty(suffix, kv._2.toString) } } subProperties diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index ed5131c79fdc..48afe3ae3511 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -20,6 +20,8 @@ package org.apache.spark.metrics import java.util.Properties import java.util.concurrent.TimeUnit +import org.apache.spark.util.Utils + import scala.collection.mutable import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry} @@ -86,7 +88,7 @@ private[spark] class MetricsSystem private ( */ def getServletHandlers: Array[ServletContextHandler] = { require(running, "Can only call getServletHandlers on a running MetricsSystem") - metricsServlet.map(_.getHandlers).getOrElse(Array()) + metricsServlet.map(_.getHandlers(conf)).getOrElse(Array()) } metricsConfig.initialize() @@ -140,6 +142,9 @@ private[spark] class MetricsSystem private ( } else { defaultName } } + def getSourcesByName(sourceName: String): Seq[Source] = + sources.filter(_.sourceName == sourceName) + def registerSource(source: Source) { sources += source try { @@ -166,7 +171,7 @@ private[spark] class MetricsSystem private ( sourceConfigs.foreach { kv => val classPath = kv._2.getProperty("class") try { - val source = Class.forName(classPath).newInstance() + val source = Utils.classForName(classPath).newInstance() registerSource(source.asInstanceOf[Source]) } catch { case e: Exception => logError("Source class " + classPath + " cannot be instantiated", e) @@ -182,7 +187,7 @@ private[spark] class MetricsSystem private ( val classPath = kv._2.getProperty("class") if (null != classPath) { try { - val sink = Class.forName(classPath) + val sink = Utils.classForName(classPath) .getConstructor(classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager]) .newInstance(kv._2, registry, securityMgr) if (kv._1 == "servlet") { diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 0c2e212a3307..4193e1d21d3c 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -27,7 +27,7 @@ import com.codahale.metrics.json.MetricsModule import com.fasterxml.jackson.databind.ObjectMapper import org.eclipse.jetty.servlet.ServletContextHandler -import org.apache.spark.SecurityManager +import org.apache.spark.{SparkConf, SecurityManager} import org.apache.spark.ui.JettyUtils._ private[spark] class MetricsServlet( @@ -49,10 +49,10 @@ private[spark] class MetricsServlet( val mapper = new ObjectMapper().registerModule( new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample)) - def getHandlers: Array[ServletContextHandler] = { + def getHandlers(conf: SparkConf): Array[ServletContextHandler] = { Array[ServletContextHandler]( createServletHandler(servletPath, - new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr) + new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr, conf) ) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index b089da8596e2..76968249fb62 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -19,7 +19,7 @@ package org.apache.spark.network.netty import java.nio.ByteBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.Logging import org.apache.spark.network.BlockDataManager @@ -38,6 +38,7 @@ import org.apache.spark.storage.{BlockId, StorageLevel} * is equivalent to one Spark-level shuffle block. */ class NettyBlockRpcServer( + appId: String, serializer: Serializer, blockManager: BlockDataManager) extends RpcHandler with Logging { @@ -55,7 +56,7 @@ class NettyBlockRpcServer( case openBlocks: OpenBlocks => val blocks: Seq[ManagedBuffer] = openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) - val streamId = streamManager.registerStream(blocks.iterator) + val streamId = streamManager.registerStream(appId, blocks.iterator.asJava) logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index d650d5fe7308..70a42f9045e6 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,7 +17,7 @@ package org.apache.spark.network.netty -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.concurrent.{Future, Promise} import org.apache.spark.{SecurityManager, SparkConf} @@ -49,7 +49,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage private[this] var appId: String = _ override def init(blockDataManager: BlockDataManager): Unit = { - val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) + val rpcHandler = new NettyBlockRpcServer(conf.getAppId, serializer, blockDataManager) var serverBootstrap: Option[TransportServerBootstrap] = None var clientBootstrap: Option[TransportClientBootstrap] = None if (authEnabled) { @@ -58,7 +58,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage securityManager.isSaslEncryptionEnabled())) } transportContext = new TransportContext(transportConf, rpcHandler) - clientFactory = transportContext.createClientFactory(clientBootstrap.toList) + clientFactory = transportContext.createClientFactory(clientBootstrap.toSeq.asJava) server = createServer(serverBootstrap.toList) appId = conf.getAppId logInfo("Server created on " + server.getPort) @@ -67,7 +67,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage /** Creates and binds the TransportServer, possibly trying multiple ports. */ private def createServer(bootstraps: List[TransportServerBootstrap]): TransportServer = { def startService(port: Int): (TransportServer, Int) = { - val server = transportContext.createServer(port, bootstraps) + val server = transportContext.createServer(port, bootstraps.asJava) (server, server.getPort) } @@ -137,7 +137,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage new RpcResponseCallback { override def onSuccess(response: Array[Byte]): Unit = { logTrace(s"Successfully uploaded block $blockId") - result.success() + result.success((): Unit) } override def onFailure(e: Throwable): Unit = { logError(s"Error while uploading block $blockId", e) @@ -149,7 +149,11 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage } override def close(): Unit = { - server.close() - clientFactory.close() + if (server != null) { + server.close() + } + if (clientFactory != null) { + clientFactory.close() + } } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala deleted file mode 100644 index 67a376102994..000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala +++ /dev/null @@ -1,197 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} - -import scala.collection.mutable.{ArrayBuffer, StringBuilder} - -// private[spark] because we need to register them in Kryo -private[spark] case class GetBlock(id: BlockId) -private[spark] case class GotBlock(id: BlockId, data: ByteBuffer) -private[spark] case class PutBlock(id: BlockId, data: ByteBuffer, level: StorageLevel) - -private[nio] class BlockMessage() { - // Un-initialized: typ = 0 - // GetBlock: typ = 1 - // GotBlock: typ = 2 - // PutBlock: typ = 3 - private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED - private var id: BlockId = null - private var data: ByteBuffer = null - private var level: StorageLevel = null - - def set(getBlock: GetBlock) { - typ = BlockMessage.TYPE_GET_BLOCK - id = getBlock.id - } - - def set(gotBlock: GotBlock) { - typ = BlockMessage.TYPE_GOT_BLOCK - id = gotBlock.id - data = gotBlock.data - } - - def set(putBlock: PutBlock) { - typ = BlockMessage.TYPE_PUT_BLOCK - id = putBlock.id - data = putBlock.data - level = putBlock.level - } - - def set(buffer: ByteBuffer) { - /* - println() - println("BlockMessage: ") - while(buffer.remaining > 0) { - print(buffer.get()) - } - buffer.rewind() - println() - println() - */ - typ = buffer.getInt() - val idLength = buffer.getInt() - val idBuilder = new StringBuilder(idLength) - for (i <- 1 to idLength) { - idBuilder += buffer.getChar() - } - id = BlockId(idBuilder.toString) - - if (typ == BlockMessage.TYPE_PUT_BLOCK) { - - val booleanInt = buffer.getInt() - val replication = buffer.getInt() - level = StorageLevel(booleanInt, replication) - - val dataLength = buffer.getInt() - data = ByteBuffer.allocate(dataLength) - if (dataLength != buffer.remaining) { - throw new Exception("Error parsing buffer") - } - data.put(buffer) - data.flip() - } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { - - val dataLength = buffer.getInt() - data = ByteBuffer.allocate(dataLength) - if (dataLength != buffer.remaining) { - throw new Exception("Error parsing buffer") - } - data.put(buffer) - data.flip() - } - - } - - def set(bufferMsg: BufferMessage) { - val buffer = bufferMsg.buffers.apply(0) - buffer.clear() - set(buffer) - } - - def getType: Int = typ - def getId: BlockId = id - def getData: ByteBuffer = data - def getLevel: StorageLevel = level - - def toBufferMessage: BufferMessage = { - val buffers = new ArrayBuffer[ByteBuffer]() - var buffer = ByteBuffer.allocate(4 + 4 + id.name.length * 2) - buffer.putInt(typ).putInt(id.name.length) - id.name.foreach((x: Char) => buffer.putChar(x)) - buffer.flip() - buffers += buffer - - if (typ == BlockMessage.TYPE_PUT_BLOCK) { - buffer = ByteBuffer.allocate(8).putInt(level.toInt).putInt(level.replication) - buffer.flip() - buffers += buffer - - buffer = ByteBuffer.allocate(4).putInt(data.remaining) - buffer.flip() - buffers += buffer - - buffers += data - } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { - buffer = ByteBuffer.allocate(4).putInt(data.remaining) - buffer.flip() - buffers += buffer - - buffers += data - } - - /* - println() - println("BlockMessage: ") - buffers.foreach(b => { - while(b.remaining > 0) { - print(b.get()) - } - b.rewind() - }) - println() - println() - */ - Message.createBufferMessage(buffers) - } - - override def toString: String = { - "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level + - ", data = " + (if (data != null) data.remaining.toString else "null") + "]" - } -} - -private[nio] object BlockMessage { - val TYPE_NON_INITIALIZED: Int = 0 - val TYPE_GET_BLOCK: Int = 1 - val TYPE_GOT_BLOCK: Int = 2 - val TYPE_PUT_BLOCK: Int = 3 - - def fromBufferMessage(bufferMessage: BufferMessage): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(bufferMessage) - newBlockMessage - } - - def fromByteBuffer(buffer: ByteBuffer): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(buffer) - newBlockMessage - } - - def fromGetBlock(getBlock: GetBlock): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(getBlock) - newBlockMessage - } - - def fromGotBlock(gotBlock: GotBlock): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(gotBlock) - newBlockMessage - } - - def fromPutBlock(putBlock: PutBlock): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(putBlock) - newBlockMessage - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala deleted file mode 100644 index 7d0806f0c258..000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import org.apache.spark._ -import org.apache.spark.storage.{StorageLevel, TestBlockId} - -import scala.collection.mutable.ArrayBuffer - -private[nio] -class BlockMessageArray(var blockMessages: Seq[BlockMessage]) - extends Seq[BlockMessage] with Logging { - - def this(bm: BlockMessage) = this(Array(bm)) - - def this() = this(null.asInstanceOf[Seq[BlockMessage]]) - - def apply(i: Int): BlockMessage = blockMessages(i) - - def iterator: Iterator[BlockMessage] = blockMessages.iterator - - def length: Int = blockMessages.length - - def set(bufferMessage: BufferMessage) { - val startTime = System.currentTimeMillis - val newBlockMessages = new ArrayBuffer[BlockMessage]() - val buffer = bufferMessage.buffers(0) - buffer.clear() - /* - println() - println("BlockMessageArray: ") - while(buffer.remaining > 0) { - print(buffer.get()) - } - buffer.rewind() - println() - println() - */ - while (buffer.remaining() > 0) { - val size = buffer.getInt() - logDebug("Creating block message of size " + size + " bytes") - val newBuffer = buffer.slice() - newBuffer.clear() - newBuffer.limit(size) - logDebug("Trying to convert buffer " + newBuffer + " to block message") - val newBlockMessage = BlockMessage.fromByteBuffer(newBuffer) - logDebug("Created " + newBlockMessage) - newBlockMessages += newBlockMessage - buffer.position(buffer.position() + size) - } - val finishTime = System.currentTimeMillis - logDebug("Converted block message array from buffer message in " + - (finishTime - startTime) / 1000.0 + " s") - this.blockMessages = newBlockMessages - } - - def toBufferMessage: BufferMessage = { - val buffers = new ArrayBuffer[ByteBuffer]() - - blockMessages.foreach(blockMessage => { - val bufferMessage = blockMessage.toBufferMessage - logDebug("Adding " + blockMessage) - val sizeBuffer = ByteBuffer.allocate(4).putInt(bufferMessage.size) - sizeBuffer.flip - buffers += sizeBuffer - buffers ++= bufferMessage.buffers - logDebug("Added " + bufferMessage) - }) - - logDebug("Buffer list:") - buffers.foreach((x: ByteBuffer) => logDebug("" + x)) - /* - println() - println("BlockMessageArray: ") - buffers.foreach(b => { - while(b.remaining > 0) { - print(b.get()) - } - b.rewind() - }) - println() - println() - */ - Message.createBufferMessage(buffers) - } -} - -private[nio] object BlockMessageArray { - - def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = { - val newBlockMessageArray = new BlockMessageArray() - newBlockMessageArray.set(bufferMessage) - newBlockMessageArray - } - - def main(args: Array[String]) { - val blockMessages = - (0 until 10).map { i => - if (i % 2 == 0) { - val buffer = ByteBuffer.allocate(100) - buffer.clear() - BlockMessage.fromPutBlock(PutBlock(TestBlockId(i.toString), buffer, - StorageLevel.MEMORY_ONLY_SER)) - } else { - BlockMessage.fromGetBlock(GetBlock(TestBlockId(i.toString))) - } - } - val blockMessageArray = new BlockMessageArray(blockMessages) - println("Block message array created") - - val bufferMessage = blockMessageArray.toBufferMessage - println("Converted to buffer message") - - val totalSize = bufferMessage.size - val newBuffer = ByteBuffer.allocate(totalSize) - newBuffer.clear() - bufferMessage.buffers.foreach(buffer => { - assert (0 == buffer.position()) - newBuffer.put(buffer) - buffer.rewind() - }) - newBuffer.flip - val newBufferMessage = Message.createBufferMessage(newBuffer) - println("Copied to new buffer message, size = " + newBufferMessage.size) - - val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage) - println("Converted back to block message array") - newBlockMessageArray.foreach(blockMessage => { - blockMessage.getType match { - case BlockMessage.TYPE_PUT_BLOCK => { - val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) - println(pB) - } - case BlockMessage.TYPE_GET_BLOCK => { - val gB = new GetBlock(blockMessage.getId) - println(gB) - } - } - }) - } -} - - diff --git a/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala deleted file mode 100644 index 9a9e22b0c236..000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.storage.BlockManager - - -private[nio] -class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) - extends Message(Message.BUFFER_MESSAGE, id_) { - - val initialSize = currentSize() - var gotChunkForSendingOnce = false - - def size: Int = initialSize - - def currentSize(): Int = { - if (buffers == null || buffers.isEmpty) { - 0 - } else { - buffers.map(_.remaining).reduceLeft(_ + _) - } - } - - def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = { - if (maxChunkSize <= 0) { - throw new Exception("Max chunk size is " + maxChunkSize) - } - - val security = if (isSecurityNeg) 1 else 0 - if (size == 0 && !gotChunkForSendingOnce) { - val newChunk = new MessageChunk( - new MessageChunkHeader(typ, id, 0, 0, ackId, hasError, security, senderAddress), null) - gotChunkForSendingOnce = true - return Some(newChunk) - } - - while(!buffers.isEmpty) { - val buffer = buffers(0) - if (buffer.remaining == 0) { - BlockManager.dispose(buffer) - buffers -= buffer - } else { - val newBuffer = if (buffer.remaining <= maxChunkSize) { - buffer.duplicate() - } else { - buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer] - } - buffer.position(buffer.position + newBuffer.remaining) - val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, - hasError, security, senderAddress), newBuffer) - gotChunkForSendingOnce = true - return Some(newChunk) - } - } - None - } - - def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = { - // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer - if (buffers.size > 1) { - throw new Exception("Attempting to get chunk from message with multiple data buffers") - } - val buffer = buffers(0) - val security = if (isSecurityNeg) 1 else 0 - if (buffer.remaining > 0) { - if (buffer.remaining < chunkSize) { - throw new Exception("Not enough space in data buffer for receiving chunk") - } - val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] - buffer.position(buffer.position + newBuffer.remaining) - val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer) - return Some(newChunk) - } - None - } - - def flip() { - buffers.foreach(_.flip) - } - - def hasAckId(): Boolean = ackId != 0 - - def isCompletelyReceived: Boolean = !buffers(0).hasRemaining - - override def toString: String = { - if (hasAckId) { - "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")" - } else { - "BufferMessage(id = " + id + ", size = " + size + ")" - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala deleted file mode 100644 index 1499da07bb83..000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ /dev/null @@ -1,619 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.net._ -import java.nio._ -import java.nio.channels._ -import java.util.concurrent.ConcurrentLinkedQueue -import java.util.LinkedList - -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.util.control.NonFatal - -import org.apache.spark._ -import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} - -private[nio] -abstract class Connection(val channel: SocketChannel, val selector: Selector, - val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId, - val securityMgr: SecurityManager) - extends Logging { - - var sparkSaslServer: SparkSaslServer = null - var sparkSaslClient: SparkSaslClient = null - - def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId, - securityMgr_ : SecurityManager) = { - this(channel_, selector_, - ConnectionManagerId.fromSocketAddress( - channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]), - id_, securityMgr_) - } - - channel.configureBlocking(false) - channel.socket.setTcpNoDelay(true) - channel.socket.setReuseAddress(true) - channel.socket.setKeepAlive(true) - /* channel.socket.setReceiveBufferSize(32768) */ - - @volatile private var closed = false - var onCloseCallback: Connection => Unit = null - val onExceptionCallbacks = new ConcurrentLinkedQueue[(Connection, Throwable) => Unit] - var onKeyInterestChangeCallback: (Connection, Int) => Unit = null - - val remoteAddress = getRemoteAddress() - - def isSaslComplete(): Boolean - - def resetForceReregister(): Boolean - - // Read channels typically do not register for write and write does not for read - // Now, we do have write registering for read too (temporarily), but this is to detect - // channel close NOT to actually read/consume data on it ! - // How does this work if/when we move to SSL ? - - // What is the interest to register with selector for when we want this connection to be selected - def registerInterest() - - // What is the interest to register with selector for when we want this connection to - // be de-selected - // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack, - // it will be SelectionKey.OP_READ (until we fix it properly) - def unregisterInterest() - - // On receiving a read event, should we change the interest for this channel or not ? - // Will be true for ReceivingConnection, false for SendingConnection. - def changeInterestForRead(): Boolean - - private def disposeSasl() { - if (sparkSaslServer != null) { - sparkSaslServer.dispose() - } - - if (sparkSaslClient != null) { - sparkSaslClient.dispose() - } - } - - // On receiving a write event, should we change the interest for this channel or not ? - // Will be false for ReceivingConnection, true for SendingConnection. - // Actually, for now, should not get triggered for ReceivingConnection - def changeInterestForWrite(): Boolean - - def getRemoteConnectionManagerId(): ConnectionManagerId = { - socketRemoteConnectionManagerId - } - - def key(): SelectionKey = channel.keyFor(selector) - - def getRemoteAddress(): InetSocketAddress = { - channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] - } - - // Returns whether we have to register for further reads or not. - def read(): Boolean = { - throw new UnsupportedOperationException( - "Cannot read on connection of type " + this.getClass.toString) - } - - // Returns whether we have to register for further writes or not. - def write(): Boolean = { - throw new UnsupportedOperationException( - "Cannot write on connection of type " + this.getClass.toString) - } - - def close() { - closed = true - val k = key() - if (k != null) { - k.cancel() - } - channel.close() - disposeSasl() - callOnCloseCallback() - } - - protected def isClosed: Boolean = closed - - def onClose(callback: Connection => Unit) { - onCloseCallback = callback - } - - def onException(callback: (Connection, Throwable) => Unit) { - onExceptionCallbacks.add(callback) - } - - def onKeyInterestChange(callback: (Connection, Int) => Unit) { - onKeyInterestChangeCallback = callback - } - - def callOnExceptionCallbacks(e: Throwable) { - onExceptionCallbacks foreach { - callback => - try { - callback(this, e) - } catch { - case NonFatal(e) => { - logWarning("Ignored error in onExceptionCallback", e) - } - } - } - } - - def callOnCloseCallback() { - if (onCloseCallback != null) { - onCloseCallback(this) - } else { - logWarning("Connection to " + getRemoteConnectionManagerId() + - " closed and OnExceptionCallback not registered") - } - - } - - def changeConnectionKeyInterest(ops: Int) { - if (onKeyInterestChangeCallback != null) { - onKeyInterestChangeCallback(this, ops) - } else { - throw new Exception("OnKeyInterestChangeCallback not registered") - } - } - - def printRemainingBuffer(buffer: ByteBuffer) { - val bytes = new Array[Byte](buffer.remaining) - val curPosition = buffer.position - buffer.get(bytes) - bytes.foreach(x => print(x + " ")) - buffer.position(curPosition) - print(" (" + bytes.length + ")") - } - - def printBuffer(buffer: ByteBuffer, position: Int, length: Int) { - val bytes = new Array[Byte](length) - val curPosition = buffer.position - buffer.position(position) - buffer.get(bytes) - bytes.foreach(x => print(x + " ")) - print(" (" + position + ", " + length + ")") - buffer.position(curPosition) - } -} - - -private[nio] -class SendingConnection(val address: InetSocketAddress, selector_ : Selector, - remoteId_ : ConnectionManagerId, id_ : ConnectionId, - securityMgr_ : SecurityManager) - extends Connection(SocketChannel.open, selector_, remoteId_, id_, securityMgr_) { - - def isSaslComplete(): Boolean = { - if (sparkSaslClient != null) sparkSaslClient.isComplete() else false - } - - private class Outbox { - val messages = new LinkedList[Message]() - val defaultChunkSize = 65536 - var nextMessageToBeUsed = 0 - - def addMessage(message: Message) { - messages.synchronized { - messages.add(message) - logDebug("Added [" + message + "] to outbox for sending to " + - "[" + getRemoteConnectionManagerId() + "]") - } - } - - def getChunk(): Option[MessageChunk] = { - messages.synchronized { - while (!messages.isEmpty) { - /* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */ - /* val message = messages(nextMessageToBeUsed) */ - - val message = if (securityMgr.isAuthenticationEnabled() && !isSaslComplete()) { - // only allow sending of security messages until sasl is complete - var pos = 0 - var securityMsg: Message = null - while (pos < messages.size() && securityMsg == null) { - if (messages.get(pos).isSecurityNeg) { - securityMsg = messages.remove(pos) - } - pos = pos + 1 - } - // didn't find any security messages and auth isn't completed so return - if (securityMsg == null) return None - securityMsg - } else { - messages.removeFirst() - } - - val chunk = message.getChunkForSending(defaultChunkSize) - if (chunk.isDefined) { - messages.add(message) - nextMessageToBeUsed = nextMessageToBeUsed + 1 - if (!message.started) { - logDebug( - "Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]") - message.started = true - message.startTime = System.currentTimeMillis - } - logTrace( - "Sending chunk from [" + message + "] to [" + getRemoteConnectionManagerId() + "]") - return chunk - } else { - message.finishTime = System.currentTimeMillis - logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + - "] in " + message.timeTaken ) - } - } - } - None - } - } - - // outbox is used as a lock - ensure that it is always used as a leaf (since methods which - // lock it are invoked in context of other locks) - private val outbox = new Outbox() - /* - This is orthogonal to whether we have pending bytes to write or not - and satisfies a slightly - different purpose. This flag is to see if we need to force reregister for write even when we - do not have any pending bytes to write to socket. - This can happen due to a race between adding pending buffers, and checking for existing of - data as detailed in https://github.com/mesos/spark/pull/791 - */ - private var needForceReregister = false - - val currentBuffers = new ArrayBuffer[ByteBuffer]() - - /* channel.socket.setSendBufferSize(256 * 1024) */ - - override def getRemoteAddress(): InetSocketAddress = address - - val DEFAULT_INTEREST = SelectionKey.OP_READ - - override def registerInterest() { - // Registering read too - does not really help in most cases, but for some - // it does - so let us keep it for now. - changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST) - } - - override def unregisterInterest() { - changeConnectionKeyInterest(DEFAULT_INTEREST) - } - - def registerAfterAuth(): Unit = { - outbox.synchronized { - needForceReregister = true - } - if (channel.isConnected) { - registerInterest() - } - } - - def send(message: Message) { - outbox.synchronized { - outbox.addMessage(message) - needForceReregister = true - } - if (channel.isConnected) { - registerInterest() - } - } - - // return previous value after resetting it. - def resetForceReregister(): Boolean = { - outbox.synchronized { - val result = needForceReregister - needForceReregister = false - result - } - } - - // MUST be called within the selector loop - def connect() { - try { - channel.register(selector, SelectionKey.OP_CONNECT) - channel.connect(address) - logInfo("Initiating connection to [" + address + "]") - } catch { - case e: Exception => - logError("Error connecting to " + address, e) - callOnExceptionCallbacks(e) - } - } - - def finishConnect(force: Boolean): Boolean = { - try { - // Typically, this should finish immediately since it was triggered by a connect - // selection - though need not necessarily always complete successfully. - val connected = channel.finishConnect - if (!force && !connected) { - logInfo( - "finish connect failed [" + address + "], " + outbox.messages.size + " messages pending") - return false - } - - // Fallback to previous behavior - assume finishConnect completed - // This will happen only when finishConnect failed for some repeated number of times - // (10 or so) - // Is highly unlikely unless there was an unclean close of socket, etc - registerInterest() - logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending") - } catch { - case e: Exception => { - logWarning("Error finishing connection to " + address, e) - callOnExceptionCallbacks(e) - } - } - true - } - - override def write(): Boolean = { - try { - while (true) { - if (currentBuffers.size == 0) { - outbox.synchronized { - outbox.getChunk() match { - case Some(chunk) => { - val buffers = chunk.buffers - // If we have 'seen' pending messages, then reset flag - since we handle that as - // normal registering of event (below) - if (needForceReregister && buffers.exists(_.remaining() > 0)) resetForceReregister() - - currentBuffers ++= buffers - } - case None => { - // changeConnectionKeyInterest(0) - /* key.interestOps(0) */ - return false - } - } - } - } - - if (currentBuffers.size > 0) { - val buffer = currentBuffers(0) - val remainingBytes = buffer.remaining - val writtenBytes = channel.write(buffer) - if (buffer.remaining == 0) { - currentBuffers -= buffer - } - if (writtenBytes < remainingBytes) { - // re-register for write. - return true - } - } - } - } catch { - case e: Exception => { - logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallbacks(e) - close() - return false - } - } - // should not happen - to keep scala compiler happy - true - } - - // This is a hack to determine if remote socket was closed or not. - // SendingConnection DOES NOT expect to receive any data - if it does, it is an error - // For a bunch of cases, read will return -1 in case remote socket is closed : hence we - // register for reads to determine that. - override def read(): Boolean = { - // We don't expect the other side to send anything; so, we just read to detect an error or EOF. - try { - val length = channel.read(ByteBuffer.allocate(1)) - if (length == -1) { // EOF - close() - } else if (length > 0) { - logWarning( - "Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId()) - } - } catch { - case e: Exception => - logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), - e) - callOnExceptionCallbacks(e) - close() - } - - false - } - - override def changeInterestForRead(): Boolean = false - - override def changeInterestForWrite(): Boolean = ! isClosed -} - - -// Must be created within selector loop - else deadlock -private[spark] class ReceivingConnection( - channel_ : SocketChannel, - selector_ : Selector, - id_ : ConnectionId, - securityMgr_ : SecurityManager) - extends Connection(channel_, selector_, id_, securityMgr_) { - - def isSaslComplete(): Boolean = { - if (sparkSaslServer != null) sparkSaslServer.isComplete() else false - } - - class Inbox() { - val messages = new HashMap[Int, BufferMessage]() - - def getChunk(header: MessageChunkHeader): Option[MessageChunk] = { - - def createNewMessage: BufferMessage = { - val newMessage = Message.create(header).asInstanceOf[BufferMessage] - newMessage.started = true - newMessage.startTime = System.currentTimeMillis - newMessage.isSecurityNeg = header.securityNeg == 1 - logDebug( - "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]") - messages += ((newMessage.id, newMessage)) - newMessage - } - - val message = messages.getOrElseUpdate(header.id, createNewMessage) - logTrace( - "Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]") - message.getChunkForReceiving(header.chunkSize) - } - - def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = { - messages.get(chunk.header.id) - } - - def removeMessage(message: Message) { - messages -= message.id - } - } - - @volatile private var inferredRemoteManagerId: ConnectionManagerId = null - - override def getRemoteConnectionManagerId(): ConnectionManagerId = { - val currId = inferredRemoteManagerId - if (currId != null) currId else super.getRemoteConnectionManagerId() - } - - // The receiver's remote address is the local socket on remote side : which is NOT - // the connection manager id of the receiver. - // We infer that from the messages we receive on the receiver socket. - private def processConnectionManagerId(header: MessageChunkHeader) { - val currId = inferredRemoteManagerId - if (header.address == null || currId != null) return - - val managerId = ConnectionManagerId.fromSocketAddress(header.address) - - if (managerId != null) { - inferredRemoteManagerId = managerId - } - } - - - val inbox = new Inbox() - val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE) - var onReceiveCallback: (Connection, Message) => Unit = null - var currentChunk: MessageChunk = null - - channel.register(selector, SelectionKey.OP_READ) - - override def read(): Boolean = { - try { - while (true) { - if (currentChunk == null) { - val headerBytesRead = channel.read(headerBuffer) - if (headerBytesRead == -1) { - close() - return false - } - if (headerBuffer.remaining > 0) { - // re-register for read event ... - return true - } - headerBuffer.flip - if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) { - throw new Exception( - "Unexpected number of bytes (" + headerBuffer.remaining + ") in the header") - } - val header = MessageChunkHeader.create(headerBuffer) - headerBuffer.clear() - - processConnectionManagerId(header) - - header.typ match { - case Message.BUFFER_MESSAGE => { - if (header.totalSize == 0) { - if (onReceiveCallback != null) { - onReceiveCallback(this, Message.create(header)) - } - currentChunk = null - // re-register for read event ... - return true - } else { - currentChunk = inbox.getChunk(header).orNull - } - } - case _ => throw new Exception("Message of unknown type received") - } - } - - if (currentChunk == null) throw new Exception("No message chunk to receive data") - - val bytesRead = channel.read(currentChunk.buffer) - if (bytesRead == 0) { - // re-register for read event ... - return true - } else if (bytesRead == -1) { - close() - return false - } - - /* logDebug("Read " + bytesRead + " bytes for the buffer") */ - - if (currentChunk.buffer.remaining == 0) { - /* println("Filled buffer at " + System.currentTimeMillis) */ - val bufferMessage = inbox.getMessageForChunk(currentChunk).get - if (bufferMessage.isCompletelyReceived) { - bufferMessage.flip() - bufferMessage.finishTime = System.currentTimeMillis - logDebug("Finished receiving [" + bufferMessage + "] from " + - "[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken) - if (onReceiveCallback != null) { - onReceiveCallback(this, bufferMessage) - } - inbox.removeMessage(bufferMessage) - } - currentChunk = null - } - } - } catch { - case e: Exception => { - logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallbacks(e) - close() - return false - } - } - // should not happen - to keep scala compiler happy - true - } - - def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback} - - // override def changeInterestForRead(): Boolean = ! isClosed - override def changeInterestForRead(): Boolean = true - - override def changeInterestForWrite(): Boolean = { - throw new IllegalStateException("Unexpected invocation right now") - } - - override def registerInterest() { - // Registering read too - does not really help in most cases, but for some - // it does - so let us keep it for now. - changeConnectionKeyInterest(SelectionKey.OP_READ) - } - - override def unregisterInterest() { - changeConnectionKeyInterest(0) - } - - // For read conn, always false. - override def resetForceReregister(): Boolean = false -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala deleted file mode 100644 index c0bca2c4bc99..000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ /dev/null @@ -1,1153 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.io.IOException -import java.lang.ref.WeakReference -import java.net._ -import java.nio._ -import java.nio.channels._ -import java.nio.channels.spi._ -import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit} - -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue} -import scala.concurrent.duration._ -import scala.concurrent.{Await, ExecutionContext, Future, Promise} -import scala.language.postfixOps - -import com.google.common.base.Charsets.UTF_8 -import io.netty.util.{Timeout, TimerTask, HashedWheelTimer} - -import org.apache.spark._ -import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} -import org.apache.spark.util.{ThreadUtils, Utils} - -import scala.util.Try -import scala.util.control.NonFatal - -private[nio] class ConnectionManager( - port: Int, - conf: SparkConf, - securityManager: SecurityManager, - name: String = "Connection manager") - extends Logging { - - /** - * Used by sendMessageReliably to track messages being sent. - * @param message the message that was sent - * @param connectionManagerId the connection manager that sent this message - * @param completionHandler callback that's invoked when the send has completed or failed - */ - class MessageStatus( - val message: Message, - val connectionManagerId: ConnectionManagerId, - completionHandler: Try[Message] => Unit) { - - def success(ackMessage: Message) { - if (ackMessage == null) { - failure(new NullPointerException) - } - else { - completionHandler(scala.util.Success(ackMessage)) - } - } - - def failWithoutAck() { - completionHandler(scala.util.Failure(new IOException("Failed without being ACK'd"))) - } - - def failure(e: Throwable) { - completionHandler(scala.util.Failure(e)) - } - } - - private val selector = SelectorProvider.provider.openSelector() - private val ackTimeoutMonitor = - new HashedWheelTimer(ThreadUtils.namedThreadFactory("AckTimeoutMonitor")) - - private val ackTimeout = - conf.getTimeAsSeconds("spark.core.connection.ack.wait.timeout", - conf.get("spark.network.timeout", "120s")) - - // Get the thread counts from the Spark Configuration. - // - // Even though the ThreadPoolExecutor constructor takes both a minimum and maximum value, - // we only query for the minimum value because we are using LinkedBlockingDeque. - // - // The JavaDoc for ThreadPoolExecutor points out that when using a LinkedBlockingDeque (which is - // an unbounded queue) no more than corePoolSize threads will ever be created, so only the "min" - // parameter is necessary. - private val handlerThreadCount = conf.getInt("spark.core.connection.handler.threads.min", 20) - private val ioThreadCount = conf.getInt("spark.core.connection.io.threads.min", 4) - private val connectThreadCount = conf.getInt("spark.core.connection.connect.threads.min", 1) - - private val handleMessageExecutor = new ThreadPoolExecutor( - handlerThreadCount, - handlerThreadCount, - conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable](), - ThreadUtils.namedThreadFactory("handle-message-executor")) { - - override def afterExecute(r: Runnable, t: Throwable): Unit = { - super.afterExecute(r, t) - if (t != null && NonFatal(t)) { - logError("Error in handleMessageExecutor is not handled properly", t) - } - } - } - - private val handleReadWriteExecutor = new ThreadPoolExecutor( - ioThreadCount, - ioThreadCount, - conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable](), - ThreadUtils.namedThreadFactory("handle-read-write-executor")) { - - override def afterExecute(r: Runnable, t: Throwable): Unit = { - super.afterExecute(r, t) - if (t != null && NonFatal(t)) { - logError("Error in handleReadWriteExecutor is not handled properly", t) - } - } - } - - // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : - // which should be executed asap - private val handleConnectExecutor = new ThreadPoolExecutor( - connectThreadCount, - connectThreadCount, - conf.getInt("spark.core.connection.connect.threads.keepalive", 60), TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable](), - ThreadUtils.namedThreadFactory("handle-connect-executor")) { - - override def afterExecute(r: Runnable, t: Throwable): Unit = { - super.afterExecute(r, t) - if (t != null && NonFatal(t)) { - logError("Error in handleConnectExecutor is not handled properly", t) - } - } - } - - private val serverChannel = ServerSocketChannel.open() - // used to track the SendingConnections waiting to do SASL negotiation - private val connectionsAwaitingSasl = new HashMap[ConnectionId, SendingConnection] - with SynchronizedMap[ConnectionId, SendingConnection] - private val connectionsByKey = - new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] - private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] - with SynchronizedMap[ConnectionManagerId, SendingConnection] - // Tracks sent messages for which we are awaiting acknowledgements. Entries are added to this - // map when messages are sent and are removed when acknowledgement messages are received or when - // acknowledgement timeouts expire - private val messageStatuses = new HashMap[Int, MessageStatus] // [MessageId, MessageStatus] - private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] - private val registerRequests = new SynchronizedQueue[SendingConnection] - - implicit val futureExecContext = ExecutionContext.fromExecutor( - ThreadUtils.newDaemonCachedThreadPool("Connection manager future execution context")) - - @volatile - private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message] = null - - private val authEnabled = securityManager.isAuthenticationEnabled() - - serverChannel.configureBlocking(false) - serverChannel.socket.setReuseAddress(true) - serverChannel.socket.setReceiveBufferSize(256 * 1024) - - private def startService(port: Int): (ServerSocketChannel, Int) = { - serverChannel.socket.bind(new InetSocketAddress(port)) - (serverChannel, serverChannel.socket.getLocalPort) - } - Utils.startServiceOnPort[ServerSocketChannel](port, startService, conf, name) - serverChannel.register(selector, SelectionKey.OP_ACCEPT) - - val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) - logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id) - - // used in combination with the ConnectionManagerId to create unique Connection ids - // to be able to track asynchronous messages - private val idCount: AtomicInteger = new AtomicInteger(1) - - private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() - private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() - - @volatile private var isActive = true - private val selectorThread = new Thread("connection-manager-thread") { - override def run(): Unit = ConnectionManager.this.run() - } - selectorThread.setDaemon(true) - // start this thread last, since it invokes run(), which accesses members above - selectorThread.start() - - private def triggerWrite(key: SelectionKey) { - val conn = connectionsByKey.getOrElse(key, null) - if (conn == null) return - - writeRunnableStarted.synchronized { - // So that we do not trigger more write events while processing this one. - // The write method will re-register when done. - if (conn.changeInterestForWrite()) conn.unregisterInterest() - if (writeRunnableStarted.contains(key)) { - // key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE) - return - } - - writeRunnableStarted += key - } - handleReadWriteExecutor.execute(new Runnable { - override def run() { - try { - var register: Boolean = false - try { - register = conn.write() - } finally { - writeRunnableStarted.synchronized { - writeRunnableStarted -= key - val needReregister = register || conn.resetForceReregister() - if (needReregister && conn.changeInterestForWrite()) { - conn.registerInterest() - } - } - } - } catch { - case NonFatal(e) => { - logError("Error when writing to " + conn.getRemoteConnectionManagerId(), e) - conn.callOnExceptionCallbacks(e) - } - } - } - } ) - } - - - private def triggerRead(key: SelectionKey) { - val conn = connectionsByKey.getOrElse(key, null) - if (conn == null) return - - readRunnableStarted.synchronized { - // So that we do not trigger more read events while processing this one. - // The read method will re-register when done. - if (conn.changeInterestForRead())conn.unregisterInterest() - if (readRunnableStarted.contains(key)) { - return - } - - readRunnableStarted += key - } - handleReadWriteExecutor.execute(new Runnable { - override def run() { - try { - var register: Boolean = false - try { - register = conn.read() - } finally { - readRunnableStarted.synchronized { - readRunnableStarted -= key - if (register && conn.changeInterestForRead()) { - conn.registerInterest() - } - } - } - } catch { - case NonFatal(e) => { - logError("Error when reading from " + conn.getRemoteConnectionManagerId(), e) - conn.callOnExceptionCallbacks(e) - } - } - } - } ) - } - - private def triggerConnect(key: SelectionKey) { - val conn = connectionsByKey.getOrElse(key, null).asInstanceOf[SendingConnection] - if (conn == null) return - - // prevent other events from being triggered - // Since we are still trying to connect, we do not need to do the additional steps in - // triggerWrite - conn.changeConnectionKeyInterest(0) - - handleConnectExecutor.execute(new Runnable { - override def run() { - try { - var tries: Int = 10 - while (tries >= 0) { - if (conn.finishConnect(false)) return - // Sleep ? - Thread.sleep(1) - tries -= 1 - } - - // fallback to previous behavior : we should not really come here since this method was - // triggered since channel became connectable : but at times, the first finishConnect need - // not succeed : hence the loop to retry a few 'times'. - conn.finishConnect(true) - } catch { - case NonFatal(e) => { - logError("Error when finishConnect for " + conn.getRemoteConnectionManagerId(), e) - conn.callOnExceptionCallbacks(e) - } - } - } - } ) - } - - // MUST be called within selector loop - else deadlock. - private def triggerForceCloseByException(key: SelectionKey, e: Exception) { - try { - key.interestOps(0) - } catch { - // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) - } - - val conn = connectionsByKey.getOrElse(key, null) - if (conn == null) return - - // Pushing to connect threadpool - handleConnectExecutor.execute(new Runnable { - override def run() { - try { - conn.callOnExceptionCallbacks(e) - } catch { - // ignore exceptions - case NonFatal(e) => logDebug("Ignoring exception", e) - } - try { - conn.close() - } catch { - // ignore exceptions - case NonFatal(e) => logDebug("Ignoring exception", e) - } - } - }) - } - - - def run() { - try { - while (isActive) { - while (!registerRequests.isEmpty) { - val conn: SendingConnection = registerRequests.dequeue() - addListeners(conn) - conn.connect() - addConnection(conn) - } - - while(!keyInterestChangeRequests.isEmpty) { - val (key, ops) = keyInterestChangeRequests.dequeue() - - try { - if (key.isValid) { - val connection = connectionsByKey.getOrElse(key, null) - if (connection != null) { - val lastOps = key.interestOps() - key.interestOps(ops) - - // hot loop - prevent materialization of string if trace not enabled. - if (isTraceEnabled()) { - def intToOpStr(op: Int): String = { - val opStrs = ArrayBuffer[String]() - if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" - if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" - if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" - if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" - if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " - } - - logTrace("Changed key for connection to [" + - connection.getRemoteConnectionManagerId() + "] changed from [" + - intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") - } - } - } else { - logInfo("Key not valid ? " + key) - throw new CancelledKeyException() - } - } catch { - case e: CancelledKeyException => { - logInfo("key already cancelled ? " + key, e) - triggerForceCloseByException(key, e) - } - case e: Exception => { - logError("Exception processing key " + key, e) - triggerForceCloseByException(key, e) - } - } - } - - val selectedKeysCount = - try { - selector.select() - } catch { - // Explicitly only dealing with CancelledKeyException here since other exceptions - // should be dealt with differently. - case e: CancelledKeyException => - // Some keys within the selectors list are invalid/closed. clear them. - val allKeys = selector.keys().iterator() - - while (allKeys.hasNext) { - val key = allKeys.next() - try { - if (! key.isValid) { - logInfo("Key not valid ? " + key) - throw new CancelledKeyException() - } - } catch { - case e: CancelledKeyException => { - logInfo("key already cancelled ? " + key, e) - triggerForceCloseByException(key, e) - } - case e: Exception => { - logError("Exception processing key " + key, e) - triggerForceCloseByException(key, e) - } - } - } - 0 - - case e: ClosedSelectorException => - logDebug("Failed select() as selector is closed.", e) - return - } - - if (selectedKeysCount == 0) { - logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + - " keys") - } - if (selectorThread.isInterrupted) { - logInfo("Selector thread was interrupted!") - return - } - - if (0 != selectedKeysCount) { - val selectedKeys = selector.selectedKeys().iterator() - while (selectedKeys.hasNext) { - val key = selectedKeys.next - selectedKeys.remove() - try { - if (key.isValid) { - if (key.isAcceptable) { - acceptConnection(key) - } else - if (key.isConnectable) { - triggerConnect(key) - } else - if (key.isReadable) { - triggerRead(key) - } else - if (key.isWritable) { - triggerWrite(key) - } - } else { - logInfo("Key not valid ? " + key) - throw new CancelledKeyException() - } - } catch { - // weird, but we saw this happening - even though key.isValid was true, - // key.isAcceptable would throw CancelledKeyException. - case e: CancelledKeyException => { - logInfo("key already cancelled ? " + key, e) - triggerForceCloseByException(key, e) - } - case e: Exception => { - logError("Exception processing key " + key, e) - triggerForceCloseByException(key, e) - } - } - } - } - } - } catch { - case e: Exception => logError("Error in select loop", e) - } - } - - def acceptConnection(key: SelectionKey) { - val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] - - var newChannel = serverChannel.accept() - - // accept them all in a tight loop. non blocking accept with no processing, should be fine - while (newChannel != null) { - try { - val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) - val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId, - securityManager) - newConnection.onReceive(receiveMessage) - addListeners(newConnection) - addConnection(newConnection) - logInfo("Accepted connection from [" + newConnection.remoteAddress + "]") - } catch { - // might happen in case of issues with registering with selector - case e: Exception => logError("Error in accept loop", e) - } - - newChannel = serverChannel.accept() - } - } - - private def addListeners(connection: Connection) { - connection.onKeyInterestChange(changeConnectionKeyInterest) - connection.onException(handleConnectionError) - connection.onClose(removeConnection) - } - - def addConnection(connection: Connection) { - connectionsByKey += ((connection.key, connection)) - } - - def removeConnection(connection: Connection) { - connectionsByKey -= connection.key - - try { - connection match { - case sendingConnection: SendingConnection => - val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() - logInfo("Removing SendingConnection to " + sendingConnectionManagerId) - - connectionsById -= sendingConnectionManagerId - connectionsAwaitingSasl -= connection.connectionId - - messageStatuses.synchronized { - messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId) - .foreach(status => { - logInfo("Notifying " + status) - status.failWithoutAck() - }) - - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) - } - case receivingConnection: ReceivingConnection => - val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId() - logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId) - - val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId) - if (!sendingConnectionOpt.isDefined) { - logError(s"Corresponding SendingConnection to ${remoteConnectionManagerId} not found") - return - } - - val sendingConnection = sendingConnectionOpt.get - connectionsById -= remoteConnectionManagerId - sendingConnection.close() - - val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() - - assert(sendingConnectionManagerId == remoteConnectionManagerId) - - messageStatuses.synchronized { - for (s <- messageStatuses.values - if s.connectionManagerId == sendingConnectionManagerId) { - logInfo("Notifying " + s) - s.failWithoutAck() - } - - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) - } - case _ => logError("Unsupported type of connection.") - } - } finally { - // So that the selection keys can be removed. - wakeupSelector() - } - } - - def handleConnectionError(connection: Connection, e: Throwable) { - logInfo("Handling connection error on connection to " + - connection.getRemoteConnectionManagerId()) - removeConnection(connection) - } - - def changeConnectionKeyInterest(connection: Connection, ops: Int) { - keyInterestChangeRequests += ((connection.key, ops)) - // so that registrations happen ! - wakeupSelector() - } - - def receiveMessage(connection: Connection, message: Message) { - val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress) - logDebug("Received [" + message + "] from [" + connectionManagerId + "]") - val runnable = new Runnable() { - val creationTime = System.currentTimeMillis - def run() { - try { - logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") - handleMessage(connectionManagerId, message, connection) - logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") - } catch { - case NonFatal(e) => { - logError("Error when handling messages from " + - connection.getRemoteConnectionManagerId(), e) - connection.callOnExceptionCallbacks(e) - } - } - } - } - handleMessageExecutor.execute(runnable) - /* handleMessage(connection, message) */ - } - - private def handleClientAuthentication( - waitingConn: SendingConnection, - securityMsg: SecurityMessage, - connectionId : ConnectionId) { - if (waitingConn.isSaslComplete()) { - logDebug("Client sasl completed for id: " + waitingConn.connectionId) - connectionsAwaitingSasl -= waitingConn.connectionId - waitingConn.registerAfterAuth() - wakeupSelector() - return - } else { - var replyToken : Array[Byte] = null - try { - replyToken = waitingConn.sparkSaslClient.response(securityMsg.getToken) - if (waitingConn.isSaslComplete()) { - logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId) - connectionsAwaitingSasl -= waitingConn.connectionId - waitingConn.registerAfterAuth() - wakeupSelector() - return - } - val securityMsgResp = SecurityMessage.fromResponse(replyToken, - securityMsg.getConnectionId.toString) - val message = securityMsgResp.toBufferMessage - if (message == null) throw new IOException("Error creating security message") - sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) - } catch { - case e: Exception => - logError("Error handling sasl client authentication", e) - waitingConn.close() - throw new IOException("Error evaluating sasl response: ", e) - } - } - } - - private def handleServerAuthentication( - connection: Connection, - securityMsg: SecurityMessage, - connectionId: ConnectionId) { - if (!connection.isSaslComplete()) { - logDebug("saslContext not established") - var replyToken : Array[Byte] = null - try { - connection.synchronized { - if (connection.sparkSaslServer == null) { - logDebug("Creating sasl Server") - connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager, false) - } - } - replyToken = connection.sparkSaslServer.response(securityMsg.getToken) - if (connection.isSaslComplete()) { - logDebug("Server sasl completed: " + connection.connectionId + - " for: " + connectionId) - } else { - logDebug("Server sasl not completed: " + connection.connectionId + - " for: " + connectionId) - } - if (replyToken != null) { - val securityMsgResp = SecurityMessage.fromResponse(replyToken, - securityMsg.getConnectionId) - val message = securityMsgResp.toBufferMessage - if (message == null) throw new Exception("Error creating security Message") - sendSecurityMessage(connection.getRemoteConnectionManagerId(), message) - } - } catch { - case e: Exception => { - logError("Error in server auth negotiation: " + e) - // It would probably be better to send an error message telling other side auth failed - // but for now just close - connection.close() - } - } - } else { - logDebug("connection already established for this connection id: " + connection.connectionId) - } - } - - - private def handleAuthentication(conn: Connection, bufferMessage: BufferMessage): Boolean = { - if (bufferMessage.isSecurityNeg) { - logDebug("This is security neg message") - - // parse as SecurityMessage - val securityMsg = SecurityMessage.fromBufferMessage(bufferMessage) - val connectionId = ConnectionId.createConnectionIdFromString(securityMsg.getConnectionId) - - connectionsAwaitingSasl.get(connectionId) match { - case Some(waitingConn) => { - // Client - this must be in response to us doing Send - logDebug("Client handleAuth for id: " + waitingConn.connectionId) - handleClientAuthentication(waitingConn, securityMsg, connectionId) - } - case None => { - // Server - someone sent us something and we haven't authenticated yet - logDebug("Server handleAuth for id: " + connectionId) - handleServerAuthentication(conn, securityMsg, connectionId) - } - } - return true - } else { - if (!conn.isSaslComplete()) { - // We could handle this better and tell the client we need to do authentication - // negotiation, but for now just ignore them. - logError("message sent that is not security negotiation message on connection " + - "not authenticated yet, ignoring it!!") - return true - } - } - false - } - - private def handleMessage( - connectionManagerId: ConnectionManagerId, - message: Message, - connection: Connection) { - logDebug("Handling [" + message + "] from [" + connectionManagerId + "]") - message match { - case bufferMessage: BufferMessage => { - if (authEnabled) { - val res = handleAuthentication(connection, bufferMessage) - if (res) { - // message was security negotiation so skip the rest - logDebug("After handleAuth result was true, returning") - return - } - } - if (bufferMessage.hasAckId()) { - messageStatuses.synchronized { - messageStatuses.get(bufferMessage.ackId) match { - case Some(status) => { - messageStatuses -= bufferMessage.ackId - status.success(message) - } - case None => { - /** - * We can fall down on this code because of following 2 cases - * - * (1) Invalid ack sent due to buggy code. - * - * (2) Late-arriving ack for a SendMessageStatus - * To avoid unwilling late-arriving ack - * caused by long pause like GC, you can set - * larger value than default to spark.core.connection.ack.wait.timeout - */ - logWarning(s"Could not find reference for received ack Message ${message.id}") - } - } - } - } else { - var ackMessage : Option[Message] = None - try { - ackMessage = if (onReceiveCallback != null) { - logDebug("Calling back") - onReceiveCallback(bufferMessage, connectionManagerId) - } else { - logDebug("Not calling back as callback is null") - None - } - - if (ackMessage.isDefined) { - if (!ackMessage.get.isInstanceOf[BufferMessage]) { - logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " - + ackMessage.get.getClass) - } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) { - logDebug("Response to " + bufferMessage + " does not have ack id set") - ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id - } - } - } catch { - case e: Exception => { - logError(s"Exception was thrown while processing message", e) - ackMessage = Some(Message.createErrorMessage(e, bufferMessage.id)) - } - } finally { - sendMessage(connectionManagerId, ackMessage.getOrElse { - Message.createBufferMessage(bufferMessage.id) - }) - } - } - } - case _ => throw new Exception("Unknown type message received") - } - } - - private def checkSendAuthFirst(connManagerId: ConnectionManagerId, conn: SendingConnection) { - // see if we need to do sasl before writing - // this should only be the first negotiation as the Client!!! - if (!conn.isSaslComplete()) { - conn.synchronized { - if (conn.sparkSaslClient == null) { - conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager, false) - var firstResponse: Array[Byte] = null - try { - firstResponse = conn.sparkSaslClient.firstToken() - val securityMsg = SecurityMessage.fromResponse(firstResponse, - conn.connectionId.toString()) - val message = securityMsg.toBufferMessage - if (message == null) throw new Exception("Error creating security message") - connectionsAwaitingSasl += ((conn.connectionId, conn)) - sendSecurityMessage(connManagerId, message) - logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId + - " to: " + connManagerId) - } catch { - case e: Exception => { - logError("Error getting first response from the SaslClient.", e) - conn.close() - throw new Exception("Error getting first response from the SaslClient") - } - } - } - } - } else { - logDebug("Sasl already established ") - } - } - - // allow us to add messages to the inbox for doing sasl negotiating - private def sendSecurityMessage(connManagerId: ConnectionManagerId, message: Message) { - def startNewConnection(): SendingConnection = { - val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port) - val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) - val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId, - newConnectionId, securityManager) - logInfo("creating new sending connection for security! " + newConnectionId ) - registerRequests.enqueue(newConnection) - - newConnection - } - // I removed the lookupKey stuff as part of merge ... should I re-add it ? - // We did not find it useful in our test-env ... - // If we do re-add it, we should consistently use it everywhere I guess ? - message.senderAddress = id.toSocketAddress() - logTrace("Sending Security [" + message + "] to [" + connManagerId + "]") - val connection = connectionsById.getOrElseUpdate(connManagerId, startNewConnection()) - - // send security message until going connection has been authenticated - connection.send(message) - - wakeupSelector() - } - - private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { - def startNewConnection(): SendingConnection = { - val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, - connectionManagerId.port) - val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) - val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId, - newConnectionId, securityManager) - newConnection.onException { - case (conn, e) => { - logError("Exception while sending message.", e) - reportSendingMessageFailure(message.id, e) - } - } - logTrace("creating new sending connection: " + newConnectionId) - registerRequests.enqueue(newConnection) - - newConnection - } - val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection()) - - message.senderAddress = id.toSocketAddress() - logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " + - "connectionid: " + connection.connectionId) - - if (authEnabled) { - try { - checkSendAuthFirst(connectionManagerId, connection) - } catch { - case NonFatal(e) => { - reportSendingMessageFailure(message.id, e) - } - } - } - logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") - connection.send(message) - wakeupSelector() - } - - private def reportSendingMessageFailure(messageId: Int, e: Throwable): Unit = { - // need to tell sender it failed - messageStatuses.synchronized { - val s = messageStatuses.get(messageId) - s match { - case Some(msgStatus) => { - messageStatuses -= messageId - logInfo("Notifying " + msgStatus.connectionManagerId) - msgStatus.failure(e) - } - case None => { - logError("no messageStatus for failed message id: " + messageId) - } - } - } - } - - private def wakeupSelector() { - selector.wakeup() - } - - /** - * Send a message and block until an acknowledgment is received or an error occurs. - * @param connectionManagerId the message's destination - * @param message the message being sent - * @return a Future that either returns the acknowledgment message or captures an exception. - */ - def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message) - : Future[Message] = { - val promise = Promise[Message]() - - // It's important that the TimerTask doesn't capture a reference to `message`, which can cause - // memory leaks since cancelled TimerTasks won't necessarily be garbage collected until the time - // at which they would originally be scheduled to run. Therefore, extract the message id - // from outside of the TimerTask closure (see SPARK-4393 for more context). - val messageId = message.id - // Keep a weak reference to the promise so that the completed promise may be garbage-collected - val promiseReference = new WeakReference(promise) - val timeoutTask: TimerTask = new TimerTask { - override def run(timeout: Timeout): Unit = { - messageStatuses.synchronized { - messageStatuses.remove(messageId).foreach { s => - val e = new IOException("sendMessageReliably failed because ack " + - s"was not received within $ackTimeout sec") - val p = promiseReference.get - if (p != null) { - // Attempt to fail the promise with a Timeout exception - if (!p.tryFailure(e)) { - // If we reach here, then someone else has already signalled success or failure - // on this promise, so log a warning: - logError("Ignore error because promise is completed", e) - } - } else { - // The WeakReference was empty, which should never happen because - // sendMessageReliably's caller should have a strong reference to promise.future; - logError("Promise was garbage collected; this should never happen!", e) - } - } - } - } - } - - val timeoutTaskHandle = ackTimeoutMonitor.newTimeout(timeoutTask, ackTimeout, TimeUnit.SECONDS) - - val status = new MessageStatus(message, connectionManagerId, s => { - timeoutTaskHandle.cancel() - s match { - case scala.util.Failure(e) => - // Indicates a failure where we either never sent or never got ACK'd - if (!promise.tryFailure(e)) { - logWarning("Ignore error because promise is completed", e) - } - case scala.util.Success(ackMessage) => - if (ackMessage.hasError) { - val errorMsgByteBuf = ackMessage.asInstanceOf[BufferMessage].buffers.head - val errorMsgBytes = new Array[Byte](errorMsgByteBuf.limit()) - errorMsgByteBuf.get(errorMsgBytes) - val errorMsg = new String(errorMsgBytes, UTF_8) - val e = new IOException( - s"sendMessageReliably failed with ACK that signalled a remote error: $errorMsg") - if (!promise.tryFailure(e)) { - logWarning("Ignore error because promise is completed", e) - } - } else { - if (!promise.trySuccess(ackMessage)) { - logWarning("Drop ackMessage because promise is completed") - } - } - } - }) - messageStatuses.synchronized { - messageStatuses += ((message.id, status)) - } - - sendMessage(connectionManagerId, message) - promise.future - } - - def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) { - onReceiveCallback = callback - } - - def stop() { - isActive = false - ackTimeoutMonitor.stop() - selector.close() - selectorThread.interrupt() - selectorThread.join() - val connections = connectionsByKey.values - connections.foreach(_.close()) - if (connectionsByKey.size != 0) { - logWarning("All connections not cleaned up") - } - handleMessageExecutor.shutdown() - handleReadWriteExecutor.shutdown() - handleConnectExecutor.shutdown() - logInfo("ConnectionManager stopped") - } -} - - -private[spark] object ConnectionManager { - import scala.concurrent.ExecutionContext.Implicits.global - - def main(args: Array[String]) { - val conf = new SparkConf - val manager = new ConnectionManager(9999, conf, new SecurityManager(conf)) - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - println("Received [" + msg + "] from [" + id + "]") - None - }) - - /* testSequentialSending(manager) */ - /* System.gc() */ - - /* testParallelSending(manager) */ - /* System.gc() */ - - /* testParallelDecreasingSending(manager) */ - /* System.gc() */ - - testContinuousSending(manager) - System.gc() - } - - def testSequentialSending(manager: ConnectionManager) { - println("--------------------------") - println("Sequential Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - Await.result(manager.sendMessageReliably(manager.id, bufferMessage), Duration.Inf) - }) - println("--------------------------") - println() - } - - def testParallelSending(manager: ConnectionManager) { - println("--------------------------") - println("Parallel Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => { - f.onFailure { - case e => println("Failed due to " + e) - } - Await.ready(f, 1 second) - }) - val finishTime = System.currentTimeMillis - - val mb = size * count / 1024.0 / 1024.0 - val ms = finishTime - startTime - val tput = mb * 1000.0 / ms - println("--------------------------") - println("Started at " + startTime + ", finished at " + finishTime) - println("Sent " + count + " messages of size " + size + " in " + ms + " ms " + - "(" + tput + " MB/s)") - println("--------------------------") - println() - } - - def testParallelDecreasingSending(manager: ConnectionManager) { - println("--------------------------") - println("Parallel Decreasing Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - val buffers = Array.tabulate(count) { i => - val bufferLen = size * (i + 1) - val bufferContent = Array.tabulate[Byte](bufferLen)(x => x.toByte) - ByteBuffer.allocate(bufferLen).put(bufferContent) - } - buffers.foreach(_.flip) - val mb = buffers.map(_.remaining).reduceLeft(_ + _) / 1024.0 / 1024.0 - - val startTime = System.currentTimeMillis - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate) - manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => { - f.onFailure { - case e => println("Failed due to " + e) - } - Await.ready(f, 1 second) - }) - val finishTime = System.currentTimeMillis - - val ms = finishTime - startTime - val tput = mb * 1000.0 / ms - println("--------------------------") - /* println("Started at " + startTime + ", finished at " + finishTime) */ - println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") - println("--------------------------") - println() - } - - def testContinuousSending(manager: ConnectionManager) { - println("--------------------------") - println("Continuous Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - while(true) { - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => { - f.onFailure { - case e => println("Failed due to " + e) - } - Await.ready(f, 1 second) - }) - val finishTime = System.currentTimeMillis - Thread.sleep(1000) - val mb = size * count / 1024.0 / 1024.0 - val ms = finishTime - startTime - val tput = mb * 1000.0 / ms - println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") - println("--------------------------") - println() - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala deleted file mode 100644 index 85d2fe2bf9c2..000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/Message.scala +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.net.InetSocketAddress -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer - -import com.google.common.base.Charsets.UTF_8 - -import org.apache.spark.util.Utils - -private[nio] abstract class Message(val typ: Long, val id: Int) { - var senderAddress: InetSocketAddress = null - var started = false - var startTime = -1L - var finishTime = -1L - var isSecurityNeg = false - var hasError = false - - def size: Int - - def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] - - def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] - - def timeTaken(): String = (finishTime - startTime).toString + " ms" - - override def toString: String = { - this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" - } -} - - -private[nio] object Message { - val BUFFER_MESSAGE = 1111111111L - - var lastId = 1 - - def getNewId(): Int = synchronized { - lastId += 1 - if (lastId == 0) { - lastId += 1 - } - lastId - } - - def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = { - if (dataBuffers == null) { - return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId) - } - if (dataBuffers.exists(_ == null)) { - throw new Exception("Attempting to create buffer message with null buffer") - } - new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId) - } - - def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage = - createBufferMessage(dataBuffers, 0) - - def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = { - if (dataBuffer == null) { - createBufferMessage(Array(ByteBuffer.allocate(0)), ackId) - } else { - createBufferMessage(Array(dataBuffer), ackId) - } - } - - def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage = - createBufferMessage(dataBuffer, 0) - - def createBufferMessage(ackId: Int): BufferMessage = { - createBufferMessage(new Array[ByteBuffer](0), ackId) - } - - /** - * Create a "negative acknowledgment" to notify a sender that an error occurred - * while processing its message. The exception's stacktrace will be formatted - * as a string, serialized into a byte array, and sent as the message payload. - */ - def createErrorMessage(exception: Exception, ackId: Int): BufferMessage = { - val exceptionString = Utils.exceptionString(exception) - val serializedExceptionString = ByteBuffer.wrap(exceptionString.getBytes(UTF_8)) - val errorMessage = createBufferMessage(serializedExceptionString, ackId) - errorMessage.hasError = true - errorMessage - } - - def create(header: MessageChunkHeader): Message = { - val newMessage: Message = header.typ match { - case BUFFER_MESSAGE => new BufferMessage(header.id, - ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) - } - newMessage.hasError = header.hasError - newMessage.senderAddress = header.address - newMessage - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala deleted file mode 100644 index 7b3da4bb9d5e..000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.net.{InetAddress, InetSocketAddress} -import java.nio.ByteBuffer - -private[nio] class MessageChunkHeader( - val typ: Long, - val id: Int, - val totalSize: Int, - val chunkSize: Int, - val other: Int, - val hasError: Boolean, - val securityNeg: Int, - val address: InetSocketAddress) { - lazy val buffer = { - // No need to change this, at 'use' time, we do a reverse lookup of the hostname. - // Refer to network.Connection - val ip = address.getAddress.getAddress() - val port = address.getPort() - ByteBuffer. - allocate(MessageChunkHeader.HEADER_SIZE). - putLong(typ). - putInt(id). - putInt(totalSize). - putInt(chunkSize). - putInt(other). - put(if (hasError) 1.asInstanceOf[Byte] else 0.asInstanceOf[Byte]). - putInt(securityNeg). - putInt(ip.size). - put(ip). - putInt(port). - position(MessageChunkHeader.HEADER_SIZE). - flip.asInstanceOf[ByteBuffer] - } - - override def toString: String = { - "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + - " and sizes " + totalSize + " / " + chunkSize + " bytes, securityNeg: " + securityNeg - } - -} - - -private[nio] object MessageChunkHeader { - val HEADER_SIZE = 45 - - def create(buffer: ByteBuffer): MessageChunkHeader = { - if (buffer.remaining != HEADER_SIZE) { - throw new IllegalArgumentException("Cannot convert buffer data to Message") - } - val typ = buffer.getLong() - val id = buffer.getInt() - val totalSize = buffer.getInt() - val chunkSize = buffer.getInt() - val other = buffer.getInt() - val hasError = buffer.get() != 0 - val securityNeg = buffer.getInt() - val ipSize = buffer.getInt() - val ipBytes = new Array[Byte](ipSize) - buffer.get(ipBytes) - val ip = InetAddress.getByAddress(ipBytes) - val port = buffer.getInt() - new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, securityNeg, - new InetSocketAddress(ip, port)) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala deleted file mode 100644 index b2aec160635c..000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ /dev/null @@ -1,217 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import org.apache.spark.network._ -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.shuffle.BlockFetchingListener -import org.apache.spark.storage.{BlockId, StorageLevel} -import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} - -import scala.concurrent.Future - - -/** - * A [[BlockTransferService]] implementation based on [[ConnectionManager]], a custom - * implementation using Java NIO. - */ -final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityManager) - extends BlockTransferService with Logging { - - private var cm: ConnectionManager = _ - - private var blockDataManager: BlockDataManager = _ - - /** - * Port number the service is listening on, available only after [[init]] is invoked. - */ - override def port: Int = { - checkInit() - cm.id.port - } - - /** - * Host name the service is listening on, available only after [[init]] is invoked. - */ - override def hostName: String = { - checkInit() - cm.id.host - } - - /** - * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch - * local blocks or put local blocks. - */ - override def init(blockDataManager: BlockDataManager): Unit = { - this.blockDataManager = blockDataManager - cm = new ConnectionManager( - conf.getInt("spark.blockManager.port", 0), - conf, - securityManager, - "Connection manager for block manager") - cm.onReceiveMessage(onBlockMessageReceive) - } - - /** - * Tear down the transfer service. - */ - override def close(): Unit = { - if (cm != null) { - cm.stop() - } - } - - override def fetchBlocks( - host: String, - port: Int, - execId: String, - blockIds: Array[String], - listener: BlockFetchingListener): Unit = { - checkInit() - - val cmId = new ConnectionManagerId(host, port) - val blockMessageArray = new BlockMessageArray(blockIds.map { blockId => - BlockMessage.fromGetBlock(GetBlock(BlockId(blockId))) - }) - - val future = cm.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) - - // Register the listener on success/failure future callback. - future.onSuccess { case message => - val bufferMessage = message.asInstanceOf[BufferMessage] - val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - - // SPARK-4064: In some cases(eg. Remote block was removed) blockMessageArray may be empty. - if (blockMessageArray.isEmpty) { - blockIds.foreach { id => - listener.onBlockFetchFailure(id, new SparkException(s"Received empty message from $cmId")) - } - } else { - for (blockMessage: BlockMessage <- blockMessageArray) { - val msgType = blockMessage.getType - if (msgType != BlockMessage.TYPE_GOT_BLOCK) { - if (blockMessage.getId != null) { - listener.onBlockFetchFailure(blockMessage.getId.toString, - new SparkException(s"Unexpected message $msgType received from $cmId")) - } - } else { - val blockId = blockMessage.getId - val networkSize = blockMessage.getData.limit() - listener.onBlockFetchSuccess( - blockId.toString, new NioManagedBuffer(blockMessage.getData)) - } - } - } - }(cm.futureExecContext) - - future.onFailure { case exception => - blockIds.foreach { blockId => - listener.onBlockFetchFailure(blockId, exception) - } - }(cm.futureExecContext) - } - - /** - * Upload a single block to a remote node, available only after [[init]] is invoked. - * - * This call blocks until the upload completes, or throws an exception upon failures. - */ - override def uploadBlock( - hostname: String, - port: Int, - execId: String, - blockId: BlockId, - blockData: ManagedBuffer, - level: StorageLevel) - : Future[Unit] = { - checkInit() - val msg = PutBlock(blockId, blockData.nioByteBuffer(), level) - val blockMessageArray = new BlockMessageArray(BlockMessage.fromPutBlock(msg)) - val remoteCmId = new ConnectionManagerId(hostName, port) - val reply = cm.sendMessageReliably(remoteCmId, blockMessageArray.toBufferMessage) - reply.map(x => ())(cm.futureExecContext) - } - - private def checkInit(): Unit = if (cm == null) { - throw new IllegalStateException(getClass.getName + " has not been initialized") - } - - private def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = { - logDebug("Handling message " + msg) - msg match { - case bufferMessage: BufferMessage => - try { - logDebug("Handling as a buffer message " + bufferMessage) - val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage) - logDebug("Parsed as a block message array") - val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) - Some(new BlockMessageArray(responseMessages).toBufferMessage) - } catch { - case e: Exception => - logError("Exception handling buffer message", e) - Some(Message.createErrorMessage(e, msg.id)) - } - - case otherMessage: Any => - val errorMsg = s"Received unknown message type: ${otherMessage.getClass.getName}" - logError(errorMsg) - Some(Message.createErrorMessage(new UnsupportedOperationException(errorMsg), msg.id)) - } - } - - private def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = { - blockMessage.getType match { - case BlockMessage.TYPE_PUT_BLOCK => - val msg = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) - logDebug("Received [" + msg + "]") - putBlock(msg.id, msg.data, msg.level) - None - - case BlockMessage.TYPE_GET_BLOCK => - val msg = new GetBlock(blockMessage.getId) - logDebug("Received [" + msg + "]") - val buffer = getBlock(msg.id) - if (buffer == null) { - return None - } - Some(BlockMessage.fromGotBlock(GotBlock(msg.id, buffer))) - - case _ => None - } - } - - private def putBlock(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) { - val startTimeMs = System.currentTimeMillis() - logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with data: " + bytes) - blockDataManager.putBlockData(blockId, new NioManagedBuffer(bytes), level) - logDebug("PutBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) - + " with data size: " + bytes.limit) - } - - private def getBlock(blockId: BlockId): ByteBuffer = { - val startTimeMs = System.currentTimeMillis() - logDebug("GetBlock " + blockId + " started from " + startTimeMs) - val buffer = blockDataManager.getBlockData(blockId) - logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) - + " and got buffer " + buffer) - buffer.nioByteBuffer() - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala deleted file mode 100644 index 232c552f9865..000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import scala.collection.mutable.{ArrayBuffer, StringBuilder} - -import org.apache.spark._ - -/** - * SecurityMessage is class that contains the connectionId and sasl token - * used in SASL negotiation. SecurityMessage has routines for converting - * it to and from a BufferMessage so that it can be sent by the ConnectionManager - * and easily consumed by users when received. - * The api was modeled after BlockMessage. - * - * The connectionId is the connectionId of the client side. Since - * message passing is asynchronous and its possible for the server side (receiving) - * to get multiple different types of messages on the same connection the connectionId - * is used to know which connnection the security message is intended for. - * - * For instance, lets say we are node_0. We need to send data to node_1. The node_0 side - * is acting as a client and connecting to node_1. SASL negotiation has to occur - * between node_0 and node_1 before node_1 trusts node_0 so node_0 sends a security message. - * node_1 receives the message from node_0 but before it can process it and send a response, - * some thread on node_1 decides it needs to send data to node_0 so it connects to node_0 - * and sends a security message of its own to authenticate as a client. Now node_0 gets - * the message and it needs to decide if this message is in response to it being a client - * (from the first send) or if its just node_1 trying to connect to it to send data. This - * is where the connectionId field is used. node_0 can lookup the connectionId to see if - * it is in response to it being a client or if its in response to someone sending other data. - * - * The format of a SecurityMessage as its sent is: - * - Length of the ConnectionId - * - ConnectionId - * - Length of the token - * - Token - */ -private[nio] class SecurityMessage extends Logging { - - private var connectionId: String = null - private var token: Array[Byte] = null - - def set(byteArr: Array[Byte], newconnectionId: String) { - if (byteArr == null) { - token = new Array[Byte](0) - } else { - token = byteArr - } - connectionId = newconnectionId - } - - /** - * Read the given buffer and set the members of this class. - */ - def set(buffer: ByteBuffer) { - val idLength = buffer.getInt() - val idBuilder = new StringBuilder(idLength) - for (i <- 1 to idLength) { - idBuilder += buffer.getChar() - } - connectionId = idBuilder.toString() - - val tokenLength = buffer.getInt() - token = new Array[Byte](tokenLength) - if (tokenLength > 0) { - buffer.get(token, 0, tokenLength) - } - } - - def set(bufferMsg: BufferMessage) { - val buffer = bufferMsg.buffers.apply(0) - buffer.clear() - set(buffer) - } - - def getConnectionId: String = { - return connectionId - } - - def getToken: Array[Byte] = { - return token - } - - /** - * Create a BufferMessage that can be sent by the ConnectionManager containing - * the security information from this class. - * @return BufferMessage - */ - def toBufferMessage: BufferMessage = { - val buffers = new ArrayBuffer[ByteBuffer]() - - // 4 bytes for the length of the connectionId - // connectionId is of type char so multiple the length by 2 to get number of bytes - // 4 bytes for the length of token - // token is a byte buffer so just take the length - var buffer = ByteBuffer.allocate(4 + connectionId.length() * 2 + 4 + token.length) - buffer.putInt(connectionId.length()) - connectionId.foreach((x: Char) => buffer.putChar(x)) - buffer.putInt(token.length) - - if (token.length > 0) { - buffer.put(token) - } - buffer.flip() - buffers += buffer - - var message = Message.createBufferMessage(buffers) - logDebug("message total size is : " + message.size) - message.isSecurityNeg = true - return message - } - - override def toString: String = { - "SecurityMessage [connId= " + connectionId + ", Token = " + token + "]" - } -} - -private[nio] object SecurityMessage { - - /** - * Convert the given BufferMessage to a SecurityMessage by parsing the contents - * of the BufferMessage and populating the SecurityMessage fields. - * @param bufferMessage is a BufferMessage that was received - * @return new SecurityMessage - */ - def fromBufferMessage(bufferMessage: BufferMessage): SecurityMessage = { - val newSecurityMessage = new SecurityMessage() - newSecurityMessage.set(bufferMessage) - newSecurityMessage - } - - /** - * Create a SecurityMessage to send from a given saslResponse. - * @param response is the response to a challenge from the SaslClient or Saslserver - * @param connectionId the client connectionId we are negotiation authentication for - * @return a new SecurityMessage - */ - def fromResponse(response : Array[Byte], connectionId : String) : SecurityMessage = { - val newSecurityMessage = new SecurityMessage() - newSecurityMessage.set(response, connectionId) - newSecurityMessage - } -} diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 8ae76c5f72f2..7515aad09db7 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -43,5 +43,5 @@ package org.apache package object spark { // For package docs only - val SPARK_VERSION = "1.5.0-SNAPSHOT" + val SPARK_VERSION = "1.6.0-SNAPSHOT" } diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala index 91b07ce3af1b..5afce75680f9 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.partial import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.HashMap import scala.reflect.ClassTag @@ -48,9 +48,9 @@ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, conf if (outputsMerged == totalOutputs) { val result = new JHashMap[T, BoundedDouble](sums.size) sums.foreach { case (key, sum) => - result(key) = new BoundedDouble(sum, 1.0, sum, sum) + result.put(key, new BoundedDouble(sum, 1.0, sum, sum)) } - result + result.asScala } else if (outputsMerged == 0) { new HashMap[T, BoundedDouble] } else { @@ -64,9 +64,9 @@ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, conf val stdev = math.sqrt(variance) val low = mean - confFactor * stdev val high = mean + confFactor * stdev - result(key) = new BoundedDouble(mean, confidence, low, high) + result.put(key, new BoundedDouble(mean, confidence, low, high)) } - result + result.asScala } } } diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala index af26c3d59ac0..a16404068480 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.partial import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.HashMap @@ -55,9 +55,9 @@ private[spark] class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Doub while (iter.hasNext) { val entry = iter.next() val mean = entry.getValue.mean - result(entry.getKey) = new BoundedDouble(mean, 1.0, mean, mean) + result.put(entry.getKey, new BoundedDouble(mean, 1.0, mean, mean)) } - result + result.asScala } else if (outputsMerged == 0) { new HashMap[T, BoundedDouble] } else { @@ -72,9 +72,9 @@ private[spark] class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Doub val confFactor = studentTCacher.get(counter.count) val low = mean - confFactor * stdev val high = mean + confFactor * stdev - result(entry.getKey) = new BoundedDouble(mean, confidence, low, high) + result.put(entry.getKey, new BoundedDouble(mean, confidence, low, high)) } - result + result.asScala } } } diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala index 442fb86227d8..54a1beab3514 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.partial import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.HashMap @@ -55,9 +55,9 @@ private[spark] class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Doubl while (iter.hasNext) { val entry = iter.next() val sum = entry.getValue.sum - result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum) + result.put(entry.getKey, new BoundedDouble(sum, 1.0, sum, sum)) } - result + result.asScala } else if (outputsMerged == 0) { new HashMap[T, BoundedDouble] } else { @@ -80,9 +80,9 @@ private[spark] class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Doubl val confFactor = studentTCacher.get(counter.count) val low = sumEstimate - confFactor * sumStdev val high = sumEstimate + confFactor * sumStdev - result(entry.getKey) = new BoundedDouble(sumEstimate, confidence, low, high) + result.put(entry.getKey, new BoundedDouble(sumEstimate, confidence, low, high)) } - result + result.asScala } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index 1f755db48581..6fec00dcd0d8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -28,7 +28,7 @@ private[spark] class BinaryFileRDD[T]( inputFormatClass: Class[_ <: StreamFileInputFormat[T]], keyClass: Class[String], valueClass: Class[T], - @transient conf: Configuration, + conf: Configuration, minPartitions: Int) extends NewHadoopRDD[String, T](sc, inputFormatClass, keyClass, valueClass, conf) { @@ -36,10 +36,10 @@ private[spark] class BinaryFileRDD[T]( val inputFormat = inputFormatClass.newInstance inputFormat match { case configurable: Configurable => - configurable.setConf(conf) + configurable.setConf(getConf) case _ => } - val jobContext = newJobContext(conf, jobId) + val jobContext = newJobContext(getConf, jobId) inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index 922030263756..fc1710fbad0a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -28,7 +28,7 @@ private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends P } private[spark] -class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds: Array[BlockId]) +class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[BlockId]) extends RDD[T](sc, Nil) { @transient lazy val _locations = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get) @@ -64,7 +64,7 @@ class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds */ private[spark] def removeBlocks() { blockIds.foreach { blockId => - sc.env.blockManager.master.removeBlock(blockId) + sparkContext.env.blockManager.master.removeBlock(blockId) } _isValid = false } diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index c1d697178757..18e8cddbc40d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -27,8 +27,8 @@ import org.apache.spark.util.Utils private[spark] class CartesianPartition( idx: Int, - @transient rdd1: RDD[_], - @transient rdd2: RDD[_], + @transient private val rdd1: RDD[_], + @transient private val rdd2: RDD[_], s1Index: Int, s2Index: Int ) extends Partition { diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 33e6998b2cb1..b0364623af4c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -17,157 +17,31 @@ package org.apache.spark.rdd -import java.io.IOException - import scala.reflect.ClassTag -import org.apache.hadoop.fs.Path - -import org.apache.spark._ -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.{Partition, SparkContext, TaskContext} -private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} +/** + * An RDD partition used to recover checkpointed data. + */ +private[spark] class CheckpointRDDPartition(val index: Int) extends Partition /** - * This RDD represents a RDD checkpoint file (similar to HadoopRDD). + * An RDD that recovers checkpointed data from storage. */ -private[spark] -class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) +private[spark] abstract class CheckpointRDD[T: ClassTag](sc: SparkContext) extends RDD[T](sc, Nil) { - val broadcastedConf = sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration)) - - @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) - - override def getPartitions: Array[Partition] = { - val cpath = new Path(checkpointPath) - val numPartitions = - // listStatus can throw exception if path does not exist. - if (fs.exists(cpath)) { - val dirContents = fs.listStatus(cpath).map(_.getPath) - val partitionFiles = dirContents.filter(_.getName.startsWith("part-")).map(_.toString).sorted - val numPart = partitionFiles.length - if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || - ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) { - throw new SparkException("Invalid checkpoint directory: " + checkpointPath) - } - numPart - } else 0 - - Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i)) - } - - checkpointData = Some(new RDDCheckpointData[T](this)) - checkpointData.get.cpFile = Some(checkpointPath) - - override def getPreferredLocations(split: Partition): Seq[String] = { - val status = fs.getFileStatus(new Path(checkpointPath, - CheckpointRDD.splitIdToFile(split.index))) - val locations = fs.getFileBlockLocations(status, 0, status.getLen) - locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost") - } - - override def compute(split: Partition, context: TaskContext): Iterator[T] = { - val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index)) - CheckpointRDD.readFromFile(file, broadcastedConf, context) - } - - override def checkpoint() { - // Do nothing. CheckpointRDD should not be checkpointed. - } -} - -private[spark] object CheckpointRDD extends Logging { - def splitIdToFile(splitId: Int): String = { - "part-%05d".format(splitId) - } - - def writeToFile[T: ClassTag]( - path: String, - broadcastedConf: Broadcast[SerializableConfiguration], - blockSize: Int = -1 - )(ctx: TaskContext, iterator: Iterator[T]) { - val env = SparkEnv.get - val outputDir = new Path(path) - val fs = outputDir.getFileSystem(broadcastedConf.value.value) - - val finalOutputName = splitIdToFile(ctx.partitionId) - val finalOutputPath = new Path(outputDir, finalOutputName) - val tempOutputPath = - new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptNumber) - - if (fs.exists(tempOutputPath)) { - throw new IOException("Checkpoint failed: temporary path " + - tempOutputPath + " already exists") - } - val bufferSize = env.conf.getInt("spark.buffer.size", 65536) - - val fileOutputStream = if (blockSize < 0) { - fs.create(tempOutputPath, false, bufferSize) - } else { - // This is mainly for testing purpose - fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize) - } - val serializer = env.serializer.newInstance() - val serializeStream = serializer.serializeStream(fileOutputStream) - Utils.tryWithSafeFinally { - serializeStream.writeAll(iterator) - } { - serializeStream.close() - } - - if (!fs.rename(tempOutputPath, finalOutputPath)) { - if (!fs.exists(finalOutputPath)) { - logInfo("Deleting tempOutputPath " + tempOutputPath) - fs.delete(tempOutputPath, false) - throw new IOException("Checkpoint failed: failed to save output of task: " - + ctx.attemptNumber + " and final output path does not exist") - } else { - // Some other copy of this task must've finished before us and renamed it - logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it") - fs.delete(tempOutputPath, false) - } - } - } - - def readFromFile[T]( - path: Path, - broadcastedConf: Broadcast[SerializableConfiguration], - context: TaskContext - ): Iterator[T] = { - val env = SparkEnv.get - val fs = path.getFileSystem(broadcastedConf.value.value) - val bufferSize = env.conf.getInt("spark.buffer.size", 65536) - val fileInputStream = fs.open(path, bufferSize) - val serializer = env.serializer.newInstance() - val deserializeStream = serializer.deserializeStream(fileInputStream) - - // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener(context => deserializeStream.close()) - - deserializeStream.asIterator.asInstanceOf[Iterator[T]] - } + // CheckpointRDD should not be checkpointed again + override def doCheckpoint(): Unit = { } + override def checkpoint(): Unit = { } + override def localCheckpoint(): this.type = this - // Test whether CheckpointRDD generate expected number of partitions despite - // each split file having multiple blocks. This needs to be run on a - // cluster (mesos or standalone) using HDFS. - def main(args: Array[String]) { - import org.apache.spark._ + // Note: There is a bug in MiMa that complains about `AbstractMethodProblem`s in the + // base [[org.apache.spark.rdd.RDD]] class if we do not override the following methods. + // scalastyle:off + protected override def getPartitions: Array[Partition] = ??? + override def compute(p: Partition, tc: TaskContext): Iterator[T] = ??? + // scalastyle:on - val Array(cluster, hdfsPath) = args - val env = SparkEnv.get - val sc = new SparkContext(cluster, "CheckpointRDD Test") - val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000) - val path = new Path(hdfsPath, "temp") - val conf = SparkHadoopUtil.get.newConfiguration(new SparkConf()) - val fs = path.getFileSystem(conf) - val broadcastedConf = sc.broadcast(new SerializableConfiguration(conf)) - sc.runJob(rdd, CheckpointRDD.writeToFile[Int](path.toString, broadcastedConf, 1024) _) - val cpRDD = new CheckpointRDD[Int](sc, path.toString) - assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same") - assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same") - fs.delete(path, true) - } } diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 658e8c8b8931..7bad749d5832 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -22,9 +22,9 @@ import scala.language.existentials import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag -import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext} -import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} +import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer} import org.apache.spark.util.Utils @@ -75,7 +75,9 @@ private[spark] class CoGroupPartition( * @param part partitioner used to partition the shuffle output */ @DeveloperApi -class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner) +class CoGroupedRDD[K: ClassTag]( + @transient var rdds: Seq[RDD[_ <: Product2[K, _]]], + part: Partitioner) extends RDD[(K, Array[Iterable[_]])](rdds.head.context, Nil) { // For example, `(k, a) cogroup (k, b)` produces k -> Array(ArrayBuffer as, ArrayBuffer bs). @@ -94,13 +96,14 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: } override def getDependencies: Seq[Dependency[_]] = { - rdds.map { rdd: RDD[_ <: Product2[K, _]] => + rdds.map { rdd: RDD[_] => if (rdd.partitioner == Some(part)) { logDebug("Adding one-to-one dependency with " + rdd) new OneToOneDependency(rdd) } else { logDebug("Adding shuffle dependency with " + rdd) - new ShuffleDependency[K, Any, CoGroupCombiner](rdd, part, serializer) + new ShuffleDependency[K, Any, CoGroupCombiner]( + rdd.asInstanceOf[RDD[_ <: Product2[K, _]]], part, serializer) } } } @@ -133,7 +136,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: // A list of (rdd iterator, dependency number) pairs val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)] for ((dep, depNum) <- dependencies.zipWithIndex) dep match { - case oneToOneDependency: OneToOneDependency[Product2[K, Any]] => + case oneToOneDependency: OneToOneDependency[Product2[K, Any]] @unchecked => val dependencyPartition = split.narrowDeps(depNum).get.split // Read them from the parent val it = oneToOneDependency.rdd.iterator(dependencyPartition, context) @@ -168,8 +171,10 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: for ((it, depNum) <- rddIterators) { map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) } - context.taskMetrics.incMemoryBytesSpilled(map.memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(map.diskBytesSpilled) + context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) new InterruptibleIterator(context, map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 663eebb8e419..90d9735cb3f6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -69,7 +69,7 @@ private[spark] case class CoalescedRDDPartition( * the preferred location of each new partition overlaps with as many preferred locations of its * parent partitions * @param prev RDD to be coalesced - * @param maxPartitions number of desired partitions in the coalesced RDD + * @param maxPartitions number of desired partitions in the coalesced RDD (must be positive) * @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance */ private[spark] class CoalescedRDD[T: ClassTag]( @@ -78,6 +78,9 @@ private[spark] class CoalescedRDD[T: ClassTag]( balanceSlack: Double = 0.10) extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies + require(maxPartitions > 0 || maxPartitions == prev.partitions.length, + s"Number of partitions ($maxPartitions) must be positive.") + override def getPartitions: Array[Partition] = { val pc = new PartitionCoalescer(maxPartitions, prev, balanceSlack) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index bee59a437f12..8f2655d63b79 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -44,14 +44,14 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{SerializableConfiguration, NextIterator, Utils} +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, NextIterator, Utils} import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation} import org.apache.spark.storage.StorageLevel /** * A Spark split class that wraps around a Hadoop InputSplit. */ -private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSplit) +private[spark] class HadoopPartition(rddId: Int, idx: Int, s: InputSplit) extends Partition { val inputSplit = new SerializableWritable[InputSplit](s) @@ -99,7 +99,7 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp */ @DeveloperApi class HadoopRDD[K, V]( - @transient sc: SparkContext, + sc: SparkContext, broadcastedConf: Broadcast[SerializableConfiguration], initLocalJobConfFuncOpt: Option[JobConf => Unit], inputFormatClass: Class[_ <: InputFormat[K, V]], @@ -109,7 +109,7 @@ class HadoopRDD[K, V]( extends RDD[(K, V)](sc, Nil) with Logging { if (initLocalJobConfFuncOpt.isDefined) { - sc.clean(initLocalJobConfFuncOpt.get) + sparkContext.clean(initLocalJobConfFuncOpt.get) } def this( @@ -137,7 +137,7 @@ class HadoopRDD[K, V]( // used to build JobTracker ID private val createTime = new Date() - private val shouldCloneJobConf = sc.conf.getBoolean("spark.hadoop.cloneConf", false) + private val shouldCloneJobConf = sparkContext.conf.getBoolean("spark.hadoop.cloneConf", false) // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. protected def getJobConf(): JobConf = { @@ -274,7 +274,7 @@ class HadoopRDD[K, V]( } } catch { case e: Exception => { - if (!Utils.inShutdown()) { + if (!ShutdownHookManager.inShutdown()) { logWarning("Exception in RecordReader.close()", e) } } @@ -383,11 +383,11 @@ private[spark] object HadoopRDD extends Logging { private[spark] class SplitInfoReflections { val inputSplitWithLocationInfo = - Class.forName("org.apache.hadoop.mapred.InputSplitWithLocationInfo") + Utils.classForName("org.apache.hadoop.mapred.InputSplitWithLocationInfo") val getLocationInfo = inputSplitWithLocationInfo.getMethod("getLocationInfo") - val newInputSplit = Class.forName("org.apache.hadoop.mapreduce.InputSplit") + val newInputSplit = Utils.classForName("org.apache.hadoop.mapreduce.InputSplit") val newGetLocationInfo = newInputSplit.getMethod("getLocationInfo") - val splitLocationInfo = Class.forName("org.apache.hadoop.mapred.SplitLocationInfo") + val splitLocationInfo = Utils.classForName("org.apache.hadoop.mapred.SplitLocationInfo") val isInMemory = splitLocationInfo.getMethod("isInMemory") val getLocation = splitLocationInfo.getMethod("getLocation") } diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala new file mode 100644 index 000000000000..bfe19195fcd3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag + +import org.apache.spark.{Partition, SparkContext, SparkEnv, SparkException, TaskContext} +import org.apache.spark.storage.RDDBlockId + +/** + * A dummy CheckpointRDD that exists to provide informative error messages during failures. + * + * This is simply a placeholder because the original checkpointed RDD is expected to be + * fully cached. Only if an executor fails or if the user explicitly unpersists the original + * RDD will Spark ever attempt to compute this CheckpointRDD. When this happens, however, + * we must provide an informative error message. + * + * @param sc the active SparkContext + * @param rddId the ID of the checkpointed RDD + * @param numPartitions the number of partitions in the checkpointed RDD + */ +private[spark] class LocalCheckpointRDD[T: ClassTag]( + sc: SparkContext, + rddId: Int, + numPartitions: Int) + extends CheckpointRDD[T](sc) { + + def this(rdd: RDD[T]) { + this(rdd.context, rdd.id, rdd.partitions.size) + } + + protected override def getPartitions: Array[Partition] = { + (0 until numPartitions).toArray.map { i => new CheckpointRDDPartition(i) } + } + + /** + * Throw an exception indicating that the relevant block is not found. + * + * This should only be called if the original RDD is explicitly unpersisted or if an + * executor is lost. Under normal circumstances, however, the original RDD (our child) + * is expected to be fully cached and so all partitions should already be computed and + * available in the block storage. + */ + override def compute(partition: Partition, context: TaskContext): Iterator[T] = { + throw new SparkException( + s"Checkpoint block ${RDDBlockId(rddId, partition.index)} not found! Either the executor " + + s"that originally checkpointed this partition is no longer alive, or the original RDD is " + + s"unpersisted. If this problem persists, you may consider using `rdd.checkpoint()` " + + s"instead, which is slower than local checkpointing but more fault-tolerant.") + } + +} diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala new file mode 100644 index 000000000000..c115e0ff74d3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag + +import org.apache.spark.{Logging, SparkEnv, SparkException, TaskContext} +import org.apache.spark.storage.{RDDBlockId, StorageLevel} +import org.apache.spark.util.Utils + +/** + * An implementation of checkpointing implemented on top of Spark's caching layer. + * + * Local checkpointing trades off fault tolerance for performance by skipping the expensive + * step of saving the RDD data to a reliable and fault-tolerant storage. Instead, the data + * is written to the local, ephemeral block storage that lives in each executor. This is useful + * for use cases where RDDs build up long lineages that need to be truncated often (e.g. GraphX). + */ +private[spark] class LocalRDDCheckpointData[T: ClassTag](@transient private val rdd: RDD[T]) + extends RDDCheckpointData[T](rdd) with Logging { + + /** + * Ensure the RDD is fully cached so the partitions can be recovered later. + */ + protected override def doCheckpoint(): CheckpointRDD[T] = { + val level = rdd.getStorageLevel + + // Assume storage level uses disk; otherwise memory eviction may cause data loss + assume(level.useDisk, s"Storage level $level is not appropriate for local checkpointing") + + // Not all actions compute all partitions of the RDD (e.g. take). For correctness, we + // must cache any missing partitions. TODO: avoid running another job here (SPARK-8582). + val action = (tc: TaskContext, iterator: Iterator[T]) => Utils.getIteratorSize(iterator) + val missingPartitionIndices = rdd.partitions.map(_.index).filter { i => + !SparkEnv.get.blockManager.master.contains(RDDBlockId(rdd.id, i)) + } + if (missingPartitionIndices.nonEmpty) { + rdd.sparkContext.runJob(rdd, action, missingPartitionIndices) + } + + new LocalCheckpointRDD[T](rdd) + } + +} + +private[spark] object LocalRDDCheckpointData { + + val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK + + /** + * Transform the specified storage level to one that uses disk. + * + * This guarantees that the RDD can be recomputed multiple times correctly as long as + * executors do not fail. Otherwise, if the RDD is cached in memory only, for instance, + * the checkpoint data will be lost if the relevant block is evicted from memory. + * + * This method is idempotent. + */ + def transformStorageLevel(level: StorageLevel): StorageLevel = { + // If this RDD is to be cached off-heap, fail fast since we cannot provide any + // correctness guarantees about subsequent computations after the first one + if (level.useOffHeap) { + throw new SparkException("Local checkpointing is not compatible with off-heap caching.") + } + + StorageLevel(useDisk = true, level.useMemory, level.deserialized, level.replication) + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index a838aac6e8d1..4312d3a41775 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -21,6 +21,9 @@ import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} +/** + * An RDD that applies the provided function to every partition of the parent RDD. + */ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( prev: RDD[T], f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator) diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala new file mode 100644 index 000000000000..417ff5278db2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import org.apache.spark.{Partition, Partitioner, TaskContext} + +/** + * An RDD that applies a user provided function to every partition of the parent RDD, and + * additionally allows the user to prepare each partition before computing the parent partition. + */ +private[spark] class MapPartitionsWithPreparationRDD[U: ClassTag, T: ClassTag, M: ClassTag]( + prev: RDD[T], + preparePartition: () => M, + executePartition: (TaskContext, Int, M, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false) + extends RDD[U](prev) { + + override val partitioner: Option[Partitioner] = { + if (preservesPartitioning) firstParent[T].partitioner else None + } + + override def getPartitions: Array[Partition] = firstParent[T].partitions + + // In certain join operations, prepare can be called on the same partition multiple times. + // In this case, we need to ensure that each call to compute gets a separate prepare argument. + private[this] val preparedArguments: ArrayBuffer[M] = new ArrayBuffer[M] + + /** + * Prepare a partition for a single call to compute. + */ + def prepare(): Unit = { + preparedArguments += preparePartition() + } + + /** + * Prepare a partition before computing it from its parent. + */ + override def compute(partition: Partition, context: TaskContext): Iterator[U] = { + val prepared = + if (preparedArguments.isEmpty) { + preparePartition() + } else { + preparedArguments.remove(0) + } + val parentIterator = firstParent[T].iterator(partition, context) + executePartition(context, partition.index, prepared, parentIterator) + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index f827270ee6a4..174979aaeb23 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -33,14 +33,14 @@ import org.apache.spark._ import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.storage.StorageLevel private[spark] class NewHadoopPartition( rddId: Int, val index: Int, - @transient rawSplit: InputSplit with Writable) + rawSplit: InputSplit with Writable) extends Partition { val serializableHadoopSplit = new SerializableWritable(rawSplit) @@ -68,14 +68,14 @@ class NewHadoopRDD[K, V]( inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], valueClass: Class[V], - @transient conf: Configuration) + @transient private val _conf: Configuration) extends RDD[(K, V)](sc, Nil) with SparkHadoopMapReduceUtil with Logging { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it - private val confBroadcast = sc.broadcast(new SerializableConfiguration(conf)) - // private val serializableConf = new SerializableWritable(conf) + private val confBroadcast = sc.broadcast(new SerializableConfiguration(_conf)) + // private val serializableConf = new SerializableWritable(_conf) private val jobTrackerId: String = { val formatter = new SimpleDateFormat("yyyyMMddHHmm") @@ -88,10 +88,10 @@ class NewHadoopRDD[K, V]( val inputFormat = inputFormatClass.newInstance inputFormat match { case configurable: Configurable => - configurable.setConf(conf) + configurable.setConf(_conf) case _ => } - val jobContext = newJobContext(conf, jobId) + val jobContext = newJobContext(_conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) for (i <- 0 until rawSplits.size) { @@ -128,7 +128,7 @@ class NewHadoopRDD[K, V]( configurable.setConf(conf) case _ => } - val reader = format.createRecordReader( + private var reader = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) @@ -141,6 +141,12 @@ class NewHadoopRDD[K, V]( override def hasNext: Boolean = { if (!finished && !havePair) { finished = !reader.nextKeyValue + if (finished) { + // Close and release the reader here; close() will also be called when the task + // completes, but for tasks that read from many files, it helps to release the + // resources early. + close() + } havePair = !finished } !finished @@ -159,23 +165,28 @@ class NewHadoopRDD[K, V]( private def close() { try { - reader.close() - if (bytesReadCallback.isDefined) { - inputMetrics.updateBytesRead() - } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || - split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { - // If we can't get the bytes read from the FS stats, fall back to the split size, - // which may be inaccurate. - try { - inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) - } catch { - case e: java.io.IOException => - logWarning("Unable to get input size to set InputMetrics for task", e) + if (reader != null) { + // Close reader and release it + reader.close() + reader = null + + if (bytesReadCallback.isDefined) { + inputMetrics.updateBytesRead() + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { + // If we can't get the bytes read from the FS stats, fall back to the split size, + // which may be inaccurate. + try { + inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) + } catch { + case e: java.io.IOException => + logWarning("Unable to get input size to set InputMetrics for task", e) + } } } } catch { case e: Exception => { - if (!Utils.inShutdown()) { + if (!ShutdownHookManager.inShutdown()) { logWarning("Exception in RecordReader.close()", e) } } @@ -251,7 +262,7 @@ private[spark] class WholeTextFileRDD( inputFormatClass: Class[_ <: WholeTextFileInputFormat], keyClass: Class[String], valueClass: Class[String], - @transient conf: Configuration, + conf: Configuration, minPartitions: Int) extends NewHadoopRDD[String, String](sc, inputFormatClass, keyClass, valueClass, conf) { @@ -259,10 +270,10 @@ private[spark] class WholeTextFileRDD( val inputFormat = inputFormatClass.newInstance inputFormat match { case configurable: Configurable => - configurable.setConf(conf) + configurable.setConf(getConf) case _ => } - val jobContext = newJobContext(conf, jobId) + val jobContext = newJobContext(getConf, jobId) inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 91a6a2d03985..a981b63942e6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -22,7 +22,7 @@ import java.text.SimpleDateFormat import java.util.{Date, HashMap => JHashMap} import scala.collection.{Map, mutable} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import scala.util.DynamicVariable @@ -57,7 +57,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) with SparkHadoopMapReduceUtil with Serializable { + /** + * :: Experimental :: * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C * Note that V and C can be different -- for example, one might group an RDD of type @@ -70,12 +72,14 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * In addition, users can control the partitioning of the output RDD, and whether to perform * map-side aggregation (if a mapper can produce multiple items with the same key). */ - def combineByKey[C](createCombiner: V => C, + @Experimental + def combineByKeyWithClassTag[C]( + createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, partitioner: Partitioner, mapSideCombine: Boolean = true, - serializer: Serializer = null): RDD[(K, C)] = self.withScope { + serializer: Serializer = null)(implicit ct: ClassTag[C]): RDD[(K, C)] = self.withScope { require(mergeCombiners != null, "mergeCombiners must be defined") // required as of Spark 0.9.0 if (keyClass.isArray) { if (mapSideCombine) { @@ -103,13 +107,50 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } /** - * Simplified version of combineByKey that hash-partitions the output RDD. + * Generic function to combine the elements for each key using a custom set of aggregation + * functions. This method is here for backward compatibility. It does not provide combiner + * classtag information to the shuffle. + * + * @see [[combineByKeyWithClassTag]] */ - def combineByKey[C](createCombiner: V => C, + def combineByKey[C]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C, + partitioner: Partitioner, + mapSideCombine: Boolean = true, + serializer: Serializer = null): RDD[(K, C)] = self.withScope { + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners, + partitioner, mapSideCombine, serializer)(null) + } + + /** + * Simplified version of combineByKeyWithClassTag that hash-partitions the output RDD. + * This method is here for backward compatibility. It does not provide combiner + * classtag information to the shuffle. + * + * @see [[combineByKeyWithClassTag]] + */ + def combineByKey[C]( + createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, numPartitions: Int): RDD[(K, C)] = self.withScope { - combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions)) + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners, numPartitions)(null) + } + + /** + * :: Experimental :: + * Simplified version of combineByKeyWithClassTag that hash-partitions the output RDD. + */ + @Experimental + def combineByKeyWithClassTag[C]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C, + numPartitions: Int)(implicit ct: ClassTag[C]): RDD[(K, C)] = self.withScope { + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners, + new HashPartitioner(numPartitions)) } /** @@ -133,7 +174,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) // We will clean the combiner closure later in `combineByKey` val cleanedSeqOp = self.context.clean(seqOp) - combineByKey[U]((v: V) => cleanedSeqOp(createZero(), v), cleanedSeqOp, combOp, partitioner) + combineByKeyWithClassTag[U]((v: V) => cleanedSeqOp(createZero(), v), + cleanedSeqOp, combOp, partitioner) } /** @@ -182,7 +224,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val createZero = () => cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray)) val cleanedFunc = self.context.clean(func) - combineByKey[V]((v: V) => cleanedFunc(createZero(), v), cleanedFunc, cleanedFunc, partitioner) + combineByKeyWithClassTag[V]((v: V) => cleanedFunc(createZero(), v), + cleanedFunc, cleanedFunc, partitioner) } /** @@ -268,7 +311,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * "combiner" in MapReduce. */ def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = self.withScope { - combineByKey[V]((v: V) => v, func, func, partitioner) + combineByKeyWithClassTag[V]((v: V) => v, func, func, partitioner) } /** @@ -312,14 +355,14 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } : Iterator[JHashMap[K, V]] val mergeMaps = (m1: JHashMap[K, V], m2: JHashMap[K, V]) => { - m2.foreach { pair => + m2.asScala.foreach { pair => val old = m1.get(pair._1) m1.put(pair._1, if (old == null) pair._2 else cleanedF(old, pair._2)) } m1 } : JHashMap[K, V] - self.mapPartitions(reducePartition).reduce(mergeMaps) + self.mapPartitions(reducePartition).reduce(mergeMaps).asScala } /** Alias for reduceByKeyLocally */ @@ -392,7 +435,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) h1 } - combineByKey(createHLL, mergeValueHLL, mergeHLL, partitioner).mapValues(_.cardinality()) + combineByKeyWithClassTag(createHLL, mergeValueHLL, mergeHLL, partitioner) + .mapValues(_.cardinality()) } /** @@ -466,7 +510,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val createCombiner = (v: V) => CompactBuffer(v) val mergeValue = (buf: CompactBuffer[V], v: V) => buf += v val mergeCombiners = (c1: CompactBuffer[V], c2: CompactBuffer[V]) => c1 ++= c2 - val bufs = combineByKey[CompactBuffer[V]]( + val bufs = combineByKeyWithClassTag[CompactBuffer[V]]( createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine = false) bufs.asInstanceOf[RDD[(K, Iterable[V])]] } @@ -565,12 +609,30 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } /** - * Simplified version of combineByKey that hash-partitions the resulting RDD using the + * Simplified version of combineByKeyWithClassTag that hash-partitions the resulting RDD using the + * existing partitioner/parallelism level. This method is here for backward compatibility. It + * does not provide combiner classtag information to the shuffle. + * + * @see [[combineByKeyWithClassTag]] + */ + def combineByKey[C]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C): RDD[(K, C)] = self.withScope { + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners)(null) + } + + /** + * :: Experimental :: + * Simplified version of combineByKeyWithClassTag that hash-partitions the resulting RDD using the * existing partitioner/parallelism level. */ - def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) - : RDD[(K, C)] = self.withScope { - combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) + @Experimental + def combineByKeyWithClassTag[C]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C)(implicit ct: ClassTag[C]): RDD[(K, C)] = self.withScope { + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) } /** @@ -881,7 +943,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } buf } : Seq[V] - val res = self.context.runJob(self, process, Array(index), false) + val res = self.context.runJob(self, process, Array(index)) res(0) case None => self.filter(_._1 == key).map(_._2).collect() @@ -934,8 +996,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) job.setOutputKeyClass(keyClass) job.setOutputValueClass(valueClass) job.setOutputFormatClass(outputFormatClass) - job.getConfiguration.set("mapred.output.dir", path) - saveAsNewAPIHadoopDataset(job.getConfiguration) + val jobConfiguration = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + jobConfiguration.set("mapred.output.dir", path) + saveAsNewAPIHadoopDataset(jobConfiguration) } /** @@ -955,6 +1018,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) /** * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class * supporting the key and value types K and V in this RDD. + * + * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * not use output committer that writes data directly. + * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad + * result of using direct output committer with speculation enabled. */ def saveAsHadoopFile( path: String, @@ -967,10 +1035,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val hadoopConf = conf hadoopConf.setOutputKeyClass(keyClass) hadoopConf.setOutputValueClass(valueClass) - // Doesn't work in Scala 2.9 due to what may be a generics bug - // TODO: Should we uncomment this for Scala 2.10? - // conf.setOutputFormat(outputFormatClass) - hadoopConf.set("mapred.output.format.class", outputFormatClass.getName) + conf.setOutputFormat(outputFormatClass) for (c <- codec) { hadoopConf.setCompressMapOutput(true) hadoopConf.set("mapred.output.compress", "true") @@ -984,6 +1049,19 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) hadoopConf.setOutputCommitter(classOf[FileOutputCommitter]) } + // When speculation is on and output committer class name contains "Direct", we should warn + // users that they may loss data if they are using a direct output committer. + val speculationEnabled = self.conf.getBoolean("spark.speculation", false) + val outputCommitterClass = hadoopConf.get("mapred.output.committer.class", "") + if (speculationEnabled && outputCommitterClass.contains("Direct")) { + val warningMessage = + s"$outputCommitterClass may be an output committer that writes data directly to " + + "the final location. Because speculation is enabled, this output committer may " + + "cause data loss (see the case in SPARK-10063). If possible, please use a output " + + "committer that does not have this behavior (e.g. FileOutputCommitter)." + logWarning(warningMessage) + } + FileOutputFormat.setOutputPath(hadoopConf, SparkHadoopWriter.createPathFromString(path, hadoopConf)) saveAsHadoopDataset(hadoopConf) @@ -994,6 +1072,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Configuration object for that storage system. The Conf should set an OutputFormat and any * output paths required (e.g. a table name to write to) in the same way as it would be * configured for a Hadoop MapReduce job. + * + * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * not use output committer that writes data directly. + * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad + * result of using direct output committer with speculation enabled. */ def saveAsNewAPIHadoopDataset(conf: Configuration): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). @@ -1002,7 +1085,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = self.id - val wrappedConf = new SerializableConfiguration(job.getConfiguration) + val jobConfiguration = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val wrappedConf = new SerializableConfiguration(jobConfiguration) val outfmt = job.getOutputFormatClass val jobFormat = outfmt.newInstance @@ -1051,6 +1135,20 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0) val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) + + // When speculation is on and output committer class name contains "Direct", we should warn + // users that they may loss data if they are using a direct output committer. + val speculationEnabled = self.conf.getBoolean("spark.speculation", false) + val outputCommitterClass = jobCommitter.getClass.getSimpleName + if (speculationEnabled && outputCommitterClass.contains("Direct")) { + val warningMessage = + s"$outputCommitterClass may be an output committer that writes data directly to " + + "the final location. Because speculation is enabled, this output committer may " + + "cause data loss (see the case in SPARK-10063). If possible, please use a output " + + "committer that does not have this behavior (e.g. FileOutputCommitter)." + logWarning(warningMessage) + } + jobCommitter.setupJob(jobTaskContext) self.context.runJob(self, writeShard) jobCommitter.commitJob(jobTaskContext) @@ -1065,7 +1163,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) def saveAsHadoopDataset(conf: JobConf): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf - val wrappedConf = new SerializableConfiguration(hadoopConf) val outputFormatInstance = hadoopConf.getOutputFormat val keyClass = hadoopConf.getOutputKeyClass val valueClass = hadoopConf.getOutputValueClass @@ -1093,7 +1190,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.preSetup() val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => { - val config = wrappedConf.value // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index e2394e28f8d2..582fa93afe34 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -83,8 +83,8 @@ private[spark] class ParallelCollectionPartition[T: ClassTag]( } private[spark] class ParallelCollectionRDD[T: ClassTag]( - @transient sc: SparkContext, - @transient data: Seq[T], + sc: SparkContext, + @transient private val data: Seq[T], numSlices: Int, locationPrefs: Map[Int, Seq[String]]) extends RDD[T](sc, Nil) { diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index a00f4c1cdff9..d6a37e8cc5da 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -32,7 +32,7 @@ private[spark] class PartitionPruningRDDPartition(idx: Int, val parentSplit: Par * Represents a dependency between the PartitionPruningRDD and its parent. In this * case, the child RDD contains a subset of partitions of the parents'. */ -private[spark] class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean) +private[spark] class PruneDependency[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean) extends NarrowDependency[T](rdd) { @transient @@ -55,8 +55,8 @@ private[spark] class PruneDependency[T](rdd: RDD[T], @transient partitionFilterF */ @DeveloperApi class PartitionPruningRDD[T: ClassTag]( - @transient prev: RDD[T], - @transient partitionFilterFunc: Int => Boolean) + prev: RDD[T], + partitionFilterFunc: Int => Boolean) extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { override def compute(split: Partition, context: TaskContext): Iterator[T] = { diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala index a637d6f15b7e..3b1acacf409b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala @@ -47,8 +47,8 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long) private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag]( prev: RDD[T], sampler: RandomSampler[T, U], - @transient preservesPartitioning: Boolean, - @transient seed: Long = Utils.random.nextLong) + preservesPartitioning: Boolean, + @transient private val seed: Long = Utils.random.nextLong) extends RDD[U](prev) { @transient override val partitioner = if (preservesPartitioning) prev.partitioner else None diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index dc60d4892762..afbe566b7656 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -23,7 +23,7 @@ import java.io.IOException import java.io.PrintWriter import java.util.StringTokenizer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.ArrayBuffer import scala.io.Source @@ -72,7 +72,7 @@ private[spark] class PipedRDD[T: ClassTag]( } override def compute(split: Partition, context: TaskContext): Iterator[String] = { - val pb = new ProcessBuilder(command) + val pb = new ProcessBuilder(command.asJava) // Add the environmental variables to the process. val currentEnvVars = pb.environment() envVars.foreach { case (variable, value) => currentEnvVars.put(variable, value) } @@ -81,7 +81,7 @@ private[spark] class PipedRDD[T: ClassTag]( // so the user code can access the input filename if (split.isInstanceOf[HadoopPartition]) { val hadoopSplit = split.asInstanceOf[HadoopPartition] - currentEnvVars.putAll(hadoopSplit.getPipeEnvVars()) + currentEnvVars.putAll(hadoopSplit.getPipeEnvVars().asJava) } // When spark.worker.separated.working.directory option is turned on, each @@ -123,7 +123,9 @@ private[spark] class PipedRDD[T: ClassTag]( new Thread("stderr reader for " + command) { override def run() { for (line <- Source.fromInputStream(proc.getErrorStream).getLines) { + // scalastyle:off println System.err.println(line) + // scalastyle:on println } } }.start() @@ -131,8 +133,10 @@ private[spark] class PipedRDD[T: ClassTag]( // Start a thread to feed the process input from our parent's iterator new Thread("stdin writer for " + command) { override def run() { + TaskContext.setTaskContext(context) val out = new PrintWriter(proc.getOutputStream) + // scalastyle:off println // input the pipe context firstly if (printPipeContext != null) { printPipeContext(out.println(_)) @@ -144,6 +148,7 @@ private[spark] class PipedRDD[T: ClassTag]( out.println(elem) } } + // scalastyle:on println out.close() } }.start() diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 10610f4b6f1f..a56e542242d5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -149,23 +149,43 @@ abstract class RDD[T: ClassTag]( } /** - * Set this RDD's storage level to persist its values across operations after the first time - * it is computed. This can only be used to assign a new storage level if the RDD does not - * have a storage level set yet.. + * Mark this RDD for persisting using the specified level. + * + * @param newLevel the target storage level + * @param allowOverride whether to override any existing level with the new one */ - def persist(newLevel: StorageLevel): this.type = { + private def persist(newLevel: StorageLevel, allowOverride: Boolean): this.type = { // TODO: Handle changes of StorageLevel - if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) { + if (storageLevel != StorageLevel.NONE && newLevel != storageLevel && !allowOverride) { throw new UnsupportedOperationException( "Cannot change storage level of an RDD after it was already assigned a level") } - sc.persistRDD(this) - // Register the RDD with the ContextCleaner for automatic GC-based cleanup - sc.cleaner.foreach(_.registerRDDForCleanup(this)) + // If this is the first time this RDD is marked for persisting, register it + // with the SparkContext for cleanups and accounting. Do this only once. + if (storageLevel == StorageLevel.NONE) { + sc.cleaner.foreach(_.registerRDDForCleanup(this)) + sc.persistRDD(this) + } storageLevel = newLevel this } + /** + * Set this RDD's storage level to persist its values across operations after the first time + * it is computed. This can only be used to assign a new storage level if the RDD does not + * have a storage level set yet. Local checkpointing is an exception. + */ + def persist(newLevel: StorageLevel): this.type = { + if (isLocallyCheckpointed) { + // This means the user previously called localCheckpoint(), which should have already + // marked this RDD for persisting. Here we should override the old storage level with + // one that is explicitly requested by the user (after adapting it to use disk). + persist(LocalRDDCheckpointData.transformStorageLevel(newLevel), allowOverride = true) + } else { + persist(newLevel, allowOverride = false) + } + } + /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ def persist(): this.type = persist(StorageLevel.MEMORY_ONLY) @@ -194,7 +214,7 @@ abstract class RDD[T: ClassTag]( @transient private var partitions_ : Array[Partition] = null /** An Option holding our checkpoint RDD, if we are checkpointed */ - private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD) + private def checkpointRDD: Option[CheckpointRDD[T]] = checkpointData.flatMap(_.checkpointRDD) /** * Get the list of dependencies of this RDD, taking into account whether the @@ -449,50 +469,44 @@ abstract class RDD[T: ClassTag]( * @param seed seed for the random number generator * @return sample of specified size in an array */ - // TODO: rewrite this without return statements so we can wrap it in a scope def takeSample( withReplacement: Boolean, num: Int, - seed: Long = Utils.random.nextLong): Array[T] = { + seed: Long = Utils.random.nextLong): Array[T] = withScope { val numStDev = 10.0 - if (num < 0) { - throw new IllegalArgumentException("Negative number of elements requested") - } else if (num == 0) { - return new Array[T](0) - } + require(num >= 0, "Negative number of elements requested") + require(num <= (Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt), + "Cannot support a sample size > Int.MaxValue - " + + s"$numStDev * math.sqrt(Int.MaxValue)") - val initialCount = this.count() - if (initialCount == 0) { - return new Array[T](0) - } - - val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt - if (num > maxSampleSize) { - throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " + - s"$numStDev * math.sqrt(Int.MaxValue)") - } - - val rand = new Random(seed) - if (!withReplacement && num >= initialCount) { - return Utils.randomizeInPlace(this.collect(), rand) - } - - val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, - withReplacement) - - var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() - - // If the first sample didn't turn out large enough, keep trying to take samples; - // this shouldn't happen often because we use a big multiplier for the initial size - var numIters = 0 - while (samples.length < num) { - logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters") - samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() - numIters += 1 + if (num == 0) { + new Array[T](0) + } else { + val initialCount = this.count() + if (initialCount == 0) { + new Array[T](0) + } else { + val rand = new Random(seed) + if (!withReplacement && num >= initialCount) { + Utils.randomizeInPlace(this.collect(), rand) + } else { + val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, + withReplacement) + var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + + // If the first sample didn't turn out large enough, keep trying to take samples; + // this shouldn't happen often because we use a big multiplier for the initial size + var numIters = 0 + while (samples.length < num) { + logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters") + samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + numIters += 1 + } + Utils.randomizeInPlace(samples, rand).take(num) + } + } } - - Utils.randomizeInPlace(samples, rand).take(num) } /** @@ -890,10 +904,14 @@ abstract class RDD[T: ClassTag]( * Return an iterator that contains all of the elements in this RDD. * * The iterator will consume as much memory as the largest partition in this RDD. + * + * Note: this results in multiple Spark jobs, and if the input RDD is the result + * of a wide transformation (e.g. join with different partitioners), to avoid + * recomputing the input RDD should be cached first. */ def toLocalIterator: Iterator[T] = withScope { def collectPartition(p: Int): Array[T] = { - sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p), allowLocal = false).head + sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p)).head } (0 until partitions.length).iterator.flatMap(i => collectPartition(i)) } @@ -1078,7 +1096,9 @@ abstract class RDD[T: ClassTag]( val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) // If creating an extra level doesn't help reduce // the wall-clock time, we stop tree aggregation. - while (numPartitions > scale + numPartitions / scale) { + + // Don't trigger TreeAggregation when it doesn't save wall-clock time + while (numPartitions > scale + math.ceil(numPartitions.toDouble / scale)) { numPartitions /= scale val curNumPartitions = numPartitions partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { @@ -1269,7 +1289,7 @@ abstract class RDD[T: ClassTag]( val left = num - buf.size val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) - val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p, allowLocal = true) + val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p) res.foreach(buf ++= _.take(num - buf.size)) partsScanned += numPartsToTry @@ -1442,29 +1462,99 @@ abstract class RDD[T: ClassTag]( /** * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint - * directory set with SparkContext.setCheckpointDir() and all references to its parent + * directory set with `SparkContext#setCheckpointDir` and all references to its parent * RDDs will be removed. This function must be called before any job has been * executed on this RDD. It is strongly recommended that this RDD is persisted in * memory, otherwise saving it on a file will require recomputation. */ - def checkpoint() { + def checkpoint(): Unit = RDDCheckpointData.synchronized { + // NOTE: we use a global lock here due to complexities downstream with ensuring + // children RDD partitions point to the correct parent partitions. In the future + // we should revisit this consideration. if (context.checkpointDir.isEmpty) { throw new SparkException("Checkpoint directory has not been set in the SparkContext") } else if (checkpointData.isEmpty) { - checkpointData = Some(new RDDCheckpointData(this)) - checkpointData.get.markForCheckpoint() + checkpointData = Some(new ReliableRDDCheckpointData(this)) + } + } + + /** + * Mark this RDD for local checkpointing using Spark's existing caching layer. + * + * This method is for users who wish to truncate RDD lineages while skipping the expensive + * step of replicating the materialized data in a reliable distributed file system. This is + * useful for RDDs with long lineages that need to be truncated periodically (e.g. GraphX). + * + * Local checkpointing sacrifices fault-tolerance for performance. In particular, checkpointed + * data is written to ephemeral local storage in the executors instead of to a reliable, + * fault-tolerant storage. The effect is that if an executor fails during the computation, + * the checkpointed data may no longer be accessible, causing an irrecoverable job failure. + * + * This is NOT safe to use with dynamic allocation, which removes executors along + * with their cached blocks. If you must use both features, you are advised to set + * `spark.dynamicAllocation.cachedExecutorIdleTimeout` to a high value. + * + * The checkpoint directory set through `SparkContext#setCheckpointDir` is not used. + */ + def localCheckpoint(): this.type = RDDCheckpointData.synchronized { + if (conf.getBoolean("spark.dynamicAllocation.enabled", false) && + conf.contains("spark.dynamicAllocation.cachedExecutorIdleTimeout")) { + logWarning("Local checkpointing is NOT safe to use with dynamic allocation, " + + "which removes executors along with their cached blocks. If you must use both " + + "features, you are advised to set `spark.dynamicAllocation.cachedExecutorIdleTimeout` " + + "to a high value. E.g. If you plan to use the RDD for 1 hour, set the timeout to " + + "at least 1 hour.") + } + + // Note: At this point we do not actually know whether the user will call persist() on + // this RDD later, so we must explicitly call it here ourselves to ensure the cached + // blocks are registered for cleanup later in the SparkContext. + // + // If, however, the user has already called persist() on this RDD, then we must adapt + // the storage level he/she specified to one that is appropriate for local checkpointing + // (i.e. uses disk) to guarantee correctness. + + if (storageLevel == StorageLevel.NONE) { + persist(LocalRDDCheckpointData.DEFAULT_STORAGE_LEVEL) + } else { + persist(LocalRDDCheckpointData.transformStorageLevel(storageLevel), allowOverride = true) } + + checkpointData match { + case Some(reliable: ReliableRDDCheckpointData[_]) => logWarning( + "RDD was already marked for reliable checkpointing: overriding with local checkpoint.") + case _ => + } + checkpointData = Some(new LocalRDDCheckpointData(this)) + this } /** - * Return whether this RDD has been checkpointed or not + * Return whether this RDD is marked for checkpointing, either reliably or locally. */ def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed) /** - * Gets the name of the file to which this RDD was checkpointed + * Return whether this RDD is marked for local checkpointing. + * Exposed for testing. */ - def getCheckpointFile: Option[String] = checkpointData.flatMap(_.getCheckpointFile) + private[rdd] def isLocallyCheckpointed: Boolean = { + checkpointData match { + case Some(_: LocalRDDCheckpointData[T]) => true + case _ => false + } + } + + /** + * Gets the name of the directory to which this RDD was checkpointed. + * This is not defined if the RDD is checkpointed locally. + */ + def getCheckpointFile: Option[String] = { + checkpointData match { + case Some(reliable: ReliableRDDCheckpointData[T]) => reliable.getCheckpointDir + case _ => None + } + } // ======================================================================= // Other internal methods and fields @@ -1493,7 +1583,7 @@ abstract class RDD[T: ClassTag]( private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None /** Returns the first parent RDD */ - protected[spark] def firstParent[U: ClassTag] = { + protected[spark] def firstParent[U: ClassTag]: RDD[U] = { dependencies.head.rdd.asInstanceOf[RDD[U]] } @@ -1535,7 +1625,7 @@ abstract class RDD[T: ClassTag]( if (!doCheckpointCalled) { doCheckpointCalled = true if (checkpointData.isDefined) { - checkpointData.get.doCheckpoint() + checkpointData.get.checkpoint() } else { dependencies.foreach(_.rdd.doCheckpoint()) } @@ -1547,7 +1637,7 @@ abstract class RDD[T: ClassTag]( * Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`) * created from the checkpoint file, and forget its old dependencies and partitions. */ - private[spark] def markCheckpointed(checkpointRDD: RDD[_]) { + private[spark] def markCheckpointed(): Unit = { clearDependencies() partitions_ = null deps = null // Forget the constructor argument for dependencies too @@ -1570,7 +1660,7 @@ abstract class RDD[T: ClassTag]( import Utils.bytesToString val persistence = if (storageLevel != StorageLevel.NONE) storageLevel.description else "" - val storageInfo = rdd.context.getRDDStorageInfo.filter(_.id == rdd.id).map(info => + val storageInfo = rdd.context.getRDDStorageInfo(_.id == rdd.id).map(info => " CachedPartitions: %d; MemorySize: %s; ExternalBlockStoreSize: %s; DiskSize: %s".format( info.numCachedPartitions, bytesToString(info.memSize), bytesToString(info.externalBlockStoreSize), bytesToString(info.diskSize))) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index acbd31aacdf5..429514b4f6be 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -19,19 +19,15 @@ package org.apache.spark.rdd import scala.reflect.ClassTag -import org.apache.hadoop.fs.Path - -import org.apache.spark._ -import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask} -import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.Partition /** * Enumeration to manage state transitions of an RDD through checkpointing - * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ] + * [ Initialized --> checkpointing in progress --> checkpointed ]. */ private[spark] object CheckpointState extends Enumeration { type CheckpointState = Value - val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value + val Initialized, CheckpointingInProgress, Checkpointed = Value } /** @@ -40,113 +36,76 @@ private[spark] object CheckpointState extends Enumeration { * as well as, manages the post-checkpoint state by providing the updated partitions, * iterator and preferred locations of the checkpointed RDD. */ -private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) - extends Logging with Serializable { +private[spark] abstract class RDDCheckpointData[T: ClassTag](@transient private val rdd: RDD[T]) + extends Serializable { import CheckpointState._ // The checkpoint state of the associated RDD. - var cpState = Initialized - - // The file to which the associated RDD has been checkpointed to - @transient var cpFile: Option[String] = None + protected var cpState = Initialized - // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD. - var cpRDD: Option[RDD[T]] = None - - // Mark the RDD for checkpointing - def markForCheckpoint() { - RDDCheckpointData.synchronized { - if (cpState == Initialized) cpState = MarkedForCheckpoint - } - } + // The RDD that contains our checkpointed data + private var cpRDD: Option[CheckpointRDD[T]] = None - // Is the RDD already checkpointed - def isCheckpointed: Boolean = { - RDDCheckpointData.synchronized { cpState == Checkpointed } - } + // TODO: are we sure we need to use a global lock in the following methods? - // Get the file to which this RDD was checkpointed to as an Option - def getCheckpointFile: Option[String] = { - RDDCheckpointData.synchronized { cpFile } + /** + * Return whether the checkpoint data for this RDD is already persisted. + */ + def isCheckpointed: Boolean = RDDCheckpointData.synchronized { + cpState == Checkpointed } - // Do the checkpointing of the RDD. Called after the first job using that RDD is over. - def doCheckpoint() { - // If it is marked for checkpointing AND checkpointing is not already in progress, - // then set it to be in progress, else return + /** + * Materialize this RDD and persist its content. + * This is called immediately after the first action invoked on this RDD has completed. + */ + final def checkpoint(): Unit = { + // Guard against multiple threads checkpointing the same RDD by + // atomically flipping the state of this RDDCheckpointData RDDCheckpointData.synchronized { - if (cpState == MarkedForCheckpoint) { + if (cpState == Initialized) { cpState = CheckpointingInProgress } else { return } } - // Create the output path for the checkpoint - val path = RDDCheckpointData.rddCheckpointDataPath(rdd.context, rdd.id).get - val fs = path.getFileSystem(rdd.context.hadoopConfiguration) - if (!fs.mkdirs(path)) { - throw new SparkException("Failed to create checkpoint path " + path) - } + val newRDD = doCheckpoint() - // Save to file, and reload it as an RDD - val broadcastedConf = rdd.context.broadcast( - new SerializableConfiguration(rdd.context.hadoopConfiguration)) - val newRDD = new CheckpointRDD[T](rdd.context, path.toString) - if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) { - rdd.context.cleaner.foreach { cleaner => - cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id) - } - } - rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) - if (newRDD.partitions.length != rdd.partitions.length) { - throw new SparkException( - "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.length + ") has different " + - "number of partitions than original RDD " + rdd + "(" + rdd.partitions.length + ")") - } - - // Change the dependencies and partitions of the RDD + // Update our state and truncate the RDD lineage RDDCheckpointData.synchronized { - cpFile = Some(path.toString) cpRDD = Some(newRDD) - rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions cpState = Checkpointed + rdd.markCheckpointed() } - logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id) } - // Get preferred location of a split after checkpointing - def getPreferredLocations(split: Partition): Seq[String] = { - RDDCheckpointData.synchronized { - cpRDD.get.preferredLocations(split) - } + /** + * Materialize this RDD and persist its content. + * + * Subclasses should override this method to define custom checkpointing behavior. + * @return the checkpoint RDD created in the process. + */ + protected def doCheckpoint(): CheckpointRDD[T] + + /** + * Return the RDD that contains our checkpointed data. + * This is only defined if the checkpoint state is `Checkpointed`. + */ + def checkpointRDD: Option[CheckpointRDD[T]] = RDDCheckpointData.synchronized { cpRDD } + + /** + * Return the partitions of the resulting checkpoint RDD. + * For tests only. + */ + def getPartitions: Array[Partition] = RDDCheckpointData.synchronized { + cpRDD.map(_.partitions).getOrElse { Array.empty } } - def getPartitions: Array[Partition] = { - RDDCheckpointData.synchronized { - cpRDD.get.partitions - } - } - - def checkpointRDD: Option[RDD[T]] = { - RDDCheckpointData.synchronized { - cpRDD - } - } } -private[spark] object RDDCheckpointData { - def rddCheckpointDataPath(sc: SparkContext, rddId: Int): Option[Path] = { - sc.checkpointDir.map { dir => new Path(dir, "rdd-" + rddId) } - } - - def clearRDDCheckpointData(sc: SparkContext, rddId: Int): Unit = { - rddCheckpointDataPath(sc, rddId).foreach { path => - val fs = path.getFileSystem(sc.hadoopConfiguration) - if (fs.exists(path)) { - fs.delete(path, true) - } - } - } -} +/** + * Global lock for synchronizing checkpoint operations. + */ +private[spark] object RDDCheckpointData diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala index 44667281c106..540cbd688b63 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala @@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.{JsonIgnore, JsonInclude, JsonPropertyOr import com.fasterxml.jackson.annotation.JsonInclude.Include import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.scala.DefaultScalaModule +import com.google.common.base.Objects import org.apache.spark.{Logging, SparkContext} @@ -67,6 +68,8 @@ private[spark] class RDDOperationScope( } } + override def hashCode(): Int = Objects.hashCode(id, name, parent) + override def toString: String = toJson } diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala new file mode 100644 index 000000000000..1c3b5da19ceb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.io.IOException + +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.Path + +import org.apache.spark._ +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.util.{SerializableConfiguration, Utils} + +/** + * An RDD that reads from checkpoint files previously written to reliable storage. + */ +private[spark] class ReliableCheckpointRDD[T: ClassTag]( + sc: SparkContext, + val checkpointPath: String) + extends CheckpointRDD[T](sc) { + + @transient private val hadoopConf = sc.hadoopConfiguration + @transient private val cpath = new Path(checkpointPath) + @transient private val fs = cpath.getFileSystem(hadoopConf) + private val broadcastedConf = sc.broadcast(new SerializableConfiguration(hadoopConf)) + + // Fail fast if checkpoint directory does not exist + require(fs.exists(cpath), s"Checkpoint directory does not exist: $checkpointPath") + + /** + * Return the path of the checkpoint directory this RDD reads data from. + */ + override def getCheckpointFile: Option[String] = Some(checkpointPath) + + /** + * Return partitions described by the files in the checkpoint directory. + * + * Since the original RDD may belong to a prior application, there is no way to know a + * priori the number of partitions to expect. This method assumes that the original set of + * checkpoint files are fully preserved in a reliable storage across application lifespans. + */ + protected override def getPartitions: Array[Partition] = { + // listStatus can throw exception if path does not exist. + val inputFiles = fs.listStatus(cpath) + .map(_.getPath) + .filter(_.getName.startsWith("part-")) + .sortBy(_.toString) + // Fail fast if input files are invalid + inputFiles.zipWithIndex.foreach { case (path, i) => + if (!path.toString.endsWith(ReliableCheckpointRDD.checkpointFileName(i))) { + throw new SparkException(s"Invalid checkpoint file: $path") + } + } + Array.tabulate(inputFiles.length)(i => new CheckpointRDDPartition(i)) + } + + /** + * Return the locations of the checkpoint file associated with the given partition. + */ + protected override def getPreferredLocations(split: Partition): Seq[String] = { + val status = fs.getFileStatus( + new Path(checkpointPath, ReliableCheckpointRDD.checkpointFileName(split.index))) + val locations = fs.getFileBlockLocations(status, 0, status.getLen) + locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost") + } + + /** + * Read the content of the checkpoint file associated with the given partition. + */ + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + val file = new Path(checkpointPath, ReliableCheckpointRDD.checkpointFileName(split.index)) + ReliableCheckpointRDD.readCheckpointFile(file, broadcastedConf, context) + } + +} + +private[spark] object ReliableCheckpointRDD extends Logging { + + /** + * Return the checkpoint file name for the given partition. + */ + private def checkpointFileName(partitionIndex: Int): String = { + "part-%05d".format(partitionIndex) + } + + /** + * Write this partition's values to a checkpoint file. + */ + def writeCheckpointFile[T: ClassTag]( + path: String, + broadcastedConf: Broadcast[SerializableConfiguration], + blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) { + val env = SparkEnv.get + val outputDir = new Path(path) + val fs = outputDir.getFileSystem(broadcastedConf.value.value) + + val finalOutputName = ReliableCheckpointRDD.checkpointFileName(ctx.partitionId()) + val finalOutputPath = new Path(outputDir, finalOutputName) + val tempOutputPath = + new Path(outputDir, s".$finalOutputName-attempt-${ctx.attemptNumber()}") + + if (fs.exists(tempOutputPath)) { + throw new IOException(s"Checkpoint failed: temporary path $tempOutputPath already exists") + } + val bufferSize = env.conf.getInt("spark.buffer.size", 65536) + + val fileOutputStream = if (blockSize < 0) { + fs.create(tempOutputPath, false, bufferSize) + } else { + // This is mainly for testing purpose + fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize) + } + val serializer = env.serializer.newInstance() + val serializeStream = serializer.serializeStream(fileOutputStream) + Utils.tryWithSafeFinally { + serializeStream.writeAll(iterator) + } { + serializeStream.close() + } + + if (!fs.rename(tempOutputPath, finalOutputPath)) { + if (!fs.exists(finalOutputPath)) { + logInfo(s"Deleting tempOutputPath $tempOutputPath") + fs.delete(tempOutputPath, false) + throw new IOException("Checkpoint failed: failed to save output of task: " + + s"${ctx.attemptNumber()} and final output path does not exist: $finalOutputPath") + } else { + // Some other copy of this task must've finished before us and renamed it + logInfo(s"Final output path $finalOutputPath already exists; not overwriting it") + fs.delete(tempOutputPath, false) + } + } + } + + /** + * Read the content of the specified checkpoint file. + */ + def readCheckpointFile[T]( + path: Path, + broadcastedConf: Broadcast[SerializableConfiguration], + context: TaskContext): Iterator[T] = { + val env = SparkEnv.get + val fs = path.getFileSystem(broadcastedConf.value.value) + val bufferSize = env.conf.getInt("spark.buffer.size", 65536) + val fileInputStream = fs.open(path, bufferSize) + val serializer = env.serializer.newInstance() + val deserializeStream = serializer.deserializeStream(fileInputStream) + + // Register an on-task-completion callback to close the input stream. + context.addTaskCompletionListener(context => deserializeStream.close()) + + deserializeStream.asIterator.asInstanceOf[Iterator[T]] + } + +} diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala new file mode 100644 index 000000000000..e9f6060301ba --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.Path + +import org.apache.spark._ +import org.apache.spark.util.SerializableConfiguration + +/** + * An implementation of checkpointing that writes the RDD data to reliable storage. + * This allows drivers to be restarted on failure with previously computed state. + */ +private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private val rdd: RDD[T]) + extends RDDCheckpointData[T](rdd) with Logging { + + // The directory to which the associated RDD has been checkpointed to + // This is assumed to be a non-local path that points to some reliable storage + private val cpDir: String = + ReliableRDDCheckpointData.checkpointPath(rdd.context, rdd.id) + .map(_.toString) + .getOrElse { throw new SparkException("Checkpoint dir must be specified.") } + + /** + * Return the directory to which this RDD was checkpointed. + * If the RDD is not checkpointed yet, return None. + */ + def getCheckpointDir: Option[String] = RDDCheckpointData.synchronized { + if (isCheckpointed) { + Some(cpDir.toString) + } else { + None + } + } + + /** + * Materialize this RDD and write its content to a reliable DFS. + * This is called immediately after the first action invoked on this RDD has completed. + */ + protected override def doCheckpoint(): CheckpointRDD[T] = { + + // Create the output path for the checkpoint + val path = new Path(cpDir) + val fs = path.getFileSystem(rdd.context.hadoopConfiguration) + if (!fs.mkdirs(path)) { + throw new SparkException(s"Failed to create checkpoint path $cpDir") + } + + // Save to file, and reload it as an RDD + val broadcastedConf = rdd.context.broadcast( + new SerializableConfiguration(rdd.context.hadoopConfiguration)) + // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582) + rdd.context.runJob(rdd, ReliableCheckpointRDD.writeCheckpointFile[T](cpDir, broadcastedConf) _) + val newRDD = new ReliableCheckpointRDD[T](rdd.context, cpDir) + if (newRDD.partitions.length != rdd.partitions.length) { + throw new SparkException( + s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " + + s"number of partitions from original RDD $rdd(${rdd.partitions.length})") + } + + // Optionally clean our checkpoint files if the reference is out of scope + if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) { + rdd.context.cleaner.foreach { cleaner => + cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id) + } + } + + logInfo(s"Done checkpointing RDD ${rdd.id} to $cpDir, new parent is RDD ${newRDD.id}") + + newRDD + } + +} + +private[spark] object ReliableRDDCheckpointData { + + /** Return the path of the directory to which this RDD's checkpoint data is written. */ + def checkpointPath(sc: SparkContext, rddId: Int): Option[Path] = { + sc.checkpointDir.map { dir => new Path(dir, s"rdd-$rddId") } + } + + /** Clean up the files associated with the checkpoint data for this RDD. */ + def cleanCheckpoint(sc: SparkContext, rddId: Int): Unit = { + checkpointPath(sc, rddId).foreach { path => + val fs = path.getFileSystem(sc.hadoopConfiguration) + if (fs.exists(path)) { + fs.delete(path, true) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 2dc47f95937c..cb15d912bbfb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.rdd +import scala.reflect.ClassTag + import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.Serializer @@ -37,7 +39,7 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { */ // TODO: Make this return RDD[Product2[K, C]] or have some way to configure mutable pairs @DeveloperApi -class ShuffledRDD[K, V, C]( +class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag]( @transient var prev: RDD[_ <: Product2[K, V]], part: Partitioner) extends RDD[(K, C)](prev.context, Nil) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala similarity index 73% rename from sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala rename to core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index 2bdc34102125..0228c54e0511 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -15,33 +15,31 @@ * limitations under the License. */ -package org.apache.spark.sql.sources +package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date +import scala.reflect.ClassTag + import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} import org.apache.spark.broadcast.Broadcast - -import org.apache.spark.{Partition => SparkPartition, _} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil -import org.apache.spark.rdd.{RDD, HadoopRDD} -import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.{Partition => SparkPartition, _} import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils} -import scala.reflect.ClassTag private[spark] class SqlNewHadoopPartition( rddId: Int, val index: Int, - @transient rawSplit: InputSplit with Writable) + rawSplit: InputSplit with Writable) extends SparkPartition { val serializableHadoopSplit = new SerializableWritable(rawSplit) @@ -60,18 +58,16 @@ private[spark] class SqlNewHadoopPartition( * and the executor side to the shared Hadoop Configuration. * * Note: This is RDD is basically a cloned version of [[org.apache.spark.rdd.NewHadoopRDD]] with - * changes based on [[org.apache.spark.rdd.HadoopRDD]]. In future, this functionality will be - * folded into core. + * changes based on [[org.apache.spark.rdd.HadoopRDD]]. */ -private[sql] class SqlNewHadoopRDD[K, V]( - @transient sc : SparkContext, +private[spark] class SqlNewHadoopRDD[V: ClassTag]( + sc : SparkContext, broadcastedConf: Broadcast[SerializableConfiguration], - @transient initDriverSideJobFuncOpt: Option[Job => Unit], + @transient private val initDriverSideJobFuncOpt: Option[Job => Unit], initLocalJobFuncOpt: Option[Job => Unit], - inputFormatClass: Class[_ <: InputFormat[K, V]], - keyClass: Class[K], + inputFormatClass: Class[_ <: InputFormat[Void, V]], valueClass: Class[V]) - extends RDD[(K, V)](sc, Nil) + extends RDD[V](sc, Nil) with SparkHadoopMapReduceUtil with Logging { @@ -90,7 +86,7 @@ private[sql] class SqlNewHadoopRDD[K, V]( if (isDriverSide) { initDriverSideJobFuncOpt.map(f => f(job)) } - job.getConfiguration + SparkHadoopUtil.get.getConfigurationFromJobContext(job) } private val jobTrackerId: String = { @@ -120,8 +116,8 @@ private[sql] class SqlNewHadoopRDD[K, V]( override def compute( theSplit: SparkPartition, - context: TaskContext): InterruptibleIterator[(K, V)] = { - val iter = new Iterator[(K, V)] { + context: TaskContext): Iterator[V] = { + val iter = new Iterator[V] { val split = theSplit.asInstanceOf[SqlNewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) val conf = getConf(isDriverSide = false) @@ -129,6 +125,12 @@ private[sql] class SqlNewHadoopRDD[K, V]( val inputMetrics = context.taskMetrics .getInputMetricsForReadMethod(DataReadMethod.Hadoop) + // Sets the thread local variable for the file's name + split.serializableHadoopSplit.value match { + case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString) + case _ => SqlNewHadoopRDD.unsetInputFileName() + } + // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { @@ -148,25 +150,34 @@ private[sql] class SqlNewHadoopRDD[K, V]( configurable.setConf(conf) case _ => } - val reader = format.createRecordReader( + private[this] var reader = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener(context => close()) - var havePair = false - var finished = false - var recordsSinceMetricsUpdate = 0 + + private[this] var havePair = false + private[this] var finished = false override def hasNext: Boolean = { + if (context.isInterrupted) { + throw new TaskKilledException + } if (!finished && !havePair) { finished = !reader.nextKeyValue + if (finished) { + // Close and release the reader here; close() will also be called when the task + // completes, but for tasks that read from many files, it helps to release the + // resources early. + close() + } havePair = !finished } !finished } - override def next(): (K, V) = { + override def next(): V = { if (!hasNext) { throw new java.util.NoSuchElementException("End of stream") } @@ -174,43 +185,40 @@ private[sql] class SqlNewHadoopRDD[K, V]( if (!finished) { inputMetrics.incRecordsRead(1) } - (reader.getCurrentKey, reader.getCurrentValue) + reader.getCurrentValue } private def close() { try { - reader.close() - if (bytesReadCallback.isDefined) { - inputMetrics.updateBytesRead() - } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || - split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { - // If we can't get the bytes read from the FS stats, fall back to the split size, - // which may be inaccurate. - try { - inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) - } catch { - case e: java.io.IOException => - logWarning("Unable to get input size to set InputMetrics for task", e) + if (reader != null) { + reader.close() + reader = null + + SqlNewHadoopRDD.unsetInputFileName() + + if (bytesReadCallback.isDefined) { + inputMetrics.updateBytesRead() + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { + // If we can't get the bytes read from the FS stats, fall back to the split size, + // which may be inaccurate. + try { + inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) + } catch { + case e: java.io.IOException => + logWarning("Unable to get input size to set InputMetrics for task", e) + } } } } catch { - case e: Exception => { - if (!Utils.inShutdown()) { + case e: Exception => + if (!ShutdownHookManager.inShutdown()) { logWarning("Exception in RecordReader.close()", e) } - } } } } - new InterruptibleIterator(context, iter) - } - - /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */ - @DeveloperApi - def mapPartitionsWithInputSplit[U: ClassTag]( - f: (InputSplit, Iterator[(K, V)]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = { - new NewHadoopMapPartitionsWithSplitRDD(this, f, preservesPartitioning) + iter } override def getPreferredLocations(hsplit: SparkPartition): Seq[String] = { @@ -241,6 +249,21 @@ private[sql] class SqlNewHadoopRDD[K, V]( } private[spark] object SqlNewHadoopRDD { + + /** + * The thread variable for the name of the current file being read. This is used by + * the InputFileName function in Spark SQL. + */ + private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { + override protected def initialValue(): UTF8String = UTF8String.fromString("") + } + + def getInputFileName(): UTF8String = inputFileName.get() + + private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) + + private[spark] def unsetInputFileName(): Unit = inputFileName.remove() + /** * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to * the given function rather than the index of the partition. diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index f7cb1791d4ac..25ec685eff5a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag @@ -63,15 +63,17 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( } override def getDependencies: Seq[Dependency[_]] = { - Seq(rdd1, rdd2).map { rdd => + def rddDependency[T1: ClassTag, T2: ClassTag](rdd: RDD[_ <: Product2[T1, T2]]) + : Dependency[_] = { if (rdd.partitioner == Some(part)) { logDebug("Adding one-to-one dependency with " + rdd) new OneToOneDependency(rdd) } else { logDebug("Adding shuffle dependency with " + rdd) - new ShuffleDependency(rdd, part, serializer) + new ShuffleDependency[T1, T2, Any](rdd, part, serializer) } } + Seq(rddDependency[K, V](rdd1), rddDependency[K, W](rdd2)) } override def getPartitions: Array[Partition] = { @@ -105,7 +107,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( seq } } - def integrate(depNum: Int, op: Product2[K, V] => Unit) = { + def integrate(depNum: Int, op: Product2[K, V] => Unit): Unit = { dependencies(depNum) match { case oneToOneDependency: OneToOneDependency[_] => val dependencyPartition = partition.narrowDeps(depNum).get.split @@ -125,7 +127,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( integrate(0, t => getSeq(t._1) += t._2) // the second dep is rdd2; remove all of its keys integrate(1, t => map.remove(t._1)) - map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten + map.asScala.iterator.map(t => t._2.iterator.map((t._1, _))).flatten } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 3986645350a8..66cf4369da2e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -37,9 +37,9 @@ import org.apache.spark.util.Utils */ private[spark] class UnionPartition[T: ClassTag]( idx: Int, - @transient rdd: RDD[T], + @transient private val rdd: RDD[T], val parentRddIndex: Int, - @transient parentRddPartitionIndex: Int) + @transient private val parentRddPartitionIndex: Int) extends Partition { var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex) diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index 81f40ad33aa5..70bf04de6400 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -26,7 +26,7 @@ import org.apache.spark.util.Utils private[spark] class ZippedPartitionsPartition( idx: Int, - @transient rdds: Seq[RDD[_]], + @transient private val rdds: Seq[RDD[_]], @transient val preferredLocations: Seq[String]) extends Partition { @@ -73,6 +73,16 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag]( super.clearDependencies() rdds = null } + + /** + * Call the prepare method of every parent that has one. + * This is needed for reserving execution memory in advance. + */ + protected def tryPrepareParents(): Unit = { + rdds.collect { + case rdd: MapPartitionsWithPreparationRDD[_, _, _] => rdd.prepare() + } + } } private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag]( @@ -84,6 +94,7 @@ private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag] extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { + tryPrepareParents() val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context)) } @@ -107,6 +118,7 @@ private[spark] class ZippedPartitionsRDD3 extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { + tryPrepareParents() val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context), @@ -134,6 +146,7 @@ private[spark] class ZippedPartitionsRDD4 extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { + tryPrepareParents() val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context), diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index 523aaf2b860b..32931d59acb1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -37,7 +37,7 @@ class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long) * @tparam T parent RDD item type */ private[spark] -class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, Long)](prev) { +class ZippedWithIndexRDD[T: ClassTag](prev: RDD[T]) extends RDD[(T, Long)](prev) { /** The start index of each partition. */ @transient private val startIndices: Array[Long] = { @@ -50,8 +50,7 @@ class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, L prev.context.runJob( prev, Utils.getIteratorSize _, - 0 until n - 1, // do not need to count the last partition - allowLocal = false + 0 until n - 1 // do not need to count the last partition ).scanLeft(0L)(_ + _) } } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala index d2b2baef1d8c..dfcbc51cdf61 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -47,11 +47,11 @@ private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint * * It is guaranteed that `onStart`, `receive` and `onStop` will be called in sequence. * - * The lift-cycle will be: + * The life-cycle of an endpoint is: * - * constructor onStart receive* onStop + * constructor -> onStart -> receive* -> onStop * - * Note: `receive` can be called concurrently. If you want `receive` is thread-safe, please use + * Note: `receive` can be called concurrently. If you want `receive` to be thread-safe, please use * [[ThreadSafeRpcEndpoint]] * * If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala index 69181edb9ad4..f25710bb5bd6 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -17,8 +17,7 @@ package org.apache.spark.rpc -import scala.concurrent.{Await, Future} -import scala.concurrent.duration.FiniteDuration +import scala.concurrent.Future import scala.reflect.ClassTag import org.apache.spark.util.RpcUtils @@ -27,12 +26,12 @@ import org.apache.spark.{SparkException, Logging, SparkConf} /** * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. */ -private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) +private[spark] abstract class RpcEndpointRef(conf: SparkConf) extends Serializable with Logging { private[this] val maxRetries = RpcUtils.numRetries(conf) private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf) - private[this] val defaultAskTimeout = RpcUtils.askTimeout(conf) + private[this] val defaultAskTimeout = RpcUtils.askRpcTimeout(conf) /** * return the address for the [[RpcEndpointRef]] @@ -52,7 +51,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) * * This method only sends the message once and never retries. */ - def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] + def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] /** * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a [[Future]] to @@ -91,7 +90,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) * @tparam T type of the reply message * @return the reply message from the corresponding [[RpcEndpoint]] */ - def askWithRetry[T: ClassTag](message: Any, timeout: FiniteDuration): T = { + def askWithRetry[T: ClassTag](message: Any, timeout: RpcTimeout): T = { // TODO: Consider removing multiple attempts var attempts = 0 var lastException: Exception = null @@ -99,9 +98,9 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) attempts += 1 try { val future = ask[T](message, timeout) - val result = Await.result(future, timeout) + val result = timeout.awaitResult(future) if (result == null) { - throw new SparkException("Actor returned null") + throw new SparkException("RpcEndpoint returned null") } return result } catch { @@ -110,10 +109,14 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) lastException = e logWarning(s"Error sending message [message = $message] in $attempts attempts", e) } - Thread.sleep(retryWaitMs) + + if (attempts < maxRetries) { + Thread.sleep(retryWaitMs) + } } throw new SparkException( s"Error sending message [message = $message]", lastException) } + } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 12b6b28d4d7e..29debe808130 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -18,8 +18,10 @@ package org.apache.spark.rpc import java.net.URI +import java.util.concurrent.TimeoutException -import scala.concurrent.{Await, Future} +import scala.concurrent.{Awaitable, Await, Future} +import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.spark.{SecurityManager, SparkConf} @@ -37,8 +39,7 @@ private[spark] object RpcEnv { val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory") val rpcEnvName = conf.get("spark.rpc", "akka") val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) - Class.forName(rpcEnvFactoryClassName, true, Utils.getContextOrSparkClassLoader). - newInstance().asInstanceOf[RpcEnvFactory] + Utils.classForName(rpcEnvFactoryClassName).newInstance().asInstanceOf[RpcEnvFactory] } def create( @@ -66,7 +67,7 @@ private[spark] object RpcEnv { */ private[spark] abstract class RpcEnv(conf: SparkConf) { - private[spark] val defaultLookupTimeout = RpcUtils.lookupTimeout(conf) + private[spark] val defaultLookupTimeout = RpcUtils.lookupRpcTimeout(conf) /** * Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement @@ -94,7 +95,7 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * Retrieve the [[RpcEndpointRef]] represented by `uri`. This is a blocking action. */ def setupEndpointRefByURI(uri: String): RpcEndpointRef = { - Await.result(asyncSetupEndpointRefByURI(uri), defaultLookupTimeout) + defaultLookupTimeout.awaitResult(asyncSetupEndpointRefByURI(uri)) } /** @@ -138,6 +139,12 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * creating it manually because different [[RpcEnv]] may have different formats. */ def uriOf(systemName: String, address: RpcAddress, endpointName: String): String + + /** + * [[RpcEndpointRef]] cannot be deserialized without [[RpcEnv]]. So when deserializing any object + * that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method. + */ + def deserialize[T](deserializationAction: () => T): T } @@ -158,6 +165,8 @@ private[spark] case class RpcAddress(host: String, port: Int) { val hostPort: String = host + ":" + port override val toString: String = hostPort + + def toSparkURL: String = "spark://" + hostPort } @@ -182,3 +191,107 @@ private[spark] object RpcAddress { RpcAddress(host, port) } } + + +/** + * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. + */ +private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException) + extends TimeoutException(message) { initCause(cause) } + + +/** + * Associates a timeout with a description so that a when a TimeoutException occurs, additional + * context about the timeout can be amended to the exception message. + * @param duration timeout duration in seconds + * @param timeoutProp the configuration property that controls this timeout + */ +private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: String) + extends Serializable { + + /** Amends the standard message of TimeoutException to include the description */ + private def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = { + new RpcTimeoutException(te.getMessage() + ". This timeout is controlled by " + timeoutProp, te) + } + + /** + * PartialFunction to match a TimeoutException and add the timeout description to the message + * + * @note This can be used in the recover callback of a Future to add to a TimeoutException + * Example: + * val timeout = new RpcTimeout(5 millis, "short timeout") + * Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout) + */ + def addMessageIfTimeout[T]: PartialFunction[Throwable, T] = { + // The exception has already been converted to a RpcTimeoutException so just raise it + case rte: RpcTimeoutException => throw rte + // Any other TimeoutException get converted to a RpcTimeoutException with modified message + case te: TimeoutException => throw createRpcTimeoutException(te) + } + + /** + * Wait for the completed result and return it. If the result is not available within this + * timeout, throw a [[RpcTimeoutException]] to indicate which configuration controls the timeout. + * @param awaitable the `Awaitable` to be awaited + * @throws RpcTimeoutException if after waiting for the specified time `awaitable` + * is still not ready + */ + def awaitResult[T](awaitable: Awaitable[T]): T = { + try { + Await.result(awaitable, duration) + } catch addMessageIfTimeout + } +} + + +private[spark] object RpcTimeout { + + /** + * Lookup the timeout property in the configuration and create + * a RpcTimeout with the property key in the description. + * @param conf configuration properties containing the timeout + * @param timeoutProp property key for the timeout in seconds + * @throws NoSuchElementException if property is not set + */ + def apply(conf: SparkConf, timeoutProp: String): RpcTimeout = { + val timeout = { conf.getTimeAsSeconds(timeoutProp) seconds } + new RpcTimeout(timeout, timeoutProp) + } + + /** + * Lookup the timeout property in the configuration and create + * a RpcTimeout with the property key in the description. + * Uses the given default value if property is not set + * @param conf configuration properties containing the timeout + * @param timeoutProp property key for the timeout in seconds + * @param defaultValue default timeout value in seconds if property not found + */ + def apply(conf: SparkConf, timeoutProp: String, defaultValue: String): RpcTimeout = { + val timeout = { conf.getTimeAsSeconds(timeoutProp, defaultValue) seconds } + new RpcTimeout(timeout, timeoutProp) + } + + /** + * Lookup prioritized list of timeout properties in the configuration + * and create a RpcTimeout with the first set property key in the + * description. + * Uses the given default value if property is not set + * @param conf configuration properties containing the timeout + * @param timeoutPropList prioritized list of property keys for the timeout in seconds + * @param defaultValue default timeout value in seconds if no properties found + */ + def apply(conf: SparkConf, timeoutPropList: Seq[String], defaultValue: String): RpcTimeout = { + require(timeoutPropList.nonEmpty) + + // Find the first set property or use the default value with the first property + val itr = timeoutPropList.iterator + var foundProp: Option[(String, String)] = None + while (itr.hasNext && foundProp.isEmpty){ + val propKey = itr.next() + conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) } + } + val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue) + val timeout = { Utils.timeStringAsSeconds(finalProp._2) seconds } + new RpcTimeout(timeout, finalProp._1) + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 0161962cde07..ad67e1c5ad4d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -20,7 +20,6 @@ package org.apache.spark.rpc.akka import java.util.concurrent.ConcurrentHashMap import scala.concurrent.Future -import scala.concurrent.duration._ import scala.language.postfixOps import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -29,7 +28,7 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Add import akka.event.Logging.Error import akka.pattern.{ask => akkaAsk} import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} -import com.google.common.util.concurrent.MoreExecutors +import akka.serialization.JavaSerializer import org.apache.spark.{SparkException, Logging, SparkConf} import org.apache.spark.rpc._ @@ -88,9 +87,9 @@ private[spark] class AkkaRpcEnv private[akka] ( override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { @volatile var endpointRef: AkkaRpcEndpointRef = null - // Use lazy because the Actor needs to use `endpointRef`. + // Use defered function because the Actor needs to use `endpointRef`. // So `actorRef` should be created after assigning `endpointRef`. - lazy val actorRef = actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { + val actorRef = () => actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { assert(endpointRef != null) @@ -180,10 +179,10 @@ private[spark] class AkkaRpcEnv private[akka] ( }) } catch { case NonFatal(e) => - if (needReply) { - // If the sender asks a reply, we should send the error back to the sender - _sender ! AkkaFailure(e) - } else { + _sender ! AkkaFailure(e) + if (!needReply) { + // If the sender does not require a reply, it may not handle the exception. So we rethrow + // "e" to make sure it will be processed. throw e } } @@ -214,8 +213,11 @@ private[spark] class AkkaRpcEnv private[akka] ( override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { import actorSystem.dispatcher - actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout). - map(new AkkaRpcEndpointRef(defaultAddress, _, conf)) + actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout.duration). + map(new AkkaRpcEndpointRef(defaultAddress, _, conf)). + // this is just in case there is a timeout from creating the future in resolveOne, we want the + // exception to indicate the conf that determines the timeout + recover(defaultLookupTimeout.addMessageIfTimeout) } override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = { @@ -237,6 +239,12 @@ private[spark] class AkkaRpcEnv private[akka] ( } override def toString: String = s"${getClass.getSimpleName}($actorSystem)" + + override def deserialize[T](deserializationAction: () => T): T = { + JavaSerializer.currentSystem.withValue(actorSystem.asInstanceOf[ExtendedActorSystem]) { + deserializationAction() + } + } } private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { @@ -264,13 +272,20 @@ private[akka] class ErrorMonitor extends Actor with ActorLogReceive with Logging } private[akka] class AkkaRpcEndpointRef( - @transient defaultAddress: RpcAddress, - @transient _actorRef: => ActorRef, - @transient conf: SparkConf, - @transient initInConstructor: Boolean = true) + @transient private val defaultAddress: RpcAddress, + @transient private val _actorRef: () => ActorRef, + conf: SparkConf, + initInConstructor: Boolean) extends RpcEndpointRef(conf) with Logging { - lazy val actorRef = _actorRef + def this( + defaultAddress: RpcAddress, + _actorRef: ActorRef, + conf: SparkConf) = { + this(defaultAddress, () => _actorRef, conf, true) + } + + lazy val actorRef = _actorRef() override lazy val address: RpcAddress = { val akkaAddress = actorRef.path.address @@ -295,8 +310,8 @@ private[akka] class AkkaRpcEndpointRef( actorRef ! AkkaMessage(message, false) } - override def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { - actorRef.ask(AkkaMessage(message, true))(timeout).flatMap { + override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { + actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap { // The function will run in the calling thread, so it should be short and never block. case msg @ AkkaMessage(message, reply) => if (reply) { @@ -307,11 +322,18 @@ private[akka] class AkkaRpcEndpointRef( } case AkkaFailure(e) => Future.failed(e) - }(ThreadUtils.sameThread).mapTo[T] + }(ThreadUtils.sameThread).mapTo[T]. + recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) } override def toString: String = s"${getClass.getSimpleName}($actorRef)" + final override def equals(that: Any): Boolean = that match { + case other: AkkaRpcEndpointRef => actorRef == other.actorRef + case _ => false + } + + final override def hashCode(): Int = if (actorRef == null) 0 else actorRef.hashCode() } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index e0edd7d4ae96..b6bff64ee368 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -24,26 +24,33 @@ import org.apache.spark.annotation.DeveloperApi * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage. */ @DeveloperApi -class AccumulableInfo ( +class AccumulableInfo private[spark] ( val id: Long, val name: String, val update: Option[String], // represents a partial update within a task - val value: String) { + val value: String, + val internal: Boolean) { override def equals(other: Any): Boolean = other match { case acc: AccumulableInfo => this.id == acc.id && this.name == acc.name && - this.update == acc.update && this.value == acc.value + this.update == acc.update && this.value == acc.value && + this.internal == acc.internal case _ => false } + + override def hashCode(): Int = { + val state = Seq(id, name, update, value, internal) + state.map(_.hashCode).reduceLeft(31 * _ + _) + } } object AccumulableInfo { def apply(id: Long, name: String, update: Option[String], value: String): AccumulableInfo = { - new AccumulableInfo(id, name, update, value) + new AccumulableInfo(id, name, update, value, internal = false) } def apply(id: Long, name: String, value: String): AccumulableInfo = { - new AccumulableInfo(id, name, None, value) + new AccumulableInfo(id, name, None, value, internal = false) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala index 50a69379412d..a3d2db31301b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala @@ -23,18 +23,42 @@ import org.apache.spark.TaskContext import org.apache.spark.util.CallSite /** - * Tracks information about an active job in the DAGScheduler. + * A running job in the DAGScheduler. Jobs can be of two types: a result job, which computes a + * ResultStage to execute an action, or a map-stage job, which computes the map outputs for a + * ShuffleMapStage before any downstream stages are submitted. The latter is used for adaptive + * query planning, to look at map output statistics before submitting later stages. We distinguish + * between these two types of jobs using the finalStage field of this class. + * + * Jobs are only tracked for "leaf" stages that clients directly submitted, through DAGScheduler's + * submitJob or submitMapStage methods. However, either type of job may cause the execution of + * other earlier stages (for RDDs in the DAG it depends on), and multiple jobs may share some of + * these previous stages. These dependencies are managed inside DAGScheduler. + * + * @param jobId A unique ID for this job. + * @param finalStage The stage that this job computes (either a ResultStage for an action or a + * ShuffleMapStage for submitMapStage). + * @param callSite Where this job was initiated in the user's program (shown on UI). + * @param listener A listener to notify if tasks in this job finish or the job fails. + * @param properties Scheduling properties attached to the job, such as fair scheduler pool name. */ private[spark] class ActiveJob( val jobId: Int, - val finalStage: ResultStage, - val func: (TaskContext, Iterator[_]) => _, - val partitions: Array[Int], + val finalStage: Stage, val callSite: CallSite, val listener: JobListener, val properties: Properties) { - val numPartitions = partitions.length + /** + * Number of partitions we need to compute for this job. Note that result stages may not need + * to compute all partitions in their target RDD, for actions like first() and lookup(). + */ + val numPartitions = finalStage match { + case r: ResultStage => r.partitions.length + case m: ShuffleMapStage => m.rdd.partitions.length + } + + /** Which partitions of the stage have finished */ val finished = Array.fill[Boolean](numPartitions)(false) + var numFinished = 0 } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index aea6674ed20b..3c9a66e50440 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -22,7 +22,8 @@ import java.util.Properties import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} +import scala.collection.Map +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Stack} import scala.concurrent.duration._ import scala.language.existentials import scala.language.postfixOps @@ -35,8 +36,8 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ -import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -44,17 +45,65 @@ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of * stages for each job, keeps track of which RDDs and stage outputs are materialized, and finds a * minimal schedule to run the job. It then submits stages as TaskSets to an underlying - * TaskScheduler implementation that runs them on the cluster. + * TaskScheduler implementation that runs them on the cluster. A TaskSet contains fully independent + * tasks that can run right away based on the data that's already on the cluster (e.g. map output + * files from previous stages), though it may fail if this data becomes unavailable. * - * In addition to coming up with a DAG of stages, this class also determines the preferred + * Spark stages are created by breaking the RDD graph at shuffle boundaries. RDD operations with + * "narrow" dependencies, like map() and filter(), are pipelined together into one set of tasks + * in each stage, but operations with shuffle dependencies require multiple stages (one to write a + * set of map output files, and another to read those files after a barrier). In the end, every + * stage will have only shuffle dependencies on other stages, and may compute multiple operations + * inside it. The actual pipelining of these operations happens in the RDD.compute() functions of + * various RDDs (MappedRDD, FilteredRDD, etc). + * + * In addition to coming up with a DAG of stages, the DAGScheduler also determines the preferred * locations to run each task on, based on the current cache status, and passes these to the * low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being * lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task * a small number of times before cancelling the whole stage. * + * When looking through this code, there are several key concepts: + * + * - Jobs (represented by [[ActiveJob]]) are the top-level work items submitted to the scheduler. + * For example, when the user calls an action, like count(), a job will be submitted through + * submitJob. Each Job may require the execution of multiple stages to build intermediate data. + * + * - Stages ([[Stage]]) are sets of tasks that compute intermediate results in jobs, where each + * task computes the same function on partitions of the same RDD. Stages are separated at shuffle + * boundaries, which introduce a barrier (where we must wait for the previous stage to finish to + * fetch outputs). There are two types of stages: [[ResultStage]], for the final stage that + * executes an action, and [[ShuffleMapStage]], which writes map output files for a shuffle. + * Stages are often shared across multiple jobs, if these jobs reuse the same RDDs. + * + * - Tasks are individual units of work, each sent to one machine. + * + * - Cache tracking: the DAGScheduler figures out which RDDs are cached to avoid recomputing them + * and likewise remembers which shuffle map stages have already produced output files to avoid + * redoing the map side of a shuffle. + * + * - Preferred locations: the DAGScheduler also computes where to run each task in a stage based + * on the preferred locations of its underlying RDDs, or the location of cached or shuffle data. + * + * - Cleanup: all data structures are cleared when the running jobs that depend on them finish, + * to prevent memory leaks in a long-running application. + * + * To recover from failures, the same stage might need to run multiple times, which are called + * "attempts". If the TaskScheduler reports that a task failed because a map output file from a + * previous stage was lost, the DAGScheduler resubmits that lost stage. This is detected through a + * CompletionEvent with FetchFailed, or an ExecutorLost event. The DAGScheduler will wait a small + * amount of time to see whether other nodes or tasks fail, then resubmit TaskSets for any lost + * stage(s) that compute the missing tasks. As part of this process, we might also have to create + * Stage objects for old (finished) stages where we previously cleaned up the Stage object. Since + * tasks from the old attempt of a stage could still be running, care must be taken to map any + * events received in the correct Stage object. + * * Here's a checklist to use when making or reviewing changes to this class: * + * - All data structures should be cleared when the jobs involving them end to avoid indefinite + * accumulation of state in long-running programs. + * * - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to * include the new structure. This will help to catch memory leaks. */ @@ -81,6 +130,8 @@ class DAGScheduler( def this(sc: SparkContext) = this(sc, sc.taskScheduler) + private[scheduler] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this) + private[scheduler] val nextJobId = new AtomicInteger(0) private[scheduler] def numTotalJobs: Int = nextJobId.get() private val nextStageId = new AtomicInteger(0) @@ -108,7 +159,7 @@ class DAGScheduler( * * All accesses to this map should be guarded by synchronizing on it (see SPARK-4454). */ - private val cacheLocs = new HashMap[Int, Seq[Seq[TaskLocation]]] + private val cacheLocs = new HashMap[Int, IndexedSeq[Seq[TaskLocation]]] // For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with // every task. When we detect a node failing, we note the current epoch number and failed @@ -124,10 +175,6 @@ class DAGScheduler( // This is only safe because DAGScheduler runs in a single thread. private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() - - /** If enabled, we may run certain actions like take() and first() locally. */ - private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false) - /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) @@ -153,17 +200,24 @@ class DAGScheduler( // may lead to more delay in scheduling if those locations are busy. private[scheduler] val REDUCER_PREF_LOCS_FRACTION = 0.2 - // Called by TaskScheduler to report task's starting. + /** + * Called by the TaskSetManager to report task's starting. + */ def taskStarted(task: Task[_], taskInfo: TaskInfo) { eventProcessLoop.post(BeginEvent(task, taskInfo)) } - // Called to report that a task has completed and results are being fetched remotely. + /** + * Called by the TaskSetManager to report that a task has completed + * and results are being fetched remotely. + */ def taskGettingResult(taskInfo: TaskInfo) { eventProcessLoop.post(GettingResultEvent(taskInfo)) } - // Called by TaskScheduler to report task completions or failures. + /** + * Called by the TaskSetManager to report task completions or failures. + */ def taskEnded( task: Task[_], reason: TaskEndReason, @@ -186,32 +240,38 @@ class DAGScheduler( blockManagerId: BlockManagerId): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) blockManagerMaster.driverEndpoint.askWithRetry[Boolean]( - BlockManagerHeartbeat(blockManagerId), 600 seconds) + BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat")) } - // Called by TaskScheduler when an executor fails. + /** + * Called by TaskScheduler implementation when an executor fails. + */ def executorLost(execId: String): Unit = { eventProcessLoop.post(ExecutorLost(execId)) } - // Called by TaskScheduler when a host is added + /** + * Called by TaskScheduler implementation when a host is added. + */ def executorAdded(execId: String, host: String): Unit = { eventProcessLoop.post(ExecutorAdded(execId, host)) } - // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or - // cancellation of the job itself. - def taskSetFailed(taskSet: TaskSet, reason: String): Unit = { - eventProcessLoop.post(TaskSetFailed(taskSet, reason)) + /** + * Called by the TaskSetManager to cancel an entire TaskSet due to either repeated failures or + * cancellation of the job itself. + */ + def taskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]): Unit = { + eventProcessLoop.post(TaskSetFailed(taskSet, reason, exception)) } private[scheduler] - def getCacheLocs(rdd: RDD[_]): Seq[Seq[TaskLocation]] = cacheLocs.synchronized { + def getCacheLocs(rdd: RDD[_]): IndexedSeq[Seq[TaskLocation]] = cacheLocs.synchronized { // Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times if (!cacheLocs.contains(rdd.id)) { // Note: if the storage level is NONE, we don't need to get locations from block manager. - val locs: Seq[Seq[TaskLocation]] = if (rdd.getStorageLevel == StorageLevel.NONE) { - Seq.fill(rdd.partitions.size)(Nil) + val locs: IndexedSeq[Seq[TaskLocation]] = if (rdd.getStorageLevel == StorageLevel.NONE) { + IndexedSeq.fill(rdd.partitions.length)(Nil) } else { val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] @@ -238,11 +298,12 @@ class DAGScheduler( case Some(stage) => stage case None => // We are going to register ancestor shuffle dependencies - registerShuffleDependencies(shuffleDep, firstJobId) + getAncestorShuffleDependencies(shuffleDep.rdd).foreach { dep => + shuffleToMapStage(dep.shuffleId) = newOrUsedShuffleStage(dep, firstJobId) + } // Then register current shuffleDep val stage = newOrUsedShuffleStage(shuffleDep, firstJobId) shuffleToMapStage(shuffleDep.shuffleId) = stage - stage } } @@ -282,12 +343,12 @@ class DAGScheduler( */ private def newResultStage( rdd: RDD[_], - numTasks: Int, + func: (TaskContext, Iterator[_]) => _, + partitions: Array[Int], jobId: Int, callSite: CallSite): ResultStage = { val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId) - val stage: ResultStage = new ResultStage(id, rdd, numTasks, parentStages, jobId, callSite) - + val stage = new ResultStage(id, rdd, func, partitions, parentStages, jobId, callSite) stageIdToStage(id) = stage updateJobIdStageIdMaps(jobId, stage) stage @@ -303,12 +364,12 @@ class DAGScheduler( shuffleDep: ShuffleDependency[_, _, _], firstJobId: Int): ShuffleMapStage = { val rdd = shuffleDep.rdd - val numTasks = rdd.partitions.size + val numTasks = rdd.partitions.length val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, firstJobId, rdd.creationSite) if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) val locs = MapOutputTracker.deserializeMapStatuses(serLocs) - for (i <- 0 until locs.size) { + for (i <- 0 until locs.length) { stage.outputLocs(i) = Option(locs(i)).toList // locs(i) will be null if missing } stage.numAvailableOutputs = locs.count(_ != null) @@ -316,7 +377,7 @@ class DAGScheduler( // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of partitions is unknown logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")") - mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.size) + mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length) } stage } @@ -353,16 +414,6 @@ class DAGScheduler( parents.toList } - /** Find ancestor missing shuffle dependencies and register into shuffleToMapStage */ - private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], firstJobId: Int) { - val parentsWithNoMapStage = getAncestorShuffleDependencies(shuffleDep.rdd) - while (parentsWithNoMapStage.nonEmpty) { - val currentShufDep = parentsWithNoMapStage.pop() - val stage = newOrUsedShuffleStage(currentShufDep, firstJobId) - shuffleToMapStage(currentShufDep.shuffleId) = stage - } - } - /** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */ private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = { val parents = new Stack[ShuffleDependency[_, _, _]] @@ -379,11 +430,9 @@ class DAGScheduler( if (!shuffleToMapStage.contains(shufDep.shuffleId)) { parents.push(shufDep) } - - waitingForVisit.push(shufDep.rdd) case _ => - waitingForVisit.push(dep.rdd) } + waitingForVisit.push(dep.rdd) } } } @@ -499,19 +548,31 @@ class DAGScheduler( jobIdToStageIds -= job.jobId jobIdToActiveJob -= job.jobId activeJobs -= job - job.finalStage.resultOfJob = None + job.finalStage match { + case r: ResultStage => + r.resultOfJob = None + case m: ShuffleMapStage => + m.mapStageJobs = m.mapStageJobs.filter(_ != job) + } } /** - * Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object + * Submit an action job to the scheduler and get a JobWaiter object back. The JobWaiter object * can be used to block until the the job finishes executing or can be used to cancel the job. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like first() + * @param callSite where in the user program this job was called + * @param resultHandler callback to pass each result to + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name */ def submitJob[T, U]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: CallSite, - allowLocal: Boolean, resultHandler: (Int, U) => Unit, properties: Properties): JobWaiter[U] = { // Check to make sure we are not launching a task on a partition that does not exist. @@ -524,6 +585,7 @@ class DAGScheduler( val jobId = nextJobId.getAndIncrement() if (partitions.size == 0) { + // Return immediately if the job is running 0 tasks return new JobWaiter[U](this, jobId, 0, resultHandler) } @@ -531,21 +593,32 @@ class DAGScheduler( val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) eventProcessLoop.post(JobSubmitted( - jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, + jobId, rdd, func2, partitions.toArray, callSite, waiter, SerializationUtils.clone(properties))) waiter } + /** + * Run an action job on the given RDD and pass all the results to the resultHandler function as + * they arrive. Throws an exception if the job fials, or returns normally if successful. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like first() + * @param callSite where in the user program this job was called + * @param resultHandler callback to pass each result to + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + */ def runJob[T, U]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: CallSite, - allowLocal: Boolean, resultHandler: (Int, U) => Unit, properties: Properties): Unit = { val start = System.nanoTime - val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties) + val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties) waiter.awaitResult() match { case JobSucceeded => logInfo("Job %d finished: %s, took %f s".format @@ -553,10 +626,24 @@ class DAGScheduler( case JobFailed(exception: Exception) => logInfo("Job %d failed: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) + // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler. + val callerStackTrace = Thread.currentThread().getStackTrace.tail + exception.setStackTrace(exception.getStackTrace ++ callerStackTrace) throw exception } } + /** + * Run an approximate job on the given RDD and pass all the results to an ApproximateEvaluator + * as they arrive. Returns a partial result object from the evaluator. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param evaluator [[ApproximateEvaluator]] to receive the partial results + * @param callSite where in the user program this job was called + * @param timeout maximum time to wait for the job, in milliseconds + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + */ def runApproximateJob[T, U, R]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, @@ -566,14 +653,48 @@ class DAGScheduler( properties: Properties): PartialResult[R] = { val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - val partitions = (0 until rdd.partitions.size).toArray + val partitions = (0 until rdd.partitions.length).toArray val jobId = nextJobId.getAndIncrement() eventProcessLoop.post(JobSubmitted( - jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, - SerializationUtils.clone(properties))) + jobId, rdd, func2, partitions, callSite, listener, SerializationUtils.clone(properties))) listener.awaitResult() // Will throw an exception if the job fails } + /** + * Submit a shuffle map stage to run independently and get a JobWaiter object back. The waiter + * can be used to block until the the job finishes executing or can be used to cancel the job. + * This method is used for adaptive query planning, to run map stages and look at statistics + * about their outputs before submitting downstream stages. + * + * @param dependency the ShuffleDependency to run a map stage for + * @param callback function called with the result of the job, which in this case will be a + * single MapOutputStatistics object showing how much data was produced for each partition + * @param callSite where in the user program this job was submitted + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + */ + def submitMapStage[K, V, C]( + dependency: ShuffleDependency[K, V, C], + callback: MapOutputStatistics => Unit, + callSite: CallSite, + properties: Properties): JobWaiter[MapOutputStatistics] = { + + val rdd = dependency.rdd + val jobId = nextJobId.getAndIncrement() + if (rdd.partitions.length == 0) { + throw new SparkException("Can't run submitMapStage on RDD with 0 partitions") + } + + // We create a JobWaiter with only one "task", which will be marked as complete when the whole + // map stage has completed, and will be passed the MapOutputStatistics for that stage. + // This makes it easier to avoid race conditions between the user code and the map output + // tracker that might result if we told the user the stage had finished, but then they queries + // the map output tracker and some node failures had caused the output statistics to be lost. + val waiter = new JobWaiter(this, jobId, 1, (i: Int, r: MapOutputStatistics) => callback(r)) + eventProcessLoop.post(MapStageSubmitted( + jobId, dependency, callSite, waiter, SerializationUtils.clone(properties))) + waiter + } + /** * Cancel a job that is running or waiting in the queue. */ @@ -582,6 +703,9 @@ class DAGScheduler( eventProcessLoop.post(JobCancelled(jobId)) } + /** + * Cancel all jobs in the given job group ID. + */ def cancelJobGroup(groupId: String): Unit = { logInfo("Asked to cancel job group " + groupId) eventProcessLoop.post(JobGroupCancelled(groupId)) @@ -647,73 +771,6 @@ class DAGScheduler( } } - /** - * Run a job on an RDD locally, assuming it has only a single partition and no dependencies. - * We run the operation in a separate thread just in case it takes a bunch of time, so that we - * don't block the DAGScheduler event loop or other concurrent jobs. - */ - protected def runLocally(job: ActiveJob) { - logInfo("Computing the requested partition locally") - new Thread("Local computation of job " + job.jobId) { - override def run() { - runLocallyWithinThread(job) - } - }.start() - } - - // Broken out for easier testing in DAGSchedulerSuite. - protected def runLocallyWithinThread(job: ActiveJob) { - var jobResult: JobResult = JobSucceeded - try { - val rdd = job.finalStage.rdd - val split = rdd.partitions(job.partitions(0)) - val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager) - val taskContext = - new TaskContextImpl( - job.finalStage.id, - job.partitions(0), - taskAttemptId = 0, - attemptNumber = 0, - taskMemoryManager = taskMemoryManager, - runningLocally = true) - TaskContext.setTaskContext(taskContext) - try { - val result = job.func(taskContext, rdd.iterator(split, taskContext)) - job.listener.taskSucceeded(0, result) - } finally { - taskContext.markTaskCompleted() - TaskContext.unset() - // Note: this memory freeing logic is duplicated in Executor.run(); when changing this, - // make sure to update both copies. - val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() - if (freedMemory > 0) { - if (sc.getConf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) { - throw new SparkException(s"Managed memory leak detected; size = $freedMemory bytes") - } else { - logError(s"Managed memory leak detected; size = $freedMemory bytes") - } - } - } - } catch { - case e: Exception => - val exception = new SparkDriverExecutionException(e) - jobResult = JobFailed(exception) - job.listener.jobFailed(exception) - case oom: OutOfMemoryError => - val exception = new SparkException("Local job aborted due to out of memory error", oom) - jobResult = JobFailed(exception) - job.listener.jobFailed(exception) - } finally { - val s = job.finalStage - // clean up data structures that were populated for a local job, - // but that won't get cleaned up via the normal paths through - // completion events or stage abort - stageIdToStage -= s.id - jobIdToStageIds -= job.jobId - listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), jobResult)) - } - } - /** Finds the earliest-created active job that needs the stage */ // TODO: Probably should actually find among the active jobs that need this // stage the one with the highest priority (highest-priority pool, earliest created). @@ -745,8 +802,11 @@ class DAGScheduler( submitWaitingStages() } - private[scheduler] def handleTaskSetFailed(taskSet: TaskSet, reason: String) { - stageIdToStage.get(taskSet.stageId).foreach {abortStage(_, reason) } + private[scheduler] def handleTaskSetFailed( + taskSet: TaskSet, + reason: String, + exception: Option[Throwable]): Unit = { + stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason, exception) } submitWaitingStages() } @@ -776,7 +836,6 @@ class DAGScheduler( finalRDD: RDD[_], func: (TaskContext, Iterator[_]) => _, partitions: Array[Int], - allowLocal: Boolean, callSite: CallSite, listener: JobListener, properties: Properties) { @@ -784,40 +843,77 @@ class DAGScheduler( try { // New stage creation may throw an exception if, for example, jobs are run on a // HadoopRDD whose underlying HDFS files have been deleted. - finalStage = newResultStage(finalRDD, partitions.size, jobId, callSite) + finalStage = newResultStage(finalRDD, func, partitions, jobId, callSite) } catch { case e: Exception => logWarning("Creating new stage failed due to exception - job: " + jobId, e) listener.jobFailed(e) return } - if (finalStage != null) { - val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) - clearCacheLocs() - logInfo("Got job %s (%s) with %d output partitions (allowLocal=%s)".format( - job.jobId, callSite.shortForm, partitions.length, allowLocal)) - logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")") - logInfo("Parents of final stage: " + finalStage.parents) - logInfo("Missing parents: " + getMissingParentStages(finalStage)) - val shouldRunLocally = - localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1 - val jobSubmissionTime = clock.getTimeMillis() - if (shouldRunLocally) { - // Compute very short actions like first() or take() with no parent stages locally. - listenerBus.post( - SparkListenerJobStart(job.jobId, jobSubmissionTime, Seq.empty, properties)) - runLocally(job) - } else { - jobIdToActiveJob(jobId) = job - activeJobs += job - finalStage.resultOfJob = Some(job) - val stageIds = jobIdToStageIds(jobId).toArray - val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) - listenerBus.post( - SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) - submitStage(finalStage) - } + + val job = new ActiveJob(jobId, finalStage, callSite, listener, properties) + clearCacheLocs() + logInfo("Got job %s (%s) with %d output partitions".format( + job.jobId, callSite.shortForm, partitions.length)) + logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")") + logInfo("Parents of final stage: " + finalStage.parents) + logInfo("Missing parents: " + getMissingParentStages(finalStage)) + + val jobSubmissionTime = clock.getTimeMillis() + jobIdToActiveJob(jobId) = job + activeJobs += job + finalStage.resultOfJob = Some(job) + val stageIds = jobIdToStageIds(jobId).toArray + val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + listenerBus.post( + SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) + submitStage(finalStage) + + submitWaitingStages() + } + + private[scheduler] def handleMapStageSubmitted(jobId: Int, + dependency: ShuffleDependency[_, _, _], + callSite: CallSite, + listener: JobListener, + properties: Properties) { + // Submitting this map stage might still require the creation of some parent stages, so make + // sure that happens. + var finalStage: ShuffleMapStage = null + try { + // New stage creation may throw an exception if, for example, jobs are run on a + // HadoopRDD whose underlying HDFS files have been deleted. + finalStage = getShuffleMapStage(dependency, jobId) + } catch { + case e: Exception => + logWarning("Creating new stage failed due to exception - job: " + jobId, e) + listener.jobFailed(e) + return + } + + val job = new ActiveJob(jobId, finalStage, callSite, listener, properties) + clearCacheLocs() + logInfo("Got map stage job %s (%s) with %d output partitions".format( + jobId, callSite.shortForm, dependency.rdd.partitions.size)) + logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")") + logInfo("Parents of final stage: " + finalStage.parents) + logInfo("Missing parents: " + getMissingParentStages(finalStage)) + + val jobSubmissionTime = clock.getTimeMillis() + jobIdToActiveJob(jobId) = job + activeJobs += job + finalStage.mapStageJobs = job :: finalStage.mapStageJobs + val stageIds = jobIdToStageIds(jobId).toArray + val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + listenerBus.post( + SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) + submitStage(finalStage) + + // If the whole stage has already finished, tell the listener and remove it + if (!finalStage.outputLocs.contains(Nil)) { + markMapStageJobAsFinished(job, mapOutputTracker.getStatistics(dependency)) } + submitWaitingStages() } @@ -840,7 +936,7 @@ class DAGScheduler( } } } else { - abortStage(stage, "No active job for stage " + stage.id) + abortStage(stage, "No active job for stage " + stage.id, None) } } @@ -850,18 +946,28 @@ class DAGScheduler( // Get our pending tasks and remember them in our pendingTasks entry stage.pendingTasks.clear() - // First figure out the indexes of partition ids to compute. - val partitionsToCompute: Seq[Int] = { + val (allPartitions: Seq[Int], partitionsToCompute: Seq[Int]) = { stage match { case stage: ShuffleMapStage => - (0 until stage.numPartitions).filter(id => stage.outputLocs(id).isEmpty) + val allPartitions = 0 until stage.numPartitions + val filteredPartitions = allPartitions.filter { id => stage.outputLocs(id).isEmpty } + (allPartitions, filteredPartitions) case stage: ResultStage => val job = stage.resultOfJob.get - (0 until job.numPartitions).filter(id => !job.finished(id)) + val allPartitions = 0 until job.numPartitions + val filteredPartitions = allPartitions.filter { id => !job.finished(id) } + (allPartitions, filteredPartitions) } } + // Create internal accumulators if the stage has no accumulators initialized. + // Reset internal accumulators only if this stage is not partially submitted + // Otherwise, we may override existing accumulator values from some tasks + if (stage.internalAccumulators.isEmpty || allPartitions == partitionsToCompute) { + stage.resetInternalAccumulators() + } + val properties = jobIdToActiveJob.get(stage.firstJobId).map(_.properties).orNull runningStages += stage @@ -869,8 +975,28 @@ class DAGScheduler( // serializable. If tasks are not serializable, a SparkListenerStageCompleted event // will be posted, which should always come after a corresponding SparkListenerStageSubmitted // event. - stage.latestInfo = StageInfo.fromStage(stage, Some(partitionsToCompute.size)) outputCommitCoordinator.stageStart(stage.id) + val taskIdToLocations = try { + stage match { + case s: ShuffleMapStage => + partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap + case s: ResultStage => + val job = s.resultOfJob.get + partitionsToCompute.map { id => + val p = s.partitions(id) + (id, getPreferredLocs(stage.rdd, p)) + }.toMap + } + } catch { + case NonFatal(e) => + stage.makeNewStageAttempt(partitionsToCompute.size) + listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) + abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e)) + runningStages -= stage + return + } + + stage.makeNewStageAttempt(partitionsToCompute.size, taskIdToLocations.values.toSeq) listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times. @@ -887,48 +1013,57 @@ class DAGScheduler( case stage: ShuffleMapStage => closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).array() case stage: ResultStage => - closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func): AnyRef).array() + closureSerializer.serialize((stage.rdd, stage.func): AnyRef).array() } taskBinary = sc.broadcast(taskBinaryBytes) } catch { // In the case of a failure during serialization, abort the stage. case e: NotSerializableException => - abortStage(stage, "Task not serializable: " + e.toString) + abortStage(stage, "Task not serializable: " + e.toString, Some(e)) runningStages -= stage // Abort execution return case NonFatal(e) => - abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}") + abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}", Some(e)) runningStages -= stage return } - val tasks: Seq[Task[_]] = stage match { - case stage: ShuffleMapStage => - partitionsToCompute.map { id => - val locs = getPreferredLocs(stage.rdd, id) - val part = stage.rdd.partitions(id) - new ShuffleMapTask(stage.id, taskBinary, part, locs) - } + val tasks: Seq[Task[_]] = try { + stage match { + case stage: ShuffleMapStage => + partitionsToCompute.map { id => + val locs = taskIdToLocations(id) + val part = stage.rdd.partitions(id) + new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, + taskBinary, part, locs, stage.internalAccumulators) + } - case stage: ResultStage => - val job = stage.resultOfJob.get - partitionsToCompute.map { id => - val p: Int = job.partitions(id) - val part = stage.rdd.partitions(p) - val locs = getPreferredLocs(stage.rdd, p) - new ResultTask(stage.id, taskBinary, part, locs, id) - } + case stage: ResultStage => + val job = stage.resultOfJob.get + partitionsToCompute.map { id => + val p: Int = stage.partitions(id) + val part = stage.rdd.partitions(p) + val locs = taskIdToLocations(id) + new ResultTask(stage.id, stage.latestInfo.attemptId, + taskBinary, part, locs, id, stage.internalAccumulators) + } + } + } catch { + case NonFatal(e) => + abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e)) + runningStages -= stage + return } if (tasks.size > 0) { logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") stage.pendingTasks ++= tasks logDebug("New pending tasks: " + stage.pendingTasks) - taskScheduler.submitTasks( - new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.firstJobId, properties)) + taskScheduler.submitTasks(new TaskSet( + tasks.toArray, stage.id, stage.latestInfo.attemptId, stage.firstJobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark @@ -968,11 +1103,11 @@ class DAGScheduler( // To avoid UI cruft, ignore cases where value wasn't updated if (acc.name.isDefined && partialValue != acc.zero) { val name = acc.name.get - val stringPartialValue = Accumulators.stringifyPartialValue(partialValue) - val stringValue = Accumulators.stringifyValue(acc.value) - stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, stringValue) + val value = s"${acc.value}" + stage.latestInfo.accumulables(id) = + new AccumulableInfo(id, name, None, value, acc.isInternal) event.taskInfo.accumulables += - AccumulableInfo(id, name, Some(stringPartialValue), stringValue) + new AccumulableInfo(id, name, Some(s"$partialValue"), value, acc.isInternal) } } } catch { @@ -993,13 +1128,16 @@ class DAGScheduler( val stageId = task.stageId val taskType = Utils.getFormattedClassName(task) - outputCommitCoordinator.taskCompleted(stageId, task.partitionId, - event.taskInfo.attempt, event.reason) + outputCommitCoordinator.taskCompleted( + stageId, + task.partitionId, + event.taskInfo.attemptNumber, // this is a task attempt number + event.reason) // The success case is dealt with separately below, since we need to compute accumulator // updates before posting. if (event.reason != Success) { - val attemptId = stageIdToStage.get(task.stageId).map(_.latestInfo.attemptId).getOrElse(-1) + val attemptId = task.stageAttemptId listenerBus.post(SparkListenerTaskEnd(stageId, attemptId, taskType, event.reason, event.taskInfo, event.taskMetrics)) } @@ -1055,10 +1193,11 @@ class DAGScheduler( val execId = status.location.executorId logDebug("ShuffleMapTask finished on " + execId) if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { - logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId) + logInfo(s"Ignoring possibly bogus $smt completion from executor $execId") } else { shuffleStage.addOutputLoc(smt.partitionId, status) } + if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) { markStageAsFinished(shuffleStage) logInfo("looking for newly runnable stages") @@ -1074,41 +1213,32 @@ class DAGScheduler( // we registered these map outputs. mapOutputTracker.registerMapOutputs( shuffleStage.shuffleDep.shuffleId, - shuffleStage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, + shuffleStage.outputLocs.map(_.headOption.orNull), changeEpoch = true) clearCacheLocs() + if (shuffleStage.outputLocs.contains(Nil)) { // Some tasks had failed; let's resubmit this shuffleStage // TODO: Lower-level scheduler should also deal with this logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name + ") because some of its tasks had failed: " + shuffleStage.outputLocs.zipWithIndex.filter(_._1.isEmpty) - .map(_._2).mkString(", ")) + .map(_._2).mkString(", ")) submitStage(shuffleStage) } else { - val newlyRunnable = new ArrayBuffer[Stage] - for (shuffleStage <- waitingStages) { - logInfo("Missing parents for " + shuffleStage + ": " + - getMissingParentStages(shuffleStage)) - } - for (shuffleStage <- waitingStages if getMissingParentStages(shuffleStage).isEmpty) - { - newlyRunnable += shuffleStage - } - waitingStages --= newlyRunnable - runningStages ++= newlyRunnable - for { - shuffleStage <- newlyRunnable.sortBy(_.id) - jobId <- activeJobForStage(shuffleStage) - } { - logInfo("Submitting " + shuffleStage + " (" + - shuffleStage.rdd + "), which is now runnable") - submitMissingTasks(shuffleStage, jobId) + // Mark any map-stage jobs waiting on this stage as finished + if (shuffleStage.mapStageJobs.nonEmpty) { + val stats = mapOutputTracker.getStatistics(shuffleStage.shuffleDep) + for (job <- shuffleStage.mapStageJobs) { + markMapStageJobAsFinished(job, stats) + } } } + + // Note: newly runnable stages will be submitted below when we submit waiting stages } - } + } case Resubmitted => logInfo("Resubmitted " + task + ", so marking it as still running") @@ -1118,44 +1248,59 @@ class DAGScheduler( val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleToMapStage(shuffleId) - // It is likely that we receive multiple FetchFailed for a single stage (because we have - // multiple tasks running concurrently on different executors). In that case, it is possible - // the fetch failure has already been handled by the scheduler. - if (runningStages.contains(failedStage)) { - logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + - s"due to a fetch failure from $mapStage (${mapStage.name})") - markStageAsFinished(failedStage, Some(failureMessage)) - } + if (failedStage.latestInfo.attemptId != task.stageAttemptId) { + logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" + + s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + + s"(attempt ID ${failedStage.latestInfo.attemptId}) running") + } else { + // It is likely that we receive multiple FetchFailed for a single stage (because we have + // multiple tasks running concurrently on different executors). In that case, it is + // possible the fetch failure has already been handled by the scheduler. + if (runningStages.contains(failedStage)) { + logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + + s"due to a fetch failure from $mapStage (${mapStage.name})") + markStageAsFinished(failedStage, Some(failureMessage)) + } else { + logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " + + s"longer running") + } - if (disallowStageRetryForTest) { - abortStage(failedStage, "Fetch failure will not retry stage due to testing config") - } else if (failedStages.isEmpty) { - // Don't schedule an event to resubmit failed stages if failed isn't empty, because - // in that case the event will already have been scheduled. - // TODO: Cancel running tasks in the stage - logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + - s"$failedStage (${failedStage.name}) due to fetch failure") - messageScheduler.schedule(new Runnable { - override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) - }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) - } - failedStages += failedStage - failedStages += mapStage - // Mark the map whose fetch failed as broken in the map stage - if (mapId != -1) { - mapStage.removeOutputLoc(mapId, bmAddress) - mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) - } + if (disallowStageRetryForTest) { + abortStage(failedStage, "Fetch failure will not retry stage due to testing config", + None) + } else if (failedStage.failedOnFetchAndShouldAbort(task.stageAttemptId)) { + abortStage(failedStage, s"$failedStage (${failedStage.name}) " + + s"has failed the maximum allowable number of " + + s"times: ${Stage.MAX_CONSECUTIVE_FETCH_FAILURES}. " + + s"Most recent failure reason: ${failureMessage}", None) + } else if (failedStages.isEmpty) { + // Don't schedule an event to resubmit failed stages if failed isn't empty, because + // in that case the event will already have been scheduled. + // TODO: Cancel running tasks in the stage + logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + + s"$failedStage (${failedStage.name}) due to fetch failure") + messageScheduler.schedule(new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + } + failedStages += failedStage + failedStages += mapStage + // Mark the map whose fetch failed as broken in the map stage + if (mapId != -1) { + mapStage.removeOutputLoc(mapId, bmAddress) + mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) + } - // TODO: mark the executor as failed only if there were lots of fetch failures on it - if (bmAddress != null) { - handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) + // TODO: mark the executor as failed only if there were lots of fetch failures on it + if (bmAddress != null) { + handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) + } } case commitDenied: TaskCommitDenied => // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits - case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) => + case exceptionFailure: ExceptionFailure => // Do nothing here, left up to the TaskScheduler to decide how to handle user failures case TaskResultLost => @@ -1193,7 +1338,7 @@ class DAGScheduler( // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { stage.removeOutputsOnExecutor(execId) - val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray + val locs = stage.outputLocs.map(_.headOption.orNull) mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true) } if (shuffleToMapStage.isEmpty) { @@ -1251,10 +1396,17 @@ class DAGScheduler( if (errorMessage.isEmpty) { logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) stage.latestInfo.completionTime = Some(clock.getTimeMillis()) + + // Clear failure count for this stage, now that it's succeeded. + // We only limit consecutive failures of stage attempts,so that if a stage is + // re-used many times in a long-running job, unrelated failures don't eventually cause the + // stage to be aborted. + stage.clearFailures() } else { stage.latestInfo.stageFailed(errorMessage.get) logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime)) } + outputCommitCoordinator.stageEnd(stage.id) listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) runningStages -= stage @@ -1264,7 +1416,10 @@ class DAGScheduler( * Aborts all jobs depending on a particular Stage. This is called in response to a task set * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. */ - private[scheduler] def abortStage(failedStage: Stage, reason: String) { + private[scheduler] def abortStage( + failedStage: Stage, + reason: String, + exception: Option[Throwable]): Unit = { if (!stageIdToStage.contains(failedStage.id)) { // Skip all the actions if the stage has been removed. return @@ -1273,7 +1428,7 @@ class DAGScheduler( activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq failedStage.latestInfo.completionTime = Some(clock.getTimeMillis()) for (job <- dependentJobs) { - failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason") + failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason", exception) } if (dependentJobs.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") @@ -1281,8 +1436,11 @@ class DAGScheduler( } /** Fails a job and all stages that are only used by that job, and cleans up relevant state. */ - private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) { - val error = new SparkException(failureReason) + private def failJobAndIndependentStages( + job: ActiveJob, + failureReason: String, + exception: Option[Throwable] = None): Unit = { + val error = new SparkException(failureReason, exception.getOrElse(null)) var ableToCancelStages = true val shouldInterruptThread = @@ -1401,36 +1559,50 @@ class DAGScheduler( return rddPrefs.map(TaskLocation(_)) } + // If the RDD has narrow dependencies, pick the first partition of the first narrow dependency + // that has any placement preferences. Ideally we would choose based on transfer sizes, + // but this will do for now. rdd.dependencies.foreach { case n: NarrowDependency[_] => - // If the RDD has narrow dependencies, pick the first partition of the first narrow dep - // that has any placement preferences. Ideally we would choose based on transfer sizes, - // but this will do for now. for (inPart <- n.getParents(partition)) { val locs = getPreferredLocsInternal(n.rdd, inPart, visited) if (locs != Nil) { return locs } } - case s: ShuffleDependency[_, _, _] => - // For shuffle dependencies, pick locations which have at least REDUCER_PREF_LOCS_FRACTION - // of data as preferred locations - if (shuffleLocalityEnabled && - rdd.partitions.size < SHUFFLE_PREF_REDUCE_THRESHOLD && - s.rdd.partitions.size < SHUFFLE_PREF_MAP_THRESHOLD) { - // Get the preferred map output locations for this reducer - val topLocsForReducer = mapOutputTracker.getLocationsWithLargestOutputs(s.shuffleId, - partition, rdd.partitions.size, REDUCER_PREF_LOCS_FRACTION) - if (topLocsForReducer.nonEmpty) { - return topLocsForReducer.get.map(loc => TaskLocation(loc.host, loc.executorId)) - } - } - case _ => } + + // If the RDD has shuffle dependencies and shuffle locality is enabled, pick locations that + // have at least REDUCER_PREF_LOCS_FRACTION of data as preferred locations + if (shuffleLocalityEnabled && rdd.partitions.length < SHUFFLE_PREF_REDUCE_THRESHOLD) { + rdd.dependencies.foreach { + case s: ShuffleDependency[_, _, _] => + if (s.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD) { + // Get the preferred map output locations for this reducer + val topLocsForReducer = mapOutputTracker.getLocationsWithLargestOutputs(s.shuffleId, + partition, rdd.partitions.length, REDUCER_PREF_LOCS_FRACTION) + if (topLocsForReducer.nonEmpty) { + return topLocsForReducer.get.map(loc => TaskLocation(loc.host, loc.executorId)) + } + } + case _ => + } + } Nil } + /** Mark a map stage job as finished with the given output stats, and report to its listener. */ + def markMapStageJobAsFinished(job: ActiveJob, stats: MapOutputStatistics): Unit = { + // In map stage jobs, we only create a single "task", which is to finish all of the stage + // (including reusing any previous map outputs, etc); so we just mark task 0 as done + job.finished(0) = true + job.numFinished += 1 + job.listener.taskSucceeded(0, stats) + cleanupStateForJobAndIndependentStages(job) + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded)) + } + def stop() { logInfo("Stopping DAGScheduler") messageScheduler.shutdownNow() @@ -1438,20 +1610,34 @@ class DAGScheduler( taskScheduler.stop() } - // Start the event thread at the end of the constructor + // Start the event thread and register the metrics source at the end of the constructor + env.metricsSystem.registerSource(metricsSource) eventProcessLoop.start() } private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler) extends EventLoop[DAGSchedulerEvent]("dag-scheduler-event-loop") with Logging { + private[this] val timer = dagScheduler.metricsSource.messageProcessingTimer + /** * The main event loop of the DAG scheduler. */ - override def onReceive(event: DAGSchedulerEvent): Unit = event match { - case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => - dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, - listener, properties) + override def onReceive(event: DAGSchedulerEvent): Unit = { + val timerContext = timer.time() + try { + doOnReceive(event) + } finally { + timerContext.stop() + } + } + + private def doOnReceive(event: DAGSchedulerEvent): Unit = event match { + case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) => + dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) + + case MapStageSubmitted(jobId, dependency, callSite, listener, properties) => + dagScheduler.handleMapStageSubmitted(jobId, dependency, callSite, listener, properties) case StageCancelled(stageId) => dagScheduler.handleStageCancellation(stageId) @@ -1480,8 +1666,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) => dagScheduler.handleTaskCompletion(completion) - case TaskSetFailed(taskSet, reason) => - dagScheduler.handleTaskSetFailed(taskSet, reason) + case TaskSetFailed(taskSet, reason, exception) => + dagScheduler.handleTaskSetFailed(taskSet, reason, exception) case ResubmitFailedStages => dagScheduler.resubmitFailedStages() diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 2b6f7e4205c3..dda3b6cc7f96 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import java.util.Properties -import scala.collection.mutable.Map +import scala.collection.Map import scala.language.existentials import org.apache.spark._ @@ -35,17 +35,26 @@ import org.apache.spark.util.CallSite */ private[scheduler] sealed trait DAGSchedulerEvent +/** A result-yielding job was submitted on a target RDD */ private[scheduler] case class JobSubmitted( jobId: Int, finalRDD: RDD[_], func: (TaskContext, Iterator[_]) => _, partitions: Array[Int], - allowLocal: Boolean, callSite: CallSite, listener: JobListener, properties: Properties = null) extends DAGSchedulerEvent +/** A map stage as submitted to run as a separate job */ +private[scheduler] case class MapStageSubmitted( + jobId: Int, + dependency: ShuffleDependency[_, _, _], + callSite: CallSite, + listener: JobListener, + properties: Properties = null) + extends DAGSchedulerEvent + private[scheduler] case class StageCancelled(stageId: Int) extends DAGSchedulerEvent private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent @@ -74,6 +83,7 @@ private[scheduler] case class ExecutorAdded(execId: String, host: String) extend private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent private[scheduler] -case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent +case class TaskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]) + extends DAGSchedulerEvent private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala index 02c67073af6a..6b667d5d7645 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala @@ -17,11 +17,11 @@ package org.apache.spark.scheduler -import com.codahale.metrics.{Gauge, MetricRegistry} +import com.codahale.metrics.{Gauge, MetricRegistry, Timer} import org.apache.spark.metrics.source.Source -private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler) +private[scheduler] class DAGSchedulerSource(val dagScheduler: DAGScheduler) extends Source { override val metricRegistry = new MetricRegistry() override val sourceName = "DAGScheduler" @@ -45,4 +45,8 @@ private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler) metricRegistry.register(MetricRegistry.name("job", "activeJobs"), new Gauge[Int] { override def getValue: Int = dagScheduler.activeJobs.size }) + + /** Timer that tracks the time to process messages in the DAGScheduler's event loop */ + val messageProcessingTimer: Timer = + metricRegistry.timer(MetricRegistry.name("messageProcessingTime")) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 529a5b2bf1a0..5a06ef02f5c5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -140,7 +140,9 @@ private[spark] class EventLoggingListener( /** Log the event as JSON. */ private def logEvent(event: SparkListenerEvent, flushLogger: Boolean = false) { val eventJson = JsonProtocol.sparkEventToJson(event) + // scalastyle:off println writer.foreach(_.println(compact(render(eventJson)))) + // scalastyle:on println if (flushLogger) { writer.foreach(_.flush()) hadoopDataStream.foreach(hadoopFlushMethod.invoke(_)) @@ -197,6 +199,9 @@ private[spark] class EventLoggingListener( logEvent(event, flushLogger = true) } + // No-op because logging every update would be overkill + override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = {} + // No-op because logging every update would be overkill override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala index 2bc43a918644..0a98c69b89ea 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala @@ -23,16 +23,20 @@ import org.apache.spark.executor.ExecutorExitCode * Represents an explanation for a executor or whole slave failing or exiting. */ private[spark] -class ExecutorLossReason(val message: String) { +class ExecutorLossReason(val message: String) extends Serializable { override def toString: String = message } private[spark] -case class ExecutorExited(val exitCode: Int) - extends ExecutorLossReason(ExecutorExitCode.explainExitCode(exitCode)) { +case class ExecutorExited(exitCode: Int, isNormalExit: Boolean, reason: String) + extends ExecutorLossReason(reason) + +private[spark] object ExecutorExited { + def apply(exitCode: Int, isNormalExit: Boolean): ExecutorExited = { + ExecutorExited(exitCode, isNormalExit, ExecutorExitCode.explainExitCode(exitCode)) + } } private[spark] case class SlaveLost(_message: String = "Slave lost") - extends ExecutorLossReason(_message) { -} + extends ExecutorLossReason(_message) diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala index bac37bfdaa23..0e438ab4366d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.immutable.Set import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -107,7 +107,7 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl val retval = new ArrayBuffer[SplitInfo]() val list = instance.getSplits(job) - for (split <- list) { + for (split <- list.asScala) { retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, split) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index e55b76c36cc5..f96eb8ca0ae0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -125,7 +125,9 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener val date = new Date(System.currentTimeMillis()) writeInfo = dateFormat.get.format(date) + ": " + info } + // scalastyle:off println jobIdToPrintWriter.get(jobId).foreach(_.println(writeInfo)) + // scalastyle:on println } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 8321037cdc02..add0dedc03f4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -25,7 +25,7 @@ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, RpcEndpoint private sealed trait OutputCommitCoordinationMessage extends Serializable private case object StopCoordinator extends OutputCommitCoordinationMessage -private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttempt: Long) +private case class AskPermissionToCommitOutput(stage: Int, partition: Int, attemptNumber: Int) /** * Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins" @@ -44,8 +44,8 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) var coordinatorRef: Option[RpcEndpointRef] = None private type StageId = Int - private type PartitionId = Long - private type TaskAttemptId = Long + private type PartitionId = Int + private type TaskAttemptNumber = Int /** * Map from active stages's id => partition id => task attempt with exclusive lock on committing @@ -57,7 +57,8 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance. */ private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map() - private type CommittersByStageMap = mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptId]] + private type CommittersByStageMap = + mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptNumber]] /** * Returns whether the OutputCommitCoordinator's internal data structures are all empty. @@ -75,14 +76,15 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * * @param stage the stage number * @param partition the partition number - * @param attempt a unique identifier for this task attempt + * @param attemptNumber how many times this task has been attempted + * (see [[TaskContext.attemptNumber()]]) * @return true if this task is authorized to commit, false otherwise */ def canCommit( stage: StageId, partition: PartitionId, - attempt: TaskAttemptId): Boolean = { - val msg = AskPermissionToCommitOutput(stage, partition, attempt) + attemptNumber: TaskAttemptNumber): Boolean = { + val msg = AskPermissionToCommitOutput(stage, partition, attemptNumber) coordinatorRef match { case Some(endpointRef) => endpointRef.askWithRetry[Boolean](msg) @@ -95,7 +97,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) // Called by DAGScheduler private[scheduler] def stageStart(stage: StageId): Unit = synchronized { - authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, TaskAttemptId]() + authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, TaskAttemptNumber]() } // Called by DAGScheduler @@ -107,7 +109,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) private[scheduler] def taskCompleted( stage: StageId, partition: PartitionId, - attempt: TaskAttemptId, + attemptNumber: TaskAttemptNumber, reason: TaskEndReason): Unit = synchronized { val authorizedCommitters = authorizedCommittersByStage.getOrElse(stage, { logDebug(s"Ignoring task completion for completed stage") @@ -117,12 +119,12 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) case Success => // The task output has been committed successfully case denied: TaskCommitDenied => - logInfo( - s"Task was denied committing, stage: $stage, partition: $partition, attempt: $attempt") + logInfo(s"Task was denied committing, stage: $stage, partition: $partition, " + + s"attempt: $attemptNumber") case otherReason => - if (authorizedCommitters.get(partition).exists(_ == attempt)) { - logDebug(s"Authorized committer $attempt (stage=$stage, partition=$partition) failed;" + - s" clearing lock") + if (authorizedCommitters.get(partition).exists(_ == attemptNumber)) { + logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " + + s"partition=$partition) failed; clearing lock") authorizedCommitters.remove(partition) } } @@ -140,21 +142,23 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) private[scheduler] def handleAskPermissionToCommit( stage: StageId, partition: PartitionId, - attempt: TaskAttemptId): Boolean = synchronized { + attemptNumber: TaskAttemptNumber): Boolean = synchronized { authorizedCommittersByStage.get(stage) match { case Some(authorizedCommitters) => authorizedCommitters.get(partition) match { case Some(existingCommitter) => - logDebug(s"Denying $attempt to commit for stage=$stage, partition=$partition; " + - s"existingCommitter = $existingCommitter") + logDebug(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage, " + + s"partition=$partition; existingCommitter = $existingCommitter") false case None => - logDebug(s"Authorizing $attempt to commit for stage=$stage, partition=$partition") - authorizedCommitters(partition) = attempt + logDebug(s"Authorizing attemptNumber=$attemptNumber to commit for stage=$stage, " + + s"partition=$partition") + authorizedCommitters(partition) = attemptNumber true } case None => - logDebug(s"Stage $stage has completed, so not allowing task attempt $attempt to commit") + logDebug(s"Stage $stage has completed, so not allowing attempt number $attemptNumber of" + + s"partition $partition to commit") false } } @@ -162,7 +166,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) private[spark] object OutputCommitCoordinator { - // This actor is used only for RPC + // This endpoint is used only for RPC private[spark] class OutputCommitCoordinatorEndpoint( override val rpcEnv: RpcEnv, outputCommitCoordinator: OutputCommitCoordinator) extends RpcEndpoint with Logging { @@ -174,9 +178,9 @@ private[spark] object OutputCommitCoordinator { } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case AskPermissionToCommitOutput(stage, partition, taskAttempt) => + case AskPermissionToCommitOutput(stage, partition, attemptNumber) => context.reply( - outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt)) + outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, attemptNumber)) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 174b73221afc..551e39a81b69 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.Logging @@ -74,7 +74,7 @@ private[spark] class Pool( if (schedulableNameToSchedulable.containsKey(schedulableName)) { return schedulableNameToSchedulable.get(schedulableName) } - for (schedulable <- schedulableQueue) { + for (schedulable <- schedulableQueue.asScala) { val sched = schedulable.getSchedulableByName(schedulableName) if (sched != null) { return sched @@ -83,13 +83,13 @@ private[spark] class Pool( null } - override def executorLost(executorId: String, host: String) { - schedulableQueue.foreach(_.executorLost(executorId, host)) + override def executorLost(executorId: String, host: String, reason: ExecutorLossReason) { + schedulableQueue.asScala.foreach(_.executorLost(executorId, host, reason)) } override def checkSpeculatableTasks(): Boolean = { var shouldRevive = false - for (schedulable <- schedulableQueue) { + for (schedulable <- schedulableQueue.asScala) { shouldRevive |= schedulable.checkSpeculatableTasks() } shouldRevive @@ -98,7 +98,7 @@ private[spark] class Pool( override def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager] = { var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] val sortedSchedulableQueue = - schedulableQueue.toSeq.sortWith(taskSetSchedulingAlgorithm.comparator) + schedulableQueue.asScala.toSeq.sortWith(taskSetSchedulingAlgorithm.comparator) for (schedulable <- sortedSchedulableQueue) { sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala index bf81b9aca481..c0451da1f024 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala @@ -17,23 +17,30 @@ package org.apache.spark.scheduler +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.util.CallSite /** - * The ResultStage represents the final stage in a job. + * ResultStages apply a function on some partitions of an RDD to compute the result of an action. + * The ResultStage object captures the function to execute, `func`, which will be applied to each + * partition, and the set of partition IDs, `partitions`. Some stages may not run on all partitions + * of the RDD, for actions like first() and lookup(). */ private[spark] class ResultStage( id: Int, rdd: RDD[_], - numTasks: Int, + val func: (TaskContext, Iterator[_]) => _, + val partitions: Array[Int], parents: List[Stage], firstJobId: Int, callSite: CallSite) - extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { + extends Stage(id, rdd, partitions.length, parents, firstJobId, callSite) { - // The active job for this result stage. Will be empty if the job has already finished - // (e.g., because the job was cancelled). + /** + * The active job for this result stage. Will be empty if the job has already finished + * (e.g., because the job was cancelled). + */ var resultOfJob: Option[ActiveJob] = None override def toString: String = "ResultStage " + id diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index c9a124113961..fb693721a9cb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -41,11 +41,14 @@ import org.apache.spark.rdd.RDD */ private[spark] class ResultTask[T, U]( stageId: Int, + stageAttemptId: Int, taskBinary: Broadcast[Array[Byte]], partition: Partition, - @transient locs: Seq[TaskLocation], - val outputId: Int) - extends Task[U](stageId, partition.index) with Serializable { + locs: Seq[TaskLocation], + val outputId: Int, + internalAccumulators: Seq[Accumulator[Long]]) + extends Task[U](stageId, stageAttemptId, partition.index, internalAccumulators) + with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq diff --git a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala index a87ef030e69c..ab00bc8f0bf4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala @@ -42,7 +42,7 @@ private[spark] trait Schedulable { def addSchedulable(schedulable: Schedulable): Unit def removeSchedulable(schedulable: Schedulable): Unit def getSchedulableByName(name: String): Schedulable - def executorLost(executorId: String, host: String): Unit + def executorLost(executorId: String, host: String, reason: ExecutorLossReason): Unit def checkSpeculatableTasks(): Boolean def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager] } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index 66c75f325fcd..7d9296087640 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -23,7 +23,15 @@ import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.CallSite /** - * The ShuffleMapStage represents the intermediate stages in a job. + * ShuffleMapStages are intermediate stages in the execution DAG that produce data for a shuffle. + * They occur right before each shuffle operation, and might contain multiple pipelined operations + * before that (e.g. map and filter). When executed, they save map output files that can later be + * fetched by reduce tasks. The `shuffleDep` field describes the shuffle each stage is part of, + * and variables like `outputLocs` and `numAvailableOutputs` track how many map outputs are ready. + * + * ShuffleMapStages can also be submitted independently as jobs with DAGScheduler.submitMapStage. + * For such stages, the ActiveJobs that submitted them are tracked in `mapStageJobs`. Note that + * there can be multiple ActiveJobs trying to compute the same shuffle map stage. */ private[spark] class ShuffleMapStage( id: Int, @@ -37,7 +45,10 @@ private[spark] class ShuffleMapStage( override def toString: String = "ShuffleMapStage " + id - var numAvailableOutputs: Long = 0 + /** Running map-stage jobs that were submitted to execute this stage independently (if any) */ + var mapStageJobs: List[ActiveJob] = Nil + + var numAvailableOutputs: Int = 0 def isAvailable: Boolean = numAvailableOutputs == numPartitions diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index bd3dd23dfe1a..f478f9982afe 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -40,14 +40,17 @@ import org.apache.spark.shuffle.ShuffleWriter */ private[spark] class ShuffleMapTask( stageId: Int, + stageAttemptId: Int, taskBinary: Broadcast[Array[Byte]], partition: Partition, - @transient private var locs: Seq[TaskLocation]) - extends Task[MapStatus](stageId, partition.index) with Logging { + @transient private var locs: Seq[TaskLocation], + internalAccumulators: Seq[Accumulator[Long]]) + extends Task[MapStatus](stageId, stageAttemptId, partition.index, internalAccumulators) + with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ def this(partitionId: Int) { - this(0, null, new Partition { override def index: Int = 0 }, null) + this(0, 0, null, new Partition { override def index: Int = 0 }, null, null) } @transient private val preferredLocs: Seq[TaskLocation] = { @@ -68,7 +71,7 @@ private[spark] class ShuffleMapTask( val manager = SparkEnv.get.shuffleManager writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) - return writer.stop(success = true).get + writer.stop(success = true).get } catch { case e: Exception => try { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 9620915f495a..896f1743332f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -26,7 +26,7 @@ import org.apache.spark.{Logging, TaskEndReason} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo} import org.apache.spark.util.{Distribution, Utils} @DeveloperApi @@ -98,6 +98,9 @@ case class SparkListenerExecutorAdded(time: Long, executorId: String, executorIn case class SparkListenerExecutorRemoved(time: Long, executorId: String, reason: String) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerBlockUpdated(blockUpdatedInfo: BlockUpdatedInfo) extends SparkListenerEvent + /** * Periodic updates from executors. * @param execId executor id @@ -215,6 +218,11 @@ trait SparkListener { * Called when the driver removes an executor. */ def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved) { } + + /** + * Called when the driver receives a block update info. + */ + def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated) { } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 61e69ecc0838..04afde33f5aa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -58,6 +58,8 @@ private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkLi listener.onExecutorAdded(executorAdded) case executorRemoved: SparkListenerExecutorRemoved => listener.onExecutorRemoved(executorRemoved) + case blockUpdated: SparkListenerBlockUpdated => + listener.onBlockUpdated(blockUpdated) case logStart: SparkListenerLogStart => // ignore event log metadata } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index c59d6e4f5bc0..b37eccbd0f7b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -24,29 +24,35 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.CallSite /** - * A stage is a set of independent tasks all computing the same function that need to run as part + * A stage is a set of parallel tasks all computing the same function that need to run as part * of a Spark job, where all the tasks have the same shuffle dependencies. Each DAG of tasks run * by the scheduler is split up into stages at the boundaries where shuffle occurs, and then the * DAGScheduler runs these stages in topological order. * * Each Stage can either be a shuffle map stage, in which case its tasks' results are input for - * another stage, or a result stage, in which case its tasks directly compute the action that - * initiated a job (e.g. count(), save(), etc). For shuffle map stages, we also track the nodes - * that each output partition is on. + * other stage(s), or a result stage, in which case its tasks directly compute a Spark action + * (e.g. count(), save(), etc) by running a function on an RDD. For shuffle map stages, we also + * track the nodes that each output partition is on. * * Each Stage also has a firstJobId, identifying the job that first submitted the stage. When FIFO * scheduling is used, this allows Stages from earlier jobs to be computed first or recovered * faster on failure. * - * The callSite provides a location in user code which relates to the stage. For a shuffle map - * stage, the callSite gives the user code that created the RDD being shuffled. For a result - * stage, the callSite gives the user code that executes the associated action (e.g. count()). - * - * A single stage can consist of multiple attempts. In that case, the latestInfo field will - * be updated for each attempt. + * Finally, a single stage can be re-executed in multiple attempts due to fault recovery. In that + * case, the Stage object will track multiple StageInfo objects to pass to listeners or the web UI. + * The latest one will be accessible through latestInfo. * + * @param id Unique stage ID + * @param rdd RDD that this stage runs on: for a shuffle map stage, it's the RDD we run map tasks + * on, while for a result stage, it's the target RDD that we ran an action on + * @param numTasks Total number of tasks in stage; result stages in particular may not need to + * compute all partitions, e.g. for first(), lookup(), and take(). + * @param parents List of stages that this stage depends on (through shuffle dependencies). + * @param firstJobId ID of the first job this stage was part of, for FIFO scheduling. + * @param callSite Location in the user program associated with this stage: either where the target + * RDD was created, for a shuffle map stage, or where the action for a result stage was called. */ -private[spark] abstract class Stage( +private[scheduler] abstract class Stage( val id: Int, val rdd: RDD[_], val numTasks: Int, @@ -62,22 +68,70 @@ private[spark] abstract class Stage( var pendingTasks = new HashSet[Task[_]] + /** The ID to use for the next new attempt for this stage. */ private var nextAttemptId: Int = 0 val name = callSite.shortForm val details = callSite.longForm - /** Pointer to the latest [StageInfo] object, set by DAGScheduler. */ - var latestInfo: StageInfo = StageInfo.fromStage(this) + private var _internalAccumulators: Seq[Accumulator[Long]] = Seq.empty + + /** Internal accumulators shared across all tasks in this stage. */ + def internalAccumulators: Seq[Accumulator[Long]] = _internalAccumulators + + /** + * Re-initialize the internal accumulators associated with this stage. + * + * This is called every time the stage is submitted, *except* when a subset of tasks + * belonging to this stage has already finished. Otherwise, reinitializing the internal + * accumulators here again will override partial values from the finished tasks. + */ + def resetInternalAccumulators(): Unit = { + _internalAccumulators = InternalAccumulator.create(rdd.sparkContext) + } + + /** + * Pointer to the [StageInfo] object for the most recent attempt. This needs to be initialized + * here, before any attempts have actually been created, because the DAGScheduler uses this + * StageInfo to tell SparkListeners when a job starts (which happens before any stage attempts + * have been created). + */ + private var _latestInfo: StageInfo = StageInfo.fromStage(this, nextAttemptId) + + /** + * Set of stage attempt IDs that have failed with a FetchFailure. We keep track of these + * failures in order to avoid endless retries if a stage keeps failing with a FetchFailure. + * We keep track of each attempt ID that has failed to avoid recording duplicate failures if + * multiple tasks from the same stage attempt fail (SPARK-5945). + */ + private val fetchFailedAttemptIds = new HashSet[Int] - /** Return a new attempt id, starting with 0. */ - def newAttemptId(): Int = { - val id = nextAttemptId + private[scheduler] def clearFailures() : Unit = { + fetchFailedAttemptIds.clear() + } + + /** + * Check whether we should abort the failedStage due to multiple consecutive fetch failures. + * + * This method updates the running set of failed stage attempts and returns + * true if the number of failures exceeds the allowable number of failures. + */ + private[scheduler] def failedOnFetchAndShouldAbort(stageAttemptId: Int): Boolean = { + fetchFailedAttemptIds.add(stageAttemptId) + fetchFailedAttemptIds.size >= Stage.MAX_CONSECUTIVE_FETCH_FAILURES + } + + /** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */ + def makeNewStageAttempt( + numPartitionsToCompute: Int, + taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty): Unit = { + _latestInfo = StageInfo.fromStage( + this, nextAttemptId, Some(numPartitionsToCompute), taskLocalityPreferences) nextAttemptId += 1 - id } - def attemptId: Int = nextAttemptId + /** Returns the StageInfo for the most recent attempt for this stage. */ + def latestInfo: StageInfo = _latestInfo override final def hashCode(): Int = id override final def equals(other: Any): Boolean = other match { @@ -85,3 +139,8 @@ private[spark] abstract class Stage( case _ => false } } + +private[scheduler] object Stage { + // The number of consecutive failures allowed before a stage is aborted + val MAX_CONSECUTIVE_FETCH_FAILURES = 4 +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index e439d2a7e122..24796c14300b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -34,7 +34,8 @@ class StageInfo( val numTasks: Int, val rddInfos: Seq[RDDInfo], val parentIds: Seq[Int], - val details: String) { + val details: String, + private[spark] val taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty) { /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */ var submissionTime: Option[Long] = None /** Time when all tasks in the stage completed or when the stage was cancelled. */ @@ -70,16 +71,22 @@ private[spark] object StageInfo { * shuffle dependencies. Therefore, all ancestor RDDs related to this Stage's RDD through a * sequence of narrow dependencies should also be associated with this Stage. */ - def fromStage(stage: Stage, numTasks: Option[Int] = None): StageInfo = { + def fromStage( + stage: Stage, + attemptId: Int, + numTasks: Option[Int] = None, + taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty + ): StageInfo = { val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd) val rddInfos = Seq(RDDInfo.fromRdd(stage.rdd)) ++ ancestorRddInfos new StageInfo( stage.id, - stage.attemptId, + attemptId, stage.name, numTasks.getOrElse(stage.numTasks), rddInfos, stage.parents.map(_.id), - stage.details) + stage.details, + taskLocalityPreferences) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 15101c64f050..9edf9f048f9f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -22,7 +22,8 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap -import org.apache.spark.{TaskContextImpl, TaskContext} +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.{Accumulator, SparkEnv, TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.unsafe.memory.TaskMemoryManager @@ -43,34 +44,62 @@ import org.apache.spark.util.Utils * @param stageId id of the stage this task belongs to * @param partitionId index of the number in the RDD */ -private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { +private[spark] abstract class Task[T]( + val stageId: Int, + val stageAttemptId: Int, + val partitionId: Int, + internalAccumulators: Seq[Accumulator[Long]]) extends Serializable { + + /** + * The key of the Map is the accumulator id and the value of the Map is the latest accumulator + * local value. + */ + type AccumulatorUpdates = Map[Long, Any] /** * Called by [[Executor]] to run this task. * * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext. * @param attemptNumber how many times this task has been attempted (0 for the first attempt) - * @return the result of the task + * @return the result of the task along with updates of Accumulators. */ - final def run(taskAttemptId: Long, attemptNumber: Int): T = { + final def run( + taskAttemptId: Long, + attemptNumber: Int, + metricsSystem: MetricsSystem) + : (T, AccumulatorUpdates) = { context = new TaskContextImpl( - stageId = stageId, - partitionId = partitionId, - taskAttemptId = taskAttemptId, - attemptNumber = attemptNumber, - taskMemoryManager = taskMemoryManager, + stageId, + partitionId, + taskAttemptId, + attemptNumber, + taskMemoryManager, + metricsSystem, + internalAccumulators, runningLocally = false) TaskContext.setTaskContext(context) context.taskMetrics.setHostname(Utils.localHostName()) + context.taskMetrics.setAccumulatorsUpdater(context.collectInternalAccumulators) taskThread = Thread.currentThread() if (_killed) { kill(interruptThread = false) } try { - runTask(context) + (runTask(context), context.collectAccumulators()) } finally { context.markTaskCompleted() - TaskContext.unset() + try { + Utils.tryLogNonFatalError { + // Release memory used by this thread for shuffles + SparkEnv.get.shuffleMemoryManager.releaseMemoryForThisTask() + } + Utils.tryLogNonFatalError { + // Release memory used by this thread for unrolling blocks + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask() + } + } finally { + TaskContext.unset() + } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 132a9ced7770..f113c2b1b843 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -29,7 +29,7 @@ import org.apache.spark.annotation.DeveloperApi class TaskInfo( val taskId: Long, val index: Int, - val attempt: Int, + val attemptNumber: Int, val launchTime: Long, val executorId: String, val host: String, @@ -95,7 +95,10 @@ class TaskInfo( } } - def id: String = s"$index.$attempt" + @deprecated("Use attemptNumber", "1.6.0") + def attempt: Int = attemptNumber + + def id: String = s"$index.$attemptNumber" def duration: Long = { if (!finished) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index 8b2a742b9698..b82c7f3fa54f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -20,7 +20,8 @@ package org.apache.spark.scheduler import java.io._ import java.nio.ByteBuffer -import scala.collection.mutable.Map +import scala.collection.Map +import scala.collection.mutable import org.apache.spark.SparkEnv import org.apache.spark.executor.TaskMetrics @@ -69,10 +70,11 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long if (numUpdates == 0) { accumUpdates = null } else { - accumUpdates = Map() + val _accumUpdates = mutable.Map[Long, Any]() for (i <- 0 until numUpdates) { - accumUpdates(in.readLong()) = in.readObject() + _accumUpdates(in.readLong()) = in.readObject() } + accumUpdates = _accumUpdates } metrics = in.readObject().asInstanceOf[TaskMetrics] valueObjectDeserialized = false diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index ed3dde0fc305..1c7bfe89c02a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -75,9 +75,9 @@ private[spark] class TaskSchedulerImpl( // TaskSetManagers are not thread safe, so any access to one should be synchronized // on this class. - val activeTaskSets = new HashMap[String, TaskSetManager] + private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]] - val taskIdToTaskSetId = new HashMap[Long, String] + private[scheduler] val taskIdToTaskSetManager = new HashMap[Long, TaskSetManager] val taskIdToExecutorId = new HashMap[Long, String] @volatile private var hasReceivedTask = false @@ -162,7 +162,17 @@ private[spark] class TaskSchedulerImpl( logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { val manager = createTaskSetManager(taskSet, maxTaskFailures) - activeTaskSets(taskSet.id) = manager + val stage = taskSet.stageId + val stageTaskSets = + taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager]) + stageTaskSets(taskSet.stageAttemptId) = manager + val conflictingTaskSet = stageTaskSets.exists { case (_, ts) => + ts.taskSet != taskSet && !ts.isZombie + } + if (conflictingTaskSet) { + throw new IllegalStateException(s"more than one active taskSet for stage $stage:" + + s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}") + } schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) if (!isLocal && !hasReceivedTask) { @@ -192,19 +202,21 @@ private[spark] class TaskSchedulerImpl( override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { logInfo("Cancelling stage " + stageId) - activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => - // There are two possible cases here: - // 1. The task set manager has been created and some tasks have been scheduled. - // In this case, send a kill signal to the executors to kill the task and then abort - // the stage. - // 2. The task set manager has been created but no tasks has been scheduled. In this case, - // simply abort the stage. - tsm.runningTasksSet.foreach { tid => - val execId = taskIdToExecutorId(tid) - backend.killTask(tid, execId, interruptThread) + taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts => + attempts.foreach { case (_, tsm) => + // There are two possible cases here: + // 1. The task set manager has been created and some tasks have been scheduled. + // In this case, send a kill signal to the executors to kill the task and then abort + // the stage. + // 2. The task set manager has been created but no tasks has been scheduled. In this case, + // simply abort the stage. + tsm.runningTasksSet.foreach { tid => + val execId = taskIdToExecutorId(tid) + backend.killTask(tid, execId, interruptThread) + } + tsm.abort("Stage %s cancelled".format(stageId)) + logInfo("Stage %d was cancelled".format(stageId)) } - tsm.abort("Stage %s cancelled".format(stageId)) - logInfo("Stage %d was cancelled".format(stageId)) } } @@ -214,7 +226,12 @@ private[spark] class TaskSchedulerImpl( * cleaned up. */ def taskSetFinished(manager: TaskSetManager): Unit = synchronized { - activeTaskSets -= manager.taskSet.id + taskSetsByStageIdAndAttempt.get(manager.taskSet.stageId).foreach { taskSetsForStage => + taskSetsForStage -= manager.taskSet.stageAttemptId + if (taskSetsForStage.isEmpty) { + taskSetsByStageIdAndAttempt -= manager.taskSet.stageId + } + } manager.parent.removeSchedulable(manager) logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s" .format(manager.taskSet.id, manager.parent.name)) @@ -235,7 +252,7 @@ private[spark] class TaskSchedulerImpl( for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { tasks(i) += task val tid = task.taskId - taskIdToTaskSetId(tid) = taskSet.taskSet.id + taskIdToTaskSetManager(tid) = taskSet taskIdToExecutorId(tid) = execId executorsByHost(host) += execId availableCpus(i) -= CPUS_PER_TASK @@ -315,30 +332,29 @@ private[spark] class TaskSchedulerImpl( // We lost this entire executor, so remember that it's gone val execId = taskIdToExecutorId(tid) if (activeExecutorIds.contains(execId)) { - removeExecutor(execId) + removeExecutor(execId, + SlaveLost(s"Task $tid was lost, so marking the executor as lost as well.")) failedExecutor = Some(execId) } } - taskIdToTaskSetId.get(tid) match { - case Some(taskSetId) => + taskIdToTaskSetManager.get(tid) match { + case Some(taskSet) => if (TaskState.isFinished(state)) { - taskIdToTaskSetId.remove(tid) + taskIdToTaskSetManager.remove(tid) taskIdToExecutorId.remove(tid) } - activeTaskSets.get(taskSetId).foreach { taskSet => - if (state == TaskState.FINISHED) { - taskSet.removeRunningTask(tid) - taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) - } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { - taskSet.removeRunningTask(tid) - taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) - } + if (state == TaskState.FINISHED) { + taskSet.removeRunningTask(tid) + taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) + } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { + taskSet.removeRunningTask(tid) + taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) } case None => logError( ("Ignoring update with state %s for TID %s because its task set is gone (this is " + - "likely the result of receiving duplicate task finished status updates)") - .format(state, tid)) + "likely the result of receiving duplicate task finished status updates)") + .format(state, tid)) } } catch { case e: Exception => logError("Exception in statusUpdate", e) @@ -363,9 +379,9 @@ private[spark] class TaskSchedulerImpl( val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized { taskMetrics.flatMap { case (id, metrics) => - taskIdToTaskSetId.get(id) - .flatMap(activeTaskSets.get) - .map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics)) + taskIdToTaskSetManager.get(id).map { taskSetMgr => + (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics) + } } } dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId) @@ -397,9 +413,12 @@ private[spark] class TaskSchedulerImpl( def error(message: String) { synchronized { - if (activeTaskSets.nonEmpty) { + if (taskSetsByStageIdAndAttempt.nonEmpty) { // Have each task set throw a SparkException with the error - for ((taskSetId, manager) <- activeTaskSets) { + for { + attempts <- taskSetsByStageIdAndAttempt.values + manager <- attempts.values + } { try { manager.abort(message) } catch { @@ -446,7 +465,7 @@ private[spark] class TaskSchedulerImpl( if (activeExecutorIds.contains(executorId)) { val hostPort = executorIdToHost(executorId) logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) - removeExecutor(executorId) + removeExecutor(executorId, reason) failedExecutor = Some(executorId) } else { // We may get multiple executorLost() calls with different loss reasons. For example, one @@ -464,7 +483,7 @@ private[spark] class TaskSchedulerImpl( } /** Remove an executor from all our data structures and mark it as lost */ - private def removeExecutor(executorId: String) { + private def removeExecutor(executorId: String, reason: ExecutorLossReason) { activeExecutorIds -= executorId val host = executorIdToHost(executorId) val execs = executorsByHost.getOrElse(host, new HashSet) @@ -479,7 +498,7 @@ private[spark] class TaskSchedulerImpl( } } executorIdToHost -= executorId - rootPool.executorLost(executorId, host) + rootPool.executorLost(executorId, host, reason) } def executorAdded(execId: String, host: String) { @@ -520,6 +539,17 @@ private[spark] class TaskSchedulerImpl( override def applicationAttemptId(): Option[String] = backend.applicationAttemptId() + private[scheduler] def taskSetManagerForAttempt( + stageId: Int, + stageAttemptId: Int): Option[TaskSetManager] = { + for { + attempts <- taskSetsByStageIdAndAttempt.get(stageId) + manager <- attempts.get(stageAttemptId) + } yield { + manager + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala index c3ad325156f5..be8526ba9b94 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala @@ -26,10 +26,10 @@ import java.util.Properties private[spark] class TaskSet( val tasks: Array[Task[_]], val stageId: Int, - val attempt: Int, + val stageAttemptId: Int, val priority: Int, val properties: Properties) { - val id: String = stageId + "." + attempt + val id: String = stageId + "." + stageAttemptId override def toString: String = "TaskSet " + id } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 82455b0426a5..62af9031b9f8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -662,7 +662,7 @@ private[spark] class TaskSetManager( val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " + reason.asInstanceOf[TaskFailedReason].toErrorString - reason match { + val failureException: Option[Throwable] = reason match { case fetchFailed: FetchFailed => logWarning(failureReason) if (!successful(index)) { @@ -671,6 +671,7 @@ private[spark] class TaskSetManager( } // Not adding to failed executors for FetchFailed. isZombie = true + None case ef: ExceptionFailure => taskMetrics = ef.metrics.orNull @@ -706,38 +707,45 @@ private[spark] class TaskSetManager( s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid) on executor ${info.host}: " + s"${ef.className} (${ef.description}) [duplicate $dupCount]") } + ef.exception + + case e: ExecutorLostFailure if e.isNormalExit => + logInfo(s"Task $tid failed because while it was being computed, its executor" + + s" exited normally. Not marking the task as failed.") + None case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others logWarning(failureReason) + None case e: TaskEndReason => logError("Unknown TaskEndReason: " + e) + None } // always add to failed executors failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()). put(info.executorId, clock.getTimeMillis()) sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics) addPendingTask(index) - if (!isZombie && state != TaskState.KILLED && !reason.isInstanceOf[TaskCommitDenied]) { - // If a task failed because its attempt to commit was denied, do not count this failure - // towards failing the stage. This is intended to prevent spurious stage failures in cases - // where many speculative tasks are launched and denied to commit. + if (!isZombie && state != TaskState.KILLED + && reason.isInstanceOf[TaskFailedReason] + && reason.asInstanceOf[TaskFailedReason].shouldEventuallyFailJob) { assert (null != failureReason) numFailures(index) += 1 if (numFailures(index) >= maxTaskFailures) { logError("Task %d in stage %s failed %d times; aborting job".format( index, taskSet.id, maxTaskFailures)) abort("Task %d in stage %s failed %d times, most recent failure: %s\nDriver stacktrace:" - .format(index, taskSet.id, maxTaskFailures, failureReason)) + .format(index, taskSet.id, maxTaskFailures, failureReason), failureException) return } } maybeFinishTaskSet() } - def abort(message: String): Unit = sched.synchronized { + def abort(message: String, exception: Option[Throwable] = None): Unit = sched.synchronized { // TODO: Kill running tasks if we were not terminated due to a Mesos error - sched.dagScheduler.taskSetFailed(taskSet, message) + sched.dagScheduler.taskSetFailed(taskSet, message, exception) isZombie = true maybeFinishTaskSet() } @@ -774,7 +782,7 @@ private[spark] class TaskSetManager( } /** Called by TaskScheduler when an executor is lost so we can re-enqueue our tasks */ - override def executorLost(execId: String, host: String) { + override def executorLost(execId: String, host: String, reason: ExecutorLossReason) { logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) // Re-enqueue pending tasks for this host based on the status of the cluster. Note @@ -805,9 +813,12 @@ private[spark] class TaskSetManager( } } } - // Also re-enqueue any tasks that were running on the node for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { - handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(execId)) + val isNormalExit: Boolean = reason match { + case exited: ExecutorExited => exited.isNormalExit + case _ => false + } + handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(info.executorId, isNormalExit)) } // recalculate valid locality levels and waits when executor is lost recomputeLocality() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 4be1eda2e929..d94743677783 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -21,6 +21,7 @@ import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler.ExecutorLossReason import org.apache.spark.util.{SerializableBuffer, Utils} private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable @@ -70,7 +71,8 @@ private[spark] object CoarseGrainedClusterMessages { case object StopExecutors extends CoarseGrainedClusterMessage - case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage + case class RemoveExecutor(executorId: String, reason: ExecutorLossReason) + extends CoarseGrainedClusterMessage case class SetupDriver(driver: RpcEndpointRef) extends CoarseGrainedClusterMessage @@ -86,7 +88,15 @@ private[spark] object CoarseGrainedClusterMessages { // Request executors by specifying the new total number of executors desired // This includes executors already pending or running - case class RequestExecutors(requestedTotal: Int) extends CoarseGrainedClusterMessage + case class RequestExecutors( + requestedTotal: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int]) + extends CoarseGrainedClusterMessage + + // Check if an executor was force-killed but for a normal reason. + // This could be the case if the executor is preempted, for instance. + case class GetExecutorLossReason(executorId: String) extends CoarseGrainedClusterMessage case class KillExecutors(executorIds: Seq[String]) extends CoarseGrainedClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 7c7f70d8a193..18771f79b44b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -26,6 +26,7 @@ import org.apache.spark.rpc._ import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.ENDPOINT_NAME import org.apache.spark.util.{ThreadUtils, SerializableBuffer, AkkaUtils, Utils} /** @@ -66,6 +67,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Executors we have requested the cluster manager to kill that have not died yet private val executorsPendingToRemove = new HashSet[String] + // A map to store hostname with its possible task number running on it + protected var hostToLocalTaskCount: Map[String, Int] = Map.empty + + // The number of pending tasks which is locality required + protected var localityAwareTasks = 0 + class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends ThreadSafeRpcEndpoint with Logging { @@ -76,7 +83,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp override protected def log = CoarseGrainedSchedulerBackend.this.log - private val addressToExecutorId = new HashMap[RpcAddress, String] + protected val addressToExecutorId = new HashMap[RpcAddress, String] private val reviveThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-revive-thread") @@ -122,13 +129,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterExecutor(executorId, executorRef, hostPort, cores, logUrls) => Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorDataMap.contains(executorId)) { context.reply(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) } else { logInfo("Registered executor: " + executorRef + " with ID " + executorId) - context.reply(RegisteredExecutor) addressToExecutorId(executorRef.address) = executorId totalCoreCount.addAndGet(cores) totalRegisteredExecutors.addAndGet(1) @@ -143,6 +150,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") } } + // Note: some tests expect the reply to come after we put the executor in the map + context.reply(RegisteredExecutor) listenerBus.post( SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data)) makeOffers() @@ -169,21 +178,29 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Make fake resource offers on all executors private def makeOffers() { - launchTasks(scheduler.resourceOffers(executorDataMap.map { case (id, executorData) => + // Filter out executors under killing + val activeExecutors = executorDataMap.filterKeys(!executorsPendingToRemove.contains(_)) + val workOffers = activeExecutors.map { case (id, executorData) => new WorkerOffer(id, executorData.executorHost, executorData.freeCores) - }.toSeq)) + }.toSeq + launchTasks(scheduler.resourceOffers(workOffers)) } override def onDisconnected(remoteAddress: RpcAddress): Unit = { - addressToExecutorId.get(remoteAddress).foreach(removeExecutor(_, - "remote Rpc client disassociated")) + addressToExecutorId + .get(remoteAddress) + .foreach(removeExecutor(_, SlaveLost("remote Rpc client disassociated"))) } // Make fake resource offers on just one executor private def makeOffers(executorId: String) { - val executorData = executorDataMap(executorId) - launchTasks(scheduler.resourceOffers( - Seq(new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)))) + // Filter out executors under killing + if (!executorsPendingToRemove.contains(executorId)) { + val executorData = executorDataMap(executorId) + val workOffers = Seq( + new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)) + launchTasks(scheduler.resourceOffers(workOffers)) + } } // Launch tasks returned by a set of resource offers @@ -191,15 +208,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp for (task <- tasks.flatten) { val serializedTask = ser.serialize(task) if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { - val taskSetId = scheduler.taskIdToTaskSetId(task.taskId) - scheduler.activeTaskSets.get(taskSetId).foreach { taskSet => + scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr => try { var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " + "spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " + "spark.akka.frameSize or using broadcast variables for large values." msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize, AkkaUtils.reservedSizeBytes) - taskSet.abort(msg) + taskSetMgr.abort(msg) } catch { case e: Exception => logError("Exception in error callback", e) } @@ -214,7 +230,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // Remove a disconnected slave from the cluster - def removeExecutor(executorId: String, reason: String): Unit = { + def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = { executorDataMap.get(executorId) match { case Some(executorInfo) => // This must be synchronized because variables mutated @@ -226,10 +242,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } totalCoreCount.addAndGet(-executorInfo.totalCores) totalRegisteredExecutors.addAndGet(-1) - scheduler.executorLost(executorId, SlaveLost(reason)) + scheduler.executorLost(executorId, reason) listenerBus.post( - SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason)) - case None => logError(s"Asked to remove non-existent executor $executorId") + SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason.toString)) + case None => logInfo(s"Asked to remove non-existent executor $executorId") } } @@ -250,8 +266,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // TODO (prashant) send conf instead of properties - driverEndpoint = rpcEnv.setupEndpoint( - CoarseGrainedSchedulerBackend.ENDPOINT_NAME, new DriverEndpoint(rpcEnv, properties)) + driverEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, createDriverEndpoint(properties)) + } + + protected def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { + new DriverEndpoint(rpcEnv, properties) } def stopExecutors() { @@ -291,7 +310,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // Called by subclasses when notified of a lost worker - def removeExecutor(executorId: String, reason: String) { + def removeExecutor(executorId: String, reason: ExecutorLossReason) { try { driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason)) } catch { @@ -333,6 +352,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager") logDebug(s"Number of pending executors is now $numPendingExecutors") + numPendingExecutors += numAdditionalExecutors // Account for executors pending to be added or removed val newTotal = numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size @@ -340,16 +360,33 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } /** - * Express a preference to the cluster manager for a given total number of executors. This can - * result in canceling pending requests or filing additional requests. - * @return whether the request is acknowledged. + * Update the cluster manager on our scheduling needs. Three bits of information are included + * to help it make decisions. + * @param numExecutors The total number of executors we'd like to have. The cluster manager + * shouldn't kill any running executor to reach this number, but, + * if all existing executors were to die, this is the number of executors + * we'd want to be allocated. + * @param localityAwareTasks The number of tasks in all active stages that have a locality + * preferences. This includes running, pending, and completed tasks. + * @param hostToLocalTaskCount A map of hosts to the number of tasks from all active stages + * that would like to like to run on that host. + * This includes running, pending, and completed tasks. + * @return whether the request is acknowledged by the cluster manager. */ - final override def requestTotalExecutors(numExecutors: Int): Boolean = synchronized { + final override def requestTotalExecutors( + numExecutors: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int] + ): Boolean = synchronized { if (numExecutors < 0) { throw new IllegalArgumentException( "Attempted to request a negative number of executor(s) " + s"$numExecutors from the cluster manager. Please specify a positive number!") } + + this.localityAwareTasks = localityAwareTasks + this.hostToLocalTaskCount = hostToLocalTaskCount + numPendingExecutors = math.max(numExecutors - numExistingExecutors + executorsPendingToRemove.size, 0) doRequestTotalExecutors(numExecutors) @@ -371,31 +408,44 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp /** * Request that the cluster manager kill the specified executors. - * Return whether the kill request is acknowledged. + * @return whether the kill request is acknowledged. */ final override def killExecutors(executorIds: Seq[String]): Boolean = synchronized { + killExecutors(executorIds, replace = false) + } + + /** + * Request that the cluster manager kill the specified executors. + * + * @param executorIds identifiers of executors to kill + * @param replace whether to replace the killed executors with new ones + * @return whether the kill request is acknowledged. + */ + final def killExecutors(executorIds: Seq[String], replace: Boolean): Boolean = synchronized { logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}") - val filteredExecutorIds = new ArrayBuffer[String] - executorIds.foreach { id => - if (executorDataMap.contains(id)) { - filteredExecutorIds += id - } else { - logWarning(s"Executor to kill $id does not exist!") - } + val (knownExecutors, unknownExecutors) = executorIds.partition(executorDataMap.contains) + unknownExecutors.foreach { id => + logWarning(s"Executor to kill $id does not exist!") + } + + // If an executor is already pending to be removed, do not kill it again (SPARK-9795) + val executorsToKill = knownExecutors.filter { id => !executorsPendingToRemove.contains(id) } + executorsPendingToRemove ++= executorsToKill + + // If we do not wish to replace the executors we kill, sync the target number of executors + // with the cluster manager to avoid allocating new ones. When computing the new target, + // take into account executors that are pending to be added or removed. + if (!replace) { + doRequestTotalExecutors( + numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) } - // Killing executors means effectively that we want less executors than before, so also update - // the target number of executors to avoid having the backend allocate new ones. - val newTotal = (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size - - filteredExecutorIds.size) - doRequestTotalExecutors(newTotal) - executorsPendingToRemove ++= filteredExecutorIds - doKillExecutors(filteredExecutorIds) + doKillExecutors(executorsToKill) } /** * Kill the given list of executors through the cluster manager. - * Return whether the kill request is acknowledged. + * @return whether the kill request is acknowledged. */ protected def doKillExecutors(executorIds: Seq[String]): Boolean = false diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index 26e72c0bff38..626a2b7d69ab 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -22,7 +22,7 @@ import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress} /** * Grouping of data for an executor used by CoarseGrainedSchedulerBackend. * - * @param executorEndpoint The ActorRef representing this executor + * @param executorEndpoint The RpcEndpointRef representing this executor * @param executorAddress The network address of this executor * @param executorHost The hostname that this executor is running on * @param freeCores The current number of cores available for work on the executor diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index ccf1dc5af612..27491ecf8b97 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -23,7 +23,7 @@ import org.apache.spark.rpc.RpcAddress import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} -import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl} +import org.apache.spark.scheduler._ import org.apache.spark.util.Utils private[spark] class SparkDeploySchedulerBackend( @@ -85,7 +85,7 @@ private[spark] class SparkDeploySchedulerBackend( val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt) val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor) - client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) + client = new AppClient(sc.env.rpcEnv, masters, appDesc, this, conf) client.start() waitForRegistration() } @@ -135,11 +135,11 @@ private[spark] class SparkDeploySchedulerBackend( override def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]) { val reason: ExecutorLossReason = exitStatus match { - case Some(code) => ExecutorExited(code) + case Some(code) => ExecutorExited(code, isNormalExit = true, message) case None => SlaveLost(message) } logInfo("Executor %s removed: %s".format(fullId, message)) - removeExecutor(fullId.split("/")(1), reason.toString) + removeExecutor(fullId.split("/")(1), reason) } override def sufficientResourcesRegistered(): Boolean = { @@ -152,6 +152,34 @@ private[spark] class SparkDeploySchedulerBackend( super.applicationId } + /** + * Request executors from the Master by specifying the total number desired, + * including existing pending and running executors. + * + * @return whether the request is acknowledged. + */ + protected override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { + Option(client) match { + case Some(c) => c.requestTotalExecutors(requestedTotal) + case None => + logWarning("Attempted to request executors before driver fully initialized.") + false + } + } + + /** + * Kill the given list of executors through the Master. + * @return whether the kill request is acknowledged. + */ + protected override def doKillExecutors(executorIds: Seq[String]): Boolean = { + Option(client) match { + case Some(c) => c.killExecutors(executorIds) + case None => + logWarning("Attempted to kill executors before driver fully initialized.") + false + } + } + private def waitForRegistration() = { registrationBarrier.acquire() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 190ff61d689d..6a4b536dee19 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -17,12 +17,13 @@ package org.apache.spark.scheduler.cluster +import scala.collection.mutable.ArrayBuffer import scala.concurrent.{Future, ExecutionContext} import org.apache.spark.{Logging, SparkContext} import org.apache.spark.rpc._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.scheduler._ import org.apache.spark.ui.JettyUtils import org.apache.spark.util.{ThreadUtils, RpcUtils} @@ -43,24 +44,27 @@ private[spark] abstract class YarnSchedulerBackend( protected var totalExpectedExecutors = 0 - private val yarnSchedulerEndpoint = rpcEnv.setupEndpoint( - YarnSchedulerBackend.ENDPOINT_NAME, new YarnSchedulerEndpoint(rpcEnv)) + private val yarnSchedulerEndpoint = new YarnSchedulerEndpoint(rpcEnv) - private implicit val askTimeout = RpcUtils.askTimeout(sc.conf) + private val yarnSchedulerEndpointRef = rpcEnv.setupEndpoint( + YarnSchedulerBackend.ENDPOINT_NAME, yarnSchedulerEndpoint) + + private implicit val askTimeout = RpcUtils.askRpcTimeout(sc.conf) /** * Request executors from the ApplicationMaster by specifying the total number desired. * This includes executors already pending or running. */ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - yarnSchedulerEndpoint.askWithRetry[Boolean](RequestExecutors(requestedTotal)) + yarnSchedulerEndpointRef.askWithRetry[Boolean]( + RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount)) } /** * Request that the ApplicationMaster kill the specified executors. */ override def doKillExecutors(executorIds: Seq[String]): Boolean = { - yarnSchedulerEndpoint.askWithRetry[Boolean](KillExecutors(executorIds)) + yarnSchedulerEndpointRef.askWithRetry[Boolean](KillExecutors(executorIds)) } override def sufficientResourcesRegistered(): Boolean = { @@ -89,6 +93,41 @@ private[spark] abstract class YarnSchedulerBackend( } } + override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { + new YarnDriverEndpoint(rpcEnv, properties) + } + + /** + * Override the DriverEndpoint to add extra logic for the case when an executor is disconnected. + * This endpoint communicates with the executors and queries the AM for an executor's exit + * status when the executor is disconnected. + */ + private class YarnDriverEndpoint(rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) + extends DriverEndpoint(rpcEnv, sparkProperties) { + + /** + * When onDisconnected is received at the driver endpoint, the superclass DriverEndpoint + * handles it by assuming the Executor was lost for a bad reason and removes the executor + * immediately. + * + * In YARN's case however it is crucial to talk to the application master and ask why the + * executor had exited. In particular, the executor may have exited due to the executor + * having been preempted. If the executor "exited normally" according to the application + * master then we pass that information down to the TaskSetManager to inform the + * TaskSetManager that tasks on that lost executor should not count towards a job failure. + * + * TODO there's a race condition where while we are querying the ApplicationMaster for + * the executor loss reason, there is the potential that tasks will be scheduled on + * the executor that failed. We should fix this by having this onDisconnected event + * also "blacklist" executors so that tasks are not assigned to them. + */ + override def onDisconnected(rpcAddress: RpcAddress): Unit = { + addressToExecutorId.get(rpcAddress).foreach { executorId => + yarnSchedulerEndpoint.handleExecutorDisconnectedFromDriver(executorId, rpcAddress) + } + } + } + /** * An [[RpcEndpoint]] that communicates with the ApplicationMaster. */ @@ -100,6 +139,33 @@ private[spark] abstract class YarnSchedulerBackend( ThreadUtils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool") implicit val askAmExecutor = ExecutionContext.fromExecutor(askAmThreadPool) + private[YarnSchedulerBackend] def handleExecutorDisconnectedFromDriver( + executorId: String, + executorRpcAddress: RpcAddress): Unit = { + amEndpoint match { + case Some(am) => + val lossReasonRequest = GetExecutorLossReason(executorId) + val future = am.ask[ExecutorLossReason](lossReasonRequest, askTimeout) + future onSuccess { + case reason: ExecutorLossReason => { + driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason)) + } + } + future onFailure { + case NonFatal(e) => { + logWarning(s"Attempted to get executor loss reason" + + s" for executor id ${executorId} at RPC address ${executorRpcAddress}," + + s" but got no response. Marking as slave lost.", e) + driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, SlaveLost())) + } + case t => throw t + } + case None => + logWarning("Attempted to check for an executor loss reason" + + " before the AM has registered!") + } + } + override def receive: PartialFunction[Any, Unit] = { case RegisterClusterManager(am) => logInfo(s"ApplicationMaster registered as $am") @@ -108,8 +174,11 @@ private[spark] abstract class YarnSchedulerBackend( case AddWebUIFilter(filterName, filterParams, proxyBase) => addWebUIFilter(filterName, filterParams, proxyBase) + case RemoveExecutor(executorId, reason) => + removeExecutor(executorId, reason) } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case r: RequestExecutors => amEndpoint match { @@ -140,7 +209,6 @@ private[spark] abstract class YarnSchedulerBackend( logWarning("Attempted to kill executors before the AM has registered!") context.reply(false) } - } override def onDisconnected(remoteAddress: RpcAddress): Unit = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 6b8edca5aa48..65cb5016cfcc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -18,18 +18,23 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File +import java.util.concurrent.locks.ReentrantLock import java.util.{Collections, List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, HashSet} +import com.google.common.collect.HashBiMap import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} -import org.apache.mesos.{Scheduler => MScheduler, _} +import org.apache.mesos.{Scheduler => MScheduler, SchedulerDriver} + +import org.apache.spark.{SecurityManager, SparkContext, SparkEnv, SparkException, TaskState} +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient import org.apache.spark.rpc.RpcAddress -import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.scheduler.{SlaveLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils -import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds @@ -44,7 +49,8 @@ import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} private[spark] class CoarseMesosSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, - master: String) + master: String, + securityManager: SecurityManager) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with MScheduler with MesosSchedulerUtils { @@ -54,18 +60,60 @@ private[spark] class CoarseMesosSchedulerBackend( // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt + // If shuffle service is enabled, the Spark driver will register with the shuffle service. + // This is for cleaning up shuffle files reliably. + private val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) + // Cores we have acquired with each Mesos task ID val coresByTaskId = new HashMap[Int, Int] var totalCoresAcquired = 0 val slaveIdsWithExecutors = new HashSet[String] - val taskIdToSlaveId = new HashMap[Int, String] - val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed + // Maping from slave Id to hostname + private val slaveIdToHost = new HashMap[String, String] + + val taskIdToSlaveId: HashBiMap[Int, String] = HashBiMap.create[Int, String] + // How many times tasks on each slave failed + val failuresBySlaveId: HashMap[String, Int] = new HashMap[String, Int] + /** + * The total number of executors we aim to have. Undefined when not using dynamic allocation + * and before the ExecutorAllocatorManager calls [[doRequestTotalExecutors]]. + */ + private var executorLimitOption: Option[Int] = None + + /** + * Return the current executor limit, which may be [[Int.MaxValue]] + * before properly initialized. + */ + private[mesos] def executorLimit: Int = executorLimitOption.getOrElse(Int.MaxValue) + + private val pendingRemovedSlaveIds = new HashSet[String] + + // private lock object protecting mutable state above. Using the intrinsic lock + // may lead to deadlocks since the superclass might also try to lock + private val stateLock = new ReentrantLock val extraCoresPerSlave = conf.getInt("spark.mesos.extra.cores", 0) + // Offer constraints + private val slaveOfferConstraints = + parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + + // A client for talking to the external shuffle service, if it is a + private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = { + if (shuffleServiceEnabled) { + Some(new MesosExternalShuffleClient( + SparkTransportConf.fromSparkConf(conf), + securityManager, + securityManager.isAuthenticationEnabled(), + securityManager.isSaslEncryptionEnabled())) + } else { + None + } + } + var nextMesosTaskId = 0 @volatile var appId: String = _ @@ -78,11 +126,12 @@ private[spark] class CoarseMesosSchedulerBackend( override def start() { super.start() - val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() - startScheduler(master, CoarseMesosSchedulerBackend.this, fwInfo) + val driver = createSchedulerDriver( + master, CoarseMesosSchedulerBackend.this, sc.sparkUser, sc.appName, sc.conf) + startScheduler(driver) } - def createCommand(offer: Offer, numCores: Int): CommandInfo = { + def createCommand(offer: Offer, numCores: Int, taskId: Int): CommandInfo = { val executorSparkHome = conf.getOption("spark.mesos.executor.home") .orElse(sc.getSparkHome()) .getOrElse { @@ -116,10 +165,6 @@ private[spark] class CoarseMesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val driverUrl = sc.env.rpcEnv.uriOf( - SparkEnv.driverActorSystemName, - RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val uri = conf.getOption("spark.executor.uri") .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) @@ -129,7 +174,7 @@ private[spark] class CoarseMesosSchedulerBackend( command.setValue( "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend" .format(prefixEnv, runScript) + - s" --driver-url $driverUrl" + + s" --driver-url $driverURL" + s" --executor-id ${offer.getSlaveId.getValue}" + s" --hostname ${offer.getHostname}" + s" --cores $numCores" + @@ -138,27 +183,49 @@ private[spark] class CoarseMesosSchedulerBackend( // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.get.split('/').last.split('.').head + val executorId = sparkExecutorId(offer.getSlaveId.getValue, taskId.toString) command.setValue( s"cd $basename*; $prefixEnv " + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + - s" --driver-url $driverUrl" + - s" --executor-id ${offer.getSlaveId.getValue}" + + s" --driver-url $driverURL" + + s" --executor-id $executorId" + s" --hostname ${offer.getHostname}" + s" --cores $numCores" + s" --app-id $appId") command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get)) } + + conf.getOption("spark.mesos.uris").map { uris => + setupUris(uris, command) + } + command.build() } + protected def driverURL: String = { + if (conf.contains("spark.testing")) { + "driverURL" + } else { + sc.env.rpcEnv.uriOf( + SparkEnv.driverActorSystemName, + RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + } + } + override def offerRescinded(d: SchedulerDriver, o: OfferID) {} override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { appId = frameworkId.getValue + mesosExternalShuffleClient.foreach(_.init(appId)) logInfo("Registered as framework ID " + appId) markRegistered() } + override def sufficientResourcesRegistered(): Boolean = { + totalCoresAcquired >= maxCores * minRegisteredRatio + } + override def disconnected(d: SchedulerDriver) {} override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} @@ -168,15 +235,19 @@ private[spark] class CoarseMesosSchedulerBackend( * unless we've already launched more than we wanted to. */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { - synchronized { + stateLock.synchronized { val filters = Filters.newBuilder().setRefuseSeconds(5).build() - - for (offer <- offers) { - val slaveId = offer.getSlaveId.toString + for (offer <- offers.asScala) { + val offerAttributes = toAttributeMap(offer.getAttributesList) + val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + val slaveId = offer.getSlaveId.getValue val mem = getResource(offer.getResourcesList, "mem") val cpus = getResource(offer.getResourcesList, "cpus").toInt - if (totalCoresAcquired < maxCores && - mem >= MemoryUtils.calculateTotalMemory(sc) && + val id = offer.getId.getValue + if (taskIdToSlaveId.size < executorLimit && + totalCoresAcquired < maxCores && + meetsConstraints && + mem >= calculateTotalMemory(sc) && cpus >= 1 && failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && !slaveIdsWithExecutors.contains(slaveId)) { @@ -184,52 +255,72 @@ private[spark] class CoarseMesosSchedulerBackend( val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) totalCoresAcquired += cpusToUse val taskId = newMesosTaskId() - taskIdToSlaveId(taskId) = slaveId + taskIdToSlaveId.put(taskId, slaveId) slaveIdsWithExecutors += slaveId coresByTaskId(taskId) = cpusToUse - val task = MesosTaskInfo.newBuilder() + // Gather cpu resources from the available resources and use them in the task. + val (remainingResources, cpuResourcesToUse) = + partitionResources(offer.getResourcesList, "cpus", cpusToUse) + val (_, memResourcesToUse) = + partitionResources(remainingResources.asJava, "mem", calculateTotalMemory(sc)) + val taskBuilder = MesosTaskInfo.newBuilder() .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) .setSlaveId(offer.getSlaveId) - .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave)) + .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId)) .setName("Task " + taskId) - .addResources(createResource("cpus", cpusToUse)) - .addResources(createResource("mem", - MemoryUtils.calculateTotalMemory(sc))) + .addAllResources(cpuResourcesToUse.asJava) + .addAllResources(memResourcesToUse.asJava) sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => MesosSchedulerBackendUtil - .setupContainerBuilderDockerInfo(image, sc.conf, task.getContainerBuilder()) + .setupContainerBuilderDockerInfo(image, sc.conf, taskBuilder.getContainerBuilder()) } + // accept the offer and launch the task + logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + slaveIdToHost(offer.getSlaveId.getValue) = offer.getHostname d.launchTasks( - Collections.singleton(offer.getId), Collections.singletonList(task.build()), filters) + Collections.singleton(offer.getId), + Collections.singleton(taskBuilder.build()), filters) } else { - // Filter it out - d.launchTasks( - Collections.singleton(offer.getId), Collections.emptyList[MesosTaskInfo](), filters) + // Decline the offer + logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + d.declineOffer(offer.getId) } } } } - /** Build a Mesos resource protobuf object */ - private def createResource(resourceName: String, quantity: Double): Protos.Resource = { - Resource.newBuilder() - .setName(resourceName) - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(quantity).build()) - .build() - } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { val taskId = status.getTaskId.getValue.toInt val state = status.getState - logInfo("Mesos task " + taskId + " is now " + state) - synchronized { + logInfo(s"Mesos task $taskId is now $state") + val slaveId: String = status.getSlaveId.getValue + stateLock.synchronized { + // If the shuffle service is enabled, have the driver register with each one of the + // shuffle services. This allows the shuffle services to clean up state associated with + // this application when the driver exits. There is currently not a great way to detect + // this through Mesos, since the shuffle services are set up independently. + if (TaskState.fromMesos(state).equals(TaskState.RUNNING) && + slaveIdToHost.contains(slaveId) && + shuffleServiceEnabled) { + assume(mesosExternalShuffleClient.isDefined, + "External shuffle client was not instantiated even though shuffle service is enabled.") + // TODO: Remove this and allow the MesosExternalShuffleService to detect + // framework termination when new Mesos Framework HTTP API is available. + val externalShufflePort = conf.getInt("spark.shuffle.service.port", 7337) + val hostname = slaveIdToHost.remove(slaveId).get + logDebug(s"Connecting to shuffle service on slave $slaveId, " + + s"host $hostname, port $externalShufflePort for app ${conf.getAppId}") + mesosExternalShuffleClient.get + .registerDriverWithShuffleService(hostname, externalShufflePort) + } + if (TaskState.isFinished(TaskState.fromMesos(state))) { - val slaveId = taskIdToSlaveId(taskId) + val slaveId = taskIdToSlaveId.get(taskId) slaveIdsWithExecutors -= slaveId - taskIdToSlaveId -= taskId + taskIdToSlaveId.remove(taskId) // Remove the cores we have remembered for this task, if it's in the hashmap for (cores <- coresByTaskId.get(taskId)) { totalCoresAcquired -= cores @@ -239,18 +330,19 @@ private[spark] class CoarseMesosSchedulerBackend( if (TaskState.isFailed(TaskState.fromMesos(state))) { failuresBySlaveId(slaveId) = failuresBySlaveId.getOrElse(slaveId, 0) + 1 if (failuresBySlaveId(slaveId) >= MAX_SLAVE_FAILURES) { - logInfo("Blacklisting Mesos slave " + slaveId + " due to too many failures; " + + logInfo(s"Blacklisting Mesos slave $slaveId due to too many failures; " + "is Spark installed on it?") } } + executorTerminated(d, slaveId, s"Executor finished with state $state") // In case we'd rejected everything before but have now lost a node - mesosDriver.reviveOffers() + d.reviveOffers() } } } override def error(d: SchedulerDriver, message: String) { - logError("Mesos error: " + message) + logError(s"Mesos error: $message") scheduler.error(message) } @@ -263,18 +355,39 @@ private[spark] class CoarseMesosSchedulerBackend( override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} - override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) { - logInfo("Mesos slave lost: " + slaveId.getValue) - synchronized { - if (slaveIdsWithExecutors.contains(slaveId.getValue)) { - // Note that the slave ID corresponds to the executor ID on that slave - slaveIdsWithExecutors -= slaveId.getValue - removeExecutor(slaveId.getValue, "Mesos slave lost") + /** + * Called when a slave is lost or a Mesos task finished. Update local view on + * what tasks are running and remove the terminated slave from the list of pending + * slave IDs that we might have asked to be killed. It also notifies the driver + * that an executor was removed. + */ + private def executorTerminated(d: SchedulerDriver, slaveId: String, reason: String): Unit = { + stateLock.synchronized { + if (slaveIdsWithExecutors.contains(slaveId)) { + val slaveIdToTaskId = taskIdToSlaveId.inverse() + if (slaveIdToTaskId.containsKey(slaveId)) { + val taskId: Int = slaveIdToTaskId.get(slaveId) + taskIdToSlaveId.remove(taskId) + removeExecutor(sparkExecutorId(slaveId, taskId.toString), SlaveLost(reason)) + } + // TODO: This assumes one Spark executor per Mesos slave, + // which may no longer be true after SPARK-5095 + pendingRemovedSlaveIds -= slaveId + slaveIdsWithExecutors -= slaveId } } } - override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int) { + private def sparkExecutorId(slaveId: String, taskId: String): String = { + s"$slaveId/$taskId" + } + + override def slaveLost(d: SchedulerDriver, slaveId: SlaveID): Unit = { + logInfo(s"Mesos slave lost: ${slaveId.getValue}") + executorTerminated(d, slaveId.getValue, "Mesos slave lost: " + slaveId.getValue) + } + + override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int): Unit = { logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue)) slaveLost(d, s) } @@ -285,4 +398,34 @@ private[spark] class CoarseMesosSchedulerBackend( super.applicationId } + override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { + // We don't truly know if we can fulfill the full amount of executors + // since at coarse grain it depends on the amount of slaves available. + logInfo("Capping the total amount of executors to " + requestedTotal) + executorLimitOption = Some(requestedTotal) + true + } + + override def doKillExecutors(executorIds: Seq[String]): Boolean = { + if (mesosDriver == null) { + logWarning("Asked to kill executors before the Mesos driver was started.") + return false + } + + val slaveIdToTaskId = taskIdToSlaveId.inverse() + for (executorId <- executorIds) { + val slaveId = executorId.split("/")(0) + if (slaveIdToTaskId.containsKey(slaveId)) { + mesosDriver.killTask( + TaskID.newBuilder().setValue(slaveIdToTaskId.get(slaveId).toString).build()) + pendingRemovedSlaveIds += slaveId + } else { + logWarning("Unable to find executor Id '" + executorId + "' in Mesos scheduler") + } + } + // no need to adjust `executorLimitOption` since the AllocationManager already communicated + // the desired limit through a call to `doRequestTotalExecutors`. + // See [[o.a.s.scheduler.cluster.CoarseGrainedSchedulerBackend.killExecutors]] + true + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala index 3efc536f1456..e0c547dce6d0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster.mesos -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.curator.framework.CuratorFramework import org.apache.zookeeper.CreateMode @@ -129,6 +129,6 @@ private[spark] class ZookeeperMesosClusterPersistenceEngine( } override def fetchAll[T](): Iterable[T] = { - zk.getChildren.forPath(WORKING_DIR).map(fetch[T]).flatten + zk.getChildren.forPath(WORKING_DIR).asScala.flatMap(fetch[T]) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 1067a7f1caf4..a6d9374eb9e8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -21,7 +21,7 @@ import java.io.File import java.util.concurrent.locks.ReentrantLock import java.util.{Collections, Date, List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -294,20 +294,24 @@ private[spark] class MesosClusterScheduler( def start(): Unit = { // TODO: Implement leader election to make sure only one framework running in the cluster. val fwId = schedulerState.fetch[String]("frameworkId") - val builder = FrameworkInfo.newBuilder() - .setUser(Utils.getCurrentUserName()) - .setName(appName) - .setWebuiUrl(frameworkUrl) - .setCheckpoint(true) - .setFailoverTimeout(Integer.MAX_VALUE) // Setting to max so tasks keep running on crash fwId.foreach { id => - builder.setId(FrameworkID.newBuilder().setValue(id).build()) frameworkId = id } recoverState() metricsSystem.registerSource(new MesosClusterSchedulerSource(this)) metricsSystem.start() - startScheduler(master, MesosClusterScheduler.this, builder.build()) + val driver = createSchedulerDriver( + master, + MesosClusterScheduler.this, + Utils.getCurrentUserName(), + appName, + conf, + Some(frameworkUrl), + Some(true), + Some(Integer.MAX_VALUE), + fwId) + + startScheduler(driver) ready = true } @@ -345,7 +349,7 @@ private[spark] class MesosClusterScheduler( } // TODO: Page the status updates to avoid trying to reconcile // a large amount of tasks at once. - driver.reconcileTasks(statuses) + driver.reconcileTasks(statuses.toSeq.asJava) } } } @@ -370,21 +374,20 @@ private[spark] class MesosClusterScheduler( val executorOpts = desc.schedulerProperties.map { case (k, v) => s"-D$k=$v" }.mkString(" ") envBuilder.addVariables( Variable.newBuilder().setName("SPARK_EXECUTOR_OPTS").setValue(executorOpts)) - val cmdOptions = generateCmdOption(desc).mkString(" ") val dockerDefined = desc.schedulerProperties.contains("spark.mesos.executor.docker.image") val executorUri = desc.schedulerProperties.get("spark.executor.uri") .orElse(desc.command.environment.get("SPARK_EXECUTOR_URI")) - val appArguments = desc.command.arguments.mkString(" ") - val (executable, jar) = if (dockerDefined) { + // Gets the path to run spark-submit, and the path to the Mesos sandbox. + val (executable, sandboxPath) = if (dockerDefined) { // Application jar is automatically downloaded in the mounted sandbox by Mesos, // and the path to the mounted volume is stored in $MESOS_SANDBOX env variable. - ("./bin/spark-submit", s"$$MESOS_SANDBOX/${desc.jarUrl.split("/").last}") + ("./bin/spark-submit", "$MESOS_SANDBOX") } else if (executorUri.isDefined) { builder.addUris(CommandInfo.URI.newBuilder().setValue(executorUri.get).build()) val folderBasename = executorUri.get.split('/').last.split('.').head val cmdExecutable = s"cd $folderBasename*; $prefixEnv bin/spark-submit" - val cmdJar = s"../${desc.jarUrl.split("/").last}" - (cmdExecutable, cmdJar) + // Sandbox path points to the parent folder as we chdir into the folderBasename. + (cmdExecutable, "..") } else { val executorSparkHome = desc.schedulerProperties.get("spark.mesos.executor.home") .orElse(conf.getOption("spark.home")) @@ -393,27 +396,50 @@ private[spark] class MesosClusterScheduler( throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") } val cmdExecutable = new File(executorSparkHome, "./bin/spark-submit").getCanonicalPath - val cmdJar = desc.jarUrl.split("/").last - (cmdExecutable, cmdJar) + // Sandbox points to the current directory by default with Mesos. + (cmdExecutable, ".") } - builder.setValue(s"$executable $cmdOptions $jar $appArguments") + val primaryResource = new File(sandboxPath, desc.jarUrl.split("/").last).toString() + val cmdOptions = generateCmdOption(desc, sandboxPath).mkString(" ") + val appArguments = desc.command.arguments.mkString(" ") + builder.setValue(s"$executable $cmdOptions $primaryResource $appArguments") builder.setEnvironment(envBuilder.build()) + conf.getOption("spark.mesos.uris").map { uris => + setupUris(uris, builder) + } + desc.schedulerProperties.get("spark.mesos.uris").map { uris => + setupUris(uris, builder) + } + desc.schedulerProperties.get("spark.submit.pyFiles").map { pyFiles => + setupUris(pyFiles, builder) + } builder.build() } - private def generateCmdOption(desc: MesosDriverDescription): Seq[String] = { + private def generateCmdOption(desc: MesosDriverDescription, sandboxPath: String): Seq[String] = { var options = Seq( "--name", desc.schedulerProperties("spark.app.name"), - "--class", desc.command.mainClass, "--master", s"mesos://${conf.get("spark.master")}", "--driver-cores", desc.cores.toString, "--driver-memory", s"${desc.mem}M") + + // Assume empty main class means we're running python + if (!desc.command.mainClass.equals("")) { + options ++= Seq("--class", desc.command.mainClass) + } + desc.schedulerProperties.get("spark.executor.memory").map { v => options ++= Seq("--executor-memory", v) } desc.schedulerProperties.get("spark.cores.max").map { v => options ++= Seq("--total-executor-cores", v) } + desc.schedulerProperties.get("spark.submit.pyFiles").map { pyFiles => + val formattedFiles = pyFiles.split(",") + .map { path => new File(sandboxPath, path.split("/").last).toString() } + .mkString(",") + options ++= Seq("--py-files", formattedFiles) + } options } @@ -448,12 +474,8 @@ private[spark] class MesosClusterScheduler( offer.cpu -= driverCpu offer.mem -= driverMem val taskId = TaskID.newBuilder().setValue(submission.submissionId).build() - val cpuResource = Resource.newBuilder() - .setName("cpus").setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(driverCpu)).build() - val memResource = Resource.newBuilder() - .setName("mem").setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(driverMem)).build() + val cpuResource = createResource("cpus", driverCpu) + val memResource = createResource("mem", driverMem) val commandInfo = buildDriverCommand(submission) val appName = submission.schedulerProperties("spark.app.name") val taskInfo = TaskInfo.newBuilder() @@ -489,10 +511,10 @@ private[spark] class MesosClusterScheduler( } override def resourceOffers(driver: SchedulerDriver, offers: JList[Offer]): Unit = { - val currentOffers = offers.map { o => + val currentOffers = offers.asScala.map(o => new ResourceOffer( o, getResource(o.getResourcesList, "cpus"), getResource(o.getResourcesList, "mem")) - }.toList + ).toList logTrace(s"Received offers from Mesos: \n${currentOffers.mkString("\n")}") val tasks = new mutable.HashMap[OfferID, ArrayBuffer[TaskInfo]]() val currentTime = new Date() @@ -503,33 +525,36 @@ private[spark] class MesosClusterScheduler( val driversToRetry = pendingRetryDrivers.filter { d => d.retryState.get.nextRetry.before(currentTime) } + scheduleTasks( - driversToRetry, + copyBuffer(driversToRetry), removeFromPendingRetryDrivers, currentOffers, tasks) + // Then we walk through the queued drivers and try to schedule them. scheduleTasks( - queuedDrivers, + copyBuffer(queuedDrivers), removeFromQueuedDrivers, currentOffers, tasks) } - tasks.foreach { case (offerId, tasks) => - driver.launchTasks(Collections.singleton(offerId), tasks) + tasks.foreach { case (offerId, taskInfos) => + driver.launchTasks(Collections.singleton(offerId), taskInfos.asJava) } - offers + offers.asScala .filter(o => !tasks.keySet.contains(o.getId)) .foreach(o => driver.declineOffer(o.getId)) } + private def copyBuffer( + buffer: ArrayBuffer[MesosDriverDescription]): ArrayBuffer[MesosDriverDescription] = { + val newBuffer = new ArrayBuffer[MesosDriverDescription](buffer.size) + buffer.copyToBuffer(newBuffer) + newBuffer + } + def getSchedulerState(): MesosClusterSchedulerState = { - def copyBuffer( - buffer: ArrayBuffer[MesosDriverDescription]): ArrayBuffer[MesosDriverDescription] = { - val newBuffer = new ArrayBuffer[MesosDriverDescription](buffer.size) - buffer.copyToBuffer(newBuffer) - newBuffer - } stateLock.synchronized { new MesosClusterSchedulerState( frameworkId, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 49de85ef48ad..8edf7007a5da 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -20,17 +20,17 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File import java.util.{ArrayList => JArrayList, Collections, List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, HashSet} +import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, _} import org.apache.mesos.protobuf.ByteString -import org.apache.mesos.{Scheduler => MScheduler, _} +import org.apache.spark.{SparkContext, SparkException, TaskState} import org.apache.spark.executor.MesosExecutorBackend import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.Utils -import org.apache.spark.{SparkContext, SparkException, TaskState} /** * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a @@ -45,8 +45,8 @@ private[spark] class MesosSchedulerBackend( with MScheduler with MesosSchedulerUtils { - // Which slave IDs we have executors on - val slaveIdsWithExecutors = new HashSet[String] + // Stores the slave ids that has launched a Mesos executor. + val slaveIdToExecutorInfo = new HashMap[String, MesosExecutorInfo] val taskIdToSlaveId = new HashMap[Long, String] // An ExecutorInfo for our tasks @@ -59,20 +59,33 @@ private[spark] class MesosSchedulerBackend( private[mesos] val mesosExecutorCores = sc.conf.getDouble("spark.mesos.mesosExecutor.cores", 1) + // Offer constraints + private[this] val slaveOfferConstraints = + parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + @volatile var appId: String = _ override def start() { - val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() classLoader = Thread.currentThread.getContextClassLoader - startScheduler(master, MesosSchedulerBackend.this, fwInfo) + val driver = createSchedulerDriver( + master, MesosSchedulerBackend.this, sc.sparkUser, sc.appName, sc.conf) + startScheduler(driver) } - def createExecutorInfo(execId: String): MesosExecutorInfo = { + /** + * Creates a MesosExecutorInfo that is used to launch a Mesos executor. + * @param availableResources Available resources that is offered by Mesos + * @param execId The executor id to assign to this new executor. + * @return A tuple of the new mesos executor info and the remaining available resources. + */ + def createExecutorInfo( + availableResources: JList[Resource], + execId: String): (MesosExecutorInfo, JList[Resource]) = { val executorSparkHome = sc.conf.getOption("spark.mesos.executor.home") .orElse(sc.getSparkHome()) // Fall back to driver Spark home for backward compatibility .getOrElse { - throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") - } + throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") + } val environment = Environment.newBuilder() sc.conf.getOption("spark.executor.extraClassPath").foreach { cp => environment.addVariables( @@ -111,32 +124,28 @@ private[spark] class MesosSchedulerBackend( command.setValue(s"cd ${basename}*; $prefixEnv ./bin/spark-class $executorBackendName") command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get)) } - val cpus = Resource.newBuilder() - .setName("cpus") - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder() - .setValue(mesosExecutorCores).build()) - .build() - val memory = Resource.newBuilder() - .setName("mem") - .setType(Value.Type.SCALAR) - .setScalar( - Value.Scalar.newBuilder() - .setValue(MemoryUtils.calculateTotalMemory(sc)).build()) - .build() - val executorInfo = MesosExecutorInfo.newBuilder() + val builder = MesosExecutorInfo.newBuilder() + val (resourcesAfterCpu, usedCpuResources) = + partitionResources(availableResources, "cpus", mesosExecutorCores) + val (resourcesAfterMem, usedMemResources) = + partitionResources(resourcesAfterCpu.asJava, "mem", calculateTotalMemory(sc)) + + builder.addAllResources(usedCpuResources.asJava) + builder.addAllResources(usedMemResources.asJava) + + sc.conf.getOption("spark.mesos.uris").foreach(setupUris(_, command)) + + val executorInfo = builder .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) .setCommand(command) .setData(ByteString.copyFrom(createExecArg())) - .addResources(cpus) - .addResources(memory) sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => MesosSchedulerBackendUtil .setupContainerBuilderDockerInfo(image, sc.conf, executorInfo.getContainerBuilder()) } - executorInfo.build() + (executorInfo.build(), resourcesAfterMem.asJava) } /** @@ -179,6 +188,18 @@ private[spark] class MesosSchedulerBackend( override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} + private def getTasksSummary(tasks: JArrayList[MesosTaskInfo]): String = { + val builder = new StringBuilder + tasks.asScala.foreach { t => + builder.append("Task id: ").append(t.getTaskId.getValue).append("\n") + .append("Slave id: ").append(t.getSlaveId.getValue).append("\n") + .append("Task resources: ").append(t.getResourcesList).append("\n") + .append("Executor resources: ").append(t.getExecutor.getResourcesList) + .append("---------------------------------------------\n") + } + builder.toString() + } + /** * Method called by Mesos to offer resources on slaves. We respond by asking our active task sets * for tasks in order of priority. We fill each node with tasks in a round-robin manner so that @@ -187,19 +208,37 @@ private[spark] class MesosSchedulerBackend( override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { inClassLoader() { // Fail-fast on offers we know will be rejected - val (usableOffers, unUsableOffers) = offers.partition { o => + val (usableOffers, unUsableOffers) = offers.asScala.partition { o => val mem = getResource(o.getResourcesList, "mem") val cpus = getResource(o.getResourcesList, "cpus") val slaveId = o.getSlaveId.getValue - (mem >= MemoryUtils.calculateTotalMemory(sc) && - // need at least 1 for executor, 1 for task - cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK)) || - (slaveIdsWithExecutors.contains(slaveId) && - cpus >= scheduler.CPUS_PER_TASK) + val offerAttributes = toAttributeMap(o.getAttributesList) + + // check if all constraints are satisfield + // 1. Attribute constraints + // 2. Memory requirements + // 3. CPU requirements - need at least 1 for executor, 1 for task + val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + val meetsMemoryRequirements = mem >= calculateTotalMemory(sc) + val meetsCPURequirements = cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK) + + val meetsRequirements = + (meetsConstraints && meetsMemoryRequirements && meetsCPURequirements) || + (slaveIdToExecutorInfo.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK) + + // add some debug messaging + val debugstr = if (meetsRequirements) "Accepting" else "Declining" + val id = o.getId.getValue + logDebug(s"$debugstr offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + + meetsRequirements } + // Decline offers we ruled out immediately + unUsableOffers.foreach(o => d.declineOffer(o.getId)) + val workerOffers = usableOffers.map { o => - val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) { + val cpus = if (slaveIdToExecutorInfo.contains(o.getSlaveId.getValue)) { getResource(o.getResourcesList, "cpus").toInt } else { // If the Mesos executor has not been started on this slave yet, set aside a few @@ -214,6 +253,10 @@ private[spark] class MesosSchedulerBackend( val slaveIdToOffer = usableOffers.map(o => o.getSlaveId.getValue -> o).toMap val slaveIdToWorkerOffer = workerOffers.map(o => o.executorId -> o).toMap + val slaveIdToResources = new HashMap[String, JList[Resource]]() + usableOffers.foreach { o => + slaveIdToResources(o.getSlaveId.getValue) = o.getResourcesList + } val mesosTasks = new HashMap[String, JArrayList[MesosTaskInfo]] @@ -225,11 +268,15 @@ private[spark] class MesosSchedulerBackend( .foreach { offer => offer.foreach { taskDesc => val slaveId = taskDesc.executorId - slaveIdsWithExecutors += slaveId slavesIdsOfAcceptedOffers += slaveId taskIdToSlaveId(taskDesc.taskId) = slaveId + val (mesosTask, remainingResources) = createMesosTask( + taskDesc, + slaveIdToResources(slaveId), + slaveId) mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) - .add(createMesosTask(taskDesc, slaveId)) + .add(mesosTask) + slaveIdToResources(slaveId) = remainingResources } } @@ -242,6 +289,7 @@ private[spark] class MesosSchedulerBackend( // TODO: Add support for log urls for Mesos new ExecutorInfo(o.host, o.cores, Map.empty))) ) + logTrace(s"Launching Mesos tasks on slave '$slaveId', tasks:\n${getTasksSummary(tasks)}") d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters) } @@ -250,28 +298,32 @@ private[spark] class MesosSchedulerBackend( for (o <- usableOffers if !slavesIdsOfAcceptedOffers.contains(o.getSlaveId.getValue)) { d.declineOffer(o.getId) } - - // Decline offers we ruled out immediately - unUsableOffers.foreach(o => d.declineOffer(o.getId)) } } - /** Turn a Spark TaskDescription into a Mesos task */ - def createMesosTask(task: TaskDescription, slaveId: String): MesosTaskInfo = { + /** Turn a Spark TaskDescription into a Mesos task and also resources unused by the task */ + def createMesosTask( + task: TaskDescription, + resources: JList[Resource], + slaveId: String): (MesosTaskInfo, JList[Resource]) = { val taskId = TaskID.newBuilder().setValue(task.taskId.toString).build() - val cpuResource = Resource.newBuilder() - .setName("cpus") - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(scheduler.CPUS_PER_TASK).build()) - .build() - MesosTaskInfo.newBuilder() + val (executorInfo, remainingResources) = if (slaveIdToExecutorInfo.contains(slaveId)) { + (slaveIdToExecutorInfo(slaveId), resources) + } else { + createExecutorInfo(resources, slaveId) + } + slaveIdToExecutorInfo(slaveId) = executorInfo + val (finalResources, cpuResources) = + partitionResources(remainingResources, "cpus", scheduler.CPUS_PER_TASK) + val taskInfo = MesosTaskInfo.newBuilder() .setTaskId(taskId) .setSlaveId(SlaveID.newBuilder().setValue(slaveId).build()) - .setExecutor(createExecutorInfo(slaveId)) + .setExecutor(executorInfo) .setName(task.name) - .addResources(cpuResource) + .addAllResources(cpuResources.asJava) .setData(MesosTaskLaunchData(task.serializedTask, task.attemptNumber).toByteString) .build() + (taskInfo, finalResources.asJava) } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { @@ -317,7 +369,7 @@ private[spark] class MesosSchedulerBackend( private def removeExecutor(slaveId: String, reason: String) = { synchronized { listenerBus.post(SparkListenerExecutorRemoved(System.currentTimeMillis(), slaveId, reason)) - slaveIdsWithExecutors -= slaveId + slaveIdToExecutorInfo -= slaveId } } @@ -337,7 +389,7 @@ private[spark] class MesosSchedulerBackend( slaveId: SlaveID, status: Int) { logInfo("Executor lost: %s, marking slave %s as lost".format(executorId.getValue, slaveId.getValue)) - recordSlaveLost(d, slaveId, ExecutorExited(status)) + recordSlaveLost(d, slaveId, ExecutorExited(status, isNormalExit = false)) } override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index d11228f3d016..860c8e097b3b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -17,16 +17,21 @@ package org.apache.spark.scheduler.cluster.mesos -import java.util.List +import java.util.{List => JList} import java.util.concurrent.CountDownLatch -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer +import scala.util.control.NonFatal -import org.apache.mesos.Protos.{FrameworkInfo, Resource, Status} -import org.apache.mesos.{MesosSchedulerDriver, Scheduler} -import org.apache.spark.Logging +import com.google.common.base.Splitter +import org.apache.mesos.{MesosSchedulerDriver, SchedulerDriver, Scheduler, Protos} +import org.apache.mesos.Protos._ +import org.apache.mesos.protobuf.{ByteString, GeneratedMessage} +import org.apache.spark.{SparkException, SparkConf, Logging, SparkContext} import org.apache.spark.util.Utils + /** * Shared trait for implementing a Mesos Scheduler. This holds common state and helper * methods and Mesos scheduler will use. @@ -36,16 +41,66 @@ private[mesos] trait MesosSchedulerUtils extends Logging { private final val registerLatch = new CountDownLatch(1) // Driver for talking to Mesos - protected var mesosDriver: MesosSchedulerDriver = null + protected var mesosDriver: SchedulerDriver = null + + /** + * Creates a new MesosSchedulerDriver that communicates to the Mesos master. + * @param masterUrl The url to connect to Mesos master + * @param scheduler the scheduler class to receive scheduler callbacks + * @param sparkUser User to impersonate with when running tasks + * @param appName The framework name to display on the Mesos UI + * @param conf Spark configuration + * @param webuiUrl The WebUI url to link from Mesos UI + * @param checkpoint Option to checkpoint tasks for failover + * @param failoverTimeout Duration Mesos master expect scheduler to reconnect on disconnect + * @param frameworkId The id of the new framework + */ + protected def createSchedulerDriver( + masterUrl: String, + scheduler: Scheduler, + sparkUser: String, + appName: String, + conf: SparkConf, + webuiUrl: Option[String] = None, + checkpoint: Option[Boolean] = None, + failoverTimeout: Option[Double] = None, + frameworkId: Option[String] = None): SchedulerDriver = { + val fwInfoBuilder = FrameworkInfo.newBuilder().setUser(sparkUser).setName(appName) + val credBuilder = Credential.newBuilder() + webuiUrl.foreach { url => fwInfoBuilder.setWebuiUrl(url) } + checkpoint.foreach { checkpoint => fwInfoBuilder.setCheckpoint(checkpoint) } + failoverTimeout.foreach { timeout => fwInfoBuilder.setFailoverTimeout(timeout) } + frameworkId.foreach { id => + fwInfoBuilder.setId(FrameworkID.newBuilder().setValue(id).build()) + } + conf.getOption("spark.mesos.principal").foreach { principal => + fwInfoBuilder.setPrincipal(principal) + credBuilder.setPrincipal(principal) + } + conf.getOption("spark.mesos.secret").foreach { secret => + credBuilder.setSecret(ByteString.copyFromUtf8(secret)) + } + if (credBuilder.hasSecret && !fwInfoBuilder.hasPrincipal) { + throw new SparkException( + "spark.mesos.principal must be configured when spark.mesos.secret is set") + } + conf.getOption("spark.mesos.role").foreach { role => + fwInfoBuilder.setRole(role) + } + if (credBuilder.hasPrincipal) { + new MesosSchedulerDriver( + scheduler, fwInfoBuilder.build(), masterUrl, credBuilder.build()) + } else { + new MesosSchedulerDriver(scheduler, fwInfoBuilder.build(), masterUrl) + } + } /** - * Starts the MesosSchedulerDriver with the provided information. This method returns - * only after the scheduler has registered with Mesos. - * @param masterUrl Mesos master connection URL - * @param scheduler Scheduler object - * @param fwInfo FrameworkInfo to pass to the Mesos master + * Starts the MesosSchedulerDriver and stores the current running driver to this new instance. + * This driver is expected to not be running. + * This method returns only after the scheduler has registered with Mesos. */ - def startScheduler(masterUrl: String, scheduler: Scheduler, fwInfo: FrameworkInfo): Unit = { + def startScheduler(newDriver: SchedulerDriver): Unit = { synchronized { if (mesosDriver != null) { registerLatch.await() @@ -56,11 +111,11 @@ private[mesos] trait MesosSchedulerUtils extends Logging { setDaemon(true) override def run() { - mesosDriver = new MesosSchedulerDriver(scheduler, fwInfo, masterUrl) + mesosDriver = newDriver try { val ret = mesosDriver.run() logInfo("driver.run() returned with code " + ret) - if (ret.equals(Status.DRIVER_ABORTED)) { + if (ret != null && ret.equals(Status.DRIVER_ABORTED)) { System.exit(1) } } catch { @@ -79,17 +134,206 @@ private[mesos] trait MesosSchedulerUtils extends Logging { /** * Signal that the scheduler has registered with Mesos. */ + protected def getResource(res: JList[Resource], name: String): Double = { + // A resource can have multiple values in the offer since it can either be from + // a specific role or wildcard. + res.asScala.filter(_.getName == name).map(_.getScalar.getValue).sum + } + protected def markRegistered(): Unit = { registerLatch.countDown() } + def createResource(name: String, amount: Double, role: Option[String] = None): Resource = { + val builder = Resource.newBuilder() + .setName(name) + .setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(amount).build()) + + role.foreach { r => builder.setRole(r) } + + builder.build() + } + + /** + * Partition the existing set of resources into two groups, those remaining to be + * scheduled and those requested to be used for a new task. + * @param resources The full list of available resources + * @param resourceName The name of the resource to take from the available resources + * @param amountToUse The amount of resources to take from the available resources + * @return The remaining resources list and the used resources list. + */ + def partitionResources( + resources: JList[Resource], + resourceName: String, + amountToUse: Double): (List[Resource], List[Resource]) = { + var remain = amountToUse + var requestedResources = new ArrayBuffer[Resource] + val remainingResources = resources.asScala.map { + case r => { + if (remain > 0 && + r.getType == Value.Type.SCALAR && + r.getScalar.getValue > 0.0 && + r.getName == resourceName) { + val usage = Math.min(remain, r.getScalar.getValue) + requestedResources += createResource(resourceName, usage, Some(r.getRole)) + remain -= usage + createResource(resourceName, r.getScalar.getValue - usage, Some(r.getRole)) + } else { + r + } + } + } + + // Filter any resource that has depleted. + val filteredResources = + remainingResources.filter(r => r.getType != Value.Type.SCALAR || r.getScalar.getValue > 0.0) + + (filteredResources.toList, requestedResources.toList) + } + + /** Helper method to get the key,value-set pair for a Mesos Attribute protobuf */ + protected def getAttribute(attr: Attribute): (String, Set[String]) = { + (attr.getName, attr.getText.getValue.split(',').toSet) + } + + + /** Build a Mesos resource protobuf object */ + protected def createResource(resourceName: String, quantity: Double): Protos.Resource = { + Resource.newBuilder() + .setName(resourceName) + .setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(quantity).build()) + .build() + } + + /** + * Converts the attributes from the resource offer into a Map of name -> Attribute Value + * The attribute values are the mesos attribute types and they are + * @param offerAttributes + * @return + */ + protected def toAttributeMap(offerAttributes: JList[Attribute]): Map[String, GeneratedMessage] = { + offerAttributes.asScala.map(attr => { + val attrValue = attr.getType match { + case Value.Type.SCALAR => attr.getScalar + case Value.Type.RANGES => attr.getRanges + case Value.Type.SET => attr.getSet + case Value.Type.TEXT => attr.getText + } + (attr.getName, attrValue) + }).toMap + } + + /** - * Get the amount of resources for the specified type from the resource list + * Match the requirements (if any) to the offer attributes. + * if attribute requirements are not specified - return true + * else if attribute is defined and no values are given, simple attribute presence is performed + * else if attribute name and value is specified, subset match is performed on slave attributes */ - protected def getResource(res: List[Resource], name: String): Double = { - for (r <- res if r.getName == name) { - return r.getScalar.getValue + def matchesAttributeRequirements( + slaveOfferConstraints: Map[String, Set[String]], + offerAttributes: Map[String, GeneratedMessage]): Boolean = { + slaveOfferConstraints.forall { + // offer has the required attribute and subsumes the required values for that attribute + case (name, requiredValues) => + offerAttributes.get(name) match { + case None => false + case Some(_) if requiredValues.isEmpty => true // empty value matches presence + case Some(scalarValue: Value.Scalar) => + // check if provided values is less than equal to the offered values + requiredValues.map(_.toDouble).exists(_ <= scalarValue.getValue) + case Some(rangeValue: Value.Range) => + val offerRange = rangeValue.getBegin to rangeValue.getEnd + // Check if there is some required value that is between the ranges specified + // Note: We only support the ability to specify discrete values, in the future + // we may expand it to subsume ranges specified with a XX..YY value or something + // similar to that. + requiredValues.map(_.toLong).exists(offerRange.contains(_)) + case Some(offeredValue: Value.Set) => + // check if the specified required values is a subset of offered set + requiredValues.subsetOf(offeredValue.getItemList.asScala.toSet) + case Some(textValue: Value.Text) => + // check if the specified value is equal, if multiple values are specified + // we succeed if any of them match. + requiredValues.contains(textValue.getValue) + } } - 0.0 } + + /** + * Parses the attributes constraints provided to spark and build a matching data struct: + * Map[, Set[values-to-match]] + * The constraints are specified as ';' separated key-value pairs where keys and values + * are separated by ':'. The ':' implies equality (for singular values) and "is one of" for + * multiple values (comma separated). For example: + * {{{ + * parseConstraintString("tachyon:true;zone:us-east-1a,us-east-1b") + * // would result in + * + * Map( + * "tachyon" -> Set("true"), + * "zone": -> Set("us-east-1a", "us-east-1b") + * ) + * }}} + * + * Mesos documentation: http://mesos.apache.org/documentation/attributes-resources/ + * https://github.com/apache/mesos/blob/master/src/common/values.cpp + * https://github.com/apache/mesos/blob/master/src/common/attributes.cpp + * + * @param constraintsVal constaints string consisting of ';' separated key-value pairs (separated + * by ':') + * @return Map of constraints to match resources offers. + */ + def parseConstraintString(constraintsVal: String): Map[String, Set[String]] = { + /* + Based on mesos docs: + attributes : attribute ( ";" attribute )* + attribute : labelString ":" ( labelString | "," )+ + labelString : [a-zA-Z0-9_/.-] + */ + val splitter = Splitter.on(';').trimResults().withKeyValueSeparator(':') + // kv splitter + if (constraintsVal.isEmpty) { + Map() + } else { + try { + splitter.split(constraintsVal).asScala.toMap.mapValues(v => + if (v == null || v.isEmpty) { + Set[String]() + } else { + v.split(',').toSet + } + ) + } catch { + case NonFatal(e) => + throw new IllegalArgumentException(s"Bad constraint string: $constraintsVal", e) + } + } + } + + // These defaults copied from YARN + private val MEMORY_OVERHEAD_FRACTION = 0.10 + private val MEMORY_OVERHEAD_MINIMUM = 384 + + /** + * Return the amount of memory to allocate to each executor, taking into account + * container overheads. + * @param sc SparkContext to use to get `spark.mesos.executor.memoryOverhead` value + * @return memory requirement as (0.1 * ) or MEMORY_OVERHEAD_MINIMUM + * (whichever is larger) + */ + def calculateTotalMemory(sc: SparkContext): Int = { + sc.conf.getInt("spark.mesos.executor.memoryOverhead", + math.max(MEMORY_OVERHEAD_FRACTION * sc.executorMemory, MEMORY_OVERHEAD_MINIMUM).toInt) + + sc.executorMemory + } + + def setupUris(uris: String, builder: CommandInfo.Builder): Unit = { + uris.split(",").foreach { uri => + builder.addUris(CommandInfo.URI.newBuilder().setValue(uri.trim())) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 3078a1b10be8..4d48fcfea44e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -17,13 +17,16 @@ package org.apache.spark.scheduler.local +import java.io.File +import java.net.URL import java.nio.ByteBuffer import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} -import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} +import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.ExecutorInfo private case class ReviveOffers() @@ -40,6 +43,7 @@ private case class StopExecutor() */ private[spark] class LocalEndpoint( override val rpcEnv: RpcEnv, + userClassPath: Seq[URL], scheduler: TaskSchedulerImpl, executorBackend: LocalBackend, private val totalCores: Int) @@ -47,11 +51,11 @@ private[spark] class LocalEndpoint( private var freeCores = totalCores - private val localExecutorId = SparkContext.DRIVER_IDENTIFIER - private val localExecutorHostname = "localhost" + val localExecutorId = SparkContext.DRIVER_IDENTIFIER + val localExecutorHostname = "localhost" private val executor = new Executor( - localExecutorId, localExecutorHostname, SparkEnv.get, isLocal = true) + localExecutorId, localExecutorHostname, SparkEnv.get, userClassPath, isLocal = true) override def receive: PartialFunction[Any, Unit] = { case ReviveOffers => @@ -96,11 +100,28 @@ private[spark] class LocalBackend( extends SchedulerBackend with ExecutorBackend with Logging { private val appId = "local-" + System.currentTimeMillis - var localEndpoint: RpcEndpointRef = null + private var localEndpoint: RpcEndpointRef = null + private val userClassPath = getUserClasspath(conf) + private val listenerBus = scheduler.sc.listenerBus + + /** + * Returns a list of URLs representing the user classpath. + * + * @param conf Spark configuration. + */ + def getUserClasspath(conf: SparkConf): Seq[URL] = { + val userClassPathStr = conf.getOption("spark.executor.extraClassPath") + userClassPathStr.map(_.split(File.pathSeparator)).toSeq.flatten.map(new File(_).toURI.toURL) + } override def start() { - localEndpoint = SparkEnv.get.rpcEnv.setupEndpoint( - "LocalBackendEndpoint", new LocalEndpoint(SparkEnv.get.rpcEnv, scheduler, this, totalCores)) + val rpcEnv = SparkEnv.get.rpcEnv + val executorEndpoint = new LocalEndpoint(rpcEnv, userClassPath, scheduler, this, totalCores) + localEndpoint = rpcEnv.setupEndpoint("LocalBackendEndpoint", executorEndpoint) + listenerBus.post(SparkListenerExecutorAdded( + System.currentTimeMillis, + executorEndpoint.localExecutorId, + new ExecutorInfo(executorEndpoint.localExecutorHostname, totalCores, Map.empty))) } override def stop() { diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala new file mode 100644 index 000000000000..62f8aae7f212 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.ByteBuffer + +import scala.collection.mutable + +import com.esotericsoftware.kryo.{Kryo, Serializer => KSerializer} +import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} +import org.apache.avro.{Schema, SchemaNormalization} +import org.apache.avro.generic.{GenericData, GenericRecord} +import org.apache.avro.io._ +import org.apache.commons.io.IOUtils + +import org.apache.spark.{SparkException, SparkEnv} +import org.apache.spark.io.CompressionCodec + +/** + * Custom serializer used for generic Avro records. If the user registers the schemas + * ahead of time, then the schema's fingerprint will be sent with each message instead of the actual + * schema, as to reduce network IO. + * Actions like parsing or compressing schemas are computationally expensive so the serializer + * caches all previously seen values as to reduce the amount of work needed to do. + * @param schemas a map where the keys are unique IDs for Avro schemas and the values are the + * string representation of the Avro schema, used to decrease the amount of data + * that needs to be serialized. + */ +private[serializer] class GenericAvroSerializer(schemas: Map[Long, String]) + extends KSerializer[GenericRecord] { + + /** Used to reduce the amount of effort to compress the schema */ + private val compressCache = new mutable.HashMap[Schema, Array[Byte]]() + private val decompressCache = new mutable.HashMap[ByteBuffer, Schema]() + + /** Reuses the same datum reader/writer since the same schema will be used many times */ + private val writerCache = new mutable.HashMap[Schema, DatumWriter[_]]() + private val readerCache = new mutable.HashMap[Schema, DatumReader[_]]() + + /** Fingerprinting is very expensive so this alleviates most of the work */ + private val fingerprintCache = new mutable.HashMap[Schema, Long]() + private val schemaCache = new mutable.HashMap[Long, Schema]() + + // GenericAvroSerializer can't take a SparkConf in the constructor b/c then it would become + // a member of KryoSerializer, which would make KryoSerializer not Serializable. We make + // the codec lazy here just b/c in some unit tests, we use a KryoSerializer w/out having + // the SparkEnv set (note those tests would fail if they tried to serialize avro data). + private lazy val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + + /** + * Used to compress Schemas when they are being sent over the wire. + * The compression results are memoized to reduce the compression time since the + * same schema is compressed many times over + */ + def compress(schema: Schema): Array[Byte] = compressCache.getOrElseUpdate(schema, { + val bos = new ByteArrayOutputStream() + val out = codec.compressedOutputStream(bos) + out.write(schema.toString.getBytes("UTF-8")) + out.close() + bos.toByteArray + }) + + /** + * Decompresses the schema into the actual in-memory object. Keeps an internal cache of already + * seen values so to limit the number of times that decompression has to be done. + */ + def decompress(schemaBytes: ByteBuffer): Schema = decompressCache.getOrElseUpdate(schemaBytes, { + val bis = new ByteArrayInputStream(schemaBytes.array()) + val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis)) + new Schema.Parser().parse(new String(bytes, "UTF-8")) + }) + + /** + * Serializes a record to the given output stream. It caches a lot of the internal data as + * to not redo work + */ + def serializeDatum[R <: GenericRecord](datum: R, output: KryoOutput): Unit = { + val encoder = EncoderFactory.get.binaryEncoder(output, null) + val schema = datum.getSchema + val fingerprint = fingerprintCache.getOrElseUpdate(schema, { + SchemaNormalization.parsingFingerprint64(schema) + }) + schemas.get(fingerprint) match { + case Some(_) => + output.writeBoolean(true) + output.writeLong(fingerprint) + case None => + output.writeBoolean(false) + val compressedSchema = compress(schema) + output.writeInt(compressedSchema.length) + output.writeBytes(compressedSchema) + } + + writerCache.getOrElseUpdate(schema, GenericData.get.createDatumWriter(schema)) + .asInstanceOf[DatumWriter[R]] + .write(datum, encoder) + encoder.flush() + } + + /** + * Deserializes generic records into their in-memory form. There is internal + * state to keep a cache of already seen schemas and datum readers. + */ + def deserializeDatum(input: KryoInput): GenericRecord = { + val schema = { + if (input.readBoolean()) { + val fingerprint = input.readLong() + schemaCache.getOrElseUpdate(fingerprint, { + schemas.get(fingerprint) match { + case Some(s) => new Schema.Parser().parse(s) + case None => + throw new SparkException( + "Error reading attempting to read avro data -- encountered an unknown " + + s"fingerprint: $fingerprint, not sure what schema to use. This could happen " + + "if you registered additional schemas after starting your spark context.") + } + }) + } else { + val length = input.readInt() + decompress(ByteBuffer.wrap(input.readBytes(length))) + } + } + val decoder = DecoderFactory.get.directBinaryDecoder(input, null) + readerCache.getOrElseUpdate(schema, GenericData.get.createDatumReader(schema)) + .asInstanceOf[DatumReader[GenericRecord]] + .read(null, decoder) + } + + override def write(kryo: Kryo, output: KryoOutput, datum: GenericRecord): Unit = + serializeDatum(datum, output) + + override def read(kryo: Kryo, input: KryoInput, datumClass: Class[GenericRecord]): GenericRecord = + deserializeDatum(input) +} diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 698d1384d580..b463a71d5bd7 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -63,13 +63,33 @@ private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoa private val objIn = new ObjectInputStream(in) { override def resolveClass(desc: ObjectStreamClass): Class[_] = - Class.forName(desc.getName, false, loader) + try { + // scalastyle:off classforname + Class.forName(desc.getName, false, loader) + // scalastyle:on classforname + } catch { + case e: ClassNotFoundException => + JavaDeserializationStream.primitiveMappings.get(desc.getName).getOrElse(throw e) + } } def readObject[T: ClassTag](): T = objIn.readObject().asInstanceOf[T] def close() { objIn.close() } } +private object JavaDeserializationStream { + val primitiveMappings = Map[String, Class[_]]( + "boolean" -> classOf[Boolean], + "byte" -> classOf[Byte], + "char" -> classOf[Char], + "short" -> classOf[Short], + "int" -> classOf[Int], + "long" -> classOf[Long], + "float" -> classOf[Float], + "double" -> classOf[Double], + "void" -> classOf[Void] + ) +} private[spark] class JavaSerializerInstance( counterReset: Int, extraDebugInfo: Boolean, defaultClassLoader: ClassLoader) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index cd8a82347a1e..c5195c1143a8 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -21,22 +21,24 @@ import java.io.{EOFException, IOException, InputStream, OutputStream} import java.nio.ByteBuffer import javax.annotation.Nullable +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import com.esotericsoftware.kryo.{Kryo, KryoException} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} +import org.apache.avro.generic.{GenericData, GenericRecord} import org.roaringbitmap.{ArrayContainer, BitmapContainer, RoaringArray, RoaringBitmap} import org.apache.spark._ import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.HttpBroadcast -import org.apache.spark.network.nio.{GetBlock, GotBlock, PutBlock} import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ -import org.apache.spark.util.BoundedPriorityQueue +import org.apache.spark.util.{Utils, BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf} import org.apache.spark.util.collection.CompactBuffer /** @@ -73,6 +75,8 @@ class KryoSerializer(conf: SparkConf) .split(',') .filter(!_.isEmpty) + private val avroSchemas = conf.getAvroSchema + def newKryoOutput(): KryoOutput = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) def newKryo(): Kryo = { @@ -94,12 +98,18 @@ class KryoSerializer(conf: SparkConf) // For results returned by asJavaIterable. See JavaIterableWrapperSerializer. kryo.register(JavaIterableWrapperSerializer.wrapperClass, new JavaIterableWrapperSerializer) - // Allow sending SerializableWritable + // Allow sending classes with custom Java serializers kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) + kryo.register(classOf[SerializableConfiguration], new KryoJavaSerializer()) + kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer()) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) + kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas)) + kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas)) + try { + // scalastyle:off classforname // Use the default classloader when calling the user registrator. Thread.currentThread.setContextClassLoader(classLoader) // Register classes given through spark.kryo.classesToRegister. @@ -109,6 +119,7 @@ class KryoSerializer(conf: SparkConf) userRegistrator .map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]) .foreach { reg => reg.registerClasses(kryo) } + // scalastyle:on classforname } catch { case e: Exception => throw new SparkException(s"Failed to register classes with Kryo", e) @@ -120,6 +131,38 @@ class KryoSerializer(conf: SparkConf) // our code override the generic serializers in Chill for things like Seq new AllScalaRegistrar().apply(kryo) + // Register types missed by Chill. + // scalastyle:off + kryo.register(classOf[Array[Tuple1[Any]]]) + kryo.register(classOf[Array[Tuple2[Any, Any]]]) + kryo.register(classOf[Array[Tuple3[Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple4[Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple5[Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple6[Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple7[Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple8[Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple9[Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + + // scalastyle:on + + kryo.register(None.getClass) + kryo.register(Nil.getClass) + kryo.register(Utils.classForName("scala.collection.immutable.$colon$colon")) + kryo.register(classOf[ArrayBuffer[Any]]) + kryo.setClassLoader(classLoader) kryo } @@ -318,9 +361,6 @@ private[serializer] object KryoSerializer { private val toRegister: Seq[Class[_]] = Seq( ByteBuffer.allocate(1).getClass, classOf[StorageLevel], - classOf[PutBlock], - classOf[GotBlock], - classOf[GetBlock], classOf[CompressedMapStatus], classOf[HighlyCompressedMapStatus], classOf[RoaringBitmap], @@ -363,16 +403,15 @@ private class JavaIterableWrapperSerializer override def read(kryo: Kryo, in: KryoInput, clz: Class[java.lang.Iterable[_]]) : java.lang.Iterable[_] = { kryo.readClassAndObject(in) match { - case scalaIterable: Iterable[_] => - scala.collection.JavaConversions.asJavaIterable(scalaIterable) - case javaIterable: java.lang.Iterable[_] => - javaIterable + case scalaIterable: Iterable[_] => scalaIterable.asJava + case javaIterable: java.lang.Iterable[_] => javaIterable } } } private object JavaIterableWrapperSerializer extends Logging { - // The class returned by asJavaIterable (scala.collection.convert.Wrappers$IterableWrapper). + // The class returned by JavaConverters.asJava + // (scala.collection.convert.Wrappers$IterableWrapper). val wrapperClass = scala.collection.convert.WrapAsJava.asJavaIterable(Seq(1)).getClass diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala index cc2f0506817d..a1b1e1631eaf 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -407,7 +407,9 @@ private[spark] object SerializationDebugger extends Logging { /** ObjectStreamClass$ClassDataSlot.desc field */ val DescField: Field = { + // scalastyle:off classforname val f = Class.forName("java.io.ObjectStreamClass$ClassDataSlot").getDeclaredField("desc") + // scalastyle:on classforname f.setAccessible(true) f } diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index 6c3b3080d260..c057de9b3f4d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -21,7 +21,7 @@ import java.io.File import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.{Logging, SparkConf, SparkEnv} import org.apache.spark.executor.ShuffleWriteMetrics @@ -35,7 +35,7 @@ import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVecto /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { - val writers: Array[BlockObjectWriter] + val writers: Array[DiskBlockObjectWriter] /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */ def releaseWriters(success: Boolean) @@ -113,15 +113,15 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) val openStartTime = System.nanoTime val serializerInstance = serializer.newInstance() - val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) { + val writers: Array[DiskBlockObjectWriter] = if (consolidateShuffleFiles) { fileGroup = getUnusedFileGroup() - Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => + Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializerInstance, bufferSize, writeMetrics) } } else { - Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => + Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) val blockFile = blockManager.diskBlockManager.getFile(blockId) // Because of previous failures, the shuffle file may already exist on this machine. @@ -210,11 +210,13 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) shuffleStates.get(shuffleId) match { case Some(state) => if (consolidateShuffleFiles) { - for (fileGroup <- state.allFileGroups; file <- fileGroup.files) { + for (fileGroup <- state.allFileGroups.asScala; + file <- fileGroup.files) { file.delete() } } else { - for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) { + for (mapId <- state.completedMapTasks.asScala; + reduceId <- 0 until state.numBuckets) { val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) blockManager.diskBlockManager.getFile(blockId).delete() } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index d9c63b6e7bbb..d0163d326dba 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -71,7 +71,7 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB /** * Write an index file with the offsets of each block, plus a final offset at the end for the - * end of the output file. This will be used by getBlockLocation to figure out where each block + * end of the output file. This will be used by getBlockData to figure out where each block * begins and ends. * */ def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = { @@ -114,7 +114,7 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB } private[spark] object IndexShuffleBlockResolver { - // No-op reduce ID used in interactions with disk store and BlockObjectWriter. + // No-op reduce ID used in interactions with disk store and DiskBlockObjectWriter. // The disk store currently expects puts to relate to a (map, reduce) pair, but in the sort // shuffle outputs for several reduces are glommed into a single file. // TODO: Avoid this entirely by having the DiskBlockObjectWriter not require a BlockId. diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala index 3bcc7178a3d8..a0d8abc2eecb 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -19,108 +19,167 @@ package org.apache.spark.shuffle import scala.collection.mutable -import org.apache.spark.{Logging, SparkException, SparkConf} +import com.google.common.annotations.VisibleForTesting + +import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext} /** - * Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling + * Allocates a pool of memory to tasks for use in shuffle operations. Each disk-spilling * collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory * from this pool and release it as it spills data out. When a task ends, all its memory will be * released by the Executor. * - * This class tries to ensure that each thread gets a reasonable share of memory, instead of some - * thread ramping up to a large amount first and then causing others to spill to disk repeatedly. - * If there are N threads, it ensures that each thread can acquire at least 1 / 2N of the memory + * This class tries to ensure that each task gets a reasonable share of memory, instead of some + * task ramping up to a large amount first and then causing others to spill to disk repeatedly. + * If there are N tasks, it ensures that each tasks can acquire at least 1 / 2N of the memory * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the - * set of active threads and redo the calculations of 1 / 2N and 1 / N in waiting threads whenever + * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever * this set changes. This is all done by synchronizing access on "this" to mutate state and using * wait() and notifyAll() to signal changes. + * + * Use `ShuffleMemoryManager.create()` factory method to create a new instance. + * + * @param maxMemory total amount of memory available for execution, in bytes. + * @param pageSizeBytes number of bytes for each page, by default. */ -private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { - private val threadMemory = new mutable.HashMap[Long, Long]() // threadId -> memory bytes +private[spark] +class ShuffleMemoryManager protected ( + val maxMemory: Long, + val pageSizeBytes: Long) + extends Logging { - def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf)) + private val taskMemory = new mutable.HashMap[Long, Long]() // taskAttemptId -> memory bytes + + private def currentTaskAttemptId(): Long = { + // In case this is called on the driver, return an invalid task attempt id. + Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L) + } /** - * Try to acquire up to numBytes memory for the current thread, and return the number of bytes + * Try to acquire up to numBytes memory for the current task, and return the number of bytes * obtained, or 0 if none can be allocated. This call may block until there is enough free memory - * in some situations, to make sure each thread has a chance to ramp up to at least 1 / 2N of the - * total memory pool (where N is the # of active threads) before it is forced to spill. This can - * happen if the number of threads increases but an older thread had a lot of memory already. + * in some situations, to make sure each task has a chance to ramp up to at least 1 / 2N of the + * total memory pool (where N is the # of active tasks) before it is forced to spill. This can + * happen if the number of tasks increases but an older task had a lot of memory already. */ def tryToAcquire(numBytes: Long): Long = synchronized { - val threadId = Thread.currentThread().getId + val taskAttemptId = currentTaskAttemptId() assert(numBytes > 0, "invalid number of bytes requested: " + numBytes) - // Add this thread to the threadMemory map just so we can keep an accurate count of the number - // of active threads, to let other threads ramp down their memory in calls to tryToAcquire - if (!threadMemory.contains(threadId)) { - threadMemory(threadId) = 0L - notifyAll() // Will later cause waiting threads to wake up and check numThreads again + // Add this task to the taskMemory map just so we can keep an accurate count of the number + // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire + if (!taskMemory.contains(taskAttemptId)) { + taskMemory(taskAttemptId) = 0L + notifyAll() // Will later cause waiting tasks to wake up and check numThreads again } // Keep looping until we're either sure that we don't want to grant this request (because this - // thread would have more than 1 / numActiveThreads of the memory) or we have enough free - // memory to give it (we always let each thread get at least 1 / (2 * numActiveThreads)). + // task would have more than 1 / numActiveTasks of the memory) or we have enough free + // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)). while (true) { - val numActiveThreads = threadMemory.keys.size - val curMem = threadMemory(threadId) - val freeMemory = maxMemory - threadMemory.values.sum + val numActiveTasks = taskMemory.keys.size + val curMem = taskMemory(taskAttemptId) + val freeMemory = maxMemory - taskMemory.values.sum - // How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads; + // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks; // don't let it be negative - val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveThreads) - curMem)) + val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveTasks) - curMem)) - if (curMem < maxMemory / (2 * numActiveThreads)) { - // We want to let each thread get at least 1 / (2 * numActiveThreads) before blocking; - // if we can't give it this much now, wait for other threads to free up memory - // (this happens if older threads allocated lots of memory before N grew) - if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) { + if (curMem < maxMemory / (2 * numActiveTasks)) { + // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; + // if we can't give it this much now, wait for other tasks to free up memory + // (this happens if older tasks allocated lots of memory before N grew) + if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveTasks) - curMem)) { val toGrant = math.min(maxToGrant, freeMemory) - threadMemory(threadId) += toGrant + taskMemory(taskAttemptId) += toGrant return toGrant } else { - logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free") + logInfo( + s"TID $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free") wait() } } else { // Only give it as much memory as is free, which might be none if it reached 1 / numThreads val toGrant = math.min(maxToGrant, freeMemory) - threadMemory(threadId) += toGrant + taskMemory(taskAttemptId) += toGrant return toGrant } } 0L // Never reached } - /** Release numBytes bytes for the current thread. */ + /** Release numBytes bytes for the current task. */ def release(numBytes: Long): Unit = synchronized { - val threadId = Thread.currentThread().getId - val curMem = threadMemory.getOrElse(threadId, 0L) + val taskAttemptId = currentTaskAttemptId() + val curMem = taskMemory.getOrElse(taskAttemptId, 0L) if (curMem < numBytes) { throw new SparkException( - s"Internal error: release called on ${numBytes} bytes but thread only has ${curMem}") + s"Internal error: release called on ${numBytes} bytes but task only has ${curMem}") } - threadMemory(threadId) -= numBytes + taskMemory(taskAttemptId) -= numBytes notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed } - /** Release all memory for the current thread and mark it as inactive (e.g. when a task ends). */ - def releaseMemoryForThisThread(): Unit = synchronized { - val threadId = Thread.currentThread().getId - threadMemory.remove(threadId) + /** Release all memory for the current task and mark it as inactive (e.g. when a task ends). */ + def releaseMemoryForThisTask(): Unit = synchronized { + val taskAttemptId = currentTaskAttemptId() + taskMemory.remove(taskAttemptId) notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed } + + /** Returns the memory consumption, in bytes, for the current task */ + def getMemoryConsumptionForThisTask(): Long = synchronized { + val taskAttemptId = currentTaskAttemptId() + taskMemory.getOrElse(taskAttemptId, 0L) + } } -private object ShuffleMemoryManager { + +private[spark] object ShuffleMemoryManager { + + def create(conf: SparkConf, numCores: Int): ShuffleMemoryManager = { + val maxMemory = ShuffleMemoryManager.getMaxMemory(conf) + val pageSize = ShuffleMemoryManager.getPageSize(conf, maxMemory, numCores) + new ShuffleMemoryManager(maxMemory, pageSize) + } + + def create(maxMemory: Long, pageSizeBytes: Long): ShuffleMemoryManager = { + new ShuffleMemoryManager(maxMemory, pageSizeBytes) + } + + @VisibleForTesting + def createForTesting(maxMemory: Long): ShuffleMemoryManager = { + new ShuffleMemoryManager(maxMemory, 4 * 1024 * 1024) + } + /** * Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction * of the memory pool and a safety factor since collections can sometimes grow bigger than * the size we target before we estimate their sizes again. */ - def getMaxMemory(conf: SparkConf): Long = { + private def getMaxMemory(conf: SparkConf): Long = { val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2) val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8) (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong } + + /** + * Sets the page size, in bytes. + * + * If user didn't explicitly set "spark.buffer.pageSize", we figure out the default value + * by looking at the number of cores available to the process, and the total amount of memory, + * and then divide it by a factor of safety. + */ + private def getPageSize(conf: SparkConf, maxMemory: Long, numCores: Int): Long = { + val minPageSize = 1L * 1024 * 1024 // 1MB + val maxPageSize = 64L * minPageSize // 64MB + val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors() + // Because of rounding to next power of 2, we may have safetyFactor as 8 in worst case + val safetyFactor = 16 + // TODO(davies): don't round to next power of 2 + val size = ByteArrayMethods.nextPowerOf2(maxMemory / cores / safetyFactor) + val default = math.min(maxPageSize, math.max(minPageSize, size)) + conf.getSizeAsBytes("spark.buffer.pageSize", default) + } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala deleted file mode 100644 index 597d46a3d222..000000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.hash - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.util.{Failure, Success, Try} - -import org.apache.spark._ -import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} -import org.apache.spark.util.CompletionIterator - -private[hash] object BlockStoreShuffleFetcher extends Logging { - def fetch[T]( - shuffleId: Int, - reduceId: Int, - context: TaskContext, - serializer: Serializer) - : Iterator[T] = - { - logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) - val blockManager = SparkEnv.get.blockManager - - val startTime = System.currentTimeMillis - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId) - logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( - shuffleId, reduceId, System.currentTimeMillis - startTime)) - - val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]] - for (((address, size), index) <- statuses.zipWithIndex) { - splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size)) - } - - val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map { - case (address, splits) => - (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) - } - - def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = { - val blockId = blockPair._1 - val blockOption = blockPair._2 - blockOption match { - case Success(block) => { - block.asInstanceOf[Iterator[T]] - } - case Failure(e) => { - blockId match { - case ShuffleBlockId(shufId, mapId, _) => - val address = statuses(mapId.toInt)._1 - throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) - case _ => - throw new SparkException( - "Failed to get block " + blockId + ", which is not a shuffle block", e) - } - } - } - } - - val blockFetcherItr = new ShuffleBlockFetcherIterator( - context, - SparkEnv.get.blockManager.shuffleClient, - blockManager, - blocksByAddress, - serializer, - // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) - val itr = blockFetcherItr.flatMap(unpackBlock) - - val completionIter = CompletionIterator[T, Iterator[T]](itr, { - context.taskMetrics.updateShuffleReadMetrics() - }) - - new InterruptibleIterator[T](context, completionIter) { - val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() - override def next(): T = { - readMetrics.incRecordsRead(1) - delegate.next() - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 41bafabde05b..0c8f08f0f3b1 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,18 +17,22 @@ package org.apache.spark.shuffle.hash -import org.apache.spark.{InterruptibleIterator, TaskContext} +import org.apache.spark._ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} +import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, - context: TaskContext) - extends ShuffleReader[K, C] -{ + context: TaskContext, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) + extends ShuffleReader[K, C] with Logging { + require(endPartition == startPartition + 1, "Hash shuffle currently only supports fetching one partition") @@ -36,20 +40,57 @@ private[spark] class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { + val blockFetcherItr = new ShuffleBlockFetcherIterator( + context, + blockManager.shuffleClient, + blockManager, + mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition), + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) + + // Wrap the streams for compression based on configuration + val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) => + blockManager.wrapForCompression(blockId, inputStream) + } + val ser = Serializer.getSerializer(dep.serializer) - val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser) + val serializerInstance = ser.newInstance() + + // Create a key/value iterator for each stream + val recordIter = wrappedStreams.flatMap { wrappedStream => + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } + + // Update the context task metrics for each record read. + val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + recordIter.map(record => { + readMetrics.incRecordsRead(1) + record + }), + context.taskMetrics().updateShuffleReadMetrics()) + + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { - new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context)) + // We are reading values that are already combined + val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] + dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) } else { - new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context)) + // We don't know the value type, but also don't care -- the dependency *should* + // have made sure its compatible w/ this aggregator, which will convert the value + // type to the combined type C + val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) } } else { require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") - - // Convert the Product2s to pairs since this is what downstream RDDs currently expect - iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2)) + interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] } // Sort the output if there is a sort ordering defined. @@ -59,8 +100,10 @@ private[spark] class HashShuffleReader[K, C]( // the ExternalSorter won't spill to disk. val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser)) sorter.insertAll(aggregatedIter) - context.taskMetrics.incMemoryBytesSpilled(sorter.memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) sorter.iterator case None => aggregatedIter diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index eb87cee15903..41df70c602c3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -22,7 +22,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle._ -import org.apache.spark.storage.BlockObjectWriter +import org.apache.spark.storage.DiskBlockObjectWriter private[spark] class HashShuffleWriter[K, V]( shuffleBlockResolver: FileShuffleBlockResolver, @@ -102,7 +102,7 @@ private[spark] class HashShuffleWriter[K, V]( private def commitWritesAndBuildStatus(): MapStatus = { // Commit the writes. Get the size of each bucket block (total block size). - val sizes: Array[Long] = shuffle.writers.map { writer: BlockObjectWriter => + val sizes: Array[Long] = shuffle.writers.map { writer: DiskBlockObjectWriter => writer.commitAndClose() writer.fileSegment().length } diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index df7bbd64247d..75f22f642b9d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -159,7 +159,7 @@ private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManage mapId: Int, context: TaskContext): ShuffleWriter[K, V] = { handle match { - case unsafeShuffleHandle: UnsafeShuffleHandle[K, V] => + case unsafeShuffleHandle: UnsafeShuffleHandle[K @unchecked, V @unchecked] => numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps) val env = SparkEnv.get new UnsafeShuffleWriter( diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 390c136df79b..24a0b5220695 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -127,7 +127,7 @@ private[v1] object AllStagesResource { new TaskData( taskId = uiData.taskInfo.taskId, index = uiData.taskInfo.index, - attempt = uiData.taskInfo.attempt, + attempt = uiData.taskInfo.attemptNumber, launchTime = new Date(uiData.taskInfo.launchTime), executorId = uiData.taskInfo.executorId, host = uiData.taskInfo.host, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetchException.scala similarity index 79% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala rename to core/src/main/scala/org/apache/spark/storage/BlockFetchException.scala index 568b7ac2c598..f6e46ae9a481 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetchException.scala @@ -15,9 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.execution +package org.apache.spark.storage -/** - * Package containing expressions that are specific to Spark runtime. - */ -package object expressions +import org.apache.spark.SparkException + +private[spark] +case class BlockFetchException(messages: String, throwable: Throwable) + extends SparkException(messages, throwable) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 1beafa177144..d31aa68eb695 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -23,6 +23,7 @@ import java.nio.{ByteBuffer, MappedByteBuffer} import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.concurrent.{ExecutionContext, Await, Future} import scala.concurrent.duration._ +import scala.util.control.NonFatal import scala.util.Random import sun.nio.ch.DirectBuffer @@ -93,8 +94,17 @@ private[spark] class BlockManager( // Port used by the external shuffle service. In Yarn mode, this may be already be // set through the Hadoop configuration as the server is launched in the Yarn NM. - private val externalShuffleServicePort = - Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt + private val externalShuffleServicePort = { + val tmpPort = Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt + if (tmpPort == 0) { + // for testing, we set "spark.shuffle.service.port" to 0 in the yarn config, so yarn finds + // an open port. But we still need to tell our spark apps the right port to use. So + // only if the yarn config has the port set to 0, we prefer the value in the spark config + conf.get("spark.shuffle.service.port").toInt + } else { + tmpPort + } + } // Check that we're not using external shuffle service with consolidated shuffle files. if (externalShuffleServiceEnabled @@ -191,6 +201,7 @@ private[spark] class BlockManager( executorId, blockTransferService.hostName, blockTransferService.port) shuffleServerId = if (externalShuffleServiceEnabled) { + logInfo(s"external shuffle service port = $externalShuffleServicePort") BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) } else { blockManagerId @@ -222,7 +233,7 @@ private[spark] class BlockManager( return } catch { case e: Exception if i < MAX_ATTEMPTS => - logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}}" + logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}" + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) Thread.sleep(SLEEP_TIME_SECS * 1000) } @@ -590,10 +601,26 @@ private[spark] class BlockManager( private def doGetRemote(blockId: BlockId, asBlockResult: Boolean): Option[Any] = { require(blockId != null, "BlockId is null") val locations = Random.shuffle(master.getLocations(blockId)) + var numFetchFailures = 0 for (loc <- locations) { logDebug(s"Getting remote block $blockId from $loc") - val data = blockTransferService.fetchBlockSync( - loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer() + val data = try { + blockTransferService.fetchBlockSync( + loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer() + } catch { + case NonFatal(e) => + numFetchFailures += 1 + if (numFetchFailures == locations.size) { + // An exception is thrown while fetching this block from all locations + throw new BlockFetchException(s"Failed to fetch block from" + + s" ${locations.size} locations. Most recent failure cause:", e) + } else { + // This location failed, so we retry fetch from a different one by returning null here + logWarning(s"Failed to fetch remote block $blockId " + + s"from $loc (failed attempt $numFetchFailures)", e) + null + } + } if (data != null) { if (asBlockResult) { @@ -648,7 +675,7 @@ private[spark] class BlockManager( file: File, serializerInstance: SerializerInstance, bufferSize: Int, - writeMetrics: ShuffleWriteMetrics): BlockObjectWriter = { + writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) new DiskBlockObjectWriter(blockId, file, serializerInstance, bufferSize, compressStream, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 7cdae22b0e25..f45bff34d4db 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -33,7 +33,7 @@ class BlockManagerMaster( isDriver: Boolean) extends Logging { - val timeout = RpcUtils.askTimeout(conf) + val timeout = RpcUtils.askRpcTimeout(conf) /** Remove a dead executor from the driver endpoint. This is only called on the driver side. */ def removeExecutor(execId: String) { @@ -69,8 +69,9 @@ class BlockManagerMaster( } /** Get locations of multiple blockIds from the driver */ - def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { - driverEndpoint.askWithRetry[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) + def getLocations(blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { + driverEndpoint.askWithRetry[IndexedSeq[Seq[BlockManagerId]]]( + GetLocationsMultipleBlockIds(blockIds)) } /** @@ -103,10 +104,10 @@ class BlockManagerMaster( val future = driverEndpoint.askWithRetry[Future[Seq[Int]]](RemoveRdd(rddId)) future.onFailure { case e: Exception => - logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}", e) + logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}", e) }(ThreadUtils.sameThread) if (blocking) { - Await.result(future, timeout) + timeout.awaitResult(future) } } @@ -115,10 +116,10 @@ class BlockManagerMaster( val future = driverEndpoint.askWithRetry[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) future.onFailure { case e: Exception => - logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}", e) + logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}", e) }(ThreadUtils.sameThread) if (blocking) { - Await.result(future, timeout) + timeout.awaitResult(future) } } @@ -129,10 +130,10 @@ class BlockManagerMaster( future.onFailure { case e: Exception => logWarning(s"Failed to remove broadcast $broadcastId" + - s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}", e) + s" with removeFromMaster = $removeFromMaster - ${e.getMessage}", e) }(ThreadUtils.sameThread) if (blocking) { - Await.result(future, timeout) + timeout.awaitResult(future) } } @@ -176,8 +177,8 @@ class BlockManagerMaster( CanBuildFrom[Iterable[Future[Option[BlockStatus]]], Option[BlockStatus], Iterable[Option[BlockStatus]]]] - val blockStatus = Await.result( - Future.sequence[Option[BlockStatus], Iterable](futures)(cbf, ThreadUtils.sameThread), timeout) + val blockStatus = timeout.awaitResult( + Future.sequence[Option[BlockStatus], Iterable](futures)(cbf, ThreadUtils.sameThread)) if (blockStatus == null) { throw new SparkException("BlockManager returned null for BlockStatus query: " + blockId) } @@ -199,7 +200,7 @@ class BlockManagerMaster( askSlaves: Boolean): Seq[BlockId] = { val msg = GetMatchingBlockIds(filter, askSlaves) val future = driverEndpoint.askWithRetry[Future[Seq[BlockId]]](msg) - Await.result(future, timeout) + timeout.awaitResult(future) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 68ed9096731c..7db6035553ae 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -21,7 +21,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.immutable.HashSet import scala.collection.mutable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, ThreadSafeRpcEndpoint} @@ -60,10 +60,11 @@ class BlockManagerMasterEndpoint( register(blockManagerId, maxMemSize, slaveEndpoint) context.reply(true) - case UpdateBlockInfo( + case _updateBlockInfo @ UpdateBlockInfo( blockManagerId, blockId, storageLevel, deserializedSize, size, externalBlockStoreSize) => context.reply(updateBlockInfo( blockManagerId, blockId, storageLevel, deserializedSize, size, externalBlockStoreSize)) + listenerBus.post(SparkListenerBlockUpdated(BlockUpdatedInfo(_updateBlockInfo))) case GetLocations(blockId) => context.reply(getLocations(blockId)) @@ -132,7 +133,7 @@ class BlockManagerMasterEndpoint( // Find all blocks for the given RDD, remove the block from both blockLocations and // the blockManagerInfo that is tracking the blocks. - val blocks = blockLocations.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) + val blocks = blockLocations.asScala.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) blocks.foreach { blockId => val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId) bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId))) @@ -241,7 +242,7 @@ class BlockManagerMasterEndpoint( private def storageStatus: Array[StorageStatus] = { blockManagerInfo.map { case (blockManagerId, info) => - new StorageStatus(blockManagerId, info.maxMem, info.blocks) + new StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala) }.toArray } @@ -291,7 +292,7 @@ class BlockManagerMasterEndpoint( if (askSlaves) { info.slaveEndpoint.ask[Seq[BlockId]](getMatchingBlockIds) } else { - Future { info.blocks.keys.filter(filter).toSeq } + Future { info.blocks.asScala.keys.filter(filter).toSeq } } future } @@ -371,7 +372,8 @@ class BlockManagerMasterEndpoint( if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty } - private def getLocationsMultipleBlockIds(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { + private def getLocationsMultipleBlockIds( + blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { blockIds.map(blockId => getLocations(blockId)) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala new file mode 100644 index 000000000000..2789e25b8d3a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import scala.collection.mutable + +import org.apache.spark.scheduler._ + +private[spark] case class BlockUIData( + blockId: BlockId, + location: String, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long, + externalBlockStoreSize: Long) + +/** + * The aggregated status of stream blocks in an executor + */ +private[spark] case class ExecutorStreamBlockStatus( + executorId: String, + location: String, + blocks: Seq[BlockUIData]) { + + def totalMemSize: Long = blocks.map(_.memSize).sum + + def totalDiskSize: Long = blocks.map(_.diskSize).sum + + def totalExternalBlockStoreSize: Long = blocks.map(_.externalBlockStoreSize).sum + + def numStreamBlocks: Int = blocks.size + +} + +private[spark] class BlockStatusListener extends SparkListener { + + private val blockManagers = + new mutable.HashMap[BlockManagerId, mutable.HashMap[BlockId, BlockUIData]] + + override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { + val blockId = blockUpdated.blockUpdatedInfo.blockId + if (!blockId.isInstanceOf[StreamBlockId]) { + // Now we only monitor StreamBlocks + return + } + val blockManagerId = blockUpdated.blockUpdatedInfo.blockManagerId + val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel + val memSize = blockUpdated.blockUpdatedInfo.memSize + val diskSize = blockUpdated.blockUpdatedInfo.diskSize + val externalBlockStoreSize = blockUpdated.blockUpdatedInfo.externalBlockStoreSize + + synchronized { + // Drop the update info if the block manager is not registered + blockManagers.get(blockManagerId).foreach { blocksInBlockManager => + if (storageLevel.isValid) { + blocksInBlockManager.put(blockId, + BlockUIData( + blockId, + blockManagerId.hostPort, + storageLevel, + memSize, + diskSize, + externalBlockStoreSize) + ) + } else { + // If isValid is not true, it means we should drop the block. + blocksInBlockManager -= blockId + } + } + } + } + + override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = { + synchronized { + blockManagers.put(blockManagerAdded.blockManagerId, mutable.HashMap()) + } + } + + override def onBlockManagerRemoved( + blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = synchronized { + blockManagers -= blockManagerRemoved.blockManagerId + } + + def allExecutorStreamBlockStatus: Seq[ExecutorStreamBlockStatus] = synchronized { + blockManagers.map { case (blockManagerId, blocks) => + ExecutorStreamBlockStatus( + blockManagerId.executorId, blockManagerId.hostPort, blocks.values.toSeq) + }.toSeq + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala new file mode 100644 index 000000000000..a5790e4454a8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.storage.BlockManagerMessages.UpdateBlockInfo + +/** + * :: DeveloperApi :: + * Stores information about a block status in a block manager. + */ +@DeveloperApi +case class BlockUpdatedInfo( + blockManagerId: BlockManagerId, + blockId: BlockId, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long, + externalBlockStoreSize: Long) + +private[spark] object BlockUpdatedInfo { + + private[spark] def apply(updateBlockInfo: UpdateBlockInfo): BlockUpdatedInfo = { + BlockUpdatedInfo( + updateBlockInfo.blockManagerId, + updateBlockInfo.blockId, + updateBlockInfo.storageLevel, + updateBlockInfo.memSize, + updateBlockInfo.diskSize, + updateBlockInfo.externalBlockStoreSize) + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 91ef86389a0c..f7e84a2c2e14 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -22,7 +22,7 @@ import java.io.{IOException, File} import org.apache.spark.{SparkConf, Logging} import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** * Creates and maintains the logical mapping between logical blocks and physical on-disk @@ -124,8 +124,13 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon (blockId, getFile(blockId)) } + /** + * Create local directories for storing block data. These directories are + * located inside configured local directories and won't + * be deleted on JVM exit when using the external shuffle service. + */ private def createLocalDirs(conf: SparkConf): Array[File] = { - Utils.getOrCreateLocalRootDirs(conf).flatMap { rootDir => + Utils.getConfiguredLocalDirs(conf).flatMap { rootDir => try { val localDir = Utils.createDirectory(rootDir, "blockmgr") logInfo(s"Created local directory at $localDir") @@ -139,7 +144,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon } private def addShutdownHook(): AnyRef = { - Utils.addShutdownHook(Utils.TEMP_DIR_SHUTDOWN_PRIORITY + 1) { () => + ShutdownHookManager.addShutdownHook(ShutdownHookManager.TEMP_DIR_SHUTDOWN_PRIORITY + 1) { () => logInfo("Shutdown hook called") DiskBlockManager.this.doStop() } @@ -149,7 +154,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon private[spark] def stop() { // Remove the shutdown hook. It causes memory leaks if we leave it around. try { - Utils.removeShutdownHook(shutdownHook) + ShutdownHookManager.removeShutdownHook(shutdownHook) } catch { case e: Exception => logError(s"Exception while removing shutdown hook.", e) @@ -159,11 +164,15 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon private def doStop(): Unit = { // Only perform cleanup if an external service is not serving our shuffle files. - if (!blockManager.externalShuffleServiceEnabled || blockManager.blockManagerId.isDriver) { + // Also blockManagerId could be null if block manager is not initialized properly. + if (!blockManager.externalShuffleServiceEnabled || + (blockManager.blockManagerId != null && blockManager.blockManagerId.isDriver)) { localDirs.foreach { localDir => if (localDir.isDirectory() && localDir.exists()) { try { - if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) + if (!ShutdownHookManager.hasRootAsShutdownDeleteDir(localDir)) { + Utils.deleteRecursively(localDir) + } } catch { case e: Exception => logError(s"Exception while deleting local spark dir: $localDir", e) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala similarity index 83% rename from core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala rename to core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index 7eeabd1e0489..49d9154f95a5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -26,66 +26,25 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.util.Utils /** - * An interface for writing JVM objects to some underlying storage. This interface allows - * appending data to an existing block, and can guarantee atomicity in the case of faults - * as it allows the caller to revert partial writes. + * A class for writing JVM objects directly to a file on disk. This class allows data to be appended + * to an existing block and can guarantee atomicity in the case of faults as it allows the caller to + * revert partial writes. * - * This interface does not support concurrent writes. Also, once the writer has - * been opened, it cannot be reopened again. - */ -private[spark] abstract class BlockObjectWriter(val blockId: BlockId) extends OutputStream { - - def open(): BlockObjectWriter - - def close() - - def isOpen: Boolean - - /** - * Flush the partial writes and commit them as a single atomic block. - */ - def commitAndClose(): Unit - - /** - * Reverts writes that haven't been flushed yet. Callers should invoke this function - * when there are runtime exceptions. This method will not throw, though it may be - * unsuccessful in truncating written data. - */ - def revertPartialWritesAndClose() - - /** - * Writes a key-value pair. - */ - def write(key: Any, value: Any) - - /** - * Notify the writer that a record worth of bytes has been written with OutputStream#write. - */ - def recordWritten() - - /** - * Returns the file segment of committed data that this Writer has written. - * This is only valid after commitAndClose() has been called. - */ - def fileSegment(): FileSegment -} - -/** - * BlockObjectWriter which writes directly to a file on disk. Appends to the given file. + * This class does not support concurrent writes. Also, once the writer has been opened it cannot be + * reopened again. */ private[spark] class DiskBlockObjectWriter( - blockId: BlockId, + val blockId: BlockId, file: File, serializerInstance: SerializerInstance, bufferSize: Int, compressStream: OutputStream => OutputStream, syncWrites: Boolean, - // These write metrics concurrently shared with other active BlockObjectWriter's who + // These write metrics concurrently shared with other active DiskBlockObjectWriters who // are themselves performing writes. All updates must be relative. writeMetrics: ShuffleWriteMetrics) - extends BlockObjectWriter(blockId) - with Logging -{ + extends OutputStream + with Logging { /** The file channel, used for repositioning / truncating the file. */ private var channel: FileChannel = null @@ -122,7 +81,7 @@ private[spark] class DiskBlockObjectWriter( */ private var numRecordsWritten = 0 - override def open(): BlockObjectWriter = { + def open(): DiskBlockObjectWriter = { if (hasBeenClosed) { throw new IllegalStateException("Writer already closed. Cannot be reopened.") } @@ -159,9 +118,12 @@ private[spark] class DiskBlockObjectWriter( } } - override def isOpen: Boolean = objOut != null + def isOpen: Boolean = objOut != null - override def commitAndClose(): Unit = { + /** + * Flush the partial writes and commit them as a single atomic block. + */ + def commitAndClose(): Unit = { if (initialized) { // NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the // serializer stream and the lower level stream. @@ -177,9 +139,15 @@ private[spark] class DiskBlockObjectWriter( commitAndCloseHasBeenCalled = true } - // Discard current writes. We do this by flushing the outstanding writes and then - // truncating the file to its initial position. - override def revertPartialWritesAndClose() { + + /** + * Reverts writes that haven't been flushed yet. Callers should invoke this function + * when there are runtime exceptions. This method will not throw, though it may be + * unsuccessful in truncating written data. + */ + def revertPartialWritesAndClose() { + // Discard current writes. We do this by flushing the outstanding writes and then + // truncating the file to its initial position. try { if (initialized) { writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) @@ -201,7 +169,10 @@ private[spark] class DiskBlockObjectWriter( } } - override def write(key: Any, value: Any) { + /** + * Writes a key-value pair. + */ + def write(key: Any, value: Any) { if (!initialized) { open() } @@ -221,7 +192,10 @@ private[spark] class DiskBlockObjectWriter( bs.write(kvBytes, offs, len) } - override def recordWritten(): Unit = { + /** + * Notify the writer that a record worth of bytes has been written with OutputStream#write. + */ + def recordWritten(): Unit = { numRecordsWritten += 1 writeMetrics.incShuffleRecordsWritten(1) @@ -230,7 +204,11 @@ private[spark] class DiskBlockObjectWriter( } } - override def fileSegment(): FileSegment = { + /** + * Returns the file segment of committed data that this Writer has written. + * This is only valid after commitAndClose() has been called. + */ + def fileSegment(): FileSegment = { if (!commitAndCloseHasBeenCalled) { throw new IllegalStateException( "fileSegment() is only valid after commitAndClose() has been called") diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala index 291394ed3481..db965d54bafd 100644 --- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala @@ -192,7 +192,7 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: .getOrElse(ExternalBlockStore.DEFAULT_BLOCK_MANAGER_NAME) try { - val instance = Class.forName(clsName) + val instance = Utils.classForName(clsName) .newInstance() .asInstanceOf[ExternalBlockManager] instance.init(blockManager, executorId) diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index ed609772e697..6f27f00307f8 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -23,6 +23,7 @@ import java.util.LinkedHashMap import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark.TaskContext import org.apache.spark.util.{SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector @@ -43,11 +44,11 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Ensure only one thread is putting, and if necessary, dropping blocks at any given time private val accountingLock = new Object - // A mapping from thread ID to amount of memory used for unrolling a block (in bytes) + // A mapping from taskAttemptId to amount of memory used for unrolling a block (in bytes) // All accesses of this map are assumed to have manually synchronized on `accountingLock` private val unrollMemoryMap = mutable.HashMap[Long, Long]() // Same as `unrollMemoryMap`, but for pending unroll memory as defined below. - // Pending unroll memory refers to the intermediate memory occupied by a thread + // Pending unroll memory refers to the intermediate memory occupied by a task // after the unroll but before the actual putting of the block in the cache. // This chunk of memory is expected to be released *as soon as* we finish // caching the corresponding block as opposed to until after the task finishes. @@ -250,21 +251,21 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) var elementsUnrolled = 0 // Whether there is still enough memory for us to continue unrolling this block var keepUnrolling = true - // Initial per-thread memory to request for unrolling blocks (bytes). Exposed for testing. + // Initial per-task memory to request for unrolling blocks (bytes). Exposed for testing. val initialMemoryThreshold = unrollMemoryThreshold // How often to check whether we need to request more memory val memoryCheckPeriod = 16 - // Memory currently reserved by this thread for this particular unrolling operation + // Memory currently reserved by this task for this particular unrolling operation var memoryThreshold = initialMemoryThreshold // Memory to request as a multiple of current vector size val memoryGrowthFactor = 1.5 - // Previous unroll memory held by this thread, for releasing later (only at the very end) - val previousMemoryReserved = currentUnrollMemoryForThisThread + // Previous unroll memory held by this task, for releasing later (only at the very end) + val previousMemoryReserved = currentUnrollMemoryForThisTask // Underlying vector for unrolling the block var vector = new SizeTrackingVector[Any] // Request enough memory to begin unrolling - keepUnrolling = reserveUnrollMemoryForThisThread(initialMemoryThreshold) + keepUnrolling = reserveUnrollMemoryForThisTask(initialMemoryThreshold) if (!keepUnrolling) { logWarning(s"Failed to reserve initial memory threshold of " + @@ -283,7 +284,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Hold the accounting lock, in case another thread concurrently puts a block that // takes up the unrolling space we just ensured here accountingLock.synchronized { - if (!reserveUnrollMemoryForThisThread(amountToRequest)) { + if (!reserveUnrollMemoryForThisTask(amountToRequest)) { // If the first request is not granted, try again after ensuring free space // If there is still not enough space, give up and drop the partition val spaceToEnsure = maxUnrollMemory - currentUnrollMemory @@ -291,7 +292,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) val result = ensureFreeSpace(blockId, spaceToEnsure) droppedBlocks ++= result.droppedBlocks } - keepUnrolling = reserveUnrollMemoryForThisThread(amountToRequest) + keepUnrolling = reserveUnrollMemoryForThisTask(amountToRequest) } } // New threshold is currentSize * memoryGrowthFactor @@ -317,9 +318,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // later when the task finishes. if (keepUnrolling) { accountingLock.synchronized { - val amountToRelease = currentUnrollMemoryForThisThread - previousMemoryReserved - releaseUnrollMemoryForThisThread(amountToRelease) - reservePendingUnrollMemoryForThisThread(amountToRelease) + val amountToRelease = currentUnrollMemoryForThisTask - previousMemoryReserved + releaseUnrollMemoryForThisTask(amountToRelease) + reservePendingUnrollMemoryForThisTask(amountToRelease) } } } @@ -397,7 +398,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } } // Release the unroll memory used because we no longer need the underlying Array - releasePendingUnrollMemoryForThisThread() + releasePendingUnrollMemoryForThisTask() } ResultWithDroppedBlocks(putSuccess, droppedBlocks) } @@ -427,9 +428,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Take into account the amount of memory currently occupied by unrolling blocks // and minus the pending unroll memory for that block on current thread. - val threadId = Thread.currentThread().getId + val taskAttemptId = currentTaskAttemptId() val actualFreeMemory = freeMemory - currentUnrollMemory + - pendingUnrollMemoryMap.getOrElse(threadId, 0L) + pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) if (actualFreeMemory < space) { val rddToAdd = getRddId(blockIdToAdd) @@ -455,7 +456,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) logInfo(s"${selectedBlocks.size} blocks selected for dropping") for (blockId <- selectedBlocks) { val entry = entries.synchronized { entries.get(blockId) } - // This should never be null as only one thread should be dropping + // This should never be null as only one task should be dropping // blocks and removing entries. However the check is still here for // future safety. if (entry != null) { @@ -482,79 +483,85 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) entries.synchronized { entries.containsKey(blockId) } } + private def currentTaskAttemptId(): Long = { + // In case this is called on the driver, return an invalid task attempt id. + Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L) + } + /** - * Reserve additional memory for unrolling blocks used by this thread. + * Reserve additional memory for unrolling blocks used by this task. * Return whether the request is granted. */ - def reserveUnrollMemoryForThisThread(memory: Long): Boolean = { + def reserveUnrollMemoryForThisTask(memory: Long): Boolean = { accountingLock.synchronized { val granted = freeMemory > currentUnrollMemory + memory if (granted) { - val threadId = Thread.currentThread().getId - unrollMemoryMap(threadId) = unrollMemoryMap.getOrElse(threadId, 0L) + memory + val taskAttemptId = currentTaskAttemptId() + unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory } granted } } /** - * Release memory used by this thread for unrolling blocks. - * If the amount is not specified, remove the current thread's allocation altogether. + * Release memory used by this task for unrolling blocks. + * If the amount is not specified, remove the current task's allocation altogether. */ - def releaseUnrollMemoryForThisThread(memory: Long = -1L): Unit = { - val threadId = Thread.currentThread().getId + def releaseUnrollMemoryForThisTask(memory: Long = -1L): Unit = { + val taskAttemptId = currentTaskAttemptId() accountingLock.synchronized { if (memory < 0) { - unrollMemoryMap.remove(threadId) + unrollMemoryMap.remove(taskAttemptId) } else { - unrollMemoryMap(threadId) = unrollMemoryMap.getOrElse(threadId, memory) - memory - // If this thread claims no more unroll memory, release it completely - if (unrollMemoryMap(threadId) <= 0) { - unrollMemoryMap.remove(threadId) + unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, memory) - memory + // If this task claims no more unroll memory, release it completely + if (unrollMemoryMap(taskAttemptId) <= 0) { + unrollMemoryMap.remove(taskAttemptId) } } } } /** - * Reserve the unroll memory of current unroll successful block used by this thread + * Reserve the unroll memory of current unroll successful block used by this task * until actually put the block into memory entry. */ - def reservePendingUnrollMemoryForThisThread(memory: Long): Unit = { - val threadId = Thread.currentThread().getId + def reservePendingUnrollMemoryForThisTask(memory: Long): Unit = { + val taskAttemptId = currentTaskAttemptId() accountingLock.synchronized { - pendingUnrollMemoryMap(threadId) = pendingUnrollMemoryMap.getOrElse(threadId, 0L) + memory + pendingUnrollMemoryMap(taskAttemptId) = + pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory } } /** - * Release pending unroll memory of current unroll successful block used by this thread + * Release pending unroll memory of current unroll successful block used by this task */ - def releasePendingUnrollMemoryForThisThread(): Unit = { - val threadId = Thread.currentThread().getId + def releasePendingUnrollMemoryForThisTask(): Unit = { + val taskAttemptId = currentTaskAttemptId() accountingLock.synchronized { - pendingUnrollMemoryMap.remove(threadId) + pendingUnrollMemoryMap.remove(taskAttemptId) } } /** - * Return the amount of memory currently occupied for unrolling blocks across all threads. + * Return the amount of memory currently occupied for unrolling blocks across all tasks. */ def currentUnrollMemory: Long = accountingLock.synchronized { unrollMemoryMap.values.sum + pendingUnrollMemoryMap.values.sum } /** - * Return the amount of memory currently occupied for unrolling blocks by this thread. + * Return the amount of memory currently occupied for unrolling blocks by this task. */ - def currentUnrollMemoryForThisThread: Long = accountingLock.synchronized { - unrollMemoryMap.getOrElse(Thread.currentThread().getId, 0L) + def currentUnrollMemoryForThisTask: Long = accountingLock.synchronized { + unrollMemoryMap.getOrElse(currentTaskAttemptId(), 0L) } /** - * Return the number of threads currently unrolling blocks. + * Return the number of tasks currently unrolling blocks. */ - def numThreadsUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size } + def numTasksUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size } /** * Log information about current memory usage. @@ -566,7 +573,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) logInfo( s"Memory use = ${Utils.bytesToString(blocksMemory)} (blocks) + " + s"${Utils.bytesToString(unrollMemory)} (scratch space shared across " + - s"$numThreadsUnrolling thread(s)) = ${Utils.bytesToString(totalMemory)}. " + + s"$numTasksUnrolling tasks(s)) = ${Utils.bytesToString(totalMemory)}. " + s"Storage limit = ${Utils.bytesToString(maxMemory)}." ) } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index d0faab62c9e9..0d0448feb5b0 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,23 +17,24 @@ package org.apache.spark.storage +import java.io.InputStream import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} -import scala.util.{Failure, Try} +import scala.util.control.NonFatal -import org.apache.spark.{Logging, TaskContext} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.{Logging, SparkException, TaskContext} import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.serializer.{SerializerInstance, Serializer} -import org.apache.spark.util.{CompletionIterator, Utils} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.util.Utils /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block * manager. For remote blocks, it fetches them using the provided BlockTransferService. * - * This creates an iterator of (BlockID, values) tuples so the caller can handle blocks in a - * pipelined fashion as they are received. + * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks + * in a pipelined fashion as they are received. * * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid * using too much memory. @@ -44,7 +45,6 @@ import org.apache.spark.util.{CompletionIterator, Utils} * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in * order to throttle the memory usage. - * @param serializer serializer used to deserialize the data. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. */ private[spark] @@ -53,9 +53,8 @@ final class ShuffleBlockFetcherIterator( shuffleClient: ShuffleClient, blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer, maxBytesInFlight: Long) - extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging { + extends Iterator[(BlockId, InputStream)] with Logging { import ShuffleBlockFetcherIterator._ @@ -83,7 +82,7 @@ final class ShuffleBlockFetcherIterator( /** * A queue to hold our results. This turns the asynchronous model provided by - * [[BlockTransferService]] into a synchronous model (iterator). + * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). */ private[this] val results = new LinkedBlockingQueue[FetchResult] @@ -102,9 +101,7 @@ final class ShuffleBlockFetcherIterator( /** Current bytes in flight from our requests */ private[this] var bytesInFlight = 0L - private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() - - private[this] val serializerInstance: SerializerInstance = serializer.newInstance() + private[this] val shuffleMetrics = context.taskMetrics().createShuffleReadMetricsForDependency() /** * Whether the iterator is still active. If isZombie is true, the callback interface will no @@ -114,23 +111,29 @@ final class ShuffleBlockFetcherIterator( initialize() - /** - * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. - */ - private[this] def cleanup() { - isZombie = true + // Decrements the buffer reference count. + // The currentResult is set to null to prevent releasing the buffer again on cleanup() + private[storage] def releaseCurrentResultBuffer(): Unit = { // Release the current buffer if necessary currentResult match { - case SuccessFetchResult(_, _, buf) => buf.release() + case SuccessFetchResult(_, _, _, buf) => buf.release() case _ => } + currentResult = null + } + /** + * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. + */ + private[this] def cleanup() { + isZombie = true + releaseCurrentResultBuffer() // Release buffers in the results queue val iter = results.iterator() while (iter.hasNext) { val result = iter.next() result match { - case SuccessFetchResult(_, _, buf) => buf.release() + case SuccessFetchResult(_, _, _, buf) => buf.release() case _ => } } @@ -155,7 +158,7 @@ final class ShuffleBlockFetcherIterator( // Increment the ref count because we need to pass this to a different thread. // This needs to be released after use. buf.retain() - results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf)) + results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf)) shuffleMetrics.incRemoteBytesRead(buf.size) shuffleMetrics.incRemoteBlocksFetched(1) } @@ -164,7 +167,7 @@ final class ShuffleBlockFetcherIterator( override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) - results.put(new FailureFetchResult(BlockId(blockId), e)) + results.put(new FailureFetchResult(BlockId(blockId), address, e)) } } ) @@ -236,12 +239,12 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() - results.put(new SuccessFetchResult(blockId, 0, buf)) + results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf)) } catch { case e: Exception => // If we see an exception, stop immediately. logError(s"Error occurred while fetching local blocks", e) - results.put(new FailureFetchResult(blockId, e)) + results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e)) return } } @@ -257,10 +260,7 @@ final class ShuffleBlockFetcherIterator( fetchRequests ++= Utils.randomize(remoteRequests) // Send out initial requests for blocks, up to our maxBytesInFlight - while (fetchRequests.nonEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) - } + fetchUpToMaxBytes() val numFetches = remoteRequests.size - fetchRequests.size logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime)) @@ -272,7 +272,15 @@ final class ShuffleBlockFetcherIterator( override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch - override def next(): (BlockId, Try[Iterator[Any]]) = { + /** + * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers + * underlying each InputStream will be freed by the cleanup() method registered with the + * TaskCompletionListener. However, callers should close() these InputStreams + * as soon as they are no longer needed, in order to release memory as early as possible. + * + * Throws a FetchFailedException if the next block could not be fetched. + */ + override def next(): (BlockId, InputStream) = { numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() currentResult = results.take() @@ -281,38 +289,78 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait) result match { - case SuccessFetchResult(_, size, _) => bytesInFlight -= size + case SuccessFetchResult(_, _, size, _) => bytesInFlight -= size case _ => } + // Send fetch requests up to maxBytesInFlight + fetchUpToMaxBytes() + + result match { + case FailureFetchResult(blockId, address, e) => + throwFetchFailedException(blockId, address, e) + + case SuccessFetchResult(blockId, address, _, buf) => + try { + (result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this)) + } catch { + case NonFatal(t) => + throwFetchFailedException(blockId, address, t) + } + } + } + + private def fetchUpToMaxBytes(): Unit = { // Send fetch requests up to maxBytesInFlight while (fetchRequests.nonEmpty && (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { sendRequest(fetchRequests.dequeue()) } + } - val iteratorTry: Try[Iterator[Any]] = result match { - case FailureFetchResult(_, e) => - Failure(e) - case SuccessFetchResult(blockId, _, buf) => - // There is a chance that createInputStream can fail (e.g. fetching a local file that does - // not exist, SPARK-4085). In that case, we should propagate the right exception so - // the scheduler gets a FetchFailedException. - Try(buf.createInputStream()).map { is0 => - val is = blockManager.wrapForCompression(blockId, is0) - val iter = serializerInstance.deserializeStream(is).asKeyValueIterator - CompletionIterator[Any, Iterator[Any]](iter, { - // Once the iterator is exhausted, release the buffer and set currentResult to null - // so we don't release it again in cleanup. - currentResult = null - buf.release() - }) - } + private def throwFetchFailedException(blockId: BlockId, address: BlockManagerId, e: Throwable) = { + blockId match { + case ShuffleBlockId(shufId, mapId, reduceId) => + throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) + case _ => + throw new SparkException( + "Failed to get block " + blockId + ", which is not a shuffle block", e) } - - (result.blockId, iteratorTry) } } +/** + * Helper class that ensures a ManagedBuffer is release upon InputStream.close() + */ +private class BufferReleasingInputStream( + private val delegate: InputStream, + private val iterator: ShuffleBlockFetcherIterator) + extends InputStream { + private[this] var closed = false + + override def read(): Int = delegate.read() + + override def close(): Unit = { + if (!closed) { + delegate.close() + iterator.releaseCurrentResultBuffer() + closed = true + } + } + + override def available(): Int = delegate.available() + + override def mark(readlimit: Int): Unit = delegate.mark(readlimit) + + override def skip(n: Long): Long = delegate.skip(n) + + override def markSupported(): Boolean = delegate.markSupported() + + override def read(b: Array[Byte]): Int = delegate.read(b) + + override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len) + + override def reset(): Unit = delegate.reset() +} private[storage] object ShuffleBlockFetcherIterator { @@ -332,16 +380,22 @@ object ShuffleBlockFetcherIterator { */ private[storage] sealed trait FetchResult { val blockId: BlockId + val address: BlockManagerId } /** * Result of a fetch from a remote block successfully. * @param blockId block id + * @param address BlockManager that the block was fetched from. * @param size estimated size of the block, used to calculate bytesInFlight. * Note that this is NOT the exact bytes. * @param buf [[ManagedBuffer]] for the content. */ - private[storage] case class SuccessFetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer) + private[storage] case class SuccessFetchResult( + blockId: BlockId, + address: BlockManagerId, + size: Long, + buf: ManagedBuffer) extends FetchResult { require(buf != null) require(size >= 0) @@ -350,8 +404,12 @@ object ShuffleBlockFetcherIterator { /** * Result of a fetch from a remote block unsuccessfully. * @param blockId block id + * @param address BlockManager that the block was attempted to be fetched from * @param e the failure exception */ - private[storage] case class FailureFetchResult(blockId: BlockId, e: Throwable) + private[storage] case class FailureFetchResult( + blockId: BlockId, + address: BlockManagerId, + e: Throwable) extends FetchResult } diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index b53c86e89a27..22878783fca6 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -27,11 +27,12 @@ import scala.util.control.NonFatal import com.google.common.io.ByteStreams import tachyon.client.{ReadType, WriteType, TachyonFS, TachyonFile} +import tachyon.conf.TachyonConf import tachyon.TachyonURI -import org.apache.spark.{SparkException, SparkConf, Logging} +import org.apache.spark.Logging import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** @@ -60,7 +61,11 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log rootDirs = s"$storeDir/$appFolderName/$executorId" master = blockManager.conf.get(ExternalBlockStore.MASTER_URL, "tachyon://localhost:19998") - client = if (master != null && master != "") TachyonFS.get(new TachyonURI(master)) else null + client = if (master != null && master != "") { + TachyonFS.get(new TachyonURI(master), new TachyonConf()) + } else { + null + } // original implementation call System.exit, we change it to run without extblkstore support if (client == null) { logError("Failed to connect to the Tachyon as the master address is not configured") @@ -75,7 +80,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log // in order to avoid having really large inodes at the top level in Tachyon. tachyonDirs = createTachyonDirs() subDirs = Array.fill(tachyonDirs.length)(new Array[TachyonFile](subDirsPerTachyonDir)) - tachyonDirs.foreach(tachyonDir => Utils.registerShutdownDeleteDir(tachyonDir)) + tachyonDirs.foreach(tachyonDir => ShutdownHookManager.registerShutdownDeleteDir(tachyonDir)) } override def toString: String = {"ExternalBlockStore-Tachyon"} @@ -235,7 +240,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log logDebug("Shutdown hook called") tachyonDirs.foreach { tachyonDir => try { - if (!Utils.hasRootAsShutdownDeleteDir(tachyonDir)) { + if (!ShutdownHookManager.hasRootAsShutdownDeleteDir(tachyonDir)) { Utils.deleteRecursively(tachyonDir, client) } } catch { diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 06e616220c70..b796a44fe01a 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -59,7 +59,17 @@ private[spark] object JettyUtils extends Logging { def createServlet[T <% AnyRef]( servletParams: ServletParams[T], - securityMgr: SecurityManager): HttpServlet = { + securityMgr: SecurityManager, + conf: SparkConf): HttpServlet = { + + // SPARK-10589 avoid frame-related click-jacking vulnerability, using X-Frame-Options + // (see http://tools.ietf.org/html/rfc7034). By default allow framing only from the + // same origin, but allow framing for a specific named URI. + // Example: spark.ui.allowFramingFrom = https://example.com/ + val allowFramingFrom = conf.getOption("spark.ui.allowFramingFrom") + val xFrameOptionsValue = + allowFramingFrom.map(uri => s"ALLOW-FROM $uri").getOrElse("SAMEORIGIN") + new HttpServlet { override def doGet(request: HttpServletRequest, response: HttpServletResponse) { try { @@ -68,7 +78,10 @@ private[spark] object JettyUtils extends Logging { response.setStatus(HttpServletResponse.SC_OK) val result = servletParams.responder(request) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + response.setHeader("X-Frame-Options", xFrameOptionsValue) + // scalastyle:off println response.getWriter.println(servletParams.extractFn(result)) + // scalastyle:on println } else { response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") @@ -95,8 +108,9 @@ private[spark] object JettyUtils extends Logging { path: String, servletParams: ServletParams[T], securityMgr: SecurityManager, + conf: SparkConf, basePath: String = ""): ServletContextHandler = { - createServletHandler(path, createServlet(servletParams, securityMgr), basePath) + createServletHandler(path, createServlet(servletParams, securityMgr, conf), basePath) } /** Create a context handler that responds to a request with the given path prefix */ @@ -104,7 +118,11 @@ private[spark] object JettyUtils extends Logging { path: String, servlet: HttpServlet, basePath: String): ServletContextHandler = { - val prefixedPath = attachPrefix(basePath, path) + val prefixedPath = if (basePath == "" && path == "/") { + path + } else { + (basePath + path).stripSuffix("/") + } val contextHandler = new ServletContextHandler val holder = new ServletHolder(servlet) contextHandler.setContextPath(prefixedPath) @@ -119,7 +137,7 @@ private[spark] object JettyUtils extends Logging { beforeRedirect: HttpServletRequest => Unit = x => (), basePath: String = "", httpMethods: Set[String] = Set("GET")): ServletContextHandler = { - val prefixedDestPath = attachPrefix(basePath, destPath) + val prefixedDestPath = basePath + destPath val servlet = new HttpServlet { override def doGet(request: HttpServletRequest, response: HttpServletResponse): Unit = { if (httpMethods.contains("GET")) { @@ -210,10 +228,16 @@ private[spark] object JettyUtils extends Logging { conf: SparkConf, serverName: String = ""): ServerInfo = { - val collection = new ContextHandlerCollection - collection.setHandlers(handlers.toArray) addFilters(handlers, conf) + val collection = new ContextHandlerCollection + val gzipHandlers = handlers.map { h => + val gzipHandler = new GzipHandler + gzipHandler.setHandler(h) + gzipHandler + } + collection.setHandlers(gzipHandlers.toArray) + // Bind to the given port, or throw a java.net.BindException if the port is occupied def connect(currentPort: Int): (Server, Int) = { val server = new Server(new InetSocketAddress(hostName, currentPort)) @@ -238,11 +262,6 @@ private[spark] object JettyUtils extends Logging { val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName) ServerInfo(server, boundPort, collection) } - - /** Attach a prefix to the given path, but avoid returning an empty path */ - private def attachPrefix(basePath: String, relativePath: String): String = { - if (basePath == "") relativePath else (basePath + relativePath).stripSuffix("/") - } } private[spark] case class ServerInfo( diff --git a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala new file mode 100644 index 000000000000..6e2375477a68 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import scala.xml.{Node, Unparsed} + +/** + * A data source that provides data for a page. + * + * @param pageSize the number of rows in a page + */ +private[ui] abstract class PagedDataSource[T](val pageSize: Int) { + + if (pageSize <= 0) { + throw new IllegalArgumentException("Page size must be positive") + } + + /** + * Return the size of all data. + */ + protected def dataSize: Int + + /** + * Slice a range of data. + */ + protected def sliceData(from: Int, to: Int): Seq[T] + + /** + * Slice the data for this page + */ + def pageData(page: Int): PageData[T] = { + val totalPages = (dataSize + pageSize - 1) / pageSize + if (page <= 0 || page > totalPages) { + throw new IndexOutOfBoundsException( + s"Page $page is out of range. Please select a page number between 1 and $totalPages.") + } + val from = (page - 1) * pageSize + val to = dataSize.min(page * pageSize) + PageData(totalPages, sliceData(from, to)) + } + +} + +/** + * The data returned by `PagedDataSource.pageData`, including the page number, the number of total + * pages and the data in this page. + */ +private[ui] case class PageData[T](totalPage: Int, data: Seq[T]) + +/** + * A paged table that will generate a HTML table for a specified page and also the page navigation. + */ +private[ui] trait PagedTable[T] { + + def tableId: String + + def tableCssClass: String + + def dataSource: PagedDataSource[T] + + def headers: Seq[Node] + + def row(t: T): Seq[Node] + + def table(page: Int): Seq[Node] = { + val _dataSource = dataSource + try { + val PageData(totalPages, data) = _dataSource.pageData(page) +

+ {pageNavigation(page, _dataSource.pageSize, totalPages)} + + {headers} + + {data.map(row)} + +
+
+ } catch { + case e: IndexOutOfBoundsException => + val PageData(totalPages, _) = _dataSource.pageData(1) +
+ {pageNavigation(1, _dataSource.pageSize, totalPages)} +
{e.getMessage}
+
+ } + } + + /** + * Return a page navigation. + *
    + *
  • If the totalPages is 1, the page navigation will be empty
  • + *
  • + * If the totalPages is more than 1, it will create a page navigation including a group of + * page numbers and a form to submit the page number. + *
  • + *
+ * + * Here are some examples of the page navigation: + * {{{ + * << < 11 12 13* 14 15 16 17 18 19 20 > >> + * + * This is the first group, so "<<" is hidden. + * < 1 2* 3 4 5 6 7 8 9 10 > >> + * + * This is the first group and the first page, so "<<" and "<" are hidden. + * 1* 2 3 4 5 6 7 8 9 10 > >> + * + * Assume totalPages is 19. This is the last group, so ">>" is hidden. + * << < 11 12 13* 14 15 16 17 18 19 > + * + * Assume totalPages is 19. This is the last group and the last page, so ">>" and ">" are hidden. + * << < 11 12 13 14 15 16 17 18 19* + * + * * means the current page number + * << means jumping to the first page of the previous group. + * < means jumping to the previous page. + * >> means jumping to the first page of the next group. + * > means jumping to the next page. + * }}} + */ + private[ui] def pageNavigation(page: Int, pageSize: Int, totalPages: Int): Seq[Node] = { + if (totalPages == 1) { + Nil + } else { + // A group includes all page numbers will be shown in the page navigation. + // The size of group is 10 means there are 10 page numbers will be shown. + // The first group is 1 to 10, the second is 2 to 20, and so on + val groupSize = 10 + val firstGroup = 0 + val lastGroup = (totalPages - 1) / groupSize + val currentGroup = (page - 1) / groupSize + val startPage = currentGroup * groupSize + 1 + val endPage = totalPages.min(startPage + groupSize - 1) + val pageTags = (startPage to endPage).map { p => + if (p == page) { + // The current page should be disabled so that it cannot be clicked. +
  • {p}
  • + } else { +
  • {p}
  • + } + } + val (goButtonJsFuncName, goButtonJsFunc) = goButtonJavascriptFunction + // When clicking the "Go" button, it will call this javascript method and then call + // "goButtonJsFuncName" + val formJs = + s"""$$(function(){ + | $$( "#form-$tableId-page" ).submit(function(event) { + | var page = $$("#form-$tableId-page-no").val() + | var pageSize = $$("#form-$tableId-page-size").val() + | pageSize = pageSize ? pageSize: 100; + | if (page != "") { + | ${goButtonJsFuncName}(page, pageSize); + | } + | event.preventDefault(); + | }); + |}); + """.stripMargin + +
    +
    +
    + + + + + + +
    +
    + + +
    + } + } + + /** + * Return a link to jump to a page. + */ + def pageLink(page: Int): String + + /** + * Only the implementation knows how to create the url with a page number and the page size, so we + * leave this one to the implementation. The implementation should create a JavaScript method that + * accepts a page number along with the page size and jumps to the page. The return value is this + * method name and its JavaScript codes. + */ + def goButtonJavascriptFunction: (String, String) +} diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 3788916cf39b..d8b90568b7b9 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -64,11 +64,11 @@ private[spark] class SparkUI private ( attachTab(new EnvironmentTab(this)) attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) - attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath)) + attachHandler(createRedirectHandler("/", "/jobs/", basePath = basePath)) attachHandler(ApiRootResource.getServletHandler(this)) // This should be POST only, but, the YARN AM proxy won't proxy POSTs attachHandler(createRedirectHandler( - "/stages/stage/kill", "/stages", stagesTab.handleKillRequest, + "/stages/stage/kill", "/stages/", stagesTab.handleKillRequest, httpMethods = Set("GET", "POST"))) } initialize() diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index 063e2a1f8b18..cb122eaed83d 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -35,6 +35,10 @@ private[spark] object ToolTips { val OUTPUT = "Bytes and records written to Hadoop." + val STORAGE_MEMORY = + "Memory used / total available memory for storage of data " + + "like RDD partitions cached in memory. " + val SHUFFLE_WRITE = "Bytes and records written to disk in order to be read by a shuffle in a future stage." @@ -58,6 +62,13 @@ private[spark] object ToolTips { """Time that the executor spent paused for Java garbage collection while the task was running.""" + val PEAK_EXECUTION_MEMORY = + """Execution memory refers to the memory used by internal data structures created during + shuffles, aggregations and joins when Tungsten is enabled. The value of this accumulator + should be approximately the sum of the peak sizes across all such data structures created + in this task. For SQL jobs, this only tracks all unsafe operators, broadcast joins, and + external sort.""" + val JOB_TIMELINE = """Shows when jobs started and ended and when executors joined or left. Drag to scroll. Click Enable Zooming and use mouse wheel to zoom in/out.""" diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 789803951920..f2da41772410 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.ui.scope.RDDOperationGraph /** Utility functions for generating XML pages with spark content. */ private[spark] object UIUtils extends Logging { - val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed sortable" + val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed" val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped" // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. @@ -267,9 +267,17 @@ private[spark] object UIUtils extends Logging { fixedWidth: Boolean = false, id: Option[String] = None, headerClasses: Seq[String] = Seq.empty, - stripeRowsWithCss: Boolean = true): Seq[Node] = { + stripeRowsWithCss: Boolean = true, + sortable: Boolean = true): Seq[Node] = { - val listingTableClass = if (stripeRowsWithCss) TABLE_CLASS_STRIPED else TABLE_CLASS_NOT_STRIPED + val listingTableClass = { + val _tableClass = if (stripeRowsWithCss) TABLE_CLASS_STRIPED else TABLE_CLASS_NOT_STRIPED + if (sortable) { + _tableClass + " sortable" + } else { + _tableClass + } + } val colWidth = 100.toDouble / headers.size val colWidthAttr = if (fixedWidth) colWidth + "%" else "" @@ -344,7 +352,8 @@ private[spark] object UIUtils extends Logging { */ private def showDagViz(graphs: Seq[RDDOperationGraph], forJob: Boolean): Seq[Node] = {
    - + diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala index ba03acdb38cc..5a8c2914314c 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -38,9 +38,11 @@ private[spark] object UIWorkloadGenerator { def main(args: Array[String]) { if (args.length < 3) { + // scalastyle:off println println( - "usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator " + + "Usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator " + "[master] [FIFO|FAIR] [#job set (4 jobs per set)]") + // scalastyle:on println System.exit(1) } @@ -96,6 +98,7 @@ private[spark] object UIWorkloadGenerator { for ((desc, job) <- jobs) { new Thread { override def run() { + // scalastyle:off println try { setProperties(desc) job() @@ -106,6 +109,7 @@ private[spark] object UIWorkloadGenerator { } finally { barrier.release() } + // scalastyle:on println } }.start Thread.sleep(INTER_JOB_WAIT_MS) diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 2c84e4485996..81a121fd441b 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -76,9 +76,9 @@ private[spark] abstract class WebUI( def attachPage(page: WebUIPage) { val pagePath = "/" + page.prefix val renderHandler = createServletHandler(pagePath, - (request: HttpServletRequest) => page.render(request), securityManager, basePath) + (request: HttpServletRequest) => page.render(request), securityManager, conf, basePath) val renderJsonHandler = createServletHandler(pagePath.stripSuffix("/") + "/json", - (request: HttpServletRequest) => page.renderJson(request), securityManager, basePath) + (request: HttpServletRequest) => page.renderJson(request), securityManager, conf, basePath) attachHandler(renderHandler) attachHandler(renderJsonHandler) pageToHandlers.getOrElseUpdate(page, ArrayBuffer[ServletContextHandler]()) @@ -107,6 +107,25 @@ private[spark] abstract class WebUI( } } + /** + * Add a handler for static content. + * + * @param resourceBase Root of where to find resources to serve. + * @param path Path in UI where to mount the resources. + */ + def addStaticHandler(resourceBase: String, path: String): Unit = { + attachHandler(JettyUtils.createStaticHandler(resourceBase, path)) + } + + /** + * Remove a static content handler. + * + * @param path Path in UI to unmount. + */ + def removeStaticHandler(path: String): Unit = { + handlers.find(_.getContextPath() == path).foreach(detachHandler) + } + /** Initialize all components of the server. */ def initialize() diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index f0ae95bb8c81..b0a2cb4aa4d4 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -49,11 +49,29 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage val maybeThreadDump = sc.get.getExecutorThreadDump(executorId) val content = maybeThreadDump.map { threadDump => - val dumpRows = threadDump.map { thread => + val dumpRows = threadDump.sortWith { + case (threadTrace1, threadTrace2) => { + val v1 = if (threadTrace1.threadName.contains("Executor task launch")) 1 else 0 + val v2 = if (threadTrace2.threadName.contains("Executor task launch")) 1 else 0 + if (v1 == v2) { + threadTrace1.threadName.toLowerCase < threadTrace2.threadName.toLowerCase + } else { + v1 > v2 + } + } + }.map { thread => + val threadName = thread.threadName + val className = "accordion-heading " + { + if (threadName.contains("Executor task launch")) { + "executor-thread" + } else { + "non-executor-thread" + } + }
    -
    + @@ -229,54 +253,50 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val accumulableTable = UIUtils.listingTable( accumulableHeaders, accumulableRow, - accumulables.values.toSeq) - - val taskHeadersAndCssClasses: Seq[(String, String)] = - Seq( - ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""), - ("Executor ID / Host", ""), ("Launch Time", ""), ("Duration", ""), - ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY), - ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), - ("GC Time", ""), - ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), - ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++ - {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ - {if (stageData.hasInput) Seq(("Input Size / Records", "")) else Nil} ++ - {if (stageData.hasOutput) Seq(("Output Size / Records", "")) else Nil} ++ - {if (stageData.hasShuffleRead) { - Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), - ("Shuffle Read Size / Records", ""), - ("Shuffle Remote Reads", TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE)) - } else { - Nil - }} ++ - {if (stageData.hasShuffleWrite) { - Seq(("Write Time", ""), ("Shuffle Write Size / Records", "")) - } else { - Nil - }} ++ - {if (stageData.hasBytesSpilled) { - Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) - } else { - Nil - }} ++ - Seq(("Errors", "")) - - val unzipped = taskHeadersAndCssClasses.unzip + externalAccumulables.toSeq) val currentTime = System.currentTimeMillis() - val taskTable = UIUtils.listingTable( - unzipped._1, - taskRow( + val (taskTable, taskTableHTML) = try { + val _taskTable = new TaskPagedTable( + parent.conf, + UIUtils.prependBaseUri(parent.basePath) + + s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}", + tasks, hasAccumulators, stageData.hasInput, stageData.hasOutput, stageData.hasShuffleRead, stageData.hasShuffleWrite, stageData.hasBytesSpilled, - currentTime), - tasks, - headerClasses = unzipped._2) + currentTime, + pageSize = taskPageSize, + sortColumn = taskSortColumn, + desc = taskSortDesc + ) + (_taskTable, _taskTable.table(taskPage)) + } catch { + case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => + (null,
    {e.getMessage}
    ) + } + + val jsForScrollingDownToTaskTable = + + + val taskIdsInPage = if (taskTable == null) Set.empty[Long] + else taskTable.dataSource.slicedTaskIds + // Excludes tasks which failed and have incomplete metrics val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.taskMetrics.isDefined) @@ -287,12 +307,14 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { else { def getDistributionQuantiles(data: Seq[Double]): IndexedSeq[Double] = Distribution(data).get.getQuantiles() - def getFormattedTimeQuantiles(times: Seq[Double]): Seq[Node] = { getDistributionQuantiles(times).map { millis => {UIUtils.formatDuration(millis.toLong)} } } + def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = { + getDistributionQuantiles(data).map(d => {Utils.bytesToString(d.toLong)}) + } val deserializationTimes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.executorDeserializeTime.toDouble @@ -332,7 +354,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +: getFormattedTimeQuantiles(serializationTimes) val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) => - getGettingResultTime(info).toDouble + getGettingResultTime(info, currentTime).toDouble } val gettingResultQuantiles = @@ -342,20 +364,33 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +: getFormattedTimeQuantiles(gettingResultTimes) + + val peakExecutionMemory = validTasks.map { case TaskUIData(info, _, _) => + info.accumulables + .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY } + .map { acc => acc.update.getOrElse("0").toLong } + .getOrElse(0L) + .toDouble + } + val peakExecutionMemoryQuantiles = { + + + Peak Execution Memory + + +: getFormattedSizeQuantiles(peakExecutionMemory) + } + // The scheduler delay includes the network delay to send the task to the worker // machine and to send back the result (but not the time to fetch the task result, // if it needed to be fetched from the block manager on the worker). val schedulerDelays = validTasks.map { case TaskUIData(info, metrics, _) => - getSchedulerDelay(info, metrics.get).toDouble + getSchedulerDelay(info, metrics.get, currentTime).toDouble } val schedulerDelayTitle = Scheduler Delay val schedulerDelayQuantiles = schedulerDelayTitle +: getFormattedTimeQuantiles(schedulerDelays) - - def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = - getDistributionQuantiles(data).map(d => {Utils.bytesToString(d.toLong)}) - def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double]) : Seq[Elem] = { val recordDist = getDistributionQuantiles(records).iterator @@ -459,6 +494,13 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {serializationQuantiles} , {gettingResultQuantiles}, + if (displayPeakExecutionMemory) { + + {peakExecutionMemoryQuantiles} + + } else { + Nil + }, if (stageData.hasInput) {inputQuantiles} else Nil, if (stageData.hasOutput) {outputQuantiles} else Nil, if (stageData.hasShuffleRead) { @@ -492,19 +534,22 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val executorTable = new ExecutorTable(stageId, stageAttemptId, parent) val maybeAccumulableTable: Seq[Node] = - if (accumulables.size > 0) {

    Accumulators

    ++ accumulableTable } else Seq() + if (hasAccumulators) {

    Accumulators

    ++ accumulableTable } else Seq() val content = summary ++ dagViz ++ maybeExpandDagViz ++ showAdditionalMetrics ++ - makeTimeline(stageData.taskData.values.toSeq, currentTime) ++ + makeTimeline( + // Only show the tasks in the table + stageData.taskData.values.toSeq.filter(t => taskIdsInPage.contains(t.taskInfo.taskId)), + currentTime) ++

    Summary Metrics for {numCompleted} Completed Tasks

    ++
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++

    Aggregated Metrics by Executor

    ++ executorTable.toNodeSeq ++ maybeAccumulableTable ++ -

    Tasks

    ++ taskTable +

    Tasks

    ++ taskTableHTML ++ jsForScrollingDownToTaskTable UIUtils.headerSparkPage(stageHeader, content, parent, showVisualization = true) } } @@ -537,20 +582,27 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { (metricsOpt.flatMap(_.shuffleWriteMetrics .map(_.shuffleWriteTime)).getOrElse(0L) / 1e6).toLong val shuffleWriteTimeProportion = toProportion(shuffleWriteTime) - val executorComputingTime = metricsOpt.map(_.executorRunTime).getOrElse(0L) - - shuffleReadTime - shuffleWriteTime - val executorComputingTimeProportion = toProportion(executorComputingTime) + val serializationTime = metricsOpt.map(_.resultSerializationTime).getOrElse(0L) val serializationTimeProportion = toProportion(serializationTime) val deserializationTime = metricsOpt.map(_.executorDeserializeTime).getOrElse(0L) val deserializationTimeProportion = toProportion(deserializationTime) - val gettingResultTime = getGettingResultTime(taskUIData.taskInfo) + val gettingResultTime = getGettingResultTime(taskUIData.taskInfo, currentTime) val gettingResultTimeProportion = toProportion(gettingResultTime) - val schedulerDelay = totalExecutionTime - - (executorComputingTime + shuffleReadTime + shuffleWriteTime + - serializationTime + deserializationTime + gettingResultTime) - val schedulerDelayProportion = - (100 - executorComputingTimeProportion - shuffleReadTimeProportion - + val schedulerDelay = + metricsOpt.map(getSchedulerDelay(taskInfo, _, currentTime)).getOrElse(0L) + val schedulerDelayProportion = toProportion(schedulerDelay) + + val executorOverhead = serializationTime + deserializationTime + val executorRunTime = if (taskInfo.running) { + totalExecutionTime - executorOverhead - gettingResultTime + } else { + metricsOpt.map(_.executorRunTime).getOrElse( + totalExecutionTime - executorOverhead - gettingResultTime) + } + val executorComputingTime = executorRunTime - shuffleReadTime - shuffleWriteTime + val executorComputingTimeProportion = + (100 - schedulerDelayProportion - shuffleReadTimeProportion - shuffleWriteTimeProportion - serializationTimeProportion - deserializationTimeProportion - gettingResultTimeProportion) @@ -569,58 +621,66 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { serializationTimeProportionPos + serializationTimeProportion val index = taskInfo.index - val attempt = taskInfo.attempt + val attempt = taskInfo.attemptNumber + + val svgTag = + if (totalExecutionTime == 0) { + // SPARK-8705: Avoid invalid attribute error in JavaScript if execution time is 0 + """""" + } else { + s""" + | + | + | + | + | + | + |""".stripMargin + } val timelineObject = s""" - { - 'className': 'task task-assignment-timeline-object', - 'group': '$executorId', - 'content': '
    ' + - 'Status: ${taskInfo.status}
    ' + - 'Launch Time: ${UIUtils.formatDate(new Date(launchTime))}' + - '${ + |{ + |'className': 'task task-assignment-timeline-object', + |'group': '$executorId', + |'content': '
    + |Status: ${taskInfo.status}
    + |Launch Time: ${UIUtils.formatDate(new Date(launchTime))} + |${ if (!taskInfo.running) { s"""
    Finish Time: ${UIUtils.formatDate(new Date(finishTime))}""" } else { "" } - }' + - '
    Scheduler Delay: $schedulerDelay ms' + - '
    Task Deserialization Time: ${UIUtils.formatDuration(deserializationTime)}' + - '
    Shuffle Read Time: ${UIUtils.formatDuration(shuffleReadTime)}' + - '
    Executor Computing Time: ${UIUtils.formatDuration(executorComputingTime)}' + - '
    Shuffle Write Time: ${UIUtils.formatDuration(shuffleWriteTime)}' + - '
    Result Serialization Time: ${UIUtils.formatDuration(serializationTime)}' + - '
    Getting Result Time: ${UIUtils.formatDuration(gettingResultTime)}">' + - '' + - '' + - '' + - '' + - '' + - '' + - '' + - '', - 'start': new Date($launchTime), - 'end': new Date($finishTime) - } - """ + } + |
    Scheduler Delay: $schedulerDelay ms + |
    Task Deserialization Time: ${UIUtils.formatDuration(deserializationTime)} + |
    Shuffle Read Time: ${UIUtils.formatDuration(shuffleReadTime)} + |
    Executor Computing Time: ${UIUtils.formatDuration(executorComputingTime)} + |
    Shuffle Write Time: ${UIUtils.formatDuration(shuffleWriteTime)} + |
    Result Serialization Time: ${UIUtils.formatDuration(serializationTime)} + |
    Getting Result Time: ${UIUtils.formatDuration(gettingResultTime)}"> + |$svgTag', + |'start': new Date($launchTime), + |'end': new Date($finishTime) + |} + |""".stripMargin.replaceAll("""[\r\n]+""", " ") timelineObject }.mkString("[", ",", "]") @@ -664,162 +724,644 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { } - def taskRow( - hasAccumulators: Boolean, - hasInput: Boolean, - hasOutput: Boolean, - hasShuffleRead: Boolean, - hasShuffleWrite: Boolean, - hasBytesSpilled: Boolean, - currentTime: Long)(taskData: TaskUIData): Seq[Node] = { - taskData match { case TaskUIData(info, metrics, errorMessage) => - val duration = if (info.status == "RUNNING") info.timeRunning(currentTime) - else metrics.map(_.executorRunTime).getOrElse(1L) - val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration) - else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("") - val schedulerDelay = metrics.map(getSchedulerDelay(info, _)).getOrElse(0L) - val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) - val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) - val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) - val gettingResultTime = getGettingResultTime(info) - - val maybeAccumulators = info.accumulables - val accumulatorsReadable = maybeAccumulators.map{acc => s"${acc.name}: ${acc.update.get}"} - - val maybeInput = metrics.flatMap(_.inputMetrics) - val inputSortable = maybeInput.map(_.bytesRead.toString).getOrElse("") - val inputReadable = maybeInput - .map(m => s"${Utils.bytesToString(m.bytesRead)} (${m.readMethod.toString.toLowerCase()})") - .getOrElse("") - val inputRecords = maybeInput.map(_.recordsRead.toString).getOrElse("") - - val maybeOutput = metrics.flatMap(_.outputMetrics) - val outputSortable = maybeOutput.map(_.bytesWritten.toString).getOrElse("") - val outputReadable = maybeOutput - .map(m => s"${Utils.bytesToString(m.bytesWritten)}") - .getOrElse("") - val outputRecords = maybeOutput.map(_.recordsWritten.toString).getOrElse("") - - val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics) - val shuffleReadBlockedTimeSortable = maybeShuffleRead - .map(_.fetchWaitTime.toString).getOrElse("") - val shuffleReadBlockedTimeReadable = - maybeShuffleRead.map(ms => UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("") - - val totalShuffleBytes = maybeShuffleRead.map(_.totalBytesRead) - val shuffleReadSortable = totalShuffleBytes.map(_.toString).getOrElse("") - val shuffleReadReadable = totalShuffleBytes.map(Utils.bytesToString).getOrElse("") - val shuffleReadRecords = maybeShuffleRead.map(_.recordsRead.toString).getOrElse("") - - val remoteShuffleBytes = maybeShuffleRead.map(_.remoteBytesRead) - val shuffleReadRemoteSortable = remoteShuffleBytes.map(_.toString).getOrElse("") - val shuffleReadRemoteReadable = remoteShuffleBytes.map(Utils.bytesToString).getOrElse("") - - val maybeShuffleWrite = metrics.flatMap(_.shuffleWriteMetrics) - val shuffleWriteSortable = maybeShuffleWrite.map(_.shuffleBytesWritten.toString).getOrElse("") - val shuffleWriteReadable = maybeShuffleWrite - .map(m => s"${Utils.bytesToString(m.shuffleBytesWritten)}").getOrElse("") - val shuffleWriteRecords = maybeShuffleWrite - .map(_.shuffleRecordsWritten.toString).getOrElse("") - - val maybeWriteTime = metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleWriteTime) - val writeTimeSortable = maybeWriteTime.map(_.toString).getOrElse("") - val writeTimeReadable = maybeWriteTime.map(t => t / (1000 * 1000)).map { ms => - if (ms == 0) "" else UIUtils.formatDuration(ms) - }.getOrElse("") - - val maybeMemoryBytesSpilled = metrics.map(_.memoryBytesSpilled) - val memoryBytesSpilledSortable = maybeMemoryBytesSpilled.map(_.toString).getOrElse("") - val memoryBytesSpilledReadable = - maybeMemoryBytesSpilled.map(Utils.bytesToString).getOrElse("") - - val maybeDiskBytesSpilled = metrics.map(_.diskBytesSpilled) - val diskBytesSpilledSortable = maybeDiskBytesSpilled.map(_.toString).getOrElse("") - val diskBytesSpilledReadable = maybeDiskBytesSpilled.map(Utils.bytesToString).getOrElse("") - - - {info.index} - {info.taskId} - { - if (info.speculative) s"${info.attempt} (speculative)" else info.attempt.toString - } - {info.status} - {info.taskLocality} - {info.executorId} / {info.host} - {UIUtils.formatDate(new Date(info.launchTime))} - - {formatDuration} - - - {UIUtils.formatDuration(schedulerDelay.toLong)} - - - {UIUtils.formatDuration(taskDeserializationTime.toLong)} - - - {if (gcTime > 0) UIUtils.formatDuration(gcTime) else ""} - - - {UIUtils.formatDuration(serializationTime)} - - - {UIUtils.formatDuration(gettingResultTime)} - - {if (hasAccumulators) { - - {Unparsed(accumulatorsReadable.mkString("
    "))} - - }} - {if (hasInput) { - - {s"$inputReadable / $inputRecords"} - - }} - {if (hasOutput) { - - {s"$outputReadable / $outputRecords"} - - }} +} + +private[ui] object StagePage { + private[ui] def getGettingResultTime(info: TaskInfo, currentTime: Long): Long = { + if (info.gettingResult) { + if (info.finished) { + info.finishTime - info.gettingResultTime + } else { + // The task is still fetching the result. + currentTime - info.gettingResultTime + } + } else { + 0L + } + } + + private[ui] def getSchedulerDelay( + info: TaskInfo, metrics: TaskMetrics, currentTime: Long): Long = { + if (info.finished) { + val totalExecutionTime = info.finishTime - info.launchTime + val executorOverhead = (metrics.executorDeserializeTime + + metrics.resultSerializationTime) + math.max( + 0, + totalExecutionTime - metrics.executorRunTime - executorOverhead - + getGettingResultTime(info, currentTime)) + } else { + // The task is still running and the metrics like executorRunTime are not available. + 0L + } + } +} + +private[ui] case class TaskTableRowInputData(inputSortable: Long, inputReadable: String) + +private[ui] case class TaskTableRowOutputData(outputSortable: Long, outputReadable: String) + +private[ui] case class TaskTableRowShuffleReadData( + shuffleReadBlockedTimeSortable: Long, + shuffleReadBlockedTimeReadable: String, + shuffleReadSortable: Long, + shuffleReadReadable: String, + shuffleReadRemoteSortable: Long, + shuffleReadRemoteReadable: String) + +private[ui] case class TaskTableRowShuffleWriteData( + writeTimeSortable: Long, + writeTimeReadable: String, + shuffleWriteSortable: Long, + shuffleWriteReadable: String) + +private[ui] case class TaskTableRowBytesSpilledData( + memoryBytesSpilledSortable: Long, + memoryBytesSpilledReadable: String, + diskBytesSpilledSortable: Long, + diskBytesSpilledReadable: String) + +/** + * Contains all data that needs for sorting and generating HTML. Using this one rather than + * TaskUIData to avoid creating duplicate contents during sorting the data. + */ +private[ui] class TaskTableRowData( + val index: Int, + val taskId: Long, + val attempt: Int, + val speculative: Boolean, + val status: String, + val taskLocality: String, + val executorIdAndHost: String, + val launchTime: Long, + val duration: Long, + val formatDuration: String, + val schedulerDelay: Long, + val taskDeserializationTime: Long, + val gcTime: Long, + val serializationTime: Long, + val gettingResultTime: Long, + val peakExecutionMemoryUsed: Long, + val accumulators: Option[String], // HTML + val input: Option[TaskTableRowInputData], + val output: Option[TaskTableRowOutputData], + val shuffleRead: Option[TaskTableRowShuffleReadData], + val shuffleWrite: Option[TaskTableRowShuffleWriteData], + val bytesSpilled: Option[TaskTableRowBytesSpilledData], + val error: String) + +private[ui] class TaskDataSource( + tasks: Seq[TaskUIData], + hasAccumulators: Boolean, + hasInput: Boolean, + hasOutput: Boolean, + hasShuffleRead: Boolean, + hasShuffleWrite: Boolean, + hasBytesSpilled: Boolean, + currentTime: Long, + pageSize: Int, + sortColumn: String, + desc: Boolean) extends PagedDataSource[TaskTableRowData](pageSize) { + import StagePage._ + + // Convert TaskUIData to TaskTableRowData which contains the final contents to show in the table + // so that we can avoid creating duplicate contents during sorting the data + private val data = tasks.map(taskRow).sorted(ordering(sortColumn, desc)) + + private var _slicedTaskIds: Set[Long] = null + + override def dataSize: Int = data.size + + override def sliceData(from: Int, to: Int): Seq[TaskTableRowData] = { + val r = data.slice(from, to) + _slicedTaskIds = r.map(_.taskId).toSet + r + } + + def slicedTaskIds: Set[Long] = _slicedTaskIds + + private def taskRow(taskData: TaskUIData): TaskTableRowData = { + val TaskUIData(info, metrics, errorMessage) = taskData + val duration = if (info.status == "RUNNING") info.timeRunning(currentTime) + else metrics.map(_.executorRunTime).getOrElse(1L) + val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration) + else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("") + val schedulerDelay = metrics.map(getSchedulerDelay(info, _, currentTime)).getOrElse(0L) + val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) + val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) + val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) + val gettingResultTime = getGettingResultTime(info, currentTime) + + val (taskInternalAccumulables, taskExternalAccumulables) = + info.accumulables.partition(_.internal) + val externalAccumulableReadable = taskExternalAccumulables.map { acc => + StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}") + } + val peakExecutionMemoryUsed = taskInternalAccumulables + .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY } + .map { acc => acc.update.getOrElse("0").toLong } + .getOrElse(0L) + + val maybeInput = metrics.flatMap(_.inputMetrics) + val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L) + val inputReadable = maybeInput + .map(m => s"${Utils.bytesToString(m.bytesRead)} (${m.readMethod.toString.toLowerCase()})") + .getOrElse("") + val inputRecords = maybeInput.map(_.recordsRead.toString).getOrElse("") + + val maybeOutput = metrics.flatMap(_.outputMetrics) + val outputSortable = maybeOutput.map(_.bytesWritten).getOrElse(0L) + val outputReadable = maybeOutput + .map(m => s"${Utils.bytesToString(m.bytesWritten)}") + .getOrElse("") + val outputRecords = maybeOutput.map(_.recordsWritten.toString).getOrElse("") + + val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics) + val shuffleReadBlockedTimeSortable = maybeShuffleRead.map(_.fetchWaitTime).getOrElse(0L) + val shuffleReadBlockedTimeReadable = + maybeShuffleRead.map(ms => UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("") + + val totalShuffleBytes = maybeShuffleRead.map(_.totalBytesRead) + val shuffleReadSortable = totalShuffleBytes.getOrElse(0L) + val shuffleReadReadable = totalShuffleBytes.map(Utils.bytesToString).getOrElse("") + val shuffleReadRecords = maybeShuffleRead.map(_.recordsRead.toString).getOrElse("") + + val remoteShuffleBytes = maybeShuffleRead.map(_.remoteBytesRead) + val shuffleReadRemoteSortable = remoteShuffleBytes.getOrElse(0L) + val shuffleReadRemoteReadable = remoteShuffleBytes.map(Utils.bytesToString).getOrElse("") + + val maybeShuffleWrite = metrics.flatMap(_.shuffleWriteMetrics) + val shuffleWriteSortable = maybeShuffleWrite.map(_.shuffleBytesWritten).getOrElse(0L) + val shuffleWriteReadable = maybeShuffleWrite + .map(m => s"${Utils.bytesToString(m.shuffleBytesWritten)}").getOrElse("") + val shuffleWriteRecords = maybeShuffleWrite + .map(_.shuffleRecordsWritten.toString).getOrElse("") + + val maybeWriteTime = metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleWriteTime) + val writeTimeSortable = maybeWriteTime.getOrElse(0L) + val writeTimeReadable = maybeWriteTime.map(t => t / (1000 * 1000)).map { ms => + if (ms == 0) "" else UIUtils.formatDuration(ms) + }.getOrElse("") + + val maybeMemoryBytesSpilled = metrics.map(_.memoryBytesSpilled) + val memoryBytesSpilledSortable = maybeMemoryBytesSpilled.getOrElse(0L) + val memoryBytesSpilledReadable = + maybeMemoryBytesSpilled.map(Utils.bytesToString).getOrElse("") + + val maybeDiskBytesSpilled = metrics.map(_.diskBytesSpilled) + val diskBytesSpilledSortable = maybeDiskBytesSpilled.getOrElse(0L) + val diskBytesSpilledReadable = maybeDiskBytesSpilled.map(Utils.bytesToString).getOrElse("") + + val input = + if (hasInput) { + Some(TaskTableRowInputData(inputSortable, s"$inputReadable / $inputRecords")) + } else { + None + } + + val output = + if (hasOutput) { + Some(TaskTableRowOutputData(outputSortable, s"$outputReadable / $outputRecords")) + } else { + None + } + + val shuffleRead = + if (hasShuffleRead) { + Some(TaskTableRowShuffleReadData( + shuffleReadBlockedTimeSortable, + shuffleReadBlockedTimeReadable, + shuffleReadSortable, + s"$shuffleReadReadable / $shuffleReadRecords", + shuffleReadRemoteSortable, + shuffleReadRemoteReadable + )) + } else { + None + } + + val shuffleWrite = + if (hasShuffleWrite) { + Some(TaskTableRowShuffleWriteData( + writeTimeSortable, + writeTimeReadable, + shuffleWriteSortable, + s"$shuffleWriteReadable / $shuffleWriteRecords" + )) + } else { + None + } + + val bytesSpilled = + if (hasBytesSpilled) { + Some(TaskTableRowBytesSpilledData( + memoryBytesSpilledSortable, + memoryBytesSpilledReadable, + diskBytesSpilledSortable, + diskBytesSpilledReadable + )) + } else { + None + } + + new TaskTableRowData( + info.index, + info.taskId, + info.attemptNumber, + info.speculative, + info.status, + info.taskLocality.toString, + s"${info.executorId} / ${info.host}", + info.launchTime, + duration, + formatDuration, + schedulerDelay, + taskDeserializationTime, + gcTime, + serializationTime, + gettingResultTime, + peakExecutionMemoryUsed, + if (hasAccumulators) Some(externalAccumulableReadable.mkString("
    ")) else None, + input, + output, + shuffleRead, + shuffleWrite, + bytesSpilled, + errorMessage.getOrElse("")) + } + + /** + * Return Ordering according to sortColumn and desc + */ + private def ordering(sortColumn: String, desc: Boolean): Ordering[TaskTableRowData] = { + val ordering = sortColumn match { + case "Index" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Int.compare(x.index, y.index) + } + case "ID" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.taskId, y.taskId) + } + case "Attempt" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Int.compare(x.attempt, y.attempt) + } + case "Status" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.String.compare(x.status, y.status) + } + case "Locality Level" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.String.compare(x.taskLocality, y.taskLocality) + } + case "Executor ID / Host" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.String.compare(x.executorIdAndHost, y.executorIdAndHost) + } + case "Launch Time" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.launchTime, y.launchTime) + } + case "Duration" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.duration, y.duration) + } + case "Scheduler Delay" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.schedulerDelay, y.schedulerDelay) + } + case "Task Deserialization Time" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.taskDeserializationTime, y.taskDeserializationTime) + } + case "GC Time" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.gcTime, y.gcTime) + } + case "Result Serialization Time" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.serializationTime, y.serializationTime) + } + case "Getting Result Time" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.gettingResultTime, y.gettingResultTime) + } + case "Peak Execution Memory" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.peakExecutionMemoryUsed, y.peakExecutionMemoryUsed) + } + case "Accumulators" => + if (hasAccumulators) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.String.compare(x.accumulators.get, y.accumulators.get) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Accumulators because of no accumulators") + } + case "Input Size / Records" => + if (hasInput) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.input.get.inputSortable, y.input.get.inputSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Input Size / Records because of no inputs") + } + case "Output Size / Records" => + if (hasOutput) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.output.get.outputSortable, y.output.get.outputSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Output Size / Records because of no outputs") + } + // ShuffleRead + case "Shuffle Read Blocked Time" => + if (hasShuffleRead) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.shuffleRead.get.shuffleReadBlockedTimeSortable, + y.shuffleRead.get.shuffleReadBlockedTimeSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Shuffle Read Blocked Time because of no shuffle reads") + } + case "Shuffle Read Size / Records" => + if (hasShuffleRead) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.shuffleRead.get.shuffleReadSortable, + y.shuffleRead.get.shuffleReadSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Shuffle Read Size / Records because of no shuffle reads") + } + case "Shuffle Remote Reads" => + if (hasShuffleRead) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.shuffleRead.get.shuffleReadRemoteSortable, + y.shuffleRead.get.shuffleReadRemoteSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Shuffle Remote Reads because of no shuffle reads") + } + // ShuffleWrite + case "Write Time" => + if (hasShuffleWrite) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.shuffleWrite.get.writeTimeSortable, + y.shuffleWrite.get.writeTimeSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Write Time because of no shuffle writes") + } + case "Shuffle Write Size / Records" => + if (hasShuffleWrite) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.shuffleWrite.get.shuffleWriteSortable, + y.shuffleWrite.get.shuffleWriteSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Shuffle Write Size / Records because of no shuffle writes") + } + // BytesSpilled + case "Shuffle Spill (Memory)" => + if (hasBytesSpilled) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.bytesSpilled.get.memoryBytesSpilledSortable, + y.bytesSpilled.get.memoryBytesSpilledSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Shuffle Spill (Memory) because of no spills") + } + case "Shuffle Spill (Disk)" => + if (hasBytesSpilled) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.bytesSpilled.get.diskBytesSpilledSortable, + y.bytesSpilled.get.diskBytesSpilledSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Shuffle Spill (Disk) because of no spills") + } + case "Errors" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.String.compare(x.error, y.error) + } + case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") + } + if (desc) { + ordering.reverse + } else { + ordering + } + } + +} + +private[ui] class TaskPagedTable( + conf: SparkConf, + basePath: String, + data: Seq[TaskUIData], + hasAccumulators: Boolean, + hasInput: Boolean, + hasOutput: Boolean, + hasShuffleRead: Boolean, + hasShuffleWrite: Boolean, + hasBytesSpilled: Boolean, + currentTime: Long, + pageSize: Int, + sortColumn: String, + desc: Boolean) extends PagedTable[TaskTableRowData] { + + // We only track peak memory used for unsafe operators + private val displayPeakExecutionMemory = conf.getBoolean("spark.sql.unsafe.enabled", true) + + override def tableId: String = "task-table" + + override def tableCssClass: String = "table table-bordered table-condensed table-striped" + + override val dataSource: TaskDataSource = new TaskDataSource( + data, + hasAccumulators, + hasInput, + hasOutput, + hasShuffleRead, + hasShuffleWrite, + hasBytesSpilled, + currentTime, + pageSize, + sortColumn, + desc) + + override def pageLink(page: Int): String = { + val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") + s"${basePath}&task.page=$page&task.sort=${encodedSortColumn}&task.desc=${desc}" + + s"&task.pageSize=${pageSize}" + } + + override def goButtonJavascriptFunction: (String, String) = { + val jsFuncName = "goToTaskPage" + val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") + val jsFunc = s""" + |currentTaskPageSize = ${pageSize} + |function goToTaskPage(page, pageSize) { + | // Set page to 1 if the page size changes + | page = pageSize == currentTaskPageSize ? page : 1; + | var url = "${basePath}&task.sort=${encodedSortColumn}&task.desc=${desc}" + + | "&task.page=" + page + "&task.pageSize=" + pageSize; + | window.location.href = url; + |} + """.stripMargin + (jsFuncName, jsFunc) + } + + def headers: Seq[Node] = { + val taskHeadersAndCssClasses: Seq[(String, String)] = + Seq( + ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""), + ("Executor ID / Host", ""), ("Launch Time", ""), ("Duration", ""), + ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY), + ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), + ("GC Time", ""), + ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), + ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++ + { + if (displayPeakExecutionMemory) { + Seq(("Peak Execution Memory", TaskDetailsClassNames.PEAK_EXECUTION_MEMORY)) + } else { + Nil + } + } ++ + {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ + {if (hasInput) Seq(("Input Size / Records", "")) else Nil} ++ + {if (hasOutput) Seq(("Output Size / Records", "")) else Nil} ++ {if (hasShuffleRead) { - - {shuffleReadBlockedTimeReadable} - - - {s"$shuffleReadReadable / $shuffleReadRecords"} - - - {shuffleReadRemoteReadable} - - }} + Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), + ("Shuffle Read Size / Records", ""), + ("Shuffle Remote Reads", TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE)) + } else { + Nil + }} ++ {if (hasShuffleWrite) { - - {writeTimeReadable} - - - {s"$shuffleWriteReadable / $shuffleWriteRecords"} - - }} + Seq(("Write Time", ""), ("Shuffle Write Size / Records", "")) + } else { + Nil + }} ++ {if (hasBytesSpilled) { - - {memoryBytesSpilledReadable} - - - {diskBytesSpilledReadable} - - }} - {errorMessageCell(errorMessage)} - + Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) + } else { + Nil + }} ++ + Seq(("Errors", "")) + + if (!taskHeadersAndCssClasses.map(_._1).contains(sortColumn)) { + throw new IllegalArgumentException(s"Unknown column: $sortColumn") + } + + val headerRow: Seq[Node] = { + taskHeadersAndCssClasses.map { case (header, cssClass) => + if (header == sortColumn) { + val headerLink = + s"$basePath&task.sort=${URLEncoder.encode(header, "UTF-8")}&task.desc=${!desc}" + + s"&task.pageSize=${pageSize}" + val js = Unparsed(s"window.location.href='${headerLink}'") + val arrow = if (desc) "▾" else "▴" // UP or DOWN + + {header} +  {Unparsed(arrow)} + + } else { + val headerLink = + s"$basePath&task.sort=${URLEncoder.encode(header, "UTF-8")}&task.pageSize=${pageSize}" + val js = Unparsed(s"window.location.href='${headerLink}'") + + {header} + + } + } } + {headerRow} } - private def errorMessageCell(errorMessage: Option[String]): Seq[Node] = { - val error = errorMessage.getOrElse("") + def row(task: TaskTableRowData): Seq[Node] = { + + {task.index} + {task.taskId} + {if (task.speculative) s"${task.attempt} (speculative)" else task.attempt.toString} + {task.status} + {task.taskLocality} + {task.executorIdAndHost} + {UIUtils.formatDate(new Date(task.launchTime))} + {task.formatDuration} + + {UIUtils.formatDuration(task.schedulerDelay)} + + + {UIUtils.formatDuration(task.taskDeserializationTime)} + + + {if (task.gcTime > 0) UIUtils.formatDuration(task.gcTime) else ""} + + + {UIUtils.formatDuration(task.serializationTime)} + + + {UIUtils.formatDuration(task.gettingResultTime)} + + {if (displayPeakExecutionMemory) { + + {Utils.bytesToString(task.peakExecutionMemoryUsed)} + + }} + {if (task.accumulators.nonEmpty) { + {Unparsed(task.accumulators.get)} + }} + {if (task.input.nonEmpty) { + {task.input.get.inputReadable} + }} + {if (task.output.nonEmpty) { + {task.output.get.outputReadable} + }} + {if (task.shuffleRead.nonEmpty) { + + {task.shuffleRead.get.shuffleReadBlockedTimeReadable} + + {task.shuffleRead.get.shuffleReadReadable} + + {task.shuffleRead.get.shuffleReadRemoteReadable} + + }} + {if (task.shuffleWrite.nonEmpty) { + {task.shuffleWrite.get.writeTimeReadable} + {task.shuffleWrite.get.shuffleWriteReadable} + }} + {if (task.bytesSpilled.nonEmpty) { + {task.bytesSpilled.get.memoryBytesSpilledReadable} + {task.bytesSpilled.get.diskBytesSpilledReadable} + }} + {errorMessageCell(task.error)} + + } + + private def errorMessageCell(error: String): Seq[Node] = { val isMultiline = error.indexOf('\n') >= 0 // Display the first line by default val errorSummary = StringEscapeUtils.escapeHtml4( @@ -843,33 +1385,4 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { } {errorSummary}{details} } - - private def getGettingResultTime(info: TaskInfo): Long = { - if (info.gettingResultTime > 0) { - if (info.finishTime > 0) { - info.finishTime - info.gettingResultTime - } else { - // The task is still fetching the result. - System.currentTimeMillis - info.gettingResultTime - } - } else { - 0L - } - } - - private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = { - val totalExecutionTime = - if (info.gettingResult) { - info.gettingResultTime - info.launchTime - } else if (info.finished) { - info.finishTime - info.launchTime - } else { - 0 - } - val executorOverhead = (metrics.executorDeserializeTime + - metrics.resultSerializationTime) - math.max( - 0, - totalExecutionTime - metrics.executorRunTime - executorOverhead - getGettingResultTime(info)) - } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala index 9bf67db8acde..d2dfc5a32915 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala @@ -31,4 +31,5 @@ private[spark] object TaskDetailsClassNames { val SHUFFLE_READ_REMOTE_SIZE = "shuffle_read_remote" val RESULT_SERIALIZATION_TIME = "serialization_time" val GETTING_RESULT_TIME = "getting_result_time" + val PEAK_EXECUTION_MEMORY = "peak_execution_memory" } diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index ffea9817c0b0..81f168a447ea 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -18,7 +18,7 @@ package org.apache.spark.ui.scope import scala.collection.mutable -import scala.collection.mutable.ListBuffer +import scala.collection.mutable.{StringBuilder, ListBuffer} import org.apache.spark.Logging import org.apache.spark.scheduler.StageInfo @@ -167,7 +167,7 @@ private[ui] object RDDOperationGraph extends Logging { def makeDotFile(graph: RDDOperationGraph): String = { val dotFile = new StringBuilder dotFile.append("digraph G {\n") - dotFile.append(makeDotSubgraph(graph.rootCluster, indent = " ")) + makeDotSubgraph(dotFile, graph.rootCluster, indent = " ") graph.edges.foreach { edge => dotFile.append(s""" ${edge.fromId}->${edge.toId};\n""") } dotFile.append("}") val result = dotFile.toString() @@ -180,18 +180,19 @@ private[ui] object RDDOperationGraph extends Logging { s"""${node.id} [label="${node.name} [${node.id}]"]""" } - /** Return the dot representation of a subgraph in an RDDOperationGraph. */ - private def makeDotSubgraph(cluster: RDDOperationCluster, indent: String): String = { - val subgraph = new StringBuilder - subgraph.append(indent + s"subgraph cluster${cluster.id} {\n") - subgraph.append(indent + s""" label="${cluster.name}";\n""") + /** Update the dot representation of the RDDOperationGraph in cluster to subgraph. */ + private def makeDotSubgraph( + subgraph: StringBuilder, + cluster: RDDOperationCluster, + indent: String): Unit = { + subgraph.append(indent).append(s"subgraph cluster${cluster.id} {\n") + subgraph.append(indent).append(s""" label="${cluster.name}";\n""") cluster.childNodes.foreach { node => - subgraph.append(indent + s" ${makeDotNode(node)};\n") + subgraph.append(indent).append(s" ${makeDotNode(node)};\n") } cluster.childClusters.foreach { cscope => - subgraph.append(makeDotSubgraph(cscope, indent + " ")) + makeDotSubgraph(subgraph, cscope, indent + " ") } - subgraph.append(indent + "}\n") - subgraph.toString() + subgraph.append(indent).append("}\n") } } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 36943978ff59..fd6cc3ed759b 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -17,12 +17,13 @@ package org.apache.spark.ui.storage +import java.net.URLEncoder import javax.servlet.http.HttpServletRequest -import scala.xml.Node +import scala.xml.{Node, Unparsed} import org.apache.spark.status.api.v1.{AllRDDResource, RDDDataDistribution, RDDPartitionInfo} -import org.apache.spark.ui.{UIUtils, WebUIPage} +import org.apache.spark.ui.{PagedDataSource, PagedTable, UIUtils, WebUIPage} import org.apache.spark.util.Utils /** Page showing storage details for a given RDD */ @@ -32,6 +33,17 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { def render(request: HttpServletRequest): Seq[Node] = { val parameterId = request.getParameter("id") require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") + + val parameterBlockPage = request.getParameter("block.page") + val parameterBlockSortColumn = request.getParameter("block.sort") + val parameterBlockSortDesc = request.getParameter("block.desc") + val parameterBlockPageSize = request.getParameter("block.pageSize") + + val blockPage = Option(parameterBlockPage).map(_.toInt).getOrElse(1) + val blockSortColumn = Option(parameterBlockSortColumn).getOrElse("Block Name") + val blockSortDesc = Option(parameterBlockSortDesc).map(_.toBoolean).getOrElse(false) + val blockPageSize = Option(parameterBlockPageSize).map(_.toInt).getOrElse(100) + val rddId = parameterId.toInt val rddStorageInfo = AllRDDResource.getRDDStorageInfo(rddId, listener, includeDetails = true) .getOrElse { @@ -44,8 +56,34 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { rddStorageInfo.dataDistribution.get, id = Some("rdd-storage-by-worker-table")) // Block table - val blockTable = UIUtils.listingTable(blockHeader, blockRow, rddStorageInfo.partitions.get, - id = Some("rdd-storage-by-block-table")) + val (blockTable, blockTableHTML) = try { + val _blockTable = new BlockPagedTable( + UIUtils.prependBaseUri(parent.basePath) + s"/storage/rdd/?id=${rddId}", + rddStorageInfo.partitions.get, + blockPageSize, + blockSortColumn, + blockSortDesc) + (_blockTable, _blockTable.table(blockPage)) + } catch { + case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => + (null,
    {e.getMessage}
    ) + } + + val jsForScrollingDownToBlockTable = + val content =
    @@ -85,11 +123,11 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") {
    -
    -
    -

    {rddStorageInfo.partitions.map(_.size).getOrElse(0)} Partitions

    - {blockTable} -
    +
    +

    + {rddStorageInfo.partitions.map(_.size).getOrElse(0)} Partitions +

    + {blockTableHTML ++ jsForScrollingDownToBlockTable}
    ; UIUtils.headerSparkPage("RDD Storage Info for " + rddStorageInfo.name, content, parent) @@ -101,14 +139,6 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { "Memory Usage", "Disk Usage") - /** Header fields for the block table */ - private def blockHeader = Seq( - "Block Name", - "Storage Level", - "Size in Memory", - "Size on Disk", - "Executors") - /** Render an HTML row representing a worker */ private def workerRow(worker: RDDDataDistribution): Seq[Node] = { @@ -120,23 +150,157 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { {Utils.bytesToString(worker.diskUsed)} } +} + +private[ui] case class BlockTableRowData( + blockName: String, + storageLevel: String, + memoryUsed: Long, + diskUsed: Long, + executors: String) + +private[ui] class BlockDataSource( + rddPartitions: Seq[RDDPartitionInfo], + pageSize: Int, + sortColumn: String, + desc: Boolean) extends PagedDataSource[BlockTableRowData](pageSize) { + + private val data = rddPartitions.map(blockRow).sorted(ordering(sortColumn, desc)) + + override def dataSize: Int = data.size + + override def sliceData(from: Int, to: Int): Seq[BlockTableRowData] = { + data.slice(from, to) + } + + private def blockRow(rddPartition: RDDPartitionInfo): BlockTableRowData = { + BlockTableRowData( + rddPartition.blockName, + rddPartition.storageLevel, + rddPartition.memoryUsed, + rddPartition.diskUsed, + rddPartition.executors.mkString(" ")) + } + + /** + * Return Ordering according to sortColumn and desc + */ + private def ordering(sortColumn: String, desc: Boolean): Ordering[BlockTableRowData] = { + val ordering = sortColumn match { + case "Block Name" => new Ordering[BlockTableRowData] { + override def compare(x: BlockTableRowData, y: BlockTableRowData): Int = + Ordering.String.compare(x.blockName, y.blockName) + } + case "Storage Level" => new Ordering[BlockTableRowData] { + override def compare(x: BlockTableRowData, y: BlockTableRowData): Int = + Ordering.String.compare(x.storageLevel, y.storageLevel) + } + case "Size in Memory" => new Ordering[BlockTableRowData] { + override def compare(x: BlockTableRowData, y: BlockTableRowData): Int = + Ordering.Long.compare(x.memoryUsed, y.memoryUsed) + } + case "Size on Disk" => new Ordering[BlockTableRowData] { + override def compare(x: BlockTableRowData, y: BlockTableRowData): Int = + Ordering.Long.compare(x.diskUsed, y.diskUsed) + } + case "Executors" => new Ordering[BlockTableRowData] { + override def compare(x: BlockTableRowData, y: BlockTableRowData): Int = + Ordering.String.compare(x.executors, y.executors) + } + case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") + } + if (desc) { + ordering.reverse + } else { + ordering + } + } +} + +private[ui] class BlockPagedTable( + basePath: String, + rddPartitions: Seq[RDDPartitionInfo], + pageSize: Int, + sortColumn: String, + desc: Boolean) extends PagedTable[BlockTableRowData] { + + override def tableId: String = "rdd-storage-by-block-table" + + override def tableCssClass: String = "table table-bordered table-condensed table-striped" + + override val dataSource: BlockDataSource = new BlockDataSource( + rddPartitions, + pageSize, + sortColumn, + desc) + + override def pageLink(page: Int): String = { + val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") + s"${basePath}&block.page=$page&block.sort=${encodedSortColumn}&block.desc=${desc}" + + s"&block.pageSize=${pageSize}" + } + + override def goButtonJavascriptFunction: (String, String) = { + val jsFuncName = "goToBlockPage" + val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") + val jsFunc = s""" + |currentBlockPageSize = ${pageSize} + |function goToBlockPage(page, pageSize) { + | // Set page to 1 if the page size changes + | page = pageSize == currentBlockPageSize ? page : 1; + | var url = "${basePath}&block.sort=${encodedSortColumn}&block.desc=${desc}" + + | "&block.page=" + page + "&block.pageSize=" + pageSize; + | window.location.href = url; + |} + """.stripMargin + (jsFuncName, jsFunc) + } - /** Render an HTML row representing a block */ - private def blockRow(row: RDDPartitionInfo): Seq[Node] = { + override def headers: Seq[Node] = { + val blockHeaders = Seq( + "Block Name", + "Storage Level", + "Size in Memory", + "Size on Disk", + "Executors") + + if (!blockHeaders.contains(sortColumn)) { + throw new IllegalArgumentException(s"Unknown column: $sortColumn") + } + + val headerRow: Seq[Node] = { + blockHeaders.map { header => + if (header == sortColumn) { + val headerLink = + s"$basePath&block.sort=${URLEncoder.encode(header, "UTF-8")}&block.desc=${!desc}" + + s"&block.pageSize=${pageSize}" + val js = Unparsed(s"window.location.href='${headerLink}'") + val arrow = if (desc) "▾" else "▴" // UP or DOWN + + {header} +  {Unparsed(arrow)} + + } else { + val headerLink = + s"$basePath&block.sort=${URLEncoder.encode(header, "UTF-8")}" + + s"&block.pageSize=${pageSize}" + val js = Unparsed(s"window.location.href='${headerLink}'") + + {header} + + } + } + } + {headerRow} + } + + override def row(block: BlockTableRowData): Seq[Node] = { - {row.blockName} - - {row.storageLevel} - - - {Utils.bytesToString(row.memoryUsed)} - - - {Utils.bytesToString(row.diskUsed)} - - - {row.executors.map(l => {l}
    )} - + {block.blockName} + {block.storageLevel} + {Utils.bytesToString(block.memoryUsed)} + {Utils.bytesToString(block.diskUsed)} + {block.executors} } } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 07db783c572c..04f584621e71 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -21,7 +21,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.storage.RDDInfo +import org.apache.spark.storage._ import org.apache.spark.ui.{UIUtils, WebUIPage} import org.apache.spark.util.Utils @@ -30,13 +30,25 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { - val rdds = listener.rddInfoList - val content = UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table")) + val content = rddTable(listener.rddInfoList) ++ + receiverBlockTables(listener.allExecutorStreamBlockStatus.sortBy(_.executorId)) UIUtils.headerSparkPage("Storage", content, parent) } + private[storage] def rddTable(rdds: Seq[RDDInfo]): Seq[Node] = { + if (rdds.isEmpty) { + // Don't show the rdd table if there is no RDD persisted. + Nil + } else { +
    +

    RDDs

    + {UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table"))} +
    + } + } + /** Header fields for the RDD table */ - private def rddHeader = Seq( + private val rddHeader = Seq( "RDD Name", "Storage Level", "Cached Partitions", @@ -56,7 +68,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { {rdd.storageLevel.description} - {rdd.numCachedPartitions} + {rdd.numCachedPartitions.toString} {"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)} {Utils.bytesToString(rdd.memSize)} {Utils.bytesToString(rdd.externalBlockStoreSize)} @@ -64,4 +76,130 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { // scalastyle:on } + + private[storage] def receiverBlockTables(statuses: Seq[ExecutorStreamBlockStatus]): Seq[Node] = { + if (statuses.map(_.numStreamBlocks).sum == 0) { + // Don't show the tables if there is no stream block + Nil + } else { + val blocks = statuses.flatMap(_.blocks).groupBy(_.blockId).toSeq.sortBy(_._1.toString) + +
    +

    Receiver Blocks

    + {executorMetricsTable(statuses)} + {streamBlockTable(blocks)} +
    + } + } + + private def executorMetricsTable(statuses: Seq[ExecutorStreamBlockStatus]): Seq[Node] = { +
    +
    Aggregated Block Metrics by Executor
    + {UIUtils.listingTable(executorMetricsTableHeader, executorMetricsTableRow, statuses, + id = Some("storage-by-executor-stream-blocks"))} +
    + } + + private val executorMetricsTableHeader = Seq( + "Executor ID", + "Address", + "Total Size in Memory", + "Total Size in ExternalBlockStore", + "Total Size on Disk", + "Stream Blocks") + + private def executorMetricsTableRow(status: ExecutorStreamBlockStatus): Seq[Node] = { + + + {status.executorId} + + + {status.location} + + + {Utils.bytesToString(status.totalMemSize)} + + + {Utils.bytesToString(status.totalExternalBlockStoreSize)} + + + {Utils.bytesToString(status.totalDiskSize)} + + + {status.numStreamBlocks.toString} + + + } + + private def streamBlockTable(blocks: Seq[(BlockId, Seq[BlockUIData])]): Seq[Node] = { + if (blocks.isEmpty) { + Nil + } else { +
    +
    Blocks
    + {UIUtils.listingTable( + streamBlockTableHeader, + streamBlockTableRow, + blocks, + id = Some("storage-by-block-table"), + sortable = false)} +
    + } + } + + private val streamBlockTableHeader = Seq( + "Block ID", + "Replication Level", + "Location", + "Storage Level", + "Size") + + /** Render a stream block */ + private def streamBlockTableRow(block: (BlockId, Seq[BlockUIData])): Seq[Node] = { + val replications = block._2 + assert(replications.size > 0) // This must be true because it's the result of "groupBy" + if (replications.size == 1) { + streamBlockTableSubrow(block._1, replications.head, replications.size, true) + } else { + streamBlockTableSubrow(block._1, replications.head, replications.size, true) ++ + replications.tail.map(streamBlockTableSubrow(block._1, _, replications.size, false)).flatten + } + } + + private def streamBlockTableSubrow( + blockId: BlockId, block: BlockUIData, replication: Int, firstSubrow: Boolean): Seq[Node] = { + val (storageLevel, size) = streamBlockStorageLevelDescriptionAndSize(block) + + + { + if (firstSubrow) { + + {block.blockId.toString} + + + {replication.toString} + + } + } + {block.location} + {storageLevel} + {Utils.bytesToString(size)} + + } + + private[storage] def streamBlockStorageLevelDescriptionAndSize( + block: BlockUIData): (String, Long) = { + if (block.storageLevel.useDisk) { + ("Disk", block.diskSize) + } else if (block.storageLevel.useMemory && block.storageLevel.deserialized) { + ("Memory", block.memSize) + } else if (block.storageLevel.useMemory && !block.storageLevel.deserialized) { + ("Memory Serialized", block.memSize) + } else if (block.storageLevel.useOffHeap) { + ("External", block.externalBlockStoreSize) + } else { + throw new IllegalStateException(s"Invalid Storage Level: ${block.storageLevel}") + } + } + } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index 035174970096..22e2993b3b5b 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -39,7 +39,8 @@ private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storag * This class is thread-safe (unlike JobProgressListener) */ @DeveloperApi -class StorageListener(storageStatusListener: StorageStatusListener) extends SparkListener { +class StorageListener(storageStatusListener: StorageStatusListener) extends BlockStatusListener { + private[ui] val _rddInfoMap = mutable.Map[Int, RDDInfo]() // exposed for testing def storageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 96aa2fe16470..1738258a0c79 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -17,9 +17,7 @@ package org.apache.spark.util -import scala.collection.JavaConversions.mapAsJavaMap -import scala.concurrent.Await -import scala.concurrent.duration.FiniteDuration +import scala.collection.JavaConverters._ import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem} import akka.pattern.ask @@ -28,6 +26,7 @@ import com.typesafe.config.ConfigFactory import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException} +import org.apache.spark.rpc.RpcTimeout /** * Various utility classes for working with Akka. @@ -93,7 +92,7 @@ private[spark] object AkkaUtils extends Logging { val akkaSslConfig = securityManager.akkaSSLOptions.createAkkaConfig .getOrElse(ConfigFactory.empty()) - val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap[String, String]) + val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap.asJava) .withFallback(akkaSslConfig).withFallback(ConfigFactory.parseString( s""" |akka.daemonic = on @@ -129,7 +128,7 @@ private[spark] object AkkaUtils extends Logging { /** Returns the configured max frame size for Akka messages in bytes. */ def maxFrameSizeBytes(conf: SparkConf): Int = { - val frameSizeInMB = conf.getInt("spark.akka.frameSize", 10) + val frameSizeInMB = conf.getInt("spark.akka.frameSize", 128) if (frameSizeInMB > AKKA_MAX_FRAME_SIZE_IN_MB) { throw new IllegalArgumentException( s"spark.akka.frameSize should not be greater than $AKKA_MAX_FRAME_SIZE_IN_MB MB") @@ -147,7 +146,7 @@ private[spark] object AkkaUtils extends Logging { def askWithReply[T]( message: Any, actor: ActorRef, - timeout: FiniteDuration): T = { + timeout: RpcTimeout): T = { askWithReply[T](message, actor, maxAttempts = 1, retryInterval = Int.MaxValue, timeout) } @@ -160,7 +159,7 @@ private[spark] object AkkaUtils extends Logging { actor: ActorRef, maxAttempts: Int, retryInterval: Long, - timeout: FiniteDuration): T = { + timeout: RpcTimeout): T = { // TODO: Consider removing multiple attempts if (actor == null) { throw new SparkException(s"Error sending message [message = $message]" + @@ -171,8 +170,8 @@ private[spark] object AkkaUtils extends Logging { while (attempts < maxAttempts) { attempts += 1 try { - val future = actor.ask(message)(timeout) - val result = Await.result(future, timeout) + val future = actor.ask(message)(timeout.duration) + val result = timeout.awaitResult(future) if (result == null) { throw new SparkException("Actor returned null") } @@ -198,9 +197,9 @@ private[spark] object AkkaUtils extends Logging { val driverPort: Int = conf.getInt("spark.driver.port", 7077) Utils.checkHost(driverHost, "Expected hostname") val url = address(protocol(actorSystem), driverActorSystemName, driverHost, driverPort, name) - val timeout = RpcUtils.lookupTimeout(conf) + val timeout = RpcUtils.lookupRpcTimeout(conf) logInfo(s"Connecting to $name: $url") - Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration)) } def makeExecutorRef( @@ -212,9 +211,9 @@ private[spark] object AkkaUtils extends Logging { val executorActorSystemName = SparkEnv.executorActorSystemName Utils.checkHost(host, "Expected hostname") val url = address(protocol(actorSystem), executorActorSystemName, host, port, name) - val timeout = RpcUtils.lookupTimeout(conf) + val timeout = RpcUtils.lookupRpcTimeout(conf) logInfo(s"Connecting to $name: $url") - Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration)) } def protocol(actorSystem: ActorSystem): String = { diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 305de4c75539..1b49dca9dc78 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -49,45 +49,28 @@ private[spark] object ClosureCleaner extends Logging { cls.getName.contains("$anonfun$") } - // Get a list of the classes of the outer objects of a given closure object, obj; + // Get a list of the outer objects and their classes of a given closure object, obj; // the outer objects are defined as any closures that obj is nested within, plus // possibly the class that the outermost closure is in, if any. We stop searching // for outer objects beyond that because cloning the user's object is probably // not a good idea (whereas we can clone closure objects just fine since we // understand how all their fields are used). - private def getOuterClasses(obj: AnyRef): List[Class[_]] = { + private def getOuterClassesAndObjects(obj: AnyRef): (List[Class[_]], List[AnyRef]) = { for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { f.setAccessible(true) val outer = f.get(obj) // The outer pointer may be null if we have cleaned this closure before if (outer != null) { if (isClosure(f.getType)) { - return f.getType :: getOuterClasses(outer) + val recurRet = getOuterClassesAndObjects(outer) + return (f.getType :: recurRet._1, outer :: recurRet._2) } else { - return f.getType :: Nil // Stop at the first $outer that is not a closure + return (f.getType :: Nil, outer :: Nil) // Stop at the first $outer that is not a closure } } } - Nil + (Nil, Nil) } - - // Get a list of the outer objects for a given closure object. - private def getOuterObjects(obj: AnyRef): List[AnyRef] = { - for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { - f.setAccessible(true) - val outer = f.get(obj) - // The outer pointer may be null if we have cleaned this closure before - if (outer != null) { - if (isClosure(f.getType)) { - return outer :: getOuterObjects(outer) - } else { - return outer :: Nil // Stop at the first $outer that is not a closure - } - } - } - Nil - } - /** * Return a list of classes that represent closures enclosed in the given closure object. */ @@ -111,7 +94,7 @@ private[spark] object ClosureCleaner extends Logging { if (cls.isPrimitive) { cls match { case java.lang.Boolean.TYPE => new java.lang.Boolean(false) - case java.lang.Character.TYPE => new java.lang.Character('\0') + case java.lang.Character.TYPE => new java.lang.Character('\u0000') case java.lang.Void.TYPE => // This should not happen because `Foo(void x) {}` does not compile. throw new IllegalStateException("Unexpected void parameter in constructor") @@ -198,15 +181,14 @@ private[spark] object ClosureCleaner extends Logging { return } - logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}}) +++") + logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++") // A list of classes that represents closures enclosed in the given one val innerClasses = getInnerClosureClasses(func) // A list of enclosing objects and their respective classes, from innermost to outermost // An outer object at a given index is of type outer class at the same index - val outerClasses = getOuterClasses(func) - val outerObjects = getOuterObjects(func) + val (outerClasses, outerObjects) = getOuterClassesAndObjects(func) // For logging purposes only val declaredFields = func.getClass.getDeclaredFields @@ -448,10 +430,12 @@ private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM if (op == INVOKESPECIAL && name == "" && argTypes.length > 0 && argTypes(0).toString.startsWith("L") // is it an object? && argTypes(0).getInternalName == myName) { + // scalastyle:off classforname output += Class.forName( owner.replace('/', '.'), false, Thread.currentThread.getContextClassLoader) + // scalastyle:on classforname } } } diff --git a/core/src/main/scala/org/apache/spark/util/Distribution.scala b/core/src/main/scala/org/apache/spark/util/Distribution.scala index 1bab707235b8..950b69f7db64 100644 --- a/core/src/main/scala/org/apache/spark/util/Distribution.scala +++ b/core/src/main/scala/org/apache/spark/util/Distribution.scala @@ -52,9 +52,11 @@ private[spark] class Distribution(val data: Array[Double], val startIdx: Int, va } def showQuantiles(out: PrintStream = System.out): Unit = { + // scalastyle:off println out.println("min\t25%\t50%\t75%\tmax") getQuantiles(defaultProbabilities).foreach{q => out.print(q + "\t")} out.println + // scalastyle:on println } def statCounter: StatCounter = StatCounter(data.slice(startIdx, endIdx)) @@ -64,8 +66,10 @@ private[spark] class Distribution(val data: Array[Double], val startIdx: Int, va * @param out */ def summary(out: PrintStream = System.out) { + // scalastyle:off println out.println(statCounter) showQuantiles(out) + // scalastyle:on println } } @@ -80,8 +84,10 @@ private[spark] object Distribution { } def showQuantiles(out: PrintStream = System.out, quantiles: Traversable[Double]) { + // scalastyle:off println out.println("min\t25%\t50%\t75%\tmax") quantiles.foreach{q => out.print(q + "\t")} out.println + // scalastyle:on println } } diff --git a/core/src/main/scala/org/apache/spark/util/IdGenerator.scala b/core/src/main/scala/org/apache/spark/util/IdGenerator.scala index 17e55f7996bf..53934ad4ce47 100644 --- a/core/src/main/scala/org/apache/spark/util/IdGenerator.scala +++ b/core/src/main/scala/org/apache/spark/util/IdGenerator.scala @@ -22,10 +22,10 @@ import java.util.concurrent.atomic.AtomicInteger /** * A util used to get a unique generation ID. This is a wrapper around Java's * AtomicInteger. An example usage is in BlockManager, where each BlockManager - * instance would start an Akka actor and we use this utility to assign the Akka - * actors unique names. + * instance would start an RpcEndpoint and we use this utility to assign the RpcEndpoints' + * unique names. */ private[spark] class IdGenerator { - private var id = new AtomicInteger + private val id = new AtomicInteger def next: Int = id.incrementAndGet } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index adf69a4e78e7..99614a786bd9 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -92,8 +92,10 @@ private[spark] object JsonProtocol { executorRemovedToJson(executorRemoved) case logStart: SparkListenerLogStart => logStartToJson(logStart) - // These aren't used, but keeps compiler happy - case SparkListenerExecutorMetricsUpdate(_, _) => JNothing + case metricsUpdate: SparkListenerExecutorMetricsUpdate => + executorMetricsUpdateToJson(metricsUpdate) + case blockUpdated: SparkListenerBlockUpdated => + throw new MatchError(blockUpdated) // TODO(ekl) implement this } } @@ -224,6 +226,19 @@ private[spark] object JsonProtocol { ("Spark Version" -> SPARK_VERSION) } + def executorMetricsUpdateToJson(metricsUpdate: SparkListenerExecutorMetricsUpdate): JValue = { + val execId = metricsUpdate.execId + val taskMetrics = metricsUpdate.taskMetrics + ("Event" -> Utils.getFormattedClassName(metricsUpdate)) ~ + ("Executor ID" -> execId) ~ + ("Metrics Updated" -> taskMetrics.map { case (taskId, stageId, stageAttemptId, metrics) => + ("Task ID" -> taskId) ~ + ("Stage ID" -> stageId) ~ + ("Stage Attempt ID" -> stageAttemptId) ~ + ("Task Metrics" -> taskMetricsToJson(metrics)) + }) + } + /** ------------------------------------------------------------------- * * JSON serialization methods for classes SparkListenerEvents depend on | * -------------------------------------------------------------------- */ @@ -251,7 +266,7 @@ private[spark] object JsonProtocol { def taskInfoToJson(taskInfo: TaskInfo): JValue = { ("Task ID" -> taskInfo.taskId) ~ ("Index" -> taskInfo.index) ~ - ("Attempt" -> taskInfo.attempt) ~ + ("Attempt" -> taskInfo.attemptNumber) ~ ("Launch Time" -> taskInfo.launchTime) ~ ("Executor ID" -> taskInfo.executorId) ~ ("Host" -> taskInfo.host) ~ @@ -347,8 +362,9 @@ private[spark] object JsonProtocol { ("Stack Trace" -> stackTrace) ~ ("Full Stack Trace" -> exceptionFailure.fullStackTrace) ~ ("Metrics" -> metrics) - case ExecutorLostFailure(executorId) => - ("Executor ID" -> executorId) + case ExecutorLostFailure(executorId, isNormalExit) => + ("Executor ID" -> executorId) ~ + ("Normal Exit" -> isNormalExit) case _ => Utils.emptyJson } ("Reason" -> reason) ~ json @@ -463,6 +479,7 @@ private[spark] object JsonProtocol { val executorAdded = Utils.getFormattedClassName(SparkListenerExecutorAdded) val executorRemoved = Utils.getFormattedClassName(SparkListenerExecutorRemoved) val logStart = Utils.getFormattedClassName(SparkListenerLogStart) + val metricsUpdate = Utils.getFormattedClassName(SparkListenerExecutorMetricsUpdate) (json \ "Event").extract[String] match { case `stageSubmitted` => stageSubmittedFromJson(json) @@ -481,6 +498,7 @@ private[spark] object JsonProtocol { case `executorAdded` => executorAddedFromJson(json) case `executorRemoved` => executorRemovedFromJson(json) case `logStart` => logStartFromJson(json) + case `metricsUpdate` => executorMetricsUpdateFromJson(json) } } @@ -598,6 +616,18 @@ private[spark] object JsonProtocol { SparkListenerLogStart(sparkVersion) } + def executorMetricsUpdateFromJson(json: JValue): SparkListenerExecutorMetricsUpdate = { + val execInfo = (json \ "Executor ID").extract[String] + val taskMetrics = (json \ "Metrics Updated").extract[List[JValue]].map { json => + val taskId = (json \ "Task ID").extract[Long] + val stageId = (json \ "Stage ID").extract[Int] + val stageAttemptId = (json \ "Stage Attempt ID").extract[Int] + val metrics = taskMetricsFromJson(json \ "Task Metrics") + (taskId, stageId, stageAttemptId, metrics) + } + SparkListenerExecutorMetricsUpdate(execInfo, taskMetrics) + } + /** --------------------------------------------------------------------- * * JSON deserialization methods for classes SparkListenerEvents depend on | * ---------------------------------------------------------------------- */ @@ -761,12 +791,14 @@ private[spark] object JsonProtocol { val fullStackTrace = Utils.jsonOption(json \ "Full Stack Trace"). map(_.extract[String]).orNull val metrics = Utils.jsonOption(json \ "Metrics").map(taskMetricsFromJson) - ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) + ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics, None) case `taskResultLost` => TaskResultLost case `taskKilled` => TaskKilled case `executorLostFailure` => + val isNormalExit = Utils.jsonOption(json \ "Normal Exit"). + map(_.extract[Boolean]) val executorId = Utils.jsonOption(json \ "Executor ID").map(_.extract[String]) - ExecutorLostFailure(executorId.getOrElse("Unknown")) + ExecutorLostFailure(executorId.getOrElse("Unknown"), isNormalExit.getOrElse(false)) case `unknownReason` => UnknownReason } } diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index a725767d08cc..13cb516b583e 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -19,12 +19,11 @@ package org.apache.spark.util import java.util.concurrent.CopyOnWriteArrayList -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark.Logging -import org.apache.spark.scheduler.SparkListener /** * An event bus which posts events to its listeners. @@ -46,7 +45,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { * `postToAll` in the same thread for all events. */ final def postToAll(event: E): Unit = { - // JavaConversions will create a JIterableWrapper if we use some Scala collection functions. + // JavaConverters can create a JIterableWrapper if we use asScala. // However, this method will be called frequently. To avoid the wrapper cost, here ewe use // Java Iterator directly. val iter = listeners.iterator @@ -69,7 +68,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { private[spark] def findListenersByClass[T <: L : ClassTag](): Seq[T] = { val c = implicitly[ClassTag[T]].runtimeClass - listeners.filter(_.getClass == c).map(_.asInstanceOf[T]).toSeq + listeners.asScala.filter(_.getClass == c).map(_.asInstanceOf[T]).toSeq } } diff --git a/core/src/main/scala/org/apache/spark/util/ManualClock.scala b/core/src/main/scala/org/apache/spark/util/ManualClock.scala index 171855406198..e7a65d74a440 100644 --- a/core/src/main/scala/org/apache/spark/util/ManualClock.scala +++ b/core/src/main/scala/org/apache/spark/util/ManualClock.scala @@ -58,7 +58,7 @@ private[spark] class ManualClock(private var time: Long) extends Clock { */ def waitTillTime(targetTime: Long): Long = synchronized { while (time < targetTime) { - wait(100) + wait(10) } getTimeMillis() } diff --git a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala index 169489df6c1e..a1c33212cdb2 100644 --- a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala @@ -21,8 +21,6 @@ import java.net.{URLClassLoader, URL} import java.util.Enumeration import java.util.concurrent.ConcurrentHashMap -import scala.collection.JavaConversions._ - /** * URL class loader that exposes the `addURL` and `getURLs` methods in URLClassLoader. */ diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala index f16cc8e7e42c..7578a3b1d85f 100644 --- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -17,11 +17,11 @@ package org.apache.spark.util -import scala.concurrent.duration._ +import scala.concurrent.duration.FiniteDuration import scala.language.postfixOps import org.apache.spark.{SparkEnv, SparkConf} -import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, RpcTimeout} object RpcUtils { @@ -47,14 +47,22 @@ object RpcUtils { } /** Returns the default Spark timeout to use for RPC ask operations. */ + private[spark] def askRpcTimeout(conf: SparkConf): RpcTimeout = { + RpcTimeout(conf, Seq("spark.rpc.askTimeout", "spark.network.timeout"), "120s") + } + + @deprecated("use askRpcTimeout instead, this method was not intended to be public", "1.5.0") def askTimeout(conf: SparkConf): FiniteDuration = { - conf.getTimeAsSeconds("spark.rpc.askTimeout", - conf.get("spark.network.timeout", "120s")) seconds + askRpcTimeout(conf).duration } /** Returns the default Spark timeout to use for RPC remote endpoint lookup. */ + private[spark] def lookupRpcTimeout(conf: SparkConf): RpcTimeout = { + RpcTimeout(conf, Seq("spark.rpc.lookupTimeout", "spark.network.timeout"), "120s") + } + + @deprecated("use lookupRpcTimeout instead, this method was not intended to be public", "1.5.0") def lookupTimeout(conf: SparkConf): FiniteDuration = { - conf.getTimeAsSeconds("spark.rpc.lookupTimeout", - conf.get("spark.network.timeout", "120s")) seconds + lookupRpcTimeout(conf).duration } } diff --git a/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala b/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala index 30bcf1d2f24d..3354a923273f 100644 --- a/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala +++ b/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala @@ -20,8 +20,6 @@ import java.io.{ObjectInputStream, ObjectOutputStream} import org.apache.hadoop.conf.Configuration -import org.apache.spark.util.Utils - private[spark] class SerializableConfiguration(@transient var value: Configuration) extends Serializable { private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { diff --git a/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala b/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala index afbcc6efc850..cadae472b3f8 100644 --- a/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala +++ b/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala @@ -21,8 +21,6 @@ import java.io.{ObjectInputStream, ObjectOutputStream} import org.apache.hadoop.mapred.JobConf -import org.apache.spark.util.Utils - private[spark] class SerializableJobConf(@transient var value: JobConf) extends Serializable { private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala new file mode 100644 index 000000000000..db4a8b304ec3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.File +import java.util.PriorityQueue + +import scala.util.{Failure, Success, Try} +import tachyon.client.TachyonFile + +import org.apache.hadoop.fs.FileSystem +import org.apache.spark.Logging + +/** + * Various utility methods used by Spark. + */ +private[spark] object ShutdownHookManager extends Logging { + val DEFAULT_SHUTDOWN_PRIORITY = 100 + + /** + * The shutdown priority of the SparkContext instance. This is lower than the default + * priority, so that by default hooks are run before the context is shut down. + */ + val SPARK_CONTEXT_SHUTDOWN_PRIORITY = 50 + + /** + * The shutdown priority of temp directory must be lower than the SparkContext shutdown + * priority. Otherwise cleaning the temp directories while Spark jobs are running can + * throw undesirable errors at the time of shutdown. + */ + val TEMP_DIR_SHUTDOWN_PRIORITY = 25 + + private lazy val shutdownHooks = { + val manager = new SparkShutdownHookManager() + manager.install() + manager + } + + private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() + private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() + + // Add a shutdown hook to delete the temp dirs when the JVM exits + addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () => + logInfo("Shutdown hook called") + shutdownDeletePaths.foreach { dirPath => + try { + logInfo("Deleting directory " + dirPath) + Utils.deleteRecursively(new File(dirPath)) + } catch { + case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e) + } + } + } + + // Register the path to be deleted via shutdown hook + def registerShutdownDeleteDir(file: File) { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths += absolutePath + } + } + + // Register the tachyon path to be deleted via shutdown hook + def registerShutdownDeleteDir(tachyonfile: TachyonFile) { + val absolutePath = tachyonfile.getPath() + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths += absolutePath + } + } + + // Remove the path to be deleted via shutdown hook + def removeShutdownDeleteDir(file: File) { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths.remove(absolutePath) + } + } + + // Remove the tachyon path to be deleted via shutdown hook + def removeShutdownDeleteDir(tachyonfile: TachyonFile) { + val absolutePath = tachyonfile.getPath() + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.remove(absolutePath) + } + } + + // Is the path already registered to be deleted via a shutdown hook ? + def hasShutdownDeleteDir(file: File): Boolean = { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths.contains(absolutePath) + } + } + + // Is the path already registered to be deleted via a shutdown hook ? + def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = { + val absolutePath = file.getPath() + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.contains(absolutePath) + } + } + + // Note: if file is child of some registered path, while not equal to it, then return true; + // else false. This is to ensure that two shutdown hooks do not try to delete each others + // paths - resulting in IOException and incomplete cleanup. + def hasRootAsShutdownDeleteDir(file: File): Boolean = { + val absolutePath = file.getAbsolutePath() + val retval = shutdownDeletePaths.synchronized { + shutdownDeletePaths.exists { path => + !absolutePath.equals(path) && absolutePath.startsWith(path) + } + } + if (retval) { + logInfo("path = " + file + ", already present as root for deletion.") + } + retval + } + + // Note: if file is child of some registered path, while not equal to it, then return true; + // else false. This is to ensure that two shutdown hooks do not try to delete each others + // paths - resulting in Exception and incomplete cleanup. + def hasRootAsShutdownDeleteDir(file: TachyonFile): Boolean = { + val absolutePath = file.getPath() + val retval = shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.exists { path => + !absolutePath.equals(path) && absolutePath.startsWith(path) + } + } + if (retval) { + logInfo("path = " + file + ", already present as root for deletion.") + } + retval + } + + /** + * Detect whether this thread might be executing a shutdown hook. Will always return true if + * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g. + * if System.exit was just called by a concurrent thread). + * + * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing + * an IllegalStateException. + */ + def inShutdown(): Boolean = { + try { + val hook = new Thread { + override def run() {} + } + Runtime.getRuntime.addShutdownHook(hook) + Runtime.getRuntime.removeShutdownHook(hook) + } catch { + case ise: IllegalStateException => return true + } + false + } + + /** + * Adds a shutdown hook with default priority. + * + * @param hook The code to run during shutdown. + * @return A handle that can be used to unregister the shutdown hook. + */ + def addShutdownHook(hook: () => Unit): AnyRef = { + addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY)(hook) + } + + /** + * Adds a shutdown hook with the given priority. Hooks with lower priority values run + * first. + * + * @param hook The code to run during shutdown. + * @return A handle that can be used to unregister the shutdown hook. + */ + def addShutdownHook(priority: Int)(hook: () => Unit): AnyRef = { + shutdownHooks.add(priority, hook) + } + + /** + * Remove a previously installed shutdown hook. + * + * @param ref A handle returned by `addShutdownHook`. + * @return Whether the hook was removed. + */ + def removeShutdownHook(ref: AnyRef): Boolean = { + shutdownHooks.remove(ref) + } + +} + +private [util] class SparkShutdownHookManager { + + private val hooks = new PriorityQueue[SparkShutdownHook]() + private var shuttingDown = false + + /** + * Install a hook to run at shutdown and run all registered hooks in order. Hadoop 1.x does not + * have `ShutdownHookManager`, so in that case we just use the JVM's `Runtime` object and hope for + * the best. + */ + def install(): Unit = { + val hookTask = new Runnable() { + override def run(): Unit = runAll() + } + Try(Utils.classForName("org.apache.hadoop.util.ShutdownHookManager")) match { + case Success(shmClass) => + val fsPriority = classOf[FileSystem] + .getField("SHUTDOWN_HOOK_PRIORITY") + .get(null) // static field, the value is not used + .asInstanceOf[Int] + val shm = shmClass.getMethod("get").invoke(null) + shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int]) + .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30)) + + case Failure(_) => + Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook")); + } + } + + def runAll(): Unit = synchronized { + shuttingDown = true + while (!hooks.isEmpty()) { + Try(Utils.logUncaughtExceptions(hooks.poll().run())) + } + } + + def add(priority: Int, hook: () => Unit): AnyRef = synchronized { + checkState() + val hookRef = new SparkShutdownHook(priority, hook) + hooks.add(hookRef) + hookRef + } + + def remove(ref: AnyRef): Boolean = synchronized { + hooks.remove(ref) + } + + private def checkState(): Unit = { + if (shuttingDown) { + throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.") + } + } + +} + +private class SparkShutdownHook(private val priority: Int, hook: () => Unit) + extends Comparable[SparkShutdownHook] { + + override def compareTo(other: SparkShutdownHook): Int = { + other.priority - priority + } + + def run(): Unit = hook() + +} diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 0180399c9dad..14b1f2a17e70 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -124,9 +124,11 @@ object SizeEstimator extends Logging { val server = ManagementFactory.getPlatformMBeanServer() // NOTE: This should throw an exception in non-Sun JVMs + // scalastyle:off classforname val hotSpotMBeanClass = Class.forName("com.sun.management.HotSpotDiagnosticMXBean") val getVMMethod = hotSpotMBeanClass.getDeclaredMethod("getVMOption", Class.forName("java.lang.String")) + // scalastyle:on classforname val bean = ManagementFactory.newPlatformMXBeanProxy(server, hotSpotMBeanName, hotSpotMBeanClass) @@ -215,10 +217,10 @@ object SizeEstimator extends Logging { var arrSize: Long = alignSize(objectSize + INT_SIZE) if (elementClass.isPrimitive) { - arrSize += alignSize(length * primitiveSize(elementClass)) + arrSize += alignSize(length.toLong * primitiveSize(elementClass)) state.size += arrSize } else { - arrSize += alignSize(length * pointerSize) + arrSize += alignSize(length.toLong * pointerSize) state.size += arrSize if (length <= ARRAY_SIZE_FOR_SAMPLING) { @@ -334,7 +336,7 @@ object SizeEstimator extends Logging { // hg.openjdk.java.net/jdk8/jdk8/hotspot/file/tip/src/share/vm/classfile/classFileParser.cpp var alignedSize = shellSize for (size <- fieldSizes if sizeCount(size) > 0) { - val count = sizeCount(size) + val count = sizeCount(size).toLong // If there are internal gaps, smaller field can fit in. alignedSize = math.max(alignedSize, alignSizeUp(shellSize, size) + size * count) shellSize += size * count diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala index ad3db1fbb57e..724818724733 100644 --- a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala +++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala @@ -33,7 +33,7 @@ private[spark] object SparkUncaughtExceptionHandler // We may have been called from a shutdown hook. If so, we must not call System.exit(). // (If we do, we will deadlock.) - if (!Utils.inShutdown()) { + if (!ShutdownHookManager.inShutdown()) { if (exception.isInstanceOf[OutOfMemoryError]) { System.exit(SparkExitCode.OOM) } else { diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index 8de75ba9a9c9..d7e5143c3095 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -21,7 +21,8 @@ import java.util.Set import java.util.Map.Entry import java.util.concurrent.ConcurrentHashMap -import scala.collection.{JavaConversions, mutable} +import scala.collection.JavaConverters._ +import scala.collection.mutable import org.apache.spark.Logging @@ -50,8 +51,7 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa } def iterator: Iterator[(A, B)] = { - val jIterator = getEntrySet.iterator - JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue.value)) + getEntrySet.iterator.asScala.map(kv => (kv.getKey, kv.getValue.value)) } def getEntrySet: Set[Entry[A, TimeStampedValue[B]]] = internalMap.entrySet @@ -90,9 +90,7 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa } override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = { - JavaConversions.mapAsScalaConcurrentMap(internalMap) - .map { case (k, TimeStampedValue(v, t)) => (k, v) } - .filter(p) + internalMap.asScala.map { case (k, TimeStampedValue(v, t)) => (k, v) }.filter(p) } override def empty: mutable.Map[A, B] = new TimeStampedHashMap[A, B]() diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala index 7cd8f28b12dd..65efeb1f4c19 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala @@ -19,7 +19,7 @@ package org.apache.spark.util import java.util.concurrent.ConcurrentHashMap -import scala.collection.JavaConversions +import scala.collection.JavaConverters._ import scala.collection.mutable.Set private[spark] class TimeStampedHashSet[A] extends Set[A] { @@ -31,7 +31,7 @@ private[spark] class TimeStampedHashSet[A] extends Set[A] { def iterator: Iterator[A] = { val jIterator = internalMap.entrySet().iterator() - JavaConversions.asScalaIterator(jIterator).map(_.getKey) + jIterator.asScala.map(_.getKey) } override def + (elem: A): Set[A] = { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 19157af5b6f4..2bab4af2e73a 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -21,11 +21,11 @@ import java.io._ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer -import java.util.{PriorityQueue, Properties, Locale, Random, UUID} +import java.util.{Properties, Locale, Random, UUID} import java.util.concurrent._ import javax.net.ssl.HttpsURLConnection -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.ArrayBuffer import scala.io.Source @@ -65,28 +65,16 @@ private[spark] object CallSite { private[spark] object Utils extends Logging { val random = new Random() - val DEFAULT_SHUTDOWN_PRIORITY = 100 - - /** - * The shutdown priority of the SparkContext instance. This is lower than the default - * priority, so that by default hooks are run before the context is shut down. - */ - val SPARK_CONTEXT_SHUTDOWN_PRIORITY = 50 - /** - * The shutdown priority of temp directory must be lower than the SparkContext shutdown - * priority. Otherwise cleaning the temp directories while Spark jobs are running can - * throw undesirable errors at the time of shutdown. + * Define a default value for driver memory here since this value is referenced across the code + * base and nearly all files already use Utils.scala */ - val TEMP_DIR_SHUTDOWN_PRIORITY = 25 + val DEFAULT_DRIVER_MEM_MB = JavaUtils.DEFAULT_DRIVER_MEM_MB.toInt private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 @volatile private var localRootDirs: Array[String] = null - private val shutdownHooks = new SparkShutdownHookManager() - shutdownHooks.install() - /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -107,8 +95,11 @@ private[spark] object Utils extends Logging { def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { val bis = new ByteArrayInputStream(bytes) val ois = new ObjectInputStream(bis) { - override def resolveClass(desc: ObjectStreamClass): Class[_] = + override def resolveClass(desc: ObjectStreamClass): Class[_] = { + // scalastyle:off classforname Class.forName(desc.getName, false, loader) + // scalastyle:on classforname + } } ois.readObject.asInstanceOf[T] } @@ -171,12 +162,16 @@ private[spark] object Utils extends Logging { /** Determines whether the provided class is loadable in the current thread. */ def classIsLoadable(clazz: String): Boolean = { + // scalastyle:off classforname Try { Class.forName(clazz, false, getContextOrSparkClassLoader) }.isSuccess + // scalastyle:on classforname } + // scalastyle:off classforname /** Preferred alternative to Class.forName(className) */ def classForName(className: String): Class[_] = { Class.forName(className, true, getContextOrSparkClassLoader) + // scalastyle:on classforname } /** @@ -192,86 +187,6 @@ private[spark] object Utils extends Logging { } } - private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() - private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() - - // Add a shutdown hook to delete the temp dirs when the JVM exits - addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () => - logInfo("Shutdown hook called") - shutdownDeletePaths.foreach { dirPath => - try { - logInfo("Deleting directory " + dirPath) - Utils.deleteRecursively(new File(dirPath)) - } catch { - case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e) - } - } - } - - // Register the path to be deleted via shutdown hook - def registerShutdownDeleteDir(file: File) { - val absolutePath = file.getAbsolutePath() - shutdownDeletePaths.synchronized { - shutdownDeletePaths += absolutePath - } - } - - // Register the tachyon path to be deleted via shutdown hook - def registerShutdownDeleteDir(tachyonfile: TachyonFile) { - val absolutePath = tachyonfile.getPath() - shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths += absolutePath - } - } - - // Is the path already registered to be deleted via a shutdown hook ? - def hasShutdownDeleteDir(file: File): Boolean = { - val absolutePath = file.getAbsolutePath() - shutdownDeletePaths.synchronized { - shutdownDeletePaths.contains(absolutePath) - } - } - - // Is the path already registered to be deleted via a shutdown hook ? - def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = { - val absolutePath = file.getPath() - shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths.contains(absolutePath) - } - } - - // Note: if file is child of some registered path, while not equal to it, then return true; - // else false. This is to ensure that two shutdown hooks do not try to delete each others - // paths - resulting in IOException and incomplete cleanup. - def hasRootAsShutdownDeleteDir(file: File): Boolean = { - val absolutePath = file.getAbsolutePath() - val retval = shutdownDeletePaths.synchronized { - shutdownDeletePaths.exists { path => - !absolutePath.equals(path) && absolutePath.startsWith(path) - } - } - if (retval) { - logInfo("path = " + file + ", already present as root for deletion.") - } - retval - } - - // Note: if file is child of some registered path, while not equal to it, then return true; - // else false. This is to ensure that two shutdown hooks do not try to delete each others - // paths - resulting in Exception and incomplete cleanup. - def hasRootAsShutdownDeleteDir(file: TachyonFile): Boolean = { - val absolutePath = file.getPath() - val retval = shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths.exists { path => - !absolutePath.equals(path) && absolutePath.startsWith(path) - } - } - if (retval) { - logInfo("path = " + file + ", already present as root for deletion.") - } - retval - } - /** * JDK equivalent of `chmod 700 file`. * @@ -320,7 +235,7 @@ private[spark] object Utils extends Logging { root: String = System.getProperty("java.io.tmpdir"), namePrefix: String = "spark"): File = { val dir = createDirectory(root, namePrefix) - registerShutdownDeleteDir(dir) + ShutdownHookManager.registerShutdownDeleteDir(dir) dir } @@ -430,11 +345,11 @@ private[spark] object Utils extends Logging { val lockFileName = s"${url.hashCode}${timestamp}_lock" val localDir = new File(getLocalDir(conf)) val lockFile = new File(localDir, lockFileName) - val raf = new RandomAccessFile(lockFile, "rw") + val lockFileChannel = new RandomAccessFile(lockFile, "rw").getChannel() // Only one executor entry. // The FileLock is only used to control synchronization for executors download file, // it's always safe regardless of lock type (mandatory or advisory). - val lock = raf.getChannel().lock() + val lock = lockFileChannel.lock() val cachedFile = new File(localDir, cachedFileName) try { if (!cachedFile.exists()) { @@ -442,6 +357,7 @@ private[spark] object Utils extends Logging { } } finally { lock.release() + lockFileChannel.close() } copyFile( url, @@ -727,7 +643,12 @@ private[spark] object Utils extends Logging { localRootDirs } - private def getOrCreateLocalRootDirsImpl(conf: SparkConf): Array[String] = { + /** + * Return the configured local directories where Spark can write files. This + * method does not create any directories on its own, it only encapsulates the + * logic of locating the local directories according to deployment mode. + */ + def getConfiguredLocalDirs(conf: SparkConf): Array[String] = { if (isRunningInYarnContainer(conf)) { // If we are in yarn mode, systems can have different disk layouts so we must set it // to what Yarn on this system said was available. Note this assumes that Yarn has @@ -743,27 +664,29 @@ private[spark] object Utils extends Logging { Option(conf.getenv("SPARK_LOCAL_DIRS")) .getOrElse(conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) .split(",") - .flatMap { root => - try { - val rootDir = new File(root) - if (rootDir.exists || rootDir.mkdirs()) { - val dir = createTempDir(root) - chmod700(dir) - Some(dir.getAbsolutePath) - } else { - logError(s"Failed to create dir in $root. Ignoring this directory.") - None - } - } catch { - case e: IOException => - logError(s"Failed to create local root dir in $root. Ignoring this directory.") - None - } - } - .toArray } } + private def getOrCreateLocalRootDirsImpl(conf: SparkConf): Array[String] = { + getConfiguredLocalDirs(conf).flatMap { root => + try { + val rootDir = new File(root) + if (rootDir.exists || rootDir.mkdirs()) { + val dir = createTempDir(root) + chmod700(dir) + Some(dir.getAbsolutePath) + } else { + logError(s"Failed to create dir in $root. Ignoring this directory.") + None + } + } catch { + case e: IOException => + logError(s"Failed to create local root dir in $root. Ignoring this directory.") + None + } + }.toArray + } + /** Get the Yarn approved local directories. */ private def getYarnLocalDirs(conf: SparkConf): String = { // Hadoop 0.23 and 2.x have different Environment variable names for the @@ -825,12 +748,12 @@ private[spark] object Utils extends Logging { // getNetworkInterfaces returns ifs in reverse order compared to ifconfig output order // on unix-like system. On windows, it returns in index order. // It's more proper to pick ip address following system output order. - val activeNetworkIFs = NetworkInterface.getNetworkInterfaces.toList + val activeNetworkIFs = NetworkInterface.getNetworkInterfaces.asScala.toSeq val reOrderedNetworkIFs = if (isWindows) activeNetworkIFs else activeNetworkIFs.reverse for (ni <- reOrderedNetworkIFs) { - val addresses = ni.getInetAddresses.toList - .filterNot(addr => addr.isLinkLocalAddress || addr.isLoopbackAddress) + val addresses = ni.getInetAddresses.asScala + .filterNot(addr => addr.isLinkLocalAddress || addr.isLoopbackAddress).toSeq if (addresses.nonEmpty) { val addr = addresses.find(_.isInstanceOf[Inet4Address]).getOrElse(addresses.head) // because of Inet6Address.toHostName may add interface at the end if it knows about it @@ -952,9 +875,7 @@ private[spark] object Utils extends Logging { if (savedIOException != null) { throw savedIOException } - shutdownDeletePaths.synchronized { - shutdownDeletePaths.remove(file.getAbsolutePath) - } + ShutdownHookManager.removeShutdownDeleteDir(file) } } finally { if (!file.delete()) { @@ -1445,7 +1366,7 @@ private[spark] object Utils extends Logging { file.getAbsolutePath, effectiveStartIndex, effectiveEndIndex)) } sum += fileToLength(file) - logDebug(s"After processing file $file, string built is ${stringBuffer.toString}}") + logDebug(s"After processing file $file, string built is ${stringBuffer.toString}") } stringBuffer.toString } @@ -1457,27 +1378,6 @@ private[spark] object Utils extends Logging { serializer.deserialize[T](serializer.serialize(value)) } - /** - * Detect whether this thread might be executing a shutdown hook. Will always return true if - * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g. - * if System.exit was just called by a concurrent thread). - * - * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing - * an IllegalStateException. - */ - def inShutdown(): Boolean = { - try { - val hook = new Thread { - override def run() {} - } - Runtime.getRuntime.addShutdownHook(hook) - Runtime.getRuntime.removeShutdownHook(hook) - } catch { - case ise: IllegalStateException => return true - } - false - } - private def isSpace(c: Char): Boolean = { " \t\r\n".indexOf(c) != -1 } @@ -1566,14 +1466,40 @@ private[spark] object Utils extends Logging { hashAbs } + /** + * NaN-safe version of [[java.lang.Double.compare()]] which allows NaN values to be compared + * according to semantics where NaN == NaN and NaN > any non-NaN double. + */ + def nanSafeCompareDoubles(x: Double, y: Double): Int = { + val xIsNan: Boolean = java.lang.Double.isNaN(x) + val yIsNan: Boolean = java.lang.Double.isNaN(y) + if ((xIsNan && yIsNan) || (x == y)) 0 + else if (xIsNan) 1 + else if (yIsNan) -1 + else if (x > y) 1 + else -1 + } + + /** + * NaN-safe version of [[java.lang.Float.compare()]] which allows NaN values to be compared + * according to semantics where NaN == NaN and NaN > any non-NaN float. + */ + def nanSafeCompareFloats(x: Float, y: Float): Int = { + val xIsNan: Boolean = java.lang.Float.isNaN(x) + val yIsNan: Boolean = java.lang.Float.isNaN(y) + if ((xIsNan && yIsNan) || (x == y)) 0 + else if (xIsNan) 1 + else if (yIsNan) -1 + else if (x > y) 1 + else -1 + } + /** Returns the system properties map that is thread-safe to iterator over. It gets the * properties which have been set explicitly, as well as those for which only a default value * has been defined. */ def getSystemProperties: Map[String, String] = { - val sysProps = for (key <- System.getProperties.stringPropertyNames()) yield - (key, System.getProperty(key)) - - sysProps.toMap + System.getProperties.stringPropertyNames().asScala + .map(key => (key, System.getProperty(key))).toMap } /** @@ -1884,7 +1810,8 @@ private[spark] object Utils extends Logging { try { val properties = new Properties() properties.load(inReader) - properties.stringPropertyNames().map(k => (k, properties(k).trim)).toMap + properties.stringPropertyNames().asScala.map( + k => (k, properties.getProperty(k).trim)).toMap } catch { case e: IOException => throw new SparkException(s"Failed when loading Spark properties from $filename", e) @@ -2013,7 +1940,8 @@ private[spark] object Utils extends Logging { return true } isBindCollision(e.getCause) - case e: MultiException => e.getThrowables.exists(isBindCollision) + case e: MultiException => + e.getThrowables.asScala.exists(isBindCollision) case e: Exception => isBindCollision(e.getCause) case _ => false } @@ -2172,37 +2100,6 @@ private[spark] object Utils extends Logging { msg.startsWith(BACKUP_STANDALONE_MASTER_PREFIX) } - /** - * Adds a shutdown hook with default priority. - * - * @param hook The code to run during shutdown. - * @return A handle that can be used to unregister the shutdown hook. - */ - def addShutdownHook(hook: () => Unit): AnyRef = { - addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY)(hook) - } - - /** - * Adds a shutdown hook with the given priority. Hooks with lower priority values run - * first. - * - * @param hook The code to run during shutdown. - * @return A handle that can be used to unregister the shutdown hook. - */ - def addShutdownHook(priority: Int)(hook: () => Unit): AnyRef = { - shutdownHooks.add(priority, hook) - } - - /** - * Remove a previously installed shutdown hook. - * - * @param ref A handle returned by `addShutdownHook`. - * @return Whether the hook was removed. - */ - def removeShutdownHook(ref: AnyRef): Boolean = { - shutdownHooks.remove(ref) - } - /** * To avoid calling `Utils.getCallSite` for every single RDD we create in the body, * set a dummy call site that RDDs use instead. This is for performance optimization. @@ -2237,72 +2134,19 @@ private[spark] object Utils extends Logging { isInDirectory(parent, child.getParentFile) } -} - -private [util] class SparkShutdownHookManager { - - private val hooks = new PriorityQueue[SparkShutdownHook]() - private var shuttingDown = false - /** - * Install a hook to run at shutdown and run all registered hooks in order. Hadoop 1.x does not - * have `ShutdownHookManager`, so in that case we just use the JVM's `Runtime` object and hope for - * the best. + * Return whether dynamic allocation is enabled in the given conf + * Dynamic allocation and explicitly setting the number of executors are inherently + * incompatible. In environments where dynamic allocation is turned on by default, + * the latter should override the former (SPARK-9092). */ - def install(): Unit = { - val hookTask = new Runnable() { - override def run(): Unit = runAll() - } - Try(Class.forName("org.apache.hadoop.util.ShutdownHookManager")) match { - case Success(shmClass) => - val fsPriority = classOf[FileSystem].getField("SHUTDOWN_HOOK_PRIORITY").get() - .asInstanceOf[Int] - val shm = shmClass.getMethod("get").invoke(null) - shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int]) - .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30)) - - case Failure(_) => - Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook")); - } - } - - def runAll(): Unit = synchronized { - shuttingDown = true - while (!hooks.isEmpty()) { - Try(Utils.logUncaughtExceptions(hooks.poll().run())) - } - } - - def add(priority: Int, hook: () => Unit): AnyRef = synchronized { - checkState() - val hookRef = new SparkShutdownHook(priority, hook) - hooks.add(hookRef) - hookRef - } - - def remove(ref: AnyRef): Boolean = synchronized { - hooks.remove(ref) - } - - private def checkState(): Unit = { - if (shuttingDown) { - throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.") - } + def isDynamicAllocationEnabled(conf: SparkConf): Boolean = { + conf.getBoolean("spark.dynamicAllocation.enabled", false) && + conf.getInt("spark.executor.instances", 0) == 0 } } -private class SparkShutdownHook(private val priority: Int, hook: () => Unit) - extends Comparable[SparkShutdownHook] { - - override def compareTo(other: SparkShutdownHook): Int = { - other.priority - priority - } - - def run(): Unit = hook() - -} - /** * A utility class to redirect the child process's stdout or stderr. */ @@ -2333,3 +2177,36 @@ private[spark] class RedirectThread( } } } + +/** + * An [[OutputStream]] that will store the last 10 kilobytes (by default) written to it + * in a circular buffer. The current contents of the buffer can be accessed using + * the toString method. + */ +private[spark] class CircularBuffer(sizeInBytes: Int = 10240) extends java.io.OutputStream { + var pos: Int = 0 + var buffer = new Array[Int](sizeInBytes) + + def write(i: Int): Unit = { + buffer(pos) = i + pos = (pos + 1) % buffer.length + } + + override def toString: String = { + val (end, start) = buffer.splitAt(pos) + val input = new java.io.InputStream { + val iterator = (start ++ end).iterator + + def read(): Int = if (iterator.hasNext) iterator.next() else -1 + } + val reader = new BufferedReader(new InputStreamReader(input)) + val stringBuilder = new StringBuilder + var line = reader.readLine() + while (line != null) { + stringBuilder.append(line) + stringBuilder.append("\n") + line = reader.readLine() + } + stringBuilder.toString() + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index 9c15b1188d91..7ab67fc3a2de 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -32,6 +32,17 @@ class BitSet(numBits: Int) extends Serializable { */ def capacity: Int = numWords * 64 + /** + * Clear all set bits. + */ + def clear(): Unit = { + var i = 0 + while (i < numWords) { + words(i) = 0L + i += 1 + } + } + /** * Set all the bits up to a given index */ diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala index 516aaa44d03f..ae60f3b0cb55 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala @@ -37,7 +37,7 @@ private[spark] class ChainedBuffer(chunkSize: Int) { private var _size: Long = 0 /** - * Feed bytes from this buffer into a BlockObjectWriter. + * Feed bytes from this buffer into a DiskBlockObjectWriter. * * @param pos Offset in the buffer to read from. * @param os OutputStream to read into. diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 1e4531ef395a..f929b12606f0 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.io.ByteStreams -import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.{Logging, SparkEnv, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.{DeserializationStream, Serializer} import org.apache.spark.storage.{BlockId, BlockManager} @@ -89,6 +89,7 @@ class ExternalAppendOnlyMap[K, V, C]( // Number of bytes spilled in total private var _diskBytesSpilled = 0L + def diskBytesSpilled: Long = _diskBytesSpilled // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val fileBufferSize = @@ -97,6 +98,10 @@ class ExternalAppendOnlyMap[K, V, C]( // Write metrics for current spill private var curWriteMetrics: ShuffleWriteMetrics = _ + // Peak size of the in-memory map observed so far, in bytes + private var _peakMemoryUsedBytes: Long = 0L + def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes + private val keyComparator = new HashComparator[K] private val ser = serializer.newInstance() @@ -126,7 +131,11 @@ class ExternalAppendOnlyMap[K, V, C]( while (entries.hasNext) { curEntry = entries.next() - if (maybeSpill(currentMap, currentMap.estimateSize())) { + val estimatedSize = currentMap.estimateSize() + if (estimatedSize > _peakMemoryUsedBytes) { + _peakMemoryUsedBytes = estimatedSize + } + if (maybeSpill(currentMap, estimatedSize)) { currentMap = new SizeTrackingAppendOnlyMap[K, C] } currentMap.changeValue(curEntry._1, update) @@ -207,8 +216,6 @@ class ExternalAppendOnlyMap[K, V, C]( spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) } - def diskBytesSpilled: Long = _diskBytesSpilled - /** * Return an iterator that merges the in-memory map with the spilled maps. * If no spill has occurred, simply return the in-memory map's iterator. @@ -470,14 +477,27 @@ class ExternalAppendOnlyMap[K, V, C]( item } - // TODO: Ensure this gets called even if the iterator isn't drained. private def cleanup() { batchIndex = batchOffsets.length // Prevent reading any other batch val ds = deserializeStream - deserializeStream = null - fileStream = null - ds.close() - file.delete() + if (ds != null) { + ds.close() + deserializeStream = null + } + if (fileStream != null) { + fileStream.close() + fileStream = null + } + if (file.exists()) { + file.delete() + } + } + + val context = TaskContext.get() + // context is null in some tests of ExternalAppendOnlyMapSuite because these tests don't run in + // a TaskContext. + if (context != null) { + context.addTaskCompletionListener(context => cleanup()) } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 757dec66c203..31230d5978b2 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -30,7 +30,7 @@ import org.apache.spark._ import org.apache.spark.serializer._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter} -import org.apache.spark.storage.{BlockId, BlockObjectWriter} +import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner @@ -152,6 +152,9 @@ private[spark] class ExternalSorter[K, V, C]( private var _diskBytesSpilled = 0L def diskBytesSpilled: Long = _diskBytesSpilled + // Peak size of the in-memory data structure observed so far, in bytes + private var _peakMemoryUsedBytes: Long = 0L + def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes // A comparator for keys K that orders them within a partition to allow aggregation or sorting. // Can be a partial ordering by hash code if a total ordering is not provided through by the @@ -185,6 +188,12 @@ private[spark] class ExternalSorter[K, V, C]( private val spills = new ArrayBuffer[SpilledFile] + /** + * Number of files this sorter has spilled so far. + * Exposed for testing. + */ + private[spark] def numSpills: Int = spills.size + override def insertAll(records: Iterator[Product2[K, V]]): Unit = { // TODO: stop combining if we find that the reduction factor isn't high val shouldCombine = aggregator.isDefined @@ -224,15 +233,22 @@ private[spark] class ExternalSorter[K, V, C]( return } + var estimatedSize = 0L if (usingMap) { - if (maybeSpill(map, map.estimateSize())) { + estimatedSize = map.estimateSize() + if (maybeSpill(map, estimatedSize)) { map = new PartitionedAppendOnlyMap[K, C] } } else { - if (maybeSpill(buffer, buffer.estimateSize())) { + estimatedSize = buffer.estimateSize() + if (maybeSpill(buffer, estimatedSize)) { buffer = newBuffer() } } + + if (estimatedSize > _peakMemoryUsedBytes) { + _peakMemoryUsedBytes = estimatedSize + } } /** @@ -250,7 +266,7 @@ private[spark] class ExternalSorter[K, V, C]( // These variables are reset after each flush var objectsWritten: Long = 0 var spillMetrics: ShuffleWriteMetrics = null - var writer: BlockObjectWriter = null + var writer: DiskBlockObjectWriter = null def openWriter(): Unit = { assert (writer == null && spillMetrics == null) spillMetrics = new ShuffleWriteMetrics @@ -281,6 +297,8 @@ private[spark] class ExternalSorter[K, V, C]( val it = collection.destructiveSortedWritablePartitionedIterator(comparator) while (it.hasNext) { val partitionId = it.nextPartition() + require(partitionId >= 0 && partitionId < numPartitions, + s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})") it.writeNext(writer) elementsPerPartition(partitionId) += 1 objectsWritten += 1 @@ -684,8 +702,10 @@ private[spark] class ExternalSorter[K, V, C]( } } - context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) + context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemoryUsedBytes) lengths } diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala index 04bb7fc78c13..f5844d5353be 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala @@ -19,7 +19,6 @@ package org.apache.spark.util.collection import java.util.Comparator -import org.apache.spark.storage.BlockObjectWriter import org.apache.spark.util.collection.WritablePartitionedPairCollection._ /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala index ae9a48729e20..87a786b02d65 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala @@ -21,9 +21,8 @@ import java.io.InputStream import java.nio.IntBuffer import java.util.Comparator -import org.apache.spark.SparkEnv import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance} -import org.apache.spark.storage.BlockObjectWriter +import org.apache.spark.storage.DiskBlockObjectWriter import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._ /** @@ -136,7 +135,7 @@ private[spark] class PartitionedSerializedPairBuffer[K, V]( // current position in the meta buffer in ints var pos = 0 - def writeNext(writer: BlockObjectWriter): Unit = { + def writeNext(writer: DiskBlockObjectWriter): Unit = { val keyStart = getKeyStartPos(metaBuffer, pos) val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN) pos += RECORD_SIZE diff --git a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala index bdbca00a0062..4939b600dbfb 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala @@ -17,7 +17,7 @@ package org.apache.spark.util.collection -import scala.collection.JavaConversions.{collectionAsScalaIterable, asJavaIterator} +import scala.collection.JavaConverters._ import com.google.common.collect.{Ordering => GuavaOrdering} @@ -34,6 +34,6 @@ private[spark] object Utils { val ordering = new GuavaOrdering[T] { override def compare(l: T, r: T): Int = ord.compare(l, r) } - collectionAsScalaIterable(ordering.leastOf(asJavaIterator(input), num)).iterator + ordering.leastOf(input.asJava, num).iterator.asScala } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala index 7bc59898658e..38848e9018c6 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala @@ -19,7 +19,7 @@ package org.apache.spark.util.collection import java.util.Comparator -import org.apache.spark.storage.BlockObjectWriter +import org.apache.spark.storage.DiskBlockObjectWriter /** * A common interface for size-tracking collections of key-value pairs that @@ -51,7 +51,7 @@ private[spark] trait WritablePartitionedPairCollection[K, V] { new WritablePartitionedIterator { private[this] var cur = if (it.hasNext) it.next() else null - def writeNext(writer: BlockObjectWriter): Unit = { + def writeNext(writer: DiskBlockObjectWriter): Unit = { writer.write(cur._1._2, cur._2) cur = if (it.hasNext) it.next() else null } @@ -91,11 +91,11 @@ private[spark] object WritablePartitionedPairCollection { } /** - * Iterator that writes elements to a BlockObjectWriter instead of returning them. Each element + * Iterator that writes elements to a DiskBlockObjectWriter instead of returning them. Each element * has an associated partition. */ private[spark] trait WritablePartitionedIterator { - def writeNext(writer: BlockObjectWriter): Unit + def writeNext(writer: DiskBlockObjectWriter): Unit def hasNext(): Boolean diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala index 7138b4b8e453..1e8476c4a047 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala @@ -79,32 +79,30 @@ private[spark] class RollingFileAppender( val rolloverSuffix = rollingPolicy.generateRolledOverFileSuffix() val rolloverFile = new File( activeFile.getParentFile, activeFile.getName + rolloverSuffix).getAbsoluteFile - try { - logDebug(s"Attempting to rollover file $activeFile to file $rolloverFile") - if (activeFile.exists) { - if (!rolloverFile.exists) { - Files.move(activeFile, rolloverFile) - logInfo(s"Rolled over $activeFile to $rolloverFile") - } else { - // In case the rollover file name clashes, make a unique file name. - // The resultant file names are long and ugly, so this is used only - // if there is a name collision. This can be avoided by the using - // the right pattern such that name collisions do not occur. - var i = 0 - var altRolloverFile: File = null - do { - altRolloverFile = new File(activeFile.getParent, - s"${activeFile.getName}$rolloverSuffix--$i").getAbsoluteFile - i += 1 - } while (i < 10000 && altRolloverFile.exists) - - logWarning(s"Rollover file $rolloverFile already exists, " + - s"rolled over $activeFile to file $altRolloverFile") - Files.move(activeFile, altRolloverFile) - } + logDebug(s"Attempting to rollover file $activeFile to file $rolloverFile") + if (activeFile.exists) { + if (!rolloverFile.exists) { + Files.move(activeFile, rolloverFile) + logInfo(s"Rolled over $activeFile to $rolloverFile") } else { - logWarning(s"File $activeFile does not exist") + // In case the rollover file name clashes, make a unique file name. + // The resultant file names are long and ugly, so this is used only + // if there is a name collision. This can be avoided by the using + // the right pattern such that name collisions do not occur. + var i = 0 + var altRolloverFile: File = null + do { + altRolloverFile = new File(activeFile.getParent, + s"${activeFile.getName}$rolloverSuffix--$i").getAbsoluteFile + i += 1 + } while (i < 10000 && altRolloverFile.exists) + + logWarning(s"Rollover file $rolloverFile already exists, " + + s"rolled over $activeFile to file $altRolloverFile") + Files.move(activeFile, altRolloverFile) } + } else { + logWarning(s"File $activeFile does not exist") } } diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 786b97ad7b9e..c156b03cdb7c 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -176,10 +176,15 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T * A sampler for sampling with replacement, based on values drawn from Poisson distribution. * * @param fraction the sampling fraction (with replacement) + * @param useGapSamplingIfPossible if true, use gap sampling when sampling ratio is low. * @tparam T item type */ @DeveloperApi -class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] { +class PoissonSampler[T: ClassTag]( + fraction: Double, + useGapSamplingIfPossible: Boolean) extends RandomSampler[T, T] { + + def this(fraction: Double) = this(fraction, useGapSamplingIfPossible = true) /** Epsilon slop to avoid failure from floating point jitter. */ require( @@ -199,17 +204,18 @@ class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] override def sample(items: Iterator[T]): Iterator[T] = { if (fraction <= 0.0) { Iterator.empty - } else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) { - new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon) + } else if (useGapSamplingIfPossible && + fraction <= RandomSampler.defaultMaxGapSamplingFraction) { + new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon) } else { - items.flatMap { item => { + items.flatMap { item => val count = rng.sample() if (count == 0) Iterator.empty else Iterator.fill(count)(item) - }} + } } } - override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction) + override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction, useGapSamplingIfPossible) } diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala index c4a7b4441c85..85fb923cd9bc 100644 --- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala @@ -70,12 +70,14 @@ private[spark] object XORShiftRandom { * @param args takes one argument - the number of random numbers to generate */ def main(args: Array[String]): Unit = { + // scalastyle:off println if (args.length != 1) { println("Benchmark of XORShiftRandom vis-a-vis java.util.Random") println("Usage: XORShiftRandom number_of_random_numbers_to_generate") System.exit(1) } println(benchmark(args(0).toInt)) + // scalastyle:on println } /** diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index dfd86d3e51e7..fd8f7f39b7cc 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -24,10 +24,10 @@ import java.util.*; import java.util.concurrent.*; -import scala.collection.JavaConversions; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; +import scala.collection.JavaConverters; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; @@ -51,7 +51,6 @@ import org.apache.spark.api.java.*; import org.apache.spark.api.java.function.*; -import org.apache.spark.executor.TaskMetrics; import org.apache.spark.input.PortableDataStream; import org.apache.spark.partial.BoundedDouble; import org.apache.spark.partial.PartialResult; @@ -91,7 +90,7 @@ public void sparkContextUnion() { JavaRDD sUnion = sc.union(s1, s2); Assert.assertEquals(4, sUnion.count()); // List - List> list = new ArrayList>(); + List> list = new ArrayList<>(); list.add(s2); sUnion = sc.union(s1, list); Assert.assertEquals(4, sUnion.count()); @@ -104,9 +103,9 @@ public void sparkContextUnion() { Assert.assertEquals(4, dUnion.count()); // Union of JavaPairRDDs - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(1, 2)); - pairs.add(new Tuple2(3, 4)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(1, 2)); + pairs.add(new Tuple2<>(3, 4)); JavaPairRDD p1 = sc.parallelizePairs(pairs); JavaPairRDD p2 = sc.parallelizePairs(pairs); JavaPairRDD pUnion = sc.union(p1, p2); @@ -134,9 +133,9 @@ public void intersection() { JavaDoubleRDD dIntersection = d1.intersection(d2); Assert.assertEquals(2, dIntersection.count()); - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(1, 2)); - pairs.add(new Tuple2(3, 4)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(1, 2)); + pairs.add(new Tuple2<>(3, 4)); JavaPairRDD p1 = sc.parallelizePairs(pairs); JavaPairRDD p2 = sc.parallelizePairs(pairs); JavaPairRDD pIntersection = p1.intersection(p2); @@ -166,47 +165,49 @@ public void randomSplit() { @Test public void sortByKey() { - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(0, 4)); - pairs.add(new Tuple2(3, 2)); - pairs.add(new Tuple2(-1, 1)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(0, 4)); + pairs.add(new Tuple2<>(3, 2)); + pairs.add(new Tuple2<>(-1, 1)); JavaPairRDD rdd = sc.parallelizePairs(pairs); // Default comparator JavaPairRDD sortedRDD = rdd.sortByKey(); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); List> sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); + Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); // Custom comparator sortedRDD = rdd.sortByKey(Collections.reverseOrder(), false); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); + Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); } @SuppressWarnings("unchecked") @Test public void repartitionAndSortWithinPartitions() { - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(0, 5)); - pairs.add(new Tuple2(3, 8)); - pairs.add(new Tuple2(2, 6)); - pairs.add(new Tuple2(0, 8)); - pairs.add(new Tuple2(3, 8)); - pairs.add(new Tuple2(1, 3)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(0, 5)); + pairs.add(new Tuple2<>(3, 8)); + pairs.add(new Tuple2<>(2, 6)); + pairs.add(new Tuple2<>(0, 8)); + pairs.add(new Tuple2<>(3, 8)); + pairs.add(new Tuple2<>(1, 3)); JavaPairRDD rdd = sc.parallelizePairs(pairs); Partitioner partitioner = new Partitioner() { + @Override public int numPartitions() { return 2; } + @Override public int getPartition(Object key) { - return ((Integer)key).intValue() % 2; + return (Integer) key % 2; } }; @@ -215,10 +216,10 @@ public int getPartition(Object key) { Assert.assertTrue(repartitioned.partitioner().isPresent()); Assert.assertEquals(repartitioned.partitioner().get(), partitioner); List>> partitions = repartitioned.glom().collect(); - Assert.assertEquals(partitions.get(0), Arrays.asList(new Tuple2(0, 5), - new Tuple2(0, 8), new Tuple2(2, 6))); - Assert.assertEquals(partitions.get(1), Arrays.asList(new Tuple2(1, 3), - new Tuple2(3, 8), new Tuple2(3, 8))); + Assert.assertEquals(partitions.get(0), + Arrays.asList(new Tuple2<>(0, 5), new Tuple2<>(0, 8), new Tuple2<>(2, 6))); + Assert.assertEquals(partitions.get(1), + Arrays.asList(new Tuple2<>(1, 3), new Tuple2<>(3, 8), new Tuple2<>(3, 8))); } @Test @@ -229,35 +230,37 @@ public void emptyRDD() { @Test public void sortBy() { - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(0, 4)); - pairs.add(new Tuple2(3, 2)); - pairs.add(new Tuple2(-1, 1)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(0, 4)); + pairs.add(new Tuple2<>(3, 2)); + pairs.add(new Tuple2<>(-1, 1)); JavaRDD> rdd = sc.parallelize(pairs); // compare on first value JavaRDD> sortedRDD = rdd.sortBy(new Function, Integer>() { - public Integer call(Tuple2 t) throws Exception { + @Override + public Integer call(Tuple2 t) { return t._1(); } }, true, 2); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); List> sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); + Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); // compare on second value sortedRDD = rdd.sortBy(new Function, Integer>() { - public Integer call(Tuple2 t) throws Exception { + @Override + public Integer call(Tuple2 t) { return t._2(); } }, true, 2); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(2)); + Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(2)); } @Test @@ -266,7 +269,7 @@ public void foreach() { JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); rdd.foreach(new VoidFunction() { @Override - public void call(String s) throws IOException { + public void call(String s) { accum.add(1); } }); @@ -279,7 +282,7 @@ public void foreachPartition() { JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); rdd.foreachPartition(new VoidFunction>() { @Override - public void call(Iterator iter) throws IOException { + public void call(Iterator iter) { while (iter.hasNext()) { iter.next(); accum.add(1); @@ -302,7 +305,7 @@ public void zipWithUniqueId() { List dataArray = Arrays.asList(1, 2, 3, 4); JavaPairRDD zip = sc.parallelize(dataArray).zipWithUniqueId(); JavaRDD indexes = zip.values(); - Assert.assertEquals(4, new HashSet(indexes.collect()).size()); + Assert.assertEquals(4, new HashSet<>(indexes.collect()).size()); } @Test @@ -318,10 +321,10 @@ public void zipWithIndex() { @Test public void lookup() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") - )); + new Tuple2<>("Apples", "Fruit"), + new Tuple2<>("Oranges", "Fruit"), + new Tuple2<>("Oranges", "Citrus") + )); Assert.assertEquals(2, categories.lookup("Oranges").size()); Assert.assertEquals(2, Iterables.size(categories.groupByKey().lookup("Oranges").get(0))); } @@ -391,18 +394,17 @@ public String call(Tuple2 x) { @Test public void cogroup() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") + new Tuple2<>("Apples", "Fruit"), + new Tuple2<>("Oranges", "Fruit"), + new Tuple2<>("Oranges", "Citrus") )); JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 2), - new Tuple2("Apples", 3) + new Tuple2<>("Oranges", 2), + new Tuple2<>("Apples", 3) )); JavaPairRDD, Iterable>> cogrouped = categories.cogroup(prices); - Assert.assertEquals("[Fruit, Citrus]", - Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); cogrouped.collect(); @@ -412,23 +414,22 @@ public void cogroup() { @Test public void cogroup3() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") + new Tuple2<>("Apples", "Fruit"), + new Tuple2<>("Oranges", "Fruit"), + new Tuple2<>("Oranges", "Citrus") )); JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 2), - new Tuple2("Apples", 3) + new Tuple2<>("Oranges", 2), + new Tuple2<>("Apples", 3) )); JavaPairRDD quantities = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 21), - new Tuple2("Apples", 42) + new Tuple2<>("Oranges", 21), + new Tuple2<>("Apples", 42) )); JavaPairRDD, Iterable, Iterable>> cogrouped = categories.cogroup(prices, quantities); - Assert.assertEquals("[Fruit, Citrus]", - Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); @@ -440,27 +441,26 @@ public void cogroup3() { @Test public void cogroup4() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") + new Tuple2<>("Apples", "Fruit"), + new Tuple2<>("Oranges", "Fruit"), + new Tuple2<>("Oranges", "Citrus") )); JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 2), - new Tuple2("Apples", 3) + new Tuple2<>("Oranges", 2), + new Tuple2<>("Apples", 3) )); JavaPairRDD quantities = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 21), - new Tuple2("Apples", 42) + new Tuple2<>("Oranges", 21), + new Tuple2<>("Apples", 42) )); JavaPairRDD countries = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", "BR"), - new Tuple2("Apples", "US") + new Tuple2<>("Oranges", "BR"), + new Tuple2<>("Apples", "US") )); JavaPairRDD, Iterable, Iterable, Iterable>> cogrouped = categories.cogroup(prices, quantities, countries); - Assert.assertEquals("[Fruit, Citrus]", - Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); Assert.assertEquals("[BR]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._4())); @@ -472,16 +472,16 @@ public void cogroup4() { @Test public void leftOuterJoin() { JavaPairRDD rdd1 = sc.parallelizePairs(Arrays.asList( - new Tuple2(1, 1), - new Tuple2(1, 2), - new Tuple2(2, 1), - new Tuple2(3, 1) + new Tuple2<>(1, 1), + new Tuple2<>(1, 2), + new Tuple2<>(2, 1), + new Tuple2<>(3, 1) )); JavaPairRDD rdd2 = sc.parallelizePairs(Arrays.asList( - new Tuple2(1, 'x'), - new Tuple2(2, 'y'), - new Tuple2(2, 'z'), - new Tuple2(4, 'w') + new Tuple2<>(1, 'x'), + new Tuple2<>(2, 'y'), + new Tuple2<>(2, 'z'), + new Tuple2<>(4, 'w') )); List>>> joined = rdd1.leftOuterJoin(rdd2).collect(); @@ -549,11 +549,11 @@ public Integer call(Integer a, Integer b) { public void aggregateByKey() { JavaPairRDD pairs = sc.parallelizePairs( Arrays.asList( - new Tuple2(1, 1), - new Tuple2(1, 1), - new Tuple2(3, 2), - new Tuple2(5, 1), - new Tuple2(5, 3)), 2); + new Tuple2<>(1, 1), + new Tuple2<>(1, 1), + new Tuple2<>(3, 2), + new Tuple2<>(5, 1), + new Tuple2<>(5, 3)), 2); Map> sets = pairs.aggregateByKey(new HashSet(), new Function2, Integer, Set>() { @@ -571,20 +571,20 @@ public Set call(Set a, Set b) { } }).collectAsMap(); Assert.assertEquals(3, sets.size()); - Assert.assertEquals(new HashSet(Arrays.asList(1)), sets.get(1)); - Assert.assertEquals(new HashSet(Arrays.asList(2)), sets.get(3)); - Assert.assertEquals(new HashSet(Arrays.asList(1, 3)), sets.get(5)); + Assert.assertEquals(new HashSet<>(Arrays.asList(1)), sets.get(1)); + Assert.assertEquals(new HashSet<>(Arrays.asList(2)), sets.get(3)); + Assert.assertEquals(new HashSet<>(Arrays.asList(1, 3)), sets.get(5)); } @SuppressWarnings("unchecked") @Test public void foldByKey() { List> pairs = Arrays.asList( - new Tuple2(2, 1), - new Tuple2(2, 1), - new Tuple2(1, 1), - new Tuple2(3, 2), - new Tuple2(3, 1) + new Tuple2<>(2, 1), + new Tuple2<>(2, 1), + new Tuple2<>(1, 1), + new Tuple2<>(3, 2), + new Tuple2<>(3, 1) ); JavaPairRDD rdd = sc.parallelizePairs(pairs); JavaPairRDD sums = rdd.foldByKey(0, @@ -603,11 +603,11 @@ public Integer call(Integer a, Integer b) { @Test public void reduceByKey() { List> pairs = Arrays.asList( - new Tuple2(2, 1), - new Tuple2(2, 1), - new Tuple2(1, 1), - new Tuple2(3, 2), - new Tuple2(3, 1) + new Tuple2<>(2, 1), + new Tuple2<>(2, 1), + new Tuple2<>(1, 1), + new Tuple2<>(3, 2), + new Tuple2<>(3, 1) ); JavaPairRDD rdd = sc.parallelizePairs(pairs); JavaPairRDD counts = rdd.reduceByKey( @@ -691,7 +691,7 @@ public void cartesian() { JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); JavaRDD stringRDD = sc.parallelize(Arrays.asList("Hello", "World")); JavaPairRDD cartesian = stringRDD.cartesian(doubleRDD); - Assert.assertEquals(new Tuple2("Hello", 1.0), cartesian.first()); + Assert.assertEquals(new Tuple2<>("Hello", 1.0), cartesian.first()); } @Test @@ -744,6 +744,7 @@ public void javaDoubleRDDHistoGram() { } private static class DoubleComparator implements Comparator, Serializable { + @Override public int compare(Double o1, Double o2) { return o1.compareTo(o2); } @@ -767,14 +768,14 @@ public void min() { public void naturalMax() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); double max = rdd.max(); - Assert.assertTrue(4.0 == max); + Assert.assertEquals(4.0, max, 0.0); } @Test public void naturalMin() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); double max = rdd.min(); - Assert.assertTrue(1.0 == max); + Assert.assertEquals(1.0, max, 0.0); } @Test @@ -810,7 +811,7 @@ public void reduceOnJavaDoubleRDD() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); double sum = rdd.reduce(new Function2() { @Override - public Double call(Double v1, Double v2) throws Exception { + public Double call(Double v1, Double v2) { return v1 + v2; } }); @@ -845,7 +846,7 @@ public double call(Integer x) { new PairFunction() { @Override public Tuple2 call(Integer x) { - return new Tuple2(x, x); + return new Tuple2<>(x, x); } }).cache(); pairs.collect(); @@ -871,26 +872,25 @@ public Iterable call(String x) { Assert.assertEquals("Hello", words.first()); Assert.assertEquals(11, words.count()); - JavaPairRDD pairs = rdd.flatMapToPair( + JavaPairRDD pairsRDD = rdd.flatMapToPair( new PairFlatMapFunction() { - @Override public Iterable> call(String s) { - List> pairs = new LinkedList>(); + List> pairs = new LinkedList<>(); for (String word : s.split(" ")) { - pairs.add(new Tuple2(word, word)); + pairs.add(new Tuple2<>(word, word)); } return pairs; } } ); - Assert.assertEquals(new Tuple2("Hello", "Hello"), pairs.first()); - Assert.assertEquals(11, pairs.count()); + Assert.assertEquals(new Tuple2<>("Hello", "Hello"), pairsRDD.first()); + Assert.assertEquals(11, pairsRDD.count()); JavaDoubleRDD doubles = rdd.flatMapToDouble(new DoubleFlatMapFunction() { @Override public Iterable call(String s) { - List lengths = new LinkedList(); + List lengths = new LinkedList<>(); for (String word : s.split(" ")) { lengths.add((double) word.length()); } @@ -898,36 +898,36 @@ public Iterable call(String s) { } }); Assert.assertEquals(5.0, doubles.first(), 0.01); - Assert.assertEquals(11, pairs.count()); + Assert.assertEquals(11, pairsRDD.count()); } @SuppressWarnings("unchecked") @Test public void mapsFromPairsToPairs() { - List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") - ); - JavaPairRDD pairRDD = sc.parallelizePairs(pairs); - - // Regression test for SPARK-668: - JavaPairRDD swapped = pairRDD.flatMapToPair( - new PairFlatMapFunction, String, Integer>() { - @Override - public Iterable> call(Tuple2 item) { - return Collections.singletonList(item.swap()); - } + List> pairs = Arrays.asList( + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") + ); + JavaPairRDD pairRDD = sc.parallelizePairs(pairs); + + // Regression test for SPARK-668: + JavaPairRDD swapped = pairRDD.flatMapToPair( + new PairFlatMapFunction, String, Integer>() { + @Override + public Iterable> call(Tuple2 item) { + return Collections.singletonList(item.swap()); + } }); - swapped.collect(); + swapped.collect(); - // There was never a bug here, but it's worth testing: - pairRDD.mapToPair(new PairFunction, String, Integer>() { - @Override - public Tuple2 call(Tuple2 item) { - return item.swap(); - } - }).collect(); + // There was never a bug here, but it's worth testing: + pairRDD.mapToPair(new PairFunction, String, Integer>() { + @Override + public Tuple2 call(Tuple2 item) { + return item.swap(); + } + }).collect(); } @Test @@ -954,7 +954,7 @@ public void mapPartitionsWithIndex() { JavaRDD partitionSums = rdd.mapPartitionsWithIndex( new Function2, Iterator>() { @Override - public Iterator call(Integer index, Iterator iter) throws Exception { + public Iterator call(Integer index, Iterator iter) { int sum = 0; while (iter.hasNext()) { sum += iter.next(); @@ -973,8 +973,8 @@ public void repartition() { JavaRDD repartitioned1 = in1.repartition(4); List> result1 = repartitioned1.glom().collect(); Assert.assertEquals(4, result1.size()); - for (List l: result1) { - Assert.assertTrue(l.size() > 0); + for (List l : result1) { + Assert.assertFalse(l.isEmpty()); } // Growing number of partitions @@ -983,7 +983,7 @@ public void repartition() { List> result2 = repartitioned2.glom().collect(); Assert.assertEquals(2, result2.size()); for (List l: result2) { - Assert.assertTrue(l.size() > 0); + Assert.assertFalse(l.isEmpty()); } } @@ -995,9 +995,9 @@ public void persist() { Assert.assertEquals(20, doubleRDD.sum(), 0.1); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD pairRDD = sc.parallelizePairs(pairs); pairRDD = pairRDD.persist(StorageLevel.DISK_ONLY()); @@ -1011,7 +1011,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, new TaskMetrics()); + TaskContext context = TaskContext$.MODULE$.empty(); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } @@ -1047,7 +1047,7 @@ public void wholeTextFiles() throws Exception { Files.write(content1, new File(tempDirName + "/part-00000")); Files.write(content2, new File(tempDirName + "/part-00001")); - Map container = new HashMap(); + Map container = new HashMap<>(); container.put(tempDirName+"/part-00000", new Text(content1).toString()); container.put(tempDirName+"/part-00001", new Text(content2).toString()); @@ -1076,16 +1076,16 @@ public void textFilesCompressed() throws IOException { public void sequenceFile() { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); @@ -1094,7 +1094,7 @@ public Tuple2 call(Tuple2 pair) { Text.class).mapToPair(new PairFunction, Integer, String>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(pair._1().get(), pair._2().toString()); + return new Tuple2<>(pair._1().get(), pair._2().toString()); } }); Assert.assertEquals(pairs, readRDD.collect()); @@ -1111,7 +1111,7 @@ public void binaryFiles() throws Exception { FileOutputStream fos1 = new FileOutputStream(file1); FileChannel channel1 = fos1.getChannel(); - ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); + ByteBuffer bbuf = ByteBuffer.wrap(content1); channel1.write(bbuf); channel1.close(); JavaPairRDD readRDD = sc.binaryFiles(tempDirName, 3); @@ -1132,14 +1132,14 @@ public void binaryFilesCaching() throws Exception { FileOutputStream fos1 = new FileOutputStream(file1); FileChannel channel1 = fos1.getChannel(); - ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); + ByteBuffer bbuf = ByteBuffer.wrap(content1); channel1.write(bbuf); channel1.close(); JavaPairRDD readRDD = sc.binaryFiles(tempDirName).cache(); readRDD.foreach(new VoidFunction>() { @Override - public void call(Tuple2 pair) throws Exception { + public void call(Tuple2 pair) { pair._2().toArray(); // force the file to read } }); @@ -1163,7 +1163,7 @@ public void binaryRecords() throws Exception { FileChannel channel1 = fos1.getChannel(); for (int i = 0; i < numOfCopies; i++) { - ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); + ByteBuffer bbuf = ByteBuffer.wrap(content1); channel1.write(bbuf); } channel1.close(); @@ -1181,24 +1181,23 @@ public void binaryRecords() throws Exception { public void writeWithNewAPIHadoopFile() { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } - }).saveAsNewAPIHadoopFile(outputDir, IntWritable.class, Text.class, - org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); + }).saveAsNewAPIHadoopFile( + outputDir, IntWritable.class, Text.class, + org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); - JavaPairRDD output = sc.sequenceFile(outputDir, IntWritable.class, - Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { + JavaPairRDD output = sc.sequenceFile(outputDir, IntWritable.class, Text.class); + Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1211,24 +1210,23 @@ public String call(Tuple2 x) { public void readWithNewAPIHadoopFile() throws IOException { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); JavaPairRDD output = sc.newAPIHadoopFile(outputDir, - org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, IntWritable.class, - Text.class, new Job().getConfiguration()); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { + org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, + IntWritable.class, Text.class, new Job().getConfiguration()); + Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1252,9 +1250,9 @@ public void objectFilesOfInts() { public void objectFilesOfComplexTypes() { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.saveAsObjectFile(outputDir); @@ -1268,23 +1266,22 @@ public void objectFilesOfComplexTypes() { public void hadoopFile() { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); JavaPairRDD output = sc.hadoopFile(outputDir, - SequenceFileInputFormat.class, IntWritable.class, Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { + SequenceFileInputFormat.class, IntWritable.class, Text.class); + Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1297,16 +1294,16 @@ public String call(Tuple2 x) { public void hadoopFileCompressed() { String outputDir = new File(tempDir, "output_compressed").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class, DefaultCodec.class); @@ -1314,8 +1311,7 @@ public Tuple2 call(Tuple2 pair) { JavaPairRDD output = sc.hadoopFile(outputDir, SequenceFileInputFormat.class, IntWritable.class, Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { + Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1415,8 +1411,8 @@ public String call(Integer t) { return t.toString(); } }).collect(); - Assert.assertEquals(new Tuple2("1", 1), s.get(0)); - Assert.assertEquals(new Tuple2("2", 2), s.get(1)); + Assert.assertEquals(new Tuple2<>("1", 1), s.get(0)); + Assert.assertEquals(new Tuple2<>("2", 2), s.get(1)); } @Test @@ -1449,20 +1445,20 @@ public void combineByKey() { JavaRDD originalRDD = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6)); Function keyFunction = new Function() { @Override - public Integer call(Integer v1) throws Exception { + public Integer call(Integer v1) { return v1 % 3; } }; Function createCombinerFunction = new Function() { @Override - public Integer call(Integer v1) throws Exception { + public Integer call(Integer v1) { return v1; } }; Function2 mergeValueFunction = new Function2() { @Override - public Integer call(Integer v1, Integer v2) throws Exception { + public Integer call(Integer v1, Integer v2) { return v1 + v2; } }; @@ -1474,7 +1470,9 @@ public Integer call(Integer v1, Integer v2) throws Exception { Assert.assertEquals(expected, results); Partitioner defaultPartitioner = Partitioner.defaultPartitioner( - combinedRDD.rdd(), JavaConversions.asScalaBuffer(Lists.>newArrayList())); + combinedRDD.rdd(), + JavaConverters.collectionAsScalaIterableConverter( + Collections.>emptyList()).asScala().toSeq()); combinedRDD = originalRDD.keyBy(keyFunction) .combineByKey( createCombinerFunction, @@ -1495,21 +1493,21 @@ public void mapOnPairRDD() { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i, i % 2); + return new Tuple2<>(i, i % 2); } }); JavaPairRDD rdd3 = rdd2.mapToPair( new PairFunction, Integer, Integer>() { - @Override - public Tuple2 call(Tuple2 in) { - return new Tuple2(in._2(), in._1()); - } - }); + @Override + public Tuple2 call(Tuple2 in) { + return new Tuple2<>(in._2(), in._1()); + } + }); Assert.assertEquals(Arrays.asList( - new Tuple2(1, 1), - new Tuple2(0, 2), - new Tuple2(1, 3), - new Tuple2(0, 4)), rdd3.collect()); + new Tuple2<>(1, 1), + new Tuple2<>(0, 2), + new Tuple2<>(1, 3), + new Tuple2<>(0, 4)), rdd3.collect()); } @@ -1522,7 +1520,7 @@ public void collectPartitions() { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i, i % 2); + return new Tuple2<>(i, i % 2); } }); @@ -1533,23 +1531,23 @@ public Tuple2 call(Integer i) { Assert.assertEquals(Arrays.asList(3, 4), parts[0]); Assert.assertEquals(Arrays.asList(5, 6, 7), parts[1]); - Assert.assertEquals(Arrays.asList(new Tuple2(1, 1), - new Tuple2(2, 0)), + Assert.assertEquals(Arrays.asList(new Tuple2<>(1, 1), + new Tuple2<>(2, 0)), rdd2.collectPartitions(new int[] {0})[0]); List>[] parts2 = rdd2.collectPartitions(new int[] {1, 2}); - Assert.assertEquals(Arrays.asList(new Tuple2(3, 1), - new Tuple2(4, 0)), + Assert.assertEquals(Arrays.asList(new Tuple2<>(3, 1), + new Tuple2<>(4, 0)), parts2[0]); - Assert.assertEquals(Arrays.asList(new Tuple2(5, 1), - new Tuple2(6, 0), - new Tuple2(7, 1)), + Assert.assertEquals(Arrays.asList(new Tuple2<>(5, 1), + new Tuple2<>(6, 0), + new Tuple2<>(7, 1)), parts2[1]); } @Test public void countApproxDistinct() { - List arrayData = new ArrayList(); + List arrayData = new ArrayList<>(); int size = 100; for (int i = 0; i < 100000; i++) { arrayData.add(i % size); @@ -1560,15 +1558,15 @@ public void countApproxDistinct() { @Test public void countApproxDistinctByKey() { - List> arrayData = new ArrayList>(); + List> arrayData = new ArrayList<>(); for (int i = 10; i < 100; i++) { for (int j = 0; j < i; j++) { - arrayData.add(new Tuple2(i, j)); + arrayData.add(new Tuple2<>(i, j)); } } double relativeSD = 0.001; JavaPairRDD pairRdd = sc.parallelizePairs(arrayData); - List> res = pairRdd.countApproxDistinctByKey(8, 0).collect(); + List> res = pairRdd.countApproxDistinctByKey(relativeSD, 8).collect(); for (Tuple2 resItem : res) { double count = (double)resItem._1(); Long resCount = (Long)resItem._2(); @@ -1586,7 +1584,7 @@ public void collectAsMapWithIntArrayValues() { new PairFunction() { @Override public Tuple2 call(Integer x) { - return new Tuple2(x, new int[] { x }); + return new Tuple2<>(x, new int[]{x}); } }); pairRDD.collect(); // Works fine @@ -1597,7 +1595,7 @@ public Tuple2 call(Integer x) { @Test public void collectAsMapAndSerialize() throws Exception { JavaPairRDD rdd = - sc.parallelizePairs(Arrays.asList(new Tuple2("foo", 1))); + sc.parallelizePairs(Arrays.asList(new Tuple2<>("foo", 1))); Map map = rdd.collectAsMap(); ByteArrayOutputStream bytes = new ByteArrayOutputStream(); new ObjectOutputStream(bytes).writeObject(map); @@ -1614,7 +1612,7 @@ public void sampleByKey() { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i % 2, 1); + return new Tuple2<>(i % 2, 1); } }); Map fractions = Maps.newHashMap(); @@ -1622,12 +1620,12 @@ public Tuple2 call(Integer i) { fractions.put(1, 1.0); JavaPairRDD wr = rdd2.sampleByKey(true, fractions, 1L); Map wrCounts = (Map) (Object) wr.countByKey(); - Assert.assertTrue(wrCounts.size() == 2); + Assert.assertEquals(2, wrCounts.size()); Assert.assertTrue(wrCounts.get(0) > 0); Assert.assertTrue(wrCounts.get(1) > 0); JavaPairRDD wor = rdd2.sampleByKey(false, fractions, 1L); Map worCounts = (Map) (Object) wor.countByKey(); - Assert.assertTrue(worCounts.size() == 2); + Assert.assertEquals(2, worCounts.size()); Assert.assertTrue(worCounts.get(0) > 0); Assert.assertTrue(worCounts.get(1) > 0); } @@ -1640,7 +1638,7 @@ public void sampleByKeyExact() { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i % 2, 1); + return new Tuple2<>(i % 2, 1); } }); Map fractions = Maps.newHashMap(); @@ -1648,25 +1646,25 @@ public Tuple2 call(Integer i) { fractions.put(1, 1.0); JavaPairRDD wrExact = rdd2.sampleByKeyExact(true, fractions, 1L); Map wrExactCounts = (Map) (Object) wrExact.countByKey(); - Assert.assertTrue(wrExactCounts.size() == 2); + Assert.assertEquals(2, wrExactCounts.size()); Assert.assertTrue(wrExactCounts.get(0) == 2); Assert.assertTrue(wrExactCounts.get(1) == 4); JavaPairRDD worExact = rdd2.sampleByKeyExact(false, fractions, 1L); Map worExactCounts = (Map) (Object) worExact.countByKey(); - Assert.assertTrue(worExactCounts.size() == 2); + Assert.assertEquals(2, worExactCounts.size()); Assert.assertTrue(worExactCounts.get(0) == 2); Assert.assertTrue(worExactCounts.get(1) == 4); } private static class SomeCustomClass implements Serializable { - public SomeCustomClass() { + SomeCustomClass() { // Intentionally left blank } } @Test public void collectUnderlyingScalaRDD() { - List data = new ArrayList(); + List data = new ArrayList<>(); for (int i = 0; i < 100; i++) { data.add(new SomeCustomClass()); } @@ -1678,7 +1676,7 @@ public void collectUnderlyingScalaRDD() { private static final class BuggyMapFunction implements Function { @Override - public T call(T x) throws Exception { + public T call(T x) { throw new IllegalStateException("Custom exception!"); } } @@ -1715,7 +1713,7 @@ public void foreachAsync() throws Exception { JavaFutureAction future = rdd.foreachAsync( new VoidFunction() { @Override - public void call(Integer integer) throws Exception { + public void call(Integer integer) { // intentionally left blank. } } @@ -1744,7 +1742,7 @@ public void testAsyncActionCancellation() throws Exception { JavaRDD rdd = sc.parallelize(data, 1); JavaFutureAction future = rdd.foreachAsync(new VoidFunction() { @Override - public void call(Integer integer) throws Exception { + public void call(Integer integer) throws InterruptedException { Thread.sleep(10000); // To ensure that the job won't finish before it's cancelled. } }); @@ -1783,7 +1781,7 @@ public void testGuavaOptional() { // Stop the context created in setUp() and start a local-cluster one, to force usage of the // assembly. sc.stop(); - JavaSparkContext localCluster = new JavaSparkContext("local-cluster[1,1,512]", "JavaAPISuite"); + JavaSparkContext localCluster = new JavaSparkContext("local-cluster[1,1,1024]", "JavaAPISuite"); try { JavaRDD rdd1 = localCluster.parallelize(Arrays.asList(1, 2, null), 3); JavaRDD> rdd2 = rdd1.map( diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java similarity index 63% rename from launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java rename to core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 252d5abae1ca..d0c26dd05679 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -20,6 +20,7 @@ import java.io.BufferedReader; import java.io.InputStream; import java.io.InputStreamReader; +import java.util.Arrays; import java.util.HashMap; import java.util.Map; @@ -35,8 +36,54 @@ public class SparkLauncherSuite { private static final Logger LOG = LoggerFactory.getLogger(SparkLauncherSuite.class); + @Test + public void testSparkArgumentHandling() throws Exception { + SparkLauncher launcher = new SparkLauncher() + .setSparkHome(System.getProperty("spark.test.home")); + SparkSubmitOptionParser opts = new SparkSubmitOptionParser(); + + launcher.addSparkArg(opts.HELP); + try { + launcher.addSparkArg(opts.PROXY_USER); + fail("Expected IllegalArgumentException."); + } catch (IllegalArgumentException e) { + // Expected. + } + + launcher.addSparkArg(opts.PROXY_USER, "someUser"); + try { + launcher.addSparkArg(opts.HELP, "someValue"); + fail("Expected IllegalArgumentException."); + } catch (IllegalArgumentException e) { + // Expected. + } + + launcher.addSparkArg("--future-argument"); + launcher.addSparkArg("--future-argument", "someValue"); + + launcher.addSparkArg(opts.MASTER, "myMaster"); + assertEquals("myMaster", launcher.builder.master); + + launcher.addJar("foo"); + launcher.addSparkArg(opts.JARS, "bar"); + assertEquals(Arrays.asList("bar"), launcher.builder.jars); + + launcher.addFile("foo"); + launcher.addSparkArg(opts.FILES, "bar"); + assertEquals(Arrays.asList("bar"), launcher.builder.files); + + launcher.addPyFile("foo"); + launcher.addSparkArg(opts.PY_FILES, "bar"); + assertEquals(Arrays.asList("bar"), launcher.builder.pyFiles); + + launcher.setConf("spark.foo", "foo"); + launcher.addSparkArg(opts.CONF, "spark.foo=bar"); + assertEquals("bar", launcher.builder.conf.get("spark.foo")); + } + @Test public void testChildProcLauncher() throws Exception { + SparkSubmitOptionParser opts = new SparkSubmitOptionParser(); Map env = new HashMap(); env.put("SPARK_PRINT_LAUNCH_COMMAND", "1"); @@ -44,9 +91,12 @@ public void testChildProcLauncher() throws Exception { .setSparkHome(System.getProperty("spark.test.home")) .setMaster("local") .setAppResource("spark-internal") + .addSparkArg(opts.CONF, + String.format("%s=-Dfoo=ShouldBeOverriddenBelow", SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS)) .setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, "-Dfoo=bar -Dtest.name=-testChildProcLauncher") .setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path")) + .addSparkArg(opts.CLASS, "ShouldBeOverriddenBelow") .setMainClass(SparkLauncherTestApp.class.getName()) .addAppArgs("proc"); final Process app = launcher.launch(); diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java index db9e82759090..934b7e03050b 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java @@ -32,8 +32,8 @@ public class PackedRecordPointerSuite { public void heap() { final TaskMemoryManager memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); - final MemoryBlock page0 = memoryManager.allocatePage(100); - final MemoryBlock page1 = memoryManager.allocatePage(100); + final MemoryBlock page0 = memoryManager.allocatePage(128); + final MemoryBlock page1 = memoryManager.allocatePage(128); final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, page1.getBaseOffset() + 42); PackedRecordPointer packedPointer = new PackedRecordPointer(); @@ -50,8 +50,8 @@ public void heap() { public void offHeap() { final TaskMemoryManager memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE)); - final MemoryBlock page0 = memoryManager.allocatePage(100); - final MemoryBlock page1 = memoryManager.allocatePage(100); + final MemoryBlock page0 = memoryManager.allocatePage(128); + final MemoryBlock page1 = memoryManager.allocatePage(128); final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, page1.getBaseOffset() + 42); PackedRecordPointer packedPointer = new PackedRecordPointer(); diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java index 8fa72597db24..40fefe2c9d14 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java @@ -24,7 +24,7 @@ import org.junit.Test; import org.apache.spark.HashPartitioner; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ExecutorMemoryManager; import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -34,11 +34,7 @@ public class UnsafeShuffleInMemorySorterSuite { private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) { final byte[] strBytes = new byte[strLength]; - PlatformDependent.copyMemory( - baseObject, - baseOffset, - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, strLength); + Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, strLength); return new String(strBytes); } @@ -74,14 +70,10 @@ public void testBasicSorting() throws Exception { for (String str : dataToSort) { final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position); final byte[] strBytes = str.getBytes("utf-8"); - PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length); + Platform.putInt(baseObject, position, strBytes.length); position += 4; - PlatformDependent.copyMemory( - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - position, - strBytes.length); + Platform.copyMemory( + strBytes, Platform.BYTE_ARRAY_OFFSET, baseObject, position, strBytes.length); position += strBytes.length; sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str)); } @@ -98,7 +90,7 @@ public void testBasicSorting() throws Exception { Assert.assertTrue("Partition id " + partitionId + " should be >= prev id " + prevPartitionId, partitionId >= prevPartitionId); final long recordAddress = iter.packedRecordPointer.getRecordPointer(); - final int recordLength = PlatformDependent.UNSAFE.getInt( + final int recordLength = Platform.getInt( memoryManager.getPage(recordAddress), memoryManager.getOffsetInPage(recordAddress)); final String str = getStringFromDataPage( memoryManager.getPage(recordAddress), diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 83d109115aa5..a266b0c36e0f 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -111,10 +111,11 @@ public void setUp() throws IOException { mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); partitionSizesInMergedFile = null; spillFilesCreated.clear(); - conf = new SparkConf(); + conf = new SparkConf().set("spark.buffer.pageSize", "128m"); taskMetrics = new TaskMetrics(); when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); + when(shuffleMemoryManager.pageSizeBytes()).thenReturn(128L * 1024 * 1024); when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); when(blockManager.getDiskWriter( @@ -190,6 +191,7 @@ public Tuple2 answer( }); when(taskContext.taskMetrics()).thenReturn(taskMetrics); + when(taskContext.internalMetricsToAccumulators()).thenReturn(null); when(shuffleDep.serializer()).thenReturn(Option.apply(serializer)); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); @@ -253,6 +255,23 @@ public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException { createWriter(false).stop(false); } + class PandaException extends RuntimeException { + } + + @Test(expected=PandaException.class) + public void writeFailurePropagates() throws Exception { + class BadRecords extends scala.collection.AbstractIterator> { + @Override public boolean hasNext() { + throw new PandaException(); + } + @Override public Product2 next() { + return null; + } + } + final UnsafeShuffleWriter writer = createWriter(true); + writer.write(new BadRecords()); + } + @Test public void writeEmptyIterator() throws Exception { final UnsafeShuffleWriter writer = createWriter(true); @@ -456,62 +475,22 @@ public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception @Test public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception { - // Use a custom serializer so that we have exact control over the size of serialized data. - final Serializer byteArraySerializer = new Serializer() { - @Override - public SerializerInstance newInstance() { - return new SerializerInstance() { - @Override - public SerializationStream serializeStream(final OutputStream s) { - return new SerializationStream() { - @Override - public void flush() { } - - @Override - public SerializationStream writeObject(T t, ClassTag ev1) { - byte[] bytes = (byte[]) t; - try { - s.write(bytes); - } catch (IOException e) { - throw new RuntimeException(e); - } - return this; - } - - @Override - public void close() { } - }; - } - public ByteBuffer serialize(T t, ClassTag ev1) { return null; } - public DeserializationStream deserializeStream(InputStream s) { return null; } - public T deserialize(ByteBuffer b, ClassLoader l, ClassTag ev1) { return null; } - public T deserialize(ByteBuffer bytes, ClassTag ev1) { return null; } - }; - } - }; - when(shuffleDep.serializer()).thenReturn(Option.apply(byteArraySerializer)); final UnsafeShuffleWriter writer = createWriter(false); - // Insert a record and force a spill so that there's something to clean up: - writer.insertRecordIntoSorter(new Tuple2(new byte[1], new byte[1])); - writer.forceSorterToSpill(); + final ArrayList> dataToWrite = new ArrayList>(); + dataToWrite.add(new Tuple2(1, ByteBuffer.wrap(new byte[1]))); // We should be able to write a record that's right _at_ the max record size - final byte[] atMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE]; + final byte[] atMaxRecordSize = new byte[writer.maxRecordSizeBytes()]; new Random(42).nextBytes(atMaxRecordSize); - writer.insertRecordIntoSorter(new Tuple2(new byte[0], atMaxRecordSize)); - writer.forceSorterToSpill(); - // Inserting a record that's larger than the max record size should fail: - final byte[] exceedsMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE + 1]; + dataToWrite.add(new Tuple2(2, ByteBuffer.wrap(atMaxRecordSize))); + // Inserting a record that's larger than the max record size + final byte[] exceedsMaxRecordSize = new byte[writer.maxRecordSizeBytes() + 1]; new Random(42).nextBytes(exceedsMaxRecordSize); - Product2 hugeRecord = - new Tuple2(new byte[0], exceedsMaxRecordSize); - try { - // Here, we write through the public `write()` interface instead of the test-only - // `insertRecordIntoSorter` interface: - writer.write(Collections.singletonList(hugeRecord).iterator()); - fail("Expected exception to be thrown"); - } catch (IOException e) { - // Pass - } + dataToWrite.add(new Tuple2(3, ByteBuffer.wrap(exceedsMaxRecordSize))); + writer.write(dataToWrite.iterator()); + writer.stop(true); + assertEquals( + HashMultiset.create(dataToWrite), + HashMultiset.create(readRecordsFromFile())); assertSpillFilesWereCleanedUp(); } @@ -525,4 +504,58 @@ public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException { writer.stop(false); assertSpillFilesWereCleanedUp(); } + + @Test + public void testPeakMemoryUsed() throws Exception { + final long recordLengthBytes = 8; + final long pageSizeBytes = 256; + final long numRecordsPerPage = pageSizeBytes / recordLengthBytes; + when(shuffleMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes); + final UnsafeShuffleWriter writer = + new UnsafeShuffleWriter( + blockManager, + shuffleBlockResolver, + taskMemoryManager, + shuffleMemoryManager, + new UnsafeShuffleHandle<>(0, 1, shuffleDep), + 0, // map id + taskContext, + conf); + + // Peak memory should be monotonically increasing. More specifically, every time + // we allocate a new page it should increase by exactly the size of the page. + long previousPeakMemory = writer.getPeakMemoryUsedBytes(); + long newPeakMemory; + try { + for (int i = 0; i < numRecordsPerPage * 10; i++) { + writer.insertRecordIntoSorter(new Tuple2(1, 1)); + newPeakMemory = writer.getPeakMemoryUsedBytes(); + if (i % numRecordsPerPage == 0 && i != 0) { + // The first page is allocated in constructor, another page will be allocated after + // every numRecordsPerPage records (peak memory should change). + assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); + } else { + assertEquals(previousPeakMemory, newPeakMemory); + } + previousPeakMemory = newPeakMemory; + } + + // Spilling should not change peak memory + writer.forceSorterToSpill(); + newPeakMemory = writer.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + for (int i = 0; i < numRecordsPerPage; i++) { + writer.insertRecordIntoSorter(new Tuple2(1, 1)); + } + newPeakMemory = writer.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + + // Closing the writer should not change peak memory + writer.closeAndWriteOutput(); + newPeakMemory = writer.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + } finally { + writer.stop(false); + } + } } diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java new file mode 100644 index 000000000000..ab480b60adae --- /dev/null +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -0,0 +1,573 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.map; + +import java.lang.Exception; +import java.nio.ByteBuffer; +import java.util.*; + +import org.junit.*; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import static org.hamcrest.Matchers.greaterThan; +import static org.junit.Assert.*; +import static org.mockito.AdditionalMatchers.geq; +import static org.mockito.Mockito.*; + +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.memory.*; +import org.apache.spark.unsafe.Platform; + + +public abstract class AbstractBytesToBytesMapSuite { + + private final Random rand = new Random(42); + + private ShuffleMemoryManager shuffleMemoryManager; + private TaskMemoryManager taskMemoryManager; + private TaskMemoryManager sizeLimitedTaskMemoryManager; + private final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes + + @Before + public void setup() { + shuffleMemoryManager = ShuffleMemoryManager.create(Long.MAX_VALUE, PAGE_SIZE_BYTES); + taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator())); + // Mocked memory manager for tests that check the maximum array size, since actually allocating + // such large arrays will cause us to run out of memory in our tests. + sizeLimitedTaskMemoryManager = mock(TaskMemoryManager.class); + when(sizeLimitedTaskMemoryManager.allocate(geq(1L << 20))).thenAnswer( + new Answer() { + @Override + public MemoryBlock answer(InvocationOnMock invocation) throws Throwable { + if (((Long) invocation.getArguments()[0] / 8) > Integer.MAX_VALUE) { + throw new OutOfMemoryError("Requested array size exceeds VM limit"); + } + return new MemoryBlock(null, 0, (Long) invocation.getArguments()[0]); + } + } + ); + } + + @After + public void tearDown() { + Assert.assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory()); + if (shuffleMemoryManager != null) { + long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask(); + shuffleMemoryManager = null; + Assert.assertEquals(0L, leakedShuffleMemory); + } + } + + protected abstract MemoryAllocator getMemoryAllocator(); + + private static byte[] getByteArray(MemoryLocation loc, int size) { + final byte[] arr = new byte[size]; + Platform.copyMemory( + loc.getBaseObject(), loc.getBaseOffset(), arr, Platform.BYTE_ARRAY_OFFSET, size); + return arr; + } + + private byte[] getRandomByteArray(int numWords) { + Assert.assertTrue(numWords >= 0); + final int lengthInBytes = numWords * 8; + final byte[] bytes = new byte[lengthInBytes]; + rand.nextBytes(bytes); + return bytes; + } + + /** + * Fast equality checking for byte arrays, since these comparisons are a bottleneck + * in our stress tests. + */ + private static boolean arrayEquals( + byte[] expected, + MemoryLocation actualAddr, + long actualLengthBytes) { + return (actualLengthBytes == expected.length) && ByteArrayMethods.arrayEquals( + expected, + Platform.BYTE_ARRAY_OFFSET, + actualAddr.getBaseObject(), + actualAddr.getBaseOffset(), + expected.length + ); + } + + @Test + public void emptyMap() { + BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, shuffleMemoryManager, 64, PAGE_SIZE_BYTES); + try { + Assert.assertEquals(0, map.numElements()); + final int keyLengthInWords = 10; + final int keyLengthInBytes = keyLengthInWords * 8; + final byte[] key = getRandomByteArray(keyLengthInWords); + Assert.assertFalse(map.lookup(key, Platform.BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined()); + Assert.assertFalse(map.iterator().hasNext()); + } finally { + map.free(); + } + } + + @Test + public void setAndRetrieveAKey() { + BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, shuffleMemoryManager, 64, PAGE_SIZE_BYTES); + final int recordLengthWords = 10; + final int recordLengthBytes = recordLengthWords * 8; + final byte[] keyData = getRandomByteArray(recordLengthWords); + final byte[] valueData = getRandomByteArray(recordLengthWords); + try { + final BytesToBytesMap.Location loc = + map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes); + Assert.assertFalse(loc.isDefined()); + Assert.assertTrue(loc.putNewKey( + keyData, + Platform.BYTE_ARRAY_OFFSET, + recordLengthBytes, + valueData, + Platform.BYTE_ARRAY_OFFSET, + recordLengthBytes + )); + // After storing the key and value, the other location methods should return results that + // reflect the result of this store without us having to call lookup() again on the same key. + Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); + Assert.assertEquals(recordLengthBytes, loc.getValueLength()); + Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); + Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); + + // After calling lookup() the location should still point to the correct data. + Assert.assertTrue( + map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined()); + Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); + Assert.assertEquals(recordLengthBytes, loc.getValueLength()); + Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); + Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); + + try { + Assert.assertTrue(loc.putNewKey( + keyData, + Platform.BYTE_ARRAY_OFFSET, + recordLengthBytes, + valueData, + Platform.BYTE_ARRAY_OFFSET, + recordLengthBytes + )); + Assert.fail("Should not be able to set a new value for a key"); + } catch (AssertionError e) { + // Expected exception; do nothing. + } + } finally { + map.free(); + } + } + + private void iteratorTestBase(boolean destructive) throws Exception { + final int size = 4096; + BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, shuffleMemoryManager, size / 2, PAGE_SIZE_BYTES); + try { + for (long i = 0; i < size; i++) { + final long[] value = new long[] { i }; + final BytesToBytesMap.Location loc = + map.lookup(value, Platform.LONG_ARRAY_OFFSET, 8); + Assert.assertFalse(loc.isDefined()); + // Ensure that we store some zero-length keys + if (i % 5 == 0) { + Assert.assertTrue(loc.putNewKey( + null, + Platform.LONG_ARRAY_OFFSET, + 0, + value, + Platform.LONG_ARRAY_OFFSET, + 8 + )); + } else { + Assert.assertTrue(loc.putNewKey( + value, + Platform.LONG_ARRAY_OFFSET, + 8, + value, + Platform.LONG_ARRAY_OFFSET, + 8 + )); + } + } + final java.util.BitSet valuesSeen = new java.util.BitSet(size); + final Iterator iter; + if (destructive) { + iter = map.destructiveIterator(); + } else { + iter = map.iterator(); + } + int numPages = map.getNumDataPages(); + int countFreedPages = 0; + while (iter.hasNext()) { + final BytesToBytesMap.Location loc = iter.next(); + Assert.assertTrue(loc.isDefined()); + final MemoryLocation keyAddress = loc.getKeyAddress(); + final MemoryLocation valueAddress = loc.getValueAddress(); + final long value = Platform.getLong( + valueAddress.getBaseObject(), valueAddress.getBaseOffset()); + final long keyLength = loc.getKeyLength(); + if (keyLength == 0) { + Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0); + } else { + final long key = Platform.getLong(keyAddress.getBaseObject(), keyAddress.getBaseOffset()); + Assert.assertEquals(value, key); + } + valuesSeen.set((int) value); + if (destructive) { + // The iterator moves onto next page and frees previous page + if (map.getNumDataPages() < numPages) { + numPages = map.getNumDataPages(); + countFreedPages++; + } + } + } + if (destructive) { + // Latest page is not freed by iterator but by map itself + Assert.assertEquals(countFreedPages, numPages - 1); + } + Assert.assertEquals(size, valuesSeen.cardinality()); + } finally { + map.free(); + } + } + + @Test + public void iteratorTest() throws Exception { + iteratorTestBase(false); + } + + @Test + public void destructiveIteratorTest() throws Exception { + iteratorTestBase(true); + } + + @Test + public void iteratingOverDataPagesWithWastedSpace() throws Exception { + final int NUM_ENTRIES = 1000 * 1000; + final int KEY_LENGTH = 24; + final int VALUE_LENGTH = 40; + final BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, shuffleMemoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES); + // Each record will take 8 + 24 + 40 = 72 bytes of space in the data page. Our 64-megabyte + // pages won't be evenly-divisible by records of this size, which will cause us to waste some + // space at the end of the page. This is necessary in order for us to take the end-of-record + // handling branch in iterator(). + try { + for (int i = 0; i < NUM_ENTRIES; i++) { + final long[] key = new long[] { i, i, i }; // 3 * 8 = 24 bytes + final long[] value = new long[] { i, i, i, i, i }; // 5 * 8 = 40 bytes + final BytesToBytesMap.Location loc = map.lookup( + key, + Platform.LONG_ARRAY_OFFSET, + KEY_LENGTH + ); + Assert.assertFalse(loc.isDefined()); + Assert.assertTrue(loc.putNewKey( + key, + Platform.LONG_ARRAY_OFFSET, + KEY_LENGTH, + value, + Platform.LONG_ARRAY_OFFSET, + VALUE_LENGTH + )); + } + Assert.assertEquals(2, map.getNumDataPages()); + + final java.util.BitSet valuesSeen = new java.util.BitSet(NUM_ENTRIES); + final Iterator iter = map.iterator(); + final long key[] = new long[KEY_LENGTH / 8]; + final long value[] = new long[VALUE_LENGTH / 8]; + while (iter.hasNext()) { + final BytesToBytesMap.Location loc = iter.next(); + Assert.assertTrue(loc.isDefined()); + Assert.assertEquals(KEY_LENGTH, loc.getKeyLength()); + Assert.assertEquals(VALUE_LENGTH, loc.getValueLength()); + Platform.copyMemory( + loc.getKeyAddress().getBaseObject(), + loc.getKeyAddress().getBaseOffset(), + key, + Platform.LONG_ARRAY_OFFSET, + KEY_LENGTH + ); + Platform.copyMemory( + loc.getValueAddress().getBaseObject(), + loc.getValueAddress().getBaseOffset(), + value, + Platform.LONG_ARRAY_OFFSET, + VALUE_LENGTH + ); + for (long j : key) { + Assert.assertEquals(key[0], j); + } + for (long j : value) { + Assert.assertEquals(key[0], j); + } + valuesSeen.set((int) key[0]); + } + Assert.assertEquals(NUM_ENTRIES, valuesSeen.cardinality()); + } finally { + map.free(); + } + } + + @Test + public void randomizedStressTest() { + final int size = 65536; + // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays + // into ByteBuffers in order to use them as keys here. + final Map expected = new HashMap(); + final BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, shuffleMemoryManager, size, PAGE_SIZE_BYTES); + + try { + // Fill the map to 90% full so that we can trigger probing + for (int i = 0; i < size * 0.9; i++) { + final byte[] key = getRandomByteArray(rand.nextInt(256) + 1); + final byte[] value = getRandomByteArray(rand.nextInt(512) + 1); + if (!expected.containsKey(ByteBuffer.wrap(key))) { + expected.put(ByteBuffer.wrap(key), value); + final BytesToBytesMap.Location loc = map.lookup( + key, + Platform.BYTE_ARRAY_OFFSET, + key.length + ); + Assert.assertFalse(loc.isDefined()); + Assert.assertTrue(loc.putNewKey( + key, + Platform.BYTE_ARRAY_OFFSET, + key.length, + value, + Platform.BYTE_ARRAY_OFFSET, + value.length + )); + // After calling putNewKey, the following should be true, even before calling + // lookup(): + Assert.assertTrue(loc.isDefined()); + Assert.assertEquals(key.length, loc.getKeyLength()); + Assert.assertEquals(value.length, loc.getValueLength()); + Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length)); + Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length)); + } + } + + for (Map.Entry entry : expected.entrySet()) { + final byte[] key = entry.getKey().array(); + final byte[] value = entry.getValue(); + final BytesToBytesMap.Location loc = + map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); + Assert.assertTrue(loc.isDefined()); + Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); + Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); + } + } finally { + map.free(); + } + } + + @Test + public void randomizedTestWithRecordsLargerThanPageSize() { + final long pageSizeBytes = 128; + final BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, shuffleMemoryManager, 64, pageSizeBytes); + // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays + // into ByteBuffers in order to use them as keys here. + final Map expected = new HashMap(); + try { + for (int i = 0; i < 1000; i++) { + final byte[] key = getRandomByteArray(rand.nextInt(128)); + final byte[] value = getRandomByteArray(rand.nextInt(128)); + if (!expected.containsKey(ByteBuffer.wrap(key))) { + expected.put(ByteBuffer.wrap(key), value); + final BytesToBytesMap.Location loc = map.lookup( + key, + Platform.BYTE_ARRAY_OFFSET, + key.length + ); + Assert.assertFalse(loc.isDefined()); + Assert.assertTrue(loc.putNewKey( + key, + Platform.BYTE_ARRAY_OFFSET, + key.length, + value, + Platform.BYTE_ARRAY_OFFSET, + value.length + )); + // After calling putNewKey, the following should be true, even before calling + // lookup(): + Assert.assertTrue(loc.isDefined()); + Assert.assertEquals(key.length, loc.getKeyLength()); + Assert.assertEquals(value.length, loc.getValueLength()); + Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length)); + Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length)); + } + } + for (Map.Entry entry : expected.entrySet()) { + final byte[] key = entry.getKey().array(); + final byte[] value = entry.getValue(); + final BytesToBytesMap.Location loc = + map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); + Assert.assertTrue(loc.isDefined()); + Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); + Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); + } + } finally { + map.free(); + } + } + + @Test + public void failureToAllocateFirstPage() { + shuffleMemoryManager = ShuffleMemoryManager.createForTesting(1024); + BytesToBytesMap map = + new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES); + try { + final long[] emptyArray = new long[0]; + final BytesToBytesMap.Location loc = + map.lookup(emptyArray, Platform.LONG_ARRAY_OFFSET, 0); + Assert.assertFalse(loc.isDefined()); + Assert.assertFalse(loc.putNewKey( + emptyArray, Platform.LONG_ARRAY_OFFSET, 0, emptyArray, Platform.LONG_ARRAY_OFFSET, 0)); + } finally { + map.free(); + } + } + + + @Test + public void failureToGrow() { + shuffleMemoryManager = ShuffleMemoryManager.createForTesting(1024 * 10); + BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, 1024); + try { + boolean success = true; + int i; + for (i = 0; i < 1024; i++) { + final long[] arr = new long[]{i}; + final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8); + success = + loc.putNewKey(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8); + if (!success) { + break; + } + } + Assert.assertThat(i, greaterThan(0)); + Assert.assertFalse(success); + } finally { + map.free(); + } + } + + @Test + public void initialCapacityBoundsChecking() { + try { + new BytesToBytesMap(sizeLimitedTaskMemoryManager, shuffleMemoryManager, 0, PAGE_SIZE_BYTES); + Assert.fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + // expected exception + } + + try { + new BytesToBytesMap( + sizeLimitedTaskMemoryManager, + shuffleMemoryManager, + BytesToBytesMap.MAX_CAPACITY + 1, + PAGE_SIZE_BYTES); + Assert.fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + // expected exception + } + + // Ignored because this can OOM now that we allocate the long array w/o a TaskMemoryManager + // Can allocate _at_ the max capacity + // BytesToBytesMap map = new BytesToBytesMap( + // sizeLimitedTaskMemoryManager, + // shuffleMemoryManager, + // BytesToBytesMap.MAX_CAPACITY, + // PAGE_SIZE_BYTES); + // map.free(); + } + + // Ignored because this can OOM now that we allocate the long array w/o a TaskMemoryManager + @Ignore + public void resizingLargeMap() { + // As long as a map's capacity is below the max, we should be able to resize up to the max + BytesToBytesMap map = new BytesToBytesMap( + sizeLimitedTaskMemoryManager, + shuffleMemoryManager, + BytesToBytesMap.MAX_CAPACITY - 64, + PAGE_SIZE_BYTES); + map.growAndRehash(); + map.free(); + } + + @Test + public void testPeakMemoryUsed() { + final long recordLengthBytes = 24; + final long pageSizeBytes = 256 + 8; // 8 bytes for end-of-page marker + final long numRecordsPerPage = (pageSizeBytes - 8) / recordLengthBytes; + final BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, shuffleMemoryManager, 1024, pageSizeBytes); + + // Since BytesToBytesMap is append-only, we expect the total memory consumption to be + // monotonically increasing. More specifically, every time we allocate a new page it + // should increase by exactly the size of the page. In this regard, the memory usage + // at any given time is also the peak memory used. + long previousPeakMemory = map.getPeakMemoryUsedBytes(); + long newPeakMemory; + try { + for (long i = 0; i < numRecordsPerPage * 10; i++) { + final long[] value = new long[]{i}; + map.lookup(value, Platform.LONG_ARRAY_OFFSET, 8).putNewKey( + value, + Platform.LONG_ARRAY_OFFSET, + 8, + value, + Platform.LONG_ARRAY_OFFSET, + 8); + newPeakMemory = map.getPeakMemoryUsedBytes(); + if (i % numRecordsPerPage == 0 && i > 0) { + // We allocated a new page for this record, so peak memory should change + assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); + } else { + assertEquals(previousPeakMemory, newPeakMemory); + } + previousPeakMemory = newPeakMemory; + } + + // Freeing the map should not change the peak memory + map.free(); + newPeakMemory = map.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + + } finally { + map.free(); + } + } + + @Test + public void testAcquirePageInConstructor() { + final BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES); + assertEquals(1, map.getNumDataPages()); + map.free(); + } + +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java rename to core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java rename to core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java new file mode 100644 index 000000000000..445a37b83e98 --- /dev/null +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -0,0 +1,385 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.LinkedList; +import java.util.UUID; + +import scala.Tuple2; +import scala.Tuple2$; +import scala.runtime.AbstractFunction1; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.junit.Assert.*; +import static org.mockito.AdditionalAnswers.returnsSecondArg; +import static org.mockito.Answers.RETURNS_SMART_NULLS; +import static org.mockito.Mockito.*; + +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.*; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; + +public class UnsafeExternalSorterSuite { + + final LinkedList spillFilesCreated = new LinkedList(); + final TaskMemoryManager taskMemoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + // Use integer comparison for comparing prefixes (which are partition ids, in this case) + final PrefixComparator prefixComparator = new PrefixComparator() { + @Override + public int compare(long prefix1, long prefix2) { + return (int) prefix1 - (int) prefix2; + } + }; + // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so + // use a dummy comparator + final RecordComparator recordComparator = new RecordComparator() { + @Override + public int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset) { + return 0; + } + }; + + SparkConf sparkConf; + File tempDir; + ShuffleMemoryManager shuffleMemoryManager; + @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; + @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; + @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; + + + private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "64m"); + + private static final class CompressStream extends AbstractFunction1 { + @Override + public OutputStream apply(OutputStream stream) { + return stream; + } + } + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + sparkConf = new SparkConf(); + tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test"); + shuffleMemoryManager = ShuffleMemoryManager.create(Long.MAX_VALUE, pageSizeBytes); + spillFilesCreated.clear(); + taskContext = mock(TaskContext.class); + when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); + when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); + when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() { + @Override + public Tuple2 answer(InvocationOnMock invocationOnMock) throws Throwable { + TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); + File file = File.createTempFile("spillFile", ".spill", tempDir); + spillFilesCreated.add(file); + return Tuple2$.MODULE$.apply(blockId, file); + } + }); + when(blockManager.getDiskWriter( + any(BlockId.class), + any(File.class), + any(SerializerInstance.class), + anyInt(), + any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() { + @Override + public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { + Object[] args = invocationOnMock.getArguments(); + + return new DiskBlockObjectWriter( + (BlockId) args[0], + (File) args[1], + (SerializerInstance) args[2], + (Integer) args[3], + new CompressStream(), + false, + (ShuffleWriteMetrics) args[4] + ); + } + }); + when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) + .then(returnsSecondArg()); + } + + @After + public void tearDown() { + try { + long leakedUnsafeMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); + if (shuffleMemoryManager != null) { + long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask(); + shuffleMemoryManager = null; + assertEquals(0L, leakedShuffleMemory); + } + assertEquals(0, leakedUnsafeMemory); + } finally { + Utils.deleteRecursively(tempDir); + tempDir = null; + } + } + + private void assertSpillFilesWereCleanedUp() { + for (File spillFile : spillFilesCreated) { + assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up", + spillFile.exists()); + } + } + + private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception { + final int[] arr = new int[]{ value }; + sorter.insertRecord(arr, Platform.INT_ARRAY_OFFSET, 4, value); + } + + private static void insertRecord( + UnsafeExternalSorter sorter, + int[] record, + long prefix) throws IOException { + sorter.insertRecord(record, Platform.INT_ARRAY_OFFSET, record.length * 4, prefix); + } + + private UnsafeExternalSorter newSorter() throws IOException { + return UnsafeExternalSorter.create( + taskMemoryManager, + shuffleMemoryManager, + blockManager, + taskContext, + recordComparator, + prefixComparator, + /* initialSize */ 1024, + pageSizeBytes); + } + + @Test + public void testSortingOnlyByPrefix() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); + insertNumber(sorter, 5); + insertNumber(sorter, 1); + insertNumber(sorter, 3); + sorter.spill(); + insertNumber(sorter, 4); + sorter.spill(); + insertNumber(sorter, 2); + + UnsafeSorterIterator iter = sorter.getSortedIterator(); + + for (int i = 1; i <= 5; i++) { + iter.loadNext(); + assertEquals(i, iter.getKeyPrefix()); + assertEquals(4, iter.getRecordLength()); + assertEquals(i, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); + } + + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); + } + + @Test + public void testSortingEmptyArrays() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); + sorter.insertRecord(null, 0, 0, 0); + sorter.insertRecord(null, 0, 0, 0); + sorter.spill(); + sorter.insertRecord(null, 0, 0, 0); + sorter.spill(); + sorter.insertRecord(null, 0, 0, 0); + sorter.insertRecord(null, 0, 0, 0); + + UnsafeSorterIterator iter = sorter.getSortedIterator(); + + for (int i = 1; i <= 5; i++) { + iter.loadNext(); + assertEquals(0, iter.getKeyPrefix()); + assertEquals(0, iter.getRecordLength()); + } + + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); + } + + @Test + public void spillingOccursInResponseToMemoryPressure() throws Exception { + shuffleMemoryManager = ShuffleMemoryManager.create(pageSizeBytes * 2, pageSizeBytes); + final UnsafeExternalSorter sorter = newSorter(); + final int numRecords = (int) pageSizeBytes / 4; + for (int i = 0; i <= numRecords; i++) { + insertNumber(sorter, numRecords - i); + } + // Ensure that spill files were created + assertThat(tempDir.listFiles().length, greaterThanOrEqualTo(1)); + // Read back the sorted data: + UnsafeSorterIterator iter = sorter.getSortedIterator(); + + int i = 0; + while (iter.hasNext()) { + iter.loadNext(); + assertEquals(i, iter.getKeyPrefix()); + assertEquals(4, iter.getRecordLength()); + assertEquals(i, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); + i++; + } + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); + } + + @Test + public void testFillingPage() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); + byte[] record = new byte[16]; + while (sorter.getNumberOfAllocatedPages() < 2) { + sorter.insertRecord(record, Platform.BYTE_ARRAY_OFFSET, record.length, 0); + } + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); + } + + @Test + public void sortingRecordsThatExceedPageSize() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); + final int[] largeRecord = new int[(int) pageSizeBytes + 16]; + Arrays.fill(largeRecord, 456); + final int[] smallRecord = new int[100]; + Arrays.fill(smallRecord, 123); + + insertRecord(sorter, largeRecord, 456); + sorter.spill(); + insertRecord(sorter, smallRecord, 123); + sorter.spill(); + insertRecord(sorter, smallRecord, 123); + insertRecord(sorter, largeRecord, 456); + + UnsafeSorterIterator iter = sorter.getSortedIterator(); + // Small record + assertTrue(iter.hasNext()); + iter.loadNext(); + assertEquals(123, iter.getKeyPrefix()); + assertEquals(smallRecord.length * 4, iter.getRecordLength()); + assertEquals(123, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); + // Small record + assertTrue(iter.hasNext()); + iter.loadNext(); + assertEquals(123, iter.getKeyPrefix()); + assertEquals(smallRecord.length * 4, iter.getRecordLength()); + assertEquals(123, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); + // Large record + assertTrue(iter.hasNext()); + iter.loadNext(); + assertEquals(456, iter.getKeyPrefix()); + assertEquals(largeRecord.length * 4, iter.getRecordLength()); + assertEquals(456, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); + // Large record + assertTrue(iter.hasNext()); + iter.loadNext(); + assertEquals(456, iter.getKeyPrefix()); + assertEquals(largeRecord.length * 4, iter.getRecordLength()); + assertEquals(456, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); + + assertFalse(iter.hasNext()); + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); + } + + @Test + public void testPeakMemoryUsed() throws Exception { + final long recordLengthBytes = 8; + final long pageSizeBytes = 256; + final long numRecordsPerPage = pageSizeBytes / recordLengthBytes; + final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( + taskMemoryManager, + shuffleMemoryManager, + blockManager, + taskContext, + recordComparator, + prefixComparator, + 1024, + pageSizeBytes); + + // Peak memory should be monotonically increasing. More specifically, every time + // we allocate a new page it should increase by exactly the size of the page. + long previousPeakMemory = sorter.getPeakMemoryUsedBytes(); + long newPeakMemory; + try { + for (int i = 0; i < numRecordsPerPage * 10; i++) { + insertNumber(sorter, i); + newPeakMemory = sorter.getPeakMemoryUsedBytes(); + // The first page is pre-allocated on instantiation + if (i % numRecordsPerPage == 0 && i > 0) { + // We allocated a new page for this record, so peak memory should change + assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); + } else { + assertEquals(previousPeakMemory, newPeakMemory); + } + previousPeakMemory = newPeakMemory; + } + + // Spilling should not change peak memory + sorter.spill(); + newPeakMemory = sorter.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + for (int i = 0; i < numRecordsPerPage; i++) { + insertNumber(sorter, i); + } + newPeakMemory = sorter.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + } finally { + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); + } + } + + @Test + public void testReservePageOnInstantiation() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); + try { + assertEquals(1, sorter.getNumberOfAllocatedPages()); + // Inserting a new record doesn't allocate more memory since we already have a page + long peakMemory = sorter.getPeakMemoryUsedBytes(); + insertNumber(sorter, 100); + assertEquals(peakMemory, sorter.getPeakMemoryUsedBytes()); + assertEquals(1, sorter.getNumberOfAllocatedPages()); + } finally { + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); + } + } + +} + diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java new file mode 100644 index 000000000000..778e813df6b5 --- /dev/null +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.util.Arrays; + +import org.junit.Test; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.*; +import static org.mockito.Mockito.mock; + +import org.apache.spark.HashPartitioner; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +public class UnsafeInMemorySorterSuite { + + private static String getStringFromDataPage(Object baseObject, long baseOffset, int length) { + final byte[] strBytes = new byte[length]; + Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, length); + return new String(strBytes); + } + + @Test + public void testSortingEmptyInput() { + final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter( + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)), + mock(RecordComparator.class), + mock(PrefixComparator.class), + 100); + final UnsafeSorterIterator iter = sorter.getSortedIterator(); + assert(!iter.hasNext()); + } + + @Test + public void testSortingOnlyByIntegerPrefix() throws Exception { + final String[] dataToSort = new String[] { + "Boba", + "Pearls", + "Tapioca", + "Taho", + "Condensed Milk", + "Jasmine", + "Milk Tea", + "Lychee", + "Mango" + }; + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final MemoryBlock dataPage = memoryManager.allocatePage(2048); + final Object baseObject = dataPage.getBaseObject(); + // Write the records into the data page: + long position = dataPage.getBaseOffset(); + for (String str : dataToSort) { + final byte[] strBytes = str.getBytes("utf-8"); + Platform.putInt(baseObject, position, strBytes.length); + position += 4; + Platform.copyMemory( + strBytes, Platform.BYTE_ARRAY_OFFSET, baseObject, position, strBytes.length); + position += strBytes.length; + } + // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so + // use a dummy comparator + final RecordComparator recordComparator = new RecordComparator() { + @Override + public int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset) { + return 0; + } + }; + // Compute key prefixes based on the records' partition ids + final HashPartitioner hashPartitioner = new HashPartitioner(4); + // Use integer comparison for comparing prefixes (which are partition ids, in this case) + final PrefixComparator prefixComparator = new PrefixComparator() { + @Override + public int compare(long prefix1, long prefix2) { + return (int) prefix1 - (int) prefix2; + } + }; + UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(memoryManager, recordComparator, + prefixComparator, dataToSort.length); + // Given a page of records, insert those records into the sorter one-by-one: + position = dataPage.getBaseOffset(); + for (int i = 0; i < dataToSort.length; i++) { + // position now points to the start of a record (which holds its length). + final int recordLength = Platform.getInt(baseObject, position); + final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); + final String str = getStringFromDataPage(baseObject, position + 4, recordLength); + final int partitionId = hashPartitioner.getPartition(str); + sorter.insertRecord(address, partitionId); + position += 4 + recordLength; + } + final UnsafeSorterIterator iter = sorter.getSortedIterator(); + int iterLength = 0; + long prevPrefix = -1; + Arrays.sort(dataToSort); + while (iter.hasNext()) { + iter.loadNext(); + final String str = + getStringFromDataPage(iter.getBaseObject(), iter.getBaseOffset(), iter.getRecordLength()); + final long keyPrefix = iter.getKeyPrefix(); + assertThat(str, isIn(Arrays.asList(dataToSort))); + assertThat(keyPrefix, greaterThanOrEqualTo(prevPrefix)); + prevPrefix = keyPrefix; + iterLength++; + } + assertEquals(dataToSort.length, iterLength); + } +} diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index e942d6579b2f..5b84acf40be4 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -18,13 +18,17 @@ package org.apache.spark import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import scala.ref.WeakReference import org.scalatest.Matchers +import org.scalatest.exceptions.TestFailedException +import org.apache.spark.scheduler._ -class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext { +class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext { + import InternalAccumulator._ implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] = new AccumulableParam[mutable.Set[A], A] { @@ -155,4 +159,223 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex assert(!Accumulators.originals.get(accId).isDefined) } + test("internal accumulators in TaskContext") { + sc = new SparkContext("local", "test") + val accums = InternalAccumulator.create(sc) + val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null, accums) + val internalMetricsToAccums = taskContext.internalMetricsToAccumulators + val collectedInternalAccums = taskContext.collectInternalAccumulators() + val collectedAccums = taskContext.collectAccumulators() + assert(internalMetricsToAccums.size > 0) + assert(internalMetricsToAccums.values.forall(_.isInternal)) + assert(internalMetricsToAccums.contains(TEST_ACCUMULATOR)) + val testAccum = internalMetricsToAccums(TEST_ACCUMULATOR) + assert(collectedInternalAccums.size === internalMetricsToAccums.size) + assert(collectedInternalAccums.size === collectedAccums.size) + assert(collectedInternalAccums.contains(testAccum.id)) + assert(collectedAccums.contains(testAccum.id)) + } + + test("internal accumulators in a stage") { + val listener = new SaveInfoListener + val numPartitions = 10 + sc = new SparkContext("local", "test") + sc.addSparkListener(listener) + // Have each task add 1 to the internal accumulator + val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitions { iter => + TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 + iter + } + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { _ => + val stageInfos = listener.getCompletedStageInfos + val taskInfos = listener.getCompletedTaskInfos + assert(stageInfos.size === 1) + assert(taskInfos.size === numPartitions) + // The accumulator values should be merged in the stage + val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) + assert(stageAccum.value.toLong === numPartitions) + // The accumulator should be updated locally on each task + val taskAccumValues = taskInfos.map { taskInfo => + val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) + assert(taskAccum.update.isDefined) + assert(taskAccum.update.get.toLong === 1) + taskAccum.value.toLong + } + // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions + assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + } + rdd.count() + } + + test("internal accumulators in multiple stages") { + val listener = new SaveInfoListener + val numPartitions = 10 + sc = new SparkContext("local", "test") + sc.addSparkListener(listener) + // Each stage creates its own set of internal accumulators so the + // values for the same metric should not be mixed up across stages + val rdd = sc.parallelize(1 to 100, numPartitions) + .map { i => (i, i) } + .mapPartitions { iter => + TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 + iter + } + .reduceByKey { case (x, y) => x + y } + .mapPartitions { iter => + TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 10 + iter + } + .repartition(numPartitions * 2) + .mapPartitions { iter => + TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 100 + iter + } + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { _ => + // We ran 3 stages, and the accumulator values should be distinct + val stageInfos = listener.getCompletedStageInfos + assert(stageInfos.size === 3) + val (firstStageAccum, secondStageAccum, thirdStageAccum) = + (findAccumulableInfo(stageInfos(0).accumulables.values, TEST_ACCUMULATOR), + findAccumulableInfo(stageInfos(1).accumulables.values, TEST_ACCUMULATOR), + findAccumulableInfo(stageInfos(2).accumulables.values, TEST_ACCUMULATOR)) + assert(firstStageAccum.value.toLong === numPartitions) + assert(secondStageAccum.value.toLong === numPartitions * 10) + assert(thirdStageAccum.value.toLong === numPartitions * 2 * 100) + } + rdd.count() + } + + test("internal accumulators in fully resubmitted stages") { + testInternalAccumulatorsWithFailedTasks((i: Int) => true) // fail all tasks + } + + test("internal accumulators in partially resubmitted stages") { + testInternalAccumulatorsWithFailedTasks((i: Int) => i % 2 == 0) // fail a subset + } + + /** + * Return the accumulable info that matches the specified name. + */ + private def findAccumulableInfo( + accums: Iterable[AccumulableInfo], + name: String): AccumulableInfo = { + accums.find { a => a.name == name }.getOrElse { + throw new TestFailedException(s"internal accumulator '$name' not found", 0) + } + } + + /** + * Test whether internal accumulators are merged properly if some tasks fail. + */ + private def testInternalAccumulatorsWithFailedTasks(failCondition: (Int => Boolean)): Unit = { + val listener = new SaveInfoListener + val numPartitions = 10 + val numFailedPartitions = (0 until numPartitions).count(failCondition) + // This says use 1 core and retry tasks up to 2 times + sc = new SparkContext("local[1, 2]", "test") + sc.addSparkListener(listener) + val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) => + val taskContext = TaskContext.get() + taskContext.internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 + // Fail the first attempts of a subset of the tasks + if (failCondition(i) && taskContext.attemptNumber() == 0) { + throw new Exception("Failing a task intentionally.") + } + iter + } + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { _ => + val stageInfos = listener.getCompletedStageInfos + val taskInfos = listener.getCompletedTaskInfos + assert(stageInfos.size === 1) + assert(taskInfos.size === numPartitions + numFailedPartitions) + val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) + // We should not double count values in the merged accumulator + assert(stageAccum.value.toLong === numPartitions) + val taskAccumValues = taskInfos.flatMap { taskInfo => + if (!taskInfo.failed) { + // If a task succeeded, its update value should always be 1 + val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) + assert(taskAccum.update.isDefined) + assert(taskAccum.update.get.toLong === 1) + Some(taskAccum.value.toLong) + } else { + // If a task failed, we should not get its accumulator values + assert(taskInfo.accumulables.isEmpty) + None + } + } + assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + } + rdd.count() + } + +} + +private[spark] object AccumulatorSuite { + + /** + * Run one or more Spark jobs and verify that the peak execution memory accumulator + * is updated afterwards. + */ + def verifyPeakExecutionMemorySet( + sc: SparkContext, + testName: String)(testBody: => Unit): Unit = { + val listener = new SaveInfoListener + sc.addSparkListener(listener) + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { jobId => + if (jobId == 0) { + // The first job is a dummy one to verify that the accumulator does not already exist + val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values) + assert(!accums.exists(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)) + } else { + // In the subsequent jobs, verify that peak execution memory is updated + val accum = listener.getCompletedStageInfos + .flatMap(_.accumulables.values) + .find(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY) + .getOrElse { + throw new TestFailedException( + s"peak execution memory accumulator not set in '$testName'", 0) + } + assert(accum.value.toLong > 0) + } + } + // Run the jobs + sc.parallelize(1 to 10).count() + testBody + } +} + +/** + * A simple listener that keeps track of the TaskInfos and StageInfos of all completed jobs. + */ +private class SaveInfoListener extends SparkListener { + private val completedStageInfos: ArrayBuffer[StageInfo] = new ArrayBuffer[StageInfo] + private val completedTaskInfos: ArrayBuffer[TaskInfo] = new ArrayBuffer[TaskInfo] + private var jobCompletionCallback: (Int => Unit) = null // parameter is job ID + + def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq + def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.toArray.toSeq + + /** Register a callback to be called on job end. */ + def registerJobCompletionCallback(callback: (Int => Unit)): Unit = { + jobCompletionCallback = callback + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + if (jobCompletionCallback != null) { + jobCompletionCallback(jobEnd.jobId) + } + } + + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + completedStageInfos += stageCompleted.stageInfo + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + completedTaskInfos += taskEnd.taskInfo + } } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index af81e46a657d..cb8bd04e496a 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -21,7 +21,7 @@ import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter import org.scalatest.mock.MockitoSugar -import org.apache.spark.executor.DataReadMethod +import org.apache.spark.executor.{DataReadMethod, TaskMetrics} import org.apache.spark.rdd.RDD import org.apache.spark.storage._ @@ -65,7 +65,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before // in blockManager.put is a losing battle. You have been warned. blockManager = sc.env.blockManager cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0, null) + val context = TaskContext.empty() val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) val getValue = blockManager.get(RDDBlockId(rdd.id, split.index)) assert(computeValue.toList === List(1, 2, 3, 4)) @@ -77,7 +77,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12) when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result)) - val context = new TaskContextImpl(0, 0, 0, 0, null) + val context = TaskContext.empty() val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(5, 6, 7)) } @@ -86,14 +86,14 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before // Local computation should not persist the resulting value, so don't expect a put(). when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None) - val context = new TaskContextImpl(0, 0, 0, 0, null, true) + val context = new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty, runningLocally = true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } test("verify task metrics updated correctly") { cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0, null) + val context = TaskContext.empty() cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) } diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index d1761a48babb..4d70bfed909b 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -25,11 +25,15 @@ import org.apache.spark.rdd._ import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.util.Utils +/** + * Test suite for end-to-end checkpointing functionality. + * This tests both reliable checkpoints and local checkpoints. + */ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging { - var checkpointDir: File = _ - val partitioner = new HashPartitioner(2) + private var checkpointDir: File = _ + private val partitioner = new HashPartitioner(2) - override def beforeEach() { + override def beforeEach(): Unit = { super.beforeEach() checkpointDir = File.createTempFile("temp", "", Utils.createTempDir()) checkpointDir.delete() @@ -37,40 +41,43 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging sc.setCheckpointDir(checkpointDir.toString) } - override def afterEach() { + override def afterEach(): Unit = { super.afterEach() Utils.deleteRecursively(checkpointDir) } - test("basic checkpointing") { + runTest("basic checkpointing") { reliableCheckpoint: Boolean => val parCollection = sc.makeRDD(1 to 4) val flatMappedRDD = parCollection.flatMap(x => 1 to x) - flatMappedRDD.checkpoint() - assert(flatMappedRDD.dependencies.head.rdd == parCollection) + checkpoint(flatMappedRDD, reliableCheckpoint) + assert(flatMappedRDD.dependencies.head.rdd === parCollection) val result = flatMappedRDD.collect() assert(flatMappedRDD.dependencies.head.rdd != parCollection) assert(flatMappedRDD.collect() === result) } - test("RDDs with one-to-one dependencies") { - testRDD(_.map(x => x.toString)) - testRDD(_.flatMap(x => 1 to x)) - testRDD(_.filter(_ % 2 == 0)) - testRDD(_.sample(false, 0.5, 0)) - testRDD(_.glom()) - testRDD(_.mapPartitions(_.map(_.toString))) - testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) - testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) - testRDD(_.pipe(Seq("cat"))) + runTest("RDDs with one-to-one dependencies") { reliableCheckpoint: Boolean => + testRDD(_.map(x => x.toString), reliableCheckpoint) + testRDD(_.flatMap(x => 1 to x), reliableCheckpoint) + testRDD(_.filter(_ % 2 == 0), reliableCheckpoint) + testRDD(_.sample(false, 0.5, 0), reliableCheckpoint) + testRDD(_.glom(), reliableCheckpoint) + testRDD(_.mapPartitions(_.map(_.toString)), reliableCheckpoint) + testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString), reliableCheckpoint) + testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x), + reliableCheckpoint) + testRDD(_.pipe(Seq("cat")), reliableCheckpoint) } - test("ParallelCollection") { + runTest("ParallelCollectionRDD") { reliableCheckpoint: Boolean => val parCollection = sc.makeRDD(1 to 4, 2) val numPartitions = parCollection.partitions.size - parCollection.checkpoint() + checkpoint(parCollection, reliableCheckpoint) assert(parCollection.dependencies === Nil) val result = parCollection.collect() - assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result) + if (reliableCheckpoint) { + assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result) + } assert(parCollection.dependencies != Nil) assert(parCollection.partitions.length === numPartitions) assert(parCollection.partitions.toList === @@ -78,44 +85,46 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging assert(parCollection.collect() === result) } - test("BlockRDD") { + runTest("BlockRDD") { reliableCheckpoint: Boolean => val blockId = TestBlockId("id") val blockManager = SparkEnv.get.blockManager blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY) val blockRDD = new BlockRDD[String](sc, Array(blockId)) val numPartitions = blockRDD.partitions.size - blockRDD.checkpoint() + checkpoint(blockRDD, reliableCheckpoint) val result = blockRDD.collect() - assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result) + if (reliableCheckpoint) { + assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result) + } assert(blockRDD.dependencies != Nil) assert(blockRDD.partitions.length === numPartitions) assert(blockRDD.partitions.toList === blockRDD.checkpointData.get.getPartitions.toList) assert(blockRDD.collect() === result) } - test("ShuffledRDD") { + runTest("ShuffleRDD") { reliableCheckpoint: Boolean => testRDD(rdd => { // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD new ShuffledRDD[Int, Int, Int](rdd.map(x => (x % 2, 1)), partitioner) - }) + }, reliableCheckpoint) } - test("UnionRDD") { + runTest("UnionRDD") { reliableCheckpoint: Boolean => def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1) - testRDD(_.union(otherRDD)) - testRDDPartitions(_.union(otherRDD)) + testRDD(_.union(otherRDD), reliableCheckpoint) + testRDDPartitions(_.union(otherRDD), reliableCheckpoint) } - test("CartesianRDD") { + runTest("CartesianRDD") { reliableCheckpoint: Boolean => def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1) - testRDD(new CartesianRDD(sc, _, otherRDD)) - testRDDPartitions(new CartesianRDD(sc, _, otherRDD)) + testRDD(new CartesianRDD(sc, _, otherRDD), reliableCheckpoint) + testRDDPartitions(new CartesianRDD(sc, _, otherRDD), reliableCheckpoint) // Test that the CartesianRDD updates parent partitions (CartesianRDD.s1/s2) after // the parent RDD has been checkpointed and parent partitions have been changed. // Note that this test is very specific to the current implementation of CartesianRDD. val ones = sc.makeRDD(1 to 100, 10).map(x => x) - ones.checkpoint() // checkpoint that MappedRDD + checkpoint(ones, reliableCheckpoint) // checkpoint that MappedRDD val cartesian = new CartesianRDD(sc, ones, ones) val splitBeforeCheckpoint = serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition]) @@ -129,16 +138,16 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging ) } - test("CoalescedRDD") { - testRDD(_.coalesce(2)) - testRDDPartitions(_.coalesce(2)) + runTest("CoalescedRDD") { reliableCheckpoint: Boolean => + testRDD(_.coalesce(2), reliableCheckpoint) + testRDDPartitions(_.coalesce(2), reliableCheckpoint) // Test that the CoalescedRDDPartition updates parent partitions (CoalescedRDDPartition.parents) // after the parent RDD has been checkpointed and parent partitions have been changed. // Note that this test is very specific to the current implementation of // CoalescedRDDPartitions. val ones = sc.makeRDD(1 to 100, 10).map(x => x) - ones.checkpoint() // checkpoint that MappedRDD + checkpoint(ones, reliableCheckpoint) // checkpoint that MappedRDD val coalesced = new CoalescedRDD(ones, 2) val splitBeforeCheckpoint = serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition]) @@ -151,7 +160,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging ) } - test("CoGroupedRDD") { + runTest("CoGroupedRDD") { reliableCheckpoint: Boolean => val longLineageRDD1 = generateFatPairRDD() // Collect the RDD as sequences instead of arrays to enable equality tests in testRDD @@ -160,26 +169,26 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging testRDD(rdd => { CheckpointSuite.cogroup(longLineageRDD1, rdd.map(x => (x % 2, 1)), partitioner) - }, seqCollectFunc) + }, reliableCheckpoint, seqCollectFunc) val longLineageRDD2 = generateFatPairRDD() testRDDPartitions(rdd => { CheckpointSuite.cogroup( longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner) - }, seqCollectFunc) + }, reliableCheckpoint, seqCollectFunc) } - test("ZippedPartitionsRDD") { - testRDD(rdd => rdd.zip(rdd.map(x => x))) - testRDDPartitions(rdd => rdd.zip(rdd.map(x => x))) + runTest("ZippedPartitionsRDD") { reliableCheckpoint: Boolean => + testRDD(rdd => rdd.zip(rdd.map(x => x)), reliableCheckpoint) + testRDDPartitions(rdd => rdd.zip(rdd.map(x => x)), reliableCheckpoint) // Test that ZippedPartitionsRDD updates parent partitions after parent RDDs have // been checkpointed and parent partitions have been changed. // Note that this test is very specific to the implementation of ZippedPartitionsRDD. val rdd = generateFatRDD() val zippedRDD = rdd.zip(rdd.map(x => x)).asInstanceOf[ZippedPartitionsRDD2[_, _, _]] - zippedRDD.rdd1.checkpoint() - zippedRDD.rdd2.checkpoint() + checkpoint(zippedRDD.rdd1, reliableCheckpoint) + checkpoint(zippedRDD.rdd2, reliableCheckpoint) val partitionBeforeCheckpoint = serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartitionsPartition]) zippedRDD.count() @@ -194,27 +203,27 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging ) } - test("PartitionerAwareUnionRDD") { + runTest("PartitionerAwareUnionRDD") { reliableCheckpoint: Boolean => testRDD(rdd => { new PartitionerAwareUnionRDD[(Int, Int)](sc, Array( generateFatPairRDD(), rdd.map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _) )) - }) + }, reliableCheckpoint) testRDDPartitions(rdd => { new PartitionerAwareUnionRDD[(Int, Int)](sc, Array( generateFatPairRDD(), rdd.map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _) )) - }) + }, reliableCheckpoint) // Test that the PartitionerAwareUnionRDD updates parent partitions // (PartitionerAwareUnionRDD.parents) after the parent RDD has been checkpointed and parent // partitions have been changed. Note that this test is very specific to the current // implementation of PartitionerAwareUnionRDD. val pairRDD = generateFatPairRDD() - pairRDD.checkpoint() + checkpoint(pairRDD, reliableCheckpoint) val unionRDD = new PartitionerAwareUnionRDD(sc, Array(pairRDD)) val partitionBeforeCheckpoint = serializeDeserialize( unionRDD.partitions.head.asInstanceOf[PartitionerAwareUnionRDDPartition]) @@ -228,17 +237,34 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging ) } - test("CheckpointRDD with zero partitions") { + runTest("CheckpointRDD with zero partitions") { reliableCheckpoint: Boolean => val rdd = new BlockRDD[Int](sc, Array[BlockId]()) assert(rdd.partitions.size === 0) assert(rdd.isCheckpointed === false) - rdd.checkpoint() + checkpoint(rdd, reliableCheckpoint) assert(rdd.count() === 0) assert(rdd.isCheckpointed === true) assert(rdd.partitions.size === 0) } - def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect() + // Utility test methods + + /** Checkpoint the RDD either locally or reliably. */ + private def checkpoint(rdd: RDD[_], reliableCheckpoint: Boolean): Unit = { + if (reliableCheckpoint) { + rdd.checkpoint() + } else { + rdd.localCheckpoint() + } + } + + /** Run a test twice, once for local checkpointing and once for reliable checkpointing. */ + private def runTest(name: String)(body: Boolean => Unit): Unit = { + test(name + " [reliable checkpoint]")(body(true)) + test(name + " [local checkpoint]")(body(false)) + } + + private def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect() /** * Test checkpointing of the RDD generated by the given operation. It tests whether the @@ -246,11 +272,14 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging * on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.). * * @param op an operation to run on the RDD + * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints * @param collectFunc a function for collecting the values in the RDD, in case there are * non-comparable types like arrays that we want to convert to something that supports == */ - def testRDD[U: ClassTag](op: (RDD[Int]) => RDD[U], - collectFunc: RDD[U] => Any = defaultCollectFunc[U] _) { + private def testRDD[U: ClassTag]( + op: (RDD[Int]) => RDD[U], + reliableCheckpoint: Boolean, + collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { // Generate the final RDD using given RDD operation val baseRDD = generateFatRDD() val operatedRDD = op(baseRDD) @@ -267,14 +296,16 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging // Find serialized sizes before and after the checkpoint logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) - operatedRDD.checkpoint() + checkpoint(operatedRDD, reliableCheckpoint) val result = collectFunc(operatedRDD) operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) // Test whether the checkpoint file has been created - assert(collectFunc(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result) + if (reliableCheckpoint) { + assert(collectFunc(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result) + } // Test whether dependencies have been changed from its earlier parent RDD assert(operatedRDD.dependencies.head.rdd != parentRDD) @@ -310,11 +341,14 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging * partitions (i.e., do not call it on simple RDD like MappedRDD). * * @param op an operation to run on the RDD + * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints * @param collectFunc a function for collecting the values in the RDD, in case there are * non-comparable types like arrays that we want to convert to something that supports == */ - def testRDDPartitions[U: ClassTag](op: (RDD[Int]) => RDD[U], - collectFunc: RDD[U] => Any = defaultCollectFunc[U] _) { + private def testRDDPartitions[U: ClassTag]( + op: (RDD[Int]) => RDD[U], + reliableCheckpoint: Boolean, + collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { // Generate the final RDD using given RDD operation val baseRDD = generateFatRDD() val operatedRDD = op(baseRDD) @@ -328,7 +362,10 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging // Find serialized sizes before and after the checkpoint logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) - parentRDDs.foreach(_.checkpoint()) // checkpoint the parent RDD, not the generated one + // checkpoint the parent RDD, not the generated one + parentRDDs.foreach { rdd => + checkpoint(rdd, reliableCheckpoint) + } val result = collectFunc(operatedRDD) // force checkpointing operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) @@ -350,7 +387,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging /** * Generate an RDD such that both the RDD and its partitions have large size. */ - def generateFatRDD(): RDD[Int] = { + private def generateFatRDD(): RDD[Int] = { new FatRDD(sc.makeRDD(1 to 100, 4)).map(x => x) } @@ -358,7 +395,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging * Generate an pair RDD (with partitioner) such that both the RDD and its partitions * have large size. */ - def generateFatPairRDD(): RDD[(Int, Int)] = { + private def generateFatPairRDD(): RDD[(Int, Int)] = { new FatPairRDD(sc.makeRDD(1 to 100, 4), partitioner).mapValues(x => x) } @@ -366,7 +403,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint. */ - def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { + private def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { val rddSize = Utils.serialize(rdd).size val rddCpDataSize = Utils.serialize(rdd.checkpointData).size val rddPartitionSize = Utils.serialize(rdd.partitions).size @@ -394,7 +431,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging * contents after deserialization (e.g., the contents of an RDD split after * it is sent to a slave along with a task) */ - def serializeDeserialize[T](obj: T): T = { + private def serializeDeserialize[T](obj: T): T = { val bytes = Utils.serialize(obj) Utils.deserialize[T](bytes) } @@ -402,10 +439,11 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging /** * Recursively force the initialization of the all members of an RDD and it parents. */ - def initializeRdd(rdd: RDD[_]) { + private def initializeRdd(rdd: RDD[_]): Unit = { rdd.partitions // forces the - rdd.dependencies.map(_.rdd).foreach(initializeRdd(_)) + rdd.dependencies.map(_.rdd).foreach(initializeRdd) } + } /** RDD partition that has large serialized size. */ @@ -445,7 +483,7 @@ class FatPairRDD(parent: RDD[Int], _partitioner: Partitioner) extends RDD[(Int, object CheckpointSuite { // This is a custom cogroup function that does not use mapValues like // the PairRDDFunctions.cogroup() - def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) + def cogroup[K: ClassTag, V: ClassTag](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) : RDD[(K, Array[Iterable[V]])] = { new CoGroupedRDD[K]( Seq(first.asInstanceOf[RDD[(K, _)]], second.asInstanceOf[RDD[(K, _)]]), diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 501fe186bfd7..0c14bef7befd 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -24,12 +24,11 @@ import scala.language.existentials import scala.util.Random import org.scalatest.BeforeAndAfter -import org.scalatest.concurrent.{PatienceConfiguration, Eventually} +import org.scalatest.concurrent.PatienceConfiguration import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.{RDDCheckpointData, RDD} +import org.apache.spark.rdd.{ReliableRDDCheckpointData, RDD} import org.apache.spark.storage._ import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager @@ -52,6 +51,7 @@ abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[Ha .setAppName("ContextCleanerSuite") .set("spark.cleaner.referenceTracking.blocking", "true") .set("spark.cleaner.referenceTracking.blocking.shuffle", "true") + .set("spark.cleaner.referenceTracking.cleanCheckpoints", "true") .set("spark.shuffle.manager", shuffleManager.getName) before { @@ -209,11 +209,11 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { postGCTester.assertCleanup() } - test("automatically cleanup checkpoint") { + test("automatically cleanup normal checkpoint") { val checkpointDir = java.io.File.createTempFile("temp", "") checkpointDir.deleteOnExit() checkpointDir.delete() - var rdd = newPairRDD + var rdd = newPairRDD() sc.setCheckpointDir(checkpointDir.toString) rdd.checkpoint() rdd.cache() @@ -221,23 +221,26 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { var rddId = rdd.id // Confirm the checkpoint directory exists - assert(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).isDefined) - val path = RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get + assert(ReliableRDDCheckpointData.checkpointPath(sc, rddId).isDefined) + val path = ReliableRDDCheckpointData.checkpointPath(sc, rddId).get val fs = path.getFileSystem(sc.hadoopConfiguration) assert(fs.exists(path)) // the checkpoint is not cleaned by default (without the configuration set) - var postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, Nil) + var postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, Seq(rddId)) rdd = null // Make RDD out of scope, ok if collected earlier runGC() postGCTester.assertCleanup() - assert(fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get)) + assert(!fs.exists(ReliableRDDCheckpointData.checkpointPath(sc, rddId).get)) + // Verify that checkpoints are NOT cleaned up if the config is not enabled sc.stop() - val conf = new SparkConf().setMaster("local[2]").setAppName("cleanupCheckpoint"). - set("spark.cleaner.referenceTracking.cleanCheckpoints", "true") + val conf = new SparkConf() + .setMaster("local[2]") + .setAppName("cleanupCheckpoint") + .set("spark.cleaner.referenceTracking.cleanCheckpoints", "false") sc = new SparkContext(conf) - rdd = newPairRDD + rdd = newPairRDD() sc.setCheckpointDir(checkpointDir.toString) rdd.checkpoint() rdd.cache() @@ -245,17 +248,40 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { rddId = rdd.id // Confirm the checkpoint directory exists - assert(fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get)) + assert(fs.exists(ReliableRDDCheckpointData.checkpointPath(sc, rddId).get)) // Reference rdd to defeat any early collection by the JVM rdd.count() // Test that GC causes checkpoint data cleanup after dereferencing the RDD - postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, Seq(rddId)) + postGCTester = new CleanerTester(sc, Seq(rddId)) rdd = null // Make RDD out of scope runGC() postGCTester.assertCleanup() - assert(!fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get)) + assert(fs.exists(ReliableRDDCheckpointData.checkpointPath(sc, rddId).get)) + } + + test("automatically clean up local checkpoint") { + // Note that this test is similar to the RDD cleanup + // test because the same underlying mechanism is used! + var rdd = newPairRDD().localCheckpoint() + assert(rdd.checkpointData.isDefined) + assert(rdd.checkpointData.get.checkpointRDD.isEmpty) + rdd.count() + assert(rdd.checkpointData.get.checkpointRDD.isDefined) + + // Test that GC does not cause checkpoint cleanup due to a strong reference + val preGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) + runGC() + intercept[Exception] { + preGCTester.assertCleanup()(timeout(1000 millis)) + } + + // Test that RDD going out of scope does cause the checkpoint blocks to be cleaned up + val postGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) + rdd = null + runGC() + postGCTester.assertCleanup() } test("automatically cleanup RDD + shuffle + broadcast") { @@ -292,7 +318,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { sc.stop() val conf2 = new SparkConf() - .setMaster("local-cluster[2, 1, 512]") + .setMaster("local-cluster[2, 1, 1024]") .setAppName("ContextCleanerSuite") .set("spark.cleaner.referenceTracking.blocking", "true") .set("spark.cleaner.referenceTracking.blocking.shuffle", "true") @@ -370,7 +396,7 @@ class SortShuffleContextCleanerSuite extends ContextCleanerSuiteBase(classOf[Sor sc.stop() val conf2 = new SparkConf() - .setMaster("local-cluster[2, 1, 512]") + .setMaster("local-cluster[2, 1, 1024]") .setAppName("ContextCleanerSuite") .set("spark.cleaner.referenceTracking.blocking", "true") .set("spark.cleaner.referenceTracking.blocking.shuffle", "true") @@ -408,7 +434,10 @@ class SortShuffleContextCleanerSuite extends ContextCleanerSuiteBase(classOf[Sor } -/** Class to test whether RDDs, shuffles, etc. have been successfully cleaned. */ +/** + * Class to test whether RDDs, shuffles, etc. have been successfully cleaned. + * The checkpoint here refers only to normal (reliable) checkpoints, not local checkpoints. + */ class CleanerTester( sc: SparkContext, rddIds: Seq[Int] = Seq.empty, diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 9c191ed52206..600c1403b034 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -29,7 +29,7 @@ class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() { class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContext { - val clusterUrl = "local-cluster[2,1,512]" + val clusterUrl = "local-cluster[2,1,1024]" test("task throws not serializable exception") { // Ensures that executors do not crash when an exn is not serializable. If executors crash, @@ -40,7 +40,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val numSlaves = 3 val numPartitions = 10 - sc = new SparkContext("local-cluster[%s,1,512]".format(numSlaves), "test") + sc = new SparkContext("local-cluster[%s,1,1024]".format(numSlaves), "test") val data = sc.parallelize(1 to 100, numPartitions). map(x => throw new NotSerializableExn(new NotSerializableClass)) intercept[SparkException] { @@ -50,16 +50,16 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex } test("local-cluster format") { - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local-cluster[2,1,1024]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) resetSparkContext() - sc = new SparkContext("local-cluster[2 , 1 , 512]", "test") + sc = new SparkContext("local-cluster[2 , 1 , 1024]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) resetSparkContext() - sc = new SparkContext("local-cluster[2, 1, 512]", "test") + sc = new SparkContext("local-cluster[2, 1, 1024]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) resetSparkContext() - sc = new SparkContext("local-cluster[ 2, 1, 512 ]", "test") + sc = new SparkContext("local-cluster[ 2, 1, 1024 ]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) resetSparkContext() } @@ -107,7 +107,9 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex sc = new SparkContext(clusterUrl, "test") val accum = sc.accumulator(0) val thrown = intercept[SparkException] { + // scalastyle:off println sc.parallelize(1 to 10, 10).foreach(x => println(x / 0)) + // scalastyle:on println } assert(thrown.getClass === classOf[SparkException]) assert(thrown.getMessage.contains("failed 4 times")) @@ -274,7 +276,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex DistributedSuite.amMaster = true // Using more than two nodes so we don't have a symmetric communication pattern and might // cache a partially correct list of peers. - sc = new SparkContext("local-cluster[3,1,512]", "test") + sc = new SparkContext("local-cluster[3,1,1024]", "test") for (i <- 1 to 3) { val data = sc.parallelize(Seq(true, false, false, false), 4) data.persist(StorageLevel.MEMORY_ONLY_2) @@ -292,7 +294,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex test("unpersist RDDs") { DistributedSuite.amMaster = true - sc = new SparkContext("local-cluster[3,1,512]", "test") + sc = new SparkContext("local-cluster[3,1,1024]", "test") val data = sc.parallelize(Seq(true, false, false, false), 4) data.persist(StorageLevel.MEMORY_ONLY_2) data.count diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index b2262033ca23..454b7e607a51 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -29,7 +29,7 @@ class DriverSuite extends SparkFunSuite with Timeouts { ignore("driver should exit after finishing without cleanup (SPARK-530)") { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - val masters = Table("master", "local", "local-cluster[2,1,512]") + val masters = Table("master", "local", "local-cluster[2,1,1024]") forAll(masters) { (master: String) => val process = Utils.executeCommand( Seq(s"$sparkHome/bin/spark-class", "org.apache.spark.DriverWithoutCleanup", master), diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 803e1831bb26..116f027a0f98 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -206,8 +206,8 @@ class ExecutorAllocationManagerSuite val task2Info = createTaskInfo(1, 0, "executor-1") sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, task2Info)) - sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, null, task1Info, null)) - sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, null, task2Info, null)) + sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, Success, task1Info, null)) + sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, Success, task2Info, null)) assert(adjustRequestedExecutors(manager) === -1) } @@ -751,6 +751,60 @@ class ExecutorAllocationManagerSuite assert(numExecutorsTarget(manager) === 2) } + test("get pending task number and related locality preference") { + sc = createSparkContext(2, 5, 3) + val manager = sc.executorAllocationManager.get + + val localityPreferences1 = Seq( + Seq(TaskLocation("host1"), TaskLocation("host2"), TaskLocation("host3")), + Seq(TaskLocation("host1"), TaskLocation("host2"), TaskLocation("host4")), + Seq(TaskLocation("host2"), TaskLocation("host3"), TaskLocation("host4")), + Seq.empty, + Seq.empty + ) + val stageInfo1 = createStageInfo(1, 5, localityPreferences1) + sc.listenerBus.postToAll(SparkListenerStageSubmitted(stageInfo1)) + + assert(localityAwareTasks(manager) === 3) + assert(hostToLocalTaskCount(manager) === + Map("host1" -> 2, "host2" -> 3, "host3" -> 2, "host4" -> 2)) + + val localityPreferences2 = Seq( + Seq(TaskLocation("host2"), TaskLocation("host3"), TaskLocation("host5")), + Seq(TaskLocation("host3"), TaskLocation("host4"), TaskLocation("host5")), + Seq.empty + ) + val stageInfo2 = createStageInfo(2, 3, localityPreferences2) + sc.listenerBus.postToAll(SparkListenerStageSubmitted(stageInfo2)) + + assert(localityAwareTasks(manager) === 5) + assert(hostToLocalTaskCount(manager) === + Map("host1" -> 2, "host2" -> 4, "host3" -> 4, "host4" -> 3, "host5" -> 2)) + + sc.listenerBus.postToAll(SparkListenerStageCompleted(stageInfo1)) + assert(localityAwareTasks(manager) === 2) + assert(hostToLocalTaskCount(manager) === + Map("host2" -> 1, "host3" -> 2, "host4" -> 1, "host5" -> 2)) + } + + test("SPARK-8366: maxNumExecutorsNeeded should properly handle failed tasks") { + sc = createSparkContext() + val manager = sc.executorAllocationManager.get + assert(maxNumExecutorsNeeded(manager) === 0) + + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1))) + assert(maxNumExecutorsNeeded(manager) === 1) + + val taskInfo = createTaskInfo(1, 1, "executor-1") + sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, taskInfo)) + assert(maxNumExecutorsNeeded(manager) === 1) + + // If the task is failed, we expect it to be resubmitted later. + val taskEndReason = ExceptionFailure(null, null, null, null, null, None) + sc.listenerBus.postToAll(SparkListenerTaskEnd(0, 0, null, taskEndReason, taskInfo, null)) + assert(maxNumExecutorsNeeded(manager) === 1) + } + private def createSparkContext( minExecutors: Int = 1, maxExecutors: Int = 5, @@ -784,8 +838,13 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private val sustainedSchedulerBacklogTimeout = 2L private val executorIdleTimeout = 3L - private def createStageInfo(stageId: Int, numTasks: Int): StageInfo = { - new StageInfo(stageId, 0, "name", numTasks, Seq.empty, Seq.empty, "no details") + private def createStageInfo( + stageId: Int, + numTasks: Int, + taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty + ): StageInfo = { + new StageInfo( + stageId, 0, "name", numTasks, Seq.empty, Seq.empty, "no details", taskLocalityPreferences) } private def createTaskInfo(taskId: Int, taskIndex: Int, executorId: String): TaskInfo = { @@ -815,6 +874,8 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private val _onSchedulerQueueEmpty = PrivateMethod[Unit]('onSchedulerQueueEmpty) private val _onExecutorIdle = PrivateMethod[Unit]('onExecutorIdle) private val _onExecutorBusy = PrivateMethod[Unit]('onExecutorBusy) + private val _localityAwareTasks = PrivateMethod[Int]('localityAwareTasks) + private val _hostToLocalTaskCount = PrivateMethod[Map[String, Int]]('hostToLocalTaskCount) private def numExecutorsToAdd(manager: ExecutorAllocationManager): Int = { manager invokePrivate _numExecutorsToAdd() @@ -885,4 +946,12 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private def onExecutorBusy(manager: ExecutorAllocationManager, id: String): Unit = { manager invokePrivate _onExecutorBusy(id) } + + private def localityAwareTasks(manager: ExecutorAllocationManager): Int = { + manager invokePrivate _localityAwareTasks() + } + + private def hostToLocalTaskCount(manager: ExecutorAllocationManager): Map[String, Int] = { + manager invokePrivate _hostToLocalTaskCount() + } } diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 140012226fdb..e846a72c888c 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -36,7 +36,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { override def beforeAll() { val transportConf = SparkTransportConf.fromSparkConf(conf, numUsableCores = 2) - rpcHandler = new ExternalShuffleBlockHandler(transportConf) + rpcHandler = new ExternalShuffleBlockHandler(transportConf, null) val transportContext = new TransportContext(transportConf, rpcHandler) server = transportContext.createServer() @@ -51,7 +51,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { // This test ensures that the external shuffle service is actually in use for the other tests. test("using external shuffle service") { - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) sc.env.blockManager.externalShuffleServiceEnabled should equal(true) sc.env.blockManager.shuffleClient.getClass should equal(classOf[ExternalShuffleClient]) diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index a8c8c6f73fb5..f58756e6f617 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import org.apache.spark.util.NonSerializable -import java.io.NotSerializableException +import java.io.{IOException, NotSerializableException, ObjectInputStream} // Common state shared by FailureSuite-launched tasks. We use a global object // for this because any local variables used in the task closures will rightfully @@ -130,7 +130,9 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { // Non-serializable closure in foreach function val thrown2 = intercept[SparkException] { + // scalastyle:off println sc.parallelize(1 to 10, 2).foreach(x => println(a)) + // scalastyle:on println } assert(thrown2.getClass === classOf[SparkException]) assert(thrown2.getMessage.contains("NotSerializableException") || @@ -139,5 +141,115 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { FailureSuiteState.clear() } + test("managed memory leak error should not mask other failures (SPARK-9266") { + val conf = new SparkConf().set("spark.unsafe.exceptionOnMemoryLeak", "true") + sc = new SparkContext("local[1,1]", "test", conf) + + // If a task leaks memory but fails due to some other cause, then make sure that the original + // cause is preserved + val thrownDueToTaskFailure = intercept[SparkException] { + sc.parallelize(Seq(0)).mapPartitions { iter => + TaskContext.get().taskMemoryManager().allocate(128) + throw new Exception("intentional task failure") + iter + }.count() + } + assert(thrownDueToTaskFailure.getMessage.contains("intentional task failure")) + + // If the task succeeded but memory was leaked, then the task should fail due to that leak + val thrownDueToMemoryLeak = intercept[SparkException] { + sc.parallelize(Seq(0)).mapPartitions { iter => + TaskContext.get().taskMemoryManager().allocate(128) + iter + }.count() + } + assert(thrownDueToMemoryLeak.getMessage.contains("memory leak")) + } + + // Run a 3-task map job in which task 1 always fails with a exception message that + // depends on the failure number, and check that we get the last failure. + test("last failure cause is sent back to driver") { + sc = new SparkContext("local[1,2]", "test") + val data = sc.makeRDD(1 to 3, 3).map { x => + FailureSuiteState.synchronized { + FailureSuiteState.tasksRun += 1 + if (x == 3) { + FailureSuiteState.tasksFailed += 1 + throw new UserException("oops", + new IllegalArgumentException("failed=" + FailureSuiteState.tasksFailed)) + } + } + x * x + } + val thrown = intercept[SparkException] { + data.collect() + } + FailureSuiteState.synchronized { + assert(FailureSuiteState.tasksRun === 4) + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getCause.getClass === classOf[UserException]) + assert(thrown.getCause.getMessage === "oops") + assert(thrown.getCause.getCause.getClass === classOf[IllegalArgumentException]) + assert(thrown.getCause.getCause.getMessage === "failed=2") + FailureSuiteState.clear() + } + + test("failure cause stacktrace is sent back to driver if exception is not serializable") { + sc = new SparkContext("local", "test") + val thrown = intercept[SparkException] { + sc.makeRDD(1 to 3).foreach { _ => throw new NonSerializableUserException } + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getCause === null) + assert(thrown.getMessage.contains("NonSerializableUserException")) + FailureSuiteState.clear() + } + + test("failure cause stacktrace is sent back to driver if exception is not deserializable") { + sc = new SparkContext("local", "test") + val thrown = intercept[SparkException] { + sc.makeRDD(1 to 3).foreach { _ => throw new NonDeserializableUserException } + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getCause === null) + assert(thrown.getMessage.contains("NonDeserializableUserException")) + FailureSuiteState.clear() + } + + // Run a 3-task map stage where one task fails once. + test("failure in tasks in a submitMapStage") { + sc = new SparkContext("local[1,2]", "test") + val rdd = sc.makeRDD(1 to 3, 3).map { x => + FailureSuiteState.synchronized { + FailureSuiteState.tasksRun += 1 + if (x == 1 && FailureSuiteState.tasksFailed == 0) { + FailureSuiteState.tasksFailed += 1 + throw new Exception("Intentional task failure") + } + } + (x, x) + } + val dep = new ShuffleDependency[Int, Int, Int](rdd, new HashPartitioner(2)) + sc.submitMapStage(dep).get() + FailureSuiteState.synchronized { + assert(FailureSuiteState.tasksRun === 4) + } + FailureSuiteState.clear() + } + // TODO: Need to add tests with shuffle fetch failures. } + +class UserException(message: String, cause: Throwable) + extends RuntimeException(message, cause) + +class NonSerializableUserException extends RuntimeException { + val nonSerializableInstanceVariable = new NonSerializable +} + +class NonDeserializableUserException extends RuntimeException { + private def readObject(in: ObjectInputStream): Unit = { + throw new IOException("Intentional exception during deserialization.") + } +} diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 6e65b0a8f6c7..1255e71af6c0 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -51,7 +51,9 @@ class FileServerSuite extends SparkFunSuite with LocalSparkContext { val textFile = new File(testTempDir, "FileServerSuite.txt") val pw = new PrintWriter(textFile) + // scalastyle:off println pw.println("100") + // scalastyle:on println pw.close() val jarFile = new File(testTempDir, "test.jar") @@ -137,7 +139,7 @@ class FileServerSuite extends SparkFunSuite with LocalSparkContext { } test("Distributing files on a standalone cluster") { - sc = new SparkContext("local-cluster[1,1,512]", "test", newConf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", newConf) sc.addFile(tmpFile.toString) val testData = Array((1, 1), (1, 1), (2, 1), (3, 5), (2, 2), (3, 0)) val result = sc.parallelize(testData).reduceByKey { @@ -151,7 +153,7 @@ class FileServerSuite extends SparkFunSuite with LocalSparkContext { } test ("Dynamically adding JARS on a standalone cluster") { - sc = new SparkContext("local-cluster[1,1,512]", "test", newConf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", newConf) sc.addJar(tmpJarUrl) val testData = Array((1, 1)) sc.parallelize(testData).foreach { x => @@ -162,7 +164,7 @@ class FileServerSuite extends SparkFunSuite with LocalSparkContext { } test ("Dynamically adding JARS on a standalone cluster using local: URL") { - sc = new SparkContext("local-cluster[1,1,512]", "test", newConf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", newConf) sc.addJar(tmpJarUrl.replace("file", "local")) val testData = Array((1, 1)) sc.parallelize(testData).foreach { x => diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 1d8fade90f39..fdb00aafc4a4 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark import java.io.{File, FileWriter} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.input.PortableDataStream import org.apache.spark.storage.StorageLevel @@ -179,6 +180,7 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } test("object files of classes from a JAR") { + // scalastyle:off classforname val original = Thread.currentThread().getContextClassLoader val className = "FileSuiteObjectFileTest" val jar = TestUtils.createJarWithClasses(Seq(className)) @@ -201,6 +203,7 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { finally { Thread.currentThread().setContextClassLoader(original) } + // scalastyle:on classforname } test("write SequenceFile using new Hadoop API") { @@ -504,8 +507,9 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { job.setOutputKeyClass(classOf[String]) job.setOutputValueClass(classOf[String]) job.setOutputFormatClass(classOf[NewTextOutputFormat[String, String]]) - job.getConfiguration.set("mapred.output.dir", tempDir.getPath + "/outputDataset_new") - randomRDD.saveAsNewAPIHadoopDataset(job.getConfiguration) + val jobConfig = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + jobConfig.set("mapred.output.dir", tempDir.getPath + "/outputDataset_new") + randomRDD.saveAsNewAPIHadoopDataset(jobConfig) assert(new File(tempDir.getPath + "/outputDataset_new/part-r-00000").exists() === true) } diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 911b3bddd183..139b8dc25f4b 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -17,64 +17,255 @@ package org.apache.spark -import scala.concurrent.duration._ +import java.util.concurrent.{ExecutorService, TimeUnit} + +import scala.collection.mutable import scala.language.postfixOps -import org.apache.spark.executor.TaskMetrics -import org.apache.spark.storage.BlockManagerId +import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} import org.mockito.Mockito.{mock, spy, verify, when} import org.mockito.Matchers import org.mockito.Matchers._ -import org.apache.spark.scheduler.TaskScheduler -import org.apache.spark.util.RpcUtils -import org.scalatest.concurrent.Eventually._ +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv, RpcEndpointRef} +import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.ManualClock + +/** + * A test suite for the heartbeating behavior between the driver and the executors. + */ +class HeartbeatReceiverSuite + extends SparkFunSuite + with BeforeAndAfterEach + with PrivateMethodTester + with LocalSparkContext { -class HeartbeatReceiverSuite extends SparkFunSuite with LocalSparkContext { + private val executorId1 = "executor-1" + private val executorId2 = "executor-2" - test("HeartbeatReceiver") { - sc = spy(new SparkContext("local[2]", "test")) - val scheduler = mock(classOf[TaskScheduler]) - when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true) + // Shared state that must be reset before and after each test + private var scheduler: TaskSchedulerImpl = null + private var heartbeatReceiver: HeartbeatReceiver = null + private var heartbeatReceiverRef: RpcEndpointRef = null + private var heartbeatReceiverClock: ManualClock = null + + // Helper private method accessors for HeartbeatReceiver + private val _executorLastSeen = PrivateMethod[collection.Map[String, Long]]('executorLastSeen) + private val _executorTimeoutMs = PrivateMethod[Long]('executorTimeoutMs) + private val _killExecutorThread = PrivateMethod[ExecutorService]('killExecutorThread) + + /** + * Before each test, set up the SparkContext and a custom [[HeartbeatReceiver]] + * that uses a manual clock. + */ + override def beforeEach(): Unit = { + val conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.dynamicAllocation.testing", "true") + sc = spy(new SparkContext(conf)) + scheduler = mock(classOf[TaskSchedulerImpl]) when(sc.taskScheduler).thenReturn(scheduler) + when(scheduler.sc).thenReturn(sc) + heartbeatReceiverClock = new ManualClock + heartbeatReceiver = new HeartbeatReceiver(sc, heartbeatReceiverClock) + heartbeatReceiverRef = sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver) + when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true) + } - val heartbeatReceiver = new HeartbeatReceiver(sc) - sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet) - eventually(timeout(5 seconds), interval(5 millis)) { - assert(heartbeatReceiver.scheduler != null) - } - val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv) + /** + * After each test, clean up all state and stop the [[SparkContext]]. + */ + override def afterEach(): Unit = { + super.afterEach() + scheduler = null + heartbeatReceiver = null + heartbeatReceiverRef = null + heartbeatReceiverClock = null + } - val metrics = new TaskMetrics - val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) - val response = receiverRef.askWithRetry[HeartbeatResponse]( - Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + test("task scheduler is set correctly") { + assert(heartbeatReceiver.scheduler === null) + heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + assert(heartbeatReceiver.scheduler !== null) + } - verify(scheduler).executorHeartbeatReceived( - Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) - assert(false === response.reregisterBlockManager) + test("normal heartbeat") { + heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) + triggerHeartbeat(executorId1, executorShouldReregister = false) + triggerHeartbeat(executorId2, executorShouldReregister = false) + val trackedExecutors = heartbeatReceiver.invokePrivate(_executorLastSeen()) + assert(trackedExecutors.size === 2) + assert(trackedExecutors.contains(executorId1)) + assert(trackedExecutors.contains(executorId2)) } - test("HeartbeatReceiver re-register") { - sc = spy(new SparkContext("local[2]", "test")) - val scheduler = mock(classOf[TaskScheduler]) - when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(false) - when(sc.taskScheduler).thenReturn(scheduler) + test("reregister if scheduler is not ready yet") { + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) + // Task scheduler is not set yet in HeartbeatReceiver, so executors should reregister + triggerHeartbeat(executorId1, executorShouldReregister = true) + } - val heartbeatReceiver = new HeartbeatReceiver(sc) - sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet) - eventually(timeout(5 seconds), interval(5 millis)) { - assert(heartbeatReceiver.scheduler != null) - } - val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv) + test("reregister if heartbeat from unregistered executor") { + heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + // Received heartbeat from unknown executor, so we ask it to re-register + triggerHeartbeat(executorId1, executorShouldReregister = true) + assert(heartbeatReceiver.invokePrivate(_executorLastSeen()).isEmpty) + } + + test("reregister if heartbeat from removed executor") { + heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) + // Remove the second executor but not the first + heartbeatReceiver.onExecutorRemoved(SparkListenerExecutorRemoved(0, executorId2, "bad boy")) + // Now trigger the heartbeats + // A heartbeat from the second executor should require reregistering + triggerHeartbeat(executorId1, executorShouldReregister = false) + triggerHeartbeat(executorId2, executorShouldReregister = true) + val trackedExecutors = heartbeatReceiver.invokePrivate(_executorLastSeen()) + assert(trackedExecutors.size === 1) + assert(trackedExecutors.contains(executorId1)) + assert(!trackedExecutors.contains(executorId2)) + } + + test("expire dead hosts") { + val executorTimeout = heartbeatReceiver.invokePrivate(_executorTimeoutMs()) + heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) + triggerHeartbeat(executorId1, executorShouldReregister = false) + triggerHeartbeat(executorId2, executorShouldReregister = false) + // Advance the clock and only trigger a heartbeat for the first executor + heartbeatReceiverClock.advance(executorTimeout / 2) + triggerHeartbeat(executorId1, executorShouldReregister = false) + heartbeatReceiverClock.advance(executorTimeout) + heartbeatReceiverRef.askWithRetry[Boolean](ExpireDeadHosts) + // Only the second executor should be expired as a dead host + verify(scheduler).executorLost(Matchers.eq(executorId2), any()) + val trackedExecutors = heartbeatReceiver.invokePrivate(_executorLastSeen()) + assert(trackedExecutors.size === 1) + assert(trackedExecutors.contains(executorId1)) + assert(!trackedExecutors.contains(executorId2)) + } + test("expire dead hosts should kill executors with replacement (SPARK-8119)") { + // Set up a fake backend and cluster manager to simulate killing executors + val rpcEnv = sc.env.rpcEnv + val fakeClusterManager = new FakeClusterManager(rpcEnv) + val fakeClusterManagerRef = rpcEnv.setupEndpoint("fake-cm", fakeClusterManager) + val fakeSchedulerBackend = new FakeSchedulerBackend(scheduler, rpcEnv, fakeClusterManagerRef) + when(sc.schedulerBackend).thenReturn(fakeSchedulerBackend) + + // Register fake executors with our fake scheduler backend + // This is necessary because the backend refuses to kill executors it does not know about + fakeSchedulerBackend.start() + val dummyExecutorEndpoint1 = new FakeExecutorEndpoint(rpcEnv) + val dummyExecutorEndpoint2 = new FakeExecutorEndpoint(rpcEnv) + val dummyExecutorEndpointRef1 = rpcEnv.setupEndpoint("fake-executor-1", dummyExecutorEndpoint1) + val dummyExecutorEndpointRef2 = rpcEnv.setupEndpoint("fake-executor-2", dummyExecutorEndpoint2) + fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisteredExecutor.type]( + RegisterExecutor(executorId1, dummyExecutorEndpointRef1, "dummy:4040", 0, Map.empty)) + fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisteredExecutor.type]( + RegisterExecutor(executorId2, dummyExecutorEndpointRef2, "dummy:4040", 0, Map.empty)) + heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) + triggerHeartbeat(executorId1, executorShouldReregister = false) + triggerHeartbeat(executorId2, executorShouldReregister = false) + + // Adjust the target number of executors on the cluster manager side + assert(fakeClusterManager.getTargetNumExecutors === 0) + sc.requestTotalExecutors(2, 0, Map.empty) + assert(fakeClusterManager.getTargetNumExecutors === 2) + assert(fakeClusterManager.getExecutorIdsToKill.isEmpty) + + // Expire the executors. This should trigger our fake backend to kill the executors. + // Since the kill request is sent to the cluster manager asynchronously, we need to block + // on the kill thread to ensure that the cluster manager actually received our requests. + // Here we use a timeout of O(seconds), but in practice this whole test takes O(10ms). + val executorTimeout = heartbeatReceiver.invokePrivate(_executorTimeoutMs()) + heartbeatReceiverClock.advance(executorTimeout * 2) + heartbeatReceiverRef.askWithRetry[Boolean](ExpireDeadHosts) + val killThread = heartbeatReceiver.invokePrivate(_killExecutorThread()) + killThread.shutdown() // needed for awaitTermination + killThread.awaitTermination(10L, TimeUnit.SECONDS) + + // The target number of executors should not change! Otherwise, having an expired + // executor means we permanently adjust the target number downwards until we + // explicitly request new executors. For more detail, see SPARK-8119. + assert(fakeClusterManager.getTargetNumExecutors === 2) + assert(fakeClusterManager.getExecutorIdsToKill === Set(executorId1, executorId2)) + } + + /** Manually send a heartbeat and return the response. */ + private def triggerHeartbeat( + executorId: String, + executorShouldReregister: Boolean): Unit = { val metrics = new TaskMetrics - val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) - val response = receiverRef.askWithRetry[HeartbeatResponse]( - Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + val blockManagerId = BlockManagerId(executorId, "localhost", 12345) + val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse]( + Heartbeat(executorId, Array(1L -> metrics), blockManagerId)) + if (executorShouldReregister) { + assert(response.reregisterBlockManager) + } else { + assert(!response.reregisterBlockManager) + // Additionally verify that the scheduler callback is called with the correct parameters + verify(scheduler).executorHeartbeatReceived( + Matchers.eq(executorId), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + } + } + +} + +// TODO: use these classes to add end-to-end tests for dynamic allocation! + +/** + * Dummy RPC endpoint to simulate executors. + */ +private class FakeExecutorEndpoint(override val rpcEnv: RpcEnv) extends RpcEndpoint + +/** + * Dummy scheduler backend to simulate executor allocation requests to the cluster manager. + */ +private class FakeSchedulerBackend( + scheduler: TaskSchedulerImpl, + rpcEnv: RpcEnv, + clusterManagerEndpoint: RpcEndpointRef) + extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) { + + protected override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { + clusterManagerEndpoint.askWithRetry[Boolean]( + RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount)) + } + + protected override def doKillExecutors(executorIds: Seq[String]): Boolean = { + clusterManagerEndpoint.askWithRetry[Boolean](KillExecutors(executorIds)) + } +} + +/** + * Dummy cluster manager to simulate responses to executor allocation requests. + */ +private class FakeClusterManager(override val rpcEnv: RpcEnv) extends RpcEndpoint { + private var targetNumExecutors = 0 + private val executorIdsToKill = new mutable.HashSet[String] + + def getTargetNumExecutors: Int = targetNumExecutors + def getExecutorIdsToKill: Set[String] = executorIdsToKill.toSet - verify(scheduler).executorHeartbeatReceived( - Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) - assert(true === response.reregisterBlockManager) + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RequestExecutors(requestedTotal, _, _) => + targetNumExecutors = requestedTotal + context.reply(true) + case KillExecutors(executorIds) => + executorIdsToKill ++= executorIds + context.reply(true) } } diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 340a9e327107..1168eb0b802f 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -64,7 +64,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft test("cluster mode, FIFO scheduler") { val conf = new SparkConf().set("spark.scheduler.mode", "FIFO") - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) testCount() testTake() // Make sure we can still launch tasks. @@ -75,7 +75,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft val conf = new SparkConf().set("spark.scheduler.mode", "FAIR") val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() conf.set("spark.scheduler.allocation.file", xmlPath) - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) testCount() testTake() // Make sure we can still launch tasks. diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 7a1961137cce..af4e68950f75 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -17,13 +17,15 @@ package org.apache.spark +import scala.collection.mutable.ArrayBuffer + import org.mockito.Mockito._ import org.mockito.Matchers.{any, isA} import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} class MapOutputTrackerSuite extends SparkFunSuite { private val conf = new SparkConf @@ -55,9 +57,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { Array(1000L, 10000L))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(10000L, 1000L))) - val statuses = tracker.getServerStatuses(10, 0) - assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000), - (BlockManagerId("b", "hostB", 1000), size10000))) + val statuses = tracker.getMapSizesByExecutorId(10, 0) + assert(statuses.toSet === + Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), + (BlockManagerId("b", "hostB", 1000), ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000)))) + .toSet) tracker.stop() rpcEnv.shutdown() } @@ -75,10 +79,10 @@ class MapOutputTrackerSuite extends SparkFunSuite { tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(compressedSize10000, compressedSize1000))) assert(tracker.containsShuffle(10)) - assert(tracker.getServerStatuses(10, 0).nonEmpty) + assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) tracker.unregisterShuffle(10) assert(!tracker.containsShuffle(10)) - assert(tracker.getServerStatuses(10, 0).isEmpty) + assert(tracker.getMapSizesByExecutorId(10, 0).isEmpty) tracker.stop() rpcEnv.shutdown() @@ -104,7 +108,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // The remaining reduce task might try to grab the output despite the shuffle failure; // this should cause it to fail, and the scheduler will ignore the failure due to the // stage already being aborted. - intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) } + intercept[FetchFailedException] { tracker.getMapSizesByExecutorId(10, 1) } tracker.stop() rpcEnv.shutdown() @@ -126,23 +130,23 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("a", "hostA", 1000), Array(1000L))) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), size1000))) + assert(slaveTracker.getMapSizesByExecutorId(10, 0) === + Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } // failure should be cached - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } masterTracker.stop() slaveTracker.stop() diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 3316f561a494..aa8028792cb4 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -91,13 +91,13 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva test("RangePartitioner for keys that are not Comparable (but with Ordering)") { // Row does not extend Comparable, but has an implicit Ordering defined. - implicit object RowOrdering extends Ordering[Row] { - override def compare(x: Row, y: Row): Int = x.value - y.value + implicit object RowOrdering extends Ordering[Item] { + override def compare(x: Item, y: Item): Int = x.value - y.value } - val rdd = sc.parallelize(1 to 4500).map(x => (Row(x), Row(x))) + val rdd = sc.parallelize(1 to 4500).map(x => (Item(x), Item(x))) val partitioner = new RangePartitioner(1500, rdd) - partitioner.getPartition(Row(100)) + partitioner.getPartition(Item(100)) } test("RangPartitioner.sketch") { @@ -252,4 +252,4 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva } -private sealed case class Row(value: Int) +private sealed case class Item(value: Int) diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 376481ba541f..25b79bce6ab9 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io.File +import javax.net.ssl.SSLContext import com.google.common.io.Files import org.apache.spark.util.Utils @@ -29,6 +30,15 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + // Pick two cipher suites that the provider knows about + val sslContext = SSLContext.getInstance("TLSv1.2") + sslContext.init(null, null, null) + val algorithms = sslContext + .getServerSocketFactory + .getDefaultCipherSuites + .take(2) + .toSet + val conf = new SparkConf conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) @@ -36,9 +46,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", - "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") - conf.set("spark.ssl.protocol", "SSLv3") + conf.set("spark.ssl.enabledAlgorithms", algorithms.mkString(",")) + conf.set("spark.ssl.protocol", "TLSv1.2") val opts = SSLOptions.parse(conf, "spark.ssl") @@ -52,9 +61,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(opts.trustStorePassword === Some("password")) assert(opts.keyStorePassword === Some("password")) assert(opts.keyPassword === Some("password")) - assert(opts.protocol === Some("SSLv3")) - assert(opts.enabledAlgorithms === - Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) + assert(opts.protocol === Some("TLSv1.2")) + assert(opts.enabledAlgorithms === algorithms) } test("test resolving property with defaults specified ") { diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala index 1a099da2c6c8..33270bec6247 100644 --- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala +++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala @@ -25,6 +25,20 @@ object SSLSampleConfigs { this.getClass.getResource("/untrusted-keystore").toURI).getAbsolutePath val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + val enabledAlgorithms = + // A reasonable set of TLSv1.2 Oracle security provider suites + "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " + + "TLS_RSA_WITH_AES_256_CBC_SHA256, " + + "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, " + + "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " + + "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, " + + // and their equivalent names in the IBM Security provider + "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " + + "SSL_RSA_WITH_AES_256_CBC_SHA256, " + + "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256, " + + "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " + + "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256" + def sparkSSLConfig(): SparkConf = { val conf = new SparkConf(loadDefaults = false) conf.set("spark.ssl.enabled", "true") @@ -33,9 +47,8 @@ object SSLSampleConfigs { conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", - "SSL_RSA_WITH_RC4_128_SHA, SSL_RSA_WITH_DES_CBC_SHA") - conf.set("spark.ssl.protocol", "TLSv1") + conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms) + conf.set("spark.ssl.protocol", "TLSv1.2") conf } @@ -47,9 +60,8 @@ object SSLSampleConfigs { conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", - "SSL_RSA_WITH_RC4_128_SHA, SSL_RSA_WITH_DES_CBC_SHA") - conf.set("spark.ssl.protocol", "TLSv1") + conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms) + conf.set("spark.ssl.protocol", "TLSv1.2") conf } diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index e9b64aa82a17..f29160d83408 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -125,8 +125,60 @@ class SecurityManagerSuite extends SparkFunSuite { } + test("set security with * in acls") { + val conf = new SparkConf + conf.set("spark.ui.acls.enable", "true") + conf.set("spark.admin.acls", "user1,user2") + conf.set("spark.ui.view.acls", "*") + conf.set("spark.modify.acls", "user4") + + val securityManager = new SecurityManager(conf) + assert(securityManager.aclsEnabled() === true) + + // check for viewAcls with * + assert(securityManager.checkUIViewPermissions("user1") === true) + assert(securityManager.checkUIViewPermissions("user5") === true) + assert(securityManager.checkUIViewPermissions("user6") === true) + assert(securityManager.checkModifyPermissions("user4") === true) + assert(securityManager.checkModifyPermissions("user7") === false) + assert(securityManager.checkModifyPermissions("user8") === false) + + // check for modifyAcls with * + securityManager.setModifyAcls(Set("user4"), "*") + assert(securityManager.checkModifyPermissions("user7") === true) + assert(securityManager.checkModifyPermissions("user8") === true) + + securityManager.setAdminAcls("user1,user2") + securityManager.setModifyAcls(Set("user1"), "user2") + securityManager.setViewAcls(Set("user1"), "user2") + assert(securityManager.checkUIViewPermissions("user5") === false) + assert(securityManager.checkUIViewPermissions("user6") === false) + assert(securityManager.checkModifyPermissions("user7") === false) + assert(securityManager.checkModifyPermissions("user8") === false) + + // check for adminAcls with * + securityManager.setAdminAcls("user1,*") + securityManager.setModifyAcls(Set("user1"), "user2") + securityManager.setViewAcls(Set("user1"), "user2") + assert(securityManager.checkUIViewPermissions("user5") === true) + assert(securityManager.checkUIViewPermissions("user6") === true) + assert(securityManager.checkModifyPermissions("user7") === true) + assert(securityManager.checkModifyPermissions("user8") === true) + } + test("ssl on setup") { val conf = SSLSampleConfigs.sparkSSLConfig() + val expectedAlgorithms = Set( + "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384", + "TLS_RSA_WITH_AES_256_CBC_SHA256", + "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256", + "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", + "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256", + "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384", + "SSL_RSA_WITH_AES_256_CBC_SHA256", + "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256", + "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256", + "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256") val securityManager = new SecurityManager(conf) @@ -143,9 +195,8 @@ class SecurityManagerSuite extends SparkFunSuite { assert(securityManager.fileServerSSLOptions.trustStorePassword === Some("password")) assert(securityManager.fileServerSSLOptions.keyStorePassword === Some("password")) assert(securityManager.fileServerSSLOptions.keyPassword === Some("password")) - assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1")) - assert(securityManager.fileServerSSLOptions.enabledAlgorithms === - Set("SSL_RSA_WITH_RC4_128_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) + assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1.2")) + assert(securityManager.fileServerSSLOptions.enabledAlgorithms === expectedAlgorithms) assert(securityManager.akkaSSLOptions.trustStore.isDefined === true) assert(securityManager.akkaSSLOptions.trustStore.get.getName === "truststore") @@ -154,9 +205,8 @@ class SecurityManagerSuite extends SparkFunSuite { assert(securityManager.akkaSSLOptions.trustStorePassword === Some("password")) assert(securityManager.akkaSSLOptions.keyStorePassword === Some("password")) assert(securityManager.akkaSSLOptions.keyPassword === Some("password")) - assert(securityManager.akkaSSLOptions.protocol === Some("TLSv1")) - assert(securityManager.akkaSSLOptions.enabledAlgorithms === - Set("SSL_RSA_WITH_RC4_128_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) + assert(securityManager.akkaSSLOptions.protocol === Some("TLSv1.2")) + assert(securityManager.akkaSSLOptions.enabledAlgorithms === expectedAlgorithms) } test("ssl off setup") { diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index c3c2b1ffc1ef..d91b799ecfc0 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -47,7 +47,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC } test("shuffle non-zero block size") { - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) val NUM_BLOCKS = 3 val a = sc.parallelize(1 to 10, 2) @@ -66,14 +66,14 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // All blocks must have non-zero size (0 until NUM_BLOCKS).foreach { id => - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) - assert(statuses.forall(s => s._2 > 0)) + val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id) + assert(statuses.forall(_._2.forall(blockIdSizePair => blockIdSizePair._2 > 0))) } } test("shuffle serializer") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) val a = sc.parallelize(1 to 10, 2) val b = a.map { x => (x, new NonJavaSerializableClass(x * 2)) @@ -89,7 +89,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("zero sized blocks") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) // 201 partitions (greater than "spark.shuffle.sort.bypassMergeThreshold") from 4 keys val NUM_BLOCKS = 201 @@ -105,8 +105,8 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) - statuses.map(x => x._2) + val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id) + statuses.flatMap(_._2.map(_._2)) } val nonEmptyBlocks = blockSizes.filter(x => x > 0) @@ -116,7 +116,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("zero sized blocks without kryo") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) // 201 partitions (greater than "spark.shuffle.sort.bypassMergeThreshold") from 4 keys val NUM_BLOCKS = 201 @@ -130,8 +130,8 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) - statuses.map(x => x._2) + val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id) + statuses.flatMap(_._2.map(_._2)) } val nonEmptyBlocks = blockSizes.filter(x => x > 0) @@ -141,7 +141,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("shuffle on mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) @@ -154,7 +154,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("sorting on mutable pairs") { // This is not in SortingSuite because of the local cluster setup. // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data = Array(p(1, 11), p(3, 33), p(100, 100), p(2, 22)) val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) @@ -168,7 +168,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("cogroup using mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"), p(3, "3")) @@ -195,7 +195,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("subtract mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1), p(3, 33)) val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22")) @@ -210,7 +210,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("sort with Java non serializable class - Kryo") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks val myConf = conf.clone().set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - sc = new SparkContext("local-cluster[2,1,512]", "test", myConf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", myConf) val a = sc.parallelize(1 to 10, 2) val b = a.map { x => (new NonJavaSerializableClass(x), x) @@ -223,7 +223,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("sort with Java non serializable class - Java") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) val a = sc.parallelize(1 to 10, 2) val b = a.map { x => (new NonJavaSerializableClass(x), x) diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 9fbaeb33f97c..ff9a92cc0a42 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark import java.util.concurrent.{TimeUnit, Executors} +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.{Try, Random} @@ -148,7 +149,6 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst } test("Thread safeness - SPARK-5425") { - import scala.collection.JavaConversions._ val executor = Executors.newSingleThreadScheduledExecutor() val sf = executor.scheduleAtFixedRate(new Runnable { override def run(): Unit = @@ -163,8 +163,9 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst } } finally { executor.shutdownNow() - for (key <- System.getProperties.stringPropertyNames() if key.startsWith("spark.5425.")) - System.getProperties.remove(key) + val sysProps = System.getProperties + for (key <- sysProps.stringPropertyNames().asScala if key.startsWith("spark.5425.")) + sysProps.remove(key) } } @@ -260,10 +261,10 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst assert(RpcUtils.retryWaitMs(conf) === 2L) conf.set("spark.akka.askTimeout", "3") - assert(RpcUtils.askTimeout(conf) === (3 seconds)) + assert(RpcUtils.askRpcTimeout(conf).duration === (3 seconds)) conf.set("spark.akka.lookupTimeout", "4") - assert(RpcUtils.lookupTimeout(conf) === (4 seconds)) + assert(RpcUtils.lookupRpcTimeout(conf).duration === (4 seconds)) } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index f89e3d0a4992..e5a14a69ef05 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark import org.scalatest.PrivateMethodTester +import org.apache.spark.util.Utils import org.apache.spark.scheduler.{SchedulerBackend, TaskScheduler, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} @@ -122,7 +123,7 @@ class SparkContextSchedulerCreationSuite } test("local-cluster") { - createTaskScheduler("local-cluster[3, 14, 512]").backend match { + createTaskScheduler("local-cluster[3, 14, 1024]").backend match { case s: SparkDeploySchedulerBackend => // OK case _ => fail() } @@ -131,7 +132,7 @@ class SparkContextSchedulerCreationSuite def testYarn(master: String, expectedClassName: String) { try { val sched = createTaskScheduler(master) - assert(sched.getClass === Class.forName(expectedClassName)) + assert(sched.getClass === Utils.classForName(expectedClassName)) } catch { case e: SparkException => assert(e.getMessage.contains("YARN mode not available")) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 6838b35ab4cc..d4f2ea87650a 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.util.Utils import scala.concurrent.Await import scala.concurrent.duration.Duration +import org.scalatest.Matchers._ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { @@ -272,4 +273,24 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { sc.stop() } } + + test("calling multiple sc.stop() must not throw any exception") { + noException should be thrownBy { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + val cnt = sc.parallelize(1 to 4).count() + sc.cancelAllJobs() + sc.stop() + // call stop second time + sc.stop() + } + } + + test("No exception when both num-executors and dynamic allocation set.") { + noException should be thrownBy { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local") + .set("spark.dynamicAllocation.enabled", "true").set("spark.executor.instances", "6")) + assert(sc.executorAllocationManager.isEmpty) + assert(sc.getConf.getInt("spark.executor.instances", 0) === 6) + } + } } diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index 6580139df6c6..54c131cdae36 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -36,7 +36,7 @@ object ThreadingSuiteState { } } -class ThreadingSuite extends SparkFunSuite with LocalSparkContext { +class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { test("accessing SparkContext form a different thread") { sc = new SparkContext("local", "test") @@ -119,30 +119,38 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext { val nums = sc.parallelize(1 to 2, 2) val sem = new Semaphore(0) ThreadingSuiteState.clear() + var throwable: Option[Throwable] = None for (i <- 0 until 2) { new Thread { override def run() { - val ans = nums.map(number => { - val running = ThreadingSuiteState.runningThreads - running.getAndIncrement() - val time = System.currentTimeMillis() - while (running.get() != 4 && System.currentTimeMillis() < time + 1000) { - Thread.sleep(100) - } - if (running.get() != 4) { - println("Waited 1 second without seeing runningThreads = 4 (it was " + - running.get() + "); failing test") - ThreadingSuiteState.failed.set(true) - } - number - }).collect() - assert(ans.toList === List(1, 2)) - sem.release() + try { + val ans = nums.map(number => { + val running = ThreadingSuiteState.runningThreads + running.getAndIncrement() + val time = System.currentTimeMillis() + while (running.get() != 4 && System.currentTimeMillis() < time + 1000) { + Thread.sleep(100) + } + if (running.get() != 4) { + ThreadingSuiteState.failed.set(true) + } + number + }).collect() + assert(ans.toList === List(1, 2)) + } catch { + case t: Throwable => + throwable = Some(t) + } finally { + sem.release() + } } }.start() } sem.acquire(2) + throwable.foreach { t => throw improveStackTrace(t) } if (ThreadingSuiteState.failed.get()) { + logError("Waited 1 second without seeing runningThreads = 4 (it was " + + ThreadingSuiteState.runningThreads.get() + "); failing test") fail("One or more threads didn't see runningThreads = 4") } } @@ -150,13 +158,19 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext { test("set local properties in different thread") { sc = new SparkContext("local", "test") val sem = new Semaphore(0) - + var throwable: Option[Throwable] = None val threads = (1 to 5).map { i => new Thread() { override def run() { - sc.setLocalProperty("test", i.toString) - assert(sc.getLocalProperty("test") === i.toString) - sem.release() + try { + sc.setLocalProperty("test", i.toString) + assert(sc.getLocalProperty("test") === i.toString) + } catch { + case t: Throwable => + throwable = Some(t) + } finally { + sem.release() + } } } } @@ -164,6 +178,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext { threads.foreach(_.start()) sem.acquire(5) + throwable.foreach { t => throw improveStackTrace(t) } assert(sc.getLocalProperty("test") === null) } @@ -171,14 +186,20 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext { sc = new SparkContext("local", "test") sc.setLocalProperty("test", "parent") val sem = new Semaphore(0) - + var throwable: Option[Throwable] = None val threads = (1 to 5).map { i => new Thread() { override def run() { - assert(sc.getLocalProperty("test") === "parent") - sc.setLocalProperty("test", i.toString) - assert(sc.getLocalProperty("test") === i.toString) - sem.release() + try { + assert(sc.getLocalProperty("test") === "parent") + sc.setLocalProperty("test", i.toString) + assert(sc.getLocalProperty("test") === i.toString) + } catch { + case t: Throwable => + throwable = Some(t) + } finally { + sem.release() + } } } } @@ -186,50 +207,41 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext { threads.foreach(_.start()) sem.acquire(5) + throwable.foreach { t => throw improveStackTrace(t) } assert(sc.getLocalProperty("test") === "parent") assert(sc.getLocalProperty("Foo") === null) } - test("mutations to local properties should not affect submitted jobs (SPARK-6629)") { - val jobStarted = new Semaphore(0) - val jobEnded = new Semaphore(0) - @volatile var jobResult: JobResult = null - + test("mutation in parent local property does not affect child (SPARK-10563)") { sc = new SparkContext("local", "test") - sc.setJobGroup("originalJobGroupId", "description") - sc.addSparkListener(new SparkListener { - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - jobStarted.release() - } - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { - jobResult = jobEnd.jobResult - jobEnded.release() - } - }) - - // Create a new thread which will inherit the current thread's properties - val thread = new Thread() { + val originalTestValue: String = "original-value" + var threadTestValue: String = null + sc.setLocalProperty("test", originalTestValue) + var throwable: Option[Throwable] = None + val thread = new Thread { override def run(): Unit = { - assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "originalJobGroupId") - // Sleeps for a total of 10 seconds, but allows cancellation to interrupt the task try { - sc.parallelize(1 to 100).foreach { x => - Thread.sleep(100) - } + threadTestValue = sc.getLocalProperty("test") } catch { - case s: SparkException => // ignored so that we don't print noise in test logs + case t: Throwable => + throwable = Some(t) } } } + sc.setLocalProperty("test", "this-should-not-be-inherited") thread.start() - // Wait for the job to start, then mutate the original properties, which should have been - // inherited by the running job but hopefully defensively copied or snapshotted: - jobStarted.tryAcquire(10, TimeUnit.SECONDS) - sc.setJobGroup("modifiedJobGroupId", "description") - // Canceling the original job group should cancel the running job. In other words, the - // modification of the properties object should not affect the properties of running jobs - sc.cancelJobGroup("originalJobGroupId") - jobEnded.tryAcquire(10, TimeUnit.SECONDS) - assert(jobResult.isInstanceOf[JobFailed]) + thread.join() + throwable.foreach { t => throw improveStackTrace(t) } + assert(threadTestValue === originalTestValue) + } + + /** + * Improve the stack trace of an error thrown from within a thread. + * Otherwise it's difficult to tell which line in the test the error came from. + */ + private def improveStackTrace(t: Throwable): Throwable = { + t.setStackTrace(t.getStackTrace ++ Thread.currentThread.getStackTrace) + t } + } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index c054c718075f..fb7a8ae3f9d4 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -69,7 +69,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { val conf = httpConf.clone conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.broadcast.compress", "true") - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf) + sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) @@ -97,7 +97,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { val conf = torrentConf.clone conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.broadcast.compress", "true") - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf) + sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) @@ -125,7 +125,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { test("Test Lazy Broadcast variables with TorrentBroadcast") { val numSlaves = 2 val conf = torrentConf.clone - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf) + sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) val rdd = sc.parallelize(1 to numSlaves) val results = new DummyBroadcastClass(rdd).doSomething() @@ -308,10 +308,16 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { sc = if (distributed) { val _sc = - new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf) + new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", broadcastConf) // Wait until all salves are up - _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 10000) - _sc + try { + _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 10000) + _sc + } catch { + case e: Throwable => + _sc.stop() + throw e + } } else { new SparkContext("local", "test", broadcastConf) } diff --git a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala new file mode 100644 index 000000000000..967aa0976f0c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io.File +import java.util.Date + +import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} +import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} +import org.apache.spark.{SecurityManager, SparkConf} + +private[deploy] object DeployTestUtils { + def createAppDesc(): ApplicationDescription = { + val cmd = new Command("mainClass", List("arg1", "arg2"), Map(), Seq(), Seq(), Seq()) + new ApplicationDescription("name", Some(4), 1234, cmd, "appUiUrl") + } + + def createAppInfo() : ApplicationInfo = { + val appInfo = new ApplicationInfo(JsonConstants.appInfoStartTime, + "id", createAppDesc(), JsonConstants.submitDate, null, Int.MaxValue) + appInfo.endTime = JsonConstants.currTimeInMillis + appInfo + } + + def createDriverCommand(): Command = new Command( + "org.apache.spark.FakeClass", Seq("some arg --and-some options -g foo"), + Map(("K1", "V1"), ("K2", "V2")), Seq("cp1", "cp2"), Seq("lp1", "lp2"), Seq("-Dfoo") + ) + + def createDriverDesc(): DriverDescription = + new DriverDescription("hdfs://some-dir/some.jar", 100, 3, false, createDriverCommand()) + + def createDriverInfo(): DriverInfo = new DriverInfo(3, "driver-3", + createDriverDesc(), new Date()) + + def createWorkerInfo(): WorkerInfo = { + val workerInfo = new WorkerInfo("id", "host", 8080, 4, 1234, null, 80, "publicAddress") + workerInfo.lastHeartbeat = JsonConstants.currTimeInMillis + workerInfo + } + + def createExecutorRunner(execId: Int): ExecutorRunner = { + new ExecutorRunner( + "appId", + execId, + createAppDesc(), + 4, + 1234, + null, + "workerId", + "host", + 123, + "publicAddress", + new File("sparkHome"), + new File("workDir"), + "akka://worker", + new SparkConf, + Seq("localDir"), + ExecutorState.RUNNING) + } + + def createDriverRunner(driverId: String): DriverRunner = { + val conf = new SparkConf() + new DriverRunner( + conf, + driverId, + new File("workDir"), + new File("sparkHome"), + createDriverDesc(), + null, + "akka://worker", + new SecurityManager(conf)) + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala index 823050b0aabb..d93febcfd23f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala @@ -19,6 +19,10 @@ package org.apache.spark.deploy import java.io.{File, FileInputStream, FileOutputStream} import java.util.jar.{JarEntry, JarOutputStream} +import java.util.jar.Attributes.Name +import java.util.jar.Manifest + +import scala.collection.mutable.ArrayBuffer import com.google.common.io.{Files, ByteStreams} @@ -35,7 +39,7 @@ private[deploy] object IvyTestUtils { * Create the path for the jar and pom from the maven coordinate. Extension should be `jar` * or `pom`. */ - private def pathFromCoordinate( + private[deploy] def pathFromCoordinate( artifact: MavenCoordinate, prefix: File, ext: String, @@ -52,7 +56,7 @@ private[deploy] object IvyTestUtils { } /** Returns the artifact naming based on standard ivy or maven format. */ - private def artifactName( + private[deploy] def artifactName( artifact: MavenCoordinate, useIvyLayout: Boolean, ext: String = ".jar"): String = { @@ -73,7 +77,7 @@ private[deploy] object IvyTestUtils { } /** Write the contents to a file to the supplied directory. */ - private def writeFile(dir: File, fileName: String, contents: String): File = { + private[deploy] def writeFile(dir: File, fileName: String, contents: String): File = { val outputFile = new File(dir, fileName) val outputStream = new FileOutputStream(outputFile) outputStream.write(contents.toCharArray.map(_.toByte)) @@ -90,6 +94,42 @@ private[deploy] object IvyTestUtils { writeFile(dir, "mylib.py", contents) } + /** Create an example R package that calls the given Java class. */ + private def createRFiles( + dir: File, + className: String, + packageName: String): Seq[(String, File)] = { + val rFilesDir = new File(dir, "R" + File.separator + "pkg") + Files.createParentDirs(new File(rFilesDir, "R" + File.separator + "mylib.R")) + val contents = + s"""myfunc <- function(x) { + | SparkR:::callJStatic("$packageName.$className", "myFunc", x) + |} + """.stripMargin + val source = writeFile(new File(rFilesDir, "R"), "mylib.R", contents) + val description = + """Package: sparkPackageTest + |Type: Package + |Title: Test for building an R package + |Version: 0.1 + |Date: 2015-07-08 + |Author: Burak Yavuz + |Imports: methods, SparkR + |Depends: R (>= 3.1), methods, SparkR + |Suggests: testthat + |Description: Test for building an R package within a jar + |License: Apache License (== 2.0) + |Collate: 'mylib.R' + """.stripMargin + val descFile = writeFile(rFilesDir, "DESCRIPTION", description) + val namespace = + """import(SparkR) + |export("myfunc") + """.stripMargin + val nameFile = writeFile(rFilesDir, "NAMESPACE", namespace) + Seq(("R/pkg/R/mylib.R", source), ("R/pkg/DESCRIPTION", descFile), ("R/pkg/NAMESPACE", nameFile)) + } + /** Create a simple testable Class. */ private def createJavaClass(dir: File, className: String, packageName: String): File = { val contents = @@ -97,17 +137,14 @@ private[deploy] object IvyTestUtils { | |import java.lang.Integer; | - |class $className implements java.io.Serializable { - | - | public $className() {} - | - | public Integer myFunc(Integer x) { + |public class $className implements java.io.Serializable { + | public static Integer myFunc(Integer x) { | return x + 1; | } |} """.stripMargin val sourceFile = - new JavaSourceFromString(new File(dir, className + ".java").getAbsolutePath, contents) + new JavaSourceFromString(new File(dir, className).getAbsolutePath, contents) createCompiledClass(className, dir, sourceFile, Seq.empty) } @@ -199,14 +236,25 @@ private[deploy] object IvyTestUtils { } /** Create the jar for the given maven coordinate, using the supplied files. */ - private def packJar( + private[deploy] def packJar( dir: File, artifact: MavenCoordinate, files: Seq[(String, File)], - useIvyLayout: Boolean): File = { + useIvyLayout: Boolean, + withR: Boolean, + withManifest: Option[Manifest] = None): File = { val jarFile = new File(dir, artifactName(artifact, useIvyLayout)) val jarFileStream = new FileOutputStream(jarFile) - val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest()) + val manifest = withManifest.getOrElse { + val mani = new Manifest() + if (withR) { + val attr = mani.getMainAttributes + attr.put(Name.MANIFEST_VERSION, "1.0") + attr.put(new Name("Spark-HasRPackage"), "true") + } + mani + } + val jarStream = new JarOutputStream(jarFileStream, manifest) for (file <- files) { val jarEntry = new JarEntry(file._1) @@ -239,7 +287,8 @@ private[deploy] object IvyTestUtils { dependencies: Option[Seq[MavenCoordinate]] = None, tempDir: Option[File] = None, useIvyLayout: Boolean = false, - withPython: Boolean = false): File = { + withPython: Boolean = false, + withR: Boolean = false): File = { // Where the root of the repository exists, and what Ivy will search in val tempPath = tempDir.getOrElse(Files.createTempDir()) // Create directory if it doesn't exist @@ -255,14 +304,16 @@ private[deploy] object IvyTestUtils { val javaClass = createJavaClass(root, className, artifact.groupId) // A tuple of files representation in the jar, and the file val javaFile = (artifact.groupId.replace(".", "/") + "/" + javaClass.getName, javaClass) - val allFiles = - if (withPython) { - val pythonFile = createPythonFile(root) - Seq(javaFile, (pythonFile.getName, pythonFile)) - } else { - Seq(javaFile) - } - val jarFile = packJar(jarPath, artifact, allFiles, useIvyLayout) + val allFiles = ArrayBuffer[(String, File)](javaFile) + if (withPython) { + val pythonFile = createPythonFile(root) + allFiles.append((pythonFile.getName, pythonFile)) + } + if (withR) { + val rFiles = createRFiles(root, className, artifact.groupId) + allFiles.append(rFiles: _*) + } + val jarFile = packJar(jarPath, artifact, allFiles, useIvyLayout, withR) assert(jarFile.exists(), "Problem creating Jar file") val descriptor = createDescriptor(tempPath, artifact, dependencies, useIvyLayout) assert(descriptor.exists(), "Problem creating Pom file") @@ -286,9 +337,10 @@ private[deploy] object IvyTestUtils { dependencies: Option[String], rootDir: Option[File], useIvyLayout: Boolean = false, - withPython: Boolean = false): File = { + withPython: Boolean = false, + withR: Boolean = false): File = { val deps = dependencies.map(SparkSubmitUtils.extractMavenCoordinates) - val mainRepo = createLocalRepository(artifact, deps, rootDir, useIvyLayout, withPython) + val mainRepo = createLocalRepository(artifact, deps, rootDir, useIvyLayout, withPython, withR) deps.foreach { seq => seq.foreach { dep => createLocalRepository(dep, None, Some(mainRepo), useIvyLayout, withPython = false) }} @@ -311,11 +363,12 @@ private[deploy] object IvyTestUtils { rootDir: Option[File], useIvyLayout: Boolean = false, withPython: Boolean = false, + withR: Boolean = false, ivySettings: IvySettings = new IvySettings)(f: String => Unit): Unit = { val deps = dependencies.map(SparkSubmitUtils.extractMavenCoordinates) purgeLocalIvyCache(artifact, deps, ivySettings) val repo = createLocalRepositoryForTests(artifact, dependencies, rootDir, useIvyLayout, - withPython) + withPython, withR) try { f(repo.toURI.toString) } finally { diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index 08529e0ef280..0a9f128a3a6b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.deploy -import java.io.File import java.util.Date import com.fasterxml.jackson.core.JsonParseException @@ -25,12 +24,14 @@ import org.json4s._ import org.json4s.jackson.JsonMethods import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} -import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, RecoveryState, WorkerInfo} -import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} -import org.apache.spark.{JsonTestUtils, SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.deploy.master.{ApplicationInfo, RecoveryState} +import org.apache.spark.deploy.worker.ExecutorRunner +import org.apache.spark.{JsonTestUtils, SparkFunSuite} class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { + import org.apache.spark.deploy.DeployTestUtils._ + test("writeApplicationInfo") { val output = JsonProtocol.writeApplicationInfo(createAppInfo()) assertValidJson(output) @@ -50,7 +51,7 @@ class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { } test("writeExecutorRunner") { - val output = JsonProtocol.writeExecutorRunner(createExecutorRunner()) + val output = JsonProtocol.writeExecutorRunner(createExecutorRunner(123)) assertValidJson(output) assertValidDataInJson(output, JsonMethods.parse(JsonConstants.executorRunnerJsonStr)) } @@ -77,9 +78,10 @@ class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { test("writeWorkerState") { val executors = List[ExecutorRunner]() - val finishedExecutors = List[ExecutorRunner](createExecutorRunner(), createExecutorRunner()) - val drivers = List(createDriverRunner()) - val finishedDrivers = List(createDriverRunner(), createDriverRunner()) + val finishedExecutors = List[ExecutorRunner](createExecutorRunner(123), + createExecutorRunner(123)) + val drivers = List(createDriverRunner("driverId")) + val finishedDrivers = List(createDriverRunner("driverId"), createDriverRunner("driverId")) val stateResponse = new WorkerStateResponse("host", 8080, "workerId", executors, finishedExecutors, drivers, finishedDrivers, "masterUrl", 4, 1234, 4, 1234, "masterWebUiUrl") val output = JsonProtocol.writeWorkerState(stateResponse) @@ -87,47 +89,6 @@ class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { assertValidDataInJson(output, JsonMethods.parse(JsonConstants.workerStateJsonStr)) } - def createAppDesc(): ApplicationDescription = { - val cmd = new Command("mainClass", List("arg1", "arg2"), Map(), Seq(), Seq(), Seq()) - new ApplicationDescription("name", Some(4), 1234, cmd, "appUiUrl") - } - - def createAppInfo() : ApplicationInfo = { - val appInfo = new ApplicationInfo(JsonConstants.appInfoStartTime, - "id", createAppDesc(), JsonConstants.submitDate, null, Int.MaxValue) - appInfo.endTime = JsonConstants.currTimeInMillis - appInfo - } - - def createDriverCommand(): Command = new Command( - "org.apache.spark.FakeClass", Seq("some arg --and-some options -g foo"), - Map(("K1", "V1"), ("K2", "V2")), Seq("cp1", "cp2"), Seq("lp1", "lp2"), Seq("-Dfoo") - ) - - def createDriverDesc(): DriverDescription = - new DriverDescription("hdfs://some-dir/some.jar", 100, 3, false, createDriverCommand()) - - def createDriverInfo(): DriverInfo = new DriverInfo(3, "driver-3", - createDriverDesc(), new Date()) - - def createWorkerInfo(): WorkerInfo = { - val workerInfo = new WorkerInfo("id", "host", 8080, 4, 1234, null, 80, "publicAddress") - workerInfo.lastHeartbeat = JsonConstants.currTimeInMillis - workerInfo - } - - def createExecutorRunner(): ExecutorRunner = { - new ExecutorRunner("appId", 123, createAppDesc(), 4, 1234, null, "workerId", "host", 123, - "publicAddress", new File("sparkHome"), new File("workDir"), "akka://worker", - new SparkConf, Seq("localDir"), ExecutorState.RUNNING) - } - - def createDriverRunner(): DriverRunner = { - val conf = new SparkConf() - new DriverRunner(conf, "driverId", new File("workDir"), new File("sparkHome"), - createDriverDesc(), null, "akka://worker", new SecurityManager(conf)) - } - def assertValidJson(json: JValue) { try { JsonMethods.parse(JsonMethods.compact(json)) diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala index ddc92814c0ac..86eb41dd7e5d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.deploy import java.net.URL -import scala.collection.JavaConversions._ import scala.collection.mutable import scala.io.Source @@ -33,7 +32,7 @@ class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { private val WAIT_TIMEOUT_MILLIS = 10000 test("verify that correct log urls get propagated from workers") { - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local-cluster[2,1,1024]", "test") val listener = new SaveExecutorInfo sc.addSparkListener(listener) @@ -66,7 +65,7 @@ class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { } val conf = new MySparkConf().set( "spark.extraListeners", classOf[SaveExecutorInfo].getName) - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) // Trigger a job so that executors get added sc.parallelize(1 to 100, 4).map(_.toString).count() diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala new file mode 100644 index 000000000000..1ed4bae3ca21 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io.{PrintStream, OutputStream, File} +import java.net.URI +import java.util.jar.Attributes.Name +import java.util.jar.{JarFile, Manifest} +import java.util.zip.ZipFile + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import com.google.common.io.Files +import org.apache.commons.io.FileUtils +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.SparkFunSuite +import org.apache.spark.api.r.RUtils +import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate + +class RPackageUtilsSuite extends SparkFunSuite with BeforeAndAfterEach { + + private val main = MavenCoordinate("a", "b", "c") + private val dep1 = MavenCoordinate("a", "dep1", "c") + private val dep2 = MavenCoordinate("a", "dep2", "d") + + private def getJarPath(coord: MavenCoordinate, repo: File): File = { + new File(IvyTestUtils.pathFromCoordinate(coord, repo, "jar", useIvyLayout = false), + IvyTestUtils.artifactName(coord, useIvyLayout = false, ".jar")) + } + + private val lineBuffer = ArrayBuffer[String]() + + private val noOpOutputStream = new OutputStream { + def write(b: Int) = {} + } + + /** Simple PrintStream that reads data into a buffer */ + private class BufferPrintStream extends PrintStream(noOpOutputStream) { + // scalastyle:off println + override def println(line: String) { + // scalastyle:on println + lineBuffer += line + } + } + + def beforeAll() { + System.setProperty("spark.testing", "true") + } + + override def beforeEach(): Unit = { + lineBuffer.clear() + } + + test("pick which jars to unpack using the manifest") { + val deps = Seq(dep1, dep2).mkString(",") + IvyTestUtils.withRepository(main, Some(deps), None, withR = true) { repo => + val jars = Seq(main, dep1, dep2).map(c => new JarFile(getJarPath(c, new File(new URI(repo))))) + assert(RPackageUtils.checkManifestForR(jars(0)), "should have R code") + assert(!RPackageUtils.checkManifestForR(jars(1)), "should not have R code") + assert(!RPackageUtils.checkManifestForR(jars(2)), "should not have R code") + } + } + + test("build an R package from a jar end to end") { + assume(RUtils.isRInstalled, "R isn't installed on this machine.") + val deps = Seq(dep1, dep2).mkString(",") + IvyTestUtils.withRepository(main, Some(deps), None, withR = true) { repo => + val jars = Seq(main, dep1, dep2).map { c => + getJarPath(c, new File(new URI(repo))) + }.mkString(",") + RPackageUtils.checkAndBuildRPackage(jars, new BufferPrintStream, verbose = true) + val firstJar = jars.substring(0, jars.indexOf(",")) + val output = lineBuffer.mkString("\n") + assert(output.contains("Building R package")) + assert(output.contains("Extracting")) + assert(output.contains(s"$firstJar contains R source code. Now installing package.")) + assert(output.contains("doesn't contain R source code, skipping...")) + } + } + + test("jars that don't exist are skipped and print warning") { + assume(RUtils.isRInstalled, "R isn't installed on this machine.") + val deps = Seq(dep1, dep2).mkString(",") + IvyTestUtils.withRepository(main, Some(deps), None, withR = true) { repo => + val jars = Seq(main, dep1, dep2).map { c => + getJarPath(c, new File(new URI(repo))) + "dummy" + }.mkString(",") + RPackageUtils.checkAndBuildRPackage(jars, new BufferPrintStream, verbose = true) + val individualJars = jars.split(",") + val output = lineBuffer.mkString("\n") + individualJars.foreach { jarFile => + assert(output.contains(s"$jarFile")) + } + } + } + + test("faulty R package shows documentation") { + assume(RUtils.isRInstalled, "R isn't installed on this machine.") + IvyTestUtils.withRepository(main, None, None) { repo => + val manifest = new Manifest + val attr = manifest.getMainAttributes + attr.put(Name.MANIFEST_VERSION, "1.0") + attr.put(new Name("Spark-HasRPackage"), "true") + val jar = IvyTestUtils.packJar(new File(new URI(repo)), dep1, Nil, + useIvyLayout = false, withR = false, Some(manifest)) + RPackageUtils.checkAndBuildRPackage(jar.getAbsolutePath, new BufferPrintStream, + verbose = true) + val output = lineBuffer.mkString("\n") + assert(output.contains(RPackageUtils.RJarDoc)) + } + } + + test("SparkR zipping works properly") { + val tempDir = Files.createTempDir() + try { + IvyTestUtils.writeFile(tempDir, "test.R", "abc") + val fakeSparkRDir = new File(tempDir, "SparkR") + assert(fakeSparkRDir.mkdirs()) + IvyTestUtils.writeFile(fakeSparkRDir, "abc.R", "abc") + IvyTestUtils.writeFile(fakeSparkRDir, "DESCRIPTION", "abc") + IvyTestUtils.writeFile(tempDir, "package.zip", "abc") // fake zip file :) + val fakePackageDir = new File(tempDir, "packageTest") + assert(fakePackageDir.mkdirs()) + IvyTestUtils.writeFile(fakePackageDir, "def.R", "abc") + IvyTestUtils.writeFile(fakePackageDir, "DESCRIPTION", "abc") + val finalZip = RPackageUtils.zipRLibraries(tempDir, "sparkr.zip") + assert(finalZip.exists()) + val entries = new ZipFile(finalZip).entries().asScala.map(_.getName).toSeq + assert(entries.contains("/test.R")) + assert(entries.contains("/SparkR/abc.R")) + assert(entries.contains("/SparkR/DESCRIPTION")) + assert(!entries.contains("/package.zip")) + assert(entries.contains("/packageTest/def.R")) + assert(entries.contains("/packageTest/DESCRIPTION")) + } finally { + FileUtils.deleteDirectory(tempDir) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 357ed90be3f5..1110ca6051a4 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -51,9 +51,11 @@ class SparkSubmitSuite /** Simple PrintStream that reads data into a buffer */ private class BufferPrintStream extends PrintStream(noOpOutputStream) { var lineBuffer = ArrayBuffer[String]() + // scalastyle:off println override def println(line: String) { lineBuffer += line } + // scalastyle:on println } /** Returns true if the script exits and the given search string is printed. */ @@ -81,6 +83,7 @@ class SparkSubmitSuite } } + // scalastyle:off println test("prints usage on empty input") { testPrematureExit(Array[String](), "Usage: spark-submit") } @@ -156,7 +159,6 @@ class SparkSubmitSuite childArgsStr should include ("--executor-cores 5") childArgsStr should include ("--arg arg1 --arg arg2") childArgsStr should include ("--queue thequeue") - childArgsStr should include ("--num-executors 6") childArgsStr should include regex ("--jar .*thejar.jar") childArgsStr should include regex ("--addJars .*one.jar,.*two.jar,.*three.jar") childArgsStr should include regex ("--files .*file1.txt,.*file2.txt") @@ -243,7 +245,7 @@ class SparkSubmitSuite mainClass should be ("org.apache.spark.deploy.Client") } classpath should have size 0 - sysProps should have size 8 + sysProps should have size 9 sysProps.keys should contain ("SPARK_SUBMIT") sysProps.keys should contain ("spark.master") sysProps.keys should contain ("spark.app.name") @@ -252,6 +254,7 @@ class SparkSubmitSuite sysProps.keys should contain ("spark.driver.cores") sysProps.keys should contain ("spark.driver.supervise") sysProps.keys should contain ("spark.shuffle.spill") + sysProps.keys should contain ("spark.submit.deployMode") sysProps("spark.shuffle.spill") should be ("false") } @@ -321,6 +324,8 @@ class SparkSubmitSuite "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", "--master", "local", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", unusedJar.toString) runSparkSubmit(args) } @@ -333,7 +338,9 @@ class SparkSubmitSuite val args = Seq( "--class", JarCreationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", - "--master", "local-cluster[2,1,512]", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", "--jars", jarsString, unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") runSparkSubmit(args) @@ -348,16 +355,41 @@ class SparkSubmitSuite val args = Seq( "--class", JarCreationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", - "--master", "local-cluster[2,1,512]", + "--master", "local-cluster[2,1,1024]", "--packages", Seq(main, dep).mkString(","), "--repositories", repo, "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", unusedJar.toString, "my.great.lib.MyLib", "my.great.dep.MyLib") runSparkSubmit(args) } } + test("correctly builds R packages included in a jar with --packages") { + // TODO(SPARK-9603): Building a package to $SPARK_HOME/R/lib is unavailable on Jenkins. + // It's hard to write the test in SparkR (because we can't create the repository dynamically) + /* + assume(RUtils.isRInstalled, "R isn't installed on this machine.") + val main = MavenCoordinate("my.great.lib", "mylib", "0.1") + val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) + val rScriptDir = + Seq(sparkHome, "R", "pkg", "inst", "tests", "packageInAJarTest.R").mkString(File.separator) + assert(new File(rScriptDir).exists) + IvyTestUtils.withRepository(main, None, None, withR = true) { repo => + val args = Seq( + "--name", "testApp", + "--master", "local-cluster[2,1,1024]", + "--packages", main.toString, + "--repositories", repo, + "--verbose", + "--conf", "spark.ui.enabled=false", + rScriptDir) + runSparkSubmit(args) + } + */ + } + test("resolves command line argument paths correctly") { val jars = "/jar1,/jar2" // --jars val files = "hdfs:/file1,file2" // --files @@ -473,6 +505,8 @@ class SparkSubmitSuite "--master", "local", "--conf", "spark.driver.extraClassPath=" + systemJar, "--conf", "spark.driver.userClassPathFirst=true", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", userJar.toString) runSparkSubmit(args) } @@ -491,6 +525,7 @@ class SparkSubmitSuite appArgs.executorMemory should be ("2.3g") } } + // scalastyle:on println // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. private def runSparkSubmit(args: Seq[String]): Unit = { @@ -536,8 +571,8 @@ object JarCreationTest extends Logging { val result = sc.makeRDD(1 to 100, 10).mapPartitions { x => var exception: String = null try { - Class.forName(args(0), true, Thread.currentThread().getContextClassLoader) - Class.forName(args(1), true, Thread.currentThread().getContextClassLoader) + Utils.classForName(args(0)) + Utils.classForName(args(1)) } catch { case t: Throwable => exception = t + "\n" + t.getStackTraceString @@ -548,6 +583,7 @@ object JarCreationTest extends Logging { if (result.nonEmpty) { throw new Exception("Could not load user class from jar:\n" + result(0)) } + sc.stop() } } @@ -573,6 +609,7 @@ object SimpleApplicationTest { s"Master had $config=$masterValue but executor had $config=$executorValue") } } + sc.stop() } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 12c40f0b7d65..63c346c1b890 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -41,9 +41,11 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { /** Simple PrintStream that reads data into a buffer */ private class BufferPrintStream extends PrintStream(noOpOutputStream) { var lineBuffer = ArrayBuffer[String]() + // scalastyle:off println override def println(line: String) { lineBuffer += line } + // scalastyle:on println } override def beforeAll() { @@ -77,9 +79,9 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(resolver2.getResolvers.size() === 7) val expected = repos.split(",").map(r => s"$r/") resolver2.getResolvers.toArray.zipWithIndex.foreach { case (resolver: AbstractResolver, i) => - if (i > 3) { - assert(resolver.getName === s"repo-${i - 3}") - assert(resolver.asInstanceOf[IBiblioResolver].getRoot === expected(i - 4)) + if (i < 3) { + assert(resolver.getName === s"repo-${i + 1}") + assert(resolver.asInstanceOf[IBiblioResolver].getRoot === expected(i)) } } } @@ -93,6 +95,25 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(md.getDependencies.length === 2) } + test("excludes works correctly") { + val md = SparkSubmitUtils.getModuleDescriptor + val excludes = Seq("a:b", "c:d") + excludes.foreach { e => + md.addExcludeRule(SparkSubmitUtils.createExclusion(e + ":*", new IvySettings, "default")) + } + val rules = md.getAllExcludeRules + assert(rules.length === 2) + val rule1 = rules(0).getId.getModuleId + assert(rule1.getOrganisation === "a") + assert(rule1.getName === "b") + val rule2 = rules(1).getId.getModuleId + assert(rule2.getOrganisation === "c") + assert(rule2.getName === "d") + intercept[IllegalArgumentException] { + SparkSubmitUtils.createExclusion("e:f:g:h", new IvySettings, "default") + } + } + test("ivy path works correctly") { val md = SparkSubmitUtils.getModuleDescriptor val artifacts = for (i <- 0 until 3) yield new MDArtifact(md, s"jar-$i", "jar", "jar") @@ -166,4 +187,15 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(files.indexOf(main.artifactId) >= 0, "Did not return artifact") } } + + test("exclude dependencies end to end") { + val main = new MavenCoordinate("my.great.lib", "mylib", "0.1") + val dep = "my.great.dep:mydep:0.5" + IvyTestUtils.withRepository(main, Some(dep), None) { repo => + val files = SparkSubmitUtils.resolveMavenCoordinates(main.toString, + Some(repo), None, Seq("my.great.dep:mydep"), isTest = true) + assert(files.indexOf(main.artifactId) >= 0, "Did not return artifact") + assert(files.indexOf("my.great.dep") < 0, "Returned excluded artifact") + } + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala new file mode 100644 index 000000000000..1f2a0f0d309c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -0,0 +1,383 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import org.mockito.Mockito.{mock, when} +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark._ +import org.apache.spark.deploy.master.Master +import org.apache.spark.deploy.worker.Worker +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv} +import org.apache.spark.scheduler.cluster._ +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RegisterExecutor + +/** + * End-to-end tests for dynamic allocation in standalone mode. + */ +class StandaloneDynamicAllocationSuite + extends SparkFunSuite + with LocalSparkContext + with BeforeAndAfterAll { + + private val numWorkers = 2 + private val conf = new SparkConf() + private val securityManager = new SecurityManager(conf) + + private var masterRpcEnv: RpcEnv = null + private var workerRpcEnvs: Seq[RpcEnv] = null + private var master: Master = null + private var workers: Seq[Worker] = null + + /** + * Start the local cluster. + * Note: local-cluster mode is insufficient because we want a reference to the Master. + */ + override def beforeAll(): Unit = { + super.beforeAll() + masterRpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityManager) + workerRpcEnvs = (0 until numWorkers).map { i => + RpcEnv.create(Worker.SYSTEM_NAME + i, "localhost", 0, conf, securityManager) + } + master = makeMaster() + workers = makeWorkers(10, 2048) + } + + override def afterAll(): Unit = { + masterRpcEnv.shutdown() + workerRpcEnvs.foreach(_.shutdown()) + master.stop() + workers.foreach(_.stop()) + masterRpcEnv = null + workerRpcEnvs = null + master = null + workers = null + super.afterAll() + } + + test("dynamic allocation default behavior") { + sc = new SparkContext(appConf) + val appId = sc.applicationId + assert(master.apps.size === 1) + assert(master.apps.head.id === appId) + assert(master.apps.head.executors.size === 2) + assert(master.apps.head.getExecutorLimit === Int.MaxValue) + // kill all executors + assert(killAllExecutors(sc)) + assert(master.apps.head.executors.size === 0) + assert(master.apps.head.getExecutorLimit === 0) + // request 1 + assert(sc.requestExecutors(1)) + assert(master.apps.head.executors.size === 1) + assert(master.apps.head.getExecutorLimit === 1) + // request 1 more + assert(sc.requestExecutors(1)) + assert(master.apps.head.executors.size === 2) + assert(master.apps.head.getExecutorLimit === 2) + // request 1 more; this one won't go through + assert(sc.requestExecutors(1)) + assert(master.apps.head.executors.size === 2) + assert(master.apps.head.getExecutorLimit === 3) + // kill all existing executors; we should end up with 3 - 2 = 1 executor + assert(killAllExecutors(sc)) + assert(master.apps.head.executors.size === 1) + assert(master.apps.head.getExecutorLimit === 1) + // kill all executors again; this time we'll have 1 - 1 = 0 executors left + assert(killAllExecutors(sc)) + assert(master.apps.head.executors.size === 0) + assert(master.apps.head.getExecutorLimit === 0) + // request many more; this increases the limit well beyond the cluster capacity + assert(sc.requestExecutors(1000)) + assert(master.apps.head.executors.size === 2) + assert(master.apps.head.getExecutorLimit === 1000) + } + + test("dynamic allocation with max cores <= cores per worker") { + sc = new SparkContext(appConf.set("spark.cores.max", "8")) + val appId = sc.applicationId + assert(master.apps.size === 1) + assert(master.apps.head.id === appId) + assert(master.apps.head.executors.size === 2) + assert(master.apps.head.executors.values.map(_.cores).toArray === Array(4, 4)) + assert(master.apps.head.getExecutorLimit === Int.MaxValue) + // kill all executors + assert(killAllExecutors(sc)) + assert(master.apps.head.executors.size === 0) + assert(master.apps.head.getExecutorLimit === 0) + // request 1 + assert(sc.requestExecutors(1)) + assert(master.apps.head.executors.size === 1) + assert(master.apps.head.executors.values.head.cores === 8) + assert(master.apps.head.getExecutorLimit === 1) + // request 1 more; this one won't go through because we're already at max cores. + // This highlights a limitation of using dynamic allocation with max cores WITHOUT + // setting cores per executor: once an application scales down and then scales back + // up, its executors may not be spread out anymore! + assert(sc.requestExecutors(1)) + assert(master.apps.head.executors.size === 1) + assert(master.apps.head.getExecutorLimit === 2) + // request 1 more; this one also won't go through for the same reason + assert(sc.requestExecutors(1)) + assert(master.apps.head.executors.size === 1) + assert(master.apps.head.getExecutorLimit === 3) + // kill all existing executors; we should end up with 3 - 1 = 2 executor + // Note: we scheduled these executors together, so their cores should be evenly distributed + assert(killAllExecutors(sc)) + assert(master.apps.head.executors.size === 2) + assert(master.apps.head.executors.values.map(_.cores).toArray === Array(4, 4)) + assert(master.apps.head.getExecutorLimit === 2) + // kill all executors again; this time we'll have 1 - 1 = 0 executors left + assert(killAllExecutors(sc)) + assert(master.apps.head.executors.size === 0) + assert(master.apps.head.getExecutorLimit === 0) + // request many more; this increases the limit well beyond the cluster capacity + assert(sc.requestExecutors(1000)) + assert(master.apps.head.executors.size === 2) + assert(master.apps.head.executors.values.map(_.cores).toArray === Array(4, 4)) + assert(master.apps.head.getExecutorLimit === 1000) + } + + test("dynamic allocation with max cores > cores per worker") { + sc = new SparkContext(appConf.set("spark.cores.max", "16")) + val appId = sc.applicationId + assert(master.apps.size === 1) + assert(master.apps.head.id === appId) + assert(master.apps.head.executors.size === 2) + assert(master.apps.head.executors.values.map(_.cores).toArray === Array(8, 8)) + assert(master.apps.head.getExecutorLimit === Int.MaxValue) + // kill all executors + assert(killAllExecutors(sc)) + assert(master.apps.head.executors.size === 0) + assert(master.apps.head.getExecutorLimit === 0) + // request 1 + assert(sc.requestExecutors(1)) + assert(master.apps.head.executors.size === 1) + assert(master.apps.head.executors.values.head.cores === 10) + assert(master.apps.head.getExecutorLimit === 1) + // request 1 more + // Note: the cores are not evenly distributed because we scheduled these executors 1 by 1 + assert(sc.requestExecutors(1)) + assert(master.apps.head.executors.size === 2) + assert(master.apps.head.executors.values.map(_.cores).toSet === Set(10, 6)) + assert(master.apps.head.getExecutorLimit === 2) + // request 1 more; this one won't go through + assert(sc.requestExecutors(1)) + assert(master.apps.head.executors.size === 2) + assert(master.apps.head.getExecutorLimit === 3) + // kill all existing executors; we should end up with 3 - 2 = 1 executor + assert(killAllExecutors(sc)) + assert(master.apps.head.executors.size === 1) + assert(master.apps.head.executors.values.head.cores === 10) + assert(master.apps.head.getExecutorLimit === 1) + // kill all executors again; this time we'll have 1 - 1 = 0 executors left + assert(killAllExecutors(sc)) + assert(master.apps.head.executors.size === 0) + assert(master.apps.head.getExecutorLimit === 0) + // request many more; this increases the limit well beyond the cluster capacity + assert(sc.requestExecutors(1000)) + assert(master.apps.head.executors.size === 2) + assert(master.apps.head.executors.values.map(_.cores).toArray === Array(8, 8)) + assert(master.apps.head.getExecutorLimit === 1000) + } + + test("dynamic allocation with cores per executor") { + sc = new SparkContext(appConf.set("spark.executor.cores", "2")) + val appId = sc.applicationId + assert(master.apps.size === 1) + assert(master.apps.head.id === appId) + assert(master.apps.head.executors.size === 10) // 20 cores total + assert(master.apps.head.getExecutorLimit === Int.MaxValue) + // kill all executors + assert(killAllExecutors(sc)) + assert(master.apps.head.executors.size === 0) + assert(master.apps.head.getExecutorLimit === 0) + // request 1 + assert(sc.requestExecutors(1)) + assert(master.apps.head.executors.size === 1) + assert(master.apps.head.getExecutorLimit === 1) + // request 3 more + assert(sc.requestExecutors(3)) + assert(master.apps.head.executors.size === 4) + assert(master.apps.head.getExecutorLimit === 4) + // request 10 more; only 6 will go through + assert(sc.requestExecutors(10)) + assert(master.apps.head.executors.size === 10) + assert(master.apps.head.getExecutorLimit === 14) + // kill 2 executors; we should get 2 back immediately + assert(killNExecutors(sc, 2)) + assert(master.apps.head.executors.size === 10) + assert(master.apps.head.getExecutorLimit === 12) + // kill 4 executors; we should end up with 12 - 4 = 8 executors + assert(killNExecutors(sc, 4)) + assert(master.apps.head.executors.size === 8) + assert(master.apps.head.getExecutorLimit === 8) + // kill all executors; this time we'll have 8 - 8 = 0 executors left + assert(killAllExecutors(sc)) + assert(master.apps.head.executors.size === 0) + assert(master.apps.head.getExecutorLimit === 0) + // request many more; this increases the limit well beyond the cluster capacity + assert(sc.requestExecutors(1000)) + assert(master.apps.head.executors.size === 10) + assert(master.apps.head.getExecutorLimit === 1000) + } + + test("dynamic allocation with cores per executor AND max cores") { + sc = new SparkContext(appConf + .set("spark.executor.cores", "2") + .set("spark.cores.max", "8")) + val appId = sc.applicationId + assert(master.apps.size === 1) + assert(master.apps.head.id === appId) + assert(master.apps.head.executors.size === 4) // 8 cores total + assert(master.apps.head.getExecutorLimit === Int.MaxValue) + // kill all executors + assert(killAllExecutors(sc)) + assert(master.apps.head.executors.size === 0) + assert(master.apps.head.getExecutorLimit === 0) + // request 1 + assert(sc.requestExecutors(1)) + assert(master.apps.head.executors.size === 1) + assert(master.apps.head.getExecutorLimit === 1) + // request 3 more + assert(sc.requestExecutors(3)) + assert(master.apps.head.executors.size === 4) + assert(master.apps.head.getExecutorLimit === 4) + // request 10 more; none will go through + assert(sc.requestExecutors(10)) + assert(master.apps.head.executors.size === 4) + assert(master.apps.head.getExecutorLimit === 14) + // kill all executors; 4 executors will be launched immediately + assert(killAllExecutors(sc)) + assert(master.apps.head.executors.size === 4) + assert(master.apps.head.getExecutorLimit === 10) + // ... and again + assert(killAllExecutors(sc)) + assert(master.apps.head.executors.size === 4) + assert(master.apps.head.getExecutorLimit === 6) + // ... and again; now we end up with 6 - 4 = 2 executors left + assert(killAllExecutors(sc)) + assert(master.apps.head.executors.size === 2) + assert(master.apps.head.getExecutorLimit === 2) + // ... and again; this time we have 2 - 2 = 0 executors left + assert(killAllExecutors(sc)) + assert(master.apps.head.executors.size === 0) + assert(master.apps.head.getExecutorLimit === 0) + // request many more; this increases the limit well beyond the cluster capacity + assert(sc.requestExecutors(1000)) + assert(master.apps.head.executors.size === 4) + assert(master.apps.head.getExecutorLimit === 1000) + } + + test("kill the same executor twice (SPARK-9795)") { + sc = new SparkContext(appConf) + val appId = sc.applicationId + assert(master.apps.size === 1) + assert(master.apps.head.id === appId) + assert(master.apps.head.executors.size === 2) + assert(master.apps.head.getExecutorLimit === Int.MaxValue) + // sync executors between the Master and the driver, needed because + // the driver refuses to kill executors it does not know about + syncExecutors(sc) + // kill the same executor twice + val executors = getExecutorIds(sc) + assert(executors.size === 2) + assert(sc.killExecutor(executors.head)) + assert(sc.killExecutor(executors.head)) + assert(master.apps.head.executors.size === 1) + // The limit should not be lowered twice + assert(master.apps.head.getExecutorLimit === 1) + } + + // =============================== + // | Utility methods for testing | + // =============================== + + /** Return a SparkConf for applications that want to talk to our Master. */ + private def appConf: SparkConf = { + new SparkConf() + .setMaster(masterRpcEnv.address.toSparkURL) + .setAppName("test") + .set("spark.executor.memory", "256m") + } + + /** Make a master to which our application will send executor requests. */ + private def makeMaster(): Master = { + val master = new Master(masterRpcEnv, masterRpcEnv.address, 0, securityManager, conf) + masterRpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master) + master + } + + /** Make a few workers that talk to our master. */ + private def makeWorkers(cores: Int, memory: Int): Seq[Worker] = { + (0 until numWorkers).map { i => + val rpcEnv = workerRpcEnvs(i) + val worker = new Worker(rpcEnv, 0, cores, memory, Array(masterRpcEnv.address), + Worker.SYSTEM_NAME + i, Worker.ENDPOINT_NAME, null, conf, securityManager) + rpcEnv.setupEndpoint(Worker.ENDPOINT_NAME, worker) + worker + } + } + + /** Kill all executors belonging to this application. */ + private def killAllExecutors(sc: SparkContext): Boolean = { + killNExecutors(sc, Int.MaxValue) + } + + /** Kill N executors belonging to this application. */ + private def killNExecutors(sc: SparkContext, n: Int): Boolean = { + syncExecutors(sc) + sc.killExecutors(getExecutorIds(sc).take(n)) + } + + /** + * Return a list of executor IDs belonging to this application. + * + * Note that we must use the executor IDs according to the Master, which has the most + * updated view. We cannot rely on the executor IDs according to the driver because we + * don't wait for executors to register. Otherwise the tests will take much longer to run. + */ + private def getExecutorIds(sc: SparkContext): Seq[String] = { + assert(master.idToApp.contains(sc.applicationId)) + master.idToApp(sc.applicationId).executors.keys.map(_.toString).toSeq + } + + /** + * Sync executor IDs between the driver and the Master. + * + * This allows us to avoid waiting for new executors to register with the driver before + * we submit a request to kill them. This must be called before each kill request. + */ + private def syncExecutors(sc: SparkContext): Unit = { + val driverExecutors = sc.getExecutorStorageStatus + .map(_.blockManagerId.executorId) + .filter { _ != SparkContext.DRIVER_IDENTIFIER} + val masterExecutors = getExecutorIds(sc) + val missingExecutors = masterExecutors.toSet.diff(driverExecutors.toSet).toSeq.sorted + missingExecutors.foreach { id => + // Fake an executor registration so the driver knows about us + val port = System.currentTimeMillis % 65536 + val endpointRef = mock(classOf[RpcEndpointRef]) + val mockAddress = mock(classOf[RpcAddress]) + when(endpointRef.address).thenReturn(mockAddress) + val message = RegisterExecutor(id, endpointRef, s"localhost:$port", 10, Map.empty) + val backend = sc.schedulerBackend.asInstanceOf[CoarseGrainedSchedulerBackend] + backend.driverEndpoint.askWithRetry[CoarseGrainedClusterMessage](message) + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index d3a6db5f260d..73cff89544dc 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -39,6 +39,8 @@ import org.apache.spark.util.{JsonProtocol, ManualClock, Utils} class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { + import FsHistoryProvider._ + private var testDir: File = null before { @@ -67,8 +69,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc // Write a new-style application log. val newAppComplete = newLogFile("new1", None, inProgress = false) writeFile(newAppComplete, true, None, - SparkListenerApplicationStart( - "new-app-complete", Some("new-app-complete"), 1L, "test", None), + SparkListenerApplicationStart(newAppComplete.getName(), Some("new-app-complete"), 1L, "test", + None), SparkListenerApplicationEnd(5L) ) @@ -76,39 +78,30 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val newAppCompressedComplete = newLogFile("new1compressed", None, inProgress = false, Some("lzf")) writeFile(newAppCompressedComplete, true, None, - SparkListenerApplicationStart( - "new-app-compressed-complete", Some("new-app-compressed-complete"), 1L, "test", None), + SparkListenerApplicationStart(newAppCompressedComplete.getName(), Some("new-complete-lzf"), + 1L, "test", None), SparkListenerApplicationEnd(4L)) // Write an unfinished app, new-style. val newAppIncomplete = newLogFile("new2", None, inProgress = true) writeFile(newAppIncomplete, true, None, - SparkListenerApplicationStart( - "new-app-incomplete", Some("new-app-incomplete"), 1L, "test", None) + SparkListenerApplicationStart(newAppIncomplete.getName(), Some("new-incomplete"), 1L, "test", + None) ) // Write an old-style application log. - val oldAppComplete = new File(testDir, "old1") - oldAppComplete.mkdir() - createEmptyFile(new File(oldAppComplete, provider.SPARK_VERSION_PREFIX + "1.0")) - writeFile(new File(oldAppComplete, provider.LOG_PREFIX + "1"), false, None, - SparkListenerApplicationStart( - "old-app-complete", Some("old-app-complete"), 2L, "test", None), + val oldAppComplete = writeOldLog("old1", "1.0", None, true, + SparkListenerApplicationStart("old1", Some("old-app-complete"), 2L, "test", None), SparkListenerApplicationEnd(3L) ) - createEmptyFile(new File(oldAppComplete, provider.APPLICATION_COMPLETE)) // Check for logs so that we force the older unfinished app to be loaded, to make // sure unfinished apps are also sorted correctly. provider.checkForLogs() // Write an unfinished app, old-style. - val oldAppIncomplete = new File(testDir, "old2") - oldAppIncomplete.mkdir() - createEmptyFile(new File(oldAppIncomplete, provider.SPARK_VERSION_PREFIX + "1.0")) - writeFile(new File(oldAppIncomplete, provider.LOG_PREFIX + "1"), false, None, - SparkListenerApplicationStart( - "old-app-incomplete", Some("old-app-incomplete"), 2L, "test", None) + val oldAppIncomplete = writeOldLog("old2", "1.0", None, false, + SparkListenerApplicationStart("old2", None, 2L, "test", None) ) // Force a reload of data from the log directory, and check that both logs are loaded. @@ -129,16 +122,15 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc List(ApplicationAttemptInfo(None, start, end, lastMod, user, completed))) } - list(0) should be (makeAppInfo("new-app-complete", "new-app-complete", 1L, 5L, + list(0) should be (makeAppInfo("new-app-complete", newAppComplete.getName(), 1L, 5L, newAppComplete.lastModified(), "test", true)) - list(1) should be (makeAppInfo("new-app-compressed-complete", - "new-app-compressed-complete", 1L, 4L, newAppCompressedComplete.lastModified(), "test", - true)) - list(2) should be (makeAppInfo("old-app-complete", "old-app-complete", 2L, 3L, + list(1) should be (makeAppInfo("new-complete-lzf", newAppCompressedComplete.getName(), + 1L, 4L, newAppCompressedComplete.lastModified(), "test", true)) + list(2) should be (makeAppInfo("old-app-complete", oldAppComplete.getName(), 2L, 3L, oldAppComplete.lastModified(), "test", true)) - list(3) should be (makeAppInfo("old-app-incomplete", "old-app-incomplete", 2L, -1L, - oldAppIncomplete.lastModified(), "test", false)) - list(4) should be (makeAppInfo("new-app-incomplete", "new-app-incomplete", 1L, -1L, + list(3) should be (makeAppInfo(oldAppIncomplete.getName(), oldAppIncomplete.getName(), 2L, + -1L, oldAppIncomplete.lastModified(), "test", false)) + list(4) should be (makeAppInfo("new-incomplete", newAppIncomplete.getName(), 1L, -1L, newAppIncomplete.lastModified(), "test", false)) // Make sure the UI can be rendered. @@ -160,12 +152,12 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val codec = if (valid) CompressionCodec.createCodec(new SparkConf(), codecName) else null val logDir = new File(testDir, codecName) logDir.mkdir() - createEmptyFile(new File(logDir, provider.SPARK_VERSION_PREFIX + "1.0")) - writeFile(new File(logDir, provider.LOG_PREFIX + "1"), false, Option(codec), - SparkListenerApplicationStart("app2", Some("app2"), 2L, "test", None), + createEmptyFile(new File(logDir, SPARK_VERSION_PREFIX + "1.0")) + writeFile(new File(logDir, LOG_PREFIX + "1"), false, Option(codec), + SparkListenerApplicationStart("app2", None, 2L, "test", None), SparkListenerApplicationEnd(3L) ) - createEmptyFile(new File(logDir, provider.COMPRESSION_CODEC_PREFIX + codecName)) + createEmptyFile(new File(logDir, COMPRESSION_CODEC_PREFIX + codecName)) val logPath = new Path(logDir.getAbsolutePath()) try { @@ -251,13 +243,12 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc appListAfterRename.size should be (1) } - test("apps with multiple attempts") { + test("apps with multiple attempts with order") { val provider = new FsHistoryProvider(createTestConf()) - val attempt1 = newLogFile("app1", Some("attempt1"), inProgress = false) + val attempt1 = newLogFile("app1", Some("attempt1"), inProgress = true) writeFile(attempt1, true, None, - SparkListenerApplicationStart("app1", Some("app1"), 1L, "test", Some("attempt1")), - SparkListenerApplicationEnd(2L) + SparkListenerApplicationStart("app1", Some("app1"), 1L, "test", Some("attempt1")) ) updateAndCheck(provider) { list => @@ -267,7 +258,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val attempt2 = newLogFile("app1", Some("attempt2"), inProgress = true) writeFile(attempt2, true, None, - SparkListenerApplicationStart("app1", Some("app1"), 3L, "test", Some("attempt2")) + SparkListenerApplicationStart("app1", Some("app1"), 2L, "test", Some("attempt2")) ) updateAndCheck(provider) { list => @@ -276,22 +267,21 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc list.head.attempts.head.attemptId should be (Some("attempt2")) } - val completedAttempt2 = newLogFile("app1", Some("attempt2"), inProgress = false) - attempt2.delete() - writeFile(attempt2, true, None, - SparkListenerApplicationStart("app1", Some("app1"), 3L, "test", Some("attempt2")), + val attempt3 = newLogFile("app1", Some("attempt3"), inProgress = false) + writeFile(attempt3, true, None, + SparkListenerApplicationStart("app1", Some("app1"), 3L, "test", Some("attempt3")), SparkListenerApplicationEnd(4L) ) updateAndCheck(provider) { list => list should not be (null) list.size should be (1) - list.head.attempts.size should be (2) - list.head.attempts.head.attemptId should be (Some("attempt2")) + list.head.attempts.size should be (3) + list.head.attempts.head.attemptId should be (Some("attempt3")) } val app2Attempt1 = newLogFile("app2", Some("attempt1"), inProgress = false) - writeFile(attempt2, true, None, + writeFile(attempt1, true, None, SparkListenerApplicationStart("app2", Some("app2"), 5L, "test", Some("attempt1")), SparkListenerApplicationEnd(6L) ) @@ -299,7 +289,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc updateAndCheck(provider) { list => list.size should be (2) list.head.attempts.size should be (1) - list.last.attempts.size should be (2) + list.last.attempts.size should be (3) list.head.attempts.head.attemptId should be (Some("attempt1")) list.foreach { case app => @@ -390,6 +380,33 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } + test("SPARK-8372: new logs with no app ID are ignored") { + val provider = new FsHistoryProvider(createTestConf()) + + // Write a new log file without an app id, to make sure it's ignored. + val logFile1 = newLogFile("app1", None, inProgress = true) + writeFile(logFile1, true, None, + SparkListenerLogStart("1.4") + ) + + // Write a 1.2 log file with no start event (= no app id), it should be ignored. + writeOldLog("v12Log", "1.2", None, false) + + // Write 1.0 and 1.1 logs, which don't have app ids. + writeOldLog("v11Log", "1.1", None, true, + SparkListenerApplicationStart("v11Log", None, 2L, "test", None), + SparkListenerApplicationEnd(3L)) + writeOldLog("v10Log", "1.0", None, true, + SparkListenerApplicationStart("v10Log", None, 2L, "test", None), + SparkListenerApplicationEnd(4L)) + + updateAndCheck(provider) { list => + list.size should be (2) + list(0).id should be ("v10Log") + list(1).id should be ("v11Log") + } + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: @@ -429,4 +446,23 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc new SparkConf().set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) } + private def writeOldLog( + fname: String, + sparkVersion: String, + codec: Option[CompressionCodec], + completed: Boolean, + events: SparkListenerEvent*): File = { + val log = new File(testDir, fname) + log.mkdir() + + val oldEventLog = new File(log, LOG_PREFIX + "1") + createEmptyFile(new File(log, SPARK_VERSION_PREFIX + sparkVersion)) + writeFile(new File(log, LOG_PREFIX + "1"), false, codec, events: _*) + if (completed) { + createEmptyFile(new File(log, APPLICATION_COMPLETE)) + } + + log + } + } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala index f4e56632e426..4b86da536768 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala @@ -19,18 +19,19 @@ // when they are outside of org.apache.spark. package other.supplier +import java.nio.ByteBuffer + import scala.collection.mutable import scala.reflect.ClassTag -import akka.serialization.Serialization - import org.apache.spark.SparkConf import org.apache.spark.deploy.master._ +import org.apache.spark.serializer.Serializer class CustomRecoveryModeFactory( conf: SparkConf, - serialization: Serialization -) extends StandaloneRecoveryModeFactory(conf, serialization) { + serializer: Serializer +) extends StandaloneRecoveryModeFactory(conf, serializer) { CustomRecoveryModeFactory.instantiationAttempts += 1 @@ -40,7 +41,7 @@ class CustomRecoveryModeFactory( * */ override def createPersistenceEngine(): PersistenceEngine = - new CustomPersistenceEngine(serialization) + new CustomPersistenceEngine(serializer) /** * Create an instance of LeaderAgent that decides who gets elected as master. @@ -53,7 +54,7 @@ object CustomRecoveryModeFactory { @volatile var instantiationAttempts = 0 } -class CustomPersistenceEngine(serialization: Serialization) extends PersistenceEngine { +class CustomPersistenceEngine(serializer: Serializer) extends PersistenceEngine { val data = mutable.HashMap[String, Array[Byte]]() CustomPersistenceEngine.lastInstance = Some(this) @@ -64,10 +65,10 @@ class CustomPersistenceEngine(serialization: Serialization) extends PersistenceE */ override def persist(name: String, obj: Object): Unit = { CustomPersistenceEngine.persistAttempts += 1 - serialization.serialize(obj) match { - case util.Success(bytes) => data += name -> bytes - case util.Failure(cause) => throw new RuntimeException(cause) - } + val serialized = serializer.newInstance().serialize(obj) + val bytes = new Array[Byte](serialized.remaining()) + serialized.get(bytes) + data += name -> bytes } /** @@ -84,15 +85,9 @@ class CustomPersistenceEngine(serialization: Serialization) extends PersistenceE */ override def read[T: ClassTag](prefix: String): Seq[T] = { CustomPersistenceEngine.readAttempts += 1 - val clazz = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]] val results = for ((name, bytes) <- data; if name.startsWith(prefix)) - yield serialization.deserialize(bytes, clazz) - - results.find(_.isFailure).foreach { - case util.Failure(cause) => throw new RuntimeException(cause) - } - - results.flatMap(_.toOption).toSeq + yield serializer.newInstance().deserialize[T](ByteBuffer.wrap(bytes)) + results.toSeq } } @@ -104,7 +99,7 @@ object CustomPersistenceEngine { @volatile var lastInstance: Option[CustomPersistenceEngine] = None } -class CustomLeaderElectionAgent(val masterActor: LeaderElectable) extends LeaderElectionAgent { - masterActor.electedLeader() +class CustomLeaderElectionAgent(val masterInstance: LeaderElectable) extends LeaderElectionAgent { + masterInstance.electedLeader() } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 014e87bb4025..242bf4b5566e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -19,68 +19,28 @@ package org.apache.spark.deploy.master import java.util.Date -import scala.concurrent.Await import scala.concurrent.duration._ import scala.io.Source import scala.language.postfixOps -import akka.actor.Address import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.scalatest.Matchers +import org.scalatest.{Matchers, PrivateMethodTester} import org.scalatest.concurrent.Eventually import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory} -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy._ +import org.apache.spark.rpc.RpcEnv -class MasterSuite extends SparkFunSuite with Matchers with Eventually { - - test("toAkkaUrl") { - val conf = new SparkConf(loadDefaults = false) - val akkaUrl = Master.toAkkaUrl("spark://1.2.3.4:1234", "akka.tcp") - assert("akka.tcp://sparkMaster@1.2.3.4:1234/user/Master" === akkaUrl) - } - - test("toAkkaUrl with SSL") { - val conf = new SparkConf(loadDefaults = false) - val akkaUrl = Master.toAkkaUrl("spark://1.2.3.4:1234", "akka.ssl.tcp") - assert("akka.ssl.tcp://sparkMaster@1.2.3.4:1234/user/Master" === akkaUrl) - } - - test("toAkkaUrl: a typo url") { - val conf = new SparkConf(loadDefaults = false) - val e = intercept[SparkException] { - Master.toAkkaUrl("spark://1.2. 3.4:1234", "akka.tcp") - } - assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage) - } - - test("toAkkaAddress") { - val conf = new SparkConf(loadDefaults = false) - val address = Master.toAkkaAddress("spark://1.2.3.4:1234", "akka.tcp") - assert(Address("akka.tcp", "sparkMaster", "1.2.3.4", 1234) === address) - } - - test("toAkkaAddress with SSL") { - val conf = new SparkConf(loadDefaults = false) - val address = Master.toAkkaAddress("spark://1.2.3.4:1234", "akka.ssl.tcp") - assert(Address("akka.ssl.tcp", "sparkMaster", "1.2.3.4", 1234) === address) - } - - test("toAkkaAddress: a typo url") { - val conf = new SparkConf(loadDefaults = false) - val e = intercept[SparkException] { - Master.toAkkaAddress("spark://1.2. 3.4:1234", "akka.tcp") - } - assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage) - } +class MasterSuite extends SparkFunSuite with Matchers with Eventually with PrivateMethodTester { test("can use a custom recovery mode factory") { val conf = new SparkConf(loadDefaults = false) conf.set("spark.deploy.recoveryMode", "CUSTOM") conf.set("spark.deploy.recoveryMode.factory", classOf[CustomRecoveryModeFactory].getCanonicalName) + conf.set("spark.master.rest.enabled", "false") val instantiationAttempts = CustomRecoveryModeFactory.instantiationAttempts @@ -129,16 +89,16 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually { port = 10000, cores = 0, memory = 0, - actor = null, + endpoint = null, webUiPort = 0, publicAddress = "" ) - val (actorSystem, port, uiPort, restPort) = - Master.startSystemAndActor("127.0.0.1", 7077, 8080, conf) + val (rpcEnv, _, _) = + Master.startRpcEnvAndEndpoint("127.0.0.1", 0, 0, conf) try { - Await.result(actorSystem.actorSelection("/user/Master").resolveOne(10 seconds), 10 seconds) + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, rpcEnv.address, Master.ENDPOINT_NAME) CustomPersistenceEngine.lastInstance.isDefined shouldBe true val persistenceEngine = CustomPersistenceEngine.lastInstance.get @@ -147,21 +107,21 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually { persistenceEngine.addDriver(driverToPersist) persistenceEngine.addWorker(workerToPersist) - val (apps, drivers, workers) = persistenceEngine.readPersistedData() + val (apps, drivers, workers) = persistenceEngine.readPersistedData(rpcEnv) apps.map(_.id) should contain(appToPersist.id) drivers.map(_.id) should contain(driverToPersist.id) workers.map(_.id) should contain(workerToPersist.id) } finally { - actorSystem.shutdown() - actorSystem.awaitTermination() + rpcEnv.shutdown() + rpcEnv.awaitTermination() } CustomRecoveryModeFactory.instantiationAttempts should be > instantiationAttempts } - test("Master & worker web ui available") { + test("master/worker web ui available") { implicit val formats = org.json4s.DefaultFormats val conf = new SparkConf() val localCluster = new LocalSparkCluster(2, 2, 512, conf) @@ -184,4 +144,247 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually { } } + test("basic scheduling - spread out") { + basicScheduling(spreadOut = true) + } + + test("basic scheduling - no spread out") { + basicScheduling(spreadOut = false) + } + + test("basic scheduling with more memory - spread out") { + basicSchedulingWithMoreMemory(spreadOut = true) + } + + test("basic scheduling with more memory - no spread out") { + basicSchedulingWithMoreMemory(spreadOut = false) + } + + test("scheduling with max cores - spread out") { + schedulingWithMaxCores(spreadOut = true) + } + + test("scheduling with max cores - no spread out") { + schedulingWithMaxCores(spreadOut = false) + } + + test("scheduling with cores per executor - spread out") { + schedulingWithCoresPerExecutor(spreadOut = true) + } + + test("scheduling with cores per executor - no spread out") { + schedulingWithCoresPerExecutor(spreadOut = false) + } + + test("scheduling with cores per executor AND max cores - spread out") { + schedulingWithCoresPerExecutorAndMaxCores(spreadOut = true) + } + + test("scheduling with cores per executor AND max cores - no spread out") { + schedulingWithCoresPerExecutorAndMaxCores(spreadOut = false) + } + + test("scheduling with executor limit - spread out") { + schedulingWithExecutorLimit(spreadOut = true) + } + + test("scheduling with executor limit - no spread out") { + schedulingWithExecutorLimit(spreadOut = false) + } + + test("scheduling with executor limit AND max cores - spread out") { + schedulingWithExecutorLimitAndMaxCores(spreadOut = true) + } + + test("scheduling with executor limit AND max cores - no spread out") { + schedulingWithExecutorLimitAndMaxCores(spreadOut = false) + } + + test("scheduling with executor limit AND cores per executor - spread out") { + schedulingWithExecutorLimitAndCoresPerExecutor(spreadOut = true) + } + + test("scheduling with executor limit AND cores per executor - no spread out") { + schedulingWithExecutorLimitAndCoresPerExecutor(spreadOut = false) + } + + test("scheduling with executor limit AND cores per executor AND max cores - spread out") { + schedulingWithEverything(spreadOut = true) + } + + test("scheduling with executor limit AND cores per executor AND max cores - no spread out") { + schedulingWithEverything(spreadOut = false) + } + + private def basicScheduling(spreadOut: Boolean): Unit = { + val master = makeMaster() + val appInfo = makeAppInfo(1024) + val scheduledCores = scheduleExecutorsOnWorkers(master, appInfo, workerInfos, spreadOut) + assert(scheduledCores === Array(10, 10, 10)) + } + + private def basicSchedulingWithMoreMemory(spreadOut: Boolean): Unit = { + val master = makeMaster() + val appInfo = makeAppInfo(3072) + val scheduledCores = scheduleExecutorsOnWorkers(master, appInfo, workerInfos, spreadOut) + assert(scheduledCores === Array(10, 10, 10)) + } + + private def schedulingWithMaxCores(spreadOut: Boolean): Unit = { + val master = makeMaster() + val appInfo1 = makeAppInfo(1024, maxCores = Some(8)) + val appInfo2 = makeAppInfo(1024, maxCores = Some(16)) + val scheduledCores1 = scheduleExecutorsOnWorkers(master, appInfo1, workerInfos, spreadOut) + val scheduledCores2 = scheduleExecutorsOnWorkers(master, appInfo2, workerInfos, spreadOut) + if (spreadOut) { + assert(scheduledCores1 === Array(3, 3, 2)) + assert(scheduledCores2 === Array(6, 5, 5)) + } else { + assert(scheduledCores1 === Array(8, 0, 0)) + assert(scheduledCores2 === Array(10, 6, 0)) + } + } + + private def schedulingWithCoresPerExecutor(spreadOut: Boolean): Unit = { + val master = makeMaster() + val appInfo1 = makeAppInfo(1024, coresPerExecutor = Some(2)) + val appInfo2 = makeAppInfo(256, coresPerExecutor = Some(2)) + val appInfo3 = makeAppInfo(256, coresPerExecutor = Some(3)) + val scheduledCores1 = scheduleExecutorsOnWorkers(master, appInfo1, workerInfos, spreadOut) + val scheduledCores2 = scheduleExecutorsOnWorkers(master, appInfo2, workerInfos, spreadOut) + val scheduledCores3 = scheduleExecutorsOnWorkers(master, appInfo3, workerInfos, spreadOut) + assert(scheduledCores1 === Array(8, 8, 8)) // 4 * 2 because of memory limits + assert(scheduledCores2 === Array(10, 10, 10)) // 5 * 2 + assert(scheduledCores3 === Array(9, 9, 9)) // 3 * 3 + } + + // Sorry for the long method name! + private def schedulingWithCoresPerExecutorAndMaxCores(spreadOut: Boolean): Unit = { + val master = makeMaster() + val appInfo1 = makeAppInfo(256, coresPerExecutor = Some(2), maxCores = Some(4)) + val appInfo2 = makeAppInfo(256, coresPerExecutor = Some(2), maxCores = Some(20)) + val appInfo3 = makeAppInfo(256, coresPerExecutor = Some(3), maxCores = Some(20)) + val scheduledCores1 = scheduleExecutorsOnWorkers(master, appInfo1, workerInfos, spreadOut) + val scheduledCores2 = scheduleExecutorsOnWorkers(master, appInfo2, workerInfos, spreadOut) + val scheduledCores3 = scheduleExecutorsOnWorkers(master, appInfo3, workerInfos, spreadOut) + if (spreadOut) { + assert(scheduledCores1 === Array(2, 2, 0)) + assert(scheduledCores2 === Array(8, 6, 6)) + assert(scheduledCores3 === Array(6, 6, 6)) + } else { + assert(scheduledCores1 === Array(4, 0, 0)) + assert(scheduledCores2 === Array(10, 10, 0)) + assert(scheduledCores3 === Array(9, 9, 0)) + } + } + + private def schedulingWithExecutorLimit(spreadOut: Boolean): Unit = { + val master = makeMaster() + val appInfo = makeAppInfo(256) + appInfo.executorLimit = 0 + val scheduledCores1 = scheduleExecutorsOnWorkers(master, appInfo, workerInfos, spreadOut) + appInfo.executorLimit = 2 + val scheduledCores2 = scheduleExecutorsOnWorkers(master, appInfo, workerInfos, spreadOut) + appInfo.executorLimit = 5 + val scheduledCores3 = scheduleExecutorsOnWorkers(master, appInfo, workerInfos, spreadOut) + assert(scheduledCores1 === Array(0, 0, 0)) + assert(scheduledCores2 === Array(10, 10, 0)) + assert(scheduledCores3 === Array(10, 10, 10)) + } + + private def schedulingWithExecutorLimitAndMaxCores(spreadOut: Boolean): Unit = { + val master = makeMaster() + val appInfo = makeAppInfo(256, maxCores = Some(16)) + appInfo.executorLimit = 0 + val scheduledCores1 = scheduleExecutorsOnWorkers(master, appInfo, workerInfos, spreadOut) + appInfo.executorLimit = 2 + val scheduledCores2 = scheduleExecutorsOnWorkers(master, appInfo, workerInfos, spreadOut) + appInfo.executorLimit = 5 + val scheduledCores3 = scheduleExecutorsOnWorkers(master, appInfo, workerInfos, spreadOut) + assert(scheduledCores1 === Array(0, 0, 0)) + if (spreadOut) { + assert(scheduledCores2 === Array(8, 8, 0)) + assert(scheduledCores3 === Array(6, 5, 5)) + } else { + assert(scheduledCores2 === Array(10, 6, 0)) + assert(scheduledCores3 === Array(10, 6, 0)) + } + } + + private def schedulingWithExecutorLimitAndCoresPerExecutor(spreadOut: Boolean): Unit = { + val master = makeMaster() + val appInfo = makeAppInfo(256, coresPerExecutor = Some(4)) + appInfo.executorLimit = 0 + val scheduledCores1 = scheduleExecutorsOnWorkers(master, appInfo, workerInfos, spreadOut) + appInfo.executorLimit = 2 + val scheduledCores2 = scheduleExecutorsOnWorkers(master, appInfo, workerInfos, spreadOut) + appInfo.executorLimit = 5 + val scheduledCores3 = scheduleExecutorsOnWorkers(master, appInfo, workerInfos, spreadOut) + assert(scheduledCores1 === Array(0, 0, 0)) + if (spreadOut) { + assert(scheduledCores2 === Array(4, 4, 0)) + } else { + assert(scheduledCores2 === Array(8, 0, 0)) + } + assert(scheduledCores3 === Array(8, 8, 4)) + } + + // Everything being: executor limit + cores per executor + max cores + private def schedulingWithEverything(spreadOut: Boolean): Unit = { + val master = makeMaster() + val appInfo = makeAppInfo(256, coresPerExecutor = Some(4), maxCores = Some(18)) + appInfo.executorLimit = 0 + val scheduledCores1 = scheduleExecutorsOnWorkers(master, appInfo, workerInfos, spreadOut) + appInfo.executorLimit = 2 + val scheduledCores2 = scheduleExecutorsOnWorkers(master, appInfo, workerInfos, spreadOut) + appInfo.executorLimit = 5 + val scheduledCores3 = scheduleExecutorsOnWorkers(master, appInfo, workerInfos, spreadOut) + assert(scheduledCores1 === Array(0, 0, 0)) + if (spreadOut) { + assert(scheduledCores2 === Array(4, 4, 0)) + assert(scheduledCores3 === Array(8, 4, 4)) + } else { + assert(scheduledCores2 === Array(8, 0, 0)) + assert(scheduledCores3 === Array(8, 8, 0)) + } + } + + // ========================================== + // | Utility methods and fields for testing | + // ========================================== + + private val _scheduleExecutorsOnWorkers = PrivateMethod[Array[Int]]('scheduleExecutorsOnWorkers) + private val workerInfo = makeWorkerInfo(4096, 10) + private val workerInfos = Array(workerInfo, workerInfo, workerInfo) + + private def makeMaster(conf: SparkConf = new SparkConf): Master = { + val securityMgr = new SecurityManager(conf) + val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityMgr) + val master = new Master(rpcEnv, rpcEnv.address, 0, securityMgr, conf) + master + } + + private def makeAppInfo( + memoryPerExecutorMb: Int, + coresPerExecutor: Option[Int] = None, + maxCores: Option[Int] = None): ApplicationInfo = { + val desc = new ApplicationDescription( + "test", maxCores, memoryPerExecutorMb, null, "", None, None, coresPerExecutor) + val appId = System.currentTimeMillis.toString + new ApplicationInfo(0, appId, desc, new Date, null, Int.MaxValue) + } + + private def makeWorkerInfo(memoryMb: Int, cores: Int): WorkerInfo = { + val workerId = System.currentTimeMillis.toString + new WorkerInfo(workerId, "host", 100, cores, memoryMb, null, 101, "address") + } + + private def scheduleExecutorsOnWorkers( + master: Master, + appInfo: ApplicationInfo, + workerInfos: Array[WorkerInfo], + spreadOut: Boolean): Array[Int] = { + master.invokePrivate(_scheduleExecutorsOnWorkers(appInfo, workerInfos, spreadOut)) + } + } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala new file mode 100644 index 000000000000..34775577de8a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.deploy.master + +import java.net.ServerSocket + +import org.apache.commons.lang3.RandomUtils +import org.apache.curator.test.TestingServer + +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.rpc.{RpcEndpoint, RpcEnv} +import org.apache.spark.serializer.{Serializer, JavaSerializer} +import org.apache.spark.util.Utils + +class PersistenceEngineSuite extends SparkFunSuite { + + test("FileSystemPersistenceEngine") { + val dir = Utils.createTempDir() + try { + val conf = new SparkConf() + testPersistenceEngine(conf, serializer => + new FileSystemPersistenceEngine(dir.getAbsolutePath, serializer) + ) + } finally { + Utils.deleteRecursively(dir) + } + } + + test("ZooKeeperPersistenceEngine") { + val conf = new SparkConf() + // TestingServer logs the port conflict exception rather than throwing an exception. + // So we have to find a free port by ourselves. This approach cannot guarantee always starting + // zkTestServer successfully because there is a time gap between finding a free port and + // starting zkTestServer. But the failure possibility should be very low. + val zkTestServer = new TestingServer(findFreePort(conf)) + try { + testPersistenceEngine(conf, serializer => { + conf.set("spark.deploy.zookeeper.url", zkTestServer.getConnectString) + new ZooKeeperPersistenceEngine(conf, serializer) + }) + } finally { + zkTestServer.stop() + } + } + + private def testPersistenceEngine( + conf: SparkConf, persistenceEngineCreator: Serializer => PersistenceEngine): Unit = { + val serializer = new JavaSerializer(conf) + val persistenceEngine = persistenceEngineCreator(serializer) + persistenceEngine.persist("test_1", "test_1_value") + assert(Seq("test_1_value") === persistenceEngine.read[String]("test_")) + persistenceEngine.persist("test_2", "test_2_value") + assert(Set("test_1_value", "test_2_value") === persistenceEngine.read[String]("test_").toSet) + persistenceEngine.unpersist("test_1") + assert(Seq("test_2_value") === persistenceEngine.read[String]("test_")) + persistenceEngine.unpersist("test_2") + assert(persistenceEngine.read[String]("test_").isEmpty) + + // Test deserializing objects that contain RpcEndpointRef + val testRpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + try { + // Create a real endpoint so that we can test RpcEndpointRef deserialization + val workerEndpoint = testRpcEnv.setupEndpoint("worker", new RpcEndpoint { + override val rpcEnv: RpcEnv = testRpcEnv + }) + + val workerToPersist = new WorkerInfo( + id = "test_worker", + host = "127.0.0.1", + port = 10000, + cores = 0, + memory = 0, + endpoint = workerEndpoint, + webUiPort = 0, + publicAddress = "" + ) + + persistenceEngine.addWorker(workerToPersist) + + val (storedApps, storedDrivers, storedWorkers) = + persistenceEngine.readPersistedData(testRpcEnv) + + assert(storedApps.isEmpty) + assert(storedDrivers.isEmpty) + + // Check deserializing WorkerInfo + assert(storedWorkers.size == 1) + val recoveryWorkerInfo = storedWorkers.head + assert(workerToPersist.id === recoveryWorkerInfo.id) + assert(workerToPersist.host === recoveryWorkerInfo.host) + assert(workerToPersist.port === recoveryWorkerInfo.port) + assert(workerToPersist.cores === recoveryWorkerInfo.cores) + assert(workerToPersist.memory === recoveryWorkerInfo.memory) + assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint) + assert(workerToPersist.webUiPort === recoveryWorkerInfo.webUiPort) + assert(workerToPersist.publicAddress === recoveryWorkerInfo.publicAddress) + } finally { + testRpcEnv.shutdown() + testRpcEnv.awaitTermination() + } + } + + private def findFreePort(conf: SparkConf): Int = { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + }, conf)._2 + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 197f68e7ec5e..9693e32bf6af 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -23,14 +23,14 @@ import javax.servlet.http.HttpServletResponse import scala.collection.mutable -import akka.actor.{Actor, ActorRef, ActorSystem, Props} import com.google.common.base.Charsets import org.scalatest.BeforeAndAfterEach import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods._ import org.apache.spark._ -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.rpc._ +import org.apache.spark.util.Utils import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.{SparkSubmit, SparkSubmitArguments} import org.apache.spark.deploy.master.DriverState._ @@ -39,11 +39,11 @@ import org.apache.spark.deploy.master.DriverState._ * Tests for the REST application submission protocol used in standalone cluster mode. */ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { - private var actorSystem: Option[ActorSystem] = None + private var rpcEnv: Option[RpcEnv] = None private var server: Option[RestSubmissionServer] = None override def afterEach() { - actorSystem.foreach(_.shutdown()) + rpcEnv.foreach(_.shutdown()) server.foreach(_.stop()) } @@ -366,6 +366,18 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { assert(conn3.getResponseCode === HttpServletResponse.SC_INTERNAL_SERVER_ERROR) } + test("client does not send 'SPARK_ENV_LOADED' env var by default") { + val environmentVariables = Map("SPARK_VAR" -> "1", "SPARK_ENV_LOADED" -> "1") + val filteredVariables = RestSubmissionClient.filterSystemEnvironment(environmentVariables) + assert(filteredVariables == Map("SPARK_VAR" -> "1")) + } + + test("client includes mesos env vars") { + val environmentVariables = Map("SPARK_VAR" -> "1", "MESOS_VAR" -> "1", "OTHER_VAR" -> "1") + val filteredVariables = RestSubmissionClient.filterSystemEnvironment(environmentVariables) + assert(filteredVariables == Map("SPARK_VAR" -> "1", "MESOS_VAR" -> "1")) + } + /* --------------------- * | Helper methods | * --------------------- */ @@ -377,31 +389,32 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { killMessage: String = "driver is killed", state: DriverState = FINISHED, exception: Option[Exception] = None): String = { - startServer(new DummyMaster(submitId, submitMessage, killMessage, state, exception)) + startServer(new DummyMaster(_, submitId, submitMessage, killMessage, state, exception)) } /** Start a smarter dummy server that keeps track of submitted driver states. */ private def startSmartServer(): String = { - startServer(new SmarterMaster) + startServer(new SmarterMaster(_)) } /** Start a dummy server that is faulty in many ways... */ private def startFaultyServer(): String = { - startServer(new DummyMaster, faulty = true) + startServer(new DummyMaster(_), faulty = true) } /** - * Start a [[StandaloneRestServer]] that communicates with the given actor. + * Start a [[StandaloneRestServer]] that communicates with the given endpoint. * If `faulty` is true, start an [[FaultyStandaloneRestServer]] instead. * Return the master URL that corresponds to the address of this server. */ - private def startServer(makeFakeMaster: => Actor, faulty: Boolean = false): String = { + private def startServer( + makeFakeMaster: RpcEnv => RpcEndpoint, faulty: Boolean = false): String = { val name = "test-standalone-rest-protocol" val conf = new SparkConf val localhost = Utils.localHostName() val securityManager = new SecurityManager(conf) - val (_actorSystem, _) = AkkaUtils.createActorSystem(name, localhost, 0, conf, securityManager) - val fakeMasterRef = _actorSystem.actorOf(Props(makeFakeMaster)) + val _rpcEnv = RpcEnv.create(name, localhost, 0, conf, securityManager) + val fakeMasterRef = _rpcEnv.setupEndpoint("fake-master", makeFakeMaster(_rpcEnv)) val _server = if (faulty) { new FaultyStandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077") @@ -410,7 +423,7 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { } val port = _server.start() // set these to clean them up after every test - actorSystem = Some(_actorSystem) + rpcEnv = Some(_rpcEnv) server = Some(_server) s"spark://$localhost:$port" } @@ -505,20 +518,21 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { * In all responses, the success parameter is always true. */ private class DummyMaster( + override val rpcEnv: RpcEnv, submitId: String = "fake-driver-id", submitMessage: String = "submitted", killMessage: String = "killed", state: DriverState = FINISHED, exception: Option[Exception] = None) - extends Actor { + extends RpcEndpoint { - override def receive: PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RequestSubmitDriver(driverDesc) => - sender ! SubmitDriverResponse(success = true, Some(submitId), submitMessage) + context.reply(SubmitDriverResponse(self, success = true, Some(submitId), submitMessage)) case RequestKillDriver(driverId) => - sender ! KillDriverResponse(driverId, success = true, killMessage) + context.reply(KillDriverResponse(self, driverId, success = true, killMessage)) case RequestDriverStatus(driverId) => - sender ! DriverStatusResponse(found = true, Some(state), None, None, exception) + context.reply(DriverStatusResponse(found = true, Some(state), None, None, exception)) } } @@ -531,28 +545,28 @@ private class DummyMaster( * Submits are always successful while kills and status requests are successful only * if the driver was submitted in the past. */ -private class SmarterMaster extends Actor { +private class SmarterMaster(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint { private var counter: Int = 0 private val submittedDrivers = new mutable.HashMap[String, DriverState] - override def receive: PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RequestSubmitDriver(driverDesc) => val driverId = s"driver-$counter" submittedDrivers(driverId) = RUNNING counter += 1 - sender ! SubmitDriverResponse(success = true, Some(driverId), "submitted") + context.reply(SubmitDriverResponse(self, success = true, Some(driverId), "submitted")) case RequestKillDriver(driverId) => val success = submittedDrivers.contains(driverId) if (success) { submittedDrivers(driverId) = KILLED } - sender ! KillDriverResponse(driverId, success, "killed") + context.reply(KillDriverResponse(self, driverId, success, "killed")) case RequestDriverStatus(driverId) => val found = submittedDrivers.contains(driverId) val state = submittedDrivers.get(driverId) - sender ! DriverStatusResponse(found, state, None, None, None) + context.reply(DriverStatusResponse(found, state, None, None, None)) } } @@ -568,7 +582,7 @@ private class FaultyStandaloneRestServer( host: String, requestedPort: Int, masterConf: SparkConf, - masterActor: ActorRef, + masterEndpoint: RpcEndpointRef, masterUrl: String) extends RestSubmissionServer(host, requestedPort, masterConf) { @@ -578,7 +592,7 @@ private class FaultyStandaloneRestServer( /** A faulty servlet that produces malformed responses. */ class MalformedSubmitServlet - extends StandaloneSubmitRequestServlet(masterActor, masterUrl, masterConf) { + extends StandaloneSubmitRequestServlet(masterEndpoint, masterUrl, masterConf) { protected override def sendResponse( responseMessage: SubmitRestProtocolResponse, responseServlet: HttpServletResponse): Unit = { @@ -588,7 +602,7 @@ private class FaultyStandaloneRestServer( } /** A faulty servlet that produces invalid responses. */ - class InvalidKillServlet extends StandaloneKillRequestServlet(masterActor, masterConf) { + class InvalidKillServlet extends StandaloneKillRequestServlet(masterEndpoint, masterConf) { protected override def handleKill(submissionId: String): KillSubmissionResponse = { val k = super.handleKill(submissionId) k.submissionId = null @@ -597,7 +611,7 @@ private class FaultyStandaloneRestServer( } /** A faulty status servlet that explodes. */ - class ExplodingStatusServlet extends StandaloneStatusRequestServlet(masterActor, masterConf) { + class ExplodingStatusServlet extends StandaloneStatusRequestServlet(masterEndpoint, masterConf) { private def explode: Int = 1 / 0 protected override def handleStatus(submissionId: String): SubmissionStatusResponse = { val s = super.handleStatus(submissionId) diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index 115ac0534a1b..725b8848bc05 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.deploy.rest import java.lang.Boolean -import java.lang.Integer import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.util.Utils /** * Tests for the REST application submission protocol. @@ -93,7 +93,7 @@ class SubmitRestProtocolSuite extends SparkFunSuite { // optional fields conf.set("spark.jars", "mayonnaise.jar,ketchup.jar") conf.set("spark.files", "fireball.png") - conf.set("spark.driver.memory", "512m") + conf.set("spark.driver.memory", s"${Utils.DEFAULT_DRIVER_MEM_MB}m") conf.set("spark.driver.cores", "180") conf.set("spark.driver.extraJavaOptions", " -Dslices=5 -Dcolor=mostly_red") conf.set("spark.driver.extraClassPath", "food-coloring.jar") @@ -126,7 +126,7 @@ class SubmitRestProtocolSuite extends SparkFunSuite { assert(newMessage.sparkProperties("spark.app.name") === "SparkPie") assert(newMessage.sparkProperties("spark.jars") === "mayonnaise.jar,ketchup.jar") assert(newMessage.sparkProperties("spark.files") === "fireball.png") - assert(newMessage.sparkProperties("spark.driver.memory") === "512m") + assert(newMessage.sparkProperties("spark.driver.memory") === s"${Utils.DEFAULT_DRIVER_MEM_MB}m") assert(newMessage.sparkProperties("spark.driver.cores") === "180") assert(newMessage.sparkProperties("spark.driver.extraJavaOptions") === " -Dslices=5 -Dcolor=mostly_red") @@ -230,7 +230,7 @@ class SubmitRestProtocolSuite extends SparkFunSuite { """.stripMargin private val submitDriverRequestJson = - """ + s""" |{ | "action" : "CreateSubmissionRequest", | "appArgs" : [ "two slices", "a hint of cinnamon" ], @@ -246,7 +246,7 @@ class SubmitRestProtocolSuite extends SparkFunSuite { | "spark.driver.supervise" : "false", | "spark.app.name" : "SparkPie", | "spark.cores.max" : "10000", - | "spark.driver.memory" : "512m", + | "spark.driver.memory" : "${Utils.DEFAULT_DRIVER_MEM_MB}m", | "spark.files" : "fireball.png", | "spark.driver.cores" : "180", | "spark.driver.extraJavaOptions" : " -Dslices=5 -Dcolor=mostly_red", diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index bed6f3ea6124..98664dc1101e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -19,8 +19,6 @@ package org.apache.spark.deploy.worker import java.io.File -import scala.collection.JavaConversions._ - import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} @@ -36,6 +34,7 @@ class ExecutorRunnerTest extends SparkFunSuite { ExecutorState.RUNNING) val builder = CommandUtils.buildProcessBuilder( appDesc.command, new SecurityManager(conf), 512, sparkHome, er.substituteVariables) - assert(builder.command().last === appId) + val builderCommand = builder.command() + assert(builderCommand.get(builderCommand.size() - 1) === appId) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index 0f4d3b28d09d..faed4bdc6844 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -17,13 +17,18 @@ package org.apache.spark.deploy.worker -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.Command - import org.scalatest.Matchers +import org.apache.spark.deploy.DeployMessages.{DriverStateChanged, ExecutorStateChanged} +import org.apache.spark.deploy.master.DriverState +import org.apache.spark.deploy.{Command, ExecutorState} +import org.apache.spark.rpc.{RpcAddress, RpcEnv} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} + class WorkerSuite extends SparkFunSuite with Matchers { + import org.apache.spark.deploy.DeployTestUtils._ + def cmd(javaOpts: String*): Command = { Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts : _*)) } @@ -56,4 +61,126 @@ class WorkerSuite extends SparkFunSuite with Matchers { "-Dspark.ssl.useNodeLocalConf=true", "-Dspark.ssl.opt1=y", "-Dspark.ssl.opt2=z") } + + test("test clearing of finishedExecutors (small number of executors)") { + val conf = new SparkConf() + conf.set("spark.worker.ui.retainedExecutors", 2.toString) + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), + "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + // initialize workers + for (i <- 0 until 5) { + worker.executors += s"app1/$i" -> createExecutorRunner(i) + } + // initialize ExecutorStateChanged Message + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", 0, ExecutorState.EXITED, None, None)) + assert(worker.finishedExecutors.size === 1) + assert(worker.executors.size === 4) + for (i <- 1 until 5) { + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", i, ExecutorState.EXITED, None, None)) + assert(worker.finishedExecutors.size === 2) + if (i > 1) { + assert(!worker.finishedExecutors.contains(s"app1/${i - 2}")) + } + assert(worker.executors.size === 4 - i) + } + } + + test("test clearing of finishedExecutors (more executors)") { + val conf = new SparkConf() + conf.set("spark.worker.ui.retainedExecutors", 30.toString) + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), + "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + // initialize workers + for (i <- 0 until 50) { + worker.executors += s"app1/$i" -> createExecutorRunner(i) + } + // initialize ExecutorStateChanged Message + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", 0, ExecutorState.EXITED, None, None)) + assert(worker.finishedExecutors.size === 1) + assert(worker.executors.size === 49) + for (i <- 1 until 50) { + val expectedValue = { + if (worker.finishedExecutors.size < 30) { + worker.finishedExecutors.size + 1 + } else { + 28 + } + } + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", i, ExecutorState.EXITED, None, None)) + if (expectedValue == 28) { + for (j <- i - 30 until i - 27) { + assert(!worker.finishedExecutors.contains(s"app1/$j")) + } + } + assert(worker.executors.size === 49 - i) + assert(worker.finishedExecutors.size === expectedValue) + } + } + + test("test clearing of finishedDrivers (small number of drivers)") { + val conf = new SparkConf() + conf.set("spark.worker.ui.retainedDrivers", 2.toString) + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), + "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + // initialize workers + for (i <- 0 until 5) { + val driverId = s"driverId-$i" + worker.drivers += driverId -> createDriverRunner(driverId) + } + // initialize DriverStateChanged Message + worker.handleDriverStateChanged(DriverStateChanged("driverId-0", DriverState.FINISHED, None)) + assert(worker.drivers.size === 4) + assert(worker.finishedDrivers.size === 1) + for (i <- 1 until 5) { + val driverId = s"driverId-$i" + worker.handleDriverStateChanged(DriverStateChanged(driverId, DriverState.FINISHED, None)) + if (i > 1) { + assert(!worker.finishedDrivers.contains(s"driverId-${i - 2}")) + } + assert(worker.drivers.size === 4 - i) + assert(worker.finishedDrivers.size === 2) + } + } + + test("test clearing of finishedDrivers (more drivers)") { + val conf = new SparkConf() + conf.set("spark.worker.ui.retainedDrivers", 30.toString) + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), + "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + // initialize workers + for (i <- 0 until 50) { + val driverId = s"driverId-$i" + worker.drivers += driverId -> createDriverRunner(driverId) + } + // initialize DriverStateChanged Message + worker.handleDriverStateChanged(DriverStateChanged("driverId-0", DriverState.FINISHED, None)) + assert(worker.finishedDrivers.size === 1) + assert(worker.drivers.size === 49) + for (i <- 1 until 50) { + val expectedValue = { + if (worker.finishedDrivers.size < 30) { + worker.finishedDrivers.size + 1 + } else { + 28 + } + } + val driverId = s"driverId-$i" + worker.handleDriverStateChanged(DriverStateChanged(driverId, DriverState.FINISHED, None)) + if (expectedValue == 28) { + for (j <- i - 30 until i - 27) { + assert(!worker.finishedDrivers.contains(s"driverId-$j")) + } + } + assert(worker.drivers.size === 49 - i) + assert(worker.finishedDrivers.size === expectedValue) + } + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index ac18f04a1147..e9034e39a715 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.deploy.worker -import akka.actor.AddressFromURIString import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.SecurityManager import org.apache.spark.rpc.{RpcAddress, RpcEnv} @@ -26,13 +25,11 @@ class WorkerWatcherSuite extends SparkFunSuite { test("WorkerWatcher shuts down on valid disassociation") { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" - val targetWorkerAddress = AddressFromURIString(targetWorkerUrl) + val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) - workerWatcher.onDisconnected( - RpcAddress(targetWorkerAddress.host.get, targetWorkerAddress.port.get)) + workerWatcher.onDisconnected(RpcAddress("1.2.3.4", 1234)) assert(workerWatcher.isShutDown) rpcEnv.shutdown() } @@ -40,13 +37,12 @@ class WorkerWatcherSuite extends SparkFunSuite { test("WorkerWatcher stays alive on invalid disassociation") { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" - val otherAkkaURL = "akka://test@4.3.2.1:1234/user/OtherActor" - val otherAkkaAddress = AddressFromURIString(otherAkkaURL) + val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") + val otherRpcAddress = RpcAddress("4.3.2.1", 1234) val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) - workerWatcher.onDisconnected(RpcAddress(otherAkkaAddress.host.get, otherAkkaAddress.port.get)) + workerWatcher.onDisconnected(otherRpcAddress) assert(!workerWatcher.isShutDown) rpcEnv.shutdown() } diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index 63947df3d43a..8a199459c1dd 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.hadoop.io.Text -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.util.Utils import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, GzipCodec} @@ -36,7 +36,7 @@ import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, Gzi * [[org.apache.spark.input.WholeTextFileRecordReader WholeTextFileRecordReader]]. A temporary * directory is created as fake input. Temporal storage would be deleted in the end. */ -class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAll { +class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { private var sc: SparkContext = _ private var factory: CompressionCodecFactory = _ @@ -85,7 +85,7 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAl */ test("Correctness of WholeTextFileRecordReader.") { val dir = Utils.createTempDir() - println(s"Local disk address is ${dir.toString}.") + logInfo(s"Local disk address is ${dir.toString}.") WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => createNativeFile(dir, filename, contents, false) @@ -109,7 +109,7 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAl test("Correctness of WholeTextFileRecordReader with GzipCodec.") { val dir = Utils.createTempDir() - println(s"Local disk address is ${dir.toString}.") + logInfo(s"Local disk address is ${dir.toString}.") WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => createNativeFile(dir, filename, contents, true) diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 9e4d34fb7d38..44eb5a046912 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -60,7 +60,9 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext tmpFile = new File(testTempDir, getClass.getSimpleName + ".txt") val pw = new PrintWriter(new FileWriter(tmpFile)) for (x <- 1 to numRecords) { + // scalastyle:off println pw.println(RandomUtils.nextInt(0, numBuckets)) + // scalastyle:on println } pw.close() @@ -284,6 +286,10 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext private def runAndReturnMetrics(job: => Unit, collector: (SparkListenerTaskEnd) => Option[Long]): Long = { val taskMetrics = new ArrayBuffer[Long]() + + // Avoid receiving earlier taskEnd events + sc.listenerBus.waitUntilEmpty(500) + sc.addSparkListener(new SparkListener() { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { collector(taskEnd).foreach(taskMetrics += _) diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala deleted file mode 100644 index 5e364cc0edeb..000000000000 --- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala +++ /dev/null @@ -1,296 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.io.IOException -import java.nio._ - -import scala.concurrent.duration._ -import scala.concurrent.{Await, TimeoutException} -import scala.language.postfixOps - -import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} -import org.apache.spark.util.Utils - -/** - * Test the ConnectionManager with various security settings. - */ -class ConnectionManagerSuite extends SparkFunSuite { - - test("security default off") { - val conf = new SparkConf - val securityManager = new SecurityManager(conf) - val manager = new ConnectionManager(0, conf, securityManager) - var receivedMessage = false - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - receivedMessage = true - None - }) - - val size = 10 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - Await.result(manager.sendMessageReliably(manager.id, bufferMessage), 10 seconds) - - assert(receivedMessage == true) - - manager.stop() - } - - test("security on same password") { - val conf = new SparkConf - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - conf.set("spark.app.id", "app-id") - val securityManager = new SecurityManager(conf) - val manager = new ConnectionManager(0, conf, securityManager) - var numReceivedMessages = 0 - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedMessages += 1 - None - }) - val managerServer = new ConnectionManager(0, conf, securityManager) - var numReceivedServerMessages = 0 - managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedServerMessages += 1 - None - }) - - val size = 10 * 1024 * 1024 - val count = 10 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - Await.result(manager.sendMessageReliably(managerServer.id, bufferMessage), 10 seconds) - }) - - assert(numReceivedServerMessages == 10) - assert(numReceivedMessages == 0) - - manager.stop() - managerServer.stop() - } - - test("security mismatch password") { - val conf = new SparkConf - conf.set("spark.authenticate", "true") - conf.set("spark.app.id", "app-id") - conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf) - val manager = new ConnectionManager(0, conf, securityManager) - var numReceivedMessages = 0 - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedMessages += 1 - None - }) - - val badconf = conf.clone.set("spark.authenticate.secret", "bad") - val badsecurityManager = new SecurityManager(badconf) - val managerServer = new ConnectionManager(0, badconf, badsecurityManager) - var numReceivedServerMessages = 0 - - managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedServerMessages += 1 - None - }) - - val size = 10 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - // Expect managerServer to close connection, which we'll report as an error: - intercept[IOException] { - Await.result(manager.sendMessageReliably(managerServer.id, bufferMessage), 10 seconds) - } - - assert(numReceivedServerMessages == 0) - assert(numReceivedMessages == 0) - - manager.stop() - managerServer.stop() - } - - test("security mismatch auth off") { - val conf = new SparkConf - conf.set("spark.authenticate", "false") - conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf) - val manager = new ConnectionManager(0, conf, securityManager) - var numReceivedMessages = 0 - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedMessages += 1 - None - }) - - val badconf = new SparkConf - badconf.set("spark.authenticate", "true") - badconf.set("spark.authenticate.secret", "good") - val badsecurityManager = new SecurityManager(badconf) - val managerServer = new ConnectionManager(0, badconf, badsecurityManager) - var numReceivedServerMessages = 0 - managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedServerMessages += 1 - None - }) - - val size = 10 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - (0 until 1).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliably(managerServer.id, bufferMessage) - }).foreach(f => { - try { - val g = Await.result(f, 1 second) - assert(false) - } catch { - case i: IOException => - assert(true) - case e: TimeoutException => { - // we should timeout here since the client can't do the negotiation - assert(true) - } - } - }) - - assert(numReceivedServerMessages == 0) - assert(numReceivedMessages == 0) - manager.stop() - managerServer.stop() - } - - test("security auth off") { - val conf = new SparkConf - conf.set("spark.authenticate", "false") - val securityManager = new SecurityManager(conf) - val manager = new ConnectionManager(0, conf, securityManager) - var numReceivedMessages = 0 - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedMessages += 1 - None - }) - - val badconf = new SparkConf - badconf.set("spark.authenticate", "false") - val badsecurityManager = new SecurityManager(badconf) - val managerServer = new ConnectionManager(0, badconf, badsecurityManager) - var numReceivedServerMessages = 0 - - managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedServerMessages += 1 - None - }) - - val size = 10 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - (0 until 10).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliably(managerServer.id, bufferMessage) - }).foreach(f => { - try { - val g = Await.result(f, 1 second) - } catch { - case e: Exception => { - assert(false) - } - } - }) - assert(numReceivedServerMessages == 10) - assert(numReceivedMessages == 0) - - manager.stop() - managerServer.stop() - } - - test("Ack error message") { - val conf = new SparkConf - conf.set("spark.authenticate", "false") - val securityManager = new SecurityManager(conf) - val manager = new ConnectionManager(0, conf, securityManager) - val managerServer = new ConnectionManager(0, conf, securityManager) - managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - throw new Exception("Custom exception text") - }) - - val size = 10 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - val bufferMessage = Message.createBufferMessage(buffer) - - val future = manager.sendMessageReliably(managerServer.id, bufferMessage) - - val exception = intercept[IOException] { - Await.result(future, 1 second) - } - assert(Utils.exceptionString(exception).contains("Custom exception text")) - - manager.stop() - managerServer.stop() - - } - - test("sendMessageReliably timeout") { - val clientConf = new SparkConf - clientConf.set("spark.authenticate", "false") - val ackTimeoutS = 30 - clientConf.set("spark.core.connection.ack.wait.timeout", s"${ackTimeoutS}s") - - val clientSecurityManager = new SecurityManager(clientConf) - val manager = new ConnectionManager(0, clientConf, clientSecurityManager) - - val serverConf = new SparkConf - serverConf.set("spark.authenticate", "false") - val serverSecurityManager = new SecurityManager(serverConf) - val managerServer = new ConnectionManager(0, serverConf, serverSecurityManager) - managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - // sleep 60 sec > ack timeout for simulating server slow down or hang up - Thread.sleep(ackTimeoutS * 3 * 1000) - None - }) - - val size = 10 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - - val future = manager.sendMessageReliably(managerServer.id, bufferMessage) - - // Future should throw IOException in 30 sec. - // Otherwise TimeoutExcepton is thrown from Await.result. - // We expect TimeoutException is not thrown. - intercept[IOException] { - Await.result(future, (ackTimeoutS * 2) second) - } - - manager.stop() - managerServer.stop() - } - -} - diff --git a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala index 08215a2bafc0..05013fbc49b8 100644 --- a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala @@ -22,11 +22,12 @@ import java.sql._ import org.scalatest.BeforeAndAfter import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite} +import org.apache.spark.util.Utils class JdbcRDDSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { before { - Class.forName("org.apache.derby.jdbc.EmbeddedDriver") + Utils.classForName("org.apache.derby.jdbc.EmbeddedDriver") val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true") try { diff --git a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala new file mode 100644 index 000000000000..5103eb74b245 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala @@ -0,0 +1,330 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.{SparkException, SparkContext, LocalSparkContext, SparkFunSuite} + +import org.mockito.Mockito.spy +import org.apache.spark.storage.{RDDBlockId, StorageLevel} + +/** + * Fine-grained tests for local checkpointing. + * For end-to-end tests, see CheckpointSuite. + */ +class LocalCheckpointSuite extends SparkFunSuite with LocalSparkContext { + + override def beforeEach(): Unit = { + sc = new SparkContext("local[2]", "test") + } + + test("transform storage level") { + val transform = LocalRDDCheckpointData.transformStorageLevel _ + assert(transform(StorageLevel.NONE) === StorageLevel.DISK_ONLY) + assert(transform(StorageLevel.MEMORY_ONLY) === StorageLevel.MEMORY_AND_DISK) + assert(transform(StorageLevel.MEMORY_ONLY_SER) === StorageLevel.MEMORY_AND_DISK_SER) + assert(transform(StorageLevel.MEMORY_ONLY_2) === StorageLevel.MEMORY_AND_DISK_2) + assert(transform(StorageLevel.MEMORY_ONLY_SER_2) === StorageLevel.MEMORY_AND_DISK_SER_2) + assert(transform(StorageLevel.DISK_ONLY) === StorageLevel.DISK_ONLY) + assert(transform(StorageLevel.DISK_ONLY_2) === StorageLevel.DISK_ONLY_2) + assert(transform(StorageLevel.MEMORY_AND_DISK) === StorageLevel.MEMORY_AND_DISK) + assert(transform(StorageLevel.MEMORY_AND_DISK_SER) === StorageLevel.MEMORY_AND_DISK_SER) + assert(transform(StorageLevel.MEMORY_AND_DISK_2) === StorageLevel.MEMORY_AND_DISK_2) + assert(transform(StorageLevel.MEMORY_AND_DISK_SER_2) === StorageLevel.MEMORY_AND_DISK_SER_2) + // Off-heap is not supported and Spark should fail fast + intercept[SparkException] { + transform(StorageLevel.OFF_HEAP) + } + } + + test("basic lineage truncation") { + val numPartitions = 4 + val parallelRdd = sc.parallelize(1 to 100, numPartitions) + val mappedRdd = parallelRdd.map { i => i + 1 } + val filteredRdd = mappedRdd.filter { i => i % 2 == 0 } + val expectedPartitionIndices = (0 until numPartitions).toArray + assert(filteredRdd.checkpointData.isEmpty) + assert(filteredRdd.getStorageLevel === StorageLevel.NONE) + assert(filteredRdd.partitions.map(_.index) === expectedPartitionIndices) + assert(filteredRdd.dependencies.size === 1) + assert(filteredRdd.dependencies.head.rdd === mappedRdd) + assert(mappedRdd.dependencies.size === 1) + assert(mappedRdd.dependencies.head.rdd === parallelRdd) + assert(parallelRdd.dependencies.size === 0) + + // Mark the RDD for local checkpointing + filteredRdd.localCheckpoint() + assert(filteredRdd.checkpointData.isDefined) + assert(!filteredRdd.checkpointData.get.isCheckpointed) + assert(!filteredRdd.checkpointData.get.checkpointRDD.isDefined) + assert(filteredRdd.getStorageLevel === LocalRDDCheckpointData.DEFAULT_STORAGE_LEVEL) + + // After an action, the lineage is truncated + val result = filteredRdd.collect() + assert(filteredRdd.checkpointData.get.isCheckpointed) + assert(filteredRdd.checkpointData.get.checkpointRDD.isDefined) + val checkpointRdd = filteredRdd.checkpointData.flatMap(_.checkpointRDD).get + assert(filteredRdd.dependencies.size === 1) + assert(filteredRdd.dependencies.head.rdd === checkpointRdd) + assert(filteredRdd.partitions.map(_.index) === expectedPartitionIndices) + assert(checkpointRdd.partitions.map(_.index) === expectedPartitionIndices) + + // Recomputation should yield the same result + assert(filteredRdd.collect() === result) + assert(filteredRdd.collect() === result) + } + + test("basic lineage truncation - caching before checkpointing") { + testBasicLineageTruncationWithCaching( + newRdd.persist(StorageLevel.MEMORY_ONLY).localCheckpoint(), + StorageLevel.MEMORY_AND_DISK) + } + + test("basic lineage truncation - caching after checkpointing") { + testBasicLineageTruncationWithCaching( + newRdd.localCheckpoint().persist(StorageLevel.MEMORY_ONLY), + StorageLevel.MEMORY_AND_DISK) + } + + test("indirect lineage truncation") { + testIndirectLineageTruncation( + newRdd.localCheckpoint(), + LocalRDDCheckpointData.DEFAULT_STORAGE_LEVEL) + } + + test("indirect lineage truncation - caching before checkpointing") { + testIndirectLineageTruncation( + newRdd.persist(StorageLevel.MEMORY_ONLY).localCheckpoint(), + StorageLevel.MEMORY_AND_DISK) + } + + test("indirect lineage truncation - caching after checkpointing") { + testIndirectLineageTruncation( + newRdd.localCheckpoint().persist(StorageLevel.MEMORY_ONLY), + StorageLevel.MEMORY_AND_DISK) + } + + test("checkpoint without draining iterator") { + testWithoutDrainingIterator( + newSortedRdd.localCheckpoint(), + LocalRDDCheckpointData.DEFAULT_STORAGE_LEVEL, + 50) + } + + test("checkpoint without draining iterator - caching before checkpointing") { + testWithoutDrainingIterator( + newSortedRdd.persist(StorageLevel.MEMORY_ONLY).localCheckpoint(), + StorageLevel.MEMORY_AND_DISK, + 50) + } + + test("checkpoint without draining iterator - caching after checkpointing") { + testWithoutDrainingIterator( + newSortedRdd.localCheckpoint().persist(StorageLevel.MEMORY_ONLY), + StorageLevel.MEMORY_AND_DISK, + 50) + } + + test("checkpoint blocks exist") { + testCheckpointBlocksExist( + newRdd.localCheckpoint(), + LocalRDDCheckpointData.DEFAULT_STORAGE_LEVEL) + } + + test("checkpoint blocks exist - caching before checkpointing") { + testCheckpointBlocksExist( + newRdd.persist(StorageLevel.MEMORY_ONLY).localCheckpoint(), + StorageLevel.MEMORY_AND_DISK) + } + + test("checkpoint blocks exist - caching after checkpointing") { + testCheckpointBlocksExist( + newRdd.localCheckpoint().persist(StorageLevel.MEMORY_ONLY), + StorageLevel.MEMORY_AND_DISK) + } + + test("missing checkpoint block fails with informative message") { + val rdd = newRdd.localCheckpoint() + val numPartitions = rdd.partitions.size + val partitionIndices = rdd.partitions.map(_.index) + val bmm = sc.env.blockManager.master + + // After an action, the blocks should be found somewhere in the cache + rdd.collect() + partitionIndices.foreach { i => + assert(bmm.contains(RDDBlockId(rdd.id, i))) + } + + // Remove one of the blocks to simulate executor failure + // Collecting the RDD should now fail with an informative exception + val blockId = RDDBlockId(rdd.id, numPartitions - 1) + bmm.removeBlock(blockId) + try { + rdd.collect() + fail("Collect should have failed if local checkpoint block is removed...") + } catch { + case se: SparkException => + assert(se.getMessage.contains(s"Checkpoint block $blockId not found")) + assert(se.getMessage.contains("rdd.checkpoint()")) // suggest an alternative + assert(se.getMessage.contains("fault-tolerant")) // justify the alternative + } + } + + /** + * Helper method to create a simple RDD. + */ + private def newRdd: RDD[Int] = { + sc.parallelize(1 to 100, 4) + .map { i => i + 1 } + .filter { i => i % 2 == 0 } + } + + /** + * Helper method to create a simple sorted RDD. + */ + private def newSortedRdd: RDD[Int] = newRdd.sortBy(identity) + + /** + * Helper method to test basic lineage truncation with caching. + * + * @param rdd an RDD that is both marked for caching and local checkpointing + */ + private def testBasicLineageTruncationWithCaching[T]( + rdd: RDD[T], + targetStorageLevel: StorageLevel): Unit = { + require(targetStorageLevel !== StorageLevel.NONE) + require(rdd.getStorageLevel !== StorageLevel.NONE) + require(rdd.isLocallyCheckpointed) + val result = rdd.collect() + assert(rdd.getStorageLevel === targetStorageLevel) + assert(rdd.checkpointData.isDefined) + assert(rdd.checkpointData.get.isCheckpointed) + assert(rdd.checkpointData.get.checkpointRDD.isDefined) + assert(rdd.dependencies.head.rdd === rdd.checkpointData.get.checkpointRDD.get) + assert(rdd.collect() === result) + assert(rdd.collect() === result) + } + + /** + * Helper method to test indirect lineage truncation. + * + * Indirect lineage truncation here means the action is called on one of the + * checkpointed RDD's descendants, but not on the checkpointed RDD itself. + * + * @param rdd a locally checkpointed RDD + */ + private def testIndirectLineageTruncation[T]( + rdd: RDD[T], + targetStorageLevel: StorageLevel): Unit = { + require(targetStorageLevel !== StorageLevel.NONE) + require(rdd.isLocallyCheckpointed) + val rdd1 = rdd.map { i => i + "1" } + val rdd2 = rdd1.map { i => i + "2" } + val rdd3 = rdd2.map { i => i + "3" } + val rddDependencies = rdd.dependencies + val rdd1Dependencies = rdd1.dependencies + val rdd2Dependencies = rdd2.dependencies + val rdd3Dependencies = rdd3.dependencies + assert(rdd1Dependencies.size === 1) + assert(rdd1Dependencies.head.rdd === rdd) + assert(rdd2Dependencies.size === 1) + assert(rdd2Dependencies.head.rdd === rdd1) + assert(rdd3Dependencies.size === 1) + assert(rdd3Dependencies.head.rdd === rdd2) + + // Only the locally checkpointed RDD should have special storage level + assert(rdd.getStorageLevel === targetStorageLevel) + assert(rdd1.getStorageLevel === StorageLevel.NONE) + assert(rdd2.getStorageLevel === StorageLevel.NONE) + assert(rdd3.getStorageLevel === StorageLevel.NONE) + + // After an action, only the dependencies of the checkpointed RDD changes + val result = rdd3.collect() + assert(rdd.dependencies !== rddDependencies) + assert(rdd1.dependencies === rdd1Dependencies) + assert(rdd2.dependencies === rdd2Dependencies) + assert(rdd3.dependencies === rdd3Dependencies) + assert(rdd3.collect() === result) + assert(rdd3.collect() === result) + } + + /** + * Helper method to test checkpointing without fully draining the iterator. + * + * Not all RDD actions fully consume the iterator. As a result, a subset of the partitions + * may not be cached. However, since we want to truncate the lineage safely, we explicitly + * ensure that *all* partitions are fully cached. This method asserts this behavior. + * + * @param rdd a locally checkpointed RDD + */ + private def testWithoutDrainingIterator[T]( + rdd: RDD[T], + targetStorageLevel: StorageLevel, + targetCount: Int): Unit = { + require(targetCount > 0) + require(targetStorageLevel !== StorageLevel.NONE) + require(rdd.isLocallyCheckpointed) + + // This does not drain the iterator, but checkpointing should still work + val first = rdd.first() + assert(rdd.count() === targetCount) + assert(rdd.count() === targetCount) + assert(rdd.first() === first) + assert(rdd.first() === first) + + // Test the same thing by calling actions on a descendant instead + val rdd1 = rdd.repartition(10) + val rdd2 = rdd1.repartition(100) + val rdd3 = rdd2.repartition(1000) + val first2 = rdd3.first() + assert(rdd3.count() === targetCount) + assert(rdd3.count() === targetCount) + assert(rdd3.first() === first2) + assert(rdd3.first() === first2) + assert(rdd.getStorageLevel === targetStorageLevel) + assert(rdd1.getStorageLevel === StorageLevel.NONE) + assert(rdd2.getStorageLevel === StorageLevel.NONE) + assert(rdd3.getStorageLevel === StorageLevel.NONE) + } + + /** + * Helper method to test whether the checkpoint blocks are found in the cache. + * + * @param rdd a locally checkpointed RDD + */ + private def testCheckpointBlocksExist[T]( + rdd: RDD[T], + targetStorageLevel: StorageLevel): Unit = { + val bmm = sc.env.blockManager.master + val partitionIndices = rdd.partitions.map(_.index) + + // The blocks should not exist before the action + partitionIndices.foreach { i => + assert(!bmm.contains(RDDBlockId(rdd.id, i))) + } + + // After an action, the blocks should be found in the cache with the expected level + rdd.collect() + partitionIndices.foreach { i => + val blockId = RDDBlockId(rdd.id, i) + val status = bmm.getBlockStatus(blockId) + assert(status.nonEmpty) + assert(status.values.head.storageLevel === targetStorageLevel) + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala new file mode 100644 index 000000000000..e281e817e493 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.collection.mutable + +import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite, TaskContext} + +class MapPartitionsWithPreparationRDDSuite extends SparkFunSuite with LocalSparkContext { + + test("prepare called before parent partition is computed") { + sc = new SparkContext("local", "test") + + // Have the parent partition push a number to the list + val parent = sc.parallelize(1 to 100, 1).mapPartitions { iter => + TestObject.things.append(20) + iter + } + + // Push a different number during the prepare phase + val preparePartition = () => { TestObject.things.append(10) } + + // Push yet another number during the execution phase + val executePartition = ( + taskContext: TaskContext, + partitionIndex: Int, + notUsed: Unit, + parentIterator: Iterator[Int]) => { + TestObject.things.append(30) + TestObject.things.iterator + } + + // Verify that the numbers are pushed in the order expected + val rdd = new MapPartitionsWithPreparationRDD[Int, Int, Unit]( + parent, preparePartition, executePartition) + val result = rdd.collect() + assert(result === Array(10, 20, 30)) + + TestObject.things.clear() + // Zip two of these RDDs, both should be prepared before the parent is executed + val rdd2 = new MapPartitionsWithPreparationRDD[Int, Int, Unit]( + parent, preparePartition, executePartition) + val result2 = rdd.zipPartitions(rdd2)((a, b) => a).collect() + assert(result2 === Array(10, 10, 20, 30, 20, 30)) + } + +} + +private object TestObject { + val things = new mutable.ListBuffer[Int] +} diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index dfa102f432a0..1321ec84735b 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -282,6 +282,29 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { )) } + // See SPARK-9326 + test("cogroup with empty RDD") { + import scala.reflect.classTag + val intPairCT = classTag[(Int, Int)] + + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.emptyRDD[(Int, Int)](intPairCT) + + val joined = rdd1.cogroup(rdd2).collect() + assert(joined.size > 0) + } + + // See SPARK-9326 + test("cogroup with groupByed RDD having 0 partitions") { + import scala.reflect.classTag + val intCT = classTag[Int] + + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.emptyRDD[Int](intCT).groupBy((x) => 5) + val joined = rdd1.cogroup(rdd2).collect() + assert(joined.size > 0) + } + test("rightOuterJoin") { val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 32f04d54eff9..5f73ec867596 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -175,7 +175,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { } val hadoopPart1 = generateFakeHadoopPartition() val pipedRdd = new PipedRDD(nums, "printenv " + varName) - val tContext = new TaskContextImpl(0, 0, 0, 0, null) + val tContext = TaskContext.empty() val rddIter = pipedRdd.compute(hadoopPart1, tContext) val arr = rddIter.toArray assert(arr(0) == "/some/path") diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala index f65349e3e358..16a92f54f936 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala @@ -38,6 +38,13 @@ class RDDOperationScopeSuite extends SparkFunSuite with BeforeAndAfter { sc.stop() } + test("equals and hashCode") { + val opScope1 = new RDDOperationScope("scope1", id = "1") + val opScope2 = new RDDOperationScope("scope1", id = "1") + assert(opScope1 === opScope2) + assert(opScope1.hashCode() === opScope2.hashCode()) + } + test("getAllScopes") { assert(scope1.getAllScopes === Seq(scope1)) assert(scope2.getAllScopes === Seq(scope1, scope2)) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index f6da9f98ad25..5f718ea9f7be 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -679,7 +679,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { test("runJob on an invalid partition") { intercept[IllegalArgumentException] { - sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2), false) + sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2)) } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcAddressSuite.scala new file mode 100644 index 000000000000..b3223ec61bf7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/RpcAddressSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import org.apache.spark.{SparkException, SparkFunSuite} + +class RpcAddressSuite extends SparkFunSuite { + + test("hostPort") { + val address = RpcAddress("1.2.3.4", 1234) + assert(address.host == "1.2.3.4") + assert(address.port == 1234) + assert(address.hostPort == "1.2.3.4:1234") + } + + test("fromSparkURL") { + val address = RpcAddress.fromSparkURL("spark://1.2.3.4:1234") + assert(address.host == "1.2.3.4") + assert(address.port == 1234) + } + + test("fromSparkURL: a typo url") { + val e = intercept[SparkException] { + RpcAddress.fromSparkURL("spark://1.2. 3.4:1234") + } + assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage) + } + + test("fromSparkURL: invalid scheme") { + val e = intercept[SparkException] { + RpcAddress.fromSparkURL("invalid://1.2.3.4:1234") + } + assert("Invalid master URL: invalid://1.2.3.4:1234" === e.getMessage) + } + + test("toSparkURL") { + val address = RpcAddress("1.2.3.4", 1234) + assert(address.toSparkURL == "spark://1.2.3.4:1234") + } +} diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 1f0aa759b08d..6ceafe433774 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -155,16 +155,21 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { }) val conf = new SparkConf() + val shortProp = "spark.rpc.short.timeout" conf.set("spark.rpc.retry.wait", "0") conf.set("spark.rpc.numRetries", "1") val anotherEnv = createRpcEnv(conf, "remote", 13345) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") try { - val e = intercept[Exception] { - rpcEndpointRef.askWithRetry[String]("hello", 1 millis) + // Any exception thrown in askWithRetry is wrapped with a SparkException and set as the cause + val e = intercept[SparkException] { + rpcEndpointRef.askWithRetry[String]("hello", new RpcTimeout(1 millis, shortProp)) } - assert(e.isInstanceOf[TimeoutException] || e.getCause.isInstanceOf[TimeoutException]) + // The SparkException cause should be a RpcTimeoutException with message indicating the + // controlling timeout property + assert(e.getCause.isInstanceOf[RpcTimeoutException]) + assert(e.getCause.getMessage.contains(shortProp)) } finally { anotherEnv.shutdown() anotherEnv.awaitTermination() @@ -539,6 +544,92 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } + test("construct RpcTimeout with conf property") { + val conf = new SparkConf + + val testProp = "spark.ask.test.timeout" + val testDurationSeconds = 30 + val secondaryProp = "spark.ask.secondary.timeout" + + conf.set(testProp, s"${testDurationSeconds}s") + conf.set(secondaryProp, "100s") + + // Construct RpcTimeout with a single property + val rt1 = RpcTimeout(conf, testProp) + assert( testDurationSeconds === rt1.duration.toSeconds ) + + // Construct RpcTimeout with prioritized list of properties + val rt2 = RpcTimeout(conf, Seq("spark.ask.invalid.timeout", testProp, secondaryProp), "1s") + assert( testDurationSeconds === rt2.duration.toSeconds ) + + // Construct RpcTimeout with default value, + val defaultProp = "spark.ask.default.timeout" + val defaultDurationSeconds = 1 + val rt3 = RpcTimeout(conf, Seq(defaultProp), defaultDurationSeconds.toString + "s") + assert( defaultDurationSeconds === rt3.duration.toSeconds ) + assert( rt3.timeoutProp.contains(defaultProp) ) + + // Try to construct RpcTimeout with an unconfigured property + intercept[NoSuchElementException] { + RpcTimeout(conf, "spark.ask.invalid.timeout") + } + } + + test("ask a message timeout on Future using RpcTimeout") { + case class NeverReply(msg: String) + + val rpcEndpointRef = env.setupEndpoint("ask-future", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => context.reply(msg) + case _: NeverReply => + } + }) + + val longTimeout = new RpcTimeout(1 second, "spark.rpc.long.timeout") + val shortTimeout = new RpcTimeout(10 millis, "spark.rpc.short.timeout") + + // Ask with immediate response, should complete successfully + val fut1 = rpcEndpointRef.ask[String]("hello", longTimeout) + val reply1 = longTimeout.awaitResult(fut1) + assert("hello" === reply1) + + // Ask with a delayed response and wait for response immediately that should timeout + val fut2 = rpcEndpointRef.ask[String](NeverReply("doh"), shortTimeout) + val reply2 = + intercept[RpcTimeoutException] { + shortTimeout.awaitResult(fut2) + }.getMessage + + // RpcTimeout.awaitResult should have added the property to the TimeoutException message + assert(reply2.contains(shortTimeout.timeoutProp)) + + // Ask with delayed response and allow the Future to timeout before Await.result + val fut3 = rpcEndpointRef.ask[String](NeverReply("goodbye"), shortTimeout) + + // Allow future to complete with failure using plain Await.result, this will return + // once the future is complete to verify addMessageIfTimeout was invoked + val reply3 = + intercept[RpcTimeoutException] { + Await.result(fut3, 200 millis) + }.getMessage + + // When the future timed out, the recover callback should have used + // RpcTimeout.addMessageIfTimeout to add the property to the TimeoutException message + assert(reply3.contains(shortTimeout.timeoutProp)) + + // Use RpcTimeout.awaitResult to process Future, since it has already failed with + // RpcTimeoutException, the same RpcTimeoutException should be thrown + val reply4 = + intercept[RpcTimeoutException] { + shortTimeout.awaitResult(fut3) + }.getMessage + + // Ensure description is not in message twice after addMessageIfTimeout and awaitResult + assert(shortTimeout.timeoutProp.r.findAllIn(reply4).length === 1) + } + } class UnserializableClass diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala index a33a83db7bc9..4aa75c9230b2 100644 --- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.rpc.akka import org.apache.spark.rpc._ -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SSLSampleConfigs, SecurityManager, SparkConf} class AkkaRpcEnvSuite extends RpcEnvSuite { @@ -47,4 +47,22 @@ class AkkaRpcEnvSuite extends RpcEnvSuite { } } + test("uriOf") { + val uri = env.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint") + assert("akka.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri) + } + + test("uriOf: ssl") { + val conf = SSLSampleConfigs.sparkSSLConfig() + val securityManager = new SecurityManager(conf) + val rpcEnv = new AkkaRpcEnvFactory().create( + RpcEnvConfig(conf, "test", "localhost", 12346, securityManager)) + try { + val uri = rpcEnv.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint") + assert("akka.ssl.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri) + } finally { + rpcEnv.shutdown() + } + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/AdaptiveSchedulingSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/AdaptiveSchedulingSuite.scala new file mode 100644 index 000000000000..3fe28027c3c2 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/AdaptiveSchedulingSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.rdd.{ShuffledRDDPartition, RDD, ShuffledRDD} +import org.apache.spark._ + +object AdaptiveSchedulingSuiteState { + var tasksRun = 0 + + def clear(): Unit = { + tasksRun = 0 + } +} + +/** A special ShuffledRDD where we can pass a ShuffleDependency object to use */ +class CustomShuffledRDD[K, V, C](@transient dep: ShuffleDependency[K, V, C]) + extends RDD[(K, C)](dep.rdd.context, Seq(dep)) { + + override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = { + val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] + SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) + .read() + .asInstanceOf[Iterator[(K, C)]] + } + + override def getPartitions: Array[Partition] = { + Array.tabulate[Partition](dep.partitioner.numPartitions)(i => new ShuffledRDDPartition(i)) + } +} + +class AdaptiveSchedulingSuite extends SparkFunSuite with LocalSparkContext { + test("simple use of submitMapStage") { + try { + sc = new SparkContext("local[1,2]", "test") + val rdd = sc.parallelize(1 to 3, 3).map { x => + AdaptiveSchedulingSuiteState.tasksRun += 1 + (x, x) + } + val dep = new ShuffleDependency[Int, Int, Int](rdd, new HashPartitioner(2)) + val shuffled = new CustomShuffledRDD[Int, Int, Int](dep) + sc.submitMapStage(dep).get() + assert(AdaptiveSchedulingSuiteState.tasksRun == 3) + assert(shuffled.collect().toSet == Set((1, 1), (2, 2), (3, 3))) + assert(AdaptiveSchedulingSuiteState.tasksRun == 3) + } finally { + AdaptiveSchedulingSuiteState.clear() + } + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index 34145691153c..eef6aafa624e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -26,7 +26,7 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo val conf = new SparkConf conf.set("spark.akka.frameSize", "1") conf.set("spark.default.parallelism", "1") - sc = new SparkContext("local-cluster[2 , 1 , 512]", "test", conf) + sc = new SparkContext("local-cluster[2, 1, 1024]", "test", conf) val frameSize = AkkaUtils.maxFrameSizeBytes(sc.conf) val buffer = new SerializableBuffer(java.nio.ByteBuffer.allocate(2 * frameSize)) val larger = sc.parallelize(Seq(buffer)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 833b600746e9..1c55f90ad9b4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -26,11 +26,11 @@ import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.executor.TaskMetrics import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} import org.apache.spark.util.CallSite -import org.apache.spark.executor.TaskMetrics class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler) extends DAGSchedulerEventProcessLoop(dagScheduler) { @@ -101,9 +101,15 @@ class DAGSchedulerSuite /** Length of time to wait while draining listener events. */ val WAIT_TIMEOUT_MILLIS = 10000 val sparkListener = new SparkListener() { + val submittedStageInfos = new HashSet[StageInfo] val successfulStages = new HashSet[Int] val failedStages = new ArrayBuffer[Int] val stageByOrderOfExecution = new ArrayBuffer[Int] + + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { + submittedStageInfos += stageSubmitted.stageInfo + } + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { val stageInfo = stageCompleted.stageInfo stageByOrderOfExecution += stageInfo.stageId @@ -127,11 +133,11 @@ class DAGSchedulerSuite val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]] // stub out BlockManagerMaster.getLocations to use our cacheLocations val blockManagerMaster = new BlockManagerMaster(null, conf, true) { - override def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { + override def getLocations(blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { blockIds.map { _.asRDDId.map(id => (id.rddId -> id.splitIndex)).flatMap(key => cacheLocations.get(key)). getOrElse(Seq()) - }.toSeq + }.toIndexedSeq } override def removeExecutor(execId: String) { // don't need to propagate to the driver, which we don't have @@ -146,10 +152,17 @@ class DAGSchedulerSuite override def jobFailed(exception: Exception) = { failure = exception } } + /** A simple helper class for creating custom JobListeners */ + class SimpleListener extends JobListener { + val results = new HashMap[Int, Any] + var failure: Exception = null + override def taskSucceeded(index: Int, result: Any): Unit = results.put(index, result) + override def jobFailed(exception: Exception): Unit = { failure = exception } + } + before { - // Enable local execution for this test - val conf = new SparkConf().set("spark.localExecution.enabled", "true") - sc = new SparkContext("local", "DAGSchedulerSuite", conf) + sc = new SparkContext("local", "DAGSchedulerSuite") + sparkListener.submittedStageInfos.clear() sparkListener.successfulStages.clear() sparkListener.failedStages.clear() failure = null @@ -165,12 +178,7 @@ class DAGSchedulerSuite sc.listenerBus, mapOutputTracker, blockManagerMaster, - sc.env) { - override def runLocally(job: ActiveJob) { - // don't bother with the thread while unit testing - runLocallyWithinThread(job) - } - } + sc.env) dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler) } @@ -229,21 +237,29 @@ class DAGSchedulerSuite } } - /** Sends the rdd to the scheduler for scheduling and returns the job id. */ + /** Submits a job to the scheduler and returns the job id. */ private def submit( rdd: RDD[_], partitions: Array[Int], func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, - allowLocal: Boolean = false, listener: JobListener = jobListener): Int = { val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, func, partitions, allowLocal, CallSite("", ""), listener)) + runEvent(JobSubmitted(jobId, rdd, func, partitions, CallSite("", ""), listener)) + jobId + } + + /** Submits a map stage to the scheduler and returns the job id. */ + private def submitMapStage( + shuffleDep: ShuffleDependency[_, _, _], + listener: JobListener = jobListener): Int = { + val jobId = scheduler.nextJobId.getAndIncrement() + runEvent(MapStageSubmitted(jobId, shuffleDep, CallSite("", ""), listener)) jobId } /** Sends TaskSetFailed to the scheduler. */ private def failed(taskSet: TaskSet, message: String) { - runEvent(TaskSetFailed(taskSet, message)) + runEvent(TaskSetFailed(taskSet, message, None)) } /** Sends JobCancelled to the DAG scheduler. */ @@ -277,37 +293,6 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } - test("local job") { - val rdd = new PairOfIntsRDD(sc, Nil) { - override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = - Array(42 -> 0).iterator - override def getPartitions: Array[Partition] = - Array( new Partition { override def index: Int = 0 } ) - override def getPreferredLocations(split: Partition): List[String] = Nil - override def toString: String = "DAGSchedulerSuite Local RDD" - } - val jobId = scheduler.nextJobId.getAndIncrement() - runEvent( - JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) - assert(results === Map(0 -> 42)) - assertDataStructuresEmpty() - } - - test("local job oom") { - val rdd = new PairOfIntsRDD(sc, Nil) { - override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = - throw new java.lang.OutOfMemoryError("test local job oom") - override def getPartitions = Array( new Partition { override def index = 0 } ) - override def getPreferredLocations(split: Partition) = Nil - override def toString = "DAGSchedulerSuite Local RDD" - } - val jobId = scheduler.nextJobId.getAndIncrement() - runEvent( - JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) - assert(results.size == 0) - assertDataStructuresEmpty() - } - test("run trivial job w/ dependency") { val baseRdd = new MyRDD(sc, 1, Nil) val finalRdd = new MyRDD(sc, 1, List(new OneToOneDependency(baseRdd))) @@ -317,6 +302,15 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + test("equals and hashCode AccumulableInfo") { + val accInfo1 = new AccumulableInfo(1, " Accumulable " + 1, Some("delta" + 1), "val" + 1, true) + val accInfo2 = new AccumulableInfo(1, " Accumulable " + 1, Some("delta" + 1), "val" + 1, false) + val accInfo3 = new AccumulableInfo(1, " Accumulable " + 1, Some("delta" + 1), "val" + 1, false) + assert(accInfo1 !== accInfo2) + assert(accInfo2 === accInfo3) + assert(accInfo2.hashCode() === accInfo3.hashCode()) + } + test("cache location preferences w/ dependency") { val baseRdd = new MyRDD(sc, 1, Nil).cache() val finalRdd = new MyRDD(sc, 1, List(new OneToOneDependency(baseRdd))) @@ -445,12 +439,7 @@ class DAGSchedulerSuite sc.listenerBus, mapOutputTracker, blockManagerMaster, - sc.env) { - override def runLocally(job: ActiveJob) { - // don't bother with the thread while unit testing - runLocallyWithinThread(job) - } - } + sc.env) dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(noKillScheduler) val jobId = submit(new MyRDD(sc, 1, Nil), Array(0)) cancel(jobId) @@ -476,8 +465,8 @@ class DAGSchedulerSuite complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -503,13 +492,289 @@ class DAGSchedulerSuite // have the 2nd attempt pass complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.size)))) // we can see both result blocks now - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === - Array("hostA", "hostB")) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === + HashSet("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) assertDataStructuresEmpty() } + + // Helper function to validate state when creating tests for task failures + private def checkStageId(stageId: Int, attempt: Int, stageAttempt: TaskSet) { + assert(stageAttempt.stageId === stageId) + assert(stageAttempt.stageAttemptId == attempt) + } + + + // Helper functions to extract commonly used code in Fetch Failure test cases + private def setupStageAbortTest(sc: SparkContext) { + sc.listenerBus.addListener(new EndListener()) + ended = false + jobResult = null + } + + // Create a new Listener to confirm that the listenerBus sees the JobEnd message + // when we abort the stage. This message will also be consumed by the EventLoggingListener + // so this will propagate up to the user. + var ended = false + var jobResult : JobResult = null + + class EndListener extends SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + jobResult = jobEnd.jobResult + ended = true + } + } + + /** + * Common code to get the next stage attempt, confirm it's the one we expect, and complete it + * successfully. + * + * @param stageId - The current stageId + * @param attemptIdx - The current attempt count + * @param numShufflePartitions - The number of partitions in the next stage + */ + private def completeShuffleMapStageSuccessfully( + stageId: Int, + attemptIdx: Int, + numShufflePartitions: Int): Unit = { + val stageAttempt = taskSets.last + checkStageId(stageId, attemptIdx, stageAttempt) + complete(stageAttempt, stageAttempt.tasks.zipWithIndex.map { + case (task, idx) => + (Success, makeMapStatus("host" + ('A' + idx).toChar, numShufflePartitions)) + }.toSeq) + } + + /** + * Common code to get the next stage attempt, confirm it's the one we expect, and complete it + * with all FetchFailure. + * + * @param stageId - The current stageId + * @param attemptIdx - The current attempt count + * @param shuffleDep - The shuffle dependency of the stage with a fetch failure + */ + private def completeNextStageWithFetchFailure( + stageId: Int, + attemptIdx: Int, + shuffleDep: ShuffleDependency[_, _, _]): Unit = { + val stageAttempt = taskSets.last + checkStageId(stageId, attemptIdx, stageAttempt) + complete(stageAttempt, stageAttempt.tasks.zipWithIndex.map { case (task, idx) => + (FetchFailed(makeBlockManagerId("hostA"), shuffleDep.shuffleId, 0, idx, "ignored"), null) + }.toSeq) + } + + /** + * Common code to get the next result stage attempt, confirm it's the one we expect, and + * complete it with a success where we return 42. + * + * @param stageId - The current stageId + * @param attemptIdx - The current attempt count + */ + private def completeNextResultStageWithSuccess(stageId: Int, attemptIdx: Int): Unit = { + val stageAttempt = taskSets.last + checkStageId(stageId, attemptIdx, stageAttempt) + assert(scheduler.stageIdToStage(stageId).isInstanceOf[ResultStage]) + complete(stageAttempt, stageAttempt.tasks.zipWithIndex.map(_ => (Success, 42)).toSeq) + } + + /** + * In this test, we simulate a job where many tasks in the same stage fail. We want to show + * that many fetch failures inside a single stage attempt do not trigger an abort + * on their own, but only when there are enough failing stage attempts. + */ + test("Single stage fetch failure should not abort the stage.") { + setupStageAbortTest(sc) + + val parts = 8 + val shuffleMapRdd = new MyRDD(sc, parts, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, parts, List(shuffleDep)) + submit(reduceRdd, (0 until parts).toArray) + + completeShuffleMapStageSuccessfully(0, 0, numShufflePartitions = parts) + + completeNextStageWithFetchFailure(1, 0, shuffleDep) + + // Resubmit and confirm that now all is well + scheduler.resubmitFailedStages() + + assert(scheduler.runningStages.nonEmpty) + assert(!ended) + + // Complete stage 0 and then stage 1 with a "42" + completeShuffleMapStageSuccessfully(0, 1, numShufflePartitions = parts) + completeNextResultStageWithSuccess(1, 1) + + // Confirm job finished succesfully + sc.listenerBus.waitUntilEmpty(1000) + assert(ended === true) + assert(results === (0 until parts).map { idx => idx -> 42 }.toMap) + assertDataStructuresEmpty() + } + + /** + * In this test we simulate a job failure where the first stage completes successfully and + * the second stage fails due to a fetch failure. Multiple successive fetch failures of a stage + * trigger an overall job abort to avoid endless retries. + */ + test("Multiple consecutive stage fetch failures should lead to job being aborted.") { + setupStageAbortTest(sc) + + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + submit(reduceRdd, Array(0, 1)) + + for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES) { + // Complete all the tasks for the current attempt of stage 0 successfully + completeShuffleMapStageSuccessfully(0, attempt, numShufflePartitions = 2) + + // Now we should have a new taskSet, for a new attempt of stage 1. + // Fail all these tasks with FetchFailure + completeNextStageWithFetchFailure(1, attempt, shuffleDep) + + // this will trigger a resubmission of stage 0, since we've lost some of its + // map output, for the next iteration through the loop + scheduler.resubmitFailedStages() + + if (attempt < Stage.MAX_CONSECUTIVE_FETCH_FAILURES - 1) { + assert(scheduler.runningStages.nonEmpty) + assert(!ended) + } else { + // Stage should have been aborted and removed from running stages + assertDataStructuresEmpty() + sc.listenerBus.waitUntilEmpty(1000) + assert(ended) + jobResult match { + case JobFailed(reason) => + assert(reason.getMessage.contains("ResultStage 1 () has failed the maximum")) + case other => fail(s"expected JobFailed, not $other") + } + } + } + } + + /** + * In this test, we create a job with two consecutive shuffles, and simulate 2 failures for each + * shuffle fetch. In total In total, the job has had four failures overall but not four failures + * for a particular stage, and as such should not be aborted. + */ + test("Failures in different stages should not trigger an overall abort") { + setupStageAbortTest(sc) + + val shuffleOneRdd = new MyRDD(sc, 2, Nil).cache() + val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) + val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne)).cache() + val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) + val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo)) + submit(finalRdd, Array(0)) + + // In the first two iterations, Stage 0 succeeds and stage 1 fails. In the next two iterations, + // stage 2 fails. + for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES) { + // Complete all the tasks for the current attempt of stage 0 successfully + completeShuffleMapStageSuccessfully(0, attempt, numShufflePartitions = 2) + + if (attempt < Stage.MAX_CONSECUTIVE_FETCH_FAILURES / 2) { + // Now we should have a new taskSet, for a new attempt of stage 1. + // Fail all these tasks with FetchFailure + completeNextStageWithFetchFailure(1, attempt, shuffleDepOne) + } else { + completeShuffleMapStageSuccessfully(1, attempt, numShufflePartitions = 1) + + // Fail stage 2 + completeNextStageWithFetchFailure(2, attempt - Stage.MAX_CONSECUTIVE_FETCH_FAILURES / 2, + shuffleDepTwo) + } + + // this will trigger a resubmission of stage 0, since we've lost some of its + // map output, for the next iteration through the loop + scheduler.resubmitFailedStages() + } + + completeShuffleMapStageSuccessfully(0, 4, numShufflePartitions = 2) + completeShuffleMapStageSuccessfully(1, 4, numShufflePartitions = 1) + + // Succeed stage2 with a "42" + completeNextResultStageWithSuccess(2, Stage.MAX_CONSECUTIVE_FETCH_FAILURES/2) + + assert(results === Map(0 -> 42)) + assertDataStructuresEmpty() + } + + /** + * In this test we demonstrate that only consecutive failures trigger a stage abort. A stage may + * fail multiple times, succeed, then fail a few more times (because its run again by downstream + * dependencies). The total number of failed attempts for one stage will go over the limit, + * but that doesn't matter, since they have successes in the middle. + */ + test("Non-consecutive stage failures don't trigger abort") { + setupStageAbortTest(sc) + + val shuffleOneRdd = new MyRDD(sc, 2, Nil).cache() + val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) + val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne)).cache() + val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) + val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo)) + submit(finalRdd, Array(0)) + + // First, execute stages 0 and 1, failing stage 1 up to MAX-1 times. + for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES - 1) { + // Make each task in stage 0 success + completeShuffleMapStageSuccessfully(0, attempt, numShufflePartitions = 2) + + // Now we should have a new taskSet, for a new attempt of stage 1. + // Fail these tasks with FetchFailure + completeNextStageWithFetchFailure(1, attempt, shuffleDepOne) + + scheduler.resubmitFailedStages() + + // Confirm we have not yet aborted + assert(scheduler.runningStages.nonEmpty) + assert(!ended) + } + + // Rerun stage 0 and 1 to step through the task set + completeShuffleMapStageSuccessfully(0, 3, numShufflePartitions = 2) + completeShuffleMapStageSuccessfully(1, 3, numShufflePartitions = 1) + + // Fail stage 2 so that stage 1 is resubmitted when we call scheduler.resubmitFailedStages() + completeNextStageWithFetchFailure(2, 0, shuffleDepTwo) + + scheduler.resubmitFailedStages() + + // Rerun stage 0 to step through the task set + completeShuffleMapStageSuccessfully(0, 4, numShufflePartitions = 2) + + // Now again, fail stage 1 (up to MAX_FAILURES) but confirm that this doesn't trigger an abort + // since we succeeded in between. + completeNextStageWithFetchFailure(1, 4, shuffleDepOne) + + scheduler.resubmitFailedStages() + + // Confirm we have not yet aborted + assert(scheduler.runningStages.nonEmpty) + assert(!ended) + + // Next, succeed all and confirm output + // Rerun stage 0 + 1 + completeShuffleMapStageSuccessfully(0, 5, numShufflePartitions = 2) + completeShuffleMapStageSuccessfully(1, 5, numShufflePartitions = 1) + + // Succeed stage 2 and verify results + completeNextResultStageWithSuccess(2, 1) + + assertDataStructuresEmpty() + sc.listenerBus.waitUntilEmpty(1000) + assert(ended === true) + assert(results === Map(0 -> 42)) + } + test("trivial shuffle with multiple fetch failures") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) @@ -520,8 +785,8 @@ class DAGSchedulerSuite (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) // The MapOutputTracker should know about both map output locations. - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === - Array("hostA", "hostB")) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === + HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. runEvent(CompletionEvent( @@ -547,33 +812,205 @@ class DAGSchedulerSuite assert(sparkListener.failedStages.size == 1) } + /** + * This tests the case where another FetchFailed comes in while the map stage is getting + * re-run. + */ + test("late fetch failures don't cause multiple concurrent attempts for the same map stage") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + submit(reduceRdd, Array(0, 1)) + + val mapStageId = 0 + def countSubmittedMapStageAttempts(): Int = { + sparkListener.submittedStageInfos.count(_.stageId == mapStageId) + } + + // The map stage should have been submitted. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedMapStageAttempts() === 1) + + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + // The MapOutputTracker should know about both map output locations. + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === + HashSet("hostA", "hostB")) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 1).map(_._1.host).toSet === + HashSet("hostA", "hostB")) + + // The first result task fails, with a fetch failure for the output from the first mapper. + runEvent(CompletionEvent( + taskSets(1).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + null, + Map[Long, Any](), + createFakeTaskInfo(), + null)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sparkListener.failedStages.contains(1)) + + // Trigger resubmission of the failed map stage. + runEvent(ResubmitFailedStages) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + + // Another attempt for the map stage should have been submitted, resulting in 2 total attempts. + assert(countSubmittedMapStageAttempts() === 2) + + // The second ResultTask fails, with a fetch failure for the output from the second mapper. + runEvent(CompletionEvent( + taskSets(1).tasks(1), + FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"), + null, + Map[Long, Any](), + createFakeTaskInfo(), + null)) + + // Another ResubmitFailedStages event should not result in another attempt for the map + // stage being run concurrently. + // NOTE: the actual ResubmitFailedStages may get called at any time during this, but it + // shouldn't effect anything -- our calling it just makes *SURE* it gets called between the + // desired event and our check. + runEvent(ResubmitFailedStages) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedMapStageAttempts() === 2) + + } + + /** + * This tests the case where a late FetchFailed comes in after the map stage has finished getting + * retried and a new reduce stage starts running. + */ + test("extremely late fetch failures don't cause multiple concurrent attempts for " + + "the same stage") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + submit(reduceRdd, Array(0, 1)) + + def countSubmittedReduceStageAttempts(): Int = { + sparkListener.submittedStageInfos.count(_.stageId == 1) + } + def countSubmittedMapStageAttempts(): Int = { + sparkListener.submittedStageInfos.count(_.stageId == 0) + } + + // The map stage should have been submitted. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedMapStageAttempts() === 1) + + // Complete the map stage. + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + + // The reduce stage should have been submitted. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedReduceStageAttempts() === 1) + + // The first result task fails, with a fetch failure for the output from the first mapper. + runEvent(CompletionEvent( + taskSets(1).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + null, + Map[Long, Any](), + createFakeTaskInfo(), + null)) + + // Trigger resubmission of the failed map stage and finish the re-started map task. + runEvent(ResubmitFailedStages) + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) + + // Because the map stage finished, another attempt for the reduce stage should have been + // submitted, resulting in 2 total attempts for each the map and the reduce stage. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedMapStageAttempts() === 2) + assert(countSubmittedReduceStageAttempts() === 2) + + // A late FetchFailed arrives from the second task in the original reduce stage. + runEvent(CompletionEvent( + taskSets(1).tasks(1), + FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"), + null, + Map[Long, Any](), + createFakeTaskInfo(), + null)) + + // Running ResubmitFailedStages shouldn't result in any more attempts for the map stage, because + // the FetchFailed should have been ignored + runEvent(ResubmitFailedStages) + + // The FetchFailed from the original reduce stage should be ignored. + assert(countSubmittedMapStageAttempts() === 2) + } + test("ignore late map task completions") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) val shuffleId = shuffleDep.shuffleId val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) + // pretend we were told hostA went away val oldEpoch = mapOutputTracker.getEpoch runEvent(ExecutorLost("exec-hostA")) val newEpoch = mapOutputTracker.getEpoch assert(newEpoch > oldEpoch) + + // now start completing some tasks in the shuffle map stage, under different hosts + // and epochs, and make sure scheduler updates its state correctly val taskSet = taskSets(0) + val shuffleStage = scheduler.stageIdToStage(taskSet.stageId).asInstanceOf[ShuffleMapStage] + assert(shuffleStage.numAvailableOutputs === 0) + // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", - reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) - // should work because it's a non-failed host - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", - reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSet.tasks(0), + Success, + makeMapStatus("hostA", reduceRdd.partitions.size), + null, + createFakeTaskInfo(), + null)) + assert(shuffleStage.numAvailableOutputs === 0) + + // should work because it's a non-failed host (so the available map outputs will increase) + runEvent(CompletionEvent( + taskSet.tasks(0), + Success, + makeMapStatus("hostB", reduceRdd.partitions.size), + null, + createFakeTaskInfo(), + null)) + assert(shuffleStage.numAvailableOutputs === 1) + // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", - reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) - // should work because it's a new epoch + runEvent(CompletionEvent( + taskSet.tasks(0), + Success, + makeMapStatus("hostA", reduceRdd.partitions.size), + null, + createFakeTaskInfo(), + null)) + assert(shuffleStage.numAvailableOutputs === 1) + + // should work because it's a new epoch, which will increase the number of available map + // outputs, and also finish the stage taskSet.tasks(1).epoch = newEpoch - runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", - reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) + runEvent(CompletionEvent( + taskSet.tasks(1), + Success, + makeMapStatus("hostA", reduceRdd.partitions.size), + null, + createFakeTaskInfo(), + null)) + assert(shuffleStage.numAvailableOutputs === 2) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) + + // finish the next stage normally, which completes the job complete(taskSets(1), Seq((Success, 42), (Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) assertDataStructuresEmpty() @@ -668,8 +1105,8 @@ class DAGSchedulerSuite (Success, makeMapStatus("hostB", 1)))) // have hostC complete the resubmitted task complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) complete(taskSets(2), Seq((Success, 42))) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -713,7 +1150,7 @@ class DAGSchedulerSuite submit(finalRdd, Array(0)) cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD")) cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) - // complete stage 2 + // complete stage 0 complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 2)), (Success, makeMapStatus("hostB", 2)))) @@ -721,7 +1158,7 @@ class DAGSchedulerSuite complete(taskSets(1), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) - // pretend stage 0 failed because hostA went down + // pretend stage 2 failed because hostA went down complete(taskSets(2), Seq( (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) // TODO assert this: @@ -748,40 +1185,52 @@ class DAGSchedulerSuite // Run this on executors sc.parallelize(1 to 10, 2).foreach { item => acc.add(1) } - // Run this within a local thread - sc.parallelize(1 to 10, 2).map { item => acc.add(1) }.take(1) - - // Make sure we can still run local commands as well as cluster commands. + // Make sure we can still run commands assert(sc.parallelize(1 to 10, 2).count() === 10) - assert(sc.parallelize(1 to 10, 2).first() === 1) } test("misbehaved resultHandler should not crash DAGScheduler and SparkContext") { - val e1 = intercept[SparkDriverExecutionException] { + val e = intercept[SparkDriverExecutionException] { val rdd = sc.parallelize(1 to 10, 2) sc.runJob[Int, Int]( rdd, (context: TaskContext, iter: Iterator[Int]) => iter.size, - Seq(0), - allowLocal = true, + Seq(0, 1), (part: Int, result: Int) => throw new DAGSchedulerSuiteDummyException) } - assert(e1.getCause.isInstanceOf[DAGSchedulerSuiteDummyException]) + assert(e.getCause.isInstanceOf[DAGSchedulerSuiteDummyException]) - val e2 = intercept[SparkDriverExecutionException] { - val rdd = sc.parallelize(1 to 10, 2) - sc.runJob[Int, Int]( - rdd, - (context: TaskContext, iter: Iterator[Int]) => iter.size, - Seq(0, 1), - allowLocal = false, - (part: Int, result: Int) => throw new DAGSchedulerSuiteDummyException) + // Make sure we can still run commands + assert(sc.parallelize(1 to 10, 2).count() === 10) + } + + test("getPartitions exceptions should not crash DAGScheduler and SparkContext (SPARK-8606)") { + val e1 = intercept[DAGSchedulerSuiteDummyException] { + val rdd = new MyRDD(sc, 2, Nil) { + override def getPartitions: Array[Partition] = { + throw new DAGSchedulerSuiteDummyException + } + } + rdd.reduceByKey(_ + _, 1).count() } - assert(e2.getCause.isInstanceOf[DAGSchedulerSuiteDummyException]) - // Make sure we can still run local commands as well as cluster commands. + // Make sure we can still run commands + assert(sc.parallelize(1 to 10, 2).count() === 10) + } + + test("getPreferredLocations errors should not crash DAGScheduler and SparkContext (SPARK-8606)") { + val e1 = intercept[SparkException] { + val rdd = new MyRDD(sc, 2, Nil) { + override def getPreferredLocations(split: Partition): Seq[String] = { + throw new DAGSchedulerSuiteDummyException + } + } + rdd.count() + } + assert(e1.getMessage.contains(classOf[DAGSchedulerSuiteDummyException].getName)) + + // Make sure we can still run commands assert(sc.parallelize(1 to 10, 2).count() === 10) - assert(sc.parallelize(1 to 10, 2).first() === 1) } test("accumulator not calculated for resubmitted result stage") { @@ -809,15 +1258,15 @@ class DAGSchedulerSuite submit(reduceRdd, Array(0)) complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)))) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostA"))) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostA"))) // Reducer should run on the same host that map task ran val reduceTaskSet = taskSets(1) assertLocations(reduceTaskSet, Seq(Seq("hostA"))) complete(reduceTaskSet, Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("reduce task locality preferences should only include machines with largest map outputs") { @@ -841,7 +1290,268 @@ class DAGSchedulerSuite assertLocations(reduceTaskSet, Seq(hosts)) complete(reduceTaskSet, Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() + } + + test("stages with both narrow and shuffle dependencies use narrow ones for locality") { + // Create an RDD that has both a shuffle dependency and a narrow dependency (e.g. for a join) + val rdd1 = new MyRDD(sc, 1, Nil) + val rdd2 = new MyRDD(sc, 1, Nil, locations = Seq(Seq("hostB"))) + val shuffleDep = new ShuffleDependency(rdd1, null) + val narrowDep = new OneToOneDependency(rdd2) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep, narrowDep)) + submit(reduceRdd, Array(0)) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 1)))) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostA"))) + + // Reducer should run where RDD 2 has preferences, even though though it also has a shuffle dep + val reduceTaskSet = taskSets(1) + assertLocations(reduceTaskSet, Seq(Seq("hostB"))) + complete(reduceTaskSet, Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assertDataStructuresEmpty() + } + + test("Spark exceptions should include call site in stack trace") { + val e = intercept[SparkException] { + sc.parallelize(1 to 10, 2).map { _ => throw new RuntimeException("uh-oh!") }.count() + } + + // Does not include message, ONLY stack trace. + val stackTraceString = e.getStackTraceString + + // should actually include the RDD operation that invoked the method: + assert(stackTraceString.contains("org.apache.spark.rdd.RDD.count")) + + // should include the FunSuite setup: + assert(stackTraceString.contains("org.scalatest.FunSuite")) + } + + test("simple map stage submission") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + + // Submit a map stage by itself + submitMapStage(shuffleDep) + assert(results.size === 0) // No results yet + completeShuffleMapStageSuccessfully(0, 0, 1) + assert(results.size === 1) + results.clear() + assertDataStructuresEmpty() + + // Submit a reduce job that depends on this map stage; it should directly do the reduce + submit(reduceRdd, Array(0)) + completeNextResultStageWithSuccess(2, 0) + assert(results === Map(0 -> 42)) + results.clear() + assertDataStructuresEmpty() + + // Check that if we submit the map stage again, no tasks run + submitMapStage(shuffleDep) + assert(results.size === 1) + assertDataStructuresEmpty() + } + + test("map stage submission with reduce stage also depending on the data") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + + // Submit the map stage by itself + submitMapStage(shuffleDep) + + // Submit a reduce job that depends on this map stage + submit(reduceRdd, Array(0)) + + // Complete tasks for the map stage + completeShuffleMapStageSuccessfully(0, 0, 1) + assert(results.size === 1) + results.clear() + + // Complete tasks for the reduce stage + completeNextResultStageWithSuccess(1, 0) + assert(results === Map(0 -> 42)) + results.clear() + assertDataStructuresEmpty() + + // Check that if we submit the map stage again, no tasks run + submitMapStage(shuffleDep) + assert(results.size === 1) + assertDataStructuresEmpty() + } + + test("map stage submission with fetch failure") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + + // Submit a map stage by itself + submitMapStage(shuffleDep) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) + assert(results.size === 1) + results.clear() + assertDataStructuresEmpty() + + // Submit a reduce job that depends on this map stage, but where one reduce will fail a fetch + submit(reduceRdd, Array(0, 1)) + complete(taskSets(1), Seq( + (Success, 42), + (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null))) + // Ask the scheduler to try it again; TaskSet 2 will rerun the map task that we couldn't fetch + // from, then TaskSet 3 will run the reduce stage + scheduler.resubmitFailedStages() + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.size)))) + complete(taskSets(3), Seq((Success, 43))) + assert(results === Map(0 -> 42, 1 -> 43)) + results.clear() + assertDataStructuresEmpty() + + // Run another reduce job without a failure; this should just work + submit(reduceRdd, Array(0, 1)) + complete(taskSets(4), Seq( + (Success, 44), + (Success, 45))) + assert(results === Map(0 -> 44, 1 -> 45)) + results.clear() + assertDataStructuresEmpty() + + // Resubmit the map stage; this should also just work + submitMapStage(shuffleDep) + assert(results.size === 1) + results.clear() + assertDataStructuresEmpty() + } + + /** + * In this test, we have three RDDs with shuffle dependencies, and we submit map stage jobs + * that are waiting on each one, as well as a reduce job on the last one. We test that all of + * these jobs complete even if there are some fetch failures in both shuffles. + */ + test("map stage submission with multiple shared stages and failures") { + val rdd1 = new MyRDD(sc, 2, Nil) + val dep1 = new ShuffleDependency(rdd1, new HashPartitioner(2)) + val rdd2 = new MyRDD(sc, 2, List(dep1)) + val dep2 = new ShuffleDependency(rdd2, new HashPartitioner(2)) + val rdd3 = new MyRDD(sc, 2, List(dep2)) + + val listener1 = new SimpleListener + val listener2 = new SimpleListener + val listener3 = new SimpleListener + + submitMapStage(dep1, listener1) + submitMapStage(dep2, listener2) + submit(rdd3, Array(0, 1), listener = listener3) + + // Complete the first stage + assert(taskSets(0).stageId === 0) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", rdd1.partitions.size)), + (Success, makeMapStatus("hostB", rdd1.partitions.size)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(listener1.results.size === 1) + + // When attempting the second stage, show a fetch failure + assert(taskSets(1).stageId === 1) + complete(taskSets(1), Seq( + (Success, makeMapStatus("hostA", rdd2.partitions.size)), + (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0, 0, "ignored"), null))) + scheduler.resubmitFailedStages() + assert(listener2.results.size === 0) // Second stage listener should not have a result yet + + // Stage 0 should now be running as task set 2; make its task succeed + assert(taskSets(2).stageId === 0) + complete(taskSets(2), Seq( + (Success, makeMapStatus("hostC", rdd2.partitions.size)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + assert(listener2.results.size === 0) // Second stage listener should still not have a result + + // Stage 1 should now be running as task set 3; make its first task succeed + assert(taskSets(3).stageId === 1) + complete(taskSets(3), Seq( + (Success, makeMapStatus("hostB", rdd2.partitions.size)), + (Success, makeMapStatus("hostD", rdd2.partitions.size)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep2.shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostD"))) + assert(listener2.results.size === 1) + + // Finally, the reduce job should be running as task set 4; make it see a fetch failure, + // then make it run again and succeed + assert(taskSets(4).stageId === 2) + complete(taskSets(4), Seq( + (Success, 52), + (FetchFailed(makeBlockManagerId("hostD"), dep2.shuffleId, 0, 0, "ignored"), null))) + scheduler.resubmitFailedStages() + + // TaskSet 5 will rerun stage 1's lost task, then TaskSet 6 will rerun stage 2 + assert(taskSets(5).stageId === 1) + complete(taskSets(5), Seq( + (Success, makeMapStatus("hostE", rdd2.partitions.size)))) + complete(taskSets(6), Seq( + (Success, 53))) + assert(listener3.results === Map(0 -> 52, 1 -> 53)) + assertDataStructuresEmpty() + } + + /** + * In this test, we run a map stage where one of the executors fails but we still receive a + * "zombie" complete message from that executor. We want to make sure the stage is not reported + * as done until all tasks have completed. + */ + test("map stage submission with executor failure late map task completions") { + val shuffleMapRdd = new MyRDD(sc, 3, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + + submitMapStage(shuffleDep) + + val oldTaskSet = taskSets(0) + runEvent(CompletionEvent(oldTaskSet.tasks(0), Success, makeMapStatus("hostA", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 0) // Map stage job should not be complete yet + + // Pretend host A was lost + val oldEpoch = mapOutputTracker.getEpoch + runEvent(ExecutorLost("exec-hostA")) + val newEpoch = mapOutputTracker.getEpoch + assert(newEpoch > oldEpoch) + + // Suppose we also get a completed event from task 1 on the same host; this should be ignored + runEvent(CompletionEvent(oldTaskSet.tasks(1), Success, makeMapStatus("hostA", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 0) // Map stage job should not be complete yet + + // A completion from another task should work because it's a non-failed host + runEvent(CompletionEvent(oldTaskSet.tasks(2), Success, makeMapStatus("hostB", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 0) // Map stage job should not be complete yet + + // Now complete tasks in the second task set + val newTaskSet = taskSets(1) + assert(newTaskSet.tasks.size === 2) // Both tasks 0 and 1 were on on hostA + runEvent(CompletionEvent(newTaskSet.tasks(0), Success, makeMapStatus("hostB", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 0) // Map stage job should not be complete yet + runEvent(CompletionEvent(newTaskSet.tasks(1), Success, makeMapStatus("hostB", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 1) // Map stage job should now finally be complete + assertDataStructuresEmpty() + + // Also test that a reduce stage using this shuffled data can immediately run + val reduceRDD = new MyRDD(sc, 2, List(shuffleDep)) + results.clear() + submit(reduceRDD, Array(0, 1)) + complete(taskSets(2), Seq((Success, 42), (Success, 43))) + assert(results === Map(0 -> 42, 1 -> 43)) + results.clear() + assertDataStructuresEmpty() } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index f681f21b6205..5cb2d4225d28 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -180,7 +180,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit // into SPARK-6688. val conf = getLoggingConf(testDirPath, compressionCodec) .set("spark.hadoop.fs.defaultFS", "unsupported://example.com") - val sc = new SparkContext("local-cluster[2,2,512]", "test", conf) + val sc = new SparkContext("local-cluster[2,2,1024]", "test", conf) assert(sc.eventLogger.isDefined) val eventLogger = sc.eventLogger.get val eventLogPath = eventLogger.logPath diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index 0a7cb69416a0..f7e16af9d3a9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -19,9 +19,11 @@ package org.apache.spark.scheduler import org.apache.spark.TaskContext -class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) { +class FakeTask( + stageId: Int, + prefLocs: Seq[TaskLocation] = Nil) + extends Task[Int](stageId, 0, 0, Seq.empty) { override def runTask(context: TaskContext): Int = 0 - override def preferredLocations: Seq[TaskLocation] = prefLocs } @@ -31,12 +33,16 @@ object FakeTask { * locations for each task (given as varargs) if this sequence is not empty. */ def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { + createTaskSet(numTasks, 0, prefLocs: _*) + } + + def createTaskSet(numTasks: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { if (prefLocs.size != 0 && prefLocs.size != numTasks) { throw new IllegalArgumentException("Wrong number of task locations") } val tasks = Array.tabulate[Task[_]](numTasks) { i => new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil) } - new TaskSet(tasks, 0, 0, 0, null) + new TaskSet(tasks, 0, stageAttemptId, 0, null) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala index 9b92f8de5675..f33324792495 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala @@ -25,7 +25,7 @@ import org.apache.spark.TaskContext * A Task implementation that fails to serialize. */ private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) - extends Task[Array[Byte]](stageId, 0) { + extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) { override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte] override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]() diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala new file mode 100644 index 000000000000..1ae5b030f083 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.hadoop.mapred.{FileOutputCommitter, TaskAttemptContext} +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.{Span, Seconds} + +import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext, SparkFunSuite, TaskContext} +import org.apache.spark.util.Utils + +/** + * Integration tests for the OutputCommitCoordinator. + * + * See also: [[OutputCommitCoordinatorSuite]] for unit tests that use mocks. + */ +class OutputCommitCoordinatorIntegrationSuite + extends SparkFunSuite + with LocalSparkContext + with Timeouts { + + override def beforeAll(): Unit = { + super.beforeAll() + val conf = new SparkConf() + .set("master", "local[2,4]") + .set("spark.speculation", "true") + .set("spark.hadoop.mapred.output.committer.class", + classOf[ThrowExceptionOnFirstAttemptOutputCommitter].getCanonicalName) + sc = new SparkContext("local[2, 4]", "test", conf) + } + + test("exception thrown in OutputCommitter.commitTask()") { + // Regression test for SPARK-10381 + failAfter(Span(60, Seconds)) { + val tempDir = Utils.createTempDir() + try { + sc.parallelize(1 to 4, 2).map(_.toString).saveAsTextFile(tempDir.getAbsolutePath + "/out") + } finally { + Utils.deleteRecursively(tempDir) + } + } + } +} + +private class ThrowExceptionOnFirstAttemptOutputCommitter extends FileOutputCommitter { + override def commitTask(context: TaskAttemptContext): Unit = { + val ctx = TaskContext.get() + if (ctx.attemptNumber < 1) { + throw new java.io.FileNotFoundException("Intentional exception") + } + super.commitTask(context) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index a9036da9cc93..6d08d7c5b7d2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -63,6 +63,9 @@ import scala.language.postfixOps * was not in SparkHadoopWriter, the tests would still pass because only one of the * increments would be captured even though the commit in both tasks was executed * erroneously. + * + * See also: [[OutputCommitCoordinatorIntegrationSuite]] for integration tests that do + * not use mocks. */ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { @@ -134,14 +137,14 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { test("Only one of two duplicate commit tasks should commit") { val rdd = sc.parallelize(Seq(1), 1) sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).commitSuccessfully _, - 0 until rdd.partitions.size, allowLocal = false) + 0 until rdd.partitions.size) assert(tempDir.list().size === 1) } test("If commit fails, if task is retried it should not be locked, and will succeed.") { val rdd = sc.parallelize(Seq(1), 1) sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).failFirstCommitAttempt _, - 0 until rdd.partitions.size, allowLocal = false) + 0 until rdd.partitions.size) assert(tempDir.list().size === 1) } @@ -164,27 +167,28 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { test("Only authorized committer failures can clear the authorized committer lock (SPARK-6614)") { val stage: Int = 1 - val partition: Long = 2 - val authorizedCommitter: Long = 3 - val nonAuthorizedCommitter: Long = 100 + val partition: Int = 2 + val authorizedCommitter: Int = 3 + val nonAuthorizedCommitter: Int = 100 outputCommitCoordinator.stageStart(stage) - assert(outputCommitCoordinator.canCommit(stage, partition, attempt = authorizedCommitter)) - assert(!outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter)) + + assert(outputCommitCoordinator.canCommit(stage, partition, authorizedCommitter)) + assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter)) // The non-authorized committer fails outputCommitCoordinator.taskCompleted( - stage, partition, attempt = nonAuthorizedCommitter, reason = TaskKilled) + stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled) // New tasks should still not be able to commit because the authorized committer has not failed assert( - !outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 1)) + !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 1)) // The authorized committer now fails, clearing the lock outputCommitCoordinator.taskCompleted( - stage, partition, attempt = authorizedCommitter, reason = TaskKilled) + stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled) // A new task should now be allowed to become the authorized committer assert( - outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 2)) + outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 2)) // There can only be one authorized committer assert( - !outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 3)) + !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 3)) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index ff3fa95ec32a..103fc19369c9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -52,8 +52,10 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter { val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, 125L, "Mickey", None) val applicationEnd = SparkListenerApplicationEnd(1000L) + // scalastyle:off println writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationStart)))) writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationEnd)))) + // scalastyle:on println writer.close() val conf = EventLoggingListenerSuite.getLoggingConf(logFilePath) @@ -100,7 +102,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter { fileSystem.mkdirs(logDirPath) val conf = EventLoggingListenerSuite.getLoggingConf(logDirPath, codecName) - val sc = new SparkContext("local-cluster[2,1,512]", "Test replay", conf) + val sc = new SparkContext("local-cluster[2,1,1024]", "Test replay", conf) // Run a few jobs sc.parallelize(1 to 100, 1).count() diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 651295b7344c..a9652d7e7d0b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.scheduler import java.util.concurrent.Semaphore import scala.collection.mutable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.scalatest.Matchers @@ -188,7 +188,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) val rdd2 = rdd1.map(_.toString) - sc.runJob(rdd2, (items: Iterator[String]) => items.size, Seq(0, 1), true) + sc.runJob(rdd2, (items: Iterator[String]) => items.size, Seq(0, 1)) sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) @@ -365,10 +365,9 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match .set("spark.extraListeners", classOf[ListenerThatAcceptsSparkConf].getName + "," + classOf[BasicJobCounter].getName) sc = new SparkContext(conf) - sc.listenerBus.listeners.collect { case x: BasicJobCounter => x}.size should be (1) - sc.listenerBus.listeners.collect { - case x: ListenerThatAcceptsSparkConf => x - }.size should be (1) + sc.listenerBus.listeners.asScala.count(_.isInstanceOf[BasicJobCounter]) should be (1) + sc.listenerBus.listeners.asScala + .count(_.isInstanceOf[ListenerThatAcceptsSparkConf]) should be (1) } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala index d97fba00976d..d1e23ed527ff 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala @@ -34,7 +34,7 @@ class SparkListenerWithClusterSuite extends SparkFunSuite with LocalSparkContext val WAIT_TIMEOUT_MILLIS = 10000 before { - sc = new SparkContext("local-cluster[2,1,512]", "SparkListenerSuite") + sc = new SparkContext("local-cluster[2,1,1024]", "SparkListenerSuite") } test("SparkListener sends executor added message") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 7c1adc1aef1b..450ab7b9fe92 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -24,11 +24,27 @@ import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener} +import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} +import org.apache.spark.metrics.source.JvmSource class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { + test("provide metrics sources") { + val filePath = getClass.getClassLoader.getResource("test_metrics_config.properties").getFile + val conf = new SparkConf(loadDefaults = false) + .set("spark.metrics.conf", filePath) + sc = new SparkContext("local", "test", conf) + val rdd = sc.makeRDD(1 to 1) + val result = sc.runJob(rdd, (tc: TaskContext, it: Iterator[Int]) => { + tc.getMetricsSources("jvm").count { + case source: JvmSource => true + case _ => false + } + }).sum + assert(result > 0) + } + test("calls TaskCompletionListener after failure") { TaskContextSuite.completed = false sc = new SparkContext("local", "test") @@ -41,16 +57,17 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark } val closureSerializer = SparkEnv.get.closureSerializer.newInstance() val func = (c: TaskContext, i: Iterator[String]) => i.next() + val taskBinary = sc.broadcast(closureSerializer.serialize((rdd, func)).array) val task = new ResultTask[String, String]( - 0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0) + 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty) intercept[RuntimeException] { - task.run(0, 0) + task.run(0, 0, null) } assert(TaskContextSuite.completed === true) } test("all TaskCompletionListeners should be called even if some fail") { - val context = new TaskContextImpl(0, 0, 0, 0, null) + val context = TaskContext.empty() val listener = mock(classOf[TaskCompletionListener]) context.addTaskCompletionListener(_ => throw new Exception("blah")) context.addTaskCompletionListener(listener) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index a6d5232feb8d..c2edd4c317d6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -33,7 +33,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L val taskScheduler = new TaskSchedulerImpl(sc) taskScheduler.initialize(new FakeSchedulerBackend) // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. - val dagScheduler = new DAGScheduler(sc, taskScheduler) { + new DAGScheduler(sc, taskScheduler) { override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} override def executorAdded(execId: String, host: String) {} } @@ -67,7 +67,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L val taskScheduler = new TaskSchedulerImpl(sc) taskScheduler.initialize(new FakeSchedulerBackend) // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. - val dagScheduler = new DAGScheduler(sc, taskScheduler) { + new DAGScheduler(sc, taskScheduler) { override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} override def executorAdded(execId: String, host: String) {} } @@ -128,4 +128,113 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L assert(taskDescriptions.map(_.executorId) === Seq("executor0")) } + test("refuse to schedule concurrent attempts for the same stage (SPARK-8103)") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(new FakeSchedulerBackend) + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + val dagScheduler = new DAGScheduler(sc, taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} + } + taskScheduler.setDAGScheduler(dagScheduler) + val attempt1 = FakeTask.createTaskSet(1, 0) + val attempt2 = FakeTask.createTaskSet(1, 1) + taskScheduler.submitTasks(attempt1) + intercept[IllegalStateException] { taskScheduler.submitTasks(attempt2) } + + // OK to submit multiple if previous attempts are all zombie + taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId) + .get.isZombie = true + taskScheduler.submitTasks(attempt2) + val attempt3 = FakeTask.createTaskSet(1, 2) + intercept[IllegalStateException] { taskScheduler.submitTasks(attempt3) } + taskScheduler.taskSetManagerForAttempt(attempt2.stageId, attempt2.stageAttemptId) + .get.isZombie = true + taskScheduler.submitTasks(attempt3) + } + + test("don't schedule more tasks after a taskset is zombie") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(new FakeSchedulerBackend) + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + new DAGScheduler(sc, taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} + } + + val numFreeCores = 1 + val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores)) + val attempt1 = FakeTask.createTaskSet(10) + + // submit attempt 1, offer some resources, some tasks get scheduled + taskScheduler.submitTasks(attempt1) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(1 === taskDescriptions.length) + + // now mark attempt 1 as a zombie + taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId) + .get.isZombie = true + + // don't schedule anything on another resource offer + val taskDescriptions2 = taskScheduler.resourceOffers(workerOffers).flatten + assert(0 === taskDescriptions2.length) + + // if we schedule another attempt for the same stage, it should get scheduled + val attempt2 = FakeTask.createTaskSet(10, 1) + + // submit attempt 2, offer some resources, some tasks get scheduled + taskScheduler.submitTasks(attempt2) + val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten + assert(1 === taskDescriptions3.length) + val mgr = taskScheduler.taskIdToTaskSetManager.get(taskDescriptions3(0).taskId).get + assert(mgr.taskSet.stageAttemptId === 1) + } + + test("if a zombie attempt finishes, continue scheduling tasks for non-zombie attempts") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(new FakeSchedulerBackend) + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + new DAGScheduler(sc, taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} + } + + val numFreeCores = 10 + val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores)) + val attempt1 = FakeTask.createTaskSet(10) + + // submit attempt 1, offer some resources, some tasks get scheduled + taskScheduler.submitTasks(attempt1) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(10 === taskDescriptions.length) + + // now mark attempt 1 as a zombie + val mgr1 = taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId).get + mgr1.isZombie = true + + // don't schedule anything on another resource offer + val taskDescriptions2 = taskScheduler.resourceOffers(workerOffers).flatten + assert(0 === taskDescriptions2.length) + + // submit attempt 2 + val attempt2 = FakeTask.createTaskSet(10, 1) + taskScheduler.submitTasks(attempt2) + + // attempt 1 finished (this can happen even if it was marked zombie earlier -- all tasks were + // already submitted, and then they finish) + taskScheduler.taskSetFinished(mgr1) + + // now with another resource offer, we should still schedule all the tasks in attempt2 + val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten + assert(10 === taskDescriptions3.length) + + taskDescriptions3.foreach { task => + val mgr = taskScheduler.taskIdToTaskSetManager.get(task.taskId).get + assert(mgr.taskSet.stageAttemptId === 1) + } + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 0060f3396dcd..f0eadf240943 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -19,12 +19,13 @@ package org.apache.spark.scheduler import java.util.Random -import scala.collection.mutable.ArrayBuffer +import scala.collection.Map import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import org.apache.spark._ import org.apache.spark.executor.TaskMetrics -import org.apache.spark.util.{ManualClock, Utils} +import org.apache.spark.util.ManualClock class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) extends DAGScheduler(sc) { @@ -37,7 +38,7 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) task: Task[_], reason: TaskEndReason, result: Any, - accumUpdates: mutable.Map[Long, Any], + accumUpdates: Map[Long, Any], taskInfo: TaskInfo, taskMetrics: TaskMetrics) { taskScheduler.endedTasks(taskInfo.index) = reason @@ -47,7 +48,10 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) override def executorLost(execId: String) {} - override def taskSetFailed(taskSet: TaskSet, reason: String) { + override def taskSetFailed( + taskSet: TaskSet, + reason: String, + exception: Option[Throwable]): Unit = { taskScheduler.taskSetsFailed += taskSet.id } } @@ -135,7 +139,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex /** * A Task implementation that results in a large serialized task. */ -class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0) { +class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) { val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024) val random = new Random(0) random.nextBytes(randomBuffer) @@ -330,7 +334,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // Now mark host2 as dead sched.removeExecutor("exec2") - manager.executorLost("exec2", "host2") + manager.executorLost("exec2", "host2", SlaveLost()) // nothing should be chosen assert(manager.resourceOffer("exec1", "host1", ANY) === None) @@ -500,13 +504,36 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Array(PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY))) // test if the valid locality is recomputed when the executor is lost sched.removeExecutor("execC") - manager.executorLost("execC", "host2") + manager.executorLost("execC", "host2", SlaveLost()) assert(manager.myLocalityLevels.sameElements(Array(NODE_LOCAL, NO_PREF, ANY))) sched.removeExecutor("execD") - manager.executorLost("execD", "host1") + manager.executorLost("execD", "host1", SlaveLost()) assert(manager.myLocalityLevels.sameElements(Array(NO_PREF, ANY))) } + test("Executors are added but exit normally while running tasks") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc) + val taskSet = FakeTask.createTaskSet(4, + Seq(TaskLocation("host1", "execA")), + Seq(TaskLocation("host1", "execB")), + Seq(TaskLocation("host2", "execC")), + Seq()) + val manager = new TaskSetManager(sched, taskSet, 1, new ManualClock) + sched.addExecutor("execA", "host1") + manager.executorAdded() + sched.addExecutor("execC", "host2") + manager.executorAdded() + assert(manager.resourceOffer("exec1", "host1", ANY).isDefined) + sched.removeExecutor("execA") + manager.executorLost("execA", "host1", ExecutorExited(143, true, "Normal termination")) + assert(!sched.taskSetsFailed.contains(taskSet.id)) + assert(manager.resourceOffer("execC", "host2", ANY).isDefined) + sched.removeExecutor("execC") + manager.executorLost("execC", "host2", ExecutorExited(1, false, "Abnormal termination")) + assert(sched.taskSetsFailed.contains(taskSet.id)) + } + test("test RACK_LOCAL tasks") { // Assign host1 to rack1 FakeRackUtil.assignHostToRack("host1", "rack1") @@ -717,8 +744,8 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(manager.resourceOffer("execB.2", "host2", ANY) !== None) sched.removeExecutor("execA") sched.removeExecutor("execB.2") - manager.executorLost("execA", "host1") - manager.executorLost("execB.2", "host2") + manager.executorLost("execA", "host1", SlaveLost()) + manager.executorLost("execB.2", "host2", SlaveLost()) clock.advance(LOCALITY_WAIT_MS * 4) sched.addExecutor("execC", "host3") manager.executorAdded() diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala new file mode 100644 index 000000000000..525ee0d3bdc5 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import java.util +import java.util.Collections + +import org.apache.mesos.Protos.Value.Scalar +import org.apache.mesos.Protos._ +import org.apache.mesos.{Protos, Scheduler, SchedulerDriver} +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.Matchers +import org.scalatest.mock.MockitoSugar +import org.scalatest.BeforeAndAfter + +import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SecurityManager, SparkFunSuite} + +class CoarseMesosSchedulerBackendSuite extends SparkFunSuite + with LocalSparkContext + with MockitoSugar + with BeforeAndAfter { + + private def createOffer(offerId: String, slaveId: String, mem: Int, cpu: Int): Offer = { + val builder = Offer.newBuilder() + builder.addResourcesBuilder() + .setName("mem") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(mem)) + builder.addResourcesBuilder() + .setName("cpus") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(cpu)) + builder.setId(OfferID.newBuilder() + .setValue(offerId).build()) + .setFrameworkId(FrameworkID.newBuilder() + .setValue("f1")) + .setSlaveId(SlaveID.newBuilder().setValue(slaveId)) + .setHostname(s"host${slaveId}") + .build() + } + + private def createSchedulerBackend( + taskScheduler: TaskSchedulerImpl, + driver: SchedulerDriver): CoarseMesosSchedulerBackend = { + val securityManager = mock[SecurityManager] + val backend = new CoarseMesosSchedulerBackend(taskScheduler, sc, "master", securityManager) { + override protected def createSchedulerDriver( + masterUrl: String, + scheduler: Scheduler, + sparkUser: String, + appName: String, + conf: SparkConf, + webuiUrl: Option[String] = None, + checkpoint: Option[Boolean] = None, + failoverTimeout: Option[Double] = None, + frameworkId: Option[String] = None): SchedulerDriver = driver + markRegistered() + } + backend.start() + backend + } + + var sparkConf: SparkConf = _ + + before { + sparkConf = (new SparkConf) + .setMaster("local[*]") + .setAppName("test-mesos-dynamic-alloc") + .setSparkHome("/path") + + sc = new SparkContext(sparkConf) + } + + test("mesos supports killing and limiting executors") { + val driver = mock[SchedulerDriver] + when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.sc).thenReturn(sc) + + sparkConf.set("spark.driver.host", "driverHost") + sparkConf.set("spark.driver.port", "1234") + + val backend = createSchedulerBackend(taskScheduler, driver) + val minMem = backend.calculateTotalMemory(sc) + val minCpu = 4 + + val mesosOffers = new java.util.ArrayList[Offer] + mesosOffers.add(createOffer("o1", "s1", minMem, minCpu)) + + val taskID0 = TaskID.newBuilder().setValue("0").build() + + backend.resourceOffers(driver, mesosOffers) + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), + any[util.Collection[TaskInfo]], + any[Filters]) + + // simulate the allocation manager down-scaling executors + backend.doRequestTotalExecutors(0) + assert(backend.doKillExecutors(Seq("s1/0"))) + verify(driver, times(1)).killTask(taskID0) + + val mesosOffers2 = new java.util.ArrayList[Offer] + mesosOffers2.add(createOffer("o2", "s2", minMem, minCpu)) + backend.resourceOffers(driver, mesosOffers2) + + verify(driver, times(1)) + .declineOffer(OfferID.newBuilder().setValue("o2").build()) + + // Verify we didn't launch any new executor + assert(backend.slaveIdsWithExecutors.size === 1) + + backend.doRequestTotalExecutors(2) + backend.resourceOffers(driver, mesosOffers2) + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(mesosOffers2.get(0).getId)), + any[util.Collection[TaskInfo]], + any[Filters]) + + assert(backend.slaveIdsWithExecutors.size === 2) + backend.slaveLost(driver, SlaveID.newBuilder().setValue("s1").build()) + assert(backend.slaveIdsWithExecutors.size === 1) + } + + test("mesos supports killing and relaunching tasks with executors") { + val driver = mock[SchedulerDriver] + when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.sc).thenReturn(sc) + + val backend = createSchedulerBackend(taskScheduler, driver) + val minMem = backend.calculateTotalMemory(sc) + 1024 + val minCpu = 4 + + val mesosOffers = new java.util.ArrayList[Offer] + val offer1 = createOffer("o1", "s1", minMem, minCpu) + mesosOffers.add(offer1) + + val offer2 = createOffer("o2", "s1", minMem, 1); + + backend.resourceOffers(driver, mesosOffers) + + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(offer1.getId)), + anyObject(), + anyObject[Filters]) + + // Simulate task killed, executor no longer running + val status = TaskStatus.newBuilder() + .setTaskId(TaskID.newBuilder().setValue("0").build()) + .setSlaveId(SlaveID.newBuilder().setValue("s1").build()) + .setState(TaskState.TASK_KILLED) + .build + + backend.statusUpdate(driver, status) + assert(!backend.slaveIdsWithExecutors.contains("s1")) + + mesosOffers.clear() + mesosOffers.add(offer2) + backend.resourceOffers(driver, mesosOffers) + assert(backend.slaveIdsWithExecutors.contains("s1")) + + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(offer2.getId)), + anyObject(), + anyObject[Filters]) + + verify(driver, times(1)).reviveOffers() + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala deleted file mode 100644 index e72285d03d3e..000000000000 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster.mesos - -import org.mockito.Mockito._ -import org.scalatest.mock.MockitoSugar - -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} - -class MemoryUtilsSuite extends SparkFunSuite with MockitoSugar { - test("MesosMemoryUtils should always override memoryOverhead when it's set") { - val sparkConf = new SparkConf - - val sc = mock[SparkContext] - when(sc.conf).thenReturn(sparkConf) - - // 384 > sc.executorMemory * 0.1 => 512 + 384 = 896 - when(sc.executorMemory).thenReturn(512) - assert(MemoryUtils.calculateTotalMemory(sc) === 896) - - // 384 < sc.executorMemory * 0.1 => 4096 + (4096 * 0.1) = 4505.6 - when(sc.executorMemory).thenReturn(4096) - assert(MemoryUtils.calculateTotalMemory(sc) === 4505) - - // set memoryOverhead - sparkConf.set("spark.mesos.executor.memoryOverhead", "100") - assert(MemoryUtils.calculateTotalMemory(sc) === 4196) - sparkConf.set("spark.mesos.executor.memoryOverhead", "400") - assert(MemoryUtils.calculateTotalMemory(sc) === 4496) - } -} diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index 68df46a41ddc..c4dc56003120 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -18,9 +18,11 @@ package org.apache.spark.scheduler.cluster.mesos import java.nio.ByteBuffer -import java.util +import java.util.Arrays +import java.util.Collection import java.util.Collections +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -40,6 +42,38 @@ import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSui class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { + test("Use configured mesosExecutor.cores for ExecutorInfo") { + val mesosExecutorCores = 3 + val conf = new SparkConf + conf.set("spark.mesos.mesosExecutor.cores", mesosExecutorCores.toString) + + val listenerBus = mock[LiveListenerBus] + listenerBus.post( + SparkListenerExecutorAdded(anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty))) + + val sc = mock[SparkContext] + when(sc.getSparkHome()).thenReturn(Option("/spark-home")) + + when(sc.conf).thenReturn(conf) + when(sc.executorEnvs).thenReturn(new mutable.HashMap[String, String]) + when(sc.executorMemory).thenReturn(100) + when(sc.listenerBus).thenReturn(listenerBus) + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.CPUS_PER_TASK).thenReturn(2) + + val mesosSchedulerBackend = new MesosSchedulerBackend(taskScheduler, sc, "master") + + val resources = Arrays.asList( + mesosSchedulerBackend.createResource("cpus", 4), + mesosSchedulerBackend.createResource("mem", 1024)) + // uri is null. + val (executorInfo, _) = mesosSchedulerBackend.createExecutorInfo(resources, "test-id") + val executorResources = executorInfo.getResourcesList + val cpus = executorResources.asScala.find(_.getName.equals("cpus")).get.getScalar.getValue + + assert(cpus === mesosExecutorCores) + } + test("check spark-class location correctly") { val conf = new SparkConf conf.set("spark.mesos.executor.home" , "/mesos-home") @@ -60,14 +94,17 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi val mesosSchedulerBackend = new MesosSchedulerBackend(taskScheduler, sc, "master") + val resources = Arrays.asList( + mesosSchedulerBackend.createResource("cpus", 4), + mesosSchedulerBackend.createResource("mem", 1024)) // uri is null. - val executorInfo = mesosSchedulerBackend.createExecutorInfo("test-id") + val (executorInfo, _) = mesosSchedulerBackend.createExecutorInfo(resources, "test-id") assert(executorInfo.getCommand.getValue === s" /mesos-home/bin/spark-class ${classOf[MesosExecutorBackend].getName}") // uri exists. conf.set("spark.executor.uri", "hdfs:///test-app-1.0.0.tgz") - val executorInfo1 = mesosSchedulerBackend.createExecutorInfo("test-id") + val (executorInfo1, _) = mesosSchedulerBackend.createExecutorInfo(resources, "test-id") assert(executorInfo1.getCommand.getValue === s"cd test-app-1*; ./bin/spark-class ${classOf[MesosExecutorBackend].getName}") } @@ -93,7 +130,8 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") - val execInfo = backend.createExecutorInfo("mockExecutor") + val (execInfo, _) = backend.createExecutorInfo( + Arrays.asList(backend.createResource("cpus", 4)), "mockExecutor") assert(execInfo.getContainer.getDocker.getImage.equals("spark/mock")) val portmaps = execInfo.getContainer.getDocker.getPortMappingsList assert(portmaps.get(0).getHostPort.equals(80)) @@ -149,7 +187,9 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi when(sc.conf).thenReturn(new SparkConf) when(sc.listenerBus).thenReturn(listenerBus) - val minMem = MemoryUtils.calculateTotalMemory(sc).toInt + val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") + + val minMem = backend.calculateTotalMemory(sc) val minCpu = 4 val mesosOffers = new java.util.ArrayList[Offer] @@ -157,8 +197,6 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi mesosOffers.add(createOffer(2, minMem - 1, minCpu)) mesosOffers.add(createOffer(3, minMem, minCpu)) - val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") - val expectedWorkerOffers = new ArrayBuffer[WorkerOffer](2) expectedWorkerOffers.append(new WorkerOffer( mesosOffers.get(0).getSlaveId.getValue, @@ -174,7 +212,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) when(taskScheduler.CPUS_PER_TASK).thenReturn(2) - val capture = ArgumentCaptor.forClass(classOf[util.Collection[TaskInfo]]) + val capture = ArgumentCaptor.forClass(classOf[Collection[TaskInfo]]) when( driver.launchTasks( Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), @@ -194,7 +232,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi ) verify(driver, times(1)).declineOffer(mesosOffers.get(1).getId) verify(driver, times(1)).declineOffer(mesosOffers.get(2).getId) - assert(capture.getValue.size() == 1) + assert(capture.getValue.size() === 1) val taskInfo = capture.getValue.iterator().next() assert(taskInfo.getName.equals("n1")) val cpus = taskInfo.getResourcesList.get(0) @@ -214,4 +252,96 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi backend.resourceOffers(driver, mesosOffers2) verify(driver, times(1)).declineOffer(mesosOffers2.get(0).getId) } + + test("can handle multiple roles") { + val driver = mock[SchedulerDriver] + val taskScheduler = mock[TaskSchedulerImpl] + + val listenerBus = mock[LiveListenerBus] + listenerBus.post( + SparkListenerExecutorAdded(anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty))) + + val sc = mock[SparkContext] + when(sc.executorMemory).thenReturn(100) + when(sc.getSparkHome()).thenReturn(Option("/path")) + when(sc.executorEnvs).thenReturn(new mutable.HashMap[String, String]) + when(sc.conf).thenReturn(new SparkConf) + when(sc.listenerBus).thenReturn(listenerBus) + + val id = 1 + val builder = Offer.newBuilder() + builder.addResourcesBuilder() + .setName("mem") + .setType(Value.Type.SCALAR) + .setRole("prod") + .setScalar(Scalar.newBuilder().setValue(500)) + builder.addResourcesBuilder() + .setName("cpus") + .setRole("prod") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(1)) + builder.addResourcesBuilder() + .setName("mem") + .setRole("dev") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(600)) + builder.addResourcesBuilder() + .setName("cpus") + .setRole("dev") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(2)) + val offer = builder.setId(OfferID.newBuilder().setValue(s"o${id.toString}").build()) + .setFrameworkId(FrameworkID.newBuilder().setValue("f1")) + .setSlaveId(SlaveID.newBuilder().setValue(s"s${id.toString}")) + .setHostname(s"host${id.toString}").build() + + val mesosOffers = new java.util.ArrayList[Offer] + mesosOffers.add(offer) + + val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") + + val expectedWorkerOffers = new ArrayBuffer[WorkerOffer](1) + expectedWorkerOffers.append(new WorkerOffer( + mesosOffers.get(0).getSlaveId.getValue, + mesosOffers.get(0).getHostname, + 2 // Deducting 1 for executor + )) + + val taskDesc = new TaskDescription(1L, 0, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0))) + when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) + when(taskScheduler.CPUS_PER_TASK).thenReturn(1) + + val capture = ArgumentCaptor.forClass(classOf[Collection[TaskInfo]]) + when( + driver.launchTasks( + Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), + capture.capture(), + any(classOf[Filters]) + ) + ).thenReturn(Status.valueOf(1)) + + backend.resourceOffers(driver, mesosOffers) + + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), + capture.capture(), + any(classOf[Filters]) + ) + + assert(capture.getValue.size() === 1) + val taskInfo = capture.getValue.iterator().next() + assert(taskInfo.getName.equals("n1")) + assert(taskInfo.getResourcesCount === 1) + val cpusDev = taskInfo.getResourcesList.get(0) + assert(cpusDev.getName.equals("cpus")) + assert(cpusDev.getScalar.getValue.equals(1.0)) + assert(cpusDev.getRole.equals("dev")) + val executorResources = taskInfo.getExecutor.getResourcesList.asScala + assert(executorResources.exists { r => + r.getName.equals("mem") && r.getScalar.getValue.equals(484.0) && r.getRole.equals("prod") + }) + assert(executorResources.exists { r => + r.getName.equals("cpus") && r.getScalar.getValue.equals(1.0) && r.getRole.equals("prod") + }) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala new file mode 100644 index 000000000000..2eb43b731338 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import scala.language.reflectiveCalls + +import org.apache.mesos.Protos.Value +import org.mockito.Mockito._ +import org.scalatest._ +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} + +class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoSugar { + + // scalastyle:off structural.type + // this is the documented way of generating fixtures in scalatest + def fixture: Object {val sc: SparkContext; val sparkConf: SparkConf} = new { + val sparkConf = new SparkConf + val sc = mock[SparkContext] + when(sc.conf).thenReturn(sparkConf) + } + val utils = new MesosSchedulerUtils { } + // scalastyle:on structural.type + + test("use at-least minimum overhead") { + val f = fixture + when(f.sc.executorMemory).thenReturn(512) + utils.calculateTotalMemory(f.sc) shouldBe 896 + } + + test("use overhead if it is greater than minimum value") { + val f = fixture + when(f.sc.executorMemory).thenReturn(4096) + utils.calculateTotalMemory(f.sc) shouldBe 4505 + } + + test("use spark.mesos.executor.memoryOverhead (if set)") { + val f = fixture + when(f.sc.executorMemory).thenReturn(1024) + f.sparkConf.set("spark.mesos.executor.memoryOverhead", "512") + utils.calculateTotalMemory(f.sc) shouldBe 1536 + } + + test("parse a non-empty constraint string correctly") { + val expectedMap = Map( + "tachyon" -> Set("true"), + "zone" -> Set("us-east-1a", "us-east-1b") + ) + utils.parseConstraintString("tachyon:true;zone:us-east-1a,us-east-1b") should be (expectedMap) + } + + test("parse an empty constraint string correctly") { + utils.parseConstraintString("") shouldBe Map() + } + + test("throw an exception when the input is malformed") { + an[IllegalArgumentException] should be thrownBy + utils.parseConstraintString("tachyon;zone:us-east") + } + + test("empty values for attributes' constraints matches all values") { + val constraintsStr = "tachyon:" + val parsedConstraints = utils.parseConstraintString(constraintsStr) + + parsedConstraints shouldBe Map("tachyon" -> Set()) + + val zoneSet = Value.Set.newBuilder().addItem("us-east-1a").addItem("us-east-1b").build() + val noTachyonOffer = Map("zone" -> zoneSet) + val tachyonTrueOffer = Map("tachyon" -> Value.Text.newBuilder().setValue("true").build()) + val tachyonFalseOffer = Map("tachyon" -> Value.Text.newBuilder().setValue("false").build()) + + utils.matchesAttributeRequirements(parsedConstraints, noTachyonOffer) shouldBe false + utils.matchesAttributeRequirements(parsedConstraints, tachyonTrueOffer) shouldBe true + utils.matchesAttributeRequirements(parsedConstraints, tachyonFalseOffer) shouldBe true + } + + test("subset match is performed for set attributes") { + val supersetConstraint = Map( + "tachyon" -> Value.Text.newBuilder().setValue("true").build(), + "zone" -> Value.Set.newBuilder() + .addItem("us-east-1a") + .addItem("us-east-1b") + .addItem("us-east-1c") + .build()) + + val zoneConstraintStr = "tachyon:;zone:us-east-1a,us-east-1c" + val parsedConstraints = utils.parseConstraintString(zoneConstraintStr) + + utils.matchesAttributeRequirements(parsedConstraints, supersetConstraint) shouldBe true + } + + test("less than equal match is performed on scalar attributes") { + val offerAttribs = Map("gpus" -> Value.Scalar.newBuilder().setValue(3).build()) + + val ltConstraint = utils.parseConstraintString("gpus:2") + val eqConstraint = utils.parseConstraintString("gpus:3") + val gtConstraint = utils.parseConstraintString("gpus:4") + + utils.matchesAttributeRequirements(ltConstraint, offerAttribs) shouldBe true + utils.matchesAttributeRequirements(eqConstraint, offerAttribs) shouldBe true + utils.matchesAttributeRequirements(gtConstraint, offerAttribs) shouldBe false + } + + test("contains match is performed for range attributes") { + val offerAttribs = Map("ports" -> Value.Range.newBuilder().setBegin(7000).setEnd(8000).build()) + val ltConstraint = utils.parseConstraintString("ports:6000") + val eqConstraint = utils.parseConstraintString("ports:7500") + val gtConstraint = utils.parseConstraintString("ports:8002") + val multiConstraint = utils.parseConstraintString("ports:5000,7500,8300") + + utils.matchesAttributeRequirements(ltConstraint, offerAttribs) shouldBe false + utils.matchesAttributeRequirements(eqConstraint, offerAttribs) shouldBe true + utils.matchesAttributeRequirements(gtConstraint, offerAttribs) shouldBe false + utils.matchesAttributeRequirements(multiConstraint, offerAttribs) shouldBe true + } + + test("equality match is performed for text attributes") { + val offerAttribs = Map("tachyon" -> Value.Text.newBuilder().setValue("true").build()) + + val trueConstraint = utils.parseConstraintString("tachyon:true") + val falseConstraint = utils.parseConstraintString("tachyon:false") + + utils.matchesAttributeRequirements(trueConstraint, offerAttribs) shouldBe true + utils.matchesAttributeRequirements(falseConstraint, offerAttribs) shouldBe false + } + +} diff --git a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala new file mode 100644 index 000000000000..bc9f3708ed69 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.ByteBuffer + +import com.esotericsoftware.kryo.io.{Output, Input} +import org.apache.avro.{SchemaBuilder, Schema} +import org.apache.avro.generic.GenericData.Record + +import org.apache.spark.{SparkFunSuite, SharedSparkContext} + +class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + + val schema : Schema = SchemaBuilder + .record("testRecord").fields() + .requiredString("data") + .endRecord() + val record = new Record(schema) + record.put("data", "test data") + + test("schema compression and decompression") { + val genericSer = new GenericAvroSerializer(conf.getAvroSchema) + assert(schema === genericSer.decompress(ByteBuffer.wrap(genericSer.compress(schema)))) + } + + test("record serialization and deserialization") { + val genericSer = new GenericAvroSerializer(conf.getAvroSchema) + + val outputStream = new ByteArrayOutputStream() + val output = new Output(outputStream) + genericSer.serializeDatum(record, output) + output.flush() + output.close() + + val input = new Input(new ByteArrayInputStream(outputStream.toByteArray)) + assert(genericSer.deserializeDatum(input) === record) + } + + test("uses schema fingerprint to decrease message size") { + val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema) + + val output = new Output(new ByteArrayOutputStream()) + + val beginningNormalPosition = output.total() + genericSerFull.serializeDatum(record, output) + output.flush() + val normalLength = output.total - beginningNormalPosition + + conf.registerAvroSchemas(schema) + val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema) + val beginningFingerprintPosition = output.total() + genericSerFinger.serializeDatum(record, output) + val fingerprintLength = output.total - beginningFingerprintPosition + + assert(fingerprintLength < normalLength) + } + + test("caches previously seen schemas") { + val genericSer = new GenericAvroSerializer(conf.getAvroSchema) + val compressedSchema = genericSer.compress(schema) + val decompressedScheam = genericSer.decompress(ByteBuffer.wrap(compressedSchema)) + + assert(compressedSchema.eq(genericSer.compress(schema))) + assert(decompressedScheam.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema)))) + } +} diff --git a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala index 329a2b6dad83..20f45670bc2b 100644 --- a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala @@ -25,4 +25,22 @@ class JavaSerializerSuite extends SparkFunSuite { val instance = serializer.newInstance() instance.deserialize[JavaSerializer](instance.serialize(serializer)) } + + test("Deserialize object containing a primitive Class as attribute") { + val serializer = new JavaSerializer(new SparkConf()) + val instance = serializer.newInstance() + instance.deserialize[JavaSerializer](instance.serialize(new ContainsPrimitiveClass())) + } +} + +private class ContainsPrimitiveClass extends Serializable { + val intClass = classOf[Int] + val longClass = classOf[Long] + val shortClass = classOf[Short] + val charClass = classOf[Char] + val doubleClass = classOf[Double] + val floatClass = classOf[Float] + val booleanClass = classOf[Boolean] + val byteClass = classOf[Byte] + val voidClass = classOf[Void] } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala index 63a8480c9b57..935a091f14f9 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -35,7 +35,7 @@ class KryoSerializerDistributedSuite extends SparkFunSuite { val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName)) conf.setJars(List(jar.getPath)) - val sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + val sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) val original = Thread.currentThread.getContextClassLoader val loader = new java.net.URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader) SparkEnv.get.serializer.setDefaultClassLoader(loader) @@ -59,7 +59,9 @@ object KryoDistributedTest { class AppJarRegistrator extends KryoRegistrator { override def registerClasses(k: Kryo) { val classLoader = Thread.currentThread.getContextClassLoader + // scalastyle:off classforname k.register(Class.forName(AppJarRegistrator.customClassName, true, classLoader)) + // scalastyle:on classforname } } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 23a1fdb0f500..e428414cf6e8 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.serializer import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag @@ -149,6 +150,36 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { mutable.HashMap(1->"one", 2->"two", 3->"three"))) } + test("Bug: SPARK-10251") { + val ser = new KryoSerializer(conf.clone.set("spark.kryo.registrationRequired", "true")) + .newInstance() + def check[T: ClassTag](t: T) { + assert(ser.deserialize[T](ser.serialize(t)) === t) + } + check((1, 3)) + check(Array((1, 3))) + check(List((1, 3))) + check(List[Int]()) + check(List[Int](1, 2, 3)) + check(List[String]()) + check(List[String]("x", "y", "z")) + check(None) + check(Some(1)) + check(Some("hi")) + check(1 -> 1) + check(mutable.ArrayBuffer(1, 2, 3)) + check(mutable.ArrayBuffer("1", "2", "3")) + check(mutable.Map()) + check(mutable.Map(1 -> "one", 2 -> "two")) + check(mutable.Map("one" -> 1, "two" -> 2)) + check(mutable.HashMap(1 -> "one", 2 -> "two")) + check(mutable.HashMap("one" -> 1, "two" -> 2)) + check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) + check(List( + mutable.HashMap("one" -> 1, "two" -> 2), + mutable.HashMap(1->"one", 2->"two", 3->"three"))) + } + test("ranges") { val ser = new KryoSerializer(conf).newInstance() def check[T: ClassTag](t: T) { @@ -173,7 +204,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { test("asJavaIterable") { // Serialize a collection wrapped by asJavaIterable val ser = new KryoSerializer(conf).newInstance() - val a = ser.serialize(scala.collection.convert.WrapAsJava.asJavaIterable(Seq(12345))) + val a = ser.serialize(Seq(12345).asJava) val b = ser.deserialize[java.lang.Iterable[Int]](a) assert(b.iterator().next() === 12345) diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDependencySuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDependencySuite.scala new file mode 100644 index 000000000000..4d5f599fb12a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDependencySuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle + +import org.apache.spark._ + +case class KeyClass() + +case class ValueClass() + +case class CombinerClass() + +class ShuffleDependencySuite extends SparkFunSuite with LocalSparkContext { + + val conf = new SparkConf(loadDefaults = false) + + test("key, value, and combiner classes correct in shuffle dependency without aggregation") { + sc = new SparkContext("local", "test", conf.clone()) + val rdd = sc.parallelize(1 to 5, 4) + .map(key => (KeyClass(), ValueClass())) + .groupByKey() + val dep = rdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(!dep.mapSideCombine, "Test requires that no map-side aggregator is defined") + assert(dep.keyClassName == classOf[KeyClass].getName) + assert(dep.valueClassName == classOf[ValueClass].getName) + } + + test("key, value, and combiner classes available in shuffle dependency with aggregation") { + sc = new SparkContext("local", "test", conf.clone()) + val rdd = sc.parallelize(1 to 5, 4) + .map(key => (KeyClass(), ValueClass())) + .aggregateByKey(CombinerClass())({ case (a, b) => a }, { case (a, b) => a }) + val dep = rdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(dep.mapSideCombine && dep.aggregator.isDefined, "Test requires map-side aggregation") + assert(dep.keyClassName == classOf[KeyClass].getName) + assert(dep.valueClassName == classOf[ValueClass].getName) + assert(dep.combinerClassName == Some(classOf[CombinerClass].getName)) + } + + test("combineByKey null combiner class tag handled correctly") { + sc = new SparkContext("local", "test", conf.clone()) + val rdd = sc.parallelize(1 to 5, 4) + .map(key => (KeyClass(), ValueClass())) + .combineByKey((v: ValueClass) => v, + (c: AnyRef, v: ValueClass) => c, + (c1: AnyRef, c2: AnyRef) => c1) + val dep = rdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(dep.keyClassName == classOf[KeyClass].getName) + assert(dep.valueClassName == classOf[ValueClass].getName) + assert(dep.combinerClassName == None) + } + +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala index 96778c9ebafb..6d45b1a101be 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala @@ -17,27 +17,40 @@ package org.apache.spark.shuffle +import java.util.concurrent.CountDownLatch +import java.util.concurrent.atomic.AtomicInteger + +import org.mockito.Mockito._ import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ -import java.util.concurrent.atomic.AtomicBoolean -import java.util.concurrent.CountDownLatch -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext} class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { + + val nextTaskAttemptId = new AtomicInteger() + /** Launch a thread with the given body block and return it. */ private def startThread(name: String)(body: => Unit): Thread = { val thread = new Thread("ShuffleMemorySuite " + name) { override def run() { - body + try { + val taskAttemptId = nextTaskAttemptId.getAndIncrement + val mockTaskContext = mock(classOf[TaskContext], RETURNS_SMART_NULLS) + when(mockTaskContext.taskAttemptId()).thenReturn(taskAttemptId) + TaskContext.setTaskContext(mockTaskContext) + body + } finally { + TaskContext.unset() + } } } thread.start() thread } - test("single thread requesting memory") { - val manager = new ShuffleMemoryManager(1000L) + test("single task requesting memory") { + val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) assert(manager.tryToAcquire(100L) === 100L) assert(manager.tryToAcquire(400L) === 400L) @@ -50,7 +63,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { assert(manager.tryToAcquire(300L) === 300L) assert(manager.tryToAcquire(300L) === 200L) - manager.releaseMemoryForThisThread() + manager.releaseMemoryForThisTask() assert(manager.tryToAcquire(1000L) === 1000L) assert(manager.tryToAcquire(100L) === 0L) } @@ -59,7 +72,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { // Two threads request 500 bytes first, wait for each other to get it, and then request // 500 more; we should immediately return 0 as both are now at 1 / N - val manager = new ShuffleMemoryManager(1000L) + val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) class State { var t1Result1 = -1L @@ -107,11 +120,11 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } - test("threads cannot grow past 1 / N") { - // Two threads request 250 bytes first, wait for each other to get it, and then request + test("tasks cannot grow past 1 / N") { + // Two tasks request 250 bytes first, wait for each other to get it, and then request // 500 more; we should only grant 250 bytes to each of them on this second request - val manager = new ShuffleMemoryManager(1000L) + val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) class State { var t1Result1 = -1L @@ -158,12 +171,12 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { assert(state.t2Result2 === 250L) } - test("threads can block to get at least 1 / 2N memory") { + test("tasks can block to get at least 1 / 2N memory") { // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps // for a bit and releases 250 bytes, which should then be granted to t2. Further requests // by t2 will return false right away because it now has 1 / 2N of the memory. - val manager = new ShuffleMemoryManager(1000L) + val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) class State { var t1Requested = false @@ -224,11 +237,11 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } } - test("releaseMemoryForThisThread") { + test("releaseMemoryForThisTask") { // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps // for a bit and releases all its memory. t2 should now be able to grab all the memory. - val manager = new ShuffleMemoryManager(1000L) + val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) class State { var t1Requested = false @@ -251,9 +264,9 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } } // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make - // sure the other thread blocks for some time otherwise + // sure the other task blocks for some time otherwise Thread.sleep(300) - manager.releaseMemoryForThisThread() + manager.releaseMemoryForThisTask() } val t2 = startThread("t2") { @@ -282,7 +295,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { t2.join() } - // Both threads should've been able to acquire their memory; the second one will have waited + // Both tasks should've been able to acquire their memory; the second one will have waited // until the first one acquired 1000 bytes and then released all of it state.synchronized { assert(state.t1Result === 1000L, "t1 could not allocate memory") @@ -293,8 +306,8 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } } - test("threads should not be granted a negative size") { - val manager = new ShuffleMemoryManager(1000L) + test("tasks should not be granted a negative size") { + val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) manager.tryToAcquire(700L) val latch = new CountDownLatch(1) diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala new file mode 100644 index 000000000000..05b3afef5b83 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.hash + +import java.io.{ByteArrayOutputStream, InputStream} +import java.nio.ByteBuffer + +import org.mockito.Matchers.{eq => meq, _} +import org.mockito.Mockito.{mock, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer + +import org.apache.spark._ +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.shuffle.BaseShuffleHandle +import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} + +/** + * Wrapper for a managed buffer that keeps track of how many times retain and release are called. + * + * We need to define this class ourselves instead of using a spy because the NioManagedBuffer class + * is final (final classes cannot be spied on). + */ +class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends ManagedBuffer { + var callsToRetain = 0 + var callsToRelease = 0 + + override def size(): Long = underlyingBuffer.size() + override def nioByteBuffer(): ByteBuffer = underlyingBuffer.nioByteBuffer() + override def createInputStream(): InputStream = underlyingBuffer.createInputStream() + override def convertToNetty(): AnyRef = underlyingBuffer.convertToNetty() + + override def retain(): ManagedBuffer = { + callsToRetain += 1 + underlyingBuffer.retain() + } + override def release(): ManagedBuffer = { + callsToRelease += 1 + underlyingBuffer.release() + } +} + +class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { + + /** + * This test makes sure that, when data is read from a HashShuffleReader, the underlying + * ManagedBuffers that contain the data are eventually released. + */ + test("read() releases resources on completion") { + val testConf = new SparkConf(false) + // Create a SparkContext as a convenient way of setting SparkEnv (needed because some of the + // shuffle code calls SparkEnv.get()). + sc = new SparkContext("local", "test", testConf) + + val reduceId = 15 + val shuffleId = 22 + val numMaps = 6 + val keyValuePairsPerMap = 10 + val serializer = new JavaSerializer(testConf) + + // Make a mock BlockManager that will return RecordingManagedByteBuffers of data, so that we + // can ensure retain() and release() are properly called. + val blockManager = mock(classOf[BlockManager]) + + // Create a return function to use for the mocked wrapForCompression method that just returns + // the original input stream. + val dummyCompressionFunction = new Answer[InputStream] { + override def answer(invocation: InvocationOnMock): InputStream = + invocation.getArguments()(1).asInstanceOf[InputStream] + } + + // Create a buffer with some randomly generated key-value pairs to use as the shuffle data + // from each mappers (all mappers return the same shuffle data). + val byteOutputStream = new ByteArrayOutputStream() + val serializationStream = serializer.newInstance().serializeStream(byteOutputStream) + (0 until keyValuePairsPerMap).foreach { i => + serializationStream.writeKey(i) + serializationStream.writeValue(2*i) + } + + // Setup the mocked BlockManager to return RecordingManagedBuffers. + val localBlockManagerId = BlockManagerId("test-client", "test-client", 1) + when(blockManager.blockManagerId).thenReturn(localBlockManagerId) + val buffers = (0 until numMaps).map { mapId => + // Create a ManagedBuffer with the shuffle data. + val nioBuffer = new NioManagedBuffer(ByteBuffer.wrap(byteOutputStream.toByteArray)) + val managedBuffer = new RecordingManagedBuffer(nioBuffer) + + // Setup the blockManager mock so the buffer gets returned when the shuffle code tries to + // fetch shuffle data. + val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) + when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer) + when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream]))) + .thenAnswer(dummyCompressionFunction) + + managedBuffer + } + + // Make a mocked MapOutputTracker for the shuffle reader to use to determine what + // shuffle data to read. + val mapOutputTracker = mock(classOf[MapOutputTracker]) + when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId)).thenReturn { + // Test a scenario where all data is local, to avoid creating a bunch of additional mocks + // for the code to read data over the network. + val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => + val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) + (shuffleBlockId, byteOutputStream.size().toLong) + } + Seq((localBlockManagerId, shuffleBlockIdsAndSizes)) + } + + // Create a mocked shuffle handle to pass into HashShuffleReader. + val shuffleHandle = { + val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]]) + when(dependency.serializer).thenReturn(Some(serializer)) + when(dependency.aggregator).thenReturn(None) + when(dependency.keyOrdering).thenReturn(None) + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } + + val shuffleReader = new HashShuffleReader( + shuffleHandle, + reduceId, + reduceId + 1, + TaskContext.empty(), + blockManager, + mapOutputTracker) + + assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps) + + // Calling .length above will have exhausted the iterator; make sure that exhausting the + // iterator caused retain and release to be called on each buffer. + buffers.foreach { buffer => + assert(buffer.callsToRetain === 1) + assert(buffer.callsToRelease === 1) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 542f8f45125a..cc7342f1ecd7 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -68,8 +68,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte any[SerializerInstance], anyInt(), any[ShuffleWriteMetrics] - )).thenAnswer(new Answer[BlockObjectWriter] { - override def answer(invocation: InvocationOnMock): BlockObjectWriter = { + )).thenAnswer(new Answer[DiskBlockObjectWriter] { + override def answer(invocation: InvocationOnMock): DiskBlockObjectWriter = { val args = invocation.getArguments new DiskBlockObjectWriter( args(0).asInstanceOf[BlockId], diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 0f5ba46f69c2..eb5af70d57ae 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -26,10 +26,10 @@ import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ +import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv import org.apache.spark._ import org.apache.spark.network.BlockTransferService -import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager @@ -38,7 +38,7 @@ import org.apache.spark.storage.StorageLevel._ /** Testsuite that tests block replication in BlockManager */ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with BeforeAndAfter { - private val conf = new SparkConf(false) + private val conf = new SparkConf(false).set("spark.app.id", "test") var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null val securityMgr = new SecurityManager(conf) @@ -59,7 +59,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo private def makeBlockManager( maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { - val transfer = new NioBlockTransferService(conf, securityMgr) + val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) val store = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) store.initialize("app-id") diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index bcee901f5dd5..34bb4952e724 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -30,10 +30,10 @@ import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ +import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv import org.apache.spark._ import org.apache.spark.executor.DataReadMethod -import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.shuffle.hash.HashShuffleManager @@ -44,9 +44,10 @@ import org.apache.spark.util._ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach with PrivateMethodTester with ResetSystemProperties { - private val conf = new SparkConf(false) + private val conf = new SparkConf(false).set("spark.app.id", "test") var store: BlockManager = null var store2: BlockManager = null + var store3: BlockManager = null var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null conf.set("spark.authenticate", "false") @@ -65,7 +66,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE private def makeBlockManager( maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { - val transfer = new NioBlockTransferService(conf, securityMgr) + val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) val manager = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) manager.initialize("app-id") @@ -99,6 +100,10 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store2.stop() store2 = null } + if (store3 != null) { + store3.stop() + store3 = null + } rpcEnv.shutdown() rpcEnv.awaitTermination() rpcEnv = null @@ -443,6 +448,38 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(list2DiskGet.get.readMethod === DataReadMethod.Disk) } + test("SPARK-9591: getRemoteBytes from another location when Exception throw") { + val origTimeoutOpt = conf.getOption("spark.network.timeout") + try { + conf.set("spark.network.timeout", "2s") + store = makeBlockManager(8000, "executor1") + store2 = makeBlockManager(8000, "executor2") + store3 = makeBlockManager(8000, "executor3") + val list1 = List(new Array[Byte](4000)) + store2.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store3.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + var list1Get = store.getRemoteBytes("list1") + assert(list1Get.isDefined, "list1Get expected to be fetched") + // block manager exit + store2.stop() + store2 = null + list1Get = store.getRemoteBytes("list1") + // get `list1` block + assert(list1Get.isDefined, "list1Get expected to be fetched") + store3.stop() + store3 = null + // exception throw because there is no locations + intercept[BlockFetchException] { + list1Get = store.getRemoteBytes("list1") + } + } finally { + origTimeoutOpt match { + case Some(t) => conf.set("spark.network.timeout", t) + case None => conf.remove("spark.network.timeout") + } + } + } + test("in-memory LRU storage") { store = makeBlockManager(12000) val a1 = new Array[Byte](4000) @@ -782,7 +819,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("block store put failure") { // Use Java serializer so we can create an unserializable error. - val transfer = new NioBlockTransferService(conf, securityMgr) + val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) @@ -796,7 +833,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // Make sure get a1 doesn't hang and returns None. failAfter(1 second) { - assert(store.getSingle("a1") == None, "a1 should not be in store") + assert(store.getSingle("a1").isEmpty, "a1 should not be in store") } } @@ -1004,32 +1041,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store = makeBlockManager(12000) val memoryStore = store.memoryStore assert(memoryStore.currentUnrollMemory === 0) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Reserve - memoryStore.reserveUnrollMemoryForThisThread(100) - assert(memoryStore.currentUnrollMemoryForThisThread === 100) - memoryStore.reserveUnrollMemoryForThisThread(200) - assert(memoryStore.currentUnrollMemoryForThisThread === 300) - memoryStore.reserveUnrollMemoryForThisThread(500) - assert(memoryStore.currentUnrollMemoryForThisThread === 800) - memoryStore.reserveUnrollMemoryForThisThread(1000000) - assert(memoryStore.currentUnrollMemoryForThisThread === 800) // not granted + memoryStore.reserveUnrollMemoryForThisTask(100) + assert(memoryStore.currentUnrollMemoryForThisTask === 100) + memoryStore.reserveUnrollMemoryForThisTask(200) + assert(memoryStore.currentUnrollMemoryForThisTask === 300) + memoryStore.reserveUnrollMemoryForThisTask(500) + assert(memoryStore.currentUnrollMemoryForThisTask === 800) + memoryStore.reserveUnrollMemoryForThisTask(1000000) + assert(memoryStore.currentUnrollMemoryForThisTask === 800) // not granted // Release - memoryStore.releaseUnrollMemoryForThisThread(100) - assert(memoryStore.currentUnrollMemoryForThisThread === 700) - memoryStore.releaseUnrollMemoryForThisThread(100) - assert(memoryStore.currentUnrollMemoryForThisThread === 600) + memoryStore.releaseUnrollMemoryForThisTask(100) + assert(memoryStore.currentUnrollMemoryForThisTask === 700) + memoryStore.releaseUnrollMemoryForThisTask(100) + assert(memoryStore.currentUnrollMemoryForThisTask === 600) // Reserve again - memoryStore.reserveUnrollMemoryForThisThread(4400) - assert(memoryStore.currentUnrollMemoryForThisThread === 5000) - memoryStore.reserveUnrollMemoryForThisThread(20000) - assert(memoryStore.currentUnrollMemoryForThisThread === 5000) // not granted + memoryStore.reserveUnrollMemoryForThisTask(4400) + assert(memoryStore.currentUnrollMemoryForThisTask === 5000) + memoryStore.reserveUnrollMemoryForThisTask(20000) + assert(memoryStore.currentUnrollMemoryForThisTask === 5000) // not granted // Release again - memoryStore.releaseUnrollMemoryForThisThread(1000) - assert(memoryStore.currentUnrollMemoryForThisThread === 4000) - memoryStore.releaseUnrollMemoryForThisThread() // release all - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + memoryStore.releaseUnrollMemoryForThisTask(1000) + assert(memoryStore.currentUnrollMemoryForThisTask === 4000) + memoryStore.releaseUnrollMemoryForThisTask() // release all + assert(memoryStore.currentUnrollMemoryForThisTask === 0) } /** @@ -1060,24 +1097,24 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val bigList = List.fill(40)(new Array[Byte](1000)) val memoryStore = store.memoryStore val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll with all the space in the world. This should succeed and return an array. var unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks) verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) - memoryStore.releasePendingUnrollMemoryForThisThread() + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + memoryStore.releasePendingUnrollMemoryForThisTask() // Unroll with not enough space. This should succeed after kicking out someBlock1. store.putIterator("someBlock1", smallList.iterator, StorageLevel.MEMORY_ONLY) store.putIterator("someBlock2", smallList.iterator, StorageLevel.MEMORY_ONLY) unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks) verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) assert(droppedBlocks.size === 1) assert(droppedBlocks.head._1 === TestBlockId("someBlock1")) droppedBlocks.clear() - memoryStore.releasePendingUnrollMemoryForThisThread() + memoryStore.releasePendingUnrollMemoryForThisTask() // Unroll huge block with not enough space. Even after ensuring free space of 12000 * 0.4 = // 4800 bytes, there is still not enough room to unroll this block. This returns an iterator. @@ -1085,7 +1122,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store.putIterator("someBlock3", smallList.iterator, StorageLevel.MEMORY_ONLY) unrollResult = memoryStore.unrollSafely("unroll", bigList.iterator, droppedBlocks) verifyUnroll(bigList.iterator, unrollResult, shouldBeArray = false) - assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator assert(droppedBlocks.size === 1) assert(droppedBlocks.head._1 === TestBlockId("someBlock2")) droppedBlocks.clear() @@ -1099,7 +1136,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val bigList = List.fill(40)(new Array[Byte](1000)) def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll with plenty of space. This should succeed and cache both blocks. val result1 = memoryStore.putIterator("b1", smallIterator, memOnly, returnValues = true) @@ -1110,7 +1147,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(result2.size > 0) assert(result1.data.isLeft) // unroll did not drop this block to disk assert(result2.data.isLeft) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Re-put these two blocks so block manager knows about them too. Otherwise, block manager // would not know how to drop them from memory later. @@ -1126,7 +1163,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!memoryStore.contains("b1")) assert(memoryStore.contains("b2")) assert(memoryStore.contains("b3")) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) memoryStore.remove("b3") store.putIterator("b3", smallIterator, memOnly) @@ -1138,7 +1175,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!memoryStore.contains("b2")) assert(memoryStore.contains("b3")) assert(!memoryStore.contains("b4")) - assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator } /** @@ -1153,7 +1190,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val bigList = List.fill(40)(new Array[Byte](1000)) def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) store.putIterator("b1", smallIterator, memAndDisk) store.putIterator("b2", smallIterator, memAndDisk) @@ -1170,7 +1207,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!diskStore.contains("b3")) memoryStore.remove("b3") store.putIterator("b3", smallIterator, StorageLevel.MEMORY_ONLY) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll huge block with not enough space. This should fail and drop the new block to disk // directly in addition to kicking out b2 in the process. Memory store should contain only @@ -1186,7 +1223,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(diskStore.contains("b2")) assert(!diskStore.contains("b3")) assert(diskStore.contains("b4")) - assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator } test("multiple unrolls by the same thread") { @@ -1195,32 +1232,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val memoryStore = store.memoryStore val smallList = List.fill(40)(new Array[Byte](100)) def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // All unroll memory used is released because unrollSafely returned an array memoryStore.putIterator("b1", smallIterator, memOnly, returnValues = true) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) memoryStore.putIterator("b2", smallIterator, memOnly, returnValues = true) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll memory is not released because unrollSafely returned an iterator // that still depends on the underlying vector used in the process memoryStore.putIterator("b3", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisTask assert(unrollMemoryAfterB3 > 0) // The unroll memory owned by this thread builds on top of its value after the previous unrolls memoryStore.putIterator("b4", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisTask assert(unrollMemoryAfterB4 > unrollMemoryAfterB3) // ... but only to a certain extent (until we run out of free space to grant new unroll memory) memoryStore.putIterator("b5", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisTask memoryStore.putIterator("b6", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisTask memoryStore.putIterator("b7", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisTask assert(unrollMemoryAfterB5 === unrollMemoryAfterB4) assert(unrollMemoryAfterB6 === unrollMemoryAfterB4) assert(unrollMemoryAfterB7 === unrollMemoryAfterB4) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala new file mode 100644 index 000000000000..d7ffde1e7864 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.apache.spark.SparkFunSuite +import org.apache.spark.scheduler._ + +class BlockStatusListenerSuite extends SparkFunSuite { + + test("basic functions") { + val blockManagerId = BlockManagerId("0", "localhost", 10000) + val listener = new BlockStatusListener() + + // Add a block manager and a new block status + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(0, blockManagerId, 0)) + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo( + blockManagerId, + StreamBlockId(0, 100), + StorageLevel.MEMORY_AND_DISK, + memSize = 100, + diskSize = 100, + externalBlockStoreSize = 0))) + // The new block status should be added to the listener + val expectedBlock = BlockUIData( + StreamBlockId(0, 100), + "localhost:10000", + StorageLevel.MEMORY_AND_DISK, + memSize = 100, + diskSize = 100, + externalBlockStoreSize = 0 + ) + val expectedExecutorStreamBlockStatus = Seq( + ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)) + ) + assert(listener.allExecutorStreamBlockStatus === expectedExecutorStreamBlockStatus) + + // Add the second block manager + val blockManagerId2 = BlockManagerId("1", "localhost", 10001) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(0, blockManagerId2, 0)) + // Add a new replication of the same block id from the second manager + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo( + blockManagerId2, + StreamBlockId(0, 100), + StorageLevel.MEMORY_AND_DISK, + memSize = 100, + diskSize = 100, + externalBlockStoreSize = 0))) + val expectedBlock2 = BlockUIData( + StreamBlockId(0, 100), + "localhost:10001", + StorageLevel.MEMORY_AND_DISK, + memSize = 100, + diskSize = 100, + externalBlockStoreSize = 0 + ) + // Each block manager should contain one block + val expectedExecutorStreamBlockStatus2 = Set( + ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)), + ExecutorStreamBlockStatus("1", "localhost:10001", Seq(expectedBlock2)) + ) + assert(listener.allExecutorStreamBlockStatus.toSet === expectedExecutorStreamBlockStatus2) + + // Remove a replication of the same block + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo( + blockManagerId2, + StreamBlockId(0, 100), + StorageLevel.NONE, // StorageLevel.NONE means removing it + memSize = 0, + diskSize = 0, + externalBlockStoreSize = 0))) + // Only the first block manager contains a block + val expectedExecutorStreamBlockStatus3 = Set( + ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)), + ExecutorStreamBlockStatus("1", "localhost:10001", Seq.empty) + ) + assert(listener.allExecutorStreamBlockStatus.toSet === expectedExecutorStreamBlockStatus3) + + // Remove the second block manager at first but add a new block status + // from this removed block manager + listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(0, blockManagerId2)) + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo( + blockManagerId2, + StreamBlockId(0, 100), + StorageLevel.MEMORY_AND_DISK, + memSize = 100, + diskSize = 100, + externalBlockStoreSize = 0))) + // The second block manager is removed so we should not see the new block + val expectedExecutorStreamBlockStatus4 = Seq( + ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)) + ) + assert(listener.allExecutorStreamBlockStatus === expectedExecutorStreamBlockStatus4) + + // Remove the last block manager + listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(0, blockManagerId)) + // No block manager now so we should dop all block managers + assert(listener.allExecutorStreamBlockStatus.isEmpty) + } + +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala similarity index 98% rename from core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala rename to core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala index 7bdea724fea5..66af6e1a7974 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.Utils -class BlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { +class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { var tempDir: File = _ diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 2a7fe67ad858..828153bdbfc4 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -17,23 +17,26 @@ package org.apache.spark.storage +import java.io.InputStream import java.util.concurrent.Semaphore -import scala.concurrent.future import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.future import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer +import org.scalatest.PrivateMethodTester -import org.apache.spark.{SparkConf, SparkFunSuite, TaskContextImpl} +import org.apache.spark.{SparkFunSuite, TaskContext} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.BlockFetchingListener -import org.apache.spark.serializer.TestSerializer +import org.apache.spark.shuffle.FetchFailedException -class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { + +class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester { // Some of the tests are quite tricky because we are testing the cleanup behavior // in the presence of faults. @@ -57,7 +60,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer } - private val conf = new SparkConf + // Create a mock managed buffer for testing + def createMockManagedBuffer(): ManagedBuffer = { + val mockManagedBuffer = mock(classOf[ManagedBuffer]) + when(mockManagedBuffer.createInputStream()).thenReturn(mock(classOf[InputStream])) + mockManagedBuffer + } test("successful 3 local reads + 2 remote reads") { val blockManager = mock(classOf[BlockManager]) @@ -66,9 +74,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure blockManager.getBlockData would return the blocks val localBlocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])) + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) localBlocks.foreach { case (blockId, buf) => doReturn(buf).when(blockManager).getBlockData(meq(blockId)) } @@ -76,9 +84,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val remoteBlocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 3, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 4, 0) -> mock(classOf[ManagedBuffer]) - ) + ShuffleBlockId(0, 3, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 4, 0) -> createMockManagedBuffer()) val transfer = createMockTransfer(remoteBlocks) @@ -88,11 +95,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { ) val iterator = new ShuffleBlockFetcherIterator( - new TaskContextImpl(0, 0, 0, 0, null), + TaskContext.empty(), transfer, blockManager, blocksByAddress, - new TestSerializer, 48 * 1024 * 1024) // 3 local blocks fetched in initialization @@ -100,15 +106,22 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") - val (blockId, subIterator) = iterator.next() - assert(subIterator.isSuccess, - s"iterator should have 5 elements defined but actually has $i elements") + val (blockId, inputStream) = iterator.next() - // Make sure we release the buffer once the iterator is exhausted. + // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) + // Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream + val wrappedInputStream = inputStream.asInstanceOf[BufferReleasingInputStream] verify(mockBuf, times(0)).release() - subIterator.get.foreach(_ => Unit) // exhaust the iterator + val delegateAccess = PrivateMethod[InputStream]('delegate) + + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(0)).close() + wrappedInputStream.close() + verify(mockBuf, times(1)).release() + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() + wrappedInputStream.close() // close should be idempotent verify(mockBuf, times(1)).release() + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() } // 3 local blocks, and 2 remote blocks @@ -125,10 +138,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]) - ) + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) @@ -153,21 +165,20 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0, null) + val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, blockManager, blocksByAddress, - new TestSerializer, 48 * 1024 * 1024) - // Exhaust the first block, and then it should be released. - iterator.next()._2.get.foreach(_ => Unit) + verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() + iterator.next()._2.close() // close() first block's input stream verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release() // Get the 2nd block but do not exhaust the iterator - val subIter = iterator.next()._2.get + val subIter = iterator.next()._2 // Complete the task; then the 2nd block buffer should be exhausted verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release() @@ -216,21 +227,21 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0, null) + val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, blockManager, blocksByAddress, - new TestSerializer, 48 * 1024 * 1024) // Continue only after the mock calls onBlockFetchFailure sem.acquire() - // The first block should be defined, and the last two are not defined (due to failure) - assert(iterator.next()._2.isSuccess) - assert(iterator.next()._2.isFailure) - assert(iterator.next()._2.isFailure) + // The first block should be returned without an exception, and the last two should throw + // FetchFailedExceptions (due to failure) + iterator.next() + intercept[FetchFailedException] { iterator.next() } + intercept[FetchFailedException] { iterator.next() } } } diff --git a/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala b/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala new file mode 100644 index 000000000000..cc76c141c53c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import scala.xml.Node + +import org.apache.spark.SparkFunSuite + +class PagedDataSourceSuite extends SparkFunSuite { + + test("basic") { + val dataSource1 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) + assert(dataSource1.pageData(1) === PageData(3, (1 to 2))) + + val dataSource2 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) + assert(dataSource2.pageData(2) === PageData(3, (3 to 4))) + + val dataSource3 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) + assert(dataSource3.pageData(3) === PageData(3, Seq(5))) + + val dataSource4 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) + val e1 = intercept[IndexOutOfBoundsException] { + dataSource4.pageData(4) + } + assert(e1.getMessage === "Page 4 is out of range. Please select a page number between 1 and 3.") + + val dataSource5 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) + val e2 = intercept[IndexOutOfBoundsException] { + dataSource5.pageData(0) + } + assert(e2.getMessage === "Page 0 is out of range. Please select a page number between 1 and 3.") + + } +} + +class PagedTableSuite extends SparkFunSuite { + test("pageNavigation") { + // Create a fake PagedTable to test pageNavigation + val pagedTable = new PagedTable[Int] { + override def tableId: String = "" + + override def tableCssClass: String = "" + + override def dataSource: PagedDataSource[Int] = null + + override def pageLink(page: Int): String = page.toString + + override def headers: Seq[Node] = Nil + + override def row(t: Int): Seq[Node] = Nil + + override def goButtonJavascriptFunction: (String, String) = ("", "") + } + + assert(pagedTable.pageNavigation(1, 10, 1) === Nil) + assert( + (pagedTable.pageNavigation(1, 10, 2).head \\ "li").map(_.text.trim) === Seq("1", "2", ">")) + assert( + (pagedTable.pageNavigation(2, 10, 2).head \\ "li").map(_.text.trim) === Seq("<", "1", "2")) + + assert((pagedTable.pageNavigation(1, 10, 100).head \\ "li").map(_.text.trim) === + (1 to 10).map(_.toString) ++ Seq(">", ">>")) + assert((pagedTable.pageNavigation(2, 10, 100).head \\ "li").map(_.text.trim) === + Seq("<") ++ (1 to 10).map(_.toString) ++ Seq(">", ">>")) + + assert((pagedTable.pageNavigation(100, 10, 100).head \\ "li").map(_.text.trim) === + Seq("<<", "<") ++ (91 to 100).map(_.toString)) + assert((pagedTable.pageNavigation(99, 10, 100).head \\ "li").map(_.text.trim) === + Seq("<<", "<") ++ (91 to 100).map(_.toString) ++ Seq(">")) + + assert((pagedTable.pageNavigation(11, 10, 100).head \\ "li").map(_.text.trim) === + Seq("<<", "<") ++ (11 to 20).map(_.toString) ++ Seq(">", ">>")) + assert((pagedTable.pageNavigation(93, 10, 97).head \\ "li").map(_.text.trim) === + Seq("<<", "<") ++ (91 to 97).map(_.toString) ++ Seq(">")) + } +} + +private[spark] class SeqPagedDataSource[T](seq: Seq[T], pageSize: Int) + extends PagedDataSource[T](pageSize) { + + override protected def dataSize: Int = seq.size + + override protected def sliceData(from: Int, to: Int): Seq[T] = seq.slice(from, to) +} diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala new file mode 100644 index 000000000000..86699e7f5695 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} + +import org.apache.spark._ +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler._ +import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab} +import org.apache.spark.ui.scope.RDDOperationGraphListener + +class StagePageSuite extends SparkFunSuite with LocalSparkContext { + + test("peak execution memory only displayed if unsafe is enabled") { + val unsafeConf = "spark.sql.unsafe.enabled" + val conf = new SparkConf(false).set(unsafeConf, "true") + val html = renderStagePage(conf).toString().toLowerCase + val targetString = "peak execution memory" + assert(html.contains(targetString)) + // Disable unsafe and make sure it's not there + val conf2 = new SparkConf(false).set(unsafeConf, "false") + val html2 = renderStagePage(conf2).toString().toLowerCase + assert(!html2.contains(targetString)) + // Avoid setting anything; it should be displayed by default + val conf3 = new SparkConf(false) + val html3 = renderStagePage(conf3).toString().toLowerCase + assert(html3.contains(targetString)) + } + + test("SPARK-10543: peak execution memory should be per-task rather than cumulative") { + val unsafeConf = "spark.sql.unsafe.enabled" + val conf = new SparkConf(false).set(unsafeConf, "true") + val html = renderStagePage(conf).toString().toLowerCase + // verify min/25/50/75/max show task value not cumulative values + assert(html.contains("10.0 b" * 5)) + } + + /** + * Render a stage page started with the given conf and return the HTML. + * This also runs a dummy stage to populate the page with useful content. + */ + private def renderStagePage(conf: SparkConf): Seq[Node] = { + val jobListener = new JobProgressListener(conf) + val graphListener = new RDDOperationGraphListener(conf) + val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS) + val request = mock(classOf[HttpServletRequest]) + when(tab.conf).thenReturn(conf) + when(tab.progressListener).thenReturn(jobListener) + when(tab.operationGraphListener).thenReturn(graphListener) + when(tab.appName).thenReturn("testing") + when(tab.headerTabs).thenReturn(Seq.empty) + when(request.getParameter("id")).thenReturn("0") + when(request.getParameter("attempt")).thenReturn("0") + val page = new StagePage(tab) + + // Simulate a stage in job progress listener + val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details") + // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness + (1 to 2).foreach { + taskId => + val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false) + val peakExecutionMemory = 10 + taskInfo.accumulables += new AccumulableInfo(0, InternalAccumulator.PEAK_EXECUTION_MEMORY, + Some(peakExecutionMemory.toString), (peakExecutionMemory * taskId).toString, true) + jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) + jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) + taskInfo.markSuccessful() + jobListener.onTaskEnd( + SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty)) + } + jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo)) + page.render(request) + } + +} diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 3aa672f8b713..22e30ecaf053 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ui import java.net.{HttpURLConnection, URL} import javax.servlet.http.{HttpServletResponse, HttpServletRequest} -import scala.collection.JavaConversions._ +import scala.io.Source import scala.xml.Node import com.gargoylesoftware.htmlunit.DefaultCssErrorHandler @@ -340,15 +340,15 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B // The completed jobs table should have two rows. The first row will be the most recent job: val firstRow = find(cssSelector("tbody tr")).get.underlying val firstRowColumns = firstRow.findElements(By.tagName("td")) - firstRowColumns(0).getText should be ("1") - firstRowColumns(4).getText should be ("1/1 (2 skipped)") - firstRowColumns(5).getText should be ("8/8 (16 skipped)") + firstRowColumns.get(0).getText should be ("1") + firstRowColumns.get(4).getText should be ("1/1 (2 skipped)") + firstRowColumns.get(5).getText should be ("8/8 (16 skipped)") // The second row is the first run of the job, where nothing was skipped: val secondRow = findAll(cssSelector("tbody tr")).toSeq(1).underlying val secondRowColumns = secondRow.findElements(By.tagName("td")) - secondRowColumns(0).getText should be ("0") - secondRowColumns(4).getText should be ("3/3") - secondRowColumns(5).getText should be ("24/24") + secondRowColumns.get(0).getText should be ("0") + secondRowColumns.get(4).getText should be ("3/3") + secondRowColumns.get(5).getText should be ("24/24") } } } @@ -501,8 +501,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B for { (row, idx) <- rows.zipWithIndex columns = row.findElements(By.tagName("td")) - id = columns(0).getText() - description = columns(1).getText() + id = columns.get(0).getText() + description = columns.get(1).getText() } { id should be (expJobInfo(idx)._1) description should include (expJobInfo(idx)._2) @@ -546,8 +546,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B for { (row, idx) <- rows.zipWithIndex columns = row.findElements(By.tagName("td")) - id = columns(0).getText() - description = columns(1).getText() + id = columns.get(0).getText() + description = columns.get(1).getText() } { id should be (expStageInfo(idx)._1) description should include (expStageInfo(idx)._2) @@ -603,6 +603,44 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } } + test("job stages should have expected dotfile under DAG visualization") { + withSpark(newSparkContext()) { sc => + // Create a multi-stage job + val rdd = + sc.parallelize(Seq(1, 2, 3)).map(identity).groupBy(identity).map(identity).groupBy(identity) + rdd.count() + + val stage0 = Source.fromURL(sc.ui.get.appUIAddress + + "/stages/stage/?id=0&attempt=0&expandDagViz=true").mkString + assert(stage0.contains("digraph G {\n subgraph clusterstage_0 {\n " + + "label="Stage 0";\n subgraph ")) + assert(stage0.contains("{\n label="parallelize";\n " + + "0 [label="ParallelCollectionRDD [0]"];\n }")) + assert(stage0.contains("{\n label="map";\n " + + "1 [label="MapPartitionsRDD [1]"];\n }")) + assert(stage0.contains("{\n label="groupBy";\n " + + "2 [label="MapPartitionsRDD [2]"];\n }")) + + val stage1 = Source.fromURL(sc.ui.get.appUIAddress + + "/stages/stage/?id=1&attempt=0&expandDagViz=true").mkString + assert(stage1.contains("digraph G {\n subgraph clusterstage_1 {\n " + + "label="Stage 1";\n subgraph ")) + assert(stage1.contains("{\n label="groupBy";\n " + + "3 [label="ShuffledRDD [3]"];\n }")) + assert(stage1.contains("{\n label="map";\n " + + "4 [label="MapPartitionsRDD [4]"];\n }")) + assert(stage1.contains("{\n label="groupBy";\n " + + "5 [label="MapPartitionsRDD [5]"];\n }")) + + val stage2 = Source.fromURL(sc.ui.get.appUIAddress + + "/stages/stage/?id=2&attempt=0&expandDagViz=true").mkString + assert(stage2.contains("digraph G {\n subgraph clusterstage_2 {\n " + + "label="Stage 2";\n subgraph ")) + assert(stage2.contains("{\n label="groupBy";\n " + + "6 [label="ShuffledRDD [6]"];\n }")) + } + } + def goToUi(sc: SparkContext, path: String): Unit = { goToUi(sc.ui.get, path) } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 56f7b9cf1f35..b140387d309f 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -240,7 +240,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with val taskFailedReasons = Seq( Resubmitted, new FetchFailed(null, 0, 0, 0, "ignored"), - ExceptionFailure("Exception", "description", null, null, None), + ExceptionFailure("Exception", "description", null, null, None, None), TaskResultLost, TaskKilled, ExecutorLostFailure("0"), diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala new file mode 100644 index 000000000000..3dab15a9d469 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.storage + +import scala.xml.Utility + +import org.mockito.Mockito._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.storage._ + +class StoragePageSuite extends SparkFunSuite { + + val storageTab = mock(classOf[StorageTab]) + when(storageTab.basePath).thenReturn("http://localhost:4040") + val storagePage = new StoragePage(storageTab) + + test("rddTable") { + val rdd1 = new RDDInfo(1, + "rdd1", + 10, + StorageLevel.MEMORY_ONLY, + Seq.empty) + rdd1.memSize = 100 + rdd1.numCachedPartitions = 10 + + val rdd2 = new RDDInfo(2, + "rdd2", + 10, + StorageLevel.DISK_ONLY, + Seq.empty) + rdd2.diskSize = 200 + rdd2.numCachedPartitions = 5 + + val rdd3 = new RDDInfo(3, + "rdd3", + 10, + StorageLevel.MEMORY_AND_DISK_SER, + Seq.empty) + rdd3.memSize = 400 + rdd3.diskSize = 500 + rdd3.numCachedPartitions = 10 + + val xmlNodes = storagePage.rddTable(Seq(rdd1, rdd2, rdd3)) + + val headers = Seq( + "RDD Name", + "Storage Level", + "Cached Partitions", + "Fraction Cached", + "Size in Memory", + "Size in ExternalBlockStore", + "Size on Disk") + assert((xmlNodes \\ "th").map(_.text) === headers) + + assert((xmlNodes \\ "tr").size === 3) + assert(((xmlNodes \\ "tr")(0) \\ "td").map(_.text.trim) === + Seq("rdd1", "Memory Deserialized 1x Replicated", "10", "100%", "100.0 B", "0.0 B", "0.0 B")) + // Check the url + assert(((xmlNodes \\ "tr")(0) \\ "td" \ "a")(0).attribute("href").map(_.text) === + Some("http://localhost:4040/storage/rdd?id=1")) + + assert(((xmlNodes \\ "tr")(1) \\ "td").map(_.text.trim) === + Seq("rdd2", "Disk Serialized 1x Replicated", "5", "50%", "0.0 B", "0.0 B", "200.0 B")) + // Check the url + assert(((xmlNodes \\ "tr")(1) \\ "td" \ "a")(0).attribute("href").map(_.text) === + Some("http://localhost:4040/storage/rdd?id=2")) + + assert(((xmlNodes \\ "tr")(2) \\ "td").map(_.text.trim) === + Seq("rdd3", "Disk Memory Serialized 1x Replicated", "10", "100%", "400.0 B", "0.0 B", + "500.0 B")) + // Check the url + assert(((xmlNodes \\ "tr")(2) \\ "td" \ "a")(0).attribute("href").map(_.text) === + Some("http://localhost:4040/storage/rdd?id=3")) + } + + test("empty rddTable") { + assert(storagePage.rddTable(Seq.empty).isEmpty) + } + + test("streamBlockStorageLevelDescriptionAndSize") { + val memoryBlock = BlockUIData(StreamBlockId(0, 0), + "localhost:1111", + StorageLevel.MEMORY_ONLY, + memSize = 100, + diskSize = 0, + externalBlockStoreSize = 0) + assert(("Memory", 100) === storagePage.streamBlockStorageLevelDescriptionAndSize(memoryBlock)) + + val memorySerializedBlock = BlockUIData(StreamBlockId(0, 0), + "localhost:1111", + StorageLevel.MEMORY_ONLY_SER, + memSize = 100, + diskSize = 0, + externalBlockStoreSize = 0) + assert(("Memory Serialized", 100) === + storagePage.streamBlockStorageLevelDescriptionAndSize(memorySerializedBlock)) + + val diskBlock = BlockUIData(StreamBlockId(0, 0), + "localhost:1111", + StorageLevel.DISK_ONLY, + memSize = 0, + diskSize = 100, + externalBlockStoreSize = 0) + assert(("Disk", 100) === storagePage.streamBlockStorageLevelDescriptionAndSize(diskBlock)) + + val externalBlock = BlockUIData(StreamBlockId(0, 0), + "localhost:1111", + StorageLevel.OFF_HEAP, + memSize = 0, + diskSize = 0, + externalBlockStoreSize = 100) + assert(("External", 100) === + storagePage.streamBlockStorageLevelDescriptionAndSize(externalBlock)) + } + + test("receiverBlockTables") { + val blocksForExecutor0 = Seq( + BlockUIData(StreamBlockId(0, 0), + "localhost:10000", + StorageLevel.MEMORY_ONLY, + memSize = 100, + diskSize = 0, + externalBlockStoreSize = 0), + BlockUIData(StreamBlockId(1, 1), + "localhost:10000", + StorageLevel.DISK_ONLY, + memSize = 0, + diskSize = 100, + externalBlockStoreSize = 0) + ) + val executor0 = ExecutorStreamBlockStatus("0", "localhost:10000", blocksForExecutor0) + + val blocksForExecutor1 = Seq( + BlockUIData(StreamBlockId(0, 0), + "localhost:10001", + StorageLevel.MEMORY_ONLY, + memSize = 100, + diskSize = 0, + externalBlockStoreSize = 0), + BlockUIData(StreamBlockId(2, 2), + "localhost:10001", + StorageLevel.OFF_HEAP, + memSize = 0, + diskSize = 0, + externalBlockStoreSize = 200), + BlockUIData(StreamBlockId(1, 1), + "localhost:10001", + StorageLevel.MEMORY_ONLY_SER, + memSize = 100, + diskSize = 0, + externalBlockStoreSize = 0) + ) + val executor1 = ExecutorStreamBlockStatus("1", "localhost:10001", blocksForExecutor1) + val xmlNodes = storagePage.receiverBlockTables(Seq(executor0, executor1)) + + val executorTable = (xmlNodes \\ "table")(0) + val executorHeaders = Seq( + "Executor ID", + "Address", + "Total Size in Memory", + "Total Size in ExternalBlockStore", + "Total Size on Disk", + "Stream Blocks") + assert((executorTable \\ "th").map(_.text) === executorHeaders) + + assert((executorTable \\ "tr").size === 2) + assert(((executorTable \\ "tr")(0) \\ "td").map(_.text.trim) === + Seq("0", "localhost:10000", "100.0 B", "0.0 B", "100.0 B", "2")) + assert(((executorTable \\ "tr")(1) \\ "td").map(_.text.trim) === + Seq("1", "localhost:10001", "200.0 B", "200.0 B", "0.0 B", "3")) + + val blockTable = (xmlNodes \\ "table")(1) + val blockHeaders = Seq( + "Block ID", + "Replication Level", + "Location", + "Storage Level", + "Size") + assert((blockTable \\ "th").map(_.text) === blockHeaders) + + assert((blockTable \\ "tr").size === 5) + assert(((blockTable \\ "tr")(0) \\ "td").map(_.text.trim) === + Seq("input-0-0", "2", "localhost:10000", "Memory", "100.0 B")) + // Check "rowspan=2" for the first 2 columns + assert(((blockTable \\ "tr")(0) \\ "td")(0).attribute("rowspan").map(_.text) === Some("2")) + assert(((blockTable \\ "tr")(0) \\ "td")(1).attribute("rowspan").map(_.text) === Some("2")) + + assert(((blockTable \\ "tr")(1) \\ "td").map(_.text.trim) === + Seq("localhost:10001", "Memory", "100.0 B")) + + assert(((blockTable \\ "tr")(2) \\ "td").map(_.text.trim) === + Seq("input-1-1", "2", "localhost:10000", "Disk", "100.0 B")) + // Check "rowspan=2" for the first 2 columns + assert(((blockTable \\ "tr")(2) \\ "td")(0).attribute("rowspan").map(_.text) === Some("2")) + assert(((blockTable \\ "tr")(2) \\ "td")(1).attribute("rowspan").map(_.text) === Some("2")) + + assert(((blockTable \\ "tr")(3) \\ "td").map(_.text.trim) === + Seq("localhost:10001", "Memory Serialized", "100.0 B")) + + assert(((blockTable \\ "tr")(4) \\ "td").map(_.text.trim) === + Seq("input-2-2", "1", "localhost:10001", "External", "200.0 B")) + // Check "rowspan=1" for the first 2 columns + assert(((blockTable \\ "tr")(4) \\ "td")(0).attribute("rowspan").map(_.text) === Some("1")) + assert(((blockTable \\ "tr")(4) \\ "td")(1).attribute("rowspan").map(_.text) === Some("1")) + } + + test("empty receiverBlockTables") { + assert(storagePage.receiverBlockTables(Seq.empty).isEmpty) + + val executor0 = ExecutorStreamBlockStatus("0", "localhost:10000", Seq.empty) + val executor1 = ExecutorStreamBlockStatus("1", "localhost:10001", Seq.empty) + assert(storagePage.receiverBlockTables(Seq(executor0, executor1)).isEmpty) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 6c40685484ed..61601016e005 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.util +import scala.collection.mutable.ArrayBuffer + import java.util.concurrent.TimeoutException import akka.actor.ActorNotFound @@ -24,7 +26,7 @@ import akka.actor.ActorNotFound import org.apache.spark._ import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.MapStatus -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} import org.apache.spark.SSLSampleConfigs._ @@ -107,8 +109,9 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst slaveTracker.updateEpoch(masterTracker.getEpoch) // this should succeed since security off - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), size1000))) + assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000), + ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) rpcEnv.shutdown() slaveRpcEnv.shutdown() @@ -153,8 +156,9 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst slaveTracker.updateEpoch(masterTracker.getEpoch) // this should succeed since security on and passwords match - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), size1000))) + assert(slaveTracker.getMapSizesByExecutorId(10, 0) === + Seq((BlockManagerId("a", "hostA", 1000), + ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) rpcEnv.shutdown() slaveRpcEnv.shutdown() @@ -232,8 +236,8 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst slaveTracker.updateEpoch(masterTracker.getEpoch) // this should succeed since security off - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), size1000))) + assert(slaveTracker.getMapSizesByExecutorId(10, 0) === + Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) rpcEnv.shutdown() slaveRpcEnv.shutdown() @@ -278,8 +282,8 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), size1000))) + assert(slaveTracker.getMapSizesByExecutorId(10, 0) === + Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) rpcEnv.shutdown() slaveRpcEnv.shutdown() diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 1053c6caf771..480722a5ac18 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -375,6 +375,7 @@ class TestCreateNullValue { // parameters of the closure constructor. This allows us to test whether // null values are created correctly for each type. val nestedClosure = () => { + // scalastyle:off println if (s.toString == "123") { // Don't really output them to avoid noisy println(bo) println(c) @@ -389,6 +390,7 @@ class TestCreateNullValue { val closure = () => { println(getX) } + // scalastyle:on println ClosureCleaner.clean(closure) } nestedClosure() diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala index 3147c937769d..a829b099025e 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala @@ -120,8 +120,8 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri // Accessors for private methods private val _isClosure = PrivateMethod[Boolean]('isClosure) private val _getInnerClosureClasses = PrivateMethod[List[Class[_]]]('getInnerClosureClasses) - private val _getOuterClasses = PrivateMethod[List[Class[_]]]('getOuterClasses) - private val _getOuterObjects = PrivateMethod[List[AnyRef]]('getOuterObjects) + private val _getOuterClassesAndObjects = + PrivateMethod[(List[Class[_]], List[AnyRef])]('getOuterClassesAndObjects) private def isClosure(obj: AnyRef): Boolean = { ClosureCleaner invokePrivate _isClosure(obj) @@ -131,12 +131,8 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri ClosureCleaner invokePrivate _getInnerClosureClasses(closure) } - private def getOuterClasses(closure: AnyRef): List[Class[_]] = { - ClosureCleaner invokePrivate _getOuterClasses(closure) - } - - private def getOuterObjects(closure: AnyRef): List[AnyRef] = { - ClosureCleaner invokePrivate _getOuterObjects(closure) + private def getOuterClassesAndObjects(closure: AnyRef): (List[Class[_]], List[AnyRef]) = { + ClosureCleaner invokePrivate _getOuterClassesAndObjects(closure) } test("get inner closure classes") { @@ -171,14 +167,11 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri val closure2 = () => localValue val closure3 = () => someSerializableValue val closure4 = () => someSerializableMethod() - val outerClasses1 = getOuterClasses(closure1) - val outerClasses2 = getOuterClasses(closure2) - val outerClasses3 = getOuterClasses(closure3) - val outerClasses4 = getOuterClasses(closure4) - val outerObjects1 = getOuterObjects(closure1) - val outerObjects2 = getOuterObjects(closure2) - val outerObjects3 = getOuterObjects(closure3) - val outerObjects4 = getOuterObjects(closure4) + + val (outerClasses1, outerObjects1) = getOuterClassesAndObjects(closure1) + val (outerClasses2, outerObjects2) = getOuterClassesAndObjects(closure2) + val (outerClasses3, outerObjects3) = getOuterClassesAndObjects(closure3) + val (outerClasses4, outerObjects4) = getOuterClassesAndObjects(closure4) // The classes and objects should have the same size assert(outerClasses1.size === outerObjects1.size) @@ -211,10 +204,8 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri val x = 1 val closure1 = () => 1 val closure2 = () => x - val outerClasses1 = getOuterClasses(closure1) - val outerClasses2 = getOuterClasses(closure2) - val outerObjects1 = getOuterObjects(closure1) - val outerObjects2 = getOuterObjects(closure2) + val (outerClasses1, outerObjects1) = getOuterClassesAndObjects(closure1) + val (outerClasses2, outerObjects2) = getOuterClassesAndObjects(closure2) assert(outerClasses1.size === outerObjects1.size) assert(outerClasses2.size === outerObjects2.size) // These inner closures only reference local variables, and so do not have $outer pointers @@ -227,12 +218,9 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri val closure1 = () => 1 val closure2 = () => y val closure3 = () => localValue - val outerClasses1 = getOuterClasses(closure1) - val outerClasses2 = getOuterClasses(closure2) - val outerClasses3 = getOuterClasses(closure3) - val outerObjects1 = getOuterObjects(closure1) - val outerObjects2 = getOuterObjects(closure2) - val outerObjects3 = getOuterObjects(closure3) + val (outerClasses1, outerObjects1) = getOuterClassesAndObjects(closure1) + val (outerClasses2, outerObjects2) = getOuterClassesAndObjects(closure2) + val (outerClasses3, outerObjects3) = getOuterClassesAndObjects(closure3) assert(outerClasses1.size === outerObjects1.size) assert(outerClasses2.size === outerObjects2.size) assert(outerClasses3.size === outerObjects3.size) @@ -265,9 +253,9 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri val closure1 = () => 1 val closure2 = () => localValue val closure3 = () => someSerializableValue - val outerClasses1 = getOuterClasses(closure1) - val outerClasses2 = getOuterClasses(closure2) - val outerClasses3 = getOuterClasses(closure3) + val (outerClasses1, _) = getOuterClassesAndObjects(closure1) + val (outerClasses2, _) = getOuterClassesAndObjects(closure2) + val (outerClasses3, _) = getOuterClassesAndObjects(closure3) val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively = false) val fields2 = findAccessedFields(closure2, outerClasses2, findTransitively = false) @@ -307,10 +295,10 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri val closure2 = () => a val closure3 = () => localValue val closure4 = () => someSerializableValue - val outerClasses1 = getOuterClasses(closure1) - val outerClasses2 = getOuterClasses(closure2) - val outerClasses3 = getOuterClasses(closure3) - val outerClasses4 = getOuterClasses(closure4) + val (outerClasses1, _) = getOuterClassesAndObjects(closure1) + val (outerClasses2, _) = getOuterClassesAndObjects(closure2) + val (outerClasses3, _) = getOuterClassesAndObjects(closure3) + val (outerClasses4, _) = getOuterClassesAndObjects(closure4) // First, find only fields accessed directly, not transitively, by these closures val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively = false) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index e0ef9c70a5fc..143c1b901df1 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -83,6 +83,9 @@ class JsonProtocolSuite extends SparkFunSuite { val executorAdded = SparkListenerExecutorAdded(executorAddedTime, "exec1", new ExecutorInfo("Hostee.awesome.com", 11, logUrlMap)) val executorRemoved = SparkListenerExecutorRemoved(executorRemovedTime, "exec2", "test reason") + val executorMetricsUpdate = SparkListenerExecutorMetricsUpdate("exec3", Seq( + (1L, 2, 3, makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, + hasHadoopInput = true, hasOutput = true)))) testEvent(stageSubmitted, stageSubmittedJsonString) testEvent(stageCompleted, stageCompletedJsonString) @@ -102,6 +105,7 @@ class JsonProtocolSuite extends SparkFunSuite { testEvent(applicationEnd, applicationEndJsonString) testEvent(executorAdded, executorAddedJsonString) testEvent(executorRemoved, executorRemovedJsonString) + testEvent(executorMetricsUpdate, executorMetricsUpdateJsonString) } test("Dependent Classes") { @@ -147,7 +151,7 @@ class JsonProtocolSuite extends SparkFunSuite { testTaskEndReason(exceptionFailure) testTaskEndReason(TaskResultLost) testTaskEndReason(TaskKilled) - testTaskEndReason(ExecutorLostFailure("100")) + testTaskEndReason(ExecutorLostFailure("100", true)) testTaskEndReason(UnknownReason) // BlockId @@ -159,7 +163,8 @@ class JsonProtocolSuite extends SparkFunSuite { } test("ExceptionFailure backward compatibility") { - val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, None) + val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, + None, None) val oldEvent = JsonProtocol.taskEndReasonToJson(exceptionFailure) .removeField({ _._1 == "Full Stack Trace" }) assertEquals(exceptionFailure, JsonProtocol.taskEndReasonFromJson(oldEvent)) @@ -290,10 +295,10 @@ class JsonProtocolSuite extends SparkFunSuite { test("ExecutorLostFailure backward compatibility") { // ExecutorLostFailure in Spark 1.1.0 does not have an "Executor ID" property. - val executorLostFailure = ExecutorLostFailure("100") + val executorLostFailure = ExecutorLostFailure("100", true) val oldEvent = JsonProtocol.taskEndReasonToJson(executorLostFailure) .removeField({ _._1 == "Executor ID" }) - val expectedExecutorLostFailure = ExecutorLostFailure("Unknown") + val expectedExecutorLostFailure = ExecutorLostFailure("Unknown", true) assert(expectedExecutorLostFailure === JsonProtocol.taskEndReasonFromJson(oldEvent)) } @@ -440,10 +445,20 @@ class JsonProtocolSuite extends SparkFunSuite { case (e1: SparkListenerEnvironmentUpdate, e2: SparkListenerEnvironmentUpdate) => assertEquals(e1.environmentDetails, e2.environmentDetails) case (e1: SparkListenerExecutorAdded, e2: SparkListenerExecutorAdded) => - assert(e1.executorId == e1.executorId) + assert(e1.executorId === e1.executorId) assertEquals(e1.executorInfo, e2.executorInfo) case (e1: SparkListenerExecutorRemoved, e2: SparkListenerExecutorRemoved) => - assert(e1.executorId == e1.executorId) + assert(e1.executorId === e1.executorId) + case (e1: SparkListenerExecutorMetricsUpdate, e2: SparkListenerExecutorMetricsUpdate) => + assert(e1.execId === e2.execId) + assertSeqEquals[(Long, Int, Int, TaskMetrics)](e1.taskMetrics, e2.taskMetrics, (a, b) => { + val (taskId1, stageId1, stageAttemptId1, metrics1) = a + val (taskId2, stageId2, stageAttemptId2, metrics2) = b + assert(taskId1 === taskId2) + assert(stageId1 === stageId2) + assert(stageAttemptId1 === stageAttemptId2) + assertEquals(metrics1, metrics2) + }) case (e1, e2) => assert(e1 === e2) case _ => fail("Events don't match in types!") @@ -484,7 +499,7 @@ class JsonProtocolSuite extends SparkFunSuite { private def assertEquals(info1: TaskInfo, info2: TaskInfo) { assert(info1.taskId === info2.taskId) assert(info1.index === info2.index) - assert(info1.attempt === info2.attempt) + assert(info1.attemptNumber === info2.attemptNumber) assert(info1.launchTime === info2.launchTime) assert(info1.executorId === info2.executorId) assert(info1.host === info2.host) @@ -562,8 +577,10 @@ class JsonProtocolSuite extends SparkFunSuite { assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals) case (TaskResultLost, TaskResultLost) => case (TaskKilled, TaskKilled) => - case (ExecutorLostFailure(execId1), ExecutorLostFailure(execId2)) => + case (ExecutorLostFailure(execId1, isNormalExit1), + ExecutorLostFailure(execId2, isNormalExit2)) => assert(execId1 === execId2) + assert(isNormalExit1 === isNormalExit2) case (UnknownReason, UnknownReason) => case _ => fail("Task end reasons don't match in types!") } @@ -1598,4 +1615,55 @@ class JsonProtocolSuite extends SparkFunSuite { | "Removed Reason": "test reason" |} """ + + private val executorMetricsUpdateJsonString = + s""" + |{ + | "Event": "SparkListenerExecutorMetricsUpdate", + | "Executor ID": "exec3", + | "Metrics Updated": [ + | { + | "Task ID": 1, + | "Stage ID": 2, + | "Stage Attempt ID": 3, + | "Task Metrics": { + | "Host Name": "localhost", + | "Executor Deserialize Time": 300, + | "Executor Run Time": 400, + | "Result Size": 500, + | "JVM GC Time": 600, + | "Result Serialization Time": 700, + | "Memory Bytes Spilled": 800, + | "Disk Bytes Spilled": 0, + | "Input Metrics": { + | "Data Read Method": "Hadoop", + | "Bytes Read": 2100, + | "Records Read": 21 + | }, + | "Output Metrics": { + | "Data Write Method": "Hadoop", + | "Bytes Written": 1200, + | "Records Written": 12 + | }, + | "Updated Blocks": [ + | { + | "Block ID": "rdd_0_0", + | "Status": { + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use ExternalBlockStore": false, + | "Deserialized": false, + | "Replication": 2 + | }, + | "Memory Size": 0, + | "ExternalBlockStore Size": 0, + | "Disk Size": 0 + | } + | } + | ] + | } + | }] + |} + """.stripMargin } diff --git a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala index 42125547436c..d3d464e84ffd 100644 --- a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala @@ -84,7 +84,9 @@ class MutableURLClassLoaderSuite extends SparkFunSuite { try { sc.makeRDD(1 to 5, 2).mapPartitions { x => val loader = Thread.currentThread().getContextClassLoader + // scalastyle:off classforname Class.forName(className, true, loader).newInstance() + // scalastyle:on classforname Seq().iterator }.count() } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index a61ea3918f46..1fb81ad565b4 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream} +import java.lang.{Double => JDouble, Float => JFloat} import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} import java.text.DecimalFormatSymbols @@ -486,11 +487,17 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { // Test for using the util function to change our log levels. test("log4j log level change") { - Utils.setLogLevel(org.apache.log4j.Level.ALL) - assert(log.isInfoEnabled()) - Utils.setLogLevel(org.apache.log4j.Level.ERROR) - assert(!log.isInfoEnabled()) - assert(log.isErrorEnabled()) + val current = org.apache.log4j.Logger.getRootLogger().getLevel() + try { + Utils.setLogLevel(org.apache.log4j.Level.ALL) + assert(log.isInfoEnabled()) + Utils.setLogLevel(org.apache.log4j.Level.ERROR) + assert(!log.isInfoEnabled()) + assert(log.isErrorEnabled()) + } finally { + // Best effort at undoing changes this test made. + Utils.setLogLevel(current) + } } test("deleteRecursively") { @@ -673,4 +680,58 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(!Utils.isInDirectory(nullFile, parentDir)) assert(!Utils.isInDirectory(nullFile, childFile3)) } + + test("circular buffer") { + val buffer = new CircularBuffer(25) + val stream = new java.io.PrintStream(buffer, true, "UTF-8") + + // scalastyle:off println + stream.println("test circular test circular test circular test circular test circular") + // scalastyle:on println + assert(buffer.toString === "t circular test circular\n") + } + + test("nanSafeCompareDoubles") { + def shouldMatchDefaultOrder(a: Double, b: Double): Unit = { + assert(Utils.nanSafeCompareDoubles(a, b) === JDouble.compare(a, b)) + assert(Utils.nanSafeCompareDoubles(b, a) === JDouble.compare(b, a)) + } + shouldMatchDefaultOrder(0d, 0d) + shouldMatchDefaultOrder(0d, 1d) + shouldMatchDefaultOrder(Double.MinValue, Double.MaxValue) + assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.NaN) === 0) + assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.PositiveInfinity) === 1) + assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.NegativeInfinity) === 1) + assert(Utils.nanSafeCompareDoubles(Double.PositiveInfinity, Double.NaN) === -1) + assert(Utils.nanSafeCompareDoubles(Double.NegativeInfinity, Double.NaN) === -1) + } + + test("nanSafeCompareFloats") { + def shouldMatchDefaultOrder(a: Float, b: Float): Unit = { + assert(Utils.nanSafeCompareFloats(a, b) === JFloat.compare(a, b)) + assert(Utils.nanSafeCompareFloats(b, a) === JFloat.compare(b, a)) + } + shouldMatchDefaultOrder(0f, 0f) + shouldMatchDefaultOrder(1f, 1f) + shouldMatchDefaultOrder(Float.MinValue, Float.MaxValue) + assert(Utils.nanSafeCompareFloats(Float.NaN, Float.NaN) === 0) + assert(Utils.nanSafeCompareFloats(Float.NaN, Float.PositiveInfinity) === 1) + assert(Utils.nanSafeCompareFloats(Float.NaN, Float.NegativeInfinity) === 1) + assert(Utils.nanSafeCompareFloats(Float.PositiveInfinity, Float.NaN) === -1) + assert(Utils.nanSafeCompareFloats(Float.NegativeInfinity, Float.NaN) === -1) + } + + test("isDynamicAllocationEnabled") { + val conf = new SparkConf() + assert(Utils.isDynamicAllocationEnabled(conf) === false) + assert(Utils.isDynamicAllocationEnabled( + conf.set("spark.dynamicAllocation.enabled", "false")) === false) + assert(Utils.isDynamicAllocationEnabled( + conf.set("spark.dynamicAllocation.enabled", "true")) === true) + assert(Utils.isDynamicAllocationEnabled( + conf.set("spark.executor.instances", "1")) === false) + assert(Utils.isDynamicAllocationEnabled( + conf.set("spark.executor.instances", "0")) === true) + } + } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 79eba61a8725..12e9bafcc92c 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -244,7 +244,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { private def testSimpleSpilling(codec: Option[String] = None): Unit = { val conf = createSparkConf(loadDefaults = true, codec) // Load defaults for Spark home conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) // reduceByKey - should spill ~8 times val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) @@ -292,7 +292,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { test("spilling with hash collisions") { val conf = createSparkConf(loadDefaults = true) conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = createExternalMap[String] val collisionPairs = Seq( @@ -341,7 +341,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { test("spilling with many hash collisions") { val conf = createSparkConf(loadDefaults = true) conf.set("spark.shuffle.memoryFraction", "0.0001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = new ExternalAppendOnlyMap[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _) // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes @@ -366,7 +366,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { test("spilling with hash collisions using the Int.MaxValue key") { val conf = createSparkConf(loadDefaults = true) conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = createExternalMap[Int] (1 to 100000).foreach { i => map.insert(i, i) } @@ -383,7 +383,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { test("spilling with null keys and values") { val conf = createSparkConf(loadDefaults = true) conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = createExternalMap[Int] map.insertAll((1 to 100000).iterator.map(i => (i, i))) @@ -399,4 +399,19 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { sc.stop() } + test("external aggregation updates peak execution memory") { + val conf = createSparkConf(loadDefaults = false) + .set("spark.shuffle.memoryFraction", "0.001") + .set("spark.shuffle.manager", "hash") // make sure we're not also using ExternalSorter + sc = new SparkContext("local", "test", conf) + // No spilling + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map without spilling") { + sc.parallelize(1 to 10, 2).map { i => (i, i) }.reduceByKey(_ + _).count() + } + // With spilling + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map with spilling") { + sc.parallelize(1 to 1000 * 1000, 2).map { i => (i, i) }.reduceByKey(_ + _).count() + } + } + } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 9cefa612f549..bdb0f4d507a7 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -176,7 +176,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { def testSpillingInLocalCluster(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) // reduceByKey - should spill ~8 times val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) @@ -254,7 +254,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { def spillingInLocalClusterWithManyReduceTasks(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) // reduceByKey - should spill ~4 times per executor val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) @@ -554,7 +554,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { test("spilling with hash collisions") { val conf = createSparkConf(true, false) conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i @@ -611,7 +611,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { test("spilling with many hash collisions") { val conf = createSparkConf(true, false) conf.set("spark.shuffle.memoryFraction", "0.0001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val agg = new Aggregator[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _) val sorter = new ExternalSorter[FixedHashObject, Int, Int](Some(agg), None, None, None) @@ -634,7 +634,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { test("spilling with hash collisions using the Int.MaxValue key") { val conf = createSparkConf(true, false) conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) def createCombiner(i: Int): ArrayBuffer[Int] = ArrayBuffer[Int](i) def mergeValue(buffer: ArrayBuffer[Int], i: Int): ArrayBuffer[Int] = buffer += i @@ -658,7 +658,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { test("spilling with null keys and values") { val conf = createSparkConf(true, false) conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i @@ -692,10 +692,10 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { sortWithoutBreakingSortingContracts(createSparkConf(true, false)) } - def sortWithoutBreakingSortingContracts(conf: SparkConf) { + private def sortWithoutBreakingSortingContracts(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.01") conf.set("spark.shuffle.manager", "sort") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) // Using wrongOrdering to show integer overflow introduced exception. val rand = new Random(100L) @@ -743,5 +743,15 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } sorter2.stop() - } + } + + test("sorting updates peak execution memory") { + val conf = createSparkConf(loadDefaults = false, kryo = false) + .set("spark.shuffle.manager", "sort") + sc = new SparkContext("local", "test", conf) + // Avoid aggregating here to make sure we're not also using ExternalAppendOnlyMap + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sorter") { + sc.parallelize(1 to 1000, 2).repartition(100).count() + } + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala index 6d2459d48d32..3b67f6206495 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala @@ -17,15 +17,20 @@ package org.apache.spark.util.collection -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import com.google.common.io.ByteStreams +import org.mockito.Matchers.any +import org.mockito.Mockito._ +import org.mockito.Mockito.RETURNS_SMART_NULLS +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.scalatest.Matchers._ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.KryoSerializer -import org.apache.spark.storage.{FileSegment, BlockObjectWriter} +import org.apache.spark.storage.DiskBlockObjectWriter class PartitionedSerializedPairBufferSuite extends SparkFunSuite { test("OrderedInputStream single record") { @@ -79,13 +84,13 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite { val struct = SomeStruct("something", 5) buffer.insert(4, 10, struct) val it = buffer.destructiveSortedWritablePartitionedIterator(None) - val writer = new SimpleBlockObjectWriter + val (writer, baos) = createMockWriter() assert(it.hasNext) it.nextPartition should be (4) it.writeNext(writer) assert(!it.hasNext) - val stream = serializerInstance.deserializeStream(writer.getInputStream) + val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray)) stream.readObject[AnyRef]() should be (10) stream.readObject[AnyRef]() should be (struct) } @@ -101,7 +106,7 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite { buffer.insert(5, 3, struct3) val it = buffer.destructiveSortedWritablePartitionedIterator(None) - val writer = new SimpleBlockObjectWriter + val (writer, baos) = createMockWriter() assert(it.hasNext) it.nextPartition should be (4) it.writeNext(writer) @@ -113,7 +118,7 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite { it.writeNext(writer) assert(!it.hasNext) - val stream = serializerInstance.deserializeStream(writer.getInputStream) + val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray)) val iter = stream.asIterator iter.next() should be (2) iter.next() should be (struct2) @@ -123,26 +128,21 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite { iter.next() should be (struct1) assert(!iter.hasNext) } -} - -case class SomeStruct(val str: String, val num: Int) - -class SimpleBlockObjectWriter extends BlockObjectWriter(null) { - val baos = new ByteArrayOutputStream() - override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = { - baos.write(bytes, offs, len) + def createMockWriter(): (DiskBlockObjectWriter, ByteArrayOutputStream) = { + val writer = mock(classOf[DiskBlockObjectWriter], RETURNS_SMART_NULLS) + val baos = new ByteArrayOutputStream() + when(writer.write(any(), any(), any())).thenAnswer(new Answer[Unit] { + override def answer(invocationOnMock: InvocationOnMock): Unit = { + val args = invocationOnMock.getArguments + val bytes = args(0).asInstanceOf[Array[Byte]] + val offset = args(1).asInstanceOf[Int] + val length = args(2).asInstanceOf[Int] + baos.write(bytes, offset, length) + } + }) + (writer, baos) } - - def getInputStream(): InputStream = new ByteArrayInputStream(baos.toByteArray) - - override def open(): BlockObjectWriter = this - override def close(): Unit = { } - override def isOpen: Boolean = true - override def commitAndClose(): Unit = { } - override def revertPartialWritesAndClose(): Unit = { } - override def fileSegment(): FileSegment = null - override def write(key: Any, value: Any): Unit = { } - override def recordWritten(): Unit = { } - override def write(b: Int): Unit = { } } + +case class SomeStruct(str: String, num: Int) diff --git a/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala index 5a5919fca246..4f382414a8dd 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala @@ -103,7 +103,9 @@ private object SizeTrackerSuite { */ def main(args: Array[String]): Unit = { if (args.size < 1) { + // scalastyle:off println println("Usage: SizeTrackerSuite [num elements]") + // scalastyle:on println System.exit(1) } val numElements = args(0).toInt @@ -180,11 +182,13 @@ private object SizeTrackerSuite { baseTimes: Seq[Long], sampledTimes: Seq[Long], unsampledTimes: Seq[Long]): Unit = { + // scalastyle:off println println(s"Average times for $testName (ms):") println(" Base - " + averageTime(baseTimes)) println(" SizeTracker (sampled) - " + averageTime(sampledTimes)) println(" SizeEstimator (unsampled) - " + averageTime(unsampledTimes)) println() + // scalastyle:on println } def time(f: => Unit): Long = { diff --git a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala index b2f5d9009ee5..fefa5165db19 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala @@ -20,10 +20,10 @@ package org.apache.spark.util.collection import java.lang.{Float => JFloat, Integer => JInteger} import java.util.{Arrays, Comparator} -import org.apache.spark.SparkFunSuite +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.util.random.XORShiftRandom -class SorterSuite extends SparkFunSuite { +class SorterSuite extends SparkFunSuite with Logging { test("equivalent to Arrays.sort") { val rand = new XORShiftRandom(123) @@ -74,7 +74,7 @@ class SorterSuite extends SparkFunSuite { /** Runs an experiment several times. */ def runExperiment(name: String, skip: Boolean = false)(f: => Unit, prepare: () => Unit): Unit = { if (skip) { - println(s"Skipped experiment $name.") + logInfo(s"Skipped experiment $name.") return } @@ -86,11 +86,11 @@ class SorterSuite extends SparkFunSuite { while (i < 10) { val time = org.apache.spark.util.Utils.timeIt(1)(f, Some(prepare)) next10 += time - println(s"$name: Took $time ms") + logInfo(s"$name: Took $time ms") i += 1 } - println(s"$name: ($firstTry ms first try, ${next10 / 10} ms average)") + logInfo(s"$name: ($firstTry ms first try, ${next10 / 10} ms average)") } /** diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala new file mode 100644 index 000000000000..0326ed70b5ed --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort + +import com.google.common.primitives.UnsignedBytes +import org.scalatest.prop.PropertyChecks +import org.apache.spark.SparkFunSuite +import org.apache.spark.unsafe.types.UTF8String + +class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { + + test("String prefix comparator") { + + def testPrefixComparison(s1: String, s2: String): Unit = { + val utf8string1 = UTF8String.fromString(s1) + val utf8string2 = UTF8String.fromString(s2) + val s1Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string1) + val s2Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string2) + val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix) + + val cmp = UnsignedBytes.lexicographicalComparator().compare( + utf8string1.getBytes.take(8), utf8string2.getBytes.take(8)) + + assert( + (prefixComparisonResult == 0 && cmp == 0) || + (prefixComparisonResult < 0 && s1.compareTo(s2) < 0) || + (prefixComparisonResult > 0 && s1.compareTo(s2) > 0)) + } + + // scalastyle:off + val regressionTests = Table( + ("s1", "s2"), + ("abc", "世界"), + ("你好", "世界"), + ("你好123", "你好122") + ) + // scalastyle:on + + forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) } + forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) } + } + + test("Binary prefix comparator") { + + def compareBinary(x: Array[Byte], y: Array[Byte]): Int = { + for (i <- 0 until x.length; if i < y.length) { + val res = x(i).compare(y(i)) + if (res != 0) return res + } + x.length - y.length + } + + def testPrefixComparison(x: Array[Byte], y: Array[Byte]): Unit = { + val s1Prefix = PrefixComparators.BinaryPrefixComparator.computePrefix(x) + val s2Prefix = PrefixComparators.BinaryPrefixComparator.computePrefix(y) + val prefixComparisonResult = + PrefixComparators.BINARY.compare(s1Prefix, s2Prefix) + assert( + (prefixComparisonResult == 0) || + (prefixComparisonResult < 0 && compareBinary(x, y) < 0) || + (prefixComparisonResult > 0 && compareBinary(x, y) > 0)) + } + + // scalastyle:off + val regressionTests = Table( + ("s1", "s2"), + ("abc", "世界"), + ("你好", "世界"), + ("你好123", "你好122") + ) + // scalastyle:on + + forAll (regressionTests) { (s1: String, s2: String) => + testPrefixComparison(s1.getBytes("UTF-8"), s2.getBytes("UTF-8")) + } + forAll { (s1: String, s2: String) => + testPrefixComparison(s1.getBytes("UTF-8"), s2.getBytes("UTF-8")) + } + } + + test("double prefix comparator handles NaNs properly") { + val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L) + val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL) + assert(nan1.isNaN) + assert(nan2.isNaN) + val nan1Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan1) + val nan2Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan2) + assert(nan1Prefix === nan2Prefix) + val doubleMaxPrefix = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue) + assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1) + } + +} diff --git a/data/mllib/pic_data.txt b/data/mllib/pic_data.txt new file mode 100644 index 000000000000..fcfef8cd1913 --- /dev/null +++ b/data/mllib/pic_data.txt @@ -0,0 +1,19 @@ +0 1 1.0 +0 2 1.0 +0 3 1.0 +1 2 1.0 +1 3 1.0 +2 3 1.0 +3 4 0.1 +4 5 1.0 +4 15 1.0 +5 6 1.0 +6 7 1.0 +7 8 1.0 +8 9 1.0 +9 10 1.0 +10 11 1.0 +11 12 1.0 +12 13 1.0 +13 14 1.0 +14 15 1.0 diff --git a/data/mllib/sample_naive_bayes_data.txt b/data/mllib/sample_naive_bayes_data.txt index 981da382d6ac..bd22bea3a59d 100644 --- a/data/mllib/sample_naive_bayes_data.txt +++ b/data/mllib/sample_naive_bayes_data.txt @@ -1,6 +1,12 @@ 0,1 0 0 0,2 0 0 +0,3 0 0 +0,4 0 0 1,0 1 0 1,0 2 0 +1,0 3 0 +1,0 4 0 2,0 0 1 2,0 0 2 +2,0 0 3 +2,0 0 4 \ No newline at end of file diff --git a/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala index fc03fec9866a..61d91c70e970 100644 --- a/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala +++ b/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.util.Try @@ -59,3 +60,4 @@ object SimpleApp { } } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala index 0be8e64fbfab..9f7ae75d0b47 100644 --- a/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala +++ b/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.util.Try @@ -37,3 +38,4 @@ object SimpleApp { } } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_graphx/src/main/scala/GraphxApp.scala b/dev/audit-release/sbt_app_graphx/src/main/scala/GraphxApp.scala index 24c7f8d66729..2f0b6ef9a567 100644 --- a/dev/audit-release/sbt_app_graphx/src/main/scala/GraphxApp.scala +++ b/dev/audit-release/sbt_app_graphx/src/main/scala/GraphxApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import org.apache.spark.{SparkContext, SparkConf} @@ -51,3 +52,4 @@ object GraphXApp { println("Test succeeded") } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala b/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala index 5111bc0adb77..4a980ec071ae 100644 --- a/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala +++ b/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.collection.mutable.{ListBuffer, Queue} @@ -55,3 +56,4 @@ object SparkSqlExample { sc.stop() } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala index 9f8506650147..adc25b57d6aa 100644 --- a/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala +++ b/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.util.Try @@ -31,3 +32,4 @@ object SimpleApp { } } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala index cc86ef45858c..69c1154dc095 100644 --- a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala +++ b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.collection.mutable.{ListBuffer, Queue} @@ -57,3 +58,4 @@ object SparkSqlExample { sc.stop() } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_streaming/src/main/scala/StreamingApp.scala b/dev/audit-release/sbt_app_streaming/src/main/scala/StreamingApp.scala index 58a662bd9b2e..d6a074687f4a 100644 --- a/dev/audit-release/sbt_app_streaming/src/main/scala/StreamingApp.scala +++ b/dev/audit-release/sbt_app_streaming/src/main/scala/StreamingApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.collection.mutable.{ListBuffer, Queue} @@ -61,3 +62,4 @@ object SparkStreamingExample { ssc.stop() } } +// scalastyle:on println diff --git a/dev/change-scala-version.sh b/dev/change-scala-version.sh new file mode 100755 index 000000000000..d7975dfb6475 --- /dev/null +++ b/dev/change-scala-version.sh @@ -0,0 +1,70 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +set -e + +VALID_VERSIONS=( 2.10 2.11 ) + +usage() { + echo "Usage: $(basename $0) [-h|--help] +where : + -h| --help Display this help text + valid version values : ${VALID_VERSIONS[*]} +" 1>&2 + exit 1 +} + +if [[ ($# -ne 1) || ( $1 == "--help") || $1 == "-h" ]]; then + usage +fi + +TO_VERSION=$1 + +check_scala_version() { + for i in ${VALID_VERSIONS[*]}; do [ $i = "$1" ] && return 0; done + echo "Invalid Scala version: $1. Valid versions: ${VALID_VERSIONS[*]}" 1>&2 + exit 1 +} + +check_scala_version "$TO_VERSION" + +if [ $TO_VERSION = "2.11" ]; then + FROM_VERSION="2.10" +else + FROM_VERSION="2.11" +fi + +sed_i() { + sed -e "$1" "$2" > "$2.tmp" && mv "$2.tmp" "$2" +} + +export -f sed_i + +BASEDIR=$(dirname $0)/.. +find "$BASEDIR" -name 'pom.xml' -not -path '*target*' -print \ + -exec bash -c "sed_i 's/\(artifactId.*\)_'$FROM_VERSION'/\1_'$TO_VERSION'/g' {}" \; + +# Also update in parent POM +# Match any scala binary version to ensure idempotency +sed_i '1,/[0-9]*\.[0-9]*[0-9]*\.[0-9]*'$TO_VERSION' in parent POM -sed -i -e '0,/2.112.10 in parent POM -sed -i -e '0,/2.102.11${cur_ver}<\/version>$" - new="\1${rel_ver}<\/version>" - find . -name pom.xml | grep -v dev | xargs -I {} sed -i \ - -e "s/${old}/${new}/" {} - find . -name package.scala | grep -v dev | xargs -I {} sed -i \ - -e "s/${old}/${new}/" {} - - git commit -a -m "Preparing Spark release $GIT_TAG" - echo "Creating tag $GIT_TAG at the head of $GIT_BRANCH" - git tag $GIT_TAG - - old="^\( \{2,4\}\)${rel_ver}<\/version>$" - new="\1${next_ver}<\/version>" - find . -name pom.xml | grep -v dev | xargs -I {} sed -i \ - -e "s/$old/$new/" {} - find . -name package.scala | grep -v dev | xargs -I {} sed -i \ - -e "s/${old}/${new}/" {} - git commit -a -m "Preparing development version $next_ver" - git push origin $GIT_TAG - git push origin HEAD:$GIT_BRANCH - popd - rm -rf spark -fi - -if [[ ! "$@" =~ --skip-publish ]]; then - git clone https://$ASF_USERNAME:$ASF_PASSWORD@git-wip-us.apache.org/repos/asf/spark.git - pushd spark - git checkout --force $GIT_TAG - - # Substitute in case published version is different than released - old="^\( \{2,4\}\)${RELEASE_VERSION}<\/version>$" - new="\1${PUBLISH_VERSION}<\/version>" - find . -name pom.xml | grep -v dev | xargs -I {} sed -i \ - -e "s/${old}/${new}/" {} - - # Using Nexus API documented here: - # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API - echo "Creating Nexus staging repository" - repo_request="Apache Spark $GIT_TAG (published as $PUBLISH_VERSION)" - out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ - -H "Content-Type:application/xml" -v \ - $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) - staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") - echo "Created Nexus staging repository: $staged_repo_id" - - rm -rf $SPARK_REPO - - build/mvn -DskipTests -Pyarn -Phive \ - -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ - clean install - - ./dev/change-version-to-2.11.sh - - build/mvn -DskipTests -Pyarn -Phive \ - -Dscala-2.11 -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ - clean install - - ./dev/change-version-to-2.10.sh - - pushd $SPARK_REPO - - # Remove any extra files generated during install - find . -type f |grep -v \.jar |grep -v \.pom | xargs rm - - echo "Creating hash and signature files" - for file in $(find . -type f) - do - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --output $file.asc --detach-sig --armour $file; - if [ $(command -v md5) ]; then - # Available on OS X; -q to keep only hash - md5 -q $file > $file.md5 - else - # Available on Linux; cut to keep only hash - md5sum $file | cut -f1 -d' ' > $file.md5 - fi - shasum -a 1 $file | cut -f1 -d' ' > $file.sha1 - done - - nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id - echo "Uplading files to $nexus_upload" - for file in $(find . -type f) - do - # strip leading ./ - file_short=$(echo $file | sed -e "s/\.\///") - dest_url="$nexus_upload/org/apache/spark/$file_short" - echo " Uploading $file_short" - curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url - done - - echo "Closing nexus staging repository" - repo_request="$staged_repo_idApache Spark $GIT_TAG (published as $PUBLISH_VERSION)" - out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ - -H "Content-Type:application/xml" -v \ - $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) - echo "Closed Nexus staging repository: $staged_repo_id" - - popd - popd - rm -rf spark -fi - -if [[ ! "$@" =~ --skip-package ]]; then - # Source and binary tarballs - echo "Packaging release tarballs" - git clone https://git-wip-us.apache.org/repos/asf/spark.git - cd spark - git checkout --force $GIT_TAG - release_hash=`git rev-parse HEAD` - - rm .gitignore - rm -rf .git - cd .. - - cp -r spark spark-$RELEASE_VERSION - tar cvzf spark-$RELEASE_VERSION.tgz spark-$RELEASE_VERSION - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour --output spark-$RELEASE_VERSION.tgz.asc \ - --detach-sig spark-$RELEASE_VERSION.tgz - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md MD5 spark-$RELEASE_VERSION.tgz > \ - spark-$RELEASE_VERSION.tgz.md5 - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md SHA512 spark-$RELEASE_VERSION.tgz > \ - spark-$RELEASE_VERSION.tgz.sha - rm -rf spark-$RELEASE_VERSION - - # Updated for each binary build - make_binary_release() { - NAME=$1 - FLAGS=$2 - ZINC_PORT=$3 - cp -r spark spark-$RELEASE_VERSION-bin-$NAME - - cd spark-$RELEASE_VERSION-bin-$NAME - - # TODO There should probably be a flag to make-distribution to allow 2.11 support - if [[ $FLAGS == *scala-2.11* ]]; then - ./dev/change-version-to-2.11.sh - fi - - export ZINC_PORT=$ZINC_PORT - echo "Creating distribution: $NAME ($FLAGS)" - ./make-distribution.sh --name $NAME --tgz $FLAGS -DzincPort=$ZINC_PORT 2>&1 > \ - ../binary-release-$NAME.log - cd .. - cp spark-$RELEASE_VERSION-bin-$NAME/spark-$RELEASE_VERSION-bin-$NAME.tgz . - - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour \ - --output spark-$RELEASE_VERSION-bin-$NAME.tgz.asc \ - --detach-sig spark-$RELEASE_VERSION-bin-$NAME.tgz - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md \ - MD5 spark-$RELEASE_VERSION-bin-$NAME.tgz > \ - spark-$RELEASE_VERSION-bin-$NAME.tgz.md5 - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md \ - SHA512 spark-$RELEASE_VERSION-bin-$NAME.tgz > \ - spark-$RELEASE_VERSION-bin-$NAME.tgz.sha - } - - # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds - # share the same Zinc server. - make_binary_release "hadoop1" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver" "3030" & - make_binary_release "hadoop1-scala2.11" "-Psparkr -Phadoop-1 -Phive -Dscala-2.11" "3031" & - make_binary_release "cdh4" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & - make_binary_release "hadoop2.3" "-Psparkr -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & - make_binary_release "hadoop2.4" "-Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & - make_binary_release "mapr3" "-Pmapr3 -Psparkr -Phive -Phive-thriftserver" "3035" & - make_binary_release "mapr4" "-Pmapr4 -Psparkr -Pyarn -Phive -Phive-thriftserver" "3036" & - make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn" "3037" & - wait - rm -rf spark-$RELEASE_VERSION-bin-*/ - - # Copy data - echo "Copying release tarballs" - rc_folder=spark-$RELEASE_VERSION-$RC_NAME - ssh $ASF_USERNAME@people.apache.org \ - mkdir /home/$ASF_USERNAME/public_html/$rc_folder - scp spark-* \ - $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_folder/ - - # Docs - cd spark - sbt/sbt clean - cd docs - # Compile docs with Java 7 to use nicer format - JAVA_HOME="$JAVA_7_HOME" PRODUCTION=1 RELEASE_VERSION="$RELEASE_VERSION" jekyll build - echo "Copying release documentation" - rc_docs_folder=${rc_folder}-docs - ssh $ASF_USERNAME@people.apache.org \ - mkdir /home/$ASF_USERNAME/public_html/$rc_docs_folder - rsync -r _site/* $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_docs_folder - - echo "Release $RELEASE_VERSION completed:" - echo "Git tag:\t $GIT_TAG" - echo "Release commit:\t $release_hash" - echo "Binary location:\t http://people.apache.org/~$ASF_USERNAME/$rc_folder" - echo "Doc location:\t http://people.apache.org/~$ASF_USERNAME/$rc_docs_folder" -fi diff --git a/dev/create-release/generate-contributors.py b/dev/create-release/generate-contributors.py index 8aaa250bd7e2..db9c680a4bad 100755 --- a/dev/create-release/generate-contributors.py +++ b/dev/create-release/generate-contributors.py @@ -178,13 +178,16 @@ def populate(issue_type, components): author_info[author][issue_type].add(component) # Find issues and components associated with this commit for issue in issues: - jira_issue = jira_client.issue(issue) - jira_type = jira_issue.fields.issuetype.name - jira_type = translate_issue_type(jira_type, issue, warnings) - jira_components = [translate_component(c.name, _hash, warnings)\ - for c in jira_issue.fields.components] - all_components = set(jira_components + commit_components) - populate(jira_type, all_components) + try: + jira_issue = jira_client.issue(issue) + jira_type = jira_issue.fields.issuetype.name + jira_type = translate_issue_type(jira_type, issue, warnings) + jira_components = [translate_component(c.name, _hash, warnings)\ + for c in jira_issue.fields.components] + all_components = set(jira_components + commit_components) + populate(jira_type, all_components) + except Exception as e: + print "Unexpected error:", e # For docs without an associated JIRA, manually add it ourselves if is_docs(title) and not issues: populate("documentation", commit_components) @@ -223,7 +226,8 @@ def populate(issue_type, components): # E.g. andrewor14/SPARK-3425/SPARK-1157/SPARK-6672 if author in invalid_authors and invalid_authors[author]: author = author + "/" + "/".join(invalid_authors[author]) - line = " * %s -- %s" % (author, contribution) + #line = " * %s -- %s" % (author, contribution) + line = author contributors_file.write(line + "\n") contributors_file.close() print "Contributors list is successfully written to %s!" % contributors_file_name diff --git a/dev/create-release/known_translations b/dev/create-release/known_translations index 5f2671a6e505..3563fe3cc3c0 100644 --- a/dev/create-release/known_translations +++ b/dev/create-release/known_translations @@ -129,3 +129,39 @@ yongtang - Yong Tang ypcat - Pei-Lun Lee zhichao-li - Zhichao Li zzcclp - Zhichao Zhang +979969786 - Yuming Wang +Rosstin - Rosstin Murphy +ameyc - Amey Chaugule +animeshbaranawal - Animesh Baranawal +cafreeman - Chris Freeman +lee19 - Lee +lockwobr - Brian Lockwood +navis - Navis Ryu +pparkkin - Paavo Parkkinen +HyukjinKwon - Hyukjin Kwon +JDrit - Joseph Batchik +JuhongPark - Juhong Park +KaiXinXiaoLei - KaiXinXIaoLei +NamelessAnalyst - NamelessAnalyst +alyaxey - Alex Slusarenko +baishuo - Shuo Bai +fe2s - Oleksiy Dyagilev +felixcheung - Felix Cheung +feynmanliang - Feynman Liang +josepablocam - Jose Cambronero +kai-zeng - Kai Zeng +mosessky - mosessky +msannell - Michael Sannella +nishkamravi2 - Nishkam Ravi +noel-smith - Noel Smith +petz2000 - Patrick Baier +qiansl127 - Shilei Qian +rahulpalamuttam - Rahul Palamuttam +rowan000 - Rowan Chattaway +sarutak - Kousuke Saruta +sethah - Seth Hendrickson +small-wang - Wang Wei +stanzhai - Stan Zhai +tien-dungle - Tien-Dung Le +xuchenCN - Xu Chen +zhangjiajin - Zhang JiaJin diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh new file mode 100755 index 000000000000..9dac43ce5442 --- /dev/null +++ b/dev/create-release/release-build.sh @@ -0,0 +1,322 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +function exit_with_usage { + cat << EOF +usage: release-build.sh +Creates build deliverables from a Spark commit. + +Top level targets are + package: Create binary packages and copy them to people.apache + docs: Build docs and copy them to people.apache + publish-snapshot: Publish snapshot release to Apache snapshots + publish-release: Publish a release to Apache release repo + +All other inputs are environment variables + +GIT_REF - Release tag or commit to build from +SPARK_VERSION - Release identifier used when publishing +SPARK_PACKAGE_VERSION - Release identifier in top level package directory +REMOTE_PARENT_DIR - Parent in which to create doc or release builds. +REMOTE_PARENT_MAX_LENGTH - If set, parent directory will be cleaned to only + have this number of subdirectories (by deleting old ones). WARNING: This deletes data. + +ASF_USERNAME - Username of ASF committer account +ASF_PASSWORD - Password of ASF committer account +ASF_RSA_KEY - RSA private key file for ASF committer account + +GPG_KEY - GPG key used to sign release artifacts +GPG_PASSPHRASE - Passphrase for GPG key +EOF + exit 1 +} + +set -e + +if [ $# -eq 0 ]; then + exit_with_usage +fi + +if [[ $@ == *"help"* ]]; then + exit_with_usage +fi + +for env in ASF_USERNAME ASF_RSA_KEY GPG_PASSPHRASE GPG_KEY; do + if [ -z "${!env}" ]; then + echo "ERROR: $env must be set to run this script" + exit_with_usage + fi +done + +# Commit ref to checkout when building +GIT_REF=${GIT_REF:-master} + +# Destination directory parent on remote server +REMOTE_PARENT_DIR=${REMOTE_PARENT_DIR:-/home/$ASF_USERNAME/public_html} + +SSH="ssh -o StrictHostKeyChecking=no -i $ASF_RSA_KEY" +GPG="gpg --no-tty --batch" +NEXUS_ROOT=https://repository.apache.org/service/local/staging +NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads +BASE_DIR=$(pwd) + +MVN="build/mvn --force" +PUBLISH_PROFILES="-Pyarn -Phive -Phadoop-2.2" +PUBLISH_PROFILES="$PUBLISH_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" + +rm -rf spark +git clone https://git-wip-us.apache.org/repos/asf/spark.git +cd spark +git checkout $GIT_REF +git_hash=`git rev-parse --short HEAD` +echo "Checked out Spark git hash $git_hash" + +if [ -z "$SPARK_VERSION" ]; then + SPARK_VERSION=$($MVN help:evaluate -Dexpression=project.version \ + | grep -v INFO | grep -v WARNING | grep -v Download) +fi + +if [ -z "$SPARK_PACKAGE_VERSION" ]; then + SPARK_PACKAGE_VERSION="${SPARK_VERSION}-$(date +%Y_%m_%d_%H_%M)-${git_hash}" +fi + +DEST_DIR_NAME="spark-$SPARK_PACKAGE_VERSION" +USER_HOST="$ASF_USERNAME@people.apache.org" + +git clean -d -f -x +rm .gitignore +rm -rf .git +cd .. + +if [ -n "$REMOTE_PARENT_MAX_LENGTH" ]; then + old_dirs=$($SSH $USER_HOST ls -t $REMOTE_PARENT_DIR | tail -n +$REMOTE_PARENT_MAX_LENGTH) + for old_dir in $old_dirs; do + echo "Removing directory: $old_dir" + $SSH $USER_HOST rm -r $REMOTE_PARENT_DIR/$old_dir + done +fi + +if [[ "$1" == "package" ]]; then + # Source and binary tarballs + echo "Packaging release tarballs" + cp -r spark spark-$SPARK_VERSION + tar cvzf spark-$SPARK_VERSION.tgz spark-$SPARK_VERSION + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour --output spark-$SPARK_VERSION.tgz.asc \ + --detach-sig spark-$SPARK_VERSION.tgz + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md MD5 spark-$SPARK_VERSION.tgz > \ + spark-$SPARK_VERSION.tgz.md5 + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + SHA512 spark-$SPARK_VERSION.tgz > spark-$SPARK_VERSION.tgz.sha + rm -rf spark-$SPARK_VERSION + + # Updated for each binary build + make_binary_release() { + NAME=$1 + FLAGS=$2 + ZINC_PORT=$3 + cp -r spark spark-$SPARK_VERSION-bin-$NAME + + cd spark-$SPARK_VERSION-bin-$NAME + + # TODO There should probably be a flag to make-distribution to allow 2.11 support + if [[ $FLAGS == *scala-2.11* ]]; then + ./dev/change-scala-version.sh 2.11 + fi + + export ZINC_PORT=$ZINC_PORT + echo "Creating distribution: $NAME ($FLAGS)" + ./make-distribution.sh --name $NAME --tgz $FLAGS -DzincPort=$ZINC_PORT 2>&1 > \ + ../binary-release-$NAME.log + cd .. + cp spark-$SPARK_VERSION-bin-$NAME/spark-$SPARK_VERSION-bin-$NAME.tgz . + + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ + --output spark-$SPARK_VERSION-bin-$NAME.tgz.asc \ + --detach-sig spark-$SPARK_VERSION-bin-$NAME.tgz + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + MD5 spark-$SPARK_VERSION-bin-$NAME.tgz > \ + spark-$SPARK_VERSION-bin-$NAME.tgz.md5 + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + SHA512 spark-$SPARK_VERSION-bin-$NAME.tgz > \ + spark-$SPARK_VERSION-bin-$NAME.tgz.sha + } + + # TODO: Check exit codes of children here: + # http://stackoverflow.com/questions/1570262/shell-get-exit-code-of-background-process + + # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds + # share the same Zinc server. + make_binary_release "hadoop1" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver" "3030" & + make_binary_release "hadoop1-scala2.11" "-Psparkr -Phadoop-1 -Phive -Dscala-2.11" "3031" & + make_binary_release "cdh4" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & + make_binary_release "hadoop2.3" "-Psparkr -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & + make_binary_release "hadoop2.4" "-Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & + make_binary_release "hadoop2.6" "-Psparkr -Phadoop-2.6 -Phive -Phive-thriftserver -Pyarn" "3034" & + make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn" "3037" & + make_binary_release "without-hadoop" "-Psparkr -Phadoop-provided -Pyarn" "3038" & + wait + rm -rf spark-$SPARK_VERSION-bin-*/ + + # Copy data + dest_dir="$REMOTE_PARENT_DIR/${DEST_DIR_NAME}-bin" + echo "Copying release tarballs to $dest_dir" + $SSH $USER_HOST mkdir $dest_dir + rsync -e "$SSH" spark-* $USER_HOST:$dest_dir + echo "Linking /latest to $dest_dir" + $SSH $USER_HOST rm -f "$REMOTE_PARENT_DIR/latest" + $SSH $USER_HOST ln -s $dest_dir "$REMOTE_PARENT_DIR/latest" + exit 0 +fi + +if [[ "$1" == "docs" ]]; then + # Documentation + cd spark + echo "Building Spark docs" + dest_dir="$REMOTE_PARENT_DIR/${DEST_DIR_NAME}-docs" + cd docs + # Compile docs with Java 7 to use nicer format + # TODO: Make configurable to add this: PRODUCTION=1 + PRODUCTION=1 RELEASE_VERSION="$SPARK_VERSION" jekyll build + echo "Copying release documentation to $dest_dir" + $SSH $USER_HOST mkdir $dest_dir + echo "Linking /latest to $dest_dir" + $SSH $USER_HOST rm -f "$REMOTE_PARENT_DIR/latest" + $SSH $USER_HOST ln -s $dest_dir "$REMOTE_PARENT_DIR/latest" + rsync -e "$SSH" -r _site/* $USER_HOST:$dest_dir + cd .. + exit 0 +fi + +if [[ "$1" == "publish-snapshot" ]]; then + cd spark + # Publish Spark to Maven release repo + echo "Deploying Spark SNAPSHOT at '$GIT_REF' ($git_hash)" + echo "Publish version is $SPARK_VERSION" + if [[ ! $SPARK_VERSION == *"SNAPSHOT"* ]]; then + echo "ERROR: Snapshots must have a version containing SNAPSHOT" + echo "ERROR: You gave version '$SPARK_VERSION'" + exit 1 + fi + # Coerce the requested version + $MVN versions:set -DnewVersion=$SPARK_VERSION + tmp_settings="tmp-settings.xml" + echo "" > $tmp_settings + echo "apache.snapshots.https$ASF_USERNAME" >> $tmp_settings + echo "$ASF_PASSWORD" >> $tmp_settings + echo "" >> $tmp_settings + + # Generate random point for Zinc + export ZINC_PORT=$(python -S -c "import random; print random.randrange(3030,4030)") + + $MVN -DzincPort=$ZINC_PORT --settings $tmp_settings -DskipTests $PUBLISH_PROFILES \ + -Phive-thriftserver deploy + ./dev/change-scala-version.sh 2.11 + $MVN -DzincPort=$ZINC_PORT -Dscala-2.11 --settings $tmp_settings \ + -DskipTests $PUBLISH_PROFILES clean deploy + + # Clean-up Zinc nailgun process + /usr/sbin/lsof -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill + + rm $tmp_settings + cd .. + exit 0 +fi + +if [[ "$1" == "publish-release" ]]; then + cd spark + # Publish Spark to Maven release repo + echo "Publishing Spark checkout at '$GIT_REF' ($git_hash)" + echo "Publish version is $SPARK_VERSION" + # Coerce the requested version + $MVN versions:set -DnewVersion=$SPARK_VERSION + + # Using Nexus API documented here: + # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API + echo "Creating Nexus staging repository" + repo_request="Apache Spark $SPARK_VERSION (commit $git_hash)" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) + staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") + echo "Created Nexus staging repository: $staged_repo_id" + + tmp_repo=$(mktemp -d spark-repo-XXXXX) + + # Generate random point for Zinc + export ZINC_PORT=$(python -S -c "import random; print random.randrange(3030,4030)") + + $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests $PUBLISH_PROFILES \ + -Phive-thriftserver clean install + + ./dev/change-scala-version.sh 2.11 + + $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -Dscala-2.11 \ + -DskipTests $PUBLISH_PROFILES clean install + + # Clean-up Zinc nailgun process + /usr/sbin/lsof -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill + + ./dev/change-version-to-2.10.sh + + pushd $tmp_repo/org/apache/spark + + # Remove any extra files generated during install + find . -type f |grep -v \.jar |grep -v \.pom | xargs rm + + echo "Creating hash and signature files" + for file in $(find . -type f) + do + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --output $file.asc \ + --detach-sig --armour $file; + if [ $(command -v md5) ]; then + # Available on OS X; -q to keep only hash + md5 -q $file > $file.md5 + else + # Available on Linux; cut to keep only hash + md5sum $file | cut -f1 -d' ' > $file.md5 + fi + sha1sum $file | cut -f1 -d' ' > $file.sha1 + done + + nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id + echo "Uplading files to $nexus_upload" + for file in $(find . -type f) + do + # strip leading ./ + file_short=$(echo $file | sed -e "s/\.\///") + dest_url="$nexus_upload/org/apache/spark/$file_short" + echo " Uploading $file_short" + curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url + done + + echo "Closing nexus staging repository" + repo_request="$staged_repo_idApache Spark $SPARK_VERSION (commit $git_hash)" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) + echo "Closed Nexus staging repository: $staged_repo_id" + popd + rm -rf $tmp_repo + cd .. + exit 0 +fi + +cd .. +rm -rf spark +echo "ERROR: expects to be called with 'package', 'docs', 'publish-release' or 'publish-snapshot'" diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh new file mode 100755 index 000000000000..b0a3374becc6 --- /dev/null +++ b/dev/create-release/release-tag.sh @@ -0,0 +1,79 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +function exit_with_usage { + cat << EOF +usage: tag-release.sh +Tags a Spark release on a particular branch. + +Inputs are specified with the following environment variables: +ASF_USERNAME - Apache Username +ASF_PASSWORD - Apache Password +GIT_NAME - Name to use with git +GIT_EMAIL - E-mail address to use with git +GIT_BRANCH - Git branch on which to make release +RELEASE_VERSION - Version used in pom files for release +RELEASE_TAG - Name of release tag +NEXT_VERSION - Development version after release +EOF + exit 1 +} + +set -e + +if [[ $@ == *"help"* ]]; then + exit_with_usage +fi + +for env in ASF_USERNAME ASF_PASSWORD RELEASE_VERSION RELEASE_TAG NEXT_VERSION GIT_EMAIL GIT_NAME GIT_BRANCH; do + if [ -z "${!env}" ]; then + echo "$env must be set to run this script" + exit 1 + fi +done + +ASF_SPARK_REPO="git-wip-us.apache.org/repos/asf/spark.git" +MVN="build/mvn --force" + +rm -rf spark +git clone https://$ASF_USERNAME:$ASF_PASSWORD@$ASF_SPARK_REPO -b $GIT_BRANCH +cd spark + +git config user.name "$GIT_NAME" +git config user.email $GIT_EMAIL + +# Create release version +$MVN versions:set -DnewVersion=$RELEASE_VERSION | grep -v "no value" # silence logs +git commit -a -m "Preparing Spark release $RELEASE_TAG" +echo "Creating tag $RELEASE_TAG at the head of $GIT_BRANCH" +git tag $RELEASE_TAG + +# TODO: It would be nice to do some verifications here +# i.e. check whether ec2 scripts have the new version + +# Create next version +$MVN versions:set -DnewVersion=$NEXT_VERSION | grep -v "no value" # silence logs +git commit -a -m "Preparing development version $NEXT_VERSION" + +# Push changes +git push origin $RELEASE_TAG +git push origin HEAD:$GIT_BRANCH + +cd .. +rm -rf spark diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index 51ab25a6a5bd..7f152b7f5355 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -24,7 +24,11 @@ try: from jira.client import JIRA - from jira.exceptions import JIRAError + # Old versions have JIRAError in exceptions package, new (0.5+) in utils. + try: + from jira.exceptions import JIRAError + except ImportError: + from jira.utils import JIRAError except ImportError: print "This tool requires the jira-python library" print "Install using 'sudo pip install jira'" diff --git a/dev/lint-python b/dev/lint-python index f50d149dc4d4..575dbb0ae321 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -19,13 +19,16 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" -PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/" -PYTHON_LINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/python-lint-report.txt" +PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/ ./dev/sparktestsupport" +PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/run-tests.py" +PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" +PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" +PYLINT_INSTALL_INFO="$SPARK_ROOT_DIR/dev/pylint-info.txt" cd "$SPARK_ROOT_DIR" # compileall: https://docs.python.org/2/library/compileall.html -python -B -m compileall -q -l $PATHS_TO_CHECK > "$PYTHON_LINT_REPORT_PATH" +python -B -m compileall -q -l $PATHS_TO_CHECK > "$PEP8_REPORT_PATH" compile_status="${PIPESTATUS[0]}" # Get pep8 at runtime so that we don't rely on it being installed on the build server. @@ -46,11 +49,36 @@ if [ ! -e "$PEP8_SCRIPT_PATH" ]; then fi fi +# Easy install pylint in /dev/pylint. To easy_install into a directory, the PYTHONPATH should +# be set to the directory. +# dev/pylint should be appended to the PATH variable as well. +# Jenkins by default installs the pylint3 version, so for now this just checks the code quality +# of python3. +export "PYTHONPATH=$SPARK_ROOT_DIR/dev/pylint" +export "PYLINT_HOME=$PYTHONPATH" +export "PATH=$PYTHONPATH:$PATH" + +# if [ ! -d "$PYLINT_HOME" ]; then +# mkdir "$PYLINT_HOME" +# # Redirect the annoying pylint installation output. +# easy_install -d "$PYLINT_HOME" pylint==1.4.4 &>> "$PYLINT_INSTALL_INFO" +# easy_install_status="$?" +# +# if [ "$easy_install_status" -ne 0 ]; then +# echo "Unable to install pylint locally in \"$PYTHONPATH\"." +# cat "$PYLINT_INSTALL_INFO" +# exit "$easy_install_status" +# fi +# +# rm "$PYLINT_INSTALL_INFO" +# +# fi + # There is no need to write this output to a file #+ first, but we do so so that the check status can #+ be output before the report, like with the #+ scalastyle and RAT checks. -python "$PEP8_SCRIPT_PATH" --ignore=E402,E731,E241,W503,E226 $PATHS_TO_CHECK >> "$PYTHON_LINT_REPORT_PATH" +python "$PEP8_SCRIPT_PATH" --ignore=E402,E731,E241,W503,E226 $PATHS_TO_CHECK >> "$PEP8_REPORT_PATH" pep8_status="${PIPESTATUS[0]}" if [ "$compile_status" -eq 0 -a "$pep8_status" -eq 0 ]; then @@ -60,13 +88,27 @@ else fi if [ "$lint_status" -ne 0 ]; then - echo "Python lint checks failed." - cat "$PYTHON_LINT_REPORT_PATH" + echo "PEP8 checks failed." + cat "$PEP8_REPORT_PATH" else - echo "Python lint checks passed." + echo "PEP8 checks passed." fi -# rm "$PEP8_SCRIPT_PATH" -rm "$PYTHON_LINT_REPORT_PATH" +rm "$PEP8_REPORT_PATH" + +# for to_be_checked in "$PATHS_TO_CHECK" +# do +# pylint --rcfile="$SPARK_ROOT_DIR/pylintrc" $to_be_checked >> "$PYLINT_REPORT_PATH" +# done + +# if [ "${PIPESTATUS[0]}" -ne 0 ]; then +# lint_status=1 +# echo "Pylint checks failed." +# cat "$PYLINT_REPORT_PATH" +# else +# echo "Pylint checks passed." +# fi + +# rm "$PYLINT_REPORT_PATH" exit "$lint_status" diff --git a/dev/lint-r b/dev/lint-r index 7d5f4cd31153..bfda0bca15eb 100755 --- a/dev/lint-r +++ b/dev/lint-r @@ -28,3 +28,14 @@ if ! type "Rscript" > /dev/null; then fi `which Rscript` --vanilla "$SPARK_ROOT_DIR/dev/lint-r.R" "$SPARK_ROOT_DIR" | tee "$LINT_R_REPORT_FILE_NAME" + +NUM_LINES=`wc -l < "$LINT_R_REPORT_FILE_NAME" | awk '{print $1}'` +if [ "$NUM_LINES" = "0" ] ; then + lint_status=0 + echo "lintr checks passed." +else + lint_status=1 + echo "lintr checks failed." +fi + +exit "$lint_status" diff --git a/dev/lint-r.R b/dev/lint-r.R index dcb1a184291e..999eef571b82 100644 --- a/dev/lint-r.R +++ b/dev/lint-r.R @@ -15,15 +15,23 @@ # limitations under the License. # -# Installs lintr from Github. +argv <- commandArgs(TRUE) +SPARK_ROOT_DIR <- as.character(argv[1]) +LOCAL_LIB_LOC <- file.path(SPARK_ROOT_DIR, "R", "lib") + +# Checks if SparkR is installed in a local directory. +if (! library(SparkR, lib.loc = LOCAL_LIB_LOC, logical.return = TRUE)) { + stop("You should install SparkR in a local directory with `R/install-dev.sh`.") +} + +# Installs lintr from Github in a local directory. # NOTE: The CRAN's version is too old to adapt to our rules. if ("lintr" %in% row.names(installed.packages()) == FALSE) { devtools::install_github("jimhester/lintr") } -library(lintr) - -argv <- commandArgs(TRUE) -SPARK_ROOT_DIR <- as.character(argv[1]) +library(lintr) +library(methods) +library(testthat) path.to.package <- file.path(SPARK_ROOT_DIR, "R", "pkg") lint_package(path.to.package, cache = FALSE) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index cd83b352c1bf..b9bdec3d7086 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -47,6 +47,12 @@ JIRA_USERNAME = os.environ.get("JIRA_USERNAME", "") # ASF JIRA password JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", "") +# OAuth key used for issuing requests against the GitHub API. If this is not defined, then requests +# will be unauthenticated. You should only need to configure this if you find yourself regularly +# exceeding your IP's unauthenticated request rate limit. You can create an OAuth key at +# https://github.com/settings/tokens. This script only requires the "public_repo" scope. +GITHUB_OAUTH_KEY = os.environ.get("GITHUB_OAUTH_KEY") + GITHUB_BASE = "https://github.com/apache/spark/pull" GITHUB_API_BASE = "https://api.github.com/repos/apache/spark" @@ -58,9 +64,17 @@ def get_json(url): try: - return json.load(urllib2.urlopen(url)) + request = urllib2.Request(url) + if GITHUB_OAUTH_KEY: + request.add_header('Authorization', 'token %s' % GITHUB_OAUTH_KEY) + return json.load(urllib2.urlopen(request)) except urllib2.HTTPError as e: - print "Unable to fetch URL, exiting: %s" % url + if "X-RateLimit-Remaining" in e.headers and e.headers["X-RateLimit-Remaining"] == '0': + print "Exceeded the GitHub API rate limit; see the instructions in " + \ + "dev/merge_spark_pr.py to configure an OAuth token for making authenticated " + \ + "GitHub requests." + else: + print "Unable to fetch URL, exiting: %s" % url sys.exit(-1) @@ -116,7 +130,12 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): '--pretty=format:%an <%ae>']).split("\n") distinct_authors = sorted(set(commit_authors), key=lambda x: commit_authors.count(x), reverse=True) - primary_author = distinct_authors[0] + primary_author = raw_input( + "Enter primary author in the format of \"name \" [%s]: " % + distinct_authors[0]) + if primary_author == "": + primary_author = distinct_authors[0] + commits = run_cmd(['git', 'log', 'HEAD..%s' % pr_branch_name, '--pretty=format:%h [%an] %s']).split("\n\n") @@ -140,11 +159,7 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): merge_message_flags += ["-m", message] # The string "Closes #%s" string is required for GitHub to correctly close the PR - merge_message_flags += [ - "-m", - "Closes #%s from %s and squashes the following commits:" % (pr_num, pr_repo_desc)] - for c in commits: - merge_message_flags += ["-m", c] + merge_message_flags += ["-m", "Closes #%s from %s." % (pr_num, pr_repo_desc)] run_cmd(['git', 'commit', '--author="%s"' % primary_author] + merge_message_flags) @@ -267,7 +282,7 @@ def get_version_json(version_str): resolve = filter(lambda a: a['name'] == "Resolve Issue", asf_jira.transitions(jira_id))[0] resolution = filter(lambda r: r.raw['name'] == "Fixed", asf_jira.resolutions())[0] asf_jira.transition_issue( - jira_id, resolve["id"], fixVersions = jira_fix_versions, + jira_id, resolve["id"], fixVersions = jira_fix_versions, comment = comment, resolution = {'id': resolution.raw['id']}) print "Successfully resolved %s with fixVersions=%s!" % (jira_id, fix_versions) @@ -286,7 +301,7 @@ def standardize_jira_ref(text): """ Standardize the [SPARK-XXXXX] [MODULE] prefix Converts "[SPARK-XXX][mllib] Issue", "[MLLib] SPARK-XXX. Issue" or "SPARK XXX [MLLIB]: Issue" to "[SPARK-XXX] [MLLIB] Issue" - + >>> standardize_jira_ref("[SPARK-5821] [SQL] ParquetRelation2 CTAS should check if delete is successful") '[SPARK-5821] [SQL] ParquetRelation2 CTAS should check if delete is successful' >>> standardize_jira_ref("[SPARK-4123][Project Infra][WIP]: Show new dependencies added in pull requests") @@ -308,11 +323,11 @@ def standardize_jira_ref(text): """ jira_refs = [] components = [] - + # If the string is compliant, no need to process any further if (re.search(r'^\[SPARK-[0-9]{3,6}\] (\[[A-Z0-9_\s,]+\] )+\S+', text)): return text - + # Extract JIRA ref(s): pattern = re.compile(r'(SPARK[-\s]*[0-9]{3,6})+', re.IGNORECASE) for ref in pattern.findall(text): @@ -334,18 +349,18 @@ def standardize_jira_ref(text): # Assemble full text (JIRA ref(s), module(s), remaining text) clean_text = ' '.join(jira_refs).strip() + " " + ' '.join(components).strip() + " " + text.strip() - + # Replace multiple spaces with a single space, e.g. if no jira refs and/or components were included clean_text = re.sub(r'\s+', ' ', clean_text.strip()) - + return clean_text def main(): global original_head - + os.chdir(SPARK_HOME) original_head = run_cmd("git rev-parse HEAD")[:8] - + branches = get_json("%s/branches" % GITHUB_API_BASE) branch_names = filter(lambda x: x.startswith("branch-"), [x['name'] for x in branches]) # Assumes branch names can be sorted lexicographically @@ -431,6 +446,8 @@ def main(): if __name__ == "__main__": import doctest - doctest.testmod() - + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) + main() diff --git a/dev/run-tests b/dev/run-tests index a00d9f0c2763..257d1e8d50bb 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -20,4 +20,4 @@ FWDIR="$(cd "`dirname $0`"/..; pwd)" cd "$FWDIR" -exec python -u ./dev/run-tests.py +exec python -u ./dev/run-tests.py "$@" diff --git a/dev/run-tests-codes.sh b/dev/run-tests-codes.sh index f4b238e1b78a..1f16790522e7 100644 --- a/dev/run-tests-codes.sh +++ b/dev/run-tests-codes.sh @@ -21,9 +21,10 @@ readonly BLOCK_GENERAL=10 readonly BLOCK_RAT=11 readonly BLOCK_SCALA_STYLE=12 readonly BLOCK_PYTHON_STYLE=13 -readonly BLOCK_DOCUMENTATION=14 -readonly BLOCK_BUILD=15 -readonly BLOCK_MIMA=16 -readonly BLOCK_SPARK_UNIT_TESTS=17 -readonly BLOCK_PYSPARK_UNIT_TESTS=18 -readonly BLOCK_SPARKR_UNIT_TESTS=19 +readonly BLOCK_R_STYLE=14 +readonly BLOCK_DOCUMENTATION=15 +readonly BLOCK_BUILD=16 +readonly BLOCK_MIMA=17 +readonly BLOCK_SPARK_UNIT_TESTS=18 +readonly BLOCK_PYSPARK_UNIT_TESTS=19 +readonly BLOCK_SPARKR_UNIT_TESTS=20 diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index c4d39d95d589..3be78575e70f 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -48,8 +48,8 @@ COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}" SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}" # format: http://linux.die.net/man/1/timeout -# must be less than the timeout configured on Jenkins (currently 180m) -TESTS_TIMEOUT="175m" +# must be less than the timeout configured on Jenkins (currently 300m) +TESTS_TIMEOUT="250m" # Array to capture all tests to run on the pull request. These tests are held under the #+ dev/tests/ directory. @@ -164,8 +164,9 @@ pr_message="" current_pr_head="`git rev-parse HEAD`" echo "HEAD: `git rev-parse HEAD`" -echo "GHPRB: $ghprbActualCommit" -echo "SHA1: $sha1" +echo "\$ghprbActualCommit: $ghprbActualCommit" +echo "\$sha1: $sha1" +echo "\$ghprbPullTitle: $ghprbPullTitle" # Run pull request tests for t in "${PR_TESTS[@]}"; do @@ -189,6 +190,19 @@ done { # Marks this build is a pull request build. export AMP_JENKINS_PRB=true + if [[ $ghprbPullTitle == *"test-maven"* ]]; then + export AMPLAB_JENKINS_BUILD_TOOL="maven" + fi + if [[ $ghprbPullTitle == *"test-hadoop1.0"* ]]; then + export AMPLAB_JENKINS_BUILD_PROFILE="hadoop1.0" + elif [[ $ghprbPullTitle == *"test-hadoop2.0"* ]]; then + export AMPLAB_JENKINS_BUILD_PROFILE="hadoop2.0" + elif [[ $ghprbPullTitle == *"test-hadoop2.2"* ]]; then + export AMPLAB_JENKINS_BUILD_PROFILE="hadoop2.2" + elif [[ $ghprbPullTitle == *"test-hadoop2.3"* ]]; then + export AMPLAB_JENKINS_BUILD_PROFILE="hadoop2.3" + fi + timeout "${TESTS_TIMEOUT}" ./dev/run-tests test_result="$?" @@ -210,6 +224,8 @@ done failing_test="Scala style tests" elif [ "$test_result" -eq "$BLOCK_PYTHON_STYLE" ]; then failing_test="Python style tests" + elif [ "$test_result" -eq "$BLOCK_R_STYLE" ]; then + failing_test="R style tests" elif [ "$test_result" -eq "$BLOCK_DOCUMENTATION" ]; then failing_test="to generate documentation" elif [ "$test_result" -eq "$BLOCK_BUILD" ]; then diff --git a/dev/run-tests.py b/dev/run-tests.py index 2cccfed75ede..d8b22e1665e7 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -17,297 +17,25 @@ # limitations under the License. # +from __future__ import print_function import itertools +from optparse import OptionParser import os +import random import re import sys -import shutil import subprocess from collections import namedtuple -SPARK_HOME = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..") -USER_HOME = os.environ.get("HOME") - +from sparktestsupport import SPARK_HOME, USER_HOME +from sparktestsupport.shellutils import exit_from_command_with_retcode, run_cmd, rm_r, which +import sparktestsupport.modules as modules # ------------------------------------------------------------------------------------------------- -# Test module definitions and functions for traversing module dependency graph +# Functions for traversing module dependency graph # ------------------------------------------------------------------------------------------------- -all_modules = [] - - -class Module(object): - """ - A module is the basic abstraction in our test runner script. Each module consists of a set of - source files, a set of test commands, and a set of dependencies on other modules. We use modules - to define a dependency graph that lets determine which tests to run based on which files have - changed. - """ - - def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), - sbt_test_goals=(), should_run_python_tests=False, should_run_r_tests=False): - """ - Define a new module. - - :param name: A short module name, for display in logging and error messages. - :param dependencies: A set of dependencies for this module. This should only include direct - dependencies; transitive dependencies are resolved automatically. - :param source_file_regexes: a set of regexes that match source files belonging to this - module. These regexes are applied by attempting to match at the beginning of the - filename strings. - :param build_profile_flags: A set of profile flags that should be passed to Maven or SBT in - order to build and test this module (e.g. '-PprofileName'). - :param sbt_test_goals: A set of SBT test goals for testing this module. - :param should_run_python_tests: If true, changes in this module will trigger Python tests. - For now, this has the effect of causing _all_ Python tests to be run, although in the - future this should be changed to run only a subset of the Python tests that depend - on this module. - :param should_run_r_tests: If true, changes in this module will trigger all R tests. - """ - self.name = name - self.dependencies = dependencies - self.source_file_prefixes = source_file_regexes - self.sbt_test_goals = sbt_test_goals - self.build_profile_flags = build_profile_flags - self.should_run_python_tests = should_run_python_tests - self.should_run_r_tests = should_run_r_tests - - self.dependent_modules = set() - for dep in dependencies: - dep.dependent_modules.add(self) - all_modules.append(self) - - def contains_file(self, filename): - return any(re.match(p, filename) for p in self.source_file_prefixes) - - -sql = Module( - name="sql", - dependencies=[], - source_file_regexes=[ - "sql/(?!hive-thriftserver)", - "bin/spark-sql", - ], - build_profile_flags=[ - "-Phive", - ], - sbt_test_goals=[ - "catalyst/test", - "sql/test", - "hive/test", - ] -) - - -hive_thriftserver = Module( - name="hive-thriftserver", - dependencies=[sql], - source_file_regexes=[ - "sql/hive-thriftserver", - "sbin/start-thriftserver.sh", - ], - build_profile_flags=[ - "-Phive-thriftserver", - ], - sbt_test_goals=[ - "hive-thriftserver/test", - ] -) - - -graphx = Module( - name="graphx", - dependencies=[], - source_file_regexes=[ - "graphx/", - ], - sbt_test_goals=[ - "graphx/test" - ] -) - - -streaming = Module( - name="streaming", - dependencies=[], - source_file_regexes=[ - "streaming", - ], - sbt_test_goals=[ - "streaming/test", - ] -) - - -streaming_kinesis_asl = Module( - name="kinesis-asl", - dependencies=[streaming], - source_file_regexes=[ - "extras/kinesis-asl/", - ], - build_profile_flags=[ - "-Pkinesis-asl", - ], - sbt_test_goals=[ - "kinesis-asl/test", - ] -) - - -streaming_zeromq = Module( - name="streaming-zeromq", - dependencies=[streaming], - source_file_regexes=[ - "external/zeromq", - ], - sbt_test_goals=[ - "streaming-zeromq/test", - ] -) - - -streaming_twitter = Module( - name="streaming-twitter", - dependencies=[streaming], - source_file_regexes=[ - "external/twitter", - ], - sbt_test_goals=[ - "streaming-twitter/test", - ] -) - - -streaming_mqqt = Module( - name="streaming-mqqt", - dependencies=[streaming], - source_file_regexes=[ - "external/mqqt", - ], - sbt_test_goals=[ - "streaming-mqqt/test", - ] -) - - -streaming_kafka = Module( - name="streaming-kafka", - dependencies=[streaming], - source_file_regexes=[ - "external/kafka", - "external/kafka-assembly", - ], - sbt_test_goals=[ - "streaming-kafka/test", - ] -) - - -streaming_flume_sink = Module( - name="streaming-flume-sink", - dependencies=[streaming], - source_file_regexes=[ - "external/flume-sink", - ], - sbt_test_goals=[ - "streaming-flume-sink/test", - ] -) - - -streaming_flume = Module( - name="streaming_flume", - dependencies=[streaming], - source_file_regexes=[ - "external/flume", - ], - sbt_test_goals=[ - "streaming-flume/test", - ] -) - - -mllib = Module( - name="mllib", - dependencies=[streaming, sql], - source_file_regexes=[ - "data/mllib/", - "mllib/", - ], - sbt_test_goals=[ - "mllib/test", - ] -) - - -examples = Module( - name="examples", - dependencies=[graphx, mllib, streaming, sql], - source_file_regexes=[ - "examples/", - ], - sbt_test_goals=[ - "examples/test", - ] -) - - -pyspark = Module( - name="pyspark", - dependencies=[mllib, streaming, streaming_kafka, sql], - source_file_regexes=[ - "python/" - ], - should_run_python_tests=True -) - - -sparkr = Module( - name="sparkr", - dependencies=[sql, mllib], - source_file_regexes=[ - "R/", - ], - should_run_r_tests=True -) - - -docs = Module( - name="docs", - dependencies=[], - source_file_regexes=[ - "docs/", - ] -) - - -ec2 = Module( - name="ec2", - dependencies=[], - source_file_regexes=[ - "ec2/", - ] -) - - -# The root module is a dummy module which is used to run all of the tests. -# No other modules should directly depend on this module. -root = Module( - name="root", - dependencies=[], - source_file_regexes=[], - # In order to run all of the tests, enable every test profile: - build_profile_flags= - list(set(itertools.chain.from_iterable(m.build_profile_flags for m in all_modules))), - sbt_test_goals=[ - "test", - ], - should_run_python_tests=True, - should_run_r_tests=True -) - - def determine_modules_for_files(filenames): """ Given a list of filenames, return the set of modules that contain those files. @@ -315,19 +43,19 @@ def determine_modules_for_files(filenames): file to belong to the 'root' module. >>> sorted(x.name for x in determine_modules_for_files(["python/pyspark/a.py", "sql/test/foo"])) - ['pyspark', 'sql'] + ['pyspark-core', 'sql'] >>> [x.name for x in determine_modules_for_files(["file_not_matched_by_any_subproject"])] ['root'] """ changed_modules = set() for filename in filenames: matched_at_least_one_module = False - for module in all_modules: + for module in modules.all_modules: if module.contains_file(filename): changed_modules.add(module) matched_at_least_one_module = True if not matched_at_least_one_module: - changed_modules.add(root) + changed_modules.add(modules.root) return changed_modules @@ -352,28 +80,38 @@ def identify_changed_files_from_git_commits(patch_sha, target_branch=None, targe run_cmd(['git', 'fetch', 'origin', str(target_branch+':'+target_branch)]) else: diff_target = target_ref - raw_output = subprocess.check_output(['git', 'diff', '--name-only', patch_sha, diff_target]) + raw_output = subprocess.check_output(['git', 'diff', '--name-only', patch_sha, diff_target], + universal_newlines=True) # Remove any empty strings return [f for f in raw_output.split('\n') if f] +def setup_test_environ(environ): + print("[info] Setup the following environment variables for tests: ") + for (k, v) in environ.items(): + print("%s=%s" % (k, v)) + os.environ[k] = v + + def determine_modules_to_test(changed_modules): """ Given a set of modules that have changed, compute the transitive closure of those modules' dependent modules in order to determine the set of modules that should be tested. - >>> sorted(x.name for x in determine_modules_to_test([root])) + >>> sorted(x.name for x in determine_modules_to_test([modules.root])) ['root'] - >>> sorted(x.name for x in determine_modules_to_test([graphx])) + >>> sorted(x.name for x in determine_modules_to_test([modules.graphx])) ['examples', 'graphx'] - >>> sorted(x.name for x in determine_modules_to_test([sql])) - ['examples', 'hive-thriftserver', 'mllib', 'pyspark', 'sparkr', 'sql'] + >>> x = sorted(x.name for x in determine_modules_to_test([modules.sql])) + >>> x # doctest: +NORMALIZE_WHITESPACE + ['examples', 'hive-thriftserver', 'mllib', 'pyspark-ml', \ + 'pyspark-mllib', 'pyspark-sql', 'sparkr', 'sql'] """ # If we're going to have to run all of the tests, then we can just short-circuit # and return 'root'. No module depends on root, so if it appears then it will be # in changed_modules. - if root in changed_modules: - return [root] + if modules.root in changed_modules: + return [modules.root] modules_to_test = set() for module in changed_modules: modules_to_test = modules_to_test.union(determine_modules_to_test(module.dependent_modules)) @@ -398,60 +136,6 @@ def get_error_codes(err_code_file): ERROR_CODES = get_error_codes(os.path.join(SPARK_HOME, "dev/run-tests-codes.sh")) -def exit_from_command_with_retcode(cmd, retcode): - print "[error] running", ' '.join(cmd), "; received return code", retcode - sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) - - -def rm_r(path): - """Given an arbitrary path properly remove it with the correct python - construct if it exists - - from: http://stackoverflow.com/a/9559881""" - - if os.path.isdir(path): - shutil.rmtree(path) - elif os.path.exists(path): - os.remove(path) - - -def run_cmd(cmd): - """Given a command as a list of arguments will attempt to execute the - command from the determined SPARK_HOME directory and, on failure, print - an error message""" - - if not isinstance(cmd, list): - cmd = cmd.split() - try: - subprocess.check_call(cmd) - except subprocess.CalledProcessError as e: - exit_from_command_with_retcode(e.cmd, e.returncode) - - -def is_exe(path): - """Check if a given path is an executable file - - from: http://stackoverflow.com/a/377028""" - - return os.path.isfile(path) and os.access(path, os.X_OK) - - -def which(program): - """Find and return the given program by its absolute path or 'None' - - from: http://stackoverflow.com/a/377028""" - - fpath = os.path.split(program)[0] - - if fpath: - if is_exe(program): - return program - else: - for path in os.environ.get("PATH").split(os.pathsep): - path = path.strip('"') - exe_file = os.path.join(path, program) - if is_exe(exe_file): - return exe_file - return None - - def determine_java_executable(): """Will return the path of the java executable that will be used by Spark's tests or `None`""" @@ -476,8 +160,14 @@ def determine_java_version(java_exe): with accessors '.major', '.minor', '.patch', '.update'""" raw_output = subprocess.check_output([java_exe, "-version"], - stderr=subprocess.STDOUT) - raw_version_str = raw_output.split('\n')[0] # eg 'java version "1.8.0_25"' + stderr=subprocess.STDOUT, + universal_newlines=True) + + raw_output_lines = raw_output.split('\n') + + # find raw version string, eg 'java version "1.8.0_25"' + raw_version_str = next(x for x in raw_output_lines if " version " in x) + version_str = raw_version_str.split()[-1].strip('"') # eg '1.8.0_25' version, update = version_str.split('_') # eg ['1.8.0', '25'] @@ -499,10 +189,10 @@ def set_title_and_block(title, err_block): os.environ["CURRENT_BLOCK"] = ERROR_CODES[err_block] line_str = '=' * 72 - print - print line_str - print title - print line_str + print('') + print(line_str) + print(title) + print(line_str) def run_apache_rat_checks(): @@ -520,6 +210,18 @@ def run_python_style_checks(): run_cmd([os.path.join(SPARK_HOME, "dev", "lint-python")]) +def run_sparkr_style_checks(): + set_title_and_block("Running R style checks", "BLOCK_R_STYLE") + + if which("R"): + # R style check should be executed after `install-dev.sh`. + # Since warnings about `no visible global function definition` appear + # without the installation. SEE ALSO: SPARK-9121. + run_cmd([os.path.join(SPARK_HOME, "dev", "lint-r")]) + else: + print("Ignoring SparkR style check as R was not found in PATH") + + def build_spark_documentation(): set_title_and_block("Building Spark Documentation", "BLOCK_DOCUMENTATION") os.environ["PRODUCTION"] = "1 jekyll build" @@ -529,8 +231,8 @@ def build_spark_documentation(): jekyll_bin = which("jekyll") if not jekyll_bin: - print "[error] Cannot find a version of `jekyll` on the system; please", - print "install one and retry to build documentation." + print("[error] Cannot find a version of `jekyll` on the system; please", + " install one and retry to build documentation.") sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) else: run_cmd([jekyll_bin, "build"]) @@ -538,11 +240,32 @@ def build_spark_documentation(): os.chdir(SPARK_HOME) +def get_zinc_port(): + """ + Get a randomized port on which to start Zinc + """ + return random.randrange(3030, 4030) + + +def kill_zinc_on_port(zinc_port): + """ + Kill the Zinc process running on the given port, if one exists. + """ + cmd = ("/usr/sbin/lsof -P |grep %s | grep LISTEN " + "| awk '{ print $2; }' | xargs kill") % zinc_port + subprocess.check_call(cmd, shell=True) + + def exec_maven(mvn_args=()): """Will call Maven in the current directory with the list of mvn_args passed in and returns the subprocess for any further processing""" - run_cmd([os.path.join(SPARK_HOME, "build", "mvn")] + mvn_args) + zinc_port = get_zinc_port() + os.environ["ZINC_PORT"] = "%s" % zinc_port + zinc_flag = "-DzincPort=%s" % zinc_port + flags = [os.path.join(SPARK_HOME, "build", "mvn"), "--force", zinc_flag] + run_cmd(flags + mvn_args) + kill_zinc_on_port(zinc_port) def exec_sbt(sbt_args=()): @@ -566,7 +289,7 @@ def exec_sbt(sbt_args=()): echo_proc.wait() for line in iter(sbt_proc.stdout.readline, ''): if not sbt_output_filter.match(line): - print line, + print(line, end='') retcode = sbt_proc.wait() if retcode > 0: @@ -580,48 +303,53 @@ def get_hadoop_profiles(hadoop_version): """ sbt_maven_hadoop_profiles = { - "hadoop1.0": ["-Phadoop-1", "-Dhadoop.version=1.0.4"], + "hadoop1.0": ["-Phadoop-1", "-Dhadoop.version=1.2.1"], "hadoop2.0": ["-Phadoop-1", "-Dhadoop.version=2.0.0-mr1-cdh4.1.1"], "hadoop2.2": ["-Pyarn", "-Phadoop-2.2"], "hadoop2.3": ["-Pyarn", "-Phadoop-2.3", "-Dhadoop.version=2.3.0"], + "hadoop2.6": ["-Pyarn", "-Phadoop-2.6"], } if hadoop_version in sbt_maven_hadoop_profiles: return sbt_maven_hadoop_profiles[hadoop_version] else: - print "[error] Could not find", hadoop_version, "in the list. Valid options", - print "are", sbt_maven_hadoop_profiles.keys() + print("[error] Could not find", hadoop_version, "in the list. Valid options", + " are", sbt_maven_hadoop_profiles.keys()) sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) def build_spark_maven(hadoop_version): # Enable all of the profiles for the build: - build_profiles = get_hadoop_profiles(hadoop_version) + root.build_profile_flags + build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags mvn_goals = ["clean", "package", "-DskipTests"] profiles_and_goals = build_profiles + mvn_goals - print "[info] Building Spark (w/Hive 0.13.1) using Maven with these arguments:", - print " ".join(profiles_and_goals) + print("[info] Building Spark (w/Hive 1.2.1) using Maven with these arguments: ", + " ".join(profiles_and_goals)) exec_maven(profiles_and_goals) def build_spark_sbt(hadoop_version): # Enable all of the profiles for the build: - build_profiles = get_hadoop_profiles(hadoop_version) + root.build_profile_flags + build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags sbt_goals = ["package", "assembly/assembly", - "streaming-kafka-assembly/assembly"] + "streaming-kafka-assembly/assembly", + "streaming-flume-assembly/assembly", + "streaming-mqtt-assembly/assembly", + "streaming-mqtt/test:assembly", + "streaming-kinesis-asl-assembly/assembly"] profiles_and_goals = build_profiles + sbt_goals - print "[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments:", - print " ".join(profiles_and_goals) + print("[info] Building Spark (w/Hive 1.2.1) using SBT with these arguments: ", + " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) def build_apache_spark(build_tool, hadoop_version): - """Will build Spark against Hive v0.13.1 given the passed in build tool (either `sbt` or + """Will build Spark against Hive v1.2.1 given the passed in build tool (either `sbt` or `maven`). Defaults to using `sbt`.""" set_title_and_block("Building Spark", "BLOCK_BUILD") @@ -643,8 +371,8 @@ def run_scala_tests_maven(test_profiles): mvn_test_goals = ["test", "--fail-at-end"] profiles_and_goals = test_profiles + mvn_test_goals - print "[info] Running Spark tests using Maven with these arguments:", - print " ".join(profiles_and_goals) + print("[info] Running Spark tests using Maven with these arguments: ", + " ".join(profiles_and_goals)) exec_maven(profiles_and_goals) @@ -658,8 +386,8 @@ def run_scala_tests_sbt(test_modules, test_profiles): profiles_and_goals = test_profiles + list(sbt_test_goals) - print "[info] Running Spark tests using SBT with these arguments:", - print " ".join(profiles_and_goals) + print("[info] Running Spark tests using SBT with these arguments: ", + " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) @@ -679,27 +407,48 @@ def run_scala_tests(build_tool, hadoop_version, test_modules): run_scala_tests_sbt(test_modules, test_profiles) -def run_python_tests(): +def run_python_tests(test_modules, parallelism): set_title_and_block("Running PySpark tests", "BLOCK_PYSPARK_UNIT_TESTS") - run_cmd([os.path.join(SPARK_HOME, "python", "run-tests")]) + command = [os.path.join(SPARK_HOME, "python", "run-tests")] + if test_modules != [modules.root]: + command.append("--modules=%s" % ','.join(m.name for m in test_modules)) + command.append("--parallelism=%i" % parallelism) + run_cmd(command) def run_sparkr_tests(): set_title_and_block("Running SparkR tests", "BLOCK_SPARKR_UNIT_TESTS") if which("R"): - run_cmd([os.path.join(SPARK_HOME, "R", "install-dev.sh")]) run_cmd([os.path.join(SPARK_HOME, "R", "run-tests.sh")]) else: - print "Ignoring SparkR tests as R was not found in PATH" + print("Ignoring SparkR tests as R was not found in PATH") + + +def parse_opts(): + parser = OptionParser( + prog="run-tests" + ) + parser.add_option( + "-p", "--parallelism", type="int", default=4, + help="The number of suites to test in parallel (default %default)" + ) + + (opts, args) = parser.parse_args() + if args: + parser.error("Unsupported arguments: %s" % ' '.join(args)) + if opts.parallelism < 1: + parser.error("Parallelism cannot be less than 1") + return opts def main(): + opts = parse_opts() # Ensure the user home directory (HOME) is valid and is an absolute directory if not USER_HOME or not os.path.isabs(USER_HOME): - print "[error] Cannot determine your home directory as an absolute path;", - print "ensure the $HOME environment variable is set properly." + print("[error] Cannot determine your home directory as an absolute path;", + " ensure the $HOME environment variable is set properly.") sys.exit(1) os.chdir(SPARK_HOME) @@ -713,14 +462,20 @@ def main(): java_exe = determine_java_executable() if not java_exe: - print "[error] Cannot find a version of `java` on the system; please", - print "install one and retry." + print("[error] Cannot find a version of `java` on the system; please", + " install one and retry.") sys.exit(2) java_version = determine_java_version(java_exe) if java_version.minor < 8: - print "[warn] Java 8 tests will not run because JDK version is < 1.8." + print("[warn] Java 8 tests will not run because JDK version is < 1.8.") + + # install SparkR + if which("R"): + run_cmd([os.path.join(SPARK_HOME, "R", "install-dev.sh")]) + else: + print("Can't install SparkR as R is was not found in PATH") if os.environ.get("AMPLAB_JENKINS"): # if we're on the Amplab Jenkins build servers setup variables @@ -736,8 +491,8 @@ def main(): hadoop_version = "hadoop2.3" test_env = "local" - print "[info] Using build tool", build_tool, "with Hadoop profile", hadoop_version, - print "under environment", test_env + print("[info] Using build tool", build_tool, "with Hadoop profile", hadoop_version, + "under environment", test_env) changed_modules = None changed_files = None @@ -746,8 +501,18 @@ def main(): changed_files = identify_changed_files_from_git_commits("HEAD", target_branch=target_branch) changed_modules = determine_modules_for_files(changed_files) if not changed_modules: - changed_modules = [root] - print "[info] Found the following changed modules:", ", ".join(x.name for x in changed_modules) + changed_modules = [modules.root] + print("[info] Found the following changed modules:", + ", ".join(x.name for x in changed_modules)) + + # setup environment variables + # note - the 'root' module doesn't collect environment variables for all modules. Because the + # environment variables should not be set if a module is not changed, even if running the 'root' + # module. So here we should use changed_modules rather than test_modules. + test_environ = {} + for m in changed_modules: + test_environ.update(m.environ) + setup_test_environ(test_environ) test_modules = determine_modules_to_test(changed_modules) @@ -759,6 +524,8 @@ def main(): run_scala_style_checks() if not changed_files or any(f.endswith(".py") for f in changed_files): run_python_style_checks() + if not changed_files or any(f.endswith(".R") for f in changed_files): + run_sparkr_style_checks() # determine if docs were changed and if we're inside the amplab environment # note - the below commented out until *all* Jenkins workers can get `jekyll` installed @@ -769,13 +536,16 @@ def main(): build_apache_spark(build_tool, hadoop_version) # backwards compatibility checks - detect_binary_inop_with_mima() + if build_tool == "sbt": + # Note: compatiblity tests only supported in sbt for now + detect_binary_inop_with_mima() # run the test suites run_scala_tests(build_tool, hadoop_version, test_modules) - if any(m.should_run_python_tests for m in test_modules): - run_python_tests() + modules_with_python_tests = [m for m in test_modules if m.python_test_goals] + if modules_with_python_tests: + run_python_tests(modules_with_python_tests, opts.parallelism) if any(m.should_run_r_tests for m in test_modules): run_sparkr_tests() diff --git a/R/pkg/R/zzz.R b/dev/sparktestsupport/__init__.py similarity index 84% rename from R/pkg/R/zzz.R rename to dev/sparktestsupport/__init__.py index 80d796d46794..12696d98fb98 100644 --- a/R/pkg/R/zzz.R +++ b/dev/sparktestsupport/__init__.py @@ -15,7 +15,7 @@ # limitations under the License. # -.onLoad <- function(libname, pkgname) { - sparkR.onLoad(libname, pkgname) -} +import os +SPARK_HOME = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../")) +USER_HOME = os.environ.get("HOME") diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py new file mode 100644 index 000000000000..346452f3174e --- /dev/null +++ b/dev/sparktestsupport/modules.py @@ -0,0 +1,415 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import itertools +import re + +all_modules = [] + + +class Module(object): + """ + A module is the basic abstraction in our test runner script. Each module consists of a set of + source files, a set of test commands, and a set of dependencies on other modules. We use modules + to define a dependency graph that lets determine which tests to run based on which files have + changed. + """ + + def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={}, + sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(), + should_run_r_tests=False): + """ + Define a new module. + + :param name: A short module name, for display in logging and error messages. + :param dependencies: A set of dependencies for this module. This should only include direct + dependencies; transitive dependencies are resolved automatically. + :param source_file_regexes: a set of regexes that match source files belonging to this + module. These regexes are applied by attempting to match at the beginning of the + filename strings. + :param build_profile_flags: A set of profile flags that should be passed to Maven or SBT in + order to build and test this module (e.g. '-PprofileName'). + :param environ: A dict of environment variables that should be set when files in this + module are changed. + :param sbt_test_goals: A set of SBT test goals for testing this module. + :param python_test_goals: A set of Python test goals for testing this module. + :param blacklisted_python_implementations: A set of Python implementations that are not + supported by this module's Python components. The values in this set should match + strings returned by Python's `platform.python_implementation()`. + :param should_run_r_tests: If true, changes in this module will trigger all R tests. + """ + self.name = name + self.dependencies = dependencies + self.source_file_prefixes = source_file_regexes + self.sbt_test_goals = sbt_test_goals + self.build_profile_flags = build_profile_flags + self.environ = environ + self.python_test_goals = python_test_goals + self.blacklisted_python_implementations = blacklisted_python_implementations + self.should_run_r_tests = should_run_r_tests + + self.dependent_modules = set() + for dep in dependencies: + dep.dependent_modules.add(self) + all_modules.append(self) + + def contains_file(self, filename): + return any(re.match(p, filename) for p in self.source_file_prefixes) + + +sql = Module( + name="sql", + dependencies=[], + source_file_regexes=[ + "sql/(?!hive-thriftserver)", + "bin/spark-sql", + ], + build_profile_flags=[ + "-Phive", + ], + sbt_test_goals=[ + "catalyst/test", + "sql/test", + "hive/test", + ] +) + + +hive_thriftserver = Module( + name="hive-thriftserver", + dependencies=[sql], + source_file_regexes=[ + "sql/hive-thriftserver", + "sbin/start-thriftserver.sh", + ], + build_profile_flags=[ + "-Phive-thriftserver", + ], + sbt_test_goals=[ + "hive-thriftserver/test", + ] +) + + +graphx = Module( + name="graphx", + dependencies=[], + source_file_regexes=[ + "graphx/", + ], + sbt_test_goals=[ + "graphx/test" + ] +) + + +streaming = Module( + name="streaming", + dependencies=[], + source_file_regexes=[ + "streaming", + ], + sbt_test_goals=[ + "streaming/test", + ] +) + + +# Don't set the dependencies because changes in other modules should not trigger Kinesis tests. +# Kinesis tests depends on external Amazon kinesis service. We should run these tests only when +# files in streaming_kinesis_asl are changed, so that if Kinesis experiences an outage, we don't +# fail other PRs. +streaming_kinesis_asl = Module( + name="streaming-kinesis-asl", + dependencies=[], + source_file_regexes=[ + "extras/kinesis-asl/", + "extras/kinesis-asl-assembly/", + ], + build_profile_flags=[ + "-Pkinesis-asl", + ], + environ={ + "ENABLE_KINESIS_TESTS": "1" + }, + sbt_test_goals=[ + "streaming-kinesis-asl/test", + ] +) + + +streaming_zeromq = Module( + name="streaming-zeromq", + dependencies=[streaming], + source_file_regexes=[ + "external/zeromq", + ], + sbt_test_goals=[ + "streaming-zeromq/test", + ] +) + + +streaming_twitter = Module( + name="streaming-twitter", + dependencies=[streaming], + source_file_regexes=[ + "external/twitter", + ], + sbt_test_goals=[ + "streaming-twitter/test", + ] +) + + +streaming_mqtt = Module( + name="streaming-mqtt", + dependencies=[streaming], + source_file_regexes=[ + "external/mqtt", + "external/mqtt-assembly", + ], + sbt_test_goals=[ + "streaming-mqtt/test", + ] +) + + +streaming_kafka = Module( + name="streaming-kafka", + dependencies=[streaming], + source_file_regexes=[ + "external/kafka", + "external/kafka-assembly", + ], + sbt_test_goals=[ + "streaming-kafka/test", + ] +) + + +streaming_flume_sink = Module( + name="streaming-flume-sink", + dependencies=[streaming], + source_file_regexes=[ + "external/flume-sink", + ], + sbt_test_goals=[ + "streaming-flume-sink/test", + ] +) + + +streaming_flume = Module( + name="streaming-flume", + dependencies=[streaming], + source_file_regexes=[ + "external/flume", + ], + sbt_test_goals=[ + "streaming-flume/test", + ] +) + + +streaming_flume_assembly = Module( + name="streaming-flume-assembly", + dependencies=[streaming_flume, streaming_flume_sink], + source_file_regexes=[ + "external/flume-assembly", + ] +) + + +mllib = Module( + name="mllib", + dependencies=[streaming, sql], + source_file_regexes=[ + "data/mllib/", + "mllib/", + ], + sbt_test_goals=[ + "mllib/test", + ] +) + + +examples = Module( + name="examples", + dependencies=[graphx, mllib, streaming, sql], + source_file_regexes=[ + "examples/", + ], + sbt_test_goals=[ + "examples/test", + ] +) + + +pyspark_core = Module( + name="pyspark-core", + dependencies=[], + source_file_regexes=[ + "python/(?!pyspark/(ml|mllib|sql|streaming))" + ], + python_test_goals=[ + "pyspark.rdd", + "pyspark.context", + "pyspark.conf", + "pyspark.broadcast", + "pyspark.accumulators", + "pyspark.serializers", + "pyspark.profiler", + "pyspark.shuffle", + "pyspark.tests", + ] +) + + +pyspark_sql = Module( + name="pyspark-sql", + dependencies=[pyspark_core, sql], + source_file_regexes=[ + "python/pyspark/sql" + ], + python_test_goals=[ + "pyspark.sql.types", + "pyspark.sql.context", + "pyspark.sql.column", + "pyspark.sql.dataframe", + "pyspark.sql.group", + "pyspark.sql.functions", + "pyspark.sql.readwriter", + "pyspark.sql.window", + "pyspark.sql.tests", + ] +) + + +pyspark_streaming = Module( + name="pyspark-streaming", + dependencies=[ + pyspark_core, + streaming, + streaming_kafka, + streaming_flume_assembly, + streaming_mqtt, + streaming_kinesis_asl + ], + source_file_regexes=[ + "python/pyspark/streaming" + ], + python_test_goals=[ + "pyspark.streaming.util", + "pyspark.streaming.tests", + ] +) + + +pyspark_mllib = Module( + name="pyspark-mllib", + dependencies=[pyspark_core, pyspark_streaming, pyspark_sql, mllib], + source_file_regexes=[ + "python/pyspark/mllib" + ], + python_test_goals=[ + "pyspark.mllib.classification", + "pyspark.mllib.clustering", + "pyspark.mllib.evaluation", + "pyspark.mllib.feature", + "pyspark.mllib.fpm", + "pyspark.mllib.linalg.__init__", + "pyspark.mllib.linalg.distributed", + "pyspark.mllib.random", + "pyspark.mllib.recommendation", + "pyspark.mllib.regression", + "pyspark.mllib.stat._statistics", + "pyspark.mllib.stat.KernelDensity", + "pyspark.mllib.tree", + "pyspark.mllib.util", + "pyspark.mllib.tests", + ], + blacklisted_python_implementations=[ + "PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there + ] +) + + +pyspark_ml = Module( + name="pyspark-ml", + dependencies=[pyspark_core, pyspark_mllib], + source_file_regexes=[ + "python/pyspark/ml/" + ], + python_test_goals=[ + "pyspark.ml.feature", + "pyspark.ml.classification", + "pyspark.ml.clustering", + "pyspark.ml.recommendation", + "pyspark.ml.regression", + "pyspark.ml.tuning", + "pyspark.ml.tests", + "pyspark.ml.evaluation", + ], + blacklisted_python_implementations=[ + "PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there + ] +) + +sparkr = Module( + name="sparkr", + dependencies=[sql, mllib], + source_file_regexes=[ + "R/", + ], + should_run_r_tests=True +) + + +docs = Module( + name="docs", + dependencies=[], + source_file_regexes=[ + "docs/", + ] +) + + +ec2 = Module( + name="ec2", + dependencies=[], + source_file_regexes=[ + "ec2/", + ] +) + + +# The root module is a dummy module which is used to run all of the tests. +# No other modules should directly depend on this module. +root = Module( + name="root", + dependencies=[], + source_file_regexes=[], + # In order to run all of the tests, enable every test profile: + build_profile_flags=list(set( + itertools.chain.from_iterable(m.build_profile_flags for m in all_modules))), + sbt_test_goals=[ + "test", + ], + python_test_goals=list(itertools.chain.from_iterable(m.python_test_goals for m in all_modules)), + should_run_r_tests=True +) diff --git a/dev/sparktestsupport/shellutils.py b/dev/sparktestsupport/shellutils.py new file mode 100644 index 000000000000..12bd0bf3a4fe --- /dev/null +++ b/dev/sparktestsupport/shellutils.py @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function +import os +import shutil +import subprocess +import sys + + +def exit_from_command_with_retcode(cmd, retcode): + print("[error] running", ' '.join(cmd), "; received return code", retcode) + sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) + + +def rm_r(path): + """ + Given an arbitrary path, properly remove it with the correct Python construct if it exists. + From: http://stackoverflow.com/a/9559881 + """ + + if os.path.isdir(path): + shutil.rmtree(path) + elif os.path.exists(path): + os.remove(path) + + +def run_cmd(cmd): + """ + Given a command as a list of arguments will attempt to execute the command + and, on failure, print an error message and exit. + """ + + if not isinstance(cmd, list): + cmd = cmd.split() + try: + subprocess.check_call(cmd) + except subprocess.CalledProcessError as e: + exit_from_command_with_retcode(e.cmd, e.returncode) + + +def is_exe(path): + """ + Check if a given path is an executable file. + From: http://stackoverflow.com/a/377028 + """ + + return os.path.isfile(path) and os.access(path, os.X_OK) + + +def which(program): + """ + Find and return the given program by its absolute path or 'None' if the program cannot be found. + From: http://stackoverflow.com/a/377028 + """ + + fpath = os.path.split(program)[0] + + if fpath: + if is_exe(program): + return program + else: + for path in os.environ.get("PATH").split(os.pathsep): + path = path.strip('"') + exe_file = os.path.join(path, program) + if is_exe(exe_file): + return exe_file + return None diff --git a/docker/spark-mesos/Dockerfile b/docker/spark-mesos/Dockerfile index b90aef3655de..fb3f267fe5c7 100644 --- a/docker/spark-mesos/Dockerfile +++ b/docker/spark-mesos/Dockerfile @@ -24,7 +24,7 @@ RUN apt-get update && \ apt-get install -y python libnss3 openjdk-7-jre-headless curl RUN mkdir /opt/spark && \ - curl http://www.apache.org/dyn/closer.cgi/spark/spark-1.4.0/spark-1.4.0-bin-hadoop2.4.tgz \ + curl http://www.apache.org/dyn/closer.lua/spark/spark-1.4.0/spark-1.4.0-bin-hadoop2.4.tgz \ | tar -xzC /opt ENV SPARK_HOME /opt/spark ENV MESOS_NATIVE_JAVA_LIBRARY /usr/local/lib/libmesos.so diff --git a/docker/spark-test/base/Dockerfile b/docker/spark-test/base/Dockerfile index 5956d59130fb..5dbdb8b22a44 100644 --- a/docker/spark-test/base/Dockerfile +++ b/docker/spark-test/base/Dockerfile @@ -17,13 +17,13 @@ FROM ubuntu:precise -RUN echo "deb http://archive.ubuntu.com/ubuntu precise main universe" > /etc/apt/sources.list - # Upgrade package index -RUN apt-get update - # install a few other useful packages plus Open Jdk 7 -RUN apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server +# Remove unneeded /var/lib/apt/lists/* after install to reduce the +# docker image size (by ~30MB) +RUN apt-get update && \ + apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server && \ + rm -rf /var/lib/apt/lists/* ENV SCALA_VERSION 2.10.4 ENV CDH_VERSION cdh4 diff --git a/docs/README.md b/docs/README.md index 5852f972a051..1f4fd3e56ed5 100644 --- a/docs/README.md +++ b/docs/README.md @@ -8,6 +8,17 @@ Read on to learn more about viewing documentation in plain text (i.e., markdown) documentation yourself. Why build it yourself? So that you have the docs that corresponds to whichever version of Spark you currently have checked out of revision control. +## Prerequisites +The Spark documentation build uses a number of tools to build HTML docs and API docs in Scala, +Python and R. To get started you can run the following commands + + $ sudo gem install jekyll + $ sudo gem install jekyll-redirect-from + $ sudo pip install Pygments + $ sudo pip install sphinx + $ Rscript -e 'install.packages(c("knitr", "devtools"), repos="http://cran.stat.ucla.edu/")' + + ## Generating the Documentation HTML We include the Spark documentation as part of the source (as opposed to using a hosted wiki, such as @@ -19,17 +30,12 @@ you have checked out or downloaded. In this directory you will find textfiles formatted using Markdown, with an ".md" suffix. You can read those text files directly if you want. Start with index.md. -The markdown code can be compiled to HTML using the [Jekyll tool](http://jekyllrb.com). -`Jekyll` and a few dependencies must be installed for this to work. We recommend -installing via the Ruby Gem dependency manager. Since the exact HTML output -varies between versions of Jekyll and its dependencies, we list specific versions here -in some cases: - - $ sudo gem install jekyll - $ sudo gem install jekyll-redirect-from +Execute `jekyll build` from the `docs/` directory to compile the site. Compiling the site with +Jekyll will create a directory called `_site` containing index.html as well as the rest of the +compiled files. -Execute `jekyll` from the `docs/` directory. Compiling the site with Jekyll will create a directory -called `_site` containing index.html as well as the rest of the compiled files. + $ cd docs + $ jekyll build You can modify the default Jekyll build as follows: @@ -40,29 +46,6 @@ You can modify the default Jekyll build as follows: # Build the site with extra features used on the live page $ PRODUCTION=1 jekyll build -## Pygments - -We also use pygments (http://pygments.org) for syntax highlighting in documentation markdown pages, -so you will also need to install that (it requires Python) by running `sudo pip install Pygments`. - -To mark a block of code in your markdown to be syntax highlighted by jekyll during the compile -phase, use the following sytax: - - {% highlight scala %} - // Your scala code goes here, you can replace scala with many other - // supported languages too. - {% endhighlight %} - -## Sphinx - -We use Sphinx to generate Python API docs, so you will need to install it by running -`sudo pip install sphinx`. - -## knitr, devtools - -SparkR documentation is written using `roxygen2` and we use `knitr`, `devtools` to generate -documentation. To install these packages you can run `install.packages(c("knitr", "devtools"))` from a -R console. ## API Docs (Scaladoc, Sphinx, roxygen2) diff --git a/docs/_config.yml b/docs/_config.yml index c0e031a83ba9..c59cc465ef89 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 1.5.0-SNAPSHOT -SPARK_VERSION_SHORT: 1.5.0 +SPARK_VERSION: 1.6.0-SNAPSHOT +SPARK_VERSION_SHORT: 1.6.0 SCALA_BINARY_VERSION: "2.10" SCALA_VERSION: "2.10.4" MESOS_VERSION: 0.21.0 diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 6073b3626c45..15ceda11a8a8 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -63,6 +63,51 @@ puts "cp -r " + source + "/. " + dest cp_r(source + "/.", dest) + + # Begin updating JavaDoc files for badge post-processing + puts "Updating JavaDoc files for badge post-processing" + js_script_start = '' + + javadoc_files = Dir["./" + dest + "/**/*.html"] + javadoc_files.each do |javadoc_file| + # Determine file depths to reference js files + slash_count = javadoc_file.count "/" + i = 3 + path_to_js_file = "" + while (i < slash_count) do + path_to_js_file = path_to_js_file + "../" + i += 1 + end + + # Create script elements to reference js files + javadoc_jquery_script = js_script_start + path_to_js_file + "lib/jquery" + js_script_end; + javadoc_api_docs_script = js_script_start + path_to_js_file + "lib/api-javadocs" + js_script_end; + javadoc_script_elements = javadoc_jquery_script + javadoc_api_docs_script + + # Add script elements to JavaDoc files + javadoc_file_content = File.open(javadoc_file, "r") { |f| f.read } + javadoc_file_content = javadoc_file_content.sub("", javadoc_script_elements + "") + File.open(javadoc_file, "w") { |f| f.puts(javadoc_file_content) } + + end + # End updating JavaDoc files for badge post-processing + + puts "Copying jquery.js from Scala API to Java API for page post-processing of badges" + jquery_src_file = "./api/scala/lib/jquery.js" + jquery_dest_file = "./api/java/lib/jquery.js" + mkdir_p("./api/java/lib") + cp(jquery_src_file, jquery_dest_file) + + puts "Copying api_javadocs.js to Java API for page post-processing of badges" + api_javadocs_src_file = "./js/api-javadocs.js" + api_javadocs_dest_file = "./api/java/lib/api-javadocs.js" + cp(api_javadocs_src_file, api_javadocs_dest_file) + + puts "Appending content of api-javadocs.css to JavaDoc stylesheet.css for badge styles" + css = File.readlines("./css/api-javadocs.css") + css_file = dest + "/stylesheet.css" + File.open(css_file, 'a') { |f| f.write("\n" + css.join()) } end # Build Sphinx docs for Python diff --git a/docs/api.md b/docs/api.md index 45df77ac05f7..ae7d51c2aefb 100644 --- a/docs/api.md +++ b/docs/api.md @@ -3,7 +3,7 @@ layout: global title: Spark API Documentation --- -Here you can API docs for Spark and its submodules. +Here you can read API docs for Spark and its submodules. - [Spark Scala API (Scaladoc)](api/scala/index.html) - [Spark Java API (Javadoc)](api/java/index.html) diff --git a/docs/bagel-programming-guide.md b/docs/bagel-programming-guide.md index c2fe6b0e286c..347ca4a7af98 100644 --- a/docs/bagel-programming-guide.md +++ b/docs/bagel-programming-guide.md @@ -4,7 +4,7 @@ displayTitle: Bagel Programming Guide title: Bagel --- -**Bagel will soon be superseded by [GraphX](graphx-programming-guide.html); we recommend that new users try GraphX instead.** +**Bagel is deprecated, and superseded by [GraphX](graphx-programming-guide.html).** Bagel is a Spark implementation of Google's [Pregel](http://portal.acm.org/citation.cfm?id=1807184) graph processing framework. Bagel currently supports basic graph computation, combiners, and aggregators. @@ -157,11 +157,3 @@ trait Message[K] { def targetId: K } {% endhighlight %} - -# Where to Go from Here - -Two example jobs, PageRank and shortest path, are included in `examples/src/main/scala/org/apache/spark/examples/bagel`. You can run them by passing the class name to the `bin/run-example` script included in Spark; e.g.: - - ./bin/run-example org.apache.spark.examples.bagel.WikipediaPageRank - -Each example program prints usage help when run without any arguments. diff --git a/docs/building-spark.md b/docs/building-spark.md index 2128fdffecc0..4db32cfd628b 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -7,7 +7,8 @@ redirect_from: "building-with-maven.html" * This will become a table of contents (this text will be scraped). {:toc} -Building Spark using Maven requires Maven 3.0.4 or newer and Java 7+. +Building Spark using Maven requires Maven 3.3.3 or newer and Java 7+. +The Spark build can supply a suitable Maven binary; see below. # Building with `build/mvn` @@ -60,12 +61,13 @@ If you don't run this, you may see errors like the following: You can fix this by setting the `MAVEN_OPTS` variable as discussed before. **Note:** -* *For Java 8 and above this step is not required.* -* *If using `build/mvn` and `MAVEN_OPTS` were not already set, the script will automate this for you.* + +* For Java 8 and above this step is not required. +* If using `build/mvn` with no `MAVEN_OPTS` set, the script will automate this for you. # Specifying the Hadoop Version -Because HDFS is not protocol-compatible across versions, if you want to read from HDFS, you'll need to build Spark against the specific HDFS version in your environment. You can do this through the "hadoop.version" property. If unset, Spark will build against Hadoop 2.2.0 by default. Note that certain build profiles are required for particular Hadoop versions: +Because HDFS is not protocol-compatible across versions, if you want to read from HDFS, you'll need to build Spark against the specific HDFS version in your environment. You can do this through the `hadoop.version` property. If unset, Spark will build against Hadoop 2.2.0 by default. Note that certain build profiles are required for particular Hadoop versions: @@ -90,7 +92,7 @@ mvn -Dhadoop.version=1.2.1 -Phadoop-1 -DskipTests clean package mvn -Dhadoop.version=2.0.0-mr1-cdh4.2.0 -Phadoop-1 -DskipTests clean package {% endhighlight %} -You can enable the "yarn" profile and optionally set the "yarn.version" property if it is different from "hadoop.version". Spark only supports YARN versions 2.2.0 and later. +You can enable the `yarn` profile and optionally set the `yarn.version` property if it is different from `hadoop.version`. Spark only supports YARN versions 2.2.0 and later. Examples: @@ -124,7 +126,7 @@ mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -Dskip # Building for Scala 2.11 To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` property: - dev/change-version-to-2.11.sh + ./dev/change-scala-version.sh 2.11 mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package Spark does not yet support its JDBC component for Scala 2.11. @@ -162,11 +164,9 @@ the `spark-parent` module). Thus, the full flow for running continuous-compilation of the `core` submodule may look more like: -``` - $ mvn install - $ cd core - $ mvn scala:cc -``` + $ mvn install + $ cd core + $ mvn scala:cc # Building Spark with IntelliJ IDEA or Eclipse @@ -192,11 +192,11 @@ then ship it over to the cluster. We are investigating the exact cause for this. # Packaging without Hadoop Dependencies for YARN -The assembly jar produced by `mvn package` will, by default, include all of Spark's dependencies, including Hadoop and some of its ecosystem projects. On YARN deployments, this causes multiple versions of these to appear on executor classpaths: the version packaged in the Spark assembly and the version on each node, included with yarn.application.classpath. The `hadoop-provided` profile builds the assembly without including Hadoop-ecosystem projects, like ZooKeeper and Hadoop itself. +The assembly jar produced by `mvn package` will, by default, include all of Spark's dependencies, including Hadoop and some of its ecosystem projects. On YARN deployments, this causes multiple versions of these to appear on executor classpaths: the version packaged in the Spark assembly and the version on each node, included with `yarn.application.classpath`. The `hadoop-provided` profile builds the assembly without including Hadoop-ecosystem projects, like ZooKeeper and Hadoop itself. # Building with SBT -Maven is the official recommendation for packaging Spark, and is the "build of reference". +Maven is the official build tool recommended for packaging Spark, and is the *build of reference*. But SBT is supported for day-to-day development since it can provide much faster iterative compilation. More advanced developers may wish to use SBT. diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index 7079de546e2f..faaf154d243f 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -5,18 +5,19 @@ title: Cluster Mode Overview This document gives a short overview of how Spark runs on clusters, to make it easier to understand the components involved. Read through the [application submission guide](submitting-applications.html) -to submit applications to a cluster. +to learn about launching applications on a cluster. # Components -Spark applications run as independent sets of processes on a cluster, coordinated by the SparkContext +Spark applications run as independent sets of processes on a cluster, coordinated by the `SparkContext` object in your main program (called the _driver program_). + Specifically, to run on a cluster, the SparkContext can connect to several types of _cluster managers_ -(either Spark's own standalone cluster manager or Mesos/YARN), which allocate resources across +(either Spark's own standalone cluster manager, Mesos or YARN), which allocate resources across applications. Once connected, Spark acquires *executors* on nodes in the cluster, which are processes that run computations and store data for your application. Next, it sends your application code (defined by JAR or Python files passed to SparkContext) to -the executors. Finally, SparkContext sends *tasks* for the executors to run. +the executors. Finally, SparkContext sends *tasks* to the executors to run.

    Spark cluster components @@ -33,9 +34,9 @@ There are several useful things to note about this architecture: 2. Spark is agnostic to the underlying cluster manager. As long as it can acquire executor processes, and these communicate with each other, it is relatively easy to run it even on a cluster manager that also supports other applications (e.g. Mesos/YARN). -3. The driver program must listen for and accept incoming connections from its executors throughout - its lifetime (e.g., see [spark.driver.port and spark.fileserver.port in the network config - section](configuration.html#networking)). As such, the driver program must be network +3. The driver program must listen for and accept incoming connections from its executors throughout + its lifetime (e.g., see [spark.driver.port and spark.fileserver.port in the network config + section](configuration.html#networking)). As such, the driver program must be network addressable from the worker nodes. 4. Because the driver schedules tasks on the cluster, it should be run close to the worker nodes, preferably on the same local area network. If you'd like to send requests to the diff --git a/docs/configuration.md b/docs/configuration.md index affcd21514d8..1a701f18881f 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -31,7 +31,6 @@ which can help detect bugs that only exist when we run in a distributed context. val conf = new SparkConf() .setMaster("local[2]") .setAppName("CountingSheep") - .set("spark.executor.memory", "1g") val sc = new SparkContext(conf) {% endhighlight %} @@ -84,7 +83,7 @@ Running `./bin/spark-submit --help` will show the entire list of these options. each line consists of a key and a value separated by whitespace. For example: spark.master spark://5.6.7.8:7077 - spark.executor.memory 512m + spark.executor.memory 4g spark.eventLog.enabled true spark.serializer org.apache.spark.serializer.KryoSerializer @@ -137,10 +136,10 @@ of the most common options to set are:

    - + - + @@ -205,7 +203,7 @@ Apart from these, the following properties are also available, and may be useful @@ -384,16 +382,6 @@ Apart from these, the following properties are also available, and may be useful overhead per reduce task, so keep it small unless you have a large amount of memory. - - - - - @@ -459,9 +447,12 @@ Apart from these, the following properties are also available, and may be useful @@ -475,6 +466,25 @@ Apart from these, the following properties are also available, and may be useful spark.storage.memoryFraction. + + + + + + + + + + @@ -559,6 +569,20 @@ Apart from these, the following properties are also available, and may be useful collecting. + + + + + + + + + +
    spark.driver.memory512m1g Amount of memory to use for the driver process, i.e. where SparkContext is initialized. - (e.g. 512m, 2g). + (e.g. 1g, 2g).
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. @@ -150,10 +149,9 @@ of the most common options to set are:
    spark.executor.memory512m1g - Amount of memory to use per executor process, in the same format as JVM memory strings - (e.g. 512m, 2g). + Amount of memory to use per executor process (e.g. 2g, 8g).
    spark.driver.extraClassPath (none) - Extra classpath entries to append to the classpath of the driver. + Extra classpath entries to prepend to the classpath of the driver.
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. @@ -252,7 +250,7 @@ Apart from these, the following properties are also available, and may be useful
    spark.executor.extraClassPath (none) - Extra classpath entries to append to the classpath of executors. This exists primarily for + Extra classpath entries to prepend to the classpath of executors. This exists primarily for backwards-compatibility with older versions of Spark. Users typically should not need to set this option.
    spark.shuffle.blockTransferServicenetty - Implementation to use for transferring shuffle and cached blocks between executors. There - are two implementations available: netty and nio. Netty-based - block transfer is intended to be simpler but equally efficient and is the default option - starting in 1.2. -
    spark.shuffle.compress truespark.shuffle.manager sort - Implementation to use for shuffling data. There are two implementations available: - sort and hash. Sort-based shuffle is more memory-efficient and is - the default option starting in 1.2. + Implementation to use for shuffling data. There are three implementations available: + sort, hash and the new (1.5+) tungsten-sort. + Sort-based shuffle is more memory-efficient and is the default option starting in 1.2. + Tungsten-sort is similar to the sort based shuffle, with a direct binary cache-friendly + implementation with a fall back to regular sort based shuffle if its requirements are not + met.
    spark.shuffle.service.enabledfalse + Enables the external shuffle service. This service preserves the shuffle files written by + executors so the executors can be safely removed. This must be enabled if + spark.dynamicAllocation.enabled is "true". The external shuffle service + must be set up in order to enable it. See + dynamic allocation + configuration and setup documentation for more information. +
    spark.shuffle.service.port7337 + Port on which the external shuffle service will run. +
    spark.shuffle.sort.bypassMergeThreshold 200
    spark.worker.ui.retainedExecutors1000 + How many finished executors the Spark UI and status APIs remember before garbage collecting. +
    spark.worker.ui.retainedDrivers1000 + How many finished drivers the Spark UI and status APIs remember before garbage collecting. +
    #### Compression and Serialization @@ -665,7 +689,7 @@ Apart from these, the following properties are also available, and may be useful Initial size of Kryo's serialization buffer. Note that there will be one buffer per core on each worker. This buffer will grow up to - spark.kryoserializer.buffer.max.mb if needed. + spark.kryoserializer.buffer.max if needed. @@ -874,23 +898,13 @@ Apart from these, the following properties are also available, and may be useful #### Networking - - - - - - + @@ -993,7 +1007,11 @@ Apart from these, the following properties are also available, and may be useful @@ -1007,9 +1025,9 @@ Apart from these, the following properties are also available, and may be useful + @@ -1029,8 +1047,8 @@ Apart from these, the following properties are also available, and may be useful - Duration for an RPC remote endpoint lookup operation to wait before timing out.
    Property NameDefaultMeaning
    spark.akka.failure-detector.threshold300.0 - This is set to a larger value to disable failure detector that comes inbuilt akka. It can be - enabled again, if you plan to use this feature (Not recommended). This maps to akka's - `akka.remote.transport-failure-detector.threshold`. Tune this in combination of - `spark.akka.heartbeat.pauses` and `spark.akka.heartbeat.interval` if you need to. -
    spark.akka.frameSize10128 - Maximum message size to allow in "control plane" communication (for serialized tasks and task - results), in MB. Increase this if your tasks need to send back large results to the driver - (e.g. using collect() on a large dataset). + Maximum message size to allow in "control plane" communication; generally only applies to map + output size information sent between executors and the driver. Increase this if you are running + jobs with many thousands of map and reduce tasks and see messages about the frame size.
    spark.port.maxRetries 16 - Default maximum number of retries when binding to a port before giving up. + Maximum number of retries when binding to a port before giving up. + When a port is given a specific value (non 0), each subsequent retry will + increment the port used in the previous attempt by 1 before retrying. This + essentially allows it to try a range of ports from the start port specified + to port + maxRetries.
    spark.rpc.numRetries 3 Number of times to retry before an RPC task gives up. An RPC task will run at most times of this number. -
    spark.rpc.lookupTimeout 120s + Duration for an RPC remote endpoint lookup operation to wait before timing out.
    @@ -1050,15 +1068,6 @@ Apart from these, the following properties are also available, and may be useful infinite (all available cores) on Mesos. - - spark.localExecution.enabled - false - - Enables Spark to run certain jobs, such as first() or take() on the driver, without sending - tasks to the cluster. This can make certain jobs execute very quickly, but may require - shipping a whole partition of data to the driver. - - spark.locality.wait 3s @@ -1103,10 +1112,11 @@ Apart from these, the following properties are also available, and may be useful spark.scheduler.minRegisteredResourcesRatio - 0.8 for YARN mode; 0.0 otherwise + 0.8 for YARN mode; 0.0 for standalone mode and Mesos coarse-grained mode The minimum ratio of registered resources (registered resources / total expected resources) - (resources are executors in yarn mode, CPU cores in standalone mode) + (resources are executors in yarn mode, CPU cores in standalone mode and Mesos coarsed-grained + mode ['spark.cores.max' value is total expected resources for Mesos coarse-grained mode] ) to wait for before scheduling begins. Specified as a double between 0.0 and 1.0. Regardless of whether the minimum ratio of resources has been reached, the maximum amount of time it will wait before scheduling begins is controlled by config @@ -1206,7 +1216,7 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.cachedExecutorIdleTimeout - 2 * executorIdleTimeout + infinity If dynamic allocation is enabled and an executor which has cached data blocks has been idle for more than this duration, the executor will be removed. For more details, see this @@ -1222,7 +1232,7 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.maxExecutors - Integer.MAX_VALUE + infinity Upper bound for the number of executors if dynamic allocation is enabled. @@ -1273,7 +1283,8 @@ Apart from these, the following properties are also available, and may be useful Comma separated list of users/administrators that have view and modify access to all Spark jobs. This can be used if you run on a shared cluster and have a set of administrators or devs who - help debug when things work. + help debug when things work. Putting a "*" in the list means any user can have the priviledge + of admin. @@ -1314,7 +1325,8 @@ Apart from these, the following properties are also available, and may be useful Empty Comma separated list of users that have modify access to the Spark job. By default only the - user that started the Spark job has access to modify it (kill it for example). + user that started the Spark job has access to modify it (kill it for example). Putting a "*" in + the list means any user can have access to modify it. @@ -1336,7 +1348,8 @@ Apart from these, the following properties are also available, and may be useful Empty Comma separated list of users that have view access to the Spark web ui. By default only the - user that started the Spark job has view access. + user that started the Spark job has view access. Putting a "*" in the list means any user can + have view access to this Spark job. @@ -1424,6 +1437,19 @@ Apart from these, the following properties are also available, and may be useful #### Spark Streaming + + + + + @@ -1538,7 +1564,11 @@ The following variables can be set in `spark-env.sh`: - + + + + + diff --git a/docs/css/api-docs.css b/docs/css/api-docs.css index b2d1d7f86979..7cf222aad24f 100644 --- a/docs/css/api-docs.css +++ b/docs/css/api-docs.css @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + /* Dynamically injected style for the API docs */ .developer { diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala b/docs/css/api-javadocs.css similarity index 52% rename from core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala rename to docs/css/api-javadocs.css index b3b281ff465f..832e92609e01 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala +++ b/docs/css/api-javadocs.css @@ -15,22 +15,38 @@ * limitations under the License. */ -package org.apache.spark.network.nio +/* Dynamically injected style for the API docs */ -private[nio] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) { - override def toString: String = { - connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId - } +.badge { + font-family: Arial, san-serif; + float: right; + margin: 4px; + /* The following declarations are taken from the ScalaDoc template.css */ + display: inline-block; + padding: 2px 4px; + font-size: 11.844px; + font-weight: bold; + line-height: 14px; + color: #ffffff; + text-shadow: 0 -1px 0 rgba(0, 0, 0, 0.25); + white-space: nowrap; + vertical-align: baseline; + background-color: #999999; + padding-right: 9px; + padding-left: 9px; + -webkit-border-radius: 9px; + -moz-border-radius: 9px; + border-radius: 9px; } -private[nio] object ConnectionId { +.developer { + background-color: #44751E; +} + +.experimental { + background-color: #257080; +} - def createConnectionIdFromString(connectionIdString: String): ConnectionId = { - val res = connectionIdString.split("_").map(_.trim()) - if (res.size != 3) { - throw new Exception("Error converting ConnectionId string: " + connectionIdString + - " to a ConnectionId Object") - } - new ConnectionId(new ConnectionManagerId(res(0), res(1).toInt), res(2).toInt) - } +.alphaComponent { + background-color: #bb0000; } diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 3f10cb2dc3d2..c861a763d622 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -768,16 +768,14 @@ class GraphOps[VD, ED] { // Loop until no messages remain or maxIterations is achieved var i = 0 while (activeMessages > 0 && i < maxIterations) { - // Receive the messages: ----------------------------------------------------------------------- - // Run the vertex program on all vertices that receive messages - val newVerts = g.vertices.innerJoin(messages)(vprog).cache() - // Merge the new vertex values back into the graph - g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) }.cache() - // Send Messages: ------------------------------------------------------------------------------ - // Vertices that didn't receive a message above don't appear in newVerts and therefore don't - // get to send messages. More precisely the map phase of mapReduceTriplets is only invoked - // on edges in the activeDir of vertices in newVerts - messages = g.mapReduceTriplets(sendMsg, mergeMsg, Some((newVerts, activeDir))).cache() + // Receive the messages and update the vertices. + g = g.joinVertices(messages)(vprog).cache() + val oldMessages = messages + // Send new messages, skipping edges where neither side received a message. We must cache + // messages so it can be materialized on the next line, allowing us to uncache the previous + // iteration. + messages = g.mapReduceTriplets( + sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() activeMessages = messages.count() i += 1 } @@ -800,7 +798,7 @@ import org.apache.spark.graphx._ // Import random graph generation library import org.apache.spark.graphx.util.GraphGenerators // A graph with edge attributes containing distances -val graph: Graph[Int, Double] = +val graph: Graph[Long, Double] = GraphGenerators.logNormalGraph(sc, numVertices = 100).mapEdges(e => e.attr.toDouble) val sourceId: VertexId = 42 // The ultimate source // Initialize the graph such that all vertices except the root have distance infinity. diff --git a/docs/index.md b/docs/index.md index d85cf12defef..c0dc2b8d7412 100644 --- a/docs/index.md +++ b/docs/index.md @@ -90,7 +90,6 @@ options for deployment: * [Spark SQL and DataFrames](sql-programming-guide.html): support for structured data and relational queries * [MLlib](mllib-guide.html): built-in machine learning library * [GraphX](graphx-programming-guide.html): Spark's new API for graph processing - * [Bagel (Pregel on Spark)](bagel-programming-guide.html): older, simple graph processing model **API Docs:** diff --git a/docs/js/api-javadocs.js b/docs/js/api-javadocs.js new file mode 100644 index 000000000000..ead13d6e5fa7 --- /dev/null +++ b/docs/js/api-javadocs.js @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Dynamically injected post-processing code for the API docs */ + +$(document).ready(function() { + addBadges(":: AlphaComponent ::", 'Alpha Component'); + addBadges(":: DeveloperApi ::", 'Developer API'); + addBadges(":: Experimental ::", 'Experimental'); +}); + +function addBadges(tag, html) { + var tags = $(".block:contains(" + tag + ")") + + // Remove identifier tags + tags.each(function(index) { + var oldHTML = $(this).html(); + var newHTML = oldHTML.replace(tag, ""); + $(this).html(newHTML); + }); + + // Add html badge tags + tags.each(function(index) { + if ($(this).parent().is('td.colLast')) { + $(this).parent().prepend(html); + } else if ($(this).parent('li.blockList') + .parent('ul.blockList') + .parent('div.description') + .parent().is('div.contentContainer')) { + var contentContainer = $(this).parent('li.blockList') + .parent('ul.blockList') + .parent('div.description') + .parent('div.contentContainer') + var header = contentContainer.prev('div.header'); + if (header.length > 0) { + header.prepend(html); + } else { + contentContainer.prepend(html); + } + } else if ($(this).parent().is('li.blockList')) { + $(this).parent().prepend(html); + } else { + $(this).prepend(html); + } + }); +} diff --git a/docs/ml-ann.md b/docs/ml-ann.md new file mode 100644 index 000000000000..d5ddd92af1e9 --- /dev/null +++ b/docs/ml-ann.md @@ -0,0 +1,123 @@ +--- +layout: global +title: Multilayer perceptron classifier - ML +displayTitle: ML - Multilayer perceptron classifier +--- + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + + +Multilayer perceptron classifier (MLPC) is a classifier based on the [feedforward artificial neural network](https://en.wikipedia.org/wiki/Feedforward_neural_network). +MLPC consists of multiple layers of nodes. +Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes maps inputs to the outputs +by performing linear combination of the inputs with the node's weights `$\wv$` and bias `$\bv$` and applying an activation function. +It can be written in matrix form for MLPC with `$K+1$` layers as follows: +`\[ +\mathrm{y}(\x) = \mathrm{f_K}(...\mathrm{f_2}(\wv_2^T\mathrm{f_1}(\wv_1^T \x+b_1)+b_2)...+b_K) +\]` +Nodes in intermediate layers use sigmoid (logistic) function: +`\[ +\mathrm{f}(z_i) = \frac{1}{1 + e^{-z_i}} +\]` +Nodes in the output layer use softmax function: +`\[ +\mathrm{f}(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}} +\]` +The number of nodes `$N$` in the output layer corresponds to the number of classes. + +MLPC employes backpropagation for learning the model. We use logistic loss function for optimization and L-BFGS as optimization routine. + +**Examples** + +
    + +
    + +{% highlight scala %} +import org.apache.spark.ml.classification.MultilayerPerceptronClassifier +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.sql.Row + +// Load training data +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt").toDF() +// Split the data into train and test +val splits = data.randomSplit(Array(0.6, 0.4), seed = 1234L) +val train = splits(0) +val test = splits(1) +// specify layers for the neural network: +// input layer of size 4 (features), two intermediate of size 5 and 4 and output of size 3 (classes) +val layers = Array[Int](4, 5, 4, 3) +// create the trainer and set its parameters +val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(128) + .setSeed(1234L) + .setMaxIter(100) +// train the model +val model = trainer.fit(train) +// compute precision on the test set +val result = model.transform(test) +val predictionAndLabels = result.select("prediction", "label") +val evaluator = new MulticlassClassificationEvaluator() + .setMetricName("precision") +println("Precision:" + evaluator.evaluate(predictionAndLabels)) +{% endhighlight %} + +
    + +
    + +{% highlight java %} +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel; +import org.apache.spark.ml.classification.MultilayerPerceptronClassifier; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; + +// Load training data +String path = "data/mllib/sample_multiclass_classification_data.txt"; +JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); +DataFrame dataFrame = sqlContext.createDataFrame(data, LabeledPoint.class); +// Split the data into train and test +DataFrame[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L); +DataFrame train = splits[0]; +DataFrame test = splits[1]; +// specify layers for the neural network: +// input layer of size 4 (features), two intermediate of size 5 and 4 and output of size 3 (classes) +int[] layers = new int[] {4, 5, 4, 3}; +// create the trainer and set its parameters +MultilayerPerceptronClassifier trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(128) + .setSeed(1234L) + .setMaxIter(100); +// train the model +MultilayerPerceptronClassificationModel model = trainer.fit(train); +// compute precision on the test set +DataFrame result = model.transform(test); +DataFrame predictionAndLabels = result.select("prediction", "label"); +MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setMetricName("precision"); +System.out.println("Precision = " + evaluator.evaluate(predictionAndLabels)); +{% endhighlight %} +
    + +
    diff --git a/docs/ml-decision-tree.md b/docs/ml-decision-tree.md new file mode 100644 index 000000000000..542819e93e6d --- /dev/null +++ b/docs/ml-decision-tree.md @@ -0,0 +1,493 @@ +--- +layout: global +title: Decision Trees - SparkML +displayTitle: ML - Decision Trees +--- + +**Table of Contents** + +* This will become a table of contents (this text will be scraped). +{:toc} + + +# Overview + +[Decision trees](http://en.wikipedia.org/wiki/Decision_tree_learning) +and their ensembles are popular methods for the machine learning tasks of +classification and regression. Decision trees are widely used since they are easy to interpret, +handle categorical features, extend to the multiclass classification setting, do not require +feature scaling, and are able to capture non-linearities and feature interactions. Tree ensemble +algorithms such as random forests and boosting are among the top performers for classification and +regression tasks. + +MLlib supports decision trees for binary and multiclass classification and for regression, +using both continuous and categorical features. The implementation partitions data by rows, +allowing distributed training with millions or even billions of instances. + +Users can find more information about the decision tree algorithm in the [MLlib Decision Tree guide](mllib-decision-tree.html). In this section, we demonstrate the Pipelines API for Decision Trees. + +The Pipelines API for Decision Trees offers a bit more functionality than the original API. In particular, for classification, users can get the predicted probability of each class (a.k.a. class conditional probabilities). + +Ensembles of trees (Random Forests and Gradient-Boosted Trees) are described in the [Ensembles guide](ml-ensembles.html). + +# Inputs and Outputs + +We list the input and output (prediction) column types here. +All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. + +## Input Columns + +
    Property NameDefaultMeaning
    spark.streaming.backpressure.enabledfalse + Enables or disables Spark Streaming's internal backpressure mechanism (since 1.5). + This enables the Spark Streaming to control the receiving rate based on the + current batch scheduling delays and processing times so that the system receives + only as fast as the system can process. Internally, this dynamically sets the + maximum receiving rate of receivers. This rate is upper bounded by the values + `spark.streaming.receiver.maxRate` and `spark.streaming.kafka.maxRatePerPartition` + if they are set (see below). +
    spark.streaming.blockInterval 200ms
    PYSPARK_PYTHONPython binary executable to use for PySpark.Python binary executable to use for PySpark in both driver and workers (default is `python`).
    PYSPARK_DRIVER_PYTHONPython binary executable to use for PySpark in driver only (default is PYSPARK_PYTHON).
    SPARK_LOCAL_IP
    + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    labelColDouble"label"Label to predict
    featuresColVector"features"Feature vector
    + +## Output Columns + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescriptionNotes
    predictionColDouble"prediction"Predicted label
    rawPredictionColVector"rawPrediction"Vector of length # classes, with the counts of training instance labels at the tree node which makes the predictionClassification only
    probabilityColVector"probability"Vector of length # classes equal to rawPrediction normalized to a multinomial distributionClassification only
    + +# Examples + +The below examples demonstrate the Pipelines API for Decision Trees. The main differences between this API and the [original MLlib Decision Tree API](mllib-decision-tree.html) are: + +* support for ML Pipelines +* separation of Decision Trees for classification vs. regression +* use of DataFrame metadata to distinguish continuous and categorical features + + +## Classification + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the Decision Tree algorithm can recognize. + +
    +
    + +More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.classification.DecisionTreeClassifier). + +{% highlight scala %} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.DecisionTreeClassifier +import org.apache.spark.ml.classification.DecisionTreeClassificationModel +import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +import org.apache.spark.mllib.util.MLUtils + +// Load and parse the data file, converting it to a DataFrame. +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +// Index labels, adding metadata to the label column. +// Fit on whole dataset to include all labels in index. +val labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data) +// Automatically identify categorical features, and index them. +val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) // features with > 4 distinct values are treated as continuous + .fit(data) + +// Split the data into training and test sets (30% held out for testing) +val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + +// Train a DecisionTree model. +val dt = new DecisionTreeClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures") + +// Convert indexed labels back to original labels. +val labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels) + +// Chain indexers and tree in a Pipeline +val pipeline = new Pipeline() + .setStages(Array(labelIndexer, featureIndexer, dt, labelConverter)) + +// Train model. This also runs the indexers. +val model = pipeline.fit(trainingData) + +// Make predictions. +val predictions = model.transform(testData) + +// Select example rows to display. +predictions.select("predictedLabel", "label", "features").show(5) + +// Select (prediction, true label) and compute test error +val evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision") +val accuracy = evaluator.evaluate(predictions) +println("Test Error = " + (1.0 - accuracy)) + +val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] +println("Learned classification tree model:\n" + treeModel.toDebugString) +{% endhighlight %} +
    + +
    + +More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/classification/DecisionTreeClassifier.html). + +{% highlight java %} +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.DecisionTreeClassifier; +import org.apache.spark.ml.classification.DecisionTreeClassificationModel; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.ml.feature.*; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.DataFrame; + +// Load and parse the data file, converting it to a DataFrame. +RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); +DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); + +// Index labels, adding metadata to the label column. +// Fit on whole dataset to include all labels in index. +StringIndexerModel labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data); +// Automatically identify categorical features, and index them. +VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) // features with > 4 distinct values are treated as continuous + .fit(data); + +// Split the data into training and test sets (30% held out for testing) +DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); +DataFrame trainingData = splits[0]; +DataFrame testData = splits[1]; + +// Train a DecisionTree model. +DecisionTreeClassifier dt = new DecisionTreeClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures"); + +// Convert indexed labels back to original labels. +IndexToString labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels()); + +// Chain indexers and tree in a Pipeline +Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {labelIndexer, featureIndexer, dt, labelConverter}); + +// Train model. This also runs the indexers. +PipelineModel model = pipeline.fit(trainingData); + +// Make predictions. +DataFrame predictions = model.transform(testData); + +// Select example rows to display. +predictions.select("predictedLabel", "label", "features").show(5); + +// Select (prediction, true label) and compute test error +MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision"); +double accuracy = evaluator.evaluate(predictions); +System.out.println("Test Error = " + (1.0 - accuracy)); + +DecisionTreeClassificationModel treeModel = + (DecisionTreeClassificationModel)(model.stages()[2]); +System.out.println("Learned classification tree model:\n" + treeModel.toDebugString()); +{% endhighlight %} +
    + +
    + +More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.classification.DecisionTreeClassifier). + +{% highlight python %} +from pyspark.ml import Pipeline +from pyspark.ml.classification import DecisionTreeClassifier +from pyspark.ml.feature import StringIndexer, VectorIndexer +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +from pyspark.mllib.util import MLUtils + +# Load and parse the data file, converting it to a DataFrame. +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +# Index labels, adding metadata to the label column. +# Fit on whole dataset to include all labels in index. +labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) +# Automatically identify categorical features, and index them. +# We specify maxCategories so features with > 4 distinct values are treated as continuous. +featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + +# Split the data into training and test sets (30% held out for testing) +(trainingData, testData) = data.randomSplit([0.7, 0.3]) + +# Train a DecisionTree model. +dt = DecisionTreeClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures") + +# Chain indexers and tree in a Pipeline +pipeline = Pipeline(stages=[labelIndexer, featureIndexer, dt]) + +# Train model. This also runs the indexers. +model = pipeline.fit(trainingData) + +# Make predictions. +predictions = model.transform(testData) + +# Select example rows to display. +predictions.select("prediction", "indexedLabel", "features").show(5) + +# Select (prediction, true label) and compute test error +evaluator = MulticlassClassificationEvaluator( + labelCol="indexedLabel", predictionCol="prediction", metricName="precision") +accuracy = evaluator.evaluate(predictions) +print "Test Error = %g" % (1.0 - accuracy) + +treeModel = model.stages[2] +print treeModel # summary only +{% endhighlight %} +
    + +
    + + +## Regression + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use a feature transformer to index categorical features, adding metadata to the `DataFrame` which the Decision Tree algorithm can recognize. + +
    +
    + +More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.regression.DecisionTreeRegressor). + +{% highlight scala %} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.regression.DecisionTreeRegressor +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.feature.VectorIndexer +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.mllib.util.MLUtils + +// Load and parse the data file, converting it to a DataFrame. +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +// Automatically identify categorical features, and index them. +// Here, we treat features with > 4 distinct values as continuous. +val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + +// Split the data into training and test sets (30% held out for testing) +val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + +// Train a DecisionTree model. +val dt = new DecisionTreeRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures") + +// Chain indexer and tree in a Pipeline +val pipeline = new Pipeline() + .setStages(Array(featureIndexer, dt)) + +// Train model. This also runs the indexer. +val model = pipeline.fit(trainingData) + +// Make predictions. +val predictions = model.transform(testData) + +// Select example rows to display. +predictions.select("prediction", "label", "features").show(5) + +// Select (prediction, true label) and compute test error +val evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse") +val rmse = evaluator.evaluate(predictions) +println("Root Mean Squared Error (RMSE) on test data = " + rmse) + +val treeModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel] +println("Learned regression tree model:\n" + treeModel.toDebugString) +{% endhighlight %} +
    + +
    + +More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/regression/DecisionTreeRegressor.html). + +{% highlight java %} +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.ml.regression.DecisionTreeRegressionModel; +import org.apache.spark.ml.regression.DecisionTreeRegressor; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.DataFrame; + +// Load and parse the data file, converting it to a DataFrame. +RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); +DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); + +// Automatically identify categorical features, and index them. +// Set maxCategories so features with > 4 distinct values are treated as continuous. +VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + +// Split the data into training and test sets (30% held out for testing) +DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); +DataFrame trainingData = splits[0]; +DataFrame testData = splits[1]; + +// Train a DecisionTree model. +DecisionTreeRegressor dt = new DecisionTreeRegressor() + .setFeaturesCol("indexedFeatures"); + +// Chain indexer and tree in a Pipeline +Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {featureIndexer, dt}); + +// Train model. This also runs the indexer. +PipelineModel model = pipeline.fit(trainingData); + +// Make predictions. +DataFrame predictions = model.transform(testData); + +// Select example rows to display. +predictions.select("label", "features").show(5); + +// Select (prediction, true label) and compute test error +RegressionEvaluator evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse"); +double rmse = evaluator.evaluate(predictions); +System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); + +DecisionTreeRegressionModel treeModel = + (DecisionTreeRegressionModel)(model.stages()[1]); +System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); +{% endhighlight %} +
    + +
    + +More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.regression.DecisionTreeRegressor). + +{% highlight python %} +from pyspark.ml import Pipeline +from pyspark.ml.regression import DecisionTreeRegressor +from pyspark.ml.feature import VectorIndexer +from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.mllib.util import MLUtils + +# Load and parse the data file, converting it to a DataFrame. +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +# Automatically identify categorical features, and index them. +# We specify maxCategories so features with > 4 distinct values are treated as continuous. +featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + +# Split the data into training and test sets (30% held out for testing) +(trainingData, testData) = data.randomSplit([0.7, 0.3]) + +# Train a DecisionTree model. +dt = DecisionTreeRegressor(featuresCol="indexedFeatures") + +# Chain indexer and tree in a Pipeline +pipeline = Pipeline(stages=[featureIndexer, dt]) + +# Train model. This also runs the indexer. +model = pipeline.fit(trainingData) + +# Make predictions. +predictions = model.transform(testData) + +# Select example rows to display. +predictions.select("prediction", "label", "features").show(5) + +# Select (prediction, true label) and compute test error +evaluator = RegressionEvaluator( + labelCol="label", predictionCol="prediction", metricName="rmse") +rmse = evaluator.evaluate(predictions) +print "Root Mean Squared Error (RMSE) on test data = %g" % rmse + +treeModel = model.stages[1] +print treeModel # summary only +{% endhighlight %} +
    + +
    diff --git a/docs/ml-ensembles.md b/docs/ml-ensembles.md index 9ff50e95fc47..58f566c9b4b5 100644 --- a/docs/ml-ensembles.md +++ b/docs/ml-ensembles.md @@ -11,11 +11,925 @@ displayTitle: ML - Ensembles An [ensemble method](http://en.wikipedia.org/wiki/Ensemble_learning) is a learning algorithm which creates a model composed of a set of other base models. -The Pipelines API supports the following ensemble algorithms: [`OneVsRest`](api/scala/index.html#org.apache.spark.ml.classifier.OneVsRest) -## OneVsRest +## Tree Ensembles -[OneVsRest](http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest) is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. +The Pipelines API supports two major tree ensemble algorithms: [Random Forests](http://en.wikipedia.org/wiki/Random_forest) and [Gradient-Boosted Trees (GBTs)](http://en.wikipedia.org/wiki/Gradient_boosting). +Both use [MLlib decision trees](ml-decision-tree.html) as their base models. + +Users can find more information about ensemble algorithms in the [MLlib Ensemble guide](mllib-ensembles.html). In this section, we demonstrate the Pipelines API for ensembles. + +The main differences between this API and the [original MLlib ensembles API](mllib-ensembles.html) are: +* support for ML Pipelines +* separation of classification vs. regression +* use of DataFrame metadata to distinguish continuous and categorical features +* a bit more functionality for random forests: estimates of feature importance, as well as the predicted probability of each class (a.k.a. class conditional probabilities) for classification. + +### Random Forests + +[Random forests](http://en.wikipedia.org/wiki/Random_forest) +are ensembles of [decision trees](ml-decision-tree.html). +Random forests combine many decision trees in order to reduce the risk of overfitting. +MLlib supports random forests for binary and multiclass classification and for regression, +using both continuous and categorical features. + +This section gives examples of using random forests with the Pipelines API. +For more information on the algorithm, please see the [main MLlib docs on random forests](mllib-ensembles.html). + +#### Inputs and Outputs + +We list the input and output (prediction) column types here. +All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. + +##### Input Columns + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    labelColDouble"label"Label to predict
    featuresColVector"features"Feature vector
    + +##### Output Columns (Predictions) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescriptionNotes
    predictionColDouble"prediction"Predicted label
    rawPredictionColVector"rawPrediction"Vector of length # classes, with the counts of training instance labels at the tree node which makes the predictionClassification only
    probabilityColVector"probability"Vector of length # classes equal to rawPrediction normalized to a multinomial distributionClassification only
    + +#### Example: Classification + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.RandomForestClassifier) for more details. + +{% highlight scala %} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.RandomForestClassifier +import org.apache.spark.ml.classification.RandomForestClassificationModel +import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator + +// Load and parse the data file, converting it to a DataFrame. +val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + +// Index labels, adding metadata to the label column. +// Fit on whole dataset to include all labels in index. +val labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data) +// Automatically identify categorical features, and index them. +// Set maxCategories so features with > 4 distinct values are treated as continuous. +val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + +// Split the data into training and test sets (30% held out for testing) +val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + +// Train a RandomForest model. +val rf = new RandomForestClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures") + .setNumTrees(10) + +// Convert indexed labels back to original labels. +val labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels) + +// Chain indexers and forest in a Pipeline +val pipeline = new Pipeline() + .setStages(Array(labelIndexer, featureIndexer, rf, labelConverter)) + +// Train model. This also runs the indexers. +val model = pipeline.fit(trainingData) + +// Make predictions. +val predictions = model.transform(testData) + +// Select example rows to display. +predictions.select("predictedLabel", "label", "features").show(5) + +// Select (prediction, true label) and compute test error +val evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision") +val accuracy = evaluator.evaluate(predictions) +println("Test Error = " + (1.0 - accuracy)) + +val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] +println("Learned classification forest model:\n" + rfModel.toDebugString) +{% endhighlight %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/RandomForestClassifier.html) for more details. + +{% highlight java %} +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.RandomForestClassifier; +import org.apache.spark.ml.classification.RandomForestClassificationModel; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.ml.feature.*; +import org.apache.spark.sql.DataFrame; + +// Load and parse the data file, converting it to a DataFrame. +DataFrame data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); + +// Index labels, adding metadata to the label column. +// Fit on whole dataset to include all labels in index. +StringIndexerModel labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data); +// Automatically identify categorical features, and index them. +// Set maxCategories so features with > 4 distinct values are treated as continuous. +VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + +// Split the data into training and test sets (30% held out for testing) +DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); +DataFrame trainingData = splits[0]; +DataFrame testData = splits[1]; + +// Train a RandomForest model. +RandomForestClassifier rf = new RandomForestClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures"); + +// Convert indexed labels back to original labels. +IndexToString labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels()); + +// Chain indexers and forest in a Pipeline +Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {labelIndexer, featureIndexer, rf, labelConverter}); + +// Train model. This also runs the indexers. +PipelineModel model = pipeline.fit(trainingData); + +// Make predictions. +DataFrame predictions = model.transform(testData); + +// Select example rows to display. +predictions.select("predictedLabel", "label", "features").show(5); + +// Select (prediction, true label) and compute test error +MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision"); +double accuracy = evaluator.evaluate(predictions); +System.out.println("Test Error = " + (1.0 - accuracy)); + +RandomForestClassificationModel rfModel = + (RandomForestClassificationModel)(model.stages()[2]); +System.out.println("Learned classification forest model:\n" + rfModel.toDebugString()); +{% endhighlight %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.RandomForestClassifier) for more details. + +{% highlight python %} +from pyspark.ml import Pipeline +from pyspark.ml.classification import RandomForestClassifier +from pyspark.ml.feature import StringIndexer, VectorIndexer +from pyspark.ml.evaluation import MulticlassClassificationEvaluator + +# Load and parse the data file, converting it to a DataFrame. +data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + +# Index labels, adding metadata to the label column. +# Fit on whole dataset to include all labels in index. +labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) +# Automatically identify categorical features, and index them. +# Set maxCategories so features with > 4 distinct values are treated as continuous. +featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + +# Split the data into training and test sets (30% held out for testing) +(trainingData, testData) = data.randomSplit([0.7, 0.3]) + +# Train a RandomForest model. +rf = RandomForestClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures") + +# Chain indexers and forest in a Pipeline +pipeline = Pipeline(stages=[labelIndexer, featureIndexer, rf]) + +# Train model. This also runs the indexers. +model = pipeline.fit(trainingData) + +# Make predictions. +predictions = model.transform(testData) + +# Select example rows to display. +predictions.select("prediction", "indexedLabel", "features").show(5) + +# Select (prediction, true label) and compute test error +evaluator = MulticlassClassificationEvaluator( + labelCol="indexedLabel", predictionCol="prediction", metricName="precision") +accuracy = evaluator.evaluate(predictions) +print "Test Error = %g" % (1.0 - accuracy) + +rfModel = model.stages[2] +print rfModel # summary only +{% endhighlight %} +
    +
    + +#### Example: Regression + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use a feature transformer to index categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.RandomForestRegressor) for more details. + +{% highlight scala %} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.regression.RandomForestRegressor +import org.apache.spark.ml.regression.RandomForestRegressionModel +import org.apache.spark.ml.feature.VectorIndexer +import org.apache.spark.ml.evaluation.RegressionEvaluator + +// Load and parse the data file, converting it to a DataFrame. +val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + +// Automatically identify categorical features, and index them. +// Set maxCategories so features with > 4 distinct values are treated as continuous. +val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + +// Split the data into training and test sets (30% held out for testing) +val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + +// Train a RandomForest model. +val rf = new RandomForestRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures") + +// Chain indexer and forest in a Pipeline +val pipeline = new Pipeline() + .setStages(Array(featureIndexer, rf)) + +// Train model. This also runs the indexer. +val model = pipeline.fit(trainingData) + +// Make predictions. +val predictions = model.transform(testData) + +// Select example rows to display. +predictions.select("prediction", "label", "features").show(5) + +// Select (prediction, true label) and compute test error +val evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse") +val rmse = evaluator.evaluate(predictions) +println("Root Mean Squared Error (RMSE) on test data = " + rmse) + +val rfModel = model.stages(1).asInstanceOf[RandomForestRegressionModel] +println("Learned regression forest model:\n" + rfModel.toDebugString) +{% endhighlight %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/RandomForestRegressor.html) for more details. + +{% highlight java %} +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.ml.regression.RandomForestRegressionModel; +import org.apache.spark.ml.regression.RandomForestRegressor; +import org.apache.spark.sql.DataFrame; + +// Load and parse the data file, converting it to a DataFrame. +DataFrame data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); + +// Automatically identify categorical features, and index them. +// Set maxCategories so features with > 4 distinct values are treated as continuous. +VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + +// Split the data into training and test sets (30% held out for testing) +DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); +DataFrame trainingData = splits[0]; +DataFrame testData = splits[1]; + +// Train a RandomForest model. +RandomForestRegressor rf = new RandomForestRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures"); + +// Chain indexer and forest in a Pipeline +Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {featureIndexer, rf}); + +// Train model. This also runs the indexer. +PipelineModel model = pipeline.fit(trainingData); + +// Make predictions. +DataFrame predictions = model.transform(testData); + +// Select example rows to display. +predictions.select("prediction", "label", "features").show(5); + +// Select (prediction, true label) and compute test error +RegressionEvaluator evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse"); +double rmse = evaluator.evaluate(predictions); +System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); + +RandomForestRegressionModel rfModel = + (RandomForestRegressionModel)(model.stages()[1]); +System.out.println("Learned regression forest model:\n" + rfModel.toDebugString()); +{% endhighlight %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.RandomForestRegressor) for more details. + +{% highlight python %} +from pyspark.ml import Pipeline +from pyspark.ml.regression import RandomForestRegressor +from pyspark.ml.feature import VectorIndexer +from pyspark.ml.evaluation import RegressionEvaluator + +# Load and parse the data file, converting it to a DataFrame. +data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + +# Automatically identify categorical features, and index them. +# Set maxCategories so features with > 4 distinct values are treated as continuous. +featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + +# Split the data into training and test sets (30% held out for testing) +(trainingData, testData) = data.randomSplit([0.7, 0.3]) + +# Train a RandomForest model. +rf = RandomForestRegressor(featuresCol="indexedFeatures") + +# Chain indexer and forest in a Pipeline +pipeline = Pipeline(stages=[featureIndexer, rf]) + +# Train model. This also runs the indexer. +model = pipeline.fit(trainingData) + +# Make predictions. +predictions = model.transform(testData) + +# Select example rows to display. +predictions.select("prediction", "label", "features").show(5) + +# Select (prediction, true label) and compute test error +evaluator = RegressionEvaluator( + labelCol="label", predictionCol="prediction", metricName="rmse") +rmse = evaluator.evaluate(predictions) +print "Root Mean Squared Error (RMSE) on test data = %g" % rmse + +rfModel = model.stages[1] +print rfModel # summary only +{% endhighlight %} +
    +
    + +### Gradient-Boosted Trees (GBTs) + +[Gradient-Boosted Trees (GBTs)](http://en.wikipedia.org/wiki/Gradient_boosting) +are ensembles of [decision trees](ml-decision-tree.html). +GBTs iteratively train decision trees in order to minimize a loss function. +MLlib supports GBTs for binary classification and for regression, +using both continuous and categorical features. + +This section gives examples of using GBTs with the Pipelines API. +For more information on the algorithm, please see the [main MLlib docs on GBTs](mllib-ensembles.html). + +#### Inputs and Outputs + +We list the input and output (prediction) column types here. +All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. + +##### Input Columns + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    labelColDouble"label"Label to predict
    featuresColVector"features"Feature vector
    + +Note that `GBTClassifier` currently only supports binary labels. + +##### Output Columns (Predictions) + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescriptionNotes
    predictionColDouble"prediction"Predicted label
    + +In the future, `GBTClassifier` will also output columns for `rawPrediction` and `probability`, just as `RandomForestClassifier` does. + +#### Example: Classification + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.GBTClassifier) for more details. + +{% highlight scala %} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.GBTClassifier +import org.apache.spark.ml.classification.GBTClassificationModel +import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator + +// Load and parse the data file, converting it to a DataFrame. +val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + +// Index labels, adding metadata to the label column. +// Fit on whole dataset to include all labels in index. +val labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data) +// Automatically identify categorical features, and index them. +// Set maxCategories so features with > 4 distinct values are treated as continuous. +val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + +// Split the data into training and test sets (30% held out for testing) +val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + +// Train a GBT model. +val gbt = new GBTClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures") + .setMaxIter(10) + +// Convert indexed labels back to original labels. +val labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels) + +// Chain indexers and GBT in a Pipeline +val pipeline = new Pipeline() + .setStages(Array(labelIndexer, featureIndexer, gbt, labelConverter)) + +// Train model. This also runs the indexers. +val model = pipeline.fit(trainingData) + +// Make predictions. +val predictions = model.transform(testData) + +// Select example rows to display. +predictions.select("predictedLabel", "label", "features").show(5) + +// Select (prediction, true label) and compute test error +val evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision") +val accuracy = evaluator.evaluate(predictions) +println("Test Error = " + (1.0 - accuracy)) + +val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel] +println("Learned classification GBT model:\n" + gbtModel.toDebugString) +{% endhighlight %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/GBTClassifier.html) for more details. + +{% highlight java %} +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.GBTClassifier; +import org.apache.spark.ml.classification.GBTClassificationModel; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.ml.feature.*; +import org.apache.spark.sql.DataFrame; + +// Load and parse the data file, converting it to a DataFrame. +DataFrame data sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + +// Index labels, adding metadata to the label column. +// Fit on whole dataset to include all labels in index. +StringIndexerModel labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data); +// Automatically identify categorical features, and index them. +// Set maxCategories so features with > 4 distinct values are treated as continuous. +VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + +// Split the data into training and test sets (30% held out for testing) +DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); +DataFrame trainingData = splits[0]; +DataFrame testData = splits[1]; + +// Train a GBT model. +GBTClassifier gbt = new GBTClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures") + .setMaxIter(10); + +// Convert indexed labels back to original labels. +IndexToString labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels()); + +// Chain indexers and GBT in a Pipeline +Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {labelIndexer, featureIndexer, gbt, labelConverter}); + +// Train model. This also runs the indexers. +PipelineModel model = pipeline.fit(trainingData); + +// Make predictions. +DataFrame predictions = model.transform(testData); + +// Select example rows to display. +predictions.select("predictedLabel", "label", "features").show(5); + +// Select (prediction, true label) and compute test error +MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision"); +double accuracy = evaluator.evaluate(predictions); +System.out.println("Test Error = " + (1.0 - accuracy)); + +GBTClassificationModel gbtModel = + (GBTClassificationModel)(model.stages()[2]); +System.out.println("Learned classification GBT model:\n" + gbtModel.toDebugString()); +{% endhighlight %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.GBTClassifier) for more details. + +{% highlight python %} +from pyspark.ml import Pipeline +from pyspark.ml.classification import GBTClassifier +from pyspark.ml.feature import StringIndexer, VectorIndexer +from pyspark.ml.evaluation import MulticlassClassificationEvaluator + +# Load and parse the data file, converting it to a DataFrame. +data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + +# Index labels, adding metadata to the label column. +# Fit on whole dataset to include all labels in index. +labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) +# Automatically identify categorical features, and index them. +# Set maxCategories so features with > 4 distinct values are treated as continuous. +featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + +# Split the data into training and test sets (30% held out for testing) +(trainingData, testData) = data.randomSplit([0.7, 0.3]) + +# Train a GBT model. +gbt = GBTClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures", maxIter=10) + +# Chain indexers and GBT in a Pipeline +pipeline = Pipeline(stages=[labelIndexer, featureIndexer, gbt]) + +# Train model. This also runs the indexers. +model = pipeline.fit(trainingData) + +# Make predictions. +predictions = model.transform(testData) + +# Select example rows to display. +predictions.select("prediction", "indexedLabel", "features").show(5) + +# Select (prediction, true label) and compute test error +evaluator = MulticlassClassificationEvaluator( + labelCol="indexedLabel", predictionCol="prediction", metricName="precision") +accuracy = evaluator.evaluate(predictions) +print "Test Error = %g" % (1.0 - accuracy) + +gbtModel = model.stages[2] +print gbtModel # summary only +{% endhighlight %} +
    +
    + +#### Example: Regression + +Note: For this example dataset, `GBTRegressor` actually only needs 1 iteration, but that will not +be true in general. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.GBTRegressor) for more details. + +{% highlight scala %} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.regression.GBTRegressor +import org.apache.spark.ml.regression.GBTRegressionModel +import org.apache.spark.ml.feature.VectorIndexer +import org.apache.spark.ml.evaluation.RegressionEvaluator + +// Load and parse the data file, converting it to a DataFrame. +val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + +// Automatically identify categorical features, and index them. +// Set maxCategories so features with > 4 distinct values are treated as continuous. +val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + +// Split the data into training and test sets (30% held out for testing) +val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + +// Train a GBT model. +val gbt = new GBTRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures") + .setMaxIter(10) + +// Chain indexer and GBT in a Pipeline +val pipeline = new Pipeline() + .setStages(Array(featureIndexer, gbt)) + +// Train model. This also runs the indexer. +val model = pipeline.fit(trainingData) + +// Make predictions. +val predictions = model.transform(testData) + +// Select example rows to display. +predictions.select("prediction", "label", "features").show(5) + +// Select (prediction, true label) and compute test error +val evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse") +val rmse = evaluator.evaluate(predictions) +println("Root Mean Squared Error (RMSE) on test data = " + rmse) + +val gbtModel = model.stages(1).asInstanceOf[GBTRegressionModel] +println("Learned regression GBT model:\n" + gbtModel.toDebugString) +{% endhighlight %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/GBTRegressor.html) for more details. + +{% highlight java %} +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.ml.regression.GBTRegressionModel; +import org.apache.spark.ml.regression.GBTRegressor; +import org.apache.spark.sql.DataFrame; + +// Load and parse the data file, converting it to a DataFrame. +DataFrame data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + +// Automatically identify categorical features, and index them. +// Set maxCategories so features with > 4 distinct values are treated as continuous. +VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + +// Split the data into training and test sets (30% held out for testing) +DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); +DataFrame trainingData = splits[0]; +DataFrame testData = splits[1]; + +// Train a GBT model. +GBTRegressor gbt = new GBTRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures") + .setMaxIter(10); + +// Chain indexer and GBT in a Pipeline +Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {featureIndexer, gbt}); + +// Train model. This also runs the indexer. +PipelineModel model = pipeline.fit(trainingData); + +// Make predictions. +DataFrame predictions = model.transform(testData); + +// Select example rows to display. +predictions.select("prediction", "label", "features").show(5); + +// Select (prediction, true label) and compute test error +RegressionEvaluator evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse"); +double rmse = evaluator.evaluate(predictions); +System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); + +GBTRegressionModel gbtModel = + (GBTRegressionModel)(model.stages()[1]); +System.out.println("Learned regression GBT model:\n" + gbtModel.toDebugString()); +{% endhighlight %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.GBTRegressor) for more details. + +{% highlight python %} +from pyspark.ml import Pipeline +from pyspark.ml.regression import GBTRegressor +from pyspark.ml.feature import VectorIndexer +from pyspark.ml.evaluation import RegressionEvaluator + +# Load and parse the data file, converting it to a DataFrame. +data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + +# Automatically identify categorical features, and index them. +# Set maxCategories so features with > 4 distinct values are treated as continuous. +featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + +# Split the data into training and test sets (30% held out for testing) +(trainingData, testData) = data.randomSplit([0.7, 0.3]) + +# Train a GBT model. +gbt = GBTRegressor(featuresCol="indexedFeatures", maxIter=10) + +# Chain indexer and GBT in a Pipeline +pipeline = Pipeline(stages=[featureIndexer, gbt]) + +# Train model. This also runs the indexer. +model = pipeline.fit(trainingData) + +# Make predictions. +predictions = model.transform(testData) + +# Select example rows to display. +predictions.select("prediction", "label", "features").show(5) + +# Select (prediction, true label) and compute test error +evaluator = RegressionEvaluator( + labelCol="label", predictionCol="prediction", metricName="rmse") +rmse = evaluator.evaluate(predictions) +print "Root Mean Squared Error (RMSE) on test data = %g" % rmse + +gbtModel = model.stages[1] +print gbtModel # summary only +{% endhighlight %} +
    +
    + + +## One-vs-Rest (a.k.a. One-vs-All) + +[OneVsRest](http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest) is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. It is also known as "One-vs-All." `OneVsRest` is implemented as an `Estimator`. For the base classifier it takes instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. @@ -28,18 +942,20 @@ The example below demonstrates how to load the
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classifier.OneVsRest) for more details. + {% highlight scala %} import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest} import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.{Row, SQLContext} val sqlContext = new SQLContext(sc) // parse data into dataframe -val data = MLUtils.loadLibSVMFile(sc, - "data/mllib/sample_multiclass_classification_data.txt") -val Array(train, test) = data.toDF().randomSplit(Array(0.7, 0.3)) +val data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_multiclass_classification_data.txt") +val Array(train, test) = data.randomSplit(Array(0.7, 0.3)) // instantiate multiclass learner and train val ovr = new OneVsRest().setClassifier(new LogisticRegression) @@ -64,9 +980,12 @@ println("label\tfpr\n") } {% endhighlight %}
    +
    -{% highlight java %} +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/OneVsRest.html) for more details. + +{% highlight java %} import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.classification.LogisticRegression; @@ -74,9 +993,6 @@ import org.apache.spark.ml.classification.OneVsRest; import org.apache.spark.ml.classification.OneVsRestModel; import org.apache.spark.mllib.evaluation.MulticlassMetrics; import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; @@ -84,11 +1000,10 @@ SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample"); JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext jsql = new SQLContext(jsc); -RDD data = MLUtils.loadLibSVMFile(jsc.sc(), - "data/mllib/sample_multiclass_classification_data.txt"); +DataFrame dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_multiclass_classification_data.txt"); -DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class); -DataFrame[] splits = dataFrame.randomSplit(new double[]{0.7, 0.3}, 12345); +DataFrame[] splits = dataFrame.randomSplit(new double[] {0.7, 0.3}, 12345); DataFrame train = splits[0]; DataFrame test = splits[1]; diff --git a/docs/ml-features.md b/docs/ml-features.md index f88c0248c1a8..b70da4ac6384 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -55,7 +55,7 @@ rescaledData.select("features", "label").take(3).foreach(println)
    {% highlight java %} -import com.google.common.collect.Lists; +import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.HashingTF; @@ -70,7 +70,7 @@ import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( +JavaRDD jrdd = jsc.parallelize(Arrays.asList( RowFactory.create(0, "Hi I heard about Spark"), RowFactory.create(0, "I wish Java could use case classes"), RowFactory.create(1, "Logistic regression models are neat") @@ -123,12 +123,21 @@ for features_label in rescaledData.select("features", "label").take(3): ## Word2Vec -`Word2Vec` is an `Estimator` which takes sequences of words that represents documents and trains a `Word2VecModel`. The model is a `Map(String, Vector)` essentially, which maps each word to an unique fix-sized vector. The `Word2VecModel` transforms each documents into a vector using the average of all words in the document, which aims to other computations of documents such as similarity calculation consequencely. Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#Word2Vec) for more details on Word2Vec. +`Word2Vec` is an `Estimator` which takes sequences of words representing documents and trains a +`Word2VecModel`. The model maps each word to a unique fixed-size vector. The `Word2VecModel` +transforms each document into a vector using the average of all words in the document; this vector +can then be used for as features for prediction, document similarity calculations, etc. +Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#Word2Vec) for more +details. -Word2Vec is implemented in [Word2Vec](api/scala/index.html#org.apache.spark.ml.feature.Word2Vec). In the following code segment, we start with a set of documents, each of them is represented as a sequence of words. For each document, we transform it into a feature vector. This feature vector could then be passed to a learning algorithm. +In the following code segment, we start with a set of documents, each of which is represented as a sequence of words. For each document, we transform it into a feature vector. This feature vector could then be passed to a learning algorithm.
    + +Refer to the [Word2Vec Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Word2Vec) +for more details on the API. + {% highlight scala %} import org.apache.spark.ml.feature.Word2Vec @@ -152,8 +161,12 @@ result.select("result").take(3).foreach(println)
    + +Refer to the [Word2Vec Java docs](api/java/org/apache/spark/ml/feature/Word2Vec.html) +for more details on the API. + {% highlight java %} -import com.google.common.collect.Lists; +import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -167,10 +180,10 @@ JavaSparkContext jsc = ... SQLContext sqlContext = ... // Input data: Each row is a bag of words from a sentence or document. -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(Lists.newArrayList("Hi I heard about Spark".split(" "))), - RowFactory.create(Lists.newArrayList("I wish Java could use case classes".split(" "))), - RowFactory.create(Lists.newArrayList("Logistic regression models are neat".split(" "))) +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))), + RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))), + RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" "))) )); StructType schema = new StructType(new StructField[]{ new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) @@ -192,6 +205,10 @@ for (Row r: result.select("result").take(3)) {
    + +Refer to the [Word2Vec Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Word2Vec) +for more details on the API. + {% highlight python %} from pyspark.ml.feature import Word2Vec @@ -211,35 +228,156 @@ for feature in result.select("result").take(3):
    +## CountVectorizer + +`CountVectorizer` and `CountVectorizerModel` aim to help convert a collection of text documents + to vectors of token counts. When an a-priori dictionary is not available, `CountVectorizer` can + be used as an `Estimator` to extract the vocabulary and generates a `CountVectorizerModel`. The + model produces sparse representations for the documents over the vocabulary, which can then be + passed to other algorithms like LDA. + + During the fitting process, `CountVectorizer` will select the top `vocabSize` words ordered by + term frequency across the corpus. An optional parameter "minDF" also affect the fitting process + by specifying the minimum number (or fraction if < 1.0) of documents a term must appear in to be + included in the vocabulary. + +**Examples** + +Assume that we have the following DataFrame with columns `id` and `texts`: + +~~~~ + id | texts +----|---------- + 0 | Array("a", "b", "c") + 1 | Array("a", "b", "b", "c", "a") +~~~~ + +each row in`texts` is a document of type Array[String]. +Invoking fit of `CountVectorizer` produces a `CountVectorizerModel` with vocabulary (a, b, c), +then the output column "vector" after transformation contains: + +~~~~ + id | texts | vector +----|---------------------------------|--------------- + 0 | Array("a", "b", "c") | (3,[0,1,2],[1.0,1.0,1.0]) + 1 | Array("a", "b", "b", "c", "a") | (3,[0,1,2],[2.0,2.0,1.0]) +~~~~ + +each vector represents the token counts of the document over the vocabulary. + +
    +
    +More details can be found in the API docs for +[CountVectorizer](api/scala/index.html#org.apache.spark.ml.feature.CountVectorizer) and +[CountVectorizerModel](api/scala/index.html#org.apache.spark.ml.feature.CountVectorizerModel). +{% highlight scala %} +import org.apache.spark.ml.feature.CountVectorizer +import org.apache.spark.mllib.util.CountVectorizerModel + +val df = sqlContext.createDataFrame(Seq( + (0, Array("a", "b", "c")), + (1, Array("a", "b", "b", "c", "a")) +)).toDF("id", "words") + +// fit a CountVectorizerModel from the corpus +val cvModel: CountVectorizerModel = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setVocabSize(3) + .setMinDF(2) // a term must appear in more or equal to 2 documents to be included in the vocabulary + .fit(df) + +// alternatively, define CountVectorizerModel with a-priori vocabulary +val cvm = new CountVectorizerModel(Array("a", "b", "c")) + .setInputCol("words") + .setOutputCol("features") + +cvModel.transform(df).select("features").show() +{% endhighlight %} +
    + +
    +More details can be found in the API docs for +[CountVectorizer](api/java/org/apache/spark/ml/feature/CountVectorizer.html) and +[CountVectorizerModel](api/java/org/apache/spark/ml/feature/CountVectorizerModel.html). +{% highlight java %} +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.CountVectorizer; +import org.apache.spark.ml.feature.CountVectorizerModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; + +// Input data: Each row is a bag of words from a sentence or document. +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("a", "b", "c")), + RowFactory.create(Arrays.asList("a", "b", "b", "c", "a")) +)); +StructType schema = new StructType(new StructField [] { + new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) +}); +DataFrame df = sqlContext.createDataFrame(jrdd, schema); + +// fit a CountVectorizerModel from the corpus +CountVectorizerModel cvModel = new CountVectorizer() + .setInputCol("text") + .setOutputCol("feature") + .setVocabSize(3) + .setMinDF(2) // a term must appear in more or equal to 2 documents to be included in the vocabulary + .fit(df); + +// alternatively, define CountVectorizerModel with a-priori vocabulary +CountVectorizerModel cvm = new CountVectorizerModel(new String[]{"a", "b", "c"}) + .setInputCol("text") + .setOutputCol("feature"); + +cvModel.transform(df).show(); +{% endhighlight %} +
    +
    + # Feature Transformers ## Tokenizer [Tokenization](http://en.wikipedia.org/wiki/Lexical_analysis#Tokenization) is the process of taking text (such as a sentence) and breaking it into individual terms (usually words). A simple [Tokenizer](api/scala/index.html#org.apache.spark.ml.feature.Tokenizer) class provides this functionality. The example below shows how to split sentences into sequences of words. -Note: A more advanced tokenizer is provided via [RegexTokenizer](api/scala/index.html#org.apache.spark.ml.feature.RegexTokenizer). +[RegexTokenizer](api/scala/index.html#org.apache.spark.ml.feature.RegexTokenizer) allows more + advanced tokenization based on regular expression (regex) matching. + By default, the parameter "pattern" (regex, default: \\s+) is used as delimiters to split the input text. + Alternatively, users can set parameter "gaps" to false indicating the regex "pattern" denotes + "tokens" rather than splitting gaps, and find all matching occurrences as the tokenization result.
    {% highlight scala %} -import org.apache.spark.ml.feature.Tokenizer +import org.apache.spark.ml.feature.{Tokenizer, RegexTokenizer} val sentenceDataFrame = sqlContext.createDataFrame(Seq( (0, "Hi I heard about Spark"), - (0, "I wish Java could use case classes"), - (1, "Logistic regression models are neat") + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") )).toDF("label", "sentence") val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") -val wordsDataFrame = tokenizer.transform(sentenceDataFrame) -wordsDataFrame.select("words", "label").take(3).foreach(println) +val regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) + +val tokenized = tokenizer.transform(sentenceDataFrame) +tokenized.select("words", "label").take(3).foreach(println) +val regexTokenized = regexTokenizer.transform(sentenceDataFrame) +regexTokenized.select("words", "label").take(3).foreach(println) {% endhighlight %}
    {% highlight java %} -import com.google.common.collect.Lists; +import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RegexTokenizer; import org.apache.spark.ml.feature.Tokenizer; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.sql.DataFrame; @@ -250,10 +388,10 @@ import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( +JavaRDD jrdd = jsc.parallelize(Arrays.asList( RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(0, "I wish Java could use case classes"), - RowFactory.create(1, "Logistic regression models are neat") + RowFactory.create(1, "I wish Java could use case classes"), + RowFactory.create(2, "Logistic,regression,models,are,neat") )); StructType schema = new StructType(new StructField[]{ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), @@ -267,22 +405,232 @@ for (Row r : wordsDataFrame.select("words", "label").take(3)) { for (String word : words) System.out.print(word + " "); System.out.println(); } + +RegexTokenizer regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); {% endhighlight %}
    {% highlight python %} -from pyspark.ml.feature import Tokenizer +from pyspark.ml.feature import Tokenizer, RegexTokenizer sentenceDataFrame = sqlContext.createDataFrame([ (0, "Hi I heard about Spark"), - (0, "I wish Java could use case classes"), - (1, "Logistic regression models are neat") + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") ], ["label", "sentence"]) tokenizer = Tokenizer(inputCol="sentence", outputCol="words") wordsDataFrame = tokenizer.transform(sentenceDataFrame) for words_label in wordsDataFrame.select("words", "label").take(3): print(words_label) +regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") +# alternatively, pattern="\\w+", gaps(False) +{% endhighlight %} +
    +
    + +## StopWordsRemover +[Stop words](https://en.wikipedia.org/wiki/Stop_words) are words which +should be excluded from the input, typically because the words appear +frequently and don't carry as much meaning. + +`StopWordsRemover` takes as input a sequence of strings (e.g. the output +of a [Tokenizer](ml-features.html#tokenizer)) and drops all the stop +words from the input sequences. The list of stopwords is specified by +the `stopWords` parameter. We provide [a list of stop +words](http://ir.dcs.gla.ac.uk/resources/linguistic_utils/stop_words) by +default, accessible by calling `getStopWords` on a newly instantiated +`StopWordsRemover` instance. + +**Examples** + +Assume that we have the following DataFrame with columns `id` and `raw`: + +~~~~ + id | raw +----|---------- + 0 | [I, saw, the, red, baloon] + 1 | [Mary, had, a, little, lamb] +~~~~ + +Applying `StopWordsRemover` with `raw` as the input column and `filtered` as the output +column, we should get the following: + +~~~~ + id | raw | filtered +----|-----------------------------|-------------------- + 0 | [I, saw, the, red, baloon] | [saw, red, baloon] + 1 | [Mary, had, a, little, lamb]|[Mary, little, lamb] +~~~~ + +In `filtered`, the stop words "I", "the", "had", and "a" have been +filtered out. + +
    + +
    + +[`StopWordsRemover`](api/scala/index.html#org.apache.spark.ml.feature.StopWordsRemover) +takes an input column name, an output column name, a list of stop words, +and a boolean indicating if the matches should be case sensitive (false +by default). + +{% highlight scala %} +import org.apache.spark.ml.feature.StopWordsRemover + +val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") +val dataSet = sqlContext.createDataFrame(Seq( + (0, Seq("I", "saw", "the", "red", "baloon")), + (1, Seq("Mary", "had", "a", "little", "lamb")) +)).toDF("id", "raw") + +remover.transform(dataSet).show() +{% endhighlight %} +
    + +
    + +[`StopWordsRemover`](api/java/org/apache/spark/ml/feature/StopWordsRemover.html) +takes an input column name, an output column name, a list of stop words, +and a boolean indicating if the matches should be case sensitive (false +by default). + +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StopWordsRemover; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +StopWordsRemover remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered"); + +JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), + RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) +)); +StructType schema = new StructType(new StructField[] { + new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) +}); +DataFrame dataset = jsql.createDataFrame(rdd, schema); + +remover.transform(dataset).show(); +{% endhighlight %} +
    + +
    +[`StopWordsRemover`](api/python/pyspark.ml.html#pyspark.ml.feature.StopWordsRemover) +takes an input column name, an output column name, a list of stop words, +and a boolean indicating if the matches should be case sensitive (false +by default). + +{% highlight python %} +from pyspark.ml.feature import StopWordsRemover + +sentenceData = sqlContext.createDataFrame([ + (0, ["I", "saw", "the", "red", "baloon"]), + (1, ["Mary", "had", "a", "little", "lamb"]) +], ["label", "raw"]) + +remover = StopWordsRemover(inputCol="raw", outputCol="filtered") +remover.transform(sentenceData).show(truncate=False) +{% endhighlight %} +
    +
    + +## $n$-gram + +An [n-gram](https://en.wikipedia.org/wiki/N-gram) is a sequence of $n$ tokens (typically words) for some integer $n$. The `NGram` class can be used to transform input features into $n$-grams. + +`NGram` takes as input a sequence of strings (e.g. the output of a [Tokenizer](ml-features.html#tokenizer)). The parameter `n` is used to determine the number of terms in each $n$-gram. The output will consist of a sequence of $n$-grams where each $n$-gram is represented by a space-delimited string of $n$ consecutive words. If the input sequence contains fewer than `n` strings, no output is produced. + +
    + +
    + +[`NGram`](api/scala/index.html#org.apache.spark.ml.feature.NGram) takes an input column name, an output column name, and an optional length parameter n (n=2 by default). + +{% highlight scala %} +import org.apache.spark.ml.feature.NGram + +val wordDataFrame = sqlContext.createDataFrame(Seq( + (0, Array("Hi", "I", "heard", "about", "Spark")), + (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), + (2, Array("Logistic", "regression", "models", "are", "neat")) +)).toDF("label", "words") + +val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") +val ngramDataFrame = ngram.transform(wordDataFrame) +ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) +{% endhighlight %} +
    + +
    + +[`NGram`](api/java/org/apache/spark/ml/feature/NGram.html) takes an input column name, an output column name, and an optional length parameter n (n=2 by default). + +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.NGram; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")), + RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")), + RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat")) +)); +StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) +}); +DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema); +NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); +DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame); +for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { + java.util.List ngrams = r.getList(0); + for (String ngram : ngrams) System.out.print(ngram + " --- "); + System.out.println(); +} +{% endhighlight %} +
    + +
    + +[`NGram`](api/python/pyspark.ml.html#pyspark.ml.feature.NGram) takes an input column name, an output column name, and an optional length parameter n (n=2 by default). + +{% highlight python %} +from pyspark.ml.feature import NGram + +wordDataFrame = sqlContext.createDataFrame([ + (0, ["Hi", "I", "heard", "about", "Spark"]), + (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), + (2, ["Logistic", "regression", "models", "are", "neat"]) +], ["label", "words"]) +ngram = NGram(inputCol="words", outputCol="ngrams") +ngramDataFrame = ngram.transform(wordDataFrame) +for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): + print(ngrams_label) {% endhighlight %}
    @@ -290,12 +638,15 @@ for words_label in wordsDataFrame.select("words", "label").take(3): ## Binarizer -Binarization is the process of thresholding numerical features to binary features. As some probabilistic estimators make assumption that the input data is distributed according to [Bernoulli distribution](http://en.wikipedia.org/wiki/Bernoulli_distribution), a binarizer is useful for pre-processing the input data with continuous numerical features. +Binarization is the process of thresholding numerical features to binary (0/1) features. -A simple [Binarizer](api/scala/index.html#org.apache.spark.ml.feature.Binarizer) class provides this functionality. Besides the common parameters of `inputCol` and `outputCol`, `Binarizer` has the parameter `threshold` used for binarizing continuous numerical features. The features greater than the threshold, will be binarized to 1.0. The features equal to or less than the threshold, will be binarized to 0.0. The example below shows how to binarize numerical features. +`Binarizer` takes the common parameters `inputCol` and `outputCol`, as well as the `threshold` for binarization. Feature values greater than the threshold are binarized to 1.0; values equal to or less than the threshold are binarized to 0.0.
    + +Refer to the [Binarizer API doc](api/scala/index.html#org.apache.spark.ml.feature.Binarizer) for more details. + {% highlight scala %} import org.apache.spark.ml.feature.Binarizer import org.apache.spark.sql.DataFrame @@ -319,8 +670,11 @@ binarizedFeatures.collect().foreach(println)
    + +Refer to the [Binarizer API doc](api/java/org/apache/spark/ml/feature/Binarizer.html) for more details. + {% highlight java %} -import com.google.common.collect.Lists; +import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.Binarizer; @@ -332,7 +686,7 @@ import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( +JavaRDD jrdd = jsc.parallelize(Arrays.asList( RowFactory.create(0, 0.1), RowFactory.create(1, 0.8), RowFactory.create(2, 0.2) @@ -356,6 +710,9 @@ for (Row r : binarizedFeatures.collect()) {
    + +Refer to the [Binarizer API doc](api/python/pyspark.ml.html#pyspark.ml.feature.Binarizer) for more details. + {% highlight python %} from pyspark.ml.feature import Binarizer @@ -373,6 +730,92 @@ for binarized_feature, in binarizedFeatures.collect():
    +## PCA + +[PCA](http://en.wikipedia.org/wiki/Principal_component_analysis) is a statistical procedure that uses an orthogonal transformation to convert a set of observations of possibly correlated variables into a set of values of linearly uncorrelated variables called principal components. A [PCA](api/scala/index.html#org.apache.spark.ml.feature.PCA) class trains a model to project vectors to a low-dimensional space using PCA. The example below shows how to project 5-dimensional feature vectors into 3-dimensional principal components. + +
    +
    +See the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.feature.PCA) for API details. +{% highlight scala %} +import org.apache.spark.ml.feature.PCA +import org.apache.spark.mllib.linalg.Vectors + +val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) +) +val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") +val pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df) +val pcaDF = pca.transform(df) +val result = pcaDF.select("pcaFeatures") +result.show() +{% endhighlight %} +
    + +
    +See the [Java API documentation](api/java/org/apache/spark/ml/feature/PCA.html) for API details. +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.PCA +import org.apache.spark.ml.feature.PCAModel +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaSparkContext jsc = ... +SQLContext jsql = ... +JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), + RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), + RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) +)); +StructType schema = new StructType(new StructField[] { + new StructField("features", new VectorUDT(), false, Metadata.empty()), +}); +DataFrame df = jsql.createDataFrame(data, schema); +PCAModel pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df); +DataFrame result = pca.transform(df).select("pcaFeatures"); +result.show(); +{% endhighlight %} +
    + +
    +See the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.feature.PCA) for API details. +{% highlight python %} +from pyspark.ml.feature import PCA +from pyspark.mllib.linalg import Vectors + +data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), + (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), + (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] +df = sqlContext.createDataFrame(data,["features"]) +pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") +model = pca.fit(df) +result = model.transform(df).select("pcaFeatures") +result.show(truncate=False) +{% endhighlight %} +
    +
    + ## PolynomialExpansion [Polynomial expansion](http://en.wikipedia.org/wiki/Polynomial_expansion) is the process of expanding your features into a polynomial space, which is formulated by an n-degree combination of original dimensions. A [PolynomialExpansion](api/scala/index.html#org.apache.spark.ml.feature.PolynomialExpansion) class provides this functionality. The example below shows how to expand your features into a 3-degree polynomial space. @@ -400,7 +843,7 @@ polyDF.select("polyFeatures").take(3).foreach(println)
    {% highlight java %} -import com.google.common.collect.Lists; +import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -421,7 +864,7 @@ PolynomialExpansion polyExpansion = new PolynomialExpansion() .setInputCol("features") .setOutputCol("polyFeatures") .setDegree(3); -JavaRDD data = jsc.parallelize(Lists.newArrayList( +JavaRDD data = jsc.parallelize(Arrays.asList( RowFactory.create(Vectors.dense(-2.0, 2.3)), RowFactory.create(Vectors.dense(0.0, 0.0)), RowFactory.create(Vectors.dense(0.6, -1.1)) @@ -456,12 +899,87 @@ for expanded in polyDF.select("polyFeatures").take(3):
    +## Discrete Cosine Transform (DCT) + +The [Discrete Cosine +Transform](https://en.wikipedia.org/wiki/Discrete_cosine_transform) +transforms a length $N$ real-valued sequence in the time domain into +another length $N$ real-valued sequence in the frequency domain. A +[DCT](api/scala/index.html#org.apache.spark.ml.feature.DCT) class +provides this functionality, implementing the +[DCT-II](https://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II) +and scaling the result by $1/\sqrt{2}$ such that the representing matrix +for the transform is unitary. No shift is applied to the transformed +sequence (e.g. the $0$th element of the transformed sequence is the +$0$th DCT coefficient and _not_ the $N/2$th). + +
    +
    +{% highlight scala %} +import org.apache.spark.ml.feature.DCT +import org.apache.spark.mllib.linalg.Vectors + +val data = Seq( + Vectors.dense(0.0, 1.0, -2.0, 3.0), + Vectors.dense(-1.0, 2.0, 4.0, -7.0), + Vectors.dense(14.0, -2.0, -5.0, 1.0)) +val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") +val dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false) +val dctDf = dct.transform(df) +dctDf.select("featuresDCT").show(3) +{% endhighlight %} +
    + +
    +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.DCT; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), + RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), + RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) +)); +StructType schema = new StructType(new StructField[] { + new StructField("features", new VectorUDT(), false, Metadata.empty()), +}); +DataFrame df = jsql.createDataFrame(data, schema); +DCT dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false); +DataFrame dctDf = dct.transform(df); +dctDf.select("featuresDCT").show(3); +{% endhighlight %} +
    +
    + ## StringIndexer `StringIndexer` encodes a string column of labels to a column of label indices. The indices are in `[0, numLabels)`, ordered by label frequencies. So the most frequent label gets index `0`. -If the input column is numeric, we cast it to string and index the string values. +If the input column is numeric, we cast it to string and index the string +values. When downstream pipeline components such as `Estimator` or +`Transformer` make use of this string-indexed label, you must set the input +column of the component to this string-indexed column name. In many cases, +you can set the input column with `setInputCol`. **Examples** @@ -605,7 +1123,7 @@ encoded.select("id", "categoryVec").foreach(println)
    {% highlight java %} -import com.google.common.collect.Lists; +import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.OneHotEncoder; @@ -619,7 +1137,7 @@ import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( +JavaRDD jrdd = jsc.parallelize(Arrays.asList( RowFactory.create(0, "a"), RowFactory.create(1, "b"), RowFactory.create(2, "c"), @@ -687,9 +1205,9 @@ In the example below, we read in a dataset of labeled points and then use `Vecto
    {% highlight scala %} import org.apache.spark.ml.feature.VectorIndexer -import org.apache.spark.mllib.util.MLUtils -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +val data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") val indexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexed") @@ -708,16 +1226,12 @@ val indexedData = indexerModel.transform(data) {% highlight java %} import java.util.Map; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.VectorIndexer; import org.apache.spark.ml.feature.VectorIndexerModel; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.sql.DataFrame; -JavaRDD rdd = MLUtils.loadLibSVMFile(sc.sc(), - "data/mllib/sample_libsvm_data.txt").toJavaRDD(); -DataFrame data = sqlContext.createDataFrame(rdd, LabeledPoint.class); +DataFrame data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); VectorIndexer indexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexed") @@ -738,9 +1252,9 @@ DataFrame indexedData = indexerModel.transform(data);
    {% highlight python %} from pyspark.ml.feature import VectorIndexer -from pyspark.mllib.util import MLUtils -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) indexerModel = indexer.fit(data) @@ -761,10 +1275,9 @@ The following example demonstrates how to load a dataset in libsvm format and th
    {% highlight scala %} import org.apache.spark.ml.feature.Normalizer -import org.apache.spark.mllib.util.MLUtils -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -val dataFrame = sqlContext.createDataFrame(data) +val dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") // Normalize each Vector using $L^1$ norm. val normalizer = new Normalizer() @@ -780,15 +1293,11 @@ val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.Positi
    {% highlight java %} -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.Normalizer; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.sql.DataFrame; -JavaRDD data = - MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt").toJavaRDD(); -DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class); +DataFrame dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); // Normalize each Vector using $L^1$ norm. Normalizer normalizer = new Normalizer() @@ -805,11 +1314,10 @@ DataFrame lInfNormData =
    {% highlight python %} -from pyspark.mllib.util import MLUtils from pyspark.ml.feature import Normalizer -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -dataFrame = sqlContext.createDataFrame(data) +dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") # Normalize each Vector using $L^1$ norm. normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) @@ -843,10 +1351,9 @@ The following example demonstrates how to load a dataset in libsvm format and th
    {% highlight scala %} import org.apache.spark.ml.feature.StandardScaler -import org.apache.spark.mllib.util.MLUtils -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -val dataFrame = sqlContext.createDataFrame(data) +val dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") val scaler = new StandardScaler() .setInputCol("features") .setOutputCol("scaledFeatures") @@ -863,15 +1370,12 @@ val scaledData = scalerModel.transform(dataFrame)
    {% highlight java %} -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.ml.feature.StandardScalerModel; import org.apache.spark.sql.DataFrame; -JavaRDD data = - MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt").toJavaRDD(); -DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class); +DataFrame dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); StandardScaler scaler = new StandardScaler() .setInputCol("features") .setOutputCol("scaledFeatures") @@ -888,11 +1392,10 @@ DataFrame scaledData = scalerModel.transform(dataFrame);
    {% highlight python %} -from pyspark.mllib.util import MLUtils from pyspark.ml.feature import StandardScaler -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -dataFrame = sqlContext.createDataFrame(data) +dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", withStd=True, withMean=False) @@ -905,6 +1408,72 @@ scaledData = scalerModel.transform(dataFrame)
    +## MinMaxScaler + +`MinMaxScaler` transforms a dataset of `Vector` rows, rescaling each feature to a specific range (often [0, 1]). It takes parameters: + +* `min`: 0.0 by default. Lower bound after transformation, shared by all features. +* `max`: 1.0 by default. Upper bound after transformation, shared by all features. + +`MinMaxScaler` computes summary statistics on a data set and produces a `MinMaxScalerModel`. The model can then transform each feature individually such that it is in the given range. + +The rescaled value for a feature E is calculated as, +`\begin{equation} + Rescaled(e_i) = \frac{e_i - E_{min}}{E_{max} - E_{min}} * (max - min) + min +\end{equation}` +For the case `E_{max} == E_{min}`, `Rescaled(e_i) = 0.5 * (max + min)` + +Note that since zero values will probably be transformed to non-zero values, output of the transformer will be DenseVector even for sparse input. + +The following example demonstrates how to load a dataset in libsvm format and then rescale each feature to [0, 1]. + +
    +
    +More details can be found in the API docs for +[MinMaxScaler](api/scala/index.html#org.apache.spark.ml.feature.MinMaxScaler) and +[MinMaxScalerModel](api/scala/index.html#org.apache.spark.ml.feature.MinMaxScalerModel). +{% highlight scala %} +import org.apache.spark.ml.feature.MinMaxScaler + +val dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") +val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + +// Compute summary statistics and generate MinMaxScalerModel +val scalerModel = scaler.fit(dataFrame) + +// rescale each feature to range [min, max]. +val scaledData = scalerModel.transform(dataFrame) +{% endhighlight %} +
    + +
    +More details can be found in the API docs for +[MinMaxScaler](api/java/org/apache/spark/ml/feature/MinMaxScaler.html) and +[MinMaxScalerModel](api/java/org/apache/spark/ml/feature/MinMaxScalerModel.html). +{% highlight java %} +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.MinMaxScaler; +import org.apache.spark.ml.feature.MinMaxScalerModel; +import org.apache.spark.sql.DataFrame; + +DataFrame dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); +MinMaxScaler scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures"); + +// Compute summary statistics and generate MinMaxScalerModel +MinMaxScalerModel scalerModel = scaler.fit(dataFrame); + +// rescale each feature to range [min, max]. +DataFrame scaledData = scalerModel.transform(dataFrame); +{% endhighlight %} +
    +
    + ## Bucketizer `Bucketizer` transforms a column of continuous features to a column of feature buckets, where the buckets are specified by users. It takes a parameter: @@ -942,7 +1511,7 @@ val bucketedData = bucketizer.transform(dataFrame)
    {% highlight java %} -import com.google.common.collect.Lists; +import java.util.Arrays; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; @@ -954,7 +1523,7 @@ import org.apache.spark.sql.types.StructType; double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; -JavaRDD data = jsc.parallelize(Lists.newArrayList( +JavaRDD data = jsc.parallelize(Arrays.asList( RowFactory.create(-0.5), RowFactory.create(-0.3), RowFactory.create(0.0), @@ -1019,7 +1588,7 @@ v_N This example below demonstrates how to transform vectors using a transforming vector value.
    -
    +
    {% highlight scala %} import org.apache.spark.ml.feature.ElementwiseProduct import org.apache.spark.mllib.linalg.Vectors @@ -1036,14 +1605,14 @@ val transformer = new ElementwiseProduct() .setOutputCol("transformedVector") // Batch transform the vectors to create new column: -val transformedData = transformer.transform(dataFrame) +transformer.transform(dataFrame).show() {% endhighlight %}
    -
    +
    {% highlight java %} -import com.google.common.collect.Lists; +import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.ElementwiseProduct; @@ -1059,7 +1628,7 @@ import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; // Create some vector data; also works for sparse vectors -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( +JavaRDD jrdd = jsc.parallelize(Arrays.asList( RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) )); @@ -1074,10 +1643,25 @@ ElementwiseProduct transformer = new ElementwiseProduct() .setInputCol("vector") .setOutputCol("transformedVector"); // Batch transform the vectors to create new column: -DataFrame transformedData = transformer.transform(dataFrame); +transformer.transform(dataFrame).show(); + +{% endhighlight %} +
    + +
    +{% highlight python %} +from pyspark.ml.feature import ElementwiseProduct +from pyspark.mllib.linalg import Vectors + +data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] +df = sqlContext.createDataFrame(data, ["vector"]) +transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), + inputCol="vector", outputCol="transformedVector") +transformer.transform(df).show() {% endhighlight %}
    +
    ## VectorAssembler @@ -1196,3 +1780,242 @@ print(output.select("features", "clicked").first()) # Feature Selectors +## VectorSlicer + +`VectorSlicer` is a transformer that takes a feature vector and outputs a new feature vector with a +sub-array of the original features. It is useful for extracting features from a vector column. + +`VectorSlicer` accepts a vector column with a specified indices, then outputs a new vector column +whose values are selected via those indices. There are two types of indices, + + 1. Integer indices that represents the indices into the vector, `setIndices()`; + + 2. String indices that represents the names of features into the vector, `setNames()`. + *This requires the vector column to have an `AttributeGroup` since the implementation matches on + the name field of an `Attribute`.* + +Specification by integer and string are both acceptable. Moreover, you can use integer index and +string name simultaneously. At least one feature must be selected. Duplicate features are not +allowed, so there can be no overlap between selected indices and names. Note that if names of +features are selected, an exception will be threw out when encountering with empty input attributes. + +The output vector will order features with the selected indices first (in the order given), +followed by the selected names (in the order given). + +**Examples** + +Suppose that we have a DataFrame with the column `userFeatures`: + +~~~ + userFeatures +------------------ + [0.0, 10.0, 0.5] +~~~ + +`userFeatures` is a vector column that contains three user features. Assuming that the first column +of `userFeatures` are all zeros, so we want to remove it and only the last two columns are selected. +The `VectorSlicer` selects the last two elements with `setIndices(1, 2)` then produces a new vector +column named `features`: + +~~~ + userFeatures | features +------------------|----------------------------- + [0.0, 10.0, 0.5] | [10.0, 0.5] +~~~ + +Suppose also that we have a potential input attributes for the `userFeatures`, i.e. +`["f1", "f2", "f3"]`, then we can use `setNames("f2", "f3")` to select them. + +~~~ + userFeatures | features +------------------|----------------------------- + [0.0, 10.0, 0.5] | [10.0, 0.5] + ["f1", "f2", "f3"] | ["f2", "f3"] +~~~ + +
    +
    + +[`VectorSlicer`](api/scala/index.html#org.apache.spark.ml.feature.VectorSlicer) takes an input +column name with specified indices or names and an output column name. + +{% highlight scala %} +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} +import org.apache.spark.ml.feature.VectorSlicer +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +val data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(-2.0, 2.3, 0.0) +) + +val defaultAttr = NumericAttribute.defaultAttr +val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) +val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) + +val dataRDD = sc.parallelize(data).map(Row.apply) +val dataset = sqlContext.createDataFrame(dataRDD, StructType(attrGroup.toStructField())) + +val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") + +slicer.setIndices(1).setNames("f3") +// or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) + +val output = slicer.transform(dataset) +println(output.select("userFeatures", "features").first()) +{% endhighlight %} +
    + +
    + +[`VectorSlicer`](api/java/org/apache/spark/ml/feature/VectorSlicer.html) takes an input column name +with specified indices or names and an output column name. + +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; + +Attribute[] attrs = new Attribute[]{ + NumericAttribute.defaultAttr().withName("f1"), + NumericAttribute.defaultAttr().withName("f2"), + NumericAttribute.defaultAttr().withName("f3") +}; +AttributeGroup group = new AttributeGroup("userFeatures", attrs); + +JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), + RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) +)); + +DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); + +VectorSlicer vectorSlicer = new VectorSlicer() + .setInputCol("userFeatures").setOutputCol("features"); + +vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); +// or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) + +DataFrame output = vectorSlicer.transform(dataset); + +System.out.println(output.select("userFeatures", "features").first()); +{% endhighlight %} +
    +
    + +## RFormula + +`RFormula` selects columns specified by an [R model formula](https://stat.ethz.ch/R-manual/R-devel/library/stats/html/formula.html). It produces a vector column of features and a double column of labels. Like when formulas are used in R for linear regression, string input columns will be one-hot encoded, and numeric columns will be cast to doubles. If not already present in the DataFrame, the output label column will be created from the specified response variable in the formula. + +**Examples** + +Assume that we have a DataFrame with the columns `id`, `country`, `hour`, and `clicked`: + +~~~ +id | country | hour | clicked +---|---------|------|--------- + 7 | "US" | 18 | 1.0 + 8 | "CA" | 12 | 0.0 + 9 | "NZ" | 15 | 0.0 +~~~ + +If we use `RFormula` with a formula string of `clicked ~ country + hour`, which indicates that we want to +predict `clicked` based on `country` and `hour`, after transformation we should get the following DataFrame: + +~~~ +id | country | hour | clicked | features | label +---|---------|------|---------|------------------|------- + 7 | "US" | 18 | 1.0 | [0.0, 0.0, 18.0] | 1.0 + 8 | "CA" | 12 | 0.0 | [0.0, 1.0, 12.0] | 0.0 + 9 | "NZ" | 15 | 0.0 | [1.0, 0.0, 15.0] | 0.0 +~~~ + +
    +
    + +[`RFormula`](api/scala/index.html#org.apache.spark.ml.feature.RFormula) takes an R formula string, and optional parameters for the names of its output columns. + +{% highlight scala %} +import org.apache.spark.ml.feature.RFormula + +val dataset = sqlContext.createDataFrame(Seq( + (7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0) +)).toDF("id", "country", "hour", "clicked") +val formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label") +val output = formula.fit(dataset).transform(dataset) +output.select("features", "label").show() +{% endhighlight %} +
    + +
    + +[`RFormula`](api/java/org/apache/spark/ml/feature/RFormula.html) takes an R formula string, and optional parameters for the names of its output columns. + +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RFormula; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; + +StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("country", StringType, false), + createStructField("hour", IntegerType, false), + createStructField("clicked", DoubleType, false) +}); +JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(7, "US", 18, 1.0), + RowFactory.create(8, "CA", 12, 0.0), + RowFactory.create(9, "NZ", 15, 0.0) +)); +DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + +RFormula formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label"); + +DataFrame output = formula.fit(dataset).transform(dataset); +output.select("features", "label").show(); +{% endhighlight %} +
    + +
    + +[`RFormula`](api/python/pyspark.ml.html#pyspark.ml.feature.RFormula) takes an R formula string, and optional parameters for the names of its output columns. + +{% highlight python %} +from pyspark.ml.feature import RFormula + +dataset = sqlContext.createDataFrame( + [(7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0)], + ["id", "country", "hour", "clicked"]) +formula = RFormula( + formula="clicked ~ country + hour", + featuresCol="features", + labelCol="label") +output = formula.fit(dataset).transform(dataset) +output.select("features", "label").show() +{% endhighlight %} +
    +
    diff --git a/docs/ml-guide.md b/docs/ml-guide.md index c74cb1f1ef8e..c5d7f990021f 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -3,75 +3,109 @@ layout: global title: Spark ML Programming Guide --- -Spark 1.2 introduced a new package called `spark.ml`, which aims to provide a uniform set of -high-level APIs that help users create and tune practical machine learning pipelines. +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + + +The `spark.ml` package aims to provide a uniform set of high-level APIs built on top of +[DataFrames](sql-programming-guide.html#dataframes) that help users create and tune practical +machine learning pipelines. +See the [algorithm guides](#algorithm-guides) section below for guides on sub-packages of +`spark.ml`, including feature transformers unique to the Pipelines API, ensembles, and more. + +**Table of contents** -*Graduated from Alpha!* The Pipelines API is no longer an alpha component, although many elements of it are still `Experimental` or `DeveloperApi`. - -Note that we will keep supporting and adding features to `spark.mllib` along with the -development of `spark.ml`. -Users should be comfortable using `spark.mllib` features and expect more features coming. -Developers should contribute new algorithms to `spark.mllib` and can optionally contribute -to `spark.ml`. - -Guides for sub-packages of `spark.ml` include: +* This will become a table of contents (this text will be scraped). +{:toc} -* [Feature Extraction, Transformation, and Selection](ml-features.html): Details on transformers supported in the Pipelines API, including a few not in the lower-level `spark.mllib` API -* [Ensembles](ml-ensembles.html): Details on ensemble learning methods in the Pipelines API +# Algorithm guides +We provide several algorithm guides specific to the Pipelines API. +Several of these algorithms, such as certain feature transformers, are not in the `spark.mllib` API. +Also, some algorithms have additional capabilities in the `spark.ml` API; e.g., random forests +provide class probabilities, and linear models provide model summaries. -**Table of Contents** +* [Feature extraction, transformation, and selection](ml-features.html) +* [Decision Trees for classification and regression](ml-decision-tree.html) +* [Ensembles](ml-ensembles.html) +* [Linear methods with elastic net regularization](ml-linear-methods.html) +* [Multilayer perceptron classifier](ml-ann.html) -* This will become a table of contents (this text will be scraped). -{:toc} -# Main Concepts +# Main concepts in Pipelines -Spark ML standardizes APIs for machine learning algorithms to make it easier to combine multiple algorithms into a single pipeline, or workflow. This section covers the key concepts introduced by the Spark ML API. +Spark ML standardizes APIs for machine learning algorithms to make it easier to combine multiple +algorithms into a single pipeline, or workflow. +This section covers the key concepts introduced by the Spark ML API, where the pipeline concept is +mostly inspired by the [scikit-learn](http://scikit-learn.org/) project. -* **[ML Dataset](ml-guide.html#ml-dataset)**: Spark ML uses the [`DataFrame`](api/scala/index.html#org.apache.spark.sql.DataFrame) from Spark SQL as a dataset which can hold a variety of data types. -E.g., a dataset could have different columns storing text, feature vectors, true labels, and predictions. +* **[`DataFrame`](ml-guide.html#dataframe)**: Spark ML uses `DataFrame` from Spark SQL as an ML + dataset, which can hold a variety of data types. + E.g., a `DataFrame` could have different columns storing text, feature vectors, true labels, and predictions. * **[`Transformer`](ml-guide.html#transformers)**: A `Transformer` is an algorithm which can transform one `DataFrame` into another `DataFrame`. -E.g., an ML model is a `Transformer` which transforms an RDD with features into an RDD with predictions. +E.g., an ML model is a `Transformer` which transforms `DataFrame` with features into a `DataFrame` with predictions. * **[`Estimator`](ml-guide.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `DataFrame` to produce a `Transformer`. -E.g., a learning algorithm is an `Estimator` which trains on a dataset and produces a model. +E.g., a learning algorithm is an `Estimator` which trains on a `DataFrame` and produces a model. * **[`Pipeline`](ml-guide.html#pipeline)**: A `Pipeline` chains multiple `Transformer`s and `Estimator`s together to specify an ML workflow. -* **[`Param`](ml-guide.html#parameters)**: All `Transformer`s and `Estimator`s now share a common API for specifying parameters. +* **[`Parameter`](ml-guide.html#parameters)**: All `Transformer`s and `Estimator`s now share a common API for specifying parameters. -## ML Dataset +## DataFrame Machine learning can be applied to a wide variety of data types, such as vectors, text, images, and structured data. -Spark ML adopts the [`DataFrame`](api/scala/index.html#org.apache.spark.sql.DataFrame) from Spark SQL in order to support a variety of data types under a unified Dataset concept. +Spark ML adopts the `DataFrame` from Spark SQL in order to support a variety of data types. `DataFrame` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#spark-sql-datatype-reference) for a list of supported types. -In addition to the types listed in the Spark SQL guide, `DataFrame` can use ML [`Vector`](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) types. +In addition to the types listed in the Spark SQL guide, `DataFrame` can use ML [`Vector`](mllib-data-types.html#local-vector) types. A `DataFrame` can be created either implicitly or explicitly from a regular `RDD`. See the code examples below and the [Spark SQL programming guide](sql-programming-guide.html) for examples. Columns in a `DataFrame` are named. The code examples below use names such as "text," "features," and "label." -## ML Algorithms +## Pipeline components ### Transformers -A [`Transformer`](api/scala/index.html#org.apache.spark.ml.Transformer) is an abstraction which includes feature transformers and learned models. Technically, a `Transformer` implements a method `transform()` which converts one `DataFrame` into another, generally by appending one or more columns. +A `Transformer` is an abstraction that includes feature transformers and learned models. +Technically, a `Transformer` implements a method `transform()`, which converts one `DataFrame` into +another, generally by appending one or more columns. For example: -* A feature transformer might take a dataset, read a column (e.g., text), convert it into a new column (e.g., feature vectors), append the new column to the dataset, and output the updated dataset. -* A learning model might take a dataset, read the column containing feature vectors, predict the label for each feature vector, append the labels as a new column, and output the updated dataset. +* A feature transformer might take a `DataFrame`, read a column (e.g., text), map it into a new + column (e.g., feature vectors), and output a new `DataFrame` with the mapped column appended. +* A learning model might take a `DataFrame`, read the column containing feature vectors, predict the + label for each feature vector, and output a new `DataFrame` with predicted labels appended as a + column. ### Estimators -An [`Estimator`](api/scala/index.html#org.apache.spark.ml.Estimator) abstracts the concept of a learning algorithm or any algorithm which fits or trains on data. Technically, an `Estimator` implements a method `fit()` which accepts a `DataFrame` and produces a `Transformer`. -For example, a learning algorithm such as `LogisticRegression` is an `Estimator`, and calling `fit()` trains a `LogisticRegressionModel`, which is a `Transformer`. +An `Estimator` abstracts the concept of a learning algorithm or any algorithm that fits or trains on +data. +Technically, an `Estimator` implements a method `fit()`, which accepts a `DataFrame` and produces a +`Model`, which is a `Transformer`. +For example, a learning algorithm such as `LogisticRegression` is an `Estimator`, and calling +`fit()` trains a `LogisticRegressionModel`, which is a `Model` and hence a `Transformer`. -### Properties of ML Algorithms +### Properties of pipeline components -`Transformer`s and `Estimator`s are both stateless. In the future, stateful algorithms may be supported via alternative concepts. +`Transformer.transform()`s and `Estimator.fit()`s are both stateless. In the future, stateful algorithms may be supported via alternative concepts. Each instance of a `Transformer` or `Estimator` has a unique ID, which is useful in specifying parameters (discussed below). @@ -84,15 +118,16 @@ E.g., a simple text document processing workflow might include several stages: * Convert each document's words into a numerical feature vector. * Learn a prediction model using the feature vectors and labels. -Spark ML represents such a workflow as a [`Pipeline`](api/scala/index.html#org.apache.spark.ml.Pipeline), -which consists of a sequence of [`PipelineStage`s](api/scala/index.html#org.apache.spark.ml.PipelineStage) (`Transformer`s and `Estimator`s) to be run in a specific order. We will use this simple workflow as a running example in this section. +Spark ML represents such a workflow as a `Pipeline`, which consists of a sequence of +`PipelineStage`s (`Transformer`s and `Estimator`s) to be run in a specific order. +We will use this simple workflow as a running example in this section. -### How It Works +### How it works A `Pipeline` is specified as a sequence of stages, and each stage is either a `Transformer` or an `Estimator`. -These stages are run in order, and the input dataset is modified as it passes through each stage. -For `Transformer` stages, the `transform()` method is called on the dataset. -For `Estimator` stages, the `fit()` method is called to produce a `Transformer` (which becomes part of the `PipelineModel`, or fitted `Pipeline`), and that `Transformer`'s `transform()` method is called on the dataset. +These stages are run in order, and the input `DataFrame` is transformed as it passes through each stage. +For `Transformer` stages, the `transform()` method is called on the `DataFrame`. +For `Estimator` stages, the `fit()` method is called to produce a `Transformer` (which becomes part of the `PipelineModel`, or fitted `Pipeline`), and that `Transformer`'s `transform()` method is called on the `DataFrame`. We illustrate this for the simple text document workflow. The figure below is for the *training time* usage of a `Pipeline`. @@ -108,14 +143,17 @@ We illustrate this for the simple text document workflow. The figure below is f Above, the top row represents a `Pipeline` with three stages. The first two (`Tokenizer` and `HashingTF`) are `Transformer`s (blue), and the third (`LogisticRegression`) is an `Estimator` (red). The bottom row represents data flowing through the pipeline, where cylinders indicate `DataFrame`s. -The `Pipeline.fit()` method is called on the original dataset which has raw text documents and labels. -The `Tokenizer.transform()` method splits the raw text documents into words, adding a new column with words into the dataset. -The `HashingTF.transform()` method converts the words column into feature vectors, adding a new column with those vectors to the dataset. +The `Pipeline.fit()` method is called on the original `DataFrame`, which has raw text documents and labels. +The `Tokenizer.transform()` method splits the raw text documents into words, adding a new column with words to the `DataFrame`. +The `HashingTF.transform()` method converts the words column into feature vectors, adding a new column with those vectors to the `DataFrame`. Now, since `LogisticRegression` is an `Estimator`, the `Pipeline` first calls `LogisticRegression.fit()` to produce a `LogisticRegressionModel`. -If the `Pipeline` had more stages, it would call the `LogisticRegressionModel`'s `transform()` method on the dataset before passing the dataset to the next stage. +If the `Pipeline` had more stages, it would call the `LogisticRegressionModel`'s `transform()` +method on the `DataFrame` before passing the `DataFrame` to the next stage. A `Pipeline` is an `Estimator`. -Thus, after a `Pipeline`'s `fit()` method runs, it produces a `PipelineModel` which is a `Transformer`. This `PipelineModel` is used at *test time*; the figure below illustrates this usage. +Thus, after a `Pipeline`'s `fit()` method runs, it produces a `PipelineModel`, which is a +`Transformer`. +This `PipelineModel` is used at *test time*; the figure below illustrates this usage.

    In the figure above, the `PipelineModel` has the same number of stages as the original `Pipeline`, but all `Estimator`s in the original `Pipeline` have become `Transformer`s. -When the `PipelineModel`'s `transform()` method is called on a test dataset, the data are passed through the `Pipeline` in order. +When the `PipelineModel`'s `transform()` method is called on a test dataset, the data are passed +through the fitted pipeline in order. Each stage's `transform()` method updates the dataset and passes it to the next stage. `Pipeline`s and `PipelineModel`s help to ensure that training and test data go through identical feature processing steps. @@ -136,28 +175,43 @@ Each stage's `transform()` method updates the dataset and passes it to the next *DAG `Pipeline`s*: A `Pipeline`'s stages are specified as an ordered array. The examples given here are all for linear `Pipeline`s, i.e., `Pipeline`s in which each stage uses data produced by the previous stage. It is possible to create non-linear `Pipeline`s as long as the data flow graph forms a Directed Acyclic Graph (DAG). This graph is currently specified implicitly based on the input and output column names of each stage (generally specified as parameters). If the `Pipeline` forms a DAG, then the stages must be specified in topological order. -*Runtime checking*: Since `Pipeline`s can operate on datasets with varied types, they cannot use compile-time type checking. `Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. This type checking is done using the dataset *schema*, a description of the data types of columns in the `DataFrame`. +*Runtime checking*: Since `Pipeline`s can operate on `DataFrame`s with varied types, they cannot use +compile-time type checking. +`Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. +This type checking is done using the `DataFrame` *schema*, a description of the data types of columns in the `DataFrame`. + +*Unique Pipeline stages*: A `Pipeline`'s stages should be unique instances. E.g., the same instance +`myHashingTF` should not be inserted into the `Pipeline` twice since `Pipeline` stages must have +unique IDs. However, different instances `myHashingTF1` and `myHashingTF2` (both of type `HashingTF`) +can be put into the same `Pipeline` since different instances will be created with different IDs. ## Parameters Spark ML `Estimator`s and `Transformer`s use a uniform API for specifying parameters. -A [`Param`](api/scala/index.html#org.apache.spark.ml.param.Param) is a named parameter with self-contained documentation. -A [`ParamMap`](api/scala/index.html#org.apache.spark.ml.param.ParamMap) is a set of (parameter, value) pairs. +A `Param` is a named parameter with self-contained documentation. +A `ParamMap` is a set of (parameter, value) pairs. There are two main ways to pass parameters to an algorithm: -1. Set parameters for an instance. E.g., if `lr` is an instance of `LogisticRegression`, one could call `lr.setMaxIter(10)` to make `lr.fit()` use at most 10 iterations. This API resembles the API used in MLlib. +1. Set parameters for an instance. E.g., if `lr` is an instance of `LogisticRegression`, one could + call `lr.setMaxIter(10)` to make `lr.fit()` use at most 10 iterations. + This API resembles the API used in `spark.mllib` package. 2. Pass a `ParamMap` to `fit()` or `transform()`. Any parameters in the `ParamMap` will override parameters previously specified via setter methods. Parameters belong to specific instances of `Estimator`s and `Transformer`s. For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. -# Code Examples +# Code examples This section gives code examples illustrating the functionality discussed above. -There is not yet documentation for specific algorithms in Spark ML. For more info, please refer to the [API Documentation](api/scala/index.html#org.apache.spark.ml.package). Spark ML algorithms are currently wrappers for MLlib algorithms, and the [MLlib programming guide](mllib-guide.html) has details on specific algorithms. +For more info, please refer to the API documentation +([Scala](api/scala/index.html#org.apache.spark.ml.package), +[Java](api/java/org/apache/spark/ml/package-summary.html), +and [Python](api/python/pyspark.ml.html)). +Some Spark ML algorithms are wrappers for `spark.mllib` algorithms, and the +[MLlib programming guide](mllib-guide.html) has details on specific algorithms. ## Example: Estimator, Transformer, and Param @@ -167,26 +221,18 @@ This example covers the concepts of `Estimator`, `Transformer`, and `Param`.

    {% highlight scala %} -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.param.ParamMap import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.sql.{Row, SQLContext} - -val conf = new SparkConf().setAppName("SimpleParamsExample") -val sc = new SparkContext(conf) -val sqlContext = new SQLContext(sc) -import sqlContext.implicits._ +import org.apache.spark.sql.Row -// Prepare training data. -// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes -// into DataFrames, where it uses the case class metadata to infer the schema. -val training = sc.parallelize(Seq( - LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), - LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), - LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), - LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)))) +// Prepare training data from a list of (label, features) tuples. +val training = sqlContext.createDataFrame(Seq( + (1.0, Vectors.dense(0.0, 1.1, 0.1)), + (0.0, Vectors.dense(2.0, 1.0, -1.0)), + (0.0, Vectors.dense(2.0, 1.3, 1.0)), + (1.0, Vectors.dense(0.0, 1.2, -0.5)) +)).toDF("label", "features") // Create a LogisticRegression instance. This instance is an Estimator. val lr = new LogisticRegression() @@ -198,7 +244,7 @@ lr.setMaxIter(10) .setRegParam(0.01) // Learn a LogisticRegression model. This uses the parameters stored in lr. -val model1 = lr.fit(training.toDF) +val model1 = lr.fit(training) // Since model1 is a Model (i.e., a Transformer produced by an Estimator), // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this @@ -208,8 +254,8 @@ println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) // We may alternatively specify parameters using a ParamMap, // which supports several methods for specifying parameters. val paramMap = ParamMap(lr.maxIter -> 20) -paramMap.put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. -paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. + .put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. + .put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. // One can also combine ParamMaps. val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name @@ -217,58 +263,52 @@ val paramMapCombined = paramMap ++ paramMap2 // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. -val model2 = lr.fit(training.toDF, paramMapCombined) +val model2 = lr.fit(training, paramMapCombined) println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) // Prepare test data. -val test = sc.parallelize(Seq( - LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), - LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), - LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) +val test = sqlContext.createDataFrame(Seq( + (1.0, Vectors.dense(-1.0, 1.5, 1.3)), + (0.0, Vectors.dense(3.0, 2.0, -0.1)), + (1.0, Vectors.dense(0.0, 2.2, -1.5)) +)).toDF("label", "features") // Make predictions on test data using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. // Note that model2.transform() outputs a 'myProbability' column instead of the usual // 'probability' column since we renamed the lr.probabilityCol parameter previously. -model2.transform(test.toDF) +model2.transform(test) .select("features", "label", "myProbability", "prediction") .collect() .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) => println(s"($features, $label) -> prob=$prob, prediction=$prediction") } -sc.stop() {% endhighlight %}
    {% highlight java %} +import java.util.Arrays; import java.util.List; -import com.google.common.collect.Lists; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; + import org.apache.spark.ml.classification.LogisticRegressionModel; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; -SparkConf conf = new SparkConf().setAppName("JavaSimpleParamsExample"); -JavaSparkContext jsc = new JavaSparkContext(conf); -SQLContext jsql = new SQLContext(jsc); - // Prepare training data. // We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans // into DataFrames, where it uses the bean metadata to infer the schema. -List localTraining = Lists.newArrayList( +DataFrame training = sqlContext.createDataFrame(Arrays.asList( new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), - new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); -DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class); + new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)) +), LabeledPoint.class); // Create a LogisticRegression instance. This instance is an Estimator. LogisticRegression lr = new LogisticRegression(); @@ -288,14 +328,14 @@ LogisticRegressionModel model1 = lr.fit(training); System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap()); // We may alternatively specify parameters using a ParamMap. -ParamMap paramMap = new ParamMap(); -paramMap.put(lr.maxIter().w(20)); // Specify 1 Param. -paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter. -paramMap.put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. +ParamMap paramMap = new ParamMap() + .put(lr.maxIter().w(20)) // Specify 1 Param. + .put(lr.maxIter(), 30) // This overwrites the original maxIter. + .put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. // One can also combine ParamMaps. -ParamMap paramMap2 = new ParamMap(); -paramMap2.put(lr.probabilityCol().w("myProbability")); // Change output column name +ParamMap paramMap2 = new ParamMap() + .put(lr.probabilityCol().w("myProbability")); // Change output column name ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); // Now learn a new model using the paramMapCombined parameters. @@ -304,11 +344,11 @@ LogisticRegressionModel model2 = lr.fit(training, paramMapCombined); System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap()); // Prepare test documents. -List localTest = Lists.newArrayList( - new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), - new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), - new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); -DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); +DataFrame test = sqlContext.createDataFrame(Arrays.asList( + new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), + new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)) +), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. @@ -320,7 +360,68 @@ for (Row r: results.select("features", "label", "myProbability", "prediction").c + ", prediction=" + r.get(3)); } -jsc.stop(); +{% endhighlight %} +
    + +
    +{% highlight python %} +from pyspark.mllib.linalg import Vectors +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.param import Param, Params + +# Prepare training data from a list of (label, features) tuples. +training = sqlContext.createDataFrame([ + (1.0, Vectors.dense([0.0, 1.1, 0.1])), + (0.0, Vectors.dense([2.0, 1.0, -1.0])), + (0.0, Vectors.dense([2.0, 1.3, 1.0])), + (1.0, Vectors.dense([0.0, 1.2, -0.5]))], ["label", "features"]) + +# Create a LogisticRegression instance. This instance is an Estimator. +lr = LogisticRegression(maxIter=10, regParam=0.01) +# Print out the parameters, documentation, and any default values. +print "LogisticRegression parameters:\n" + lr.explainParams() + "\n" + +# Learn a LogisticRegression model. This uses the parameters stored in lr. +model1 = lr.fit(training) + +# Since model1 is a Model (i.e., a transformer produced by an Estimator), +# we can view the parameters it used during fit(). +# This prints the parameter (name: value) pairs, where names are unique IDs for this +# LogisticRegression instance. +print "Model 1 was fit using parameters: " +print model1.extractParamMap() + +# We may alternatively specify parameters using a Python dictionary as a paramMap +paramMap = {lr.maxIter: 20} +paramMap[lr.maxIter] = 30 # Specify 1 Param, overwriting the original maxIter. +paramMap.update({lr.regParam: 0.1, lr.threshold: 0.55}) # Specify multiple Params. + +# You can combine paramMaps, which are python dictionaries. +paramMap2 = {lr.probabilityCol: "myProbability"} # Change output column name +paramMapCombined = paramMap.copy() +paramMapCombined.update(paramMap2) + +# Now learn a new model using the paramMapCombined parameters. +# paramMapCombined overrides all parameters set earlier via lr.set* methods. +model2 = lr.fit(training, paramMapCombined) +print "Model 2 was fit using parameters: " +print model2.extractParamMap() + +# Prepare test data +test = sqlContext.createDataFrame([ + (1.0, Vectors.dense([-1.0, 1.5, 1.3])), + (0.0, Vectors.dense([3.0, 2.0, -0.1])), + (1.0, Vectors.dense([0.0, 2.2, -1.5]))], ["label", "features"]) + +# Make predictions on test data using the Transformer.transform() method. +# LogisticRegression.transform will only use the 'features' column. +# Note that model2.transform() outputs a "myProbability" column instead of the usual +# 'probability' column since we renamed the lr.probabilityCol parameter previously. +prediction = model2.transform(test) +selected = prediction.select("features", "label", "myProbability", "prediction") +for row in selected.collect(): + print row + {% endhighlight %}
    @@ -334,30 +435,19 @@ This example follows the simple text document `Pipeline` illustrated in the figu
    {% highlight scala %} -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.{HashingTF, Tokenizer} import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.{Row, SQLContext} - -// Labeled and unlabeled instance types. -// Spark SQL can infer schema from case classes. -case class LabeledDocument(id: Long, text: String, label: Double) -case class Document(id: Long, text: String) +import org.apache.spark.sql.Row -// Set up contexts. Import implicit conversions to DataFrame from sqlContext. -val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline") -val sc = new SparkContext(conf) -val sqlContext = new SQLContext(sc) -import sqlContext.implicits._ - -// Prepare training documents, which are labeled. -val training = sc.parallelize(Seq( - LabeledDocument(0L, "a b c d e spark", 1.0), - LabeledDocument(1L, "b d", 0.0), - LabeledDocument(2L, "spark f g h", 1.0), - LabeledDocument(3L, "hadoop mapreduce", 0.0))) +// Prepare training documents from a list of (id, text, label) tuples. +val training = sqlContext.createDataFrame(Seq( + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0) +)).toDF("id", "text", "label") // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. val tokenizer = new Tokenizer() @@ -374,14 +464,15 @@ val pipeline = new Pipeline() .setStages(Array(tokenizer, hashingTF, lr)) // Fit the pipeline to training documents. -val model = pipeline.fit(training.toDF) +val model = pipeline.fit(training) -// Prepare test documents, which are unlabeled. -val test = sc.parallelize(Seq( - Document(4L, "spark i j k"), - Document(5L, "l m n"), - Document(6L, "mapreduce spark"), - Document(7L, "apache hadoop"))) +// Prepare test documents, which are unlabeled (id, text) tuples. +val test = sqlContext.createDataFrame(Seq( + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop") +)).toDF("id", "text") // Make predictions on test documents. model.transform(test.toDF) @@ -391,16 +482,14 @@ model.transform(test.toDF) println(s"($id, $text) --> prob=$prob, prediction=$prediction") } -sc.stop() {% endhighlight %}
    {% highlight java %} +import java.util.Arrays; import java.util.List; -import com.google.common.collect.Lists; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; + import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineModel; import org.apache.spark.ml.PipelineStage; @@ -409,7 +498,6 @@ import org.apache.spark.ml.feature.HashingTF; import org.apache.spark.ml.feature.Tokenizer; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; // Labeled and unlabeled instance types. // Spark SQL can infer schema from Java Beans. @@ -441,18 +529,13 @@ public class LabeledDocument extends Document implements Serializable { public void setLabel(double label) { this.label = label; } } -// Set up contexts. -SparkConf conf = new SparkConf().setAppName("JavaSimpleTextClassificationPipeline"); -JavaSparkContext jsc = new JavaSparkContext(conf); -SQLContext jsql = new SQLContext(jsc); - // Prepare training documents, which are labeled. -List localTraining = Lists.newArrayList( +DataFrame training = sqlContext.createDataFrame(Arrays.asList( new LabeledDocument(0L, "a b c d e spark", 1.0), new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), - new LabeledDocument(3L, "hadoop mapreduce", 0.0)); -DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); + new LabeledDocument(3L, "hadoop mapreduce", 0.0) +), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -472,12 +555,12 @@ Pipeline pipeline = new Pipeline() PipelineModel model = pipeline.fit(training); // Prepare test documents, which are unlabeled. -List localTest = Lists.newArrayList( +DataFrame test = sqlContext.createDataFrame(Arrays.asList( new Document(4L, "spark i j k"), new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), - new Document(7L, "apache hadoop")); -DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); + new Document(7L, "apache hadoop") +), Document.class); // Make predictions on test documents. DataFrame predictions = model.transform(test); @@ -486,28 +569,23 @@ for (Row r: predictions.select("id", "text", "probability", "prediction").collec + ", prediction=" + r.get(3)); } -jsc.stop(); {% endhighlight %}
    {% highlight python %} -from pyspark import SparkContext from pyspark.ml import Pipeline from pyspark.ml.classification import LogisticRegression from pyspark.ml.feature import HashingTF, Tokenizer -from pyspark.sql import Row, SQLContext - -sc = SparkContext(appName="SimpleTextClassificationPipeline") -sqlContext = SQLContext(sc) +from pyspark.sql import Row -# Prepare training documents, which are labeled. +# Prepare training documents from a list of (id, text, label) tuples. LabeledDocument = Row("id", "text", "label") -training = sc.parallelize([(0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0)]) \ - .map(lambda x: LabeledDocument(*x)).toDF() +training = sqlContext.createDataFrame([ + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0)], ["id", "text", "label"]) # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. tokenizer = Tokenizer(inputCol="text", outputCol="words") @@ -518,27 +596,25 @@ pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) # Fit the pipeline to training documents. model = pipeline.fit(training) -# Prepare test documents, which are unlabeled. -Document = Row("id", "text") -test = sc.parallelize([(4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop")]) \ - .map(lambda x: Document(*x)).toDF() +# Prepare test documents, which are unlabeled (id, text) tuples. +test = sqlContext.createDataFrame([ + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")], ["id", "text"]) # Make predictions on test documents and print columns of interest. prediction = model.transform(test) selected = prediction.select("id", "text", "prediction") for row in selected.collect(): - print row + print(row) -sc.stop() {% endhighlight %}
    -## Example: Model Selection via Cross-Validation +## Example: model selection via cross-validation An important task in ML is *model selection*, or using data to find the best model or parameters for a given task. This is also called *tuning*. `Pipeline`s facilitate model selection by making it easy to tune an entire `Pipeline` at once, rather than tuning each element in the `Pipeline` separately. @@ -546,6 +622,13 @@ An important task in ML is *model selection*, or using data to find the best mod Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.Evaluator). `CrossValidator` begins by splitting the dataset into a set of *folds* which are used as separate training and test datasets; e.g., with `$k=3$` folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. `CrossValidator` iterates through the set of `ParamMap`s. For each `ParamMap`, it trains the given `Estimator` and evaluates it using the given `Evaluator`. + +The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.RegressionEvaluator) +for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.BinaryClassificationEvaluator) +for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.MultiClassClassificationEvaluator) +for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetric` +method in each of these evaluators. + The `ParamMap` which produces the best evaluation metric (averaged over the `$k$` folds) is selected as the best model. `CrossValidator` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. @@ -562,39 +645,29 @@ However, it is also a well-established method for choosing parameters which is m
    {% highlight scala %} -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.ml.feature.{HashingTF, Tokenizer} import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.{Row, SQLContext} - -// Labeled and unlabeled instance types. -// Spark SQL can infer schema from case classes. -case class LabeledDocument(id: Long, text: String, label: Double) -case class Document(id: Long, text: String) - -val conf = new SparkConf().setAppName("CrossValidatorExample") -val sc = new SparkContext(conf) -val sqlContext = new SQLContext(sc) -import sqlContext.implicits._ - -// Prepare training documents, which are labeled. -val training = sc.parallelize(Seq( - LabeledDocument(0L, "a b c d e spark", 1.0), - LabeledDocument(1L, "b d", 0.0), - LabeledDocument(2L, "spark f g h", 1.0), - LabeledDocument(3L, "hadoop mapreduce", 0.0), - LabeledDocument(4L, "b spark who", 1.0), - LabeledDocument(5L, "g d a y", 0.0), - LabeledDocument(6L, "spark fly", 1.0), - LabeledDocument(7L, "was mapreduce", 0.0), - LabeledDocument(8L, "e spark program", 1.0), - LabeledDocument(9L, "a e c l", 0.0), - LabeledDocument(10L, "spark compile", 1.0), - LabeledDocument(11L, "hadoop software", 0.0))) +import org.apache.spark.sql.Row + +// Prepare training data from a list of (id, text, label) tuples. +val training = sqlContext.createDataFrame(Seq( + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0), + (4L, "b spark who", 1.0), + (5L, "g d a y", 0.0), + (6L, "spark fly", 1.0), + (7L, "was mapreduce", 0.0), + (8L, "e spark program", 1.0), + (9L, "a e c l", 0.0), + (10L, "spark compile", 1.0), + (11L, "hadoop software", 0.0) +)).toDF("id", "text", "label") // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. val tokenizer = new Tokenizer() @@ -608,12 +681,6 @@ val lr = new LogisticRegression() val pipeline = new Pipeline() .setStages(Array(tokenizer, hashingTF, lr)) -// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. -// This will allow us to jointly choose parameters for all Pipeline stages. -// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -val crossval = new CrossValidator() - .setEstimator(pipeline) - .setEvaluator(new BinaryClassificationEvaluator) // We use a ParamGridBuilder to construct a grid of parameters to search over. // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. @@ -621,37 +688,45 @@ val paramGrid = new ParamGridBuilder() .addGrid(hashingTF.numFeatures, Array(10, 100, 1000)) .addGrid(lr.regParam, Array(0.1, 0.01)) .build() -crossval.setEstimatorParamMaps(paramGrid) -crossval.setNumFolds(2) // Use 3+ in practice + +// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. +// This will allow us to jointly choose parameters for all Pipeline stages. +// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric +// is areaUnderROC. +val cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(new BinaryClassificationEvaluator) + .setEstimatorParamMaps(paramGrid) + .setNumFolds(2) // Use 3+ in practice // Run cross-validation, and choose the best set of parameters. -val cvModel = crossval.fit(training.toDF) +val cvModel = cv.fit(training) -// Prepare test documents, which are unlabeled. -val test = sc.parallelize(Seq( - Document(4L, "spark i j k"), - Document(5L, "l m n"), - Document(6L, "mapreduce spark"), - Document(7L, "apache hadoop"))) +// Prepare test documents, which are unlabeled (id, text) tuples. +val test = sqlContext.createDataFrame(Seq( + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop") +)).toDF("id", "text") // Make predictions on test documents. cvModel uses the best model found (lrModel). -cvModel.transform(test.toDF) +cvModel.transform(test) .select("id", "text", "probability", "prediction") .collect() .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => - println(s"($id, $text) --> prob=$prob, prediction=$prediction") -} + println(s"($id, $text) --> prob=$prob, prediction=$prediction") + } -sc.stop() {% endhighlight %}
    {% highlight java %} +import java.util.Arrays; import java.util.List; -import com.google.common.collect.Lists; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; + import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineStage; import org.apache.spark.ml.classification.LogisticRegression; @@ -664,7 +739,6 @@ import org.apache.spark.ml.tuning.CrossValidatorModel; import org.apache.spark.ml.tuning.ParamGridBuilder; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; // Labeled and unlabeled instance types. // Spark SQL can infer schema from Java Beans. @@ -696,12 +770,9 @@ public class LabeledDocument extends Document implements Serializable { public void setLabel(double label) { this.label = label; } } -SparkConf conf = new SparkConf().setAppName("JavaCrossValidatorExample"); -JavaSparkContext jsc = new JavaSparkContext(conf); -SQLContext jsql = new SQLContext(jsc); // Prepare training documents, which are labeled. -List localTraining = Lists.newArrayList( +DataFrame training = sqlContext.createDataFrame(Arrays.asList( new LabeledDocument(0L, "a b c d e spark", 1.0), new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), @@ -713,8 +784,8 @@ List localTraining = Lists.newArrayList( new LabeledDocument(8L, "e spark program", 1.0), new LabeledDocument(9L, "a e c l", 0.0), new LabeledDocument(10L, "spark compile", 1.0), - new LabeledDocument(11L, "hadoop software", 0.0)); -DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); + new LabeledDocument(11L, "hadoop software", 0.0) +), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -730,12 +801,6 @@ LogisticRegression lr = new LogisticRegression() Pipeline pipeline = new Pipeline() .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); -// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. -// This will allow us to jointly choose parameters for all Pipeline stages. -// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -CrossValidator crossval = new CrossValidator() - .setEstimator(pipeline) - .setEvaluator(new BinaryClassificationEvaluator()); // We use a ParamGridBuilder to construct a grid of parameters to search over. // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. @@ -743,19 +808,28 @@ ParamMap[] paramGrid = new ParamGridBuilder() .addGrid(hashingTF.numFeatures(), new int[]{10, 100, 1000}) .addGrid(lr.regParam(), new double[]{0.1, 0.01}) .build(); -crossval.setEstimatorParamMaps(paramGrid); -crossval.setNumFolds(2); // Use 3+ in practice + +// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. +// This will allow us to jointly choose parameters for all Pipeline stages. +// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric +// is areaUnderROC. +CrossValidator cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(new BinaryClassificationEvaluator()) + .setEstimatorParamMaps(paramGrid) + .setNumFolds(2); // Use 3+ in practice // Run cross-validation, and choose the best set of parameters. -CrossValidatorModel cvModel = crossval.fit(training); +CrossValidatorModel cvModel = cv.fit(training); // Prepare test documents, which are unlabeled. -List localTest = Lists.newArrayList( +DataFrame test = sqlContext.createDataFrame(Arrays.asList( new Document(4L, "spark i j k"), new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), - new Document(7L, "apache hadoop")); -DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); + new Document(7L, "apache hadoop") +), Document.class); // Make predictions on test documents. cvModel uses the best model found (lrModel). DataFrame predictions = cvModel.transform(test); @@ -764,40 +838,121 @@ for (Row r: predictions.select("id", "text", "probability", "prediction").collec + ", prediction=" + r.get(3)); } -jsc.stop(); {% endhighlight %}
    -# Dependencies +## Example: model selection via train validation split +In addition to `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning. +`TrainValidationSplit` only evaluates each combination of parameters once as opposed to k times in + case of `CrossValidator`. It is therefore less expensive, + but will not produce as reliable results when the training dataset is not sufficiently large. + +`TrainValidationSplit` takes an `Estimator`, a set of `ParamMap`s provided in the `estimatorParamMaps` parameter, +and an `Evaluator`. +It begins by splitting the dataset into two parts using `trainRatio` parameter +which are used as separate training and test datasets. For example with `$trainRatio=0.75$` (default), +`TrainValidationSplit` will generate a training and test dataset pair where 75% of the data is used for training and 25% for validation. +Similar to `CrossValidator`, `TrainValidationSplit` also iterates through the set of `ParamMap`s. +For each combination of parameters, it trains the given `Estimator` and evaluates it using the given `Evaluator`. +The `ParamMap` which produces the best evaluation metric is selected as the best option. +`TrainValidationSplit` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. + +
    -Spark ML currently depends on MLlib and has the same dependencies. -Please see the [MLlib Dependencies guide](mllib-guide.html#dependencies) for more info. +
    +{% highlight scala %} +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} +import org.apache.spark.mllib.util.MLUtils -Spark ML also depends upon Spark SQL, but the relevant parts of Spark SQL do not bring additional dependencies. +// Prepare training and test data. +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) -# Migration Guide +val lr = new LinearRegression() -## From 1.3 to 1.4 +// We use a ParamGridBuilder to construct a grid of parameters to search over. +// TrainValidationSplit will try all combinations of values and determine best model using +// the evaluator. +val paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.01)) + .addGrid(lr.fitIntercept) + .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)) + .build() -Several major API changes occurred, including: -* `Param` and other APIs for specifying parameters -* `uid` unique IDs for Pipeline components -* Reorganization of certain classes -Since the `spark.ml` API was an Alpha Component in Spark 1.3, we do not list all changes here. +// In this case the estimator is simply the linear regression. +// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +val trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator) + .setEstimatorParamMaps(paramGrid) + // 80% of the data will be used for training and the remaining 20% for validation. + .setTrainRatio(0.8) -However, now that `spark.ml` is no longer an Alpha Component, we will provide details on any API changes for future releases. +// Run train validation split, and choose the best set of parameters. +val model = trainValidationSplit.fit(training) -## From 1.2 to 1.3 +// Make predictions on test data. model is the model with combination of parameters +// that performed best. +model.transform(test) + .select("features", "label", "prediction") + .show() -The main API changes are from Spark SQL. We list the most important changes here: +{% endhighlight %} +
    -* The old [SchemaRDD](http://spark.apache.org/docs/1.2.1/api/scala/index.html#org.apache.spark.sql.SchemaRDD) has been replaced with [DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame) with a somewhat modified API. All algorithms in Spark ML which used to use SchemaRDD now use DataFrame. -* In Spark 1.2, we used implicit conversions from `RDD`s of `LabeledPoint` into `SchemaRDD`s by calling `import sqlContext._` where `sqlContext` was an instance of `SQLContext`. These implicits have been moved, so we now call `import sqlContext.implicits._`. -* Java APIs for SQL have also changed accordingly. Please see the examples above and the [Spark SQL Programming Guide](sql-programming-guide.html) for details. +
    +{% highlight java %} +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.regression.LinearRegression; +import org.apache.spark.ml.tuning.*; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.DataFrame; + +DataFrame data = sqlContext.createDataFrame( + MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"), + LabeledPoint.class); + +// Prepare training and test data. +DataFrame[] splits = data.randomSplit(new double[] {0.9, 0.1}, 12345); +DataFrame training = splits[0]; +DataFrame test = splits[1]; + +LinearRegression lr = new LinearRegression(); + +// We use a ParamGridBuilder to construct a grid of parameters to search over. +// TrainValidationSplit will try all combinations of values and determine best model using +// the evaluator. +ParamMap[] paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam(), new double[] {0.1, 0.01}) + .addGrid(lr.fitIntercept()) + .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0}) + .build(); + +// In this case the estimator is simply the linear regression. +// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +TrainValidationSplit trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator()) + .setEstimatorParamMaps(paramGrid) + .setTrainRatio(0.8); // 80% for training and the remaining 20% for validation + +// Run train validation split, and choose the best set of parameters. +TrainValidationSplitModel model = trainValidationSplit.fit(training); + +// Make predictions on test data. model is the model with combination of parameters +// that performed best. +model.transform(test) + .select("features", "label", "prediction") + .show(); -Other changes were in `LogisticRegression`: +{% endhighlight %} +
    -* The `scoreCol` output column (with default value "score") was renamed to be `probabilityCol` (with default value "probability"). The type was originally `Double` (for the probability of class 1.0), but it is now `Vector` (for the probability of each class, to support multiclass classification in the future). -* In Spark 1.2, `LogisticRegressionModel` did not include an intercept. In Spark 1.3, it includes an intercept; however, it will always be 0.0 since it uses the default settings for [spark.mllib.LogisticRegressionWithLBFGS](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS). The option to use an intercept will be added in the future. +
    diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md new file mode 100644 index 000000000000..4e94e2f9c708 --- /dev/null +++ b/docs/ml-linear-methods.md @@ -0,0 +1,350 @@ +--- +layout: global +title: Linear Methods - ML +displayTitle: ML - Linear Methods +--- + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + + +In MLlib, we implement popular linear methods such as logistic +regression and linear least squares with $L_1$ or $L_2$ regularization. +Refer to [the linear methods in mllib](mllib-linear-methods.html) for +details. In `spark.ml`, we also include Pipelines API for [Elastic +net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid +of $L_1$ and $L_2$ regularization proposed in [Zou et al, Regularization +and variable selection via the elastic +net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). +Mathematically, it is defined as a convex combination of the $L_1$ and +the $L_2$ regularization terms: +`\[ +\alpha \left( \lambda \|\wv\|_1 \right) + (1-\alpha) \left( \frac{\lambda}{2}\|\wv\|_2^2 \right) , \alpha \in [0, 1], \lambda \geq 0 +\]` +By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$ +regularization as special cases. For example, if a [linear +regression](https://en.wikipedia.org/wiki/Linear_regression) model is +trained with the elastic net parameter $\alpha$ set to $1$, it is +equivalent to a +[Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. +On the other hand, if $\alpha$ is set to $0$, the trained model reduces +to a [ridge +regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. +We implement Pipelines API for both linear regression and logistic +regression with elastic net regularization. + +## Example: Logistic Regression + +The following example shows how to train a logistic regression model +with elastic net regularization. `elasticNetParam` corresponds to +$\alpha$ and `regParam` corresponds to $\lambda$. + +
    + +
    +{% highlight scala %} +import org.apache.spark.ml.classification.LogisticRegression + +// Load training data +val training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + +val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + +// Fit the model +val lrModel = lr.fit(training) + +// Print the weights and intercept for logistic regression +println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}") +{% endhighlight %} +
    + +
    +{% highlight java %} +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class LogisticRegressionWithElasticNetExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf() + .setAppName("Logistic Regression with Elastic Net Example"); + + SparkContext sc = new SparkContext(conf); + SQLContext sql = new SQLContext(sc); + String path = "data/mllib/sample_libsvm_data.txt"; + + // Load training data + DataFrame training = sqlContext.read.format("libsvm").load(path); + + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8); + + // Fit the model + LogisticRegressionModel lrModel = lr.fit(training); + + // Print the weights and intercept for logistic regression + System.out.println("Weights: " + lrModel.weights() + " Intercept: " + lrModel.intercept()); + } +} +{% endhighlight %} +
    + +
    +{% highlight python %} +from pyspark.ml.classification import LogisticRegression + +# Load training data +training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + +lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) + +# Fit the model +lrModel = lr.fit(training) + +# Print the weights and intercept for logistic regression +print("Weights: " + str(lrModel.weights)) +print("Intercept: " + str(lrModel.intercept)) +{% endhighlight %} +
    + +
    + +The `spark.ml` implementation of logistic regression also supports +extracting a summary of the model over the training set. Note that the +predictions and metrics which are stored as `Dataframe` in +`BinaryLogisticRegressionSummary` are annotated `@transient` and hence +only available on the driver. + +
    + +
    + +[`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary) +provides a summary for a +[`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel). +Currently, only binary classification is supported and the +summary must be explicitly cast to +[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary). +This will likely change when multiclass classification is supported. + +Continuing the earlier example: + +{% highlight scala %} +import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary + +// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example +val trainingSummary = lrModel.summary + +// Obtain the objective per iteration. +val objectiveHistory = trainingSummary.objectiveHistory +objectiveHistory.foreach(loss => println(loss)) + +// Obtain the metrics useful to judge performance on test data. +// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a +// binary classification problem. +val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary] + +// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. +val roc = binarySummary.roc +roc.show() +println(binarySummary.areaUnderROC) + +// Set the model threshold to maximize F-Measure +val fMeasure = binarySummary.fMeasureByThreshold +val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0) +val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure). + select("threshold").head().getDouble(0) +lrModel.setThreshold(bestThreshold) +{% endhighlight %} +
    + +
    +[`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html) +provides a summary for a +[`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html). +Currently, only binary classification is supported and the +summary must be explicitly cast to +[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). +This will likely change when multiclass classification is supported. + +Continuing the earlier example: + +{% highlight java %} +import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary; +import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary; +import org.apache.spark.sql.functions; + +// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example +LogisticRegressionTrainingSummary trainingSummary = lrModel.summary(); + +// Obtain the loss per iteration. +double[] objectiveHistory = trainingSummary.objectiveHistory(); +for (double lossPerIteration : objectiveHistory) { + System.out.println(lossPerIteration); +} + +// Obtain the metrics useful to judge performance on test data. +// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a +// binary classification problem. +BinaryLogisticRegressionSummary binarySummary = (BinaryLogisticRegressionSummary) trainingSummary; + +// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. +DataFrame roc = binarySummary.roc(); +roc.show(); +roc.select("FPR").show(); +System.out.println(binarySummary.areaUnderROC()); + +// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with +// this selected threshold. +DataFrame fMeasure = binarySummary.fMeasureByThreshold(); +double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0); +double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)). + select("threshold").head().getDouble(0); +lrModel.setThreshold(bestThreshold); +{% endhighlight %} +
    + + +
    +Logistic regression model summary is not yet supported in Python. +
    + +
    + +## Example: Linear Regression + +The interface for working with linear regression models and model +summaries is similar to the logistic regression case. The following +example demonstrates training an elastic net regularized linear +regression model and extracting model summary statistics. + +
    + +
    +{% highlight scala %} +import org.apache.spark.ml.regression.LinearRegression + +// Load training data +val training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + +val lr = new LinearRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + +// Fit the model +val lrModel = lr.fit(training) + +// Print the weights and intercept for linear regression +println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}") + +// Summarize the model over the training set and print out some metrics +val trainingSummary = lrModel.summary +println(s"numIterations: ${trainingSummary.totalIterations}") +println(s"objectiveHistory: ${trainingSummary.objectiveHistory.toList}") +trainingSummary.residuals.show() +println(s"RMSE: ${trainingSummary.rootMeanSquaredError}") +println(s"r2: ${trainingSummary.r2}") +{% endhighlight %} +
    + +
    +{% highlight java %} +import org.apache.spark.ml.regression.LinearRegression; +import org.apache.spark.ml.regression.LinearRegressionModel; +import org.apache.spark.ml.regression.LinearRegressionTrainingSummary; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class LinearRegressionWithElasticNetExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf() + .setAppName("Linear Regression with Elastic Net Example"); + + SparkContext sc = new SparkContext(conf); + SQLContext sql = new SQLContext(sc); + String path = "data/mllib/sample_libsvm_data.txt"; + + // Load training data + DataFrame training = sqlContext.read.format("libsvm").load(path); + + LinearRegression lr = new LinearRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8); + + // Fit the model + LinearRegressionModel lrModel = lr.fit(training); + + // Print the weights and intercept for linear regression + System.out.println("Weights: " + lrModel.weights() + " Intercept: " + lrModel.intercept()); + + // Summarize the model over the training set and print out some metrics + LinearRegressionTrainingSummary trainingSummary = lrModel.summary(); + System.out.println("numIterations: " + trainingSummary.totalIterations()); + System.out.println("objectiveHistory: " + Vectors.dense(trainingSummary.objectiveHistory())); + trainingSummary.residuals().show(); + System.out.println("RMSE: " + trainingSummary.rootMeanSquaredError()); + System.out.println("r2: " + trainingSummary.r2()); + } +} +{% endhighlight %} +
    + +
    + +{% highlight python %} +from pyspark.ml.regression import LinearRegression + +# Load training data +training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + +lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) + +# Fit the model +lrModel = lr.fit(training) + +# Print the weights and intercept for linear regression +print("Weights: " + str(lrModel.weights)) +print("Intercept: " + str(lrModel.intercept)) + +# Linear regression model summary is not yet supported in Python. +{% endhighlight %} +
    + +
    + +# Optimization + +The optimization algorithm underlying the implementation is called +[Orthant-Wise Limited-memory +QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) +(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 +regularization and elastic net. + diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index dcaa3784be87..c2711cf82deb 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -33,6 +33,7 @@ guaranteed to find a globally optimal solution, and when run multiple times on a given dataset, the algorithm returns the best clustering result). * *initializationSteps* determines the number of steps in the k-means\|\| algorithm. * *epsilon* determines the distance threshold within which we consider k-means to have converged. +* *initialModel* is an optional set of cluster centers used for initialization. If this parameter is supplied, only one run is performed. **Examples** @@ -327,11 +328,17 @@ which contains the computed clustering assignments. import org.apache.spark.mllib.clustering.{PowerIterationClustering, PowerIterationClusteringModel} import org.apache.spark.mllib.linalg.Vectors -val similarities: RDD[(Long, Long, Double)] = ... +// Load and parse the data +val data = sc.textFile("data/mllib/pic_data.txt") +val similarities = data.map { line => + val parts = line.split(' ') + (parts(0).toLong, parts(1).toLong, parts(2).toDouble) +} +// Cluster the data into two classes using PowerIterationClustering val pic = new PowerIterationClustering() - .setK(3) - .setMaxIterations(20) + .setK(2) + .setMaxIterations(10) val model = pic.run(similarities) model.assignments.foreach { a => @@ -363,11 +370,22 @@ import scala.Tuple2; import scala.Tuple3; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.clustering.PowerIterationClustering; import org.apache.spark.mllib.clustering.PowerIterationClusteringModel; -JavaRDD> similarities = ... +// Load and parse the data +JavaRDD data = sc.textFile("data/mllib/pic_data.txt"); +JavaRDD> similarities = data.map( + new Function>() { + public Tuple3 call(String line) { + String[] parts = line.split(" "); + return new Tuple3<>(new Long(parts[0]), new Long(parts[1]), new Double(parts[2])); + } + } +); +// Cluster the data into two classes using PowerIterationClustering PowerIterationClustering pic = new PowerIterationClustering() .setK(2) .setMaxIterations(10); @@ -383,6 +401,35 @@ PowerIterationClusteringModel sameModel = PowerIterationClusteringModel.load(sc. {% endhighlight %}
    +
    + +[`PowerIterationClustering`](api/python/pyspark.mllib.html#pyspark.mllib.clustering.PowerIterationClustering) +implements the PIC algorithm. +It takes an `RDD` of `(srcId: Long, dstId: Long, similarity: Double)` tuples representing the +affinity matrix. +Calling `PowerIterationClustering.run` returns a +[`PowerIterationClusteringModel`](api/python/pyspark.mllib.html#pyspark.mllib.clustering.PowerIterationClustering), +which contains the computed clustering assignments. + +{% highlight python %} +from __future__ import print_function +from pyspark.mllib.clustering import PowerIterationClustering, PowerIterationClusteringModel + +# Load and parse the data +data = sc.textFile("data/mllib/pic_data.txt") +similarities = data.map(lambda line: tuple([float(x) for x in line.split(' ')])) + +# Cluster the data into two classes using PowerIterationClustering +model = PowerIterationClustering.train(similarities, 2, 10) + +model.assignments().foreach(lambda x: print(str(x.id) + " -> " + str(x.cluster))) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = PowerIterationClusteringModel.load(sc, "myModelPath") +{% endhighlight %} +
    +
    ## Latent Dirichlet allocation (LDA) @@ -391,28 +438,129 @@ PowerIterationClusteringModel sameModel = PowerIterationClusteringModel.load(sc. is a topic model which infers topics from a collection of text documents. LDA can be thought of as a clustering algorithm as follows: -* Topics correspond to cluster centers, and documents correspond to examples (rows) in a dataset. -* Topics and documents both exist in a feature space, where feature vectors are vectors of word counts. -* Rather than estimating a clustering using a traditional distance, LDA uses a function based - on a statistical model of how text documents are generated. - -LDA takes in a collection of documents as vectors of word counts. -It supports different inference algorithms via `setOptimizer` function. EMLDAOptimizer learns clustering using [expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) -on the likelihood function and yields comprehensive results, while OnlineLDAOptimizer uses iterative mini-batch sampling for [online variational inference](https://www.cs.princeton.edu/~blei/papers/HoffmanBleiBach2010b.pdf) and is generally memory friendly. After fitting on the documents, LDA provides: +* Topics correspond to cluster centers, and documents correspond to +examples (rows) in a dataset. +* Topics and documents both exist in a feature space, where feature +vectors are vectors of word counts (bag of words). +* Rather than estimating a clustering using a traditional distance, LDA +uses a function based on a statistical model of how text documents are +generated. -* Topics: Inferred topics, each of which is a probability distribution over terms (words). -* Topic distributions for documents: For each document in the training set, LDA gives a probability distribution over topics. (EM only) +LDA supports different inference algorithms via `setOptimizer` function. +`EMLDAOptimizer` learns clustering using +[expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) +on the likelihood function and yields comprehensive results, while +`OnlineLDAOptimizer` uses iterative mini-batch sampling for [online +variational +inference](https://www.cs.princeton.edu/~blei/papers/HoffmanBleiBach2010b.pdf) +and is generally memory friendly. -LDA takes the following parameters: +LDA takes in a collection of documents as vectors of word counts and the +following parameters (set using the builder pattern): * `k`: Number of topics (i.e., cluster centers) -* `maxIterations`: Limit on the number of iterations of EM used for learning -* `docConcentration`: Hyperparameter for prior over documents' distributions over topics. Currently must be > 1, where larger values encourage smoother inferred distributions. -* `topicConcentration`: Hyperparameter for prior over topics' distributions over terms (words). Currently must be > 1, where larger values encourage smoother inferred distributions. -* `checkpointInterval`: If using checkpointing (set in the Spark configuration), this parameter specifies the frequency with which checkpoints will be created. If `maxIterations` is large, using checkpointing can help reduce shuffle file sizes on disk and help with failure recovery. - -*Note*: LDA is a new feature with some missing functionality. In particular, it does not yet -support prediction on new documents, and it does not have a Python API. These will be added in the future. +* `optimizer`: Optimizer to use for learning the LDA model, either +`EMLDAOptimizer` or `OnlineLDAOptimizer` +* `docConcentration`: Dirichlet parameter for prior over documents' +distributions over topics. Larger values encourage smoother inferred +distributions. +* `topicConcentration`: Dirichlet parameter for prior over topics' +distributions over terms (words). Larger values encourage smoother +inferred distributions. +* `maxIterations`: Limit on the number of iterations. +* `checkpointInterval`: If using checkpointing (set in the Spark +configuration), this parameter specifies the frequency with which +checkpoints will be created. If `maxIterations` is large, using +checkpointing can help reduce shuffle file sizes on disk and help with +failure recovery. + + +All of MLlib's LDA models support: + +* `describeTopics`: Returns topics as arrays of most important terms and +term weights +* `topicsMatrix`: Returns a `vocabSize` by `k` matrix where each column +is a topic + +*Note*: LDA is still an experimental feature under active development. +As a result, certain features are only available in one of the two +optimizers / models generated by the optimizer. Currently, a distributed +model can be converted into a local model, but not vice-versa. + +The following discussion will describe each optimizer/model pair +separately. + +**Expectation Maximization** + +Implemented in +[`EMLDAOptimizer`](api/scala/index.html#org.apache.spark.mllib.clustering.EMLDAOptimizer) +and +[`DistributedLDAModel`](api/scala/index.html#org.apache.spark.mllib.clustering.DistributedLDAModel). + +For the parameters provided to `LDA`: + +* `docConcentration`: Only symmetric priors are supported, so all values +in the provided `k`-dimensional vector must be identical. All values +must also be $> 1.0$. Providing `Vector(-1)` results in default behavior +(uniform `k` dimensional vector with value $(50 / k) + 1$ +* `topicConcentration`: Only symmetric priors supported. Values must be +$> 1.0$. Providing `-1` results in defaulting to a value of $0.1 + 1$. +* `maxIterations`: The maximum number of EM iterations. + +*Note*: It is important to do enough iterations. In early iterations, EM often has useless topics, +but those topics improve dramatically after more iterations. Using at least 20 and possibly +50-100 iterations is often reasonable, depending on your dataset. + +`EMLDAOptimizer` produces a `DistributedLDAModel`, which stores not only +the inferred topics but also the full training corpus and topic +distributions for each document in the training corpus. A +`DistributedLDAModel` supports: + + * `topTopicsPerDocument`: The top topics and their weights for + each document in the training corpus + * `topDocumentsPerTopic`: The top documents for each topic and + the corresponding weight of the topic in the documents. + * `logPrior`: log probability of the estimated topics and + document-topic distributions given the hyperparameters + `docConcentration` and `topicConcentration` + * `logLikelihood`: log likelihood of the training corpus, given the + inferred topics and document-topic distributions + +**Online Variational Bayes** + +Implemented in +[`OnlineLDAOptimizer`](api/scala/org/apache/spark/mllib/clustering/OnlineLDAOptimizer.html) +and +[`LocalLDAModel`](api/scala/org/apache/spark/mllib/clustering/LocalLDAModel.html). + +For the parameters provided to `LDA`: + +* `docConcentration`: Asymmetric priors can be used by passing in a +vector with values equal to the Dirichlet parameter in each of the `k` +dimensions. Values should be $>= 0$. Providing `Vector(-1)` results in +default behavior (uniform `k` dimensional vector with value $(1.0 / k)$) +* `topicConcentration`: Only symmetric priors supported. Values must be +$>= 0$. Providing `-1` results in defaulting to a value of $(1.0 / k)$. +* `maxIterations`: Maximum number of minibatches to submit. + +In addition, `OnlineLDAOptimizer` accepts the following parameters: + +* `miniBatchFraction`: Fraction of corpus sampled and used at each +iteration +* `optimizeDocConcentration`: If set to true, performs maximum-likelihood +estimation of the hyperparameter `docConcentration` (aka `alpha`) +after each minibatch and sets the optimized `docConcentration` in the +returned `LocalLDAModel` +* `tau0` and `kappa`: Used for learning-rate decay, which is computed by +$(\tau_0 + iter)^{-\kappa}$ where $iter$ is the current number of iterations. + +`OnlineLDAOptimizer` produces a `LocalLDAModel`, which only stores the +inferred topics. A `LocalLDAModel` supports: + +* `logLikelihood(documents)`: Calculates a lower bound on the provided +`documents` given the inferred topics. +* `logPerplexity(documents)`: Calculates an upper bound on the +perplexity of the provided `documents` given the inferred topics. **Examples** @@ -425,7 +573,7 @@ to the algorithm. We then output the topics, represented as probability distribu
    {% highlight scala %} -import org.apache.spark.mllib.clustering.LDA +import org.apache.spark.mllib.clustering.{LDA, DistributedLDAModel} import org.apache.spark.mllib.linalg.Vectors // Load and parse the data @@ -445,6 +593,11 @@ for (topic <- Range(0, 3)) { for (word <- Range(0, ldaModel.vocabSize)) { print(" " + topics(word, topic)); } println() } + +// Save and load model. +ldaModel.save(sc, "myLDAModel") +val sameModel = DistributedLDAModel.load(sc, "myLDAModel") + {% endhighlight %}
    @@ -504,11 +657,42 @@ public class JavaLDAExample { } System.out.println(); } + + ldaModel.save(sc.sc(), "myLDAModel"); + DistributedLDAModel sameModel = DistributedLDAModel.load(sc.sc(), "myLDAModel"); } } {% endhighlight %}
    +
    +{% highlight python %} +from pyspark.mllib.clustering import LDA, LDAModel +from pyspark.mllib.linalg import Vectors + +# Load and parse the data +data = sc.textFile("data/mllib/sample_lda_data.txt") +parsedData = data.map(lambda line: Vectors.dense([float(x) for x in line.strip().split(' ')])) +# Index documents with unique IDs +corpus = parsedData.zipWithIndex().map(lambda x: [x[1], x[0]]).cache() + +# Cluster the documents into three topics using LDA +ldaModel = LDA.train(corpus, k=3) + +# Output topics. Each is a distribution over words (matching word count vectors) +print("Learned topics (as distributions over vocab of " + str(ldaModel.vocabSize()) + " words):") +topics = ldaModel.topicsMatrix() +for topic in range(3): + print("Topic " + str(topic) + ":") + for word in range(0, ldaModel.vocabSize()): + print(" " + str(topics[word][topic])) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = LDAModel.load(sc, "myModelPath") +{% endhighlight %} +
    +
    ## Streaming k-means diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index dfdf6216b270..eedc23424ad5 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -77,7 +77,7 @@ val ratings = data.map(_.split(',') match { case Array(user, item, rate) => // Build the recommendation model using ALS val rank = 10 -val numIterations = 20 +val numIterations = 10 val model = ALS.train(ratings, rank, numIterations, 0.01) // Evaluate the model on rating data @@ -149,7 +149,7 @@ public class CollaborativeFiltering { // Build the recommendation model using ALS int rank = 10; - int numIterations = 20; + int numIterations = 10; MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01); // Evaluate the model on rating data @@ -210,7 +210,7 @@ ratings = data.map(lambda l: l.split(',')).map(lambda l: Rating(int(l[0]), int(l # Build the recommendation model using Alternating Least Squares rank = 10 -numIterations = 20 +numIterations = 10 model = ALS.train(ratings, rank, numIterations) # Evaluate the model on training data diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index d824dab1d7f7..d8c7bdc63c70 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -144,7 +144,7 @@ import org.apache.spark.mllib.regression.LabeledPoint; LabeledPoint pos = new LabeledPoint(1.0, Vectors.dense(1.0, 0.0, 3.0)); // Create a labeled point with a negative label and a sparse feature vector. -LabeledPoint neg = new LabeledPoint(1.0, Vectors.sparse(3, new int[] {0, 2}, new double[] {1.0, 3.0})); +LabeledPoint neg = new LabeledPoint(0.0, Vectors.sparse(3, new int[] {0, 2}, new double[] {1.0, 3.0})); {% endhighlight %}
    @@ -226,7 +226,8 @@ examples = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") A local matrix has integer-typed row and column indices and double-typed values, stored on a single machine. MLlib supports dense matrices, whose entry values are stored in a single double array in -column major. For example, the following matrix `\[ \begin{pmatrix} +column-major order, and sparse matrices, whose non-zero entry values are stored in the Compressed Sparse +Column (CSC) format in column-major order. For example, the following dense matrix `\[ \begin{pmatrix} 1.0 & 2.0 \\ 3.0 & 4.0 \\ 5.0 & 6.0 @@ -238,28 +239,33 @@ is stored in a one-dimensional array `[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]` with the m
    The base class of local matrices is -[`Matrix`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrix), and we provide one -implementation: [`DenseMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.DenseMatrix). +[`Matrix`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrix), and we provide two +implementations: [`DenseMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.DenseMatrix), +and [`SparseMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.SparseMatrix). We recommend using the factory methods implemented in [`Matrices`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrices$) to create local -matrices. +matrices. Remember, local matrices in MLlib are stored in column-major order. {% highlight scala %} import org.apache.spark.mllib.linalg.{Matrix, Matrices} // Create a dense matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)) val dm: Matrix = Matrices.dense(3, 2, Array(1.0, 3.0, 5.0, 2.0, 4.0, 6.0)) + +// Create a sparse matrix ((9.0, 0.0), (0.0, 8.0), (0.0, 6.0)) +val sm: Matrix = Matrices.sparse(3, 2, Array(0, 1, 3), Array(0, 2, 1), Array(9, 6, 8)) {% endhighlight %}
    The base class of local matrices is -[`Matrix`](api/java/org/apache/spark/mllib/linalg/Matrix.html), and we provide one -implementation: [`DenseMatrix`](api/java/org/apache/spark/mllib/linalg/DenseMatrix.html). +[`Matrix`](api/java/org/apache/spark/mllib/linalg/Matrix.html), and we provide two +implementations: [`DenseMatrix`](api/java/org/apache/spark/mllib/linalg/DenseMatrix.html), +and [`SparseMatrix`](api/java/org/apache/spark/mllib/linalg/SparseMatrix.html). We recommend using the factory methods implemented in [`Matrices`](api/java/org/apache/spark/mllib/linalg/Matrices.html) to create local -matrices. +matrices. Remember, local matrices in MLlib are stored in column-major order. {% highlight java %} import org.apache.spark.mllib.linalg.Matrix; @@ -267,6 +273,30 @@ import org.apache.spark.mllib.linalg.Matrices; // Create a dense matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)) Matrix dm = Matrices.dense(3, 2, new double[] {1.0, 3.0, 5.0, 2.0, 4.0, 6.0}); + +// Create a sparse matrix ((9.0, 0.0), (0.0, 8.0), (0.0, 6.0)) +Matrix sm = Matrices.sparse(3, 2, new int[] {0, 1, 3}, new int[] {0, 2, 1}, new double[] {9, 6, 8}); +{% endhighlight %} +
    + +
    + +The base class of local matrices is +[`Matrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Matrix), and we provide two +implementations: [`DenseMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.DenseMatrix), +and [`SparseMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.SparseMatrix). +We recommend using the factory methods implemented +in [`Matrices`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Matrices) to create local +matrices. Remember, local matrices in MLlib are stored in column-major order. + +{% highlight python %} +import org.apache.spark.mllib.linalg.{Matrix, Matrices} + +// Create a dense matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)) +dm2 = Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6]) + +// Create a sparse matrix ((9.0, 0.0), (0.0, 8.0), (0.0, 6.0)) +sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 2, 1], [9, 6, 8]) {% endhighlight %}
    @@ -307,7 +337,10 @@ limited by the integer range but it should be much smaller in practice.
    A [`RowMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.RowMatrix) can be -created from an `RDD[Vector]` instance. Then we can compute its column summary statistics. +created from an `RDD[Vector]` instance. Then we can compute its column summary statistics and decompositions. +[QR decomposition](https://en.wikipedia.org/wiki/QR_decomposition) is of the form A = QR where Q is an orthogonal matrix and R is an upper triangular matrix. +For [singular value decomposition (SVD)](https://en.wikipedia.org/wiki/Singular_value_decomposition) and [principal component analysis (PCA)](https://en.wikipedia.org/wiki/Principal_component_analysis), please refer to [Dimensionality reduction](mllib-dimensionality-reduction.html). + {% highlight scala %} import org.apache.spark.mllib.linalg.Vector @@ -320,6 +353,9 @@ val mat: RowMatrix = new RowMatrix(rows) // Get its size. val m = mat.numRows() val n = mat.numCols() + +// QR decomposition +val qrResult = mat.tallSkinnyQR(true) {% endhighlight %}
    @@ -340,14 +376,42 @@ RowMatrix mat = new RowMatrix(rows.rdd()); // Get its size. long m = mat.numRows(); long n = mat.numCols(); + +// QR decomposition +QRDecomposition result = mat.tallSkinnyQR(true); {% endhighlight %}
    + +
    + +A [`RowMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.RowMatrix) can be +created from an `RDD` of vectors. + +{% highlight python %} +from pyspark.mllib.linalg.distributed import RowMatrix + +# Create an RDD of vectors. +rows = sc.parallelize([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]) + +# Create a RowMatrix from an RDD of vectors. +mat = RowMatrix(rows) + +# Get its size. +m = mat.numRows() # 4 +n = mat.numCols() # 3 + +# Get the rows as an RDD of vectors again. +rowsRDD = mat.rows +{% endhighlight %} +
    +
    ### IndexedRowMatrix An `IndexedRowMatrix` is similar to a `RowMatrix` but with meaningful row indices. It is backed by -an RDD of indexed rows, so that each row is represented by its index (long-typed) and a local vector. +an RDD of indexed rows, so that each row is represented by its index (long-typed) and a local +vector.
    @@ -401,7 +465,51 @@ long n = mat.numCols(); // Drop its row indices. RowMatrix rowMat = mat.toRowMatrix(); {% endhighlight %} -
    +
    + +
    + +An [`IndexedRowMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.IndexedRowMatrix) +can be created from an `RDD` of `IndexedRow`s, where +[`IndexedRow`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.IndexedRow) is a +wrapper over `(long, vector)`. An `IndexedRowMatrix` can be converted to a `RowMatrix` by dropping +its row indices. + +{% highlight python %} +from pyspark.mllib.linalg.distributed import IndexedRow, IndexedRowMatrix + +# Create an RDD of indexed rows. +# - This can be done explicitly with the IndexedRow class: +indexedRows = sc.parallelize([IndexedRow(0, [1, 2, 3]), + IndexedRow(1, [4, 5, 6]), + IndexedRow(2, [7, 8, 9]), + IndexedRow(3, [10, 11, 12])]) +# - or by using (long, vector) tuples: +indexedRows = sc.parallelize([(0, [1, 2, 3]), (1, [4, 5, 6]), + (2, [7, 8, 9]), (3, [10, 11, 12])]) + +# Create an IndexedRowMatrix from an RDD of IndexedRows. +mat = IndexedRowMatrix(indexedRows) + +# Get its size. +m = mat.numRows() # 4 +n = mat.numCols() # 3 + +# Get the rows as an RDD of IndexedRows. +rowsRDD = mat.rows + +# Convert to a RowMatrix by dropping the row indices. +rowMat = mat.toRowMatrix() + +# Convert to a CoordinateMatrix. +coordinateMat = mat.toCoordinateMatrix() + +# Convert to a BlockMatrix. +blockMat = mat.toBlockMatrix() +{% endhighlight %} +
    + +
    ### CoordinateMatrix @@ -465,6 +573,45 @@ long n = mat.numCols(); IndexedRowMatrix indexedRowMatrix = mat.toIndexedRowMatrix(); {% endhighlight %}
    + +
    + +A [`CoordinateMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.CoordinateMatrix) +can be created from an `RDD` of `MatrixEntry` entries, where +[`MatrixEntry`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.MatrixEntry) is a +wrapper over `(long, long, float)`. A `CoordinateMatrix` can be converted to a `RowMatrix` by +calling `toRowMatrix`, or to an `IndexedRowMatrix` with sparse rows by calling `toIndexedRowMatrix`. + +{% highlight python %} +from pyspark.mllib.linalg.distributed import CoordinateMatrix, MatrixEntry + +# Create an RDD of coordinate entries. +# - This can be done explicitly with the MatrixEntry class: +entries = sc.parallelize([MatrixEntry(0, 0, 1.2), MatrixEntry(1, 0, 2.1), MatrixEntry(6, 1, 3.7)]) +# - or using (long, long, float) tuples: +entries = sc.parallelize([(0, 0, 1.2), (1, 0, 2.1), (2, 1, 3.7)]) + +# Create an CoordinateMatrix from an RDD of MatrixEntries. +mat = CoordinateMatrix(entries) + +# Get its size. +m = mat.numRows() # 3 +n = mat.numCols() # 2 + +# Get the entries as an RDD of MatrixEntries. +entriesRDD = mat.entries + +# Convert to a RowMatrix. +rowMat = mat.toRowMatrix() + +# Convert to an IndexedRowMatrix. +indexedRowMat = mat.toIndexedRowMatrix() + +# Convert to a BlockMatrix. +blockMat = mat.toBlockMatrix() +{% endhighlight %} +
    +
    ### BlockMatrix @@ -529,4 +676,39 @@ matA.validate(); BlockMatrix ata = matA.transpose().multiply(matA); {% endhighlight %}
    + +
    + +A [`BlockMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.BlockMatrix) +can be created from an `RDD` of sub-matrix blocks, where a sub-matrix block is a +`((blockRowIndex, blockColIndex), sub-matrix)` tuple. + +{% highlight python %} +from pyspark.mllib.linalg import Matrices +from pyspark.mllib.linalg.distributed import BlockMatrix + +# Create an RDD of sub-matrix blocks. +blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + +# Create a BlockMatrix from an RDD of sub-matrix blocks. +mat = BlockMatrix(blocks, 3, 2) + +# Get its size. +m = mat.numRows() # 6 +n = mat.numCols() # 2 + +# Get the blocks as an RDD of sub-matrix blocks. +blocksRDD = mat.blocks + +# Convert to a LocalMatrix. +localMat = mat.toLocalMatrix() + +# Convert to an IndexedRowMatrix. +indexedRowMat = mat.toIndexedRowMatrix() + +# Convert to a CoordinateMatrix. +coordinateMat = mat.toCoordinateMatrix() +{% endhighlight %} +
    diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index 7521fb14a7bd..1e00b2083ed7 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -9,7 +9,7 @@ displayTitle: MLlib - Ensembles An [ensemble method](http://en.wikipedia.org/wiki/Ensemble_learning) is a learning algorithm which creates a model composed of a set of other base models. -MLlib supports two major ensemble algorithms: [`GradientBoostedTrees`](api/scala/index.html#org.apache.spark.mllib.tree.GradientBosotedTrees) and [`RandomForest`](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest). +MLlib supports two major ensemble algorithms: [`GradientBoostedTrees`](api/scala/index.html#org.apache.spark.mllib.tree.GradientBoostedTrees) and [`RandomForest`](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest). Both use [decision trees](mllib-decision-tree.html) as their base models. ## Gradient-Boosted Trees vs. Random Forests diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md new file mode 100644 index 000000000000..7066d5c97418 --- /dev/null +++ b/docs/mllib-evaluation-metrics.md @@ -0,0 +1,1497 @@ +--- +layout: global +title: Evaluation Metrics - MLlib +displayTitle: MLlib - Evaluation Metrics +--- + +* Table of contents +{:toc} + +Spark's MLlib comes with a number of machine learning algorithms that can be used to learn from and make predictions +on data. When these algorithms are applied to build machine learning models, there is a need to evaluate the performance +of the model on some criteria, which depends on the application and its requirements. Spark's MLlib also provides a +suite of metrics for the purpose of evaluating the performance of machine learning models. + +Specific machine learning algorithms fall under broader types of machine learning applications like classification, +regression, clustering, etc. Each of these types have well established metrics for performance evaluation and those +metrics that are currently available in Spark's MLlib are detailed in this section. + +## Classification model evaluation + +While there are many different types of classification algorithms, the evaluation of classification models all share +similar principles. In a [supervised classification problem](https://en.wikipedia.org/wiki/Statistical_classification), +there exists a true output and a model-generated predicted output for each data point. For this reason, the results for +each data point can be assigned to one of four categories: + +* True Positive (TP) - label is positive and prediction is also positive +* True Negative (TN) - label is negative and prediction is also negative +* False Positive (FP) - label is negative but prediction is positive +* False Negative (FN) - label is positive but prediction is negative + +These four numbers are the building blocks for most classifier evaluation metrics. A fundamental point when considering +classifier evaluation is that pure accuracy (i.e. was the prediction correct or incorrect) is not generally a good metric. The +reason for this is because a dataset may be highly unbalanced. For example, if a model is designed to predict fraud from +a dataset where 95% of the data points are _not fraud_ and 5% of the data points are _fraud_, then a naive classifier +that predicts _not fraud_, regardless of input, will be 95% accurate. For this reason, metrics like +[precision and recall](https://en.wikipedia.org/wiki/Precision_and_recall) are typically used because they take into +account the *type* of error. In most applications there is some desired balance between precision and recall, which can +be captured by combining the two into a single metric, called the [F-measure](https://en.wikipedia.org/wiki/F1_score). + +### Binary classification + +[Binary classifiers](https://en.wikipedia.org/wiki/Binary_classification) are used to separate the elements of a given +dataset into one of two possible groups (e.g. fraud or not fraud) and is a special case of multiclass classification. +Most binary classification metrics can be generalized to multiclass classification metrics. + +#### Threshold tuning + +It is import to understand that many classification models actually output a "score" (often times a probability) for +each class, where a higher score indicates higher likelihood. In the binary case, the model may output a probability for +each class: $P(Y=1|X)$ and $P(Y=0|X)$. Instead of simply taking the higher probability, there may be some cases where +the model might need to be tuned so that it only predicts a class when the probability is very high (e.g. only block a +credit card transaction if the model predicts fraud with >90% probability). Therefore, there is a prediction *threshold* +which determines what the predicted class will be based on the probabilities that the model outputs. + +Tuning the prediction threshold will change the precision and recall of the model and is an important part of model +optimization. In order to visualize how precision, recall, and other metrics change as a function of the threshold it is +common practice to plot competing metrics against one another, parameterized by threshold. A P-R curve plots (precision, +recall) points for different threshold values, while a +[receiver operating characteristic](https://en.wikipedia.org/wiki/Receiver_operating_characteristic), or ROC, curve +plots (recall, false positive rate) points. + +**Available metrics** + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    MetricDefinition
    Precision (Postive Predictive Value)$PPV=\frac{TP}{TP + FP}$
    Recall (True Positive Rate)$TPR=\frac{TP}{P}=\frac{TP}{TP + FN}$
    F-measure$F(\beta) = \left(1 + \beta^2\right) \cdot \left(\frac{PPV \cdot TPR} + {\beta^2 \cdot PPV + TPR}\right)$
    Receiver Operating Characteristic (ROC)$FPR(T)=\int^\infty_{T} P_0(T)\,dT \\ TPR(T)=\int^\infty_{T} P_1(T)\,dT$
    Area Under ROC Curve$AUROC=\int^1_{0} \frac{TP}{P} d\left(\frac{FP}{N}\right)$
    Area Under Precision-Recall Curve$AUPRC=\int^1_{0} \frac{TP}{TP+FP} d\left(\frac{TP}{P}\right)$
    + + +**Examples** + +
    +The following code snippets illustrate how to load a sample dataset, train a binary classification algorithm on the +data, and evaluate the performance of the algorithm by several binary evaluation metrics. + +
    + +{% highlight scala %} +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils + +// Load training data in LIBSVM format +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + +// Split data into training (60%) and test (40%) +val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) +training.cache() + +// Run training algorithm to build the model +val model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training) + +// Clear the prediction threshold so the model will return probabilities +model.clearThreshold + +// Compute raw scores on the test set +val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) +} + +// Instantiate metrics object +val metrics = new BinaryClassificationMetrics(predictionAndLabels) + +// Precision by threshold +val precision = metrics.precisionByThreshold +precision.foreach { case (t, p) => + println(s"Threshold: $t, Precision: $p") +} + +// Recall by threshold +val recall = metrics.precisionByThreshold +recall.foreach { case (t, r) => + println(s"Threshold: $t, Recall: $r") +} + +// Precision-Recall Curve +val PRC = metrics.pr + +// F-measure +val f1Score = metrics.fMeasureByThreshold +f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 1") +} + +val beta = 0.5 +val fScore = metrics.fMeasureByThreshold(beta) +f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 0.5") +} + +// AUPRC +val auPRC = metrics.areaUnderPR +println("Area under precision-recall curve = " + auPRC) + +// Compute thresholds used in ROC and PR curves +val thresholds = precision.map(_._1) + +// ROC Curve +val roc = metrics.roc + +// AUROC +val auROC = metrics.areaUnderROC +println("Area under ROC = " + auROC) + +{% endhighlight %} + +
    + +
    + +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class BinaryClassification { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Binary Classification Metrics"); + SparkContext sc = new SparkContext(conf); + String path = "data/mllib/sample_binary_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training.rdd()); + + // Clear the prediction threshold so the model will return probabilities + model.clearThreshold(); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd()); + + // Precision by threshold + JavaRDD> precision = metrics.precisionByThreshold().toJavaRDD(); + System.out.println("Precision by threshold: " + precision.toArray()); + + // Recall by threshold + JavaRDD> recall = metrics.recallByThreshold().toJavaRDD(); + System.out.println("Recall by threshold: " + recall.toArray()); + + // F Score by threshold + JavaRDD> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); + System.out.println("F1 Score by threshold: " + f1Score.toArray()); + + JavaRDD> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); + System.out.println("F2 Score by threshold: " + f2Score.toArray()); + + // Precision-recall curve + JavaRDD> prc = metrics.pr().toJavaRDD(); + System.out.println("Precision-recall curve: " + prc.toArray()); + + // Thresholds + JavaRDD thresholds = precision.map( + new Function, Double>() { + public Double call (Tuple2 t) { + return new Double(t._1().toString()); + } + } + ); + + // ROC Curve + JavaRDD> roc = metrics.roc().toJavaRDD(); + System.out.println("ROC curve: " + roc.toArray()); + + // AUPRC + System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); + + // AUROC + System.out.println("Area under ROC = " + metrics.areaUnderROC()); + + // Save and load model + model.save(sc, "myModelPath"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); + } +} + +{% endhighlight %} + +
    + +
    + +{% highlight python %} +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.evaluation import BinaryClassificationMetrics +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import MLUtils + +# Several of the methods available in scala are currently missing from pyspark + +# Load training data in LIBSVM format +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + +# Split data into training (60%) and test (40%) +training, test = data.randomSplit([0.6, 0.4], seed = 11L) +training.cache() + +# Run training algorithm to build the model +model = LogisticRegressionWithLBFGS.train(training) + +# Compute raw scores on the test set +predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + +# Instantiate metrics object +metrics = BinaryClassificationMetrics(predictionAndLabels) + +# Area under precision-recall curve +print("Area under PR = %s" % metrics.areaUnderPR) + +# Area under ROC curve +print("Area under ROC = %s" % metrics.areaUnderROC) + +{% endhighlight %} + +
    +
    + + +### Multiclass classification + +A [multiclass classification](https://en.wikipedia.org/wiki/Multiclass_classification) describes a classification +problem where there are $M \gt 2$ possible labels for each data point (the case where $M=2$ is the binary +classification problem). For example, classifying handwriting samples to the digits 0 to 9, having 10 possible classes. + +For multiclass metrics, the notion of positives and negatives is slightly different. Predictions and labels can still +be positive or negative, but they must be considered under the context of a particular class. Each label and prediction +take on the value of one of the multiple classes and so they are said to be positive for their particular class and negative +for all other classes. So, a true positive occurs whenever the prediction and the label match, while a true negative +occurs when neither the prediction nor the label take on the value of a given class. By this convention, there can be +multiple true negatives for a given data sample. The extension of false negatives and false positives from the former +definitions of positive and negative labels is straightforward. + +#### Label based metrics + +Opposed to binary classification where there are only two possible labels, multiclass classification problems have many +possible labels and so the concept of label-based metrics is introduced. Overall precision measures precision across all +labels - the number of times any class was predicted correctly (true positives) normalized by the number of data +points. Precision by label considers only one class, and measures the number of time a specific label was predicted +correctly normalized by the number of times that label appears in the output. + +**Available metrics** + +Define the class, or label, set as + +$$L = \{\ell_0, \ell_1, \ldots, \ell_{M-1} \} $$ + +The true output vector $\mathbf{y}$ consists of $N$ elements + +$$\mathbf{y}_0, \mathbf{y}_1, \ldots, \mathbf{y}_{N-1} \in L $$ + +A multiclass prediction algorithm generates a prediction vector $\hat{\mathbf{y}}$ of $N$ elements + +$$\hat{\mathbf{y}}_0, \hat{\mathbf{y}}_1, \ldots, \hat{\mathbf{y}}_{N-1} \in L $$ + +For this section, a modified delta function $\hat{\delta}(x)$ will prove useful + +$$\hat{\delta}(x) = \begin{cases}1 & \text{if $x = 0$}, \\ 0 & \text{otherwise}.\end{cases}$$ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    MetricDefinition
    Confusion Matrix + $C_{ij} = \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_i) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_j)\\ \\ + \left( \begin{array}{ccc} + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_1) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_1) & \ldots & + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_1) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_N) \\ + \vdots & \ddots & \vdots \\ + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_N) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_1) & \ldots & + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_N) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_N) + \end{array} \right)$ +
    Overall Precision$PPV = \frac{TP}{TP + FP} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i - + \mathbf{y}_i\right)$
    Overall Recall$TPR = \frac{TP}{TP + FN} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i - + \mathbf{y}_i\right)$
    Overall F1-measure$F1 = 2 \cdot \left(\frac{PPV \cdot TPR} + {PPV + TPR}\right)$
    Precision by label$PPV(\ell) = \frac{TP}{TP + FP} = + \frac{\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell) \cdot \hat{\delta}(\mathbf{y}_i - \ell)} + {\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell)}$
    Recall by label$TPR(\ell)=\frac{TP}{P} = + \frac{\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell) \cdot \hat{\delta}(\mathbf{y}_i - \ell)} + {\sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i - \ell)}$
    F-measure by label$F(\beta, \ell) = \left(1 + \beta^2\right) \cdot \left(\frac{PPV(\ell) \cdot TPR(\ell)} + {\beta^2 \cdot PPV(\ell) + TPR(\ell)}\right)$
    Weighted precision$PPV_{w}= \frac{1}{N} \sum\nolimits_{\ell \in L} PPV(\ell) + \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$
    Weighted recall$TPR_{w}= \frac{1}{N} \sum\nolimits_{\ell \in L} TPR(\ell) + \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$
    Weighted F-measure$F_{w}(\beta)= \frac{1}{N} \sum\nolimits_{\ell \in L} F(\beta, \ell) + \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$
    + +**Examples** + +
    +The following code snippets illustrate how to load a sample dataset, train a multiclass classification algorithm on +the data, and evaluate the performance of the algorithm by several multiclass classification evaluation metrics. + +
    + +{% highlight scala %} +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils + +// Load training data in LIBSVM format +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + +// Split data into training (60%) and test (40%) +val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) +training.cache() + +// Run training algorithm to build the model +val model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training) + +// Compute raw scores on the test set +val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) +} + +// Instantiate metrics object +val metrics = new MulticlassMetrics(predictionAndLabels) + +// Confusion matrix +println("Confusion matrix:") +println(metrics.confusionMatrix) + +// Overall Statistics +val precision = metrics.precision +val recall = metrics.recall // same as true positive rate +val f1Score = metrics.fMeasure +println("Summary Statistics") +println(s"Precision = $precision") +println(s"Recall = $recall") +println(s"F1 Score = $f1Score") + +// Precision by label +val labels = metrics.labels +labels.foreach { l => + println(s"Precision($l) = " + metrics.precision(l)) +} + +// Recall by label +labels.foreach { l => + println(s"Recall($l) = " + metrics.recall(l)) +} + +// False positive rate by label +labels.foreach { l => + println(s"FPR($l) = " + metrics.falsePositiveRate(l)) +} + +// F-measure by label +labels.foreach { l => + println(s"F1-Score($l) = " + metrics.fMeasure(l)) +} + +// Weighted stats +println(s"Weighted precision: ${metrics.weightedPrecision}") +println(s"Weighted recall: ${metrics.weightedRecall}") +println(s"Weighted F1 score: ${metrics.weightedFMeasure}") +println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}") + +{% endhighlight %} + +
    + +
    + +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class MulticlassClassification { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multiclass Classification Metrics"); + SparkContext sc = new SparkContext(conf); + String path = "data/mllib/sample_multiclass_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training.rdd()); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); + + // Confusion matrix + Matrix confusion = metrics.confusionMatrix(); + System.out.println("Confusion matrix: \n" + confusion); + + // Overall statistics + System.out.println("Precision = " + metrics.precision()); + System.out.println("Recall = " + metrics.recall()); + System.out.println("F1 Score = " + metrics.fMeasure()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length; i++) { + System.out.format("Class %f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); + System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); + System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure(metrics.labels()[i])); + } + + //Weighted stats + System.out.format("Weighted precision = %f\n", metrics.weightedPrecision()); + System.out.format("Weighted recall = %f\n", metrics.weightedRecall()); + System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure()); + System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate()); + + // Save and load model + model.save(sc, "myModelPath"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); + } +} + +{% endhighlight %} + +
    + +
    + +{% highlight python %} +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.util import MLUtils +from pyspark.mllib.evaluation import MulticlassMetrics + +# Load training data in LIBSVM format +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + +# Split data into training (60%) and test (40%) +training, test = data.randomSplit([0.6, 0.4], seed = 11L) +training.cache() + +# Run training algorithm to build the model +model = LogisticRegressionWithLBFGS.train(training, numClasses=3) + +# Compute raw scores on the test set +predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + +# Instantiate metrics object +metrics = MulticlassMetrics(predictionAndLabels) + +# Overall statistics +precision = metrics.precision() +recall = metrics.recall() +f1Score = metrics.fMeasure() +print("Summary Stats") +print("Precision = %s" % precision) +print("Recall = %s" % recall) +print("F1 Score = %s" % f1Score) + +# Statistics by class +labels = data.map(lambda lp: lp.label).distinct().collect() +for label in sorted(labels): + print("Class %s precision = %s" % (label, metrics.precision(label))) + print("Class %s recall = %s" % (label, metrics.recall(label))) + print("Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0))) + +# Weighted stats +print("Weighted recall = %s" % metrics.weightedRecall) +print("Weighted precision = %s" % metrics.weightedPrecision) +print("Weighted F(1) Score = %s" % metrics.weightedFMeasure()) +print("Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5)) +print("Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate) +{% endhighlight %} + +
    +
    + +### Multilabel classification + +A [multilabel classification](https://en.wikipedia.org/wiki/Multi-label_classification) problem involves mapping +each sample in a dataset to a set of class labels. In this type of classification problem, the labels are not +mutually exclusive. For example, when classifying a set of news articles into topics, a single article might be both +science and politics. + +Because the labels are not mutually exclusive, the predictions and true labels are now vectors of label *sets*, rather +than vectors of labels. Multilabel metrics, therefore, extend the fundamental ideas of precision, recall, etc. to +operations on sets. For example, a true positive for a given class now occurs when that class exists in the predicted +set and it exists in the true label set, for a specific data point. + +**Available metrics** + +Here we define a set $D$ of $N$ documents + +$$D = \left\{d_0, d_1, ..., d_{N-1}\right\}$$ + +Define $L_0, L_1, ..., L_{N-1}$ to be a family of label sets and $P_0, P_1, ..., P_{N-1}$ +to be a family of prediction sets where $L_i$ and $P_i$ are the label set and prediction set, respectively, that +correspond to document $d_i$. + +The set of all unique labels is given by + +$$L = \bigcup_{k=0}^{N-1} L_k$$ + +The following definition of indicator function $I_A(x)$ on a set $A$ will be necessary + +$$I_A(x) = \begin{cases}1 & \text{if $x \in A$}, \\ 0 & \text{otherwise}.\end{cases}$$ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    MetricDefinition
    Precision$\frac{1}{N} \sum_{i=0}^{N-1} \frac{\left|P_i \cap L_i\right|}{\left|P_i\right|}$
    Recall$\frac{1}{N} \sum_{i=0}^{N-1} \frac{\left|L_i \cap P_i\right|}{\left|L_i\right|}$
    Accuracy + $\frac{1}{N} \sum_{i=0}^{N - 1} \frac{\left|L_i \cap P_i \right|} + {\left|L_i\right| + \left|P_i\right| - \left|L_i \cap P_i \right|}$ +
    Precision by label$PPV(\ell)=\frac{TP}{TP + FP}= + \frac{\sum_{i=0}^{N-1} I_{P_i}(\ell) \cdot I_{L_i}(\ell)} + {\sum_{i=0}^{N-1} I_{P_i}(\ell)}$
    Recall by label$TPR(\ell)=\frac{TP}{P}= + \frac{\sum_{i=0}^{N-1} I_{P_i}(\ell) \cdot I_{L_i}(\ell)} + {\sum_{i=0}^{N-1} I_{L_i}(\ell)}$
    F1-measure by label$F1(\ell) = 2 + \cdot \left(\frac{PPV(\ell) \cdot TPR(\ell)} + {PPV(\ell) + TPR(\ell)}\right)$
    Hamming Loss + $\frac{1}{N \cdot \left|L\right|} \sum_{i=0}^{N - 1} \left|L_i\right| + \left|P_i\right| - 2\left|L_i + \cap P_i\right|$ +
    Subset Accuracy$\frac{1}{N} \sum_{i=0}^{N-1} I_{\{L_i\}}(P_i)$
    F1 Measure$\frac{1}{N} \sum_{i=0}^{N-1} 2 \frac{\left|P_i \cap L_i\right|}{\left|P_i\right| \cdot \left|L_i\right|}$
    Micro precision$\frac{TP}{TP + FP}=\frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|} + {\sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|P_i - L_i\right|}$
    Micro recall$\frac{TP}{TP + FN}=\frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|} + {\sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|L_i - P_i\right|}$
    Micro F1 Measure + $2 \cdot \frac{TP}{2 \cdot TP + FP + FN}=2 \cdot \frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|}{2 \cdot + \sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|L_i - P_i\right| + \sum_{i=0}^{N-1} + \left|P_i - L_i\right|}$ +
    + +**Examples** + +The following code snippets illustrate how to evaluate the performance of a multilabel classifer. The examples +use the fake prediction and label data for multilabel classification that is shown below. + +Document predictions: + +* doc 0 - predict 0, 1 - class 0, 2 +* doc 1 - predict 0, 2 - class 0, 1 +* doc 2 - predict none - class 0 +* doc 3 - predict 2 - class 2 +* doc 4 - predict 2, 0 - class 2, 0 +* doc 5 - predict 0, 1, 2 - class 0, 1 +* doc 6 - predict 1 - class 1, 2 + +Predicted classes: + +* class 0 - doc 0, 1, 4, 5 (total 4) +* class 1 - doc 0, 5, 6 (total 3) +* class 2 - doc 1, 3, 4, 5 (total 4) + +True classes: + +* class 0 - doc 0, 1, 2, 4, 5 (total 5) +* class 1 - doc 1, 5, 6 (total 3) +* class 2 - doc 0, 3, 4, 6 (total 4) + +
    + +
    + +{% highlight scala %} +import org.apache.spark.mllib.evaluation.MultilabelMetrics +import org.apache.spark.rdd.RDD; + +val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize( + Seq((Array(0.0, 1.0), Array(0.0, 2.0)), + (Array(0.0, 2.0), Array(0.0, 1.0)), + (Array(), Array(0.0)), + (Array(2.0), Array(2.0)), + (Array(2.0, 0.0), Array(2.0, 0.0)), + (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)), + (Array(1.0), Array(1.0, 2.0))), 2) + +// Instantiate metrics object +val metrics = new MultilabelMetrics(scoreAndLabels) + +// Summary stats +println(s"Recall = ${metrics.recall}") +println(s"Precision = ${metrics.precision}") +println(s"F1 measure = ${metrics.f1Measure}") +println(s"Accuracy = ${metrics.accuracy}") + +// Individual label stats +metrics.labels.foreach(label => println(s"Class $label precision = ${metrics.precision(label)}")) +metrics.labels.foreach(label => println(s"Class $label recall = ${metrics.recall(label)}")) +metrics.labels.foreach(label => println(s"Class $label F1-score = ${metrics.f1Measure(label)}")) + +// Micro stats +println(s"Micro recall = ${metrics.microRecall}") +println(s"Micro precision = ${metrics.microPrecision}") +println(s"Micro F1 measure = ${metrics.microF1Measure}") + +// Hamming loss +println(s"Hamming loss = ${metrics.hammingLoss}") + +// Subset accuracy +println(s"Subset accuracy = ${metrics.subsetAccuracy}") + +{% endhighlight %} + +
    + +
    + +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.mllib.evaluation.MultilabelMetrics; +import org.apache.spark.SparkConf; +import java.util.Arrays; +import java.util.List; + +public class MultilabelClassification { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics"); + JavaSparkContext sc = new JavaSparkContext(conf); + + List> data = Arrays.asList( + new Tuple2(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), + new Tuple2(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2(new double[]{}, new double[]{0.0}), + new Tuple2(new double[]{2.0}, new double[]{2.0}), + new Tuple2(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), + new Tuple2(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2(new double[]{1.0}, new double[]{1.0, 2.0}) + ); + JavaRDD> scoreAndLabels = sc.parallelize(data); + + // Instantiate metrics object + MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd()); + + // Summary stats + System.out.format("Recall = %f\n", metrics.recall()); + System.out.format("Precision = %f\n", metrics.precision()); + System.out.format("F1 measure = %f\n", metrics.f1Measure()); + System.out.format("Accuracy = %f\n", metrics.accuracy()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length - 1; i++) { + System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); + System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); + System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure(metrics.labels()[i])); + } + + // Micro stats + System.out.format("Micro recall = %f\n", metrics.microRecall()); + System.out.format("Micro precision = %f\n", metrics.microPrecision()); + System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure()); + + // Hamming loss + System.out.format("Hamming loss = %f\n", metrics.hammingLoss()); + + // Subset accuracy + System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy()); + + } +} + +{% endhighlight %} + +
    + +
    + +{% highlight python %} +from pyspark.mllib.evaluation import MultilabelMetrics + +scoreAndLabels = sc.parallelize([ + ([0.0, 1.0], [0.0, 2.0]), + ([0.0, 2.0], [0.0, 1.0]), + ([], [0.0]), + ([2.0], [2.0]), + ([2.0, 0.0], [2.0, 0.0]), + ([0.0, 1.0, 2.0], [0.0, 1.0]), + ([1.0], [1.0, 2.0])]) + +# Instantiate metrics object +metrics = MultilabelMetrics(scoreAndLabels) + +# Summary stats +print("Recall = %s" % metrics.recall()) +print("Precision = %s" % metrics.precision()) +print("F1 measure = %s" % metrics.f1Measure()) +print("Accuracy = %s" % metrics.accuracy) + +# Individual label stats +labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect() +for label in labels: + print("Class %s precision = %s" % (label, metrics.precision(label))) + print("Class %s recall = %s" % (label, metrics.recall(label))) + print("Class %s F1 Measure = %s" % (label, metrics.f1Measure(label))) + +# Micro stats +print("Micro precision = %s" % metrics.microPrecision) +print("Micro recall = %s" % metrics.microRecall) +print("Micro F1 measure = %s" % metrics.microF1Measure) + +# Hamming loss +print("Hamming loss = %s" % metrics.hammingLoss) + +# Subset accuracy +print("Subset accuracy = %s" % metrics.subsetAccuracy) + +{% endhighlight %} + +
    +
    + +### Ranking systems + +The role of a ranking algorithm (often thought of as a [recommender system](https://en.wikipedia.org/wiki/Recommender_system)) +is to return to the user a set of relevant items or documents based on some training data. The definition of relevance +may vary and is usually application specific. Ranking system metrics aim to quantify the effectiveness of these +rankings or recommendations in various contexts. Some metrics compare a set of recommended documents to a ground truth +set of relevant documents, while other metrics may incorporate numerical ratings explicitly. + +**Available metrics** + +A ranking system usually deals with a set of $M$ users + +$$U = \left\{u_0, u_1, ..., u_{M-1}\right\}$$ + +Each user ($u_i$) having a set of $N$ ground truth relevant documents + +$$D_i = \left\{d_0, d_1, ..., d_{N-1}\right\}$$ + +And a list of $Q$ recommended documents, in order of decreasing relevance + +$$R_i = \left[r_0, r_1, ..., r_{Q-1}\right]$$ + +The goal of the ranking system is to produce the most relevant set of documents for each user. The relevance of the +sets and the effectiveness of the algorithms can be measured using the metrics listed below. + +It is necessary to define a function which, provided a recommended document and a set of ground truth relevant +documents, returns a relevance score for the recommended document. + +$$rel_D(r) = \begin{cases}1 & \text{if $r \in D$}, \\ 0 & \text{otherwise}.\end{cases}$$ + + + + + + + + + + + + + + + + + + + + + + +
    MetricDefinitionNotes
    + Precision at k + + $p(k)=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{k} \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} rel_{D_i}(R_i(j))}$ + + Precision at k is a measure of + how many of the first k recommended documents are in the set of true relevant documents averaged across all + users. In this metric, the order of the recommendations is not taken into account. +
    Mean Average Precision + $MAP=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{\left|D_i\right|} \sum_{j=0}^{Q-1} \frac{rel_{D_i}(R_i(j))}{j + 1}}$ + + MAP is a measure of how + many of the recommended documents are in the set of true relevant documents, where the + order of the recommendations is taken into account (i.e. penalty for highly relevant documents is higher). +
    Normalized Discounted Cumulative Gain + $NDCG(k)=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{IDCG(D_i, k)}\sum_{j=0}^{n-1} + \frac{rel_{D_i}(R_i(j))}{\text{ln}(j+1)}} \\ + \text{Where} \\ + \hspace{5 mm} n = \text{min}\left(\text{max}\left(|R_i|,|D_i|\right),k\right) \\ + \hspace{5 mm} IDCG(D, k) = \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} \frac{1}{\text{ln}(j+1)}$ + + NDCG at k is a + measure of how many of the first k recommended documents are in the set of true relevant documents averaged + across all users. In contrast to precision at k, this metric takes into account the order of the recommendations + (documents are assumed to be in order of decreasing relevance). +
    + +**Examples** + +The following code snippets illustrate how to load a sample dataset, train an alternating least squares recommendation +model on the data, and evaluate the performance of the recommender by several ranking metrics. A brief summary of the +methodology is provided below. + +MovieLens ratings are on a scale of 1-5: + + * 5: Must see + * 4: Will enjoy + * 3: It's okay + * 2: Fairly bad + * 1: Awful + +So we should not recommend a movie if the predicted rating is less than 3. +To map ratings to confidence scores, we use: + + * 5 -> 2.5 + * 4 -> 1.5 + * 3 -> 0.5 + * 2 -> -0.5 + * 1 -> -1.5. + +This mappings means unobserved entries are generally between It's okay and Fairly bad. The semantics of 0 in this +expanded world of non-positive weights are "the same as never having interacted at all." + +
    + +
    + +{% highlight scala %} +import org.apache.spark.mllib.evaluation.{RegressionMetrics, RankingMetrics} +import org.apache.spark.mllib.recommendation.{ALS, Rating} + +// Read in the ratings data +val ratings = sc.textFile("data/mllib/sample_movielens_data.txt").map { line => + val fields = line.split("::") + Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5) +}.cache() + +// Map ratings to 1 or 0, 1 indicating a movie that should be recommended +val binarizedRatings = ratings.map(r => Rating(r.user, r.product, if (r.rating > 0) 1.0 else 0.0)).cache() + +// Summarize ratings +val numRatings = ratings.count() +val numUsers = ratings.map(_.user).distinct().count() +val numMovies = ratings.map(_.product).distinct().count() +println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") + +// Build the model +val numIterations = 10 +val rank = 10 +val lambda = 0.01 +val model = ALS.train(ratings, rank, numIterations, lambda) + +// Define a function to scale ratings from 0 to 1 +def scaledRating(r: Rating): Rating = { + val scaledRating = math.max(math.min(r.rating, 1.0), 0.0) + Rating(r.user, r.product, scaledRating) +} + +// Get sorted top ten predictions for each user and then scale from [0, 1] +val userRecommended = model.recommendProductsForUsers(10).map{ case (user, recs) => + (user, recs.map(scaledRating)) +} + +// Assume that any movie a user rated 3 or higher (which maps to a 1) is a relevant document +// Compare with top ten most relevant documents +val userMovies = binarizedRatings.groupBy(_.user) +val relevantDocuments = userMovies.join(userRecommended).map{ case (user, (actual, predictions)) => + (predictions.map(_.product), actual.filter(_.rating > 0.0).map(_.product).toArray) +} + +// Instantiate metrics object +val metrics = new RankingMetrics(relevantDocuments) + +// Precision at K +Array(1, 3, 5).foreach{ k => + println(s"Precision at $k = ${metrics.precisionAt(k)}") +} + +// Mean average precision +println(s"Mean average precision = ${metrics.meanAveragePrecision}") + +// Normalized discounted cumulative gain +Array(1, 3, 5).foreach{ k => + println(s"NDCG at $k = ${metrics.ndcgAt(k)}") +} + +// Get predictions for each data point +val allPredictions = model.predict(ratings.map(r => (r.user, r.product))).map(r => ((r.user, r.product), r.rating)) +val allRatings = ratings.map(r => ((r.user, r.product), r.rating)) +val predictionsAndLabels = allPredictions.join(allRatings).map{ case ((user, product), (predicted, actual)) => + (predicted, actual) +} + +// Get the RMSE using regression metrics +val regressionMetrics = new RegressionMetrics(predictionsAndLabels) +println(s"RMSE = ${regressionMetrics.rootMeanSquaredError}") + +// R-squared +println(s"R-squared = ${regressionMetrics.r2}") + +{% endhighlight %} + +
    + +
    + +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.Function; +import java.util.*; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.mllib.evaluation.RankingMetrics; +import org.apache.spark.mllib.recommendation.ALS; +import org.apache.spark.mllib.recommendation.Rating; + +// Read in the ratings data +public class Ranking { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Ranking Metrics"); + JavaSparkContext sc = new JavaSparkContext(conf); + String path = "data/mllib/sample_movielens_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD ratings = data.map( + new Function() { + public Rating call(String line) { + String[] parts = line.split("::"); + return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double.parseDouble(parts[2]) - 2.5); + } + } + ); + ratings.cache(); + + // Train an ALS model + final MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), 10, 10, 0.01); + + // Get top 10 recommendations for every user and scale ratings from 0 to 1 + JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD(); + JavaRDD> userRecsScaled = userRecs.map( + new Function, Tuple2>() { + public Tuple2 call(Tuple2 t) { + Rating[] scaledRatings = new Rating[t._2().length]; + for (int i = 0; i < scaledRatings.length; i++) { + double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0); + scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating); + } + return new Tuple2(t._1(), scaledRatings); + } + } + ); + JavaPairRDD userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled); + + // Map ratings to 1 or 0, 1 indicating a movie that should be recommended + JavaRDD binarizedRatings = ratings.map( + new Function() { + public Rating call(Rating r) { + double binaryRating; + if (r.rating() > 0.0) { + binaryRating = 1.0; + } + else { + binaryRating = 0.0; + } + return new Rating(r.user(), r.product(), binaryRating); + } + } + ); + + // Group ratings by common user + JavaPairRDD> userMovies = binarizedRatings.groupBy( + new Function() { + public Object call(Rating r) { + return r.user(); + } + } + ); + + // Get true relevant documents from all user ratings + JavaPairRDD> userMoviesList = userMovies.mapValues( + new Function, List>() { + public List call(Iterable docs) { + List products = new ArrayList(); + for (Rating r : docs) { + if (r.rating() > 0.0) { + products.add(r.product()); + } + } + return products; + } + } + ); + + // Extract the product id from each recommendation + JavaPairRDD> userRecommendedList = userRecommended.mapValues( + new Function>() { + public List call(Rating[] docs) { + List products = new ArrayList(); + for (Rating r : docs) { + products.add(r.product()); + } + return products; + } + } + ); + JavaRDD, List>> relevantDocs = userMoviesList.join(userRecommendedList).values(); + + // Instantiate the metrics object + RankingMetrics metrics = RankingMetrics.of(relevantDocs); + + // Precision and NDCG at k + Integer[] kVector = {1, 3, 5}; + for (Integer k : kVector) { + System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k)); + System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k)); + } + + // Mean average precision + System.out.format("Mean average precision = %f\n", metrics.meanAveragePrecision()); + + // Evaluate the model using numerical ratings and regression metrics + JavaRDD> userProducts = ratings.map( + new Function>() { + public Tuple2 call(Rating r) { + return new Tuple2(r.user(), r.product()); + } + } + ); + JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD( + model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( + new Function, Object>>() { + public Tuple2, Object> call(Rating r){ + return new Tuple2, Object>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )); + JavaRDD> ratesAndPreds = + JavaPairRDD.fromJavaRDD(ratings.map( + new Function, Object>>() { + public Tuple2, Object> call(Rating r){ + return new Tuple2, Object>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )).join(predictions).values(); + + // Create regression metrics object + RegressionMetrics regressionMetrics = new RegressionMetrics(ratesAndPreds.rdd()); + + // Root mean squared error + System.out.format("RMSE = %f\n", regressionMetrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R-squared = %f\n", regressionMetrics.r2()); + } +} + +{% endhighlight %} + +
    + +
    + +{% highlight python %} +from pyspark.mllib.recommendation import ALS, Rating +from pyspark.mllib.evaluation import RegressionMetrics, RankingMetrics + +# Read in the ratings data +lines = sc.textFile("data/mllib/sample_movielens_data.txt") + +def parseLine(line): + fields = line.split("::") + return Rating(int(fields[0]), int(fields[1]), float(fields[2]) - 2.5) +ratings = lines.map(lambda r: parseLine(r)) + +# Train a model on to predict user-product ratings +model = ALS.train(ratings, 10, 10, 0.01) + +# Get predicted ratings on all existing user-product pairs +testData = ratings.map(lambda p: (p.user, p.product)) +predictions = model.predictAll(testData).map(lambda r: ((r.user, r.product), r.rating)) + +ratingsTuple = ratings.map(lambda r: ((r.user, r.product), r.rating)) +scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1]) + +# Instantiate regression metrics to compare predicted and actual ratings +metrics = RegressionMetrics(scoreAndLabels) + +# Root mean sqaured error +print("RMSE = %s" % metrics.rootMeanSquaredError) + +# R-squared +print("R-squared = %s" % metrics.r2) + +{% endhighlight %} + +
    +
    + +## Regression model evaluation + +[Regression analysis](https://en.wikipedia.org/wiki/Regression_analysis) is used when predicting a continuous output +variable from a number of independent variables. + +**Available metrics** + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    MetricDefinition
    Mean Squared Error (MSE)$MSE = \frac{\sum_{i=0}^{N-1} (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{N}$
    Root Mean Squared Error (RMSE)$RMSE = \sqrt{\frac{\sum_{i=0}^{N-1} (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{N}}$
    Mean Absoloute Error (MAE)$MAE=\sum_{i=0}^{N-1} \left|\mathbf{y}_i - \hat{\mathbf{y}}_i\right|$
    Coefficient of Determination $(R^2)$$R^2=1 - \frac{MSE}{\text{VAR}(\mathbf{y}) \cdot (N-1)}=1-\frac{\sum_{i=0}^{N-1} + (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{\sum_{i=0}^{N-1}(\mathbf{y}_i-\bar{\mathbf{y}})^2}$
    Explained Variance$1 - \frac{\text{VAR}(\mathbf{y} - \mathbf{\hat{y}})}{\text{VAR}(\mathbf{y})}$
    + +**Examples** + +
    +The following code snippets illustrate how to load a sample dataset, train a linear regression algorithm on the data, +and evaluate the performance of the algorithm by several regression metrics. + +
    + +{% highlight scala %} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.LinearRegressionModel +import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.evaluation.RegressionMetrics +import org.apache.spark.mllib.util.MLUtils + +// Load the data +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_linear_regression_data.txt").cache() + +// Build the model +val numIterations = 100 +val model = LinearRegressionWithSGD.train(data, numIterations) + +// Get predictions +val valuesAndPreds = data.map{ point => + val prediction = model.predict(point.features) + (prediction, point.label) +} + +// Instantiate metrics object +val metrics = new RegressionMetrics(valuesAndPreds) + +// Squared error +println(s"MSE = ${metrics.meanSquaredError}") +println(s"RMSE = ${metrics.rootMeanSquaredError}") + +// R-squared +println(s"R-squared = ${metrics.r2}") + +// Mean absolute error +println(s"MAE = ${metrics.meanAbsoluteError}") + +// Explained variance +println(s"Explained variance = ${metrics.explainedVariance}") + +{% endhighlight %} + +
    + +
    + +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.regression.LinearRegressionModel; +import org.apache.spark.mllib.regression.LinearRegressionWithSGD; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.SparkConf; + +public class LinearRegression { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // Load and parse the data + String path = "data/mllib/sample_linear_regression_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public LabeledPoint call(String line) { + String[] parts = line.split(" "); + double[] v = new double[parts.length - 1]; + for (int i = 1; i < parts.length - 1; i++) + v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); + return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); + } + } + ); + parsedData.cache(); + + // Building the model + int numIterations = 100; + final LinearRegressionModel model = + LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); + + // Evaluate model on training examples and compute training error + JavaRDD> valuesAndPreds = parsedData.map( + new Function>() { + public Tuple2 call(LabeledPoint point) { + double prediction = model.predict(point.features()); + return new Tuple2(prediction, point.label()); + } + } + ); + + // Instantiate metrics object + RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd()); + + // Squared error + System.out.format("MSE = %f\n", metrics.meanSquaredError()); + System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R Squared = %f\n", metrics.r2()); + + // Mean absolute error + System.out.format("MAE = %f\n", metrics.meanAbsoluteError()); + + // Explained variance + System.out.format("Explained Variance = %f\n", metrics.explainedVariance()); + + // Save and load model + model.save(sc.sc(), "myModelPath"); + LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath"); + } +} + +{% endhighlight %} + +
    + +
    + +{% highlight python %} +from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD +from pyspark.mllib.evaluation import RegressionMetrics +from pyspark.mllib.linalg import DenseVector + +# Load and parse the data +def parsePoint(line): + values = line.split() + return LabeledPoint(float(values[0]), DenseVector([float(x.split(':')[1]) for x in values[1:]])) + +data = sc.textFile("data/mllib/sample_linear_regression_data.txt") +parsedData = data.map(parsePoint) + +# Build the model +model = LinearRegressionWithSGD.train(parsedData) + +# Get predictions +valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.label)) + +# Instantiate metrics object +metrics = RegressionMetrics(valuesAndPreds) + +# Squared Error +print("MSE = %s" % metrics.meanSquaredError) +print("RMSE = %s" % metrics.rootMeanSquaredError) + +# R-squared +print("R-squared = %s" % metrics.r2) + +# Mean absolute error +print("MAE = %s" % metrics.meanAbsoluteError) + +# Explained variance +print("Explained variance = %s" % metrics.explainedVariance) + +{% endhighlight %} + +
    +
    \ No newline at end of file diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 83e937635a55..7e417ed5f37a 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -221,7 +221,7 @@ model = word2vec.fit(inp) synonyms = model.findSynonyms('china', 40) for word, cosine_distance in synonyms: - print "{}: {}".format(word, cosine_distance) + print("{}: {}".format(word, cosine_distance)) {% endhighlight %}
    @@ -380,35 +380,43 @@ data2 = labels.zip(normalizer2.transform(features)) -## Feature selection -[Feature selection](http://en.wikipedia.org/wiki/Feature_selection) allows selecting the most relevant features for use in model construction. Feature selection reduces the size of the vector space and, in turn, the complexity of any subsequent operation with vectors. The number of features to select can be tuned using a held-out validation set. +## ChiSqSelector -### ChiSqSelector -[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) stands for Chi-Squared feature selection. It operates on labeled data with categorical features. `ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, and then filters (selects) the top features which are most closely related to the label. +[Feature selection](http://en.wikipedia.org/wiki/Feature_selection) tries to identify relevant +features for use in model construction. It reduces the size of the feature space, which can improve +both speed and statistical learning behavior. -#### Model Fitting +[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) implements +Chi-Squared feature selection. It operates on labeled data with categorical features. +`ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, +and then filters (selects) the top features which the class label depends on the most. +This is akin to yielding the features with the most predictive power. -[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) has the -following parameters in the constructor: +The number of features to select can be tuned using a held-out validation set. -* `numTopFeatures` number of top features that the selector will select (filter). +### Model Fitting -We provide a [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method in -`ChiSqSelector` which can take an input of `RDD[LabeledPoint]` with categorical features, learn the summary statistics, and then -return a `ChiSqSelectorModel` which can transform an input dataset into the reduced feature space. +`ChiSqSelector` takes a `numTopFeatures` parameter specifying the number of top features that +the selector will select. -This model implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer) -which can apply the Chi-Squared feature selection on a `Vector` to produce a reduced `Vector` or on +The [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method takes +an input of `RDD[LabeledPoint]` with categorical features, learns the summary statistics, and then +returns a `ChiSqSelectorModel` which can transform an input dataset into the reduced feature space. +The `ChiSqSelectorModel` can be applied either to a `Vector` to produce a reduced `Vector`, or to an `RDD[Vector]` to produce a reduced `RDD[Vector]`. Note that the user can also construct a `ChiSqSelectorModel` by hand by providing an array of selected feature indices (which must be sorted in ascending order). -#### Example +### Example -The following example shows the basic use of ChiSqSelector. +The following example shows the basic use of ChiSqSelector. The data set used has a feature matrix consisting of greyscale values that vary from 0 to 255 for each feature.
    -
    +
    + +Refer to the [`ChiSqSelector` Scala docs](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) +for details on the API. + {% highlight scala %} import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vectors @@ -419,10 +427,11 @@ import org.apache.spark.mllib.feature.ChiSqSelector // Load some data in libsvm format val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") // Discretize data in 16 equal bins since ChiSqSelector requires categorical features +// Even though features are doubles, the ChiSqSelector treats each unique value as a category val discretizedData = data.map { lp => - LabeledPoint(lp.label, Vectors.dense(lp.features.toArray.map { x => x / 16 } ) ) + LabeledPoint(lp.label, Vectors.dense(lp.features.toArray.map { x => (x / 16).floor } ) ) } -// Create ChiSqSelector that will select 50 features +// Create ChiSqSelector that will select top 50 of 692 features val selector = new ChiSqSelector(50) // Create ChiSqSelector model (selecting features) val transformer = selector.fit(discretizedData) @@ -433,7 +442,11 @@ val filteredData = discretizedData.map { lp => {% endhighlight %}
    -
    +
    + +Refer to the [`ChiSqSelector` Java docs](api/java/org/apache/spark/mllib/feature/ChiSqSelector.html) +for details on the API. + {% highlight java %} import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; @@ -451,19 +464,20 @@ JavaRDD points = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt").toJavaRDD().cache(); // Discretize data in 16 equal bins since ChiSqSelector requires categorical features +// Even though features are doubles, the ChiSqSelector treats each unique value as a category JavaRDD discretizedData = points.map( new Function() { @Override public LabeledPoint call(LabeledPoint lp) { final double[] discretizedFeatures = new double[lp.features().size()]; for (int i = 0; i < lp.features().size(); ++i) { - discretizedFeatures[i] = lp.features().apply(i) / 16; + discretizedFeatures[i] = Math.floor(lp.features().apply(i) / 16); } return new LabeledPoint(lp.label(), Vectors.dense(discretizedFeatures)); } }); -// Create ChiSqSelector that will select 50 features +// Create ChiSqSelector that will select top 50 of 692 features ChiSqSelector selector = new ChiSqSelector(50); // Create ChiSqSelector model (selecting features) final ChiSqSelectorModel transformer = selector.fit(discretizedData.rdd()); @@ -484,7 +498,12 @@ sc.stop(); ## ElementwiseProduct -ElementwiseProduct multiplies each input vector by a provided "weight" vector, using element-wise multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) between the input vector, `v` and transforming vector, `w`, to yield a result vector. +`ElementwiseProduct` multiplies each input vector by a provided "weight" vector, using element-wise +multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This +represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) +between the input vector, `v` and transforming vector, `scalingVec`, to yield a result vector. +Qu8T948*1# +Denoting the `scalingVec` as "`w`," this transformation may be written as: `\[ \begin{pmatrix} v_1 \\ @@ -504,7 +523,7 @@ v_N [`ElementwiseProduct`](api/scala/index.html#org.apache.spark.mllib.feature.ElementwiseProduct) has the following parameter in the constructor: -* `w`: the transforming vector. +* `scalingVec`: the transforming vector. `ElementwiseProduct` implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer) which can apply the weighting on a `Vector` to produce a transformed `Vector` or on an `RDD[Vector]` to produce a transformed `RDD[Vector]`. diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index bcc066a18552..4d4f5cfdc564 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -41,16 +41,23 @@ MLlib's FP-growth implementation takes the following (hyper-)parameters: [`FPGrowth`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowth) implements the FP-growth algorithm. -It take a `JavaRDD` of transactions, where each transaction is an `Iterable` of items of a generic type. +It take a `RDD` of transactions, where each transaction is an `Array` of items of a generic type. Calling `FPGrowth.run` with transactions returns an [`FPGrowthModel`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowthModel) -that stores the frequent itemsets with their frequencies. +that stores the frequent itemsets with their frequencies. The following +example illustrates how to mine frequent itemsets and association rules +(see [Association +Rules](mllib-frequent-pattern-mining.html#association-rules) for +details) from `transactions`. + {% highlight scala %} import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel} +import org.apache.spark.mllib.fpm.FPGrowth + +val data = sc.textFile("data/mllib/sample_fpgrowth.txt") -val transactions: RDD[Array[String]] = ... +val transactions: RDD[Array[String]] = data.map(s => s.trim.split(' ')) val fpg = new FPGrowth() .setMinSupport(0.2) @@ -60,6 +67,14 @@ val model = fpg.run(transactions) model.freqItemsets.collect().foreach { itemset => println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) } + +val minConfidence = 0.8 +model.generateAssociationRules(minConfidence).collect().foreach { rule => + println( + rule.antecedent.mkString("[", ",", "]") + + " => " + rule.consequent .mkString("[", ",", "]") + + ", " + rule.confidence) +} {% endhighlight %}
    @@ -68,21 +83,38 @@ model.freqItemsets.collect().foreach { itemset => [`FPGrowth`](api/java/org/apache/spark/mllib/fpm/FPGrowth.html) implements the FP-growth algorithm. -It take an `RDD` of transactions, where each transaction is an `Array` of items of a generic type. +It take an `JavaRDD` of transactions, where each transaction is an `Iterable` of items of a generic type. Calling `FPGrowth.run` with transactions returns an [`FPGrowthModel`](api/java/org/apache/spark/mllib/fpm/FPGrowthModel.html) -that stores the frequent itemsets with their frequencies. +that stores the frequent itemsets with their frequencies. The following +example illustrates how to mine frequent itemsets and association rules +(see [Association +Rules](mllib-frequent-pattern-mining.html#association-rules) for +details) from `transactions`. {% highlight java %} +import java.util.Arrays; import java.util.List; -import com.google.common.base.Joiner; - import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.fpm.AssociationRules; import org.apache.spark.mllib.fpm.FPGrowth; import org.apache.spark.mllib.fpm.FPGrowthModel; -JavaRDD> transactions = ... +SparkConf conf = new SparkConf().setAppName("FP-growth Example"); +JavaSparkContext sc = new JavaSparkContext(conf); + +JavaRDD data = sc.textFile("data/mllib/sample_fpgrowth.txt"); + +JavaRDD> transactions = data.map( + new Function>() { + public List call(String line) { + String[] parts = line.split(" "); + return Arrays.asList(parts); + } + } +); FPGrowth fpg = new FPGrowth() .setMinSupport(0.2) @@ -90,9 +122,202 @@ FPGrowth fpg = new FPGrowth() FPGrowthModel model = fpg.run(transactions); for (FPGrowth.FreqItemset itemset: model.freqItemsets().toJavaRDD().collect()) { - System.out.println("[" + Joiner.on(",").join(s.javaItems()) + "], " + s.freq()); + System.out.println("[" + itemset.javaItems() + "], " + itemset.freq()); +} + +double minConfidence = 0.8; +for (AssociationRules.Rule rule + : model.generateAssociationRules(minConfidence).toJavaRDD().collect()) { + System.out.println( + rule.javaAntecedent() + " => " + rule.javaConsequent() + ", " + rule.confidence()); } {% endhighlight %}
    + +
    + +[`FPGrowth`](api/python/pyspark.mllib.html#pyspark.mllib.fpm.FPGrowth) implements the +FP-growth algorithm. +It take an `RDD` of transactions, where each transaction is an `List` of items of a generic type. +Calling `FPGrowth.train` with transactions returns an +[`FPGrowthModel`](api/python/pyspark.mllib.html#pyspark.mllib.fpm.FPGrowthModel) +that stores the frequent itemsets with their frequencies. + +{% highlight python %} +from pyspark.mllib.fpm import FPGrowth + +data = sc.textFile("data/mllib/sample_fpgrowth.txt") + +transactions = data.map(lambda line: line.strip().split(' ')) + +model = FPGrowth.train(transactions, minSupport=0.2, numPartitions=10) + +result = model.freqItemsets().collect() +for fi in result: + print(fi) +{% endhighlight %} +
    + +
    + +## Association Rules + +
    +
    +[AssociationRules](api/scala/index.html#org.apache.spark.mllib.fpm.AssociationRules) +implements a parallel rule generation algorithm for constructing rules +that have a single item as the consequent. + +{% highlight scala %} +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.fpm.AssociationRules +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset + +val freqItemsets = sc.parallelize(Seq( + new FreqItemset(Array("a"), 15L), + new FreqItemset(Array("b"), 35L), + new FreqItemset(Array("a", "b"), 12L) +)); + +val ar = new AssociationRules() + .setMinConfidence(0.8) +val results = ar.run(freqItemsets) + +results.collect().foreach { rule => + println("[" + rule.antecedent.mkString(",") + + "=>" + + rule.consequent.mkString(",") + "]," + rule.confidence) +} +{% endhighlight %} + +
    + +
    +[AssociationRules](api/java/org/apache/spark/mllib/fpm/AssociationRules.html) +implements a parallel rule generation algorithm for constructing rules +that have a single item as the consequent. + +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.fpm.AssociationRules; +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset; + +JavaRDD> freqItemsets = sc.parallelize(Arrays.asList( + new FreqItemset(new String[] {"a"}, 15L), + new FreqItemset(new String[] {"b"}, 35L), + new FreqItemset(new String[] {"a", "b"}, 12L) +)); + +AssociationRules arules = new AssociationRules() + .setMinConfidence(0.8); +JavaRDD> results = arules.run(freqItemsets); + +for (AssociationRules.Rule rule: results.collect()) { + System.out.println( + rule.javaAntecedent() + " => " + rule.javaConsequent() + ", " + rule.confidence()); +} +{% endhighlight %} + +
    +
    + +## PrefixSpan + +PrefixSpan is a sequential pattern mining algorithm described in +[Pei et al., Mining Sequential Patterns by Pattern-Growth: The +PrefixSpan Approach](http://dx.doi.org/10.1109%2FTKDE.2004.77). We refer +the reader to the referenced paper for formalizing the sequential +pattern mining problem. + +MLlib's PrefixSpan implementation takes the following parameters: + +* `minSupport`: the minimum support required to be considered a frequent + sequential pattern. +* `maxPatternLength`: the maximum length of a frequent sequential + pattern. Any frequent pattern exceeding this length will not be + included in the results. +* `maxLocalProjDBSize`: the maximum number of items allowed in a + prefix-projected database before local iterative processing of the + projected databse begins. This parameter should be tuned with respect + to the size of your executors. + +**Examples** + +The following example illustrates PrefixSpan running on the sequences +(using same notation as Pei et al): + +~~~ + <(12)3> + <1(32)(12)> + <(12)5> + <6> +~~~ + +
    +
    + +[`PrefixSpan`](api/scala/index.html#org.apache.spark.mllib.fpm.PrefixSpan) implements the +PrefixSpan algorithm. +Calling `PrefixSpan.run` returns a +[`PrefixSpanModel`](api/scala/index.html#org.apache.spark.mllib.fpm.PrefixSpanModel) +that stores the frequent sequences with their frequencies. + +{% highlight scala %} +import org.apache.spark.mllib.fpm.PrefixSpan + +val sequences = sc.parallelize(Seq( + Array(Array(1, 2), Array(3)), + Array(Array(1), Array(3, 2), Array(1, 2)), + Array(Array(1, 2), Array(5)), + Array(Array(6)) + ), 2).cache() +val prefixSpan = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) +val model = prefixSpan.run(sequences) +model.freqSequences.collect().foreach { freqSequence => +println( + freqSequence.sequence.map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]") + ", " + freqSequence.freq) +} +{% endhighlight %} + +
    + +
    + +[`PrefixSpan`](api/java/org/apache/spark/mllib/fpm/PrefixSpan.html) implements the +PrefixSpan algorithm. +Calling `PrefixSpan.run` returns a +[`PrefixSpanModel`](api/java/org/apache/spark/mllib/fpm/PrefixSpanModel.html) +that stores the frequent sequences with their frequencies. + +{% highlight java %} +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.mllib.fpm.PrefixSpan; +import org.apache.spark.mllib.fpm.PrefixSpanModel; + +JavaRDD>> sequences = sc.parallelize(Arrays.asList( + Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)), + Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)), + Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)), + Arrays.asList(Arrays.asList(6)) +), 2); +PrefixSpan prefixSpan = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5); +PrefixSpanModel model = prefixSpan.run(sequences); +for (PrefixSpan.FreqSequence freqSeq: model.freqSequences().toJavaRDD().collect()) { + System.out.println(freqSeq.javaSequence() + ", " + freqSeq.freq()); +} +{% endhighlight %} + +
    +
    + diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index d2d1cc93fe00..91e50ccfecec 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -5,37 +5,44 @@ displayTitle: Machine Learning Library (MLlib) Guide description: MLlib machine learning library overview for Spark SPARK_VERSION_SHORT --- -MLlib is Spark's scalable machine learning library consisting of common learning algorithms and utilities, -including classification, regression, clustering, collaborative -filtering, dimensionality reduction, as well as underlying optimization primitives. -Guides for individual algorithms are listed below. +MLlib is Spark's machine learning (ML) library. +Its goal is to make practical machine learning scalable and easy. +It consists of common learning algorithms and utilities, including classification, regression, +clustering, collaborative filtering, dimensionality reduction, as well as lower-level optimization +primitives and higher-level pipeline APIs. -The API is divided into 2 parts: +It divides into two packages: -* [The original `spark.mllib` API](mllib-guide.html#mllib-types-algorithms-and-utilities) is the primary API. -* [The "Pipelines" `spark.ml` API](mllib-guide.html#sparkml-high-level-apis-for-ml-pipelines) is a higher-level API for constructing ML workflows. +* [`spark.mllib`](mllib-guide.html#data-types-algorithms-and-utilities) contains the original API + built on top of [RDDs](programming-guide.html#resilient-distributed-datasets-rdds). +* [`spark.ml`](ml-guide.html) provides higher-level API + built on top of [DataFrames](sql-programming-guide.html#dataframes) for constructing ML pipelines. -We list major functionality from both below, with links to detailed guides. +Using `spark.ml` is recommended because with DataFrames the API is more versatile and flexible. +But we will keep supporting `spark.mllib` along with the development of `spark.ml`. +Users should be comfortable using `spark.mllib` features and expect more features coming. +Developers should contribute new algorithms to `spark.ml` if they fit the ML pipeline concept well, +e.g., feature extractors and transformers. -# MLlib types, algorithms and utilities +We list major functionality from both below, with links to detailed guides. -This lists functionality included in `spark.mllib`, the main MLlib API. +# spark.mllib: data types, algorithms, and utilities * [Data types](mllib-data-types.html) * [Basic statistics](mllib-statistics.html) - * summary statistics - * correlations - * stratified sampling - * hypothesis testing - * random data generation + * [summary statistics](mllib-statistics.html#summary-statistics) + * [correlations](mllib-statistics.html#correlations) + * [stratified sampling](mllib-statistics.html#stratified-sampling) + * [hypothesis testing](mllib-statistics.html#hypothesis-testing) + * [random data generation](mllib-statistics.html#random-data-generation) * [Classification and regression](mllib-classification-regression.html) * [linear models (SVMs, logistic regression, linear regression)](mllib-linear-methods.html) * [naive Bayes](mllib-naive-bayes.html) * [decision trees](mllib-decision-tree.html) - * [ensembles of trees](mllib-ensembles.html) (Random Forests and Gradient-Boosted Trees) + * [ensembles of trees (Random Forests and Gradient-Boosted Trees)](mllib-ensembles.html) * [isotonic regression](mllib-isotonic-regression.html) * [Collaborative filtering](mllib-collaborative-filtering.html) - * alternating least squares (ALS) + * [alternating least squares (ALS)](mllib-collaborative-filtering.html#collaborative-filtering) * [Clustering](mllib-clustering.html) * [k-means](mllib-clustering.html#k-means) * [Gaussian mixture](mllib-clustering.html#gaussian-mixture) @@ -43,78 +50,76 @@ This lists functionality included in `spark.mllib`, the main MLlib API. * [latent Dirichlet allocation (LDA)](mllib-clustering.html#latent-dirichlet-allocation-lda) * [streaming k-means](mllib-clustering.html#streaming-k-means) * [Dimensionality reduction](mllib-dimensionality-reduction.html) - * singular value decomposition (SVD) - * principal component analysis (PCA) + * [singular value decomposition (SVD)](mllib-dimensionality-reduction.html#singular-value-decomposition-svd) + * [principal component analysis (PCA)](mllib-dimensionality-reduction.html#principal-component-analysis-pca) * [Feature extraction and transformation](mllib-feature-extraction.html) * [Frequent pattern mining](mllib-frequent-pattern-mining.html) - * FP-growth -* [Optimization (developer)](mllib-optimization.html) - * stochastic gradient descent - * limited-memory BFGS (L-BFGS) + * [FP-growth](mllib-frequent-pattern-mining.html#fp-growth) + * [association rules](mllib-frequent-pattern-mining.html#association-rules) + * [PrefixSpan](mllib-frequent-pattern-mining.html#prefix-span) +* [Evaluation metrics](mllib-evaluation-metrics.html) * [PMML model export](mllib-pmml-model-export.html) - -MLlib is under active development. -The APIs marked `Experimental`/`DeveloperApi` may change in future releases, -and the migration guide below will explain all changes between releases. +* [Optimization (developer)](mllib-optimization.html) + * [stochastic gradient descent](mllib-optimization.html#stochastic-gradient-descent-sgd) + * [limited-memory BFGS (L-BFGS)](mllib-optimization.html#limited-memory-bfgs-l-bfgs) # spark.ml: high-level APIs for ML pipelines -Spark 1.2 introduced a new package called `spark.ml`, which aims to provide a uniform set of -high-level APIs that help users create and tune practical machine learning pipelines. - -*Graduated from Alpha!* The Pipelines API is no longer an alpha component, although many elements of it are still `Experimental` or `DeveloperApi`. +**[spark.ml programming guide](ml-guide.html)** provides an overview of the Pipelines API and major +concepts. It also contains sections on using algorithms within the Pipelines API, for example: -Note that we will keep supporting and adding features to `spark.mllib` along with the -development of `spark.ml`. -Users should be comfortable using `spark.mllib` features and expect more features coming. -Developers should contribute new algorithms to `spark.mllib` and can optionally contribute -to `spark.ml`. - -More detailed guides for `spark.ml` include: - -* **[spark.ml programming guide](ml-guide.html)**: overview of the Pipelines API and major concepts -* [Feature transformers](ml-features.html): Details on transformers supported in the Pipelines API, including a few not in the lower-level `spark.mllib` API -* [Ensembles](ml-ensembles.html): Details on ensemble learning methods in the Pipelines API +* [Feature extraction, transformation, and selection](ml-features.html) +* [Decision trees for classification and regression](ml-decision-tree.html) +* [Ensembles](ml-ensembles.html) +* [Linear methods with elastic net regularization](ml-linear-methods.html) +* [Multilayer perceptron classifier](ml-ann.html) # Dependencies -MLlib uses the linear algebra package -[Breeze](http://www.scalanlp.org/), which depends on -[netlib-java](https://github.com/fommil/netlib-java) for optimised -numerical processing. If natives are not available at runtime, you -will see a warning message and a pure JVM implementation will be used -instead. +MLlib uses the linear algebra package [Breeze](http://www.scalanlp.org/), which depends on +[netlib-java](https://github.com/fommil/netlib-java) for optimised numerical processing. +If natives libraries[^1] are not available at runtime, you will see a warning message and a pure JVM +implementation will be used instead. -To learn more about the benefits and background of system optimised -natives, you may wish to watch Sam Halliday's ScalaX talk on -[High Performance Linear Algebra in Scala](http://fommil.github.io/scalax14/#/)). +Due to licensing issues with runtime proprietary binaries, we do not include `netlib-java`'s native +proxies by default. +To configure `netlib-java` / Breeze to use system optimised binaries, include +`com.github.fommil.netlib:all:1.1.2` (or build Spark with `-Pnetlib-lgpl`) as a dependency of your +project and read the [netlib-java](https://github.com/fommil/netlib-java) documentation for your +platform's additional installation instructions. -Due to licensing issues with runtime proprietary binaries, we do not -include `netlib-java`'s native proxies by default. To configure -`netlib-java` / Breeze to use system optimised binaries, include -`com.github.fommil.netlib:all:1.1.2` (or build Spark with -`-Pnetlib-lgpl`) as a dependency of your project and read the -[netlib-java](https://github.com/fommil/netlib-java) documentation for -your platform's additional installation instructions. +To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4 or newer. -To use MLlib in Python, you will need [NumPy](http://www.numpy.org) -version 1.4 or newer. +[^1]: To learn more about the benefits and background of system optimised natives, you may wish to + watch Sam Halliday's ScalaX talk on [High Performance Linear Algebra in Scala](http://fommil.github.io/scalax14/#/). ---- +# Migration guide -# Migration Guide +MLlib is under active development. +The APIs marked `Experimental`/`DeveloperApi` may change in future releases, +and the migration guide below will explain all changes between releases. + +## From 1.4 to 1.5 -For the `spark.ml` package, please see the [spark.ml Migration Guide](ml-guide.html#migration-guide). +In the `spark.mllib` package, there are no break API changes but several behavior changes: -## From 1.3 to 1.4 +* [SPARK-9005](https://issues.apache.org/jira/browse/SPARK-9005): + `RegressionMetrics.explainedVariance` returns the average regression sum of squares. +* [SPARK-8600](https://issues.apache.org/jira/browse/SPARK-8600): `NaiveBayesModel.labels` become + sorted. +* [SPARK-3382](https://issues.apache.org/jira/browse/SPARK-3382): `GradientDescent` has a default + convergence tolerance `1e-3`, and hence iterations might end earlier than 1.4. -In the `spark.mllib` package, there were several breaking changes, but all in `DeveloperApi` or `Experimental` APIs: +In the `spark.ml` package, there exists one break API change and one behavior change: -* Gradient-Boosted Trees - * *(Breaking change)* The signature of the [`Loss.gradient`](api/scala/index.html#org.apache.spark.mllib.tree.loss.Loss) method was changed. This is only an issues for users who wrote their own losses for GBTs. - * *(Breaking change)* The `apply` and `copy` methods for the case class [`BoostingStrategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.BoostingStrategy) have been changed because of a modification to the case class fields. This could be an issue for users who use `BoostingStrategy` to set GBT parameters. -* *(Breaking change)* The return value of [`LDA.run`](api/scala/index.html#org.apache.spark.mllib.clustering.LDA) has changed. It now returns an abstract class `LDAModel` instead of the concrete class `DistributedLDAModel`. The object of type `LDAModel` can still be cast to the appropriate concrete type, which depends on the optimization algorithm. +* [SPARK-9268](https://issues.apache.org/jira/browse/SPARK-9268): Java's varargs support is removed + from `Params.setDefault` due to a + [Scala compiler bug](https://issues.scala-lang.org/browse/SI-9013). +* [SPARK-10097](https://issues.apache.org/jira/browse/SPARK-10097): `Evaluator.isLargerBetter` is + added to indicate metric ordering. Metrics like RMSE no longer flip signs as in 1.4. -## Previous Spark Versions +## Previous Spark versions Earlier migration guides are archived [on this page](mllib-migration-guides.html). + +--- diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index 5732bc4c7e79..6aa881f74918 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -160,4 +160,39 @@ model.save(sc.sc(), "myModelPath"); IsotonicRegressionModel sameModel = IsotonicRegressionModel.load(sc.sc(), "myModelPath"); {% endhighlight %}
    + +
    +Data are read from a file where each line has a format label,feature +i.e. 4710.28,500.00. The data are split to training and testing set. +Model is created using the training set and a mean squared error is calculated from the predicted +labels and real labels in the test set. + +{% highlight python %} +import math +from pyspark.mllib.regression import IsotonicRegression, IsotonicRegressionModel + +data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt") + +# Create label, feature, weight tuples from input data with weight set to default value 1.0. +parsedData = data.map(lambda line: tuple([float(x) for x in line.split(',')]) + (1.0,)) + +# Split data into training (60%) and test (40%) sets. +training, test = parsedData.randomSplit([0.6, 0.4], 11) + +# Create isotonic regression model from training data. +# Isotonic parameter defaults to true so it is only shown for demonstration +model = IsotonicRegression.train(training) + +# Create tuples of predicted and real labels. +predictionAndLabel = test.map(lambda p: (model.predict(p[1]), p[0])) + +# Calculate mean squared error between predicted and real labels. +meanSquaredError = predictionAndLabel.map(lambda pl: math.pow((pl[0] - pl[1]), 2)).mean() +print("Mean Squared Error = " + str(meanSquaredError)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = IsotonicRegressionModel.load(sc, "myModelPath") +{% endhighlight %} +
    diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 3dc8cc902fa7..e9b2d276cd38 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -10,7 +10,7 @@ displayTitle: MLlib - Linear Methods `\[ \newcommand{\R}{\mathbb{R}} -\newcommand{\E}{\mathbb{E}} +\newcommand{\E}{\mathbb{E}} \newcommand{\x}{\mathbf{x}} \newcommand{\y}{\mathbf{y}} \newcommand{\wv}{\mathbf{w}} @@ -18,10 +18,10 @@ displayTitle: MLlib - Linear Methods \newcommand{\bv}{\mathbf{b}} \newcommand{\N}{\mathbb{N}} \newcommand{\id}{\mathbf{I}} -\newcommand{\ind}{\mathbf{1}} -\newcommand{\0}{\mathbf{0}} -\newcommand{\unit}{\mathbf{e}} -\newcommand{\one}{\mathbf{1}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} \newcommand{\zero}{\mathbf{0}} \]` @@ -29,7 +29,7 @@ displayTitle: MLlib - Linear Methods Many standard *machine learning* methods can be formulated as a convex optimization problem, i.e. the task of finding a minimizer of a convex function `$f$` that depends on a variable vector -`$\wv$` (called `weights` in the code), which has `$d$` entries. +`$\wv$` (called `weights` in the code), which has `$d$` entries. Formally, we can write this as the optimization problem `$\min_{\wv \in\R^d} \; f(\wv)$`, where the objective function is of the form `\begin{equation} @@ -39,7 +39,7 @@ the objective function is of the form \ . \end{equation}` Here the vectors `$\x_i\in\R^d$` are the training data examples, for `$1\le i\le n$`, and -`$y_i\in\R$` are their corresponding labels, which we want to predict. +`$y_i\in\R$` are their corresponding labels, which we want to predict. We call the method *linear* if $L(\wv; \x, y)$ can be expressed as a function of $\wv^T x$ and $y$. Several of MLlib's classification and regression algorithms fall into this category, and are discussed here. @@ -99,6 +99,9 @@ regularizers in MLlib: L1$\|\wv\|_1$$\mathrm{sign}(\wv)$ + + elastic net$\alpha \|\wv\|_1 + (1-\alpha)\frac{1}{2}\|\wv\|_2^2$$\alpha \mathrm{sign}(\wv) + (1-\alpha) \wv$ + @@ -107,7 +110,7 @@ of `$\wv$`. L2-regularized problems are generally easier to solve than L1-regularized due to smoothness. However, L1 regularization can help promote sparsity in weights leading to smaller and more interpretable models, the latter of which can be useful for feature selection. -It is not recommended to train models without any regularization, +[Elastic net](http://en.wikipedia.org/wiki/Elastic_net_regularization) is a combination of L1 and L2 regularization. It is not recommended to train models without any regularization, especially when the number of training examples is small. ### Optimization @@ -499,9 +502,8 @@ Note that the Python API does not yet support multiclass classification and mode will in the future. {% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel from pyspark.mllib.regression import LabeledPoint -from numpy import array # Load and parse the data def parsePoint(line): @@ -518,6 +520,10 @@ model = LogisticRegressionWithLBFGS.train(parsedData) labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) print("Training Error = " + str(trainErr)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = LogisticRegressionModel.load(sc, "myModelPath") {% endhighlight %} @@ -527,7 +533,7 @@ print("Training Error = " + str(trainErr)) ### Linear least squares, Lasso, and ridge regression -Linear least squares is the most common formulation for regression problems. +Linear least squares is the most common formulation for regression problems. It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss function in the formulation given by the squared loss: `\[ @@ -535,8 +541,8 @@ L(\wv;\x,y) := \frac{1}{2} (\wv^T \x - y)^2. \]` Various related regression methods are derived by using different types of regularization: -[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or -[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) uses +[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or +[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) uses no regularization; [*ridge regression*](http://en.wikipedia.org/wiki/Ridge_regression) uses L2 regularization; and [*Lasso*](http://en.wikipedia.org/wiki/Lasso_(statistics)) uses L1 regularization. For all of these models, the average loss or training error, $\frac{1}{n} \sum_{i=1}^n (\wv^T x_i - y_i)^2$, is @@ -548,7 +554,7 @@ known as the [mean squared error](http://en.wikipedia.org/wiki/Mean_squared_erro
    The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. -The example then uses LinearRegressionWithSGD to build a simple linear model to predict label +The example then uses LinearRegressionWithSGD to build a simple linear model to predict label values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). @@ -610,7 +616,7 @@ public class LinearRegression { public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); JavaSparkContext sc = new JavaSparkContext(conf); - + // Load and parse the data String path = "data/mllib/ridge-data/lpsa.data"; JavaRDD data = sc.textFile(path); @@ -630,7 +636,7 @@ public class LinearRegression { // Building the model int numIterations = 100; - final LinearRegressionModel model = + final LinearRegressionModel model = LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); // Evaluate model on training examples and compute training error @@ -661,15 +667,14 @@ public class LinearRegression {
    The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. -The example then uses LinearRegressionWithSGD to build a simple linear model to predict label +The example then uses LinearRegressionWithSGD to build a simple linear model to predict label values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). Note that the Python API does not yet support model save/load but will in the future. {% highlight python %} -from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD -from numpy import array +from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD, LinearRegressionModel # Load and parse the data def parsePoint(line): @@ -686,6 +691,10 @@ model = LinearRegressionWithSGD.train(parsedData) valuesAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) MSE = valuesAndPreds.map(lambda (v, p): (v - p)**2).reduce(lambda x, y: x + y) / valuesAndPreds.count() print("Mean Squared Error = " + str(MSE)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = LinearRegressionModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -698,8 +707,8 @@ a dependency. ###Streaming linear regression -When data arrive in a streaming fashion, it is useful to fit regression models online, -updating the parameters of the model as new data arrives. MLlib currently supports +When data arrive in a streaming fashion, it is useful to fit regression models online, +updating the parameters of the model as new data arrives. MLlib currently supports streaming linear regression using ordinary least squares. The fitting is similar to that performed offline, except fitting occurs on each batch of data, so that the model continually updates to reflect the data from the stream. @@ -714,7 +723,7 @@ online to the first stream, and make predictions on the second stream.
    -First, we import the necessary classes for parsing our input data and creating the model. +First, we import the necessary classes for parsing our input data and creating the model. {% highlight scala %} @@ -726,7 +735,7 @@ import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD Then we make input streams for training and testing data. We assume a StreamingContext `ssc` has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) -for more info. For this example, we use labeled points in training and testing streams, +for more info. For this example, we use labeled points in training and testing streams, but in practice you will likely want to use unlabeled vectors for test data. {% highlight scala %} @@ -746,7 +755,7 @@ val model = new StreamingLinearRegressionWithSGD() {% endhighlight %} -Now we register the streams for training and testing and start the job. +Now we register the streams for training and testing and start the job. Printing predictions alongside true labels lets us easily see the result. {% highlight scala %} @@ -756,14 +765,66 @@ model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() ssc.start() ssc.awaitTermination() - + +{% endhighlight %} + +We can now save text files with data to the training or testing folders. +Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label +and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. +As you feed more data to the training directory, the predictions +will get better! + +
    + +
    + +First, we import the necessary classes for parsing our input data and creating the model. + +{% highlight python %} +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.regression import StreamingLinearRegressionWithSGD +{% endhighlight %} + +Then we make input streams for training and testing data. We assume a StreamingContext `ssc` +has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) +for more info. For this example, we use labeled points in training and testing streams, +but in practice you will likely want to use unlabeled vectors for test data. + +{% highlight python %} +def parse(lp): + label = float(lp[lp.find('(') + 1: lp.find(',')]) + vec = Vectors.dense(lp[lp.find('[') + 1: lp.find(']')].split(',')) + return LabeledPoint(label, vec) + +trainingData = ssc.textFileStream("/training/data/dir").map(parse).cache() +testData = ssc.textFileStream("/testing/data/dir").map(parse) +{% endhighlight %} + +We create our model by initializing the weights to 0 + +{% highlight python %} +numFeatures = 3 +model = StreamingLinearRegressionWithSGD() +model.setInitialWeights([0.0, 0.0, 0.0]) +{% endhighlight %} + +Now we register the streams for training and testing and start the job. + +{% highlight python %} +model.trainOn(trainingData) +print(model.predictOnValues(testData.map(lambda lp: (lp.label, lp.features)))) + +ssc.start() +ssc.awaitTermination() {% endhighlight %} We can now save text files with data to the training or testing folders. -Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label -and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` -the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. -As you feed more data to the training directory, the predictions +Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label +and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. +As you feed more data to the training directory, the predictions will get better!
    diff --git a/docs/mllib-migration-guides.md b/docs/mllib-migration-guides.md index 8df68d81f3c7..774b85d1f773 100644 --- a/docs/mllib-migration-guides.md +++ b/docs/mllib-migration-guides.md @@ -7,6 +7,25 @@ description: MLlib migration guides from before Spark SPARK_VERSION_SHORT The migration guide for the current Spark version is kept on the [MLlib Programming Guide main page](mllib-guide.html#migration-guide). +## From 1.3 to 1.4 + +In the `spark.mllib` package, there were several breaking changes, but all in `DeveloperApi` or `Experimental` APIs: + +* Gradient-Boosted Trees + * *(Breaking change)* The signature of the [`Loss.gradient`](api/scala/index.html#org.apache.spark.mllib.tree.loss.Loss) method was changed. This is only an issues for users who wrote their own losses for GBTs. + * *(Breaking change)* The `apply` and `copy` methods for the case class [`BoostingStrategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.BoostingStrategy) have been changed because of a modification to the case class fields. This could be an issue for users who use `BoostingStrategy` to set GBT parameters. +* *(Breaking change)* The return value of [`LDA.run`](api/scala/index.html#org.apache.spark.mllib.clustering.LDA) has changed. It now returns an abstract class `LDAModel` instead of the concrete class `DistributedLDAModel`. The object of type `LDAModel` can still be cast to the appropriate concrete type, which depends on the optimization algorithm. + +In the `spark.ml` package, several major API changes occurred, including: + +* `Param` and other APIs for specifying parameters +* `uid` unique IDs for Pipeline components +* Reorganization of certain classes + +Since the `spark.ml` API was an alpha component in Spark 1.3, we do not list all changes here. +However, since 1.4 `spark.ml` is no longer an alpha component, we will provide details on any API +changes for future releases. + ## From 1.2 to 1.3 In the `spark.mllib` package, there were several breaking changes. The first change (in `ALS`) is the only one in a component not marked as Alpha or Experimental. @@ -23,6 +42,17 @@ In the `spark.mllib` package, there were several breaking changes. The first ch * In linear regression (including Lasso and ridge regression), the squared loss is now divided by 2. So in order to produce the same result as in 1.2, the regularization parameter needs to be divided by 2 and the step size needs to be multiplied by 2. +In the `spark.ml` package, the main API changes are from Spark SQL. We list the most important changes here: + +* The old [SchemaRDD](http://spark.apache.org/docs/1.2.1/api/scala/index.html#org.apache.spark.sql.SchemaRDD) has been replaced with [DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame) with a somewhat modified API. All algorithms in Spark ML which used to use SchemaRDD now use DataFrame. +* In Spark 1.2, we used implicit conversions from `RDD`s of `LabeledPoint` into `SchemaRDD`s by calling `import sqlContext._` where `sqlContext` was an instance of `SQLContext`. These implicits have been moved, so we now call `import sqlContext.implicits._`. +* Java APIs for SQL have also changed accordingly. Please see the examples above and the [Spark SQL Programming Guide](sql-programming-guide.html) for details. + +Other changes were in `LogisticRegression`: + +* The `scoreCol` output column (with default value "score") was renamed to be `probabilityCol` (with default value "probability"). The type was originally `Double` (for the probability of class 1.0), but it is now `Vector` (for the probability of each class, to support multiclass classification in the future). +* In Spark 1.2, `LogisticRegressionModel` did not include an intercept. In Spark 1.3, it includes an intercept; however, it will always be 0.0 since it uses the default settings for [spark.mllib.LogisticRegressionWithLBFGS](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS). The option to use an intercept will be added in the future. + ## From 1.1 to 1.2 The only API changes in MLlib v1.2 are in diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index bf6d124fd5d8..e73bd30f3a90 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -119,7 +119,7 @@ used for evaluation and prediction. Note that the Python API does not yet support model save/load but will in the future. {% highlight python %} -from pyspark.mllib.classification import NaiveBayes +from pyspark.mllib.classification import NaiveBayes, NaiveBayesModel from pyspark.mllib.linalg import Vectors from pyspark.mllib.regression import LabeledPoint @@ -140,6 +140,10 @@ model = NaiveBayes.train(training, 1.0) # Make prediction and test accuracy. predictionAndLabel = test.map(lambda p : (model.predict(p.features), p.label)) accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count() + +# Save and load model +model.save(sc, "myModelPath") +sameModel = NaiveBayesModel.load(sc, "myModelPath") {% endhighlight %} diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index 887eae7f4f07..6acfc71d7b01 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -95,9 +95,9 @@ mat = ... # an RDD of Vectors # Compute column summary statistics. summary = Statistics.colStats(mat) -print summary.mean() -print summary.variance() -print summary.numNonzeros() +print(summary.mean()) +print(summary.variance()) +print(summary.numNonzeros()) {% endhighlight %} @@ -183,12 +183,12 @@ seriesY = ... # must have the same number of partitions and cardinality as serie # Compute the correlation using Pearson's method. Enter "spearman" for Spearman's method. If a # method is not specified, Pearson's method will be used by default. -print Statistics.corr(seriesX, seriesY, method="pearson") +print(Statistics.corr(seriesX, seriesY, method="pearson")) data = ... # an RDD of Vectors # calculate the correlation matrix using Pearson's method. Use "spearman" for Spearman's method. # If a method is not specified, Pearson's method will be used by default. -print Statistics.corr(data, method="pearson") +print(Statistics.corr(data, method="pearson")) {% endhighlight %} @@ -283,7 +283,7 @@ approxSample = data.sampleByKey(False, fractions); Hypothesis testing is a powerful tool in statistics to determine whether a result is statistically significant, whether this result occurred by chance or not. MLlib currently supports Pearson's -chi-squared ( $\chi^2$) tests for goodness of fit and independence. The input data types determine +chi-squared ( $\chi^2$) tests for goodness of fit and independence. The input data types determine whether the goodness of fit or the independence test is conducted. The goodness of fit test requires an input type of `Vector`, whereas the independence test requires a `Matrix` as input. @@ -398,14 +398,14 @@ vec = Vectors.dense(...) # a vector composed of the frequencies of events # compute the goodness of fit. If a second vector to test against is not supplied as a parameter, # the test runs against a uniform distribution. goodnessOfFitTestResult = Statistics.chiSqTest(vec) -print goodnessOfFitTestResult # summary of the test including the p-value, degrees of freedom, - # test statistic, the method used, and the null hypothesis. +print(goodnessOfFitTestResult) # summary of the test including the p-value, degrees of freedom, + # test statistic, the method used, and the null hypothesis. mat = Matrices.dense(...) # a contingency matrix # conduct Pearson's independence test on the input contingency matrix independenceTestResult = Statistics.chiSqTest(mat) -print independenceTestResult # summary of the test including the p-value, degrees of freedom... +print(independenceTestResult) # summary of the test including the p-value, degrees of freedom... obs = sc.parallelize(...) # LabeledPoint(feature, label) . @@ -415,13 +415,91 @@ obs = sc.parallelize(...) # LabeledPoint(feature, label) . featureTestResults = Statistics.chiSqTest(obs) for i, result in enumerate(featureTestResults): - print "Column $d:" % (i + 1) - print result + print("Column $d:" % (i + 1)) + print(result) {% endhighlight %} +Additionally, MLlib provides a 1-sample, 2-sided implementation of the Kolmogorov-Smirnov (KS) test +for equality of probability distributions. By providing the name of a theoretical distribution +(currently solely supported for the normal distribution) and its parameters, or a function to +calculate the cumulative distribution according to a given theoretical distribution, the user can +test the null hypothesis that their sample is drawn from that distribution. In the case that the +user tests against the normal distribution (`distName="norm"`), but does not provide distribution +parameters, the test initializes to the standard normal distribution and logs an appropriate +message. + +
    +
    +[`Statistics`](api/scala/index.html#org.apache.spark.mllib.stat.Statistics$) provides methods to +run a 1-sample, 2-sided Kolmogorov-Smirnov test. The following example demonstrates how to run +and interpret the hypothesis tests. + +{% highlight scala %} +import org.apache.spark.mllib.stat.Statistics + +val data: RDD[Double] = ... // an RDD of sample data + +// run a KS test for the sample versus a standard normal distribution +val testResult = Statistics.kolmogorovSmirnovTest(data, "norm", 0, 1) +println(testResult) // summary of the test including the p-value, test statistic, + // and null hypothesis + // if our p-value indicates significance, we can reject the null hypothesis + +// perform a KS test using a cumulative distribution function of our making +val myCDF: Double => Double = ... +val testResult2 = Statistics.kolmogorovSmirnovTest(data, myCDF) +{% endhighlight %} +
    + +
    +[`Statistics`](api/java/org/apache/spark/mllib/stat/Statistics.html) provides methods to +run a 1-sample, 2-sided Kolmogorov-Smirnov test. The following example demonstrates how to run +and interpret the hypothesis tests. + +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaSparkContext; + +import org.apache.spark.mllib.stat.Statistics; +import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult; + +JavaSparkContext jsc = ... +JavaDoubleRDD data = jsc.parallelizeDoubles(Arrays.asList(0.2, 1.0, ...)); +KolmogorovSmirnovTestResult testResult = Statistics.kolmogorovSmirnovTest(data, "norm", 0.0, 1.0); +// summary of the test including the p-value, test statistic, +// and null hypothesis +// if our p-value indicates significance, we can reject the null hypothesis +System.out.println(testResult); +{% endhighlight %} +
    + +
    +[`Statistics`](api/python/pyspark.mllib.html#pyspark.mllib.stat.Statistics) provides methods to +run a 1-sample, 2-sided Kolmogorov-Smirnov test. The following example demonstrates how to run +and interpret the hypothesis tests. + +{% highlight python %} +from pyspark.mllib.stat import Statistics + +parallelData = sc.parallelize([1.0, 2.0, ... ]) + +# run a KS test for the sample versus a standard normal distribution +testResult = Statistics.kolmogorovSmirnovTest(parallelData, "norm", 0, 1) +print(testResult) # summary of the test including the p-value, test statistic, + # and null hypothesis + # if our p-value indicates significance, we can reject the null hypothesis +# Note that the Scala functionality of calling Statistics.kolmogorovSmirnovTest with +# a lambda to calculate the CDF is not made available in the Python API +{% endhighlight %} +
    +
    + + ## Random data generation Random data generation is useful for randomized algorithms, prototyping, and performance testing. @@ -493,5 +571,82 @@ u = RandomRDDs.uniformRDD(sc, 1000000L, 10) v = u.map(lambda x: 1.0 + 2.0 * x) {% endhighlight %} + + +## Kernel density estimation + +[Kernel density estimation](https://en.wikipedia.org/wiki/Kernel_density_estimation) is a technique +useful for visualizing empirical probability distributions without requiring assumptions about the +particular distribution that the observed samples are drawn from. It computes an estimate of the +probability density function of a random variables, evaluated at a given set of points. It achieves +this estimate by expressing the PDF of the empirical distribution at a particular point as the the +mean of PDFs of normal distributions centered around each of the samples. + +
    + +
    +[`KernelDensity`](api/scala/index.html#org.apache.spark.mllib.stat.KernelDensity) provides methods +to compute kernel density estimates from an RDD of samples. The following example demonstrates how +to do so. + +{% highlight scala %} +import org.apache.spark.mllib.stat.KernelDensity +import org.apache.spark.rdd.RDD + +val data: RDD[Double] = ... // an RDD of sample data + +// Construct the density estimator with the sample data and a standard deviation for the Gaussian +// kernels +val kd = new KernelDensity() + .setSample(data) + .setBandwidth(3.0) + +// Find density estimates for the given values +val densities = kd.estimate(Array(-1.0, 2.0, 5.0)) +{% endhighlight %} +
    + +
    +[`KernelDensity`](api/java/index.html#org.apache.spark.mllib.stat.KernelDensity) provides methods +to compute kernel density estimates from an RDD of samples. The following example demonstrates how +to do so. + +{% highlight java %} +import org.apache.spark.mllib.stat.KernelDensity; +import org.apache.spark.rdd.RDD; + +RDD data = ... // an RDD of sample data + +// Construct the density estimator with the sample data and a standard deviation for the Gaussian +// kernels +KernelDensity kd = new KernelDensity() + .setSample(data) + .setBandwidth(3.0); + +// Find density estimates for the given values +double[] densities = kd.estimate(new double[] {-1.0, 2.0, 5.0}); +{% endhighlight %} +
    + +
    +[`KernelDensity`](api/python/pyspark.mllib.html#pyspark.mllib.stat.KernelDensity) provides methods +to compute kernel density estimates from an RDD of samples. The following example demonstrates how +to do so. + +{% highlight python %} +from pyspark.mllib.stat import KernelDensity + +data = ... # an RDD of sample data + +# Construct the density estimator with the sample data and a standard deviation for the Gaussian +# kernels +kd = KernelDensity() +kd.setSample(data) +kd.setBandwidth(3.0) + +# Find density estimates for the given values +densities = kd.estimate([-1.0, 2.0, 5.0]) +{% endhighlight %} +
    diff --git a/docs/monitoring.md b/docs/monitoring.md index bcf885fe4e68..cedceb295802 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -48,7 +48,7 @@ follows: Environment VariableMeaning SPARK_DAEMON_MEMORY - Memory to allocate to the history server (default: 512m). + Memory to allocate to the history server (default: 1g). SPARK_DAEMON_JAVA_OPTS diff --git a/docs/programming-guide.md b/docs/programming-guide.md index ae712d62746f..4cf83bb39263 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -85,8 +85,8 @@ import org.apache.spark.SparkConf
    -Spark {{site.SPARK_VERSION}} works with Python 2.6 or higher (but not Python 3). It uses the standard CPython interpreter, -so C libraries like NumPy can be used. +Spark {{site.SPARK_VERSION}} works with Python 2.6+ or Python 3.4+. It can use the standard CPython interpreter, +so C libraries like NumPy can be used. It also works with PyPy 2.3+. To run Spark applications in Python, use the `bin/spark-submit` script located in the Spark directory. This script will load Spark's Java/Scala libraries and allow you to submit applications to a cluster. @@ -104,6 +104,14 @@ Finally, you need to import some Spark classes into your program. Add the follow from pyspark import SparkContext, SparkConf {% endhighlight %} +PySpark requires the same minor version of Python in both driver and workers. It uses the default python version in PATH, +you can specify which version of Python you want to use by `PYSPARK_PYTHON`, for example: + +{% highlight bash %} +$ PYSPARK_PYTHON=python3.4 bin/pyspark +$ PYSPARK_PYTHON=/opt/pypy-2.5/bin/pypy bin/spark-submit examples/src/main/python/pi.py +{% endhighlight %} +
    @@ -541,7 +549,7 @@ returning only its answer to the driver program. If we also wanted to use `lineLengths` again later, we could add: {% highlight java %} -lineLengths.persist(); +lineLengths.persist(StorageLevel.MEMORY_ONLY()); {% endhighlight %} before the `reduce`, which would cause `lineLengths` to be saved in memory after the first time it is computed. diff --git a/docs/quick-start.md b/docs/quick-start.md index bb39e4111f24..d481fe0ea6d7 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -126,7 +126,7 @@ scala> val wordCounts = textFile.flatMap(line => line.split(" ")).map(word => (w wordCounts: spark.RDD[(String, Int)] = spark.ShuffledAggregatedRDD@71f027b8 {% endhighlight %} -Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations) and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (String, Int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: +Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations), and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (String, Int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: {% highlight scala %} scala> wordCounts.collect() @@ -163,7 +163,7 @@ One common data flow pattern is MapReduce, as popularized by Hadoop. Spark can i >>> wordCounts = textFile.flatMap(lambda line: line.split()).map(lambda word: (word, 1)).reduceByKey(lambda a, b: a+b) {% endhighlight %} -Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations) and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (string, int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: +Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations), and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (string, int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: {% highlight python %} >>> wordCounts.collect() @@ -217,13 +217,13 @@ a cluster, as described in the [programming guide](programming-guide.html#initia # Self-Contained Applications -Now say we wanted to write a self-contained application using the Spark API. We will walk through a -simple application in both Scala (with SBT), Java (with Maven), and Python. +Suppose we wish to write a self-contained application using the Spark API. We will walk through a +simple application in Scala (with sbt), Java (with Maven), and Python.
    -We'll create a very simple Spark application in Scala. So simple, in fact, that it's +We'll create a very simple Spark application in Scala--so simple, in fact, that it's named `SimpleApp.scala`: {% highlight scala %} @@ -259,7 +259,7 @@ object which contains information about our application. Our application depends on the Spark API, so we'll also include an sbt configuration file, -`simple.sbt` which explains that Spark is a dependency. This file also adds a repository that +`simple.sbt`, which explains that Spark is a dependency. This file also adds a repository that Spark depends on: {% highlight scala %} @@ -302,7 +302,7 @@ Lines with a: 46, Lines with b: 23
    -This example will use Maven to compile an application jar, but any similar build system will work. +This example will use Maven to compile an application JAR, but any similar build system will work. We'll create a very simple Spark application, `SimpleApp.java`: @@ -374,7 +374,7 @@ $ find . Now, we can package the application using Maven and execute it with `./bin/spark-submit`. {% highlight bash %} -# Package a jar containing your application +# Package a JAR containing your application $ mvn package ... [INFO] Building jar: {..}/{..}/target/simple-project-1.0.jar @@ -406,7 +406,7 @@ logData = sc.textFile(logFile).cache() numAs = logData.filter(lambda s: 'a' in s).count() numBs = logData.filter(lambda s: 'b' in s).count() -print "Lines with a: %i, lines with b: %i" % (numAs, numBs) +print("Lines with a: %i, lines with b: %i" % (numAs, numBs)) {% endhighlight %} diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 5f1d6daeb27f..247e6ecfbdb8 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -45,7 +45,7 @@ frameworks. You can install Mesos either from source or using prebuilt packages To install Apache Mesos from source, follow these steps: 1. Download a Mesos release from a - [mirror](http://www.apache.org/dyn/closer.cgi/mesos/{{site.MESOS_VERSION}}/) + [mirror](http://www.apache.org/dyn/closer.lua/mesos/{{site.MESOS_VERSION}}/) 2. Follow the Mesos [Getting Started](http://mesos.apache.org/gettingstarted) page for compiling and installing Mesos @@ -157,6 +157,8 @@ From the client, you can submit a job to Mesos cluster by running `spark-submit` to the url of the MesosClusterDispatcher (e.g: mesos://dispatcher:7077). You can view driver statuses on the Spark cluster Web UI. +Note that jars or python files that are passed to spark-submit should be URIs reachable by Mesos slaves. + # Mesos Run Modes Spark can run over Mesos in two modes: "fine-grained" (default) and "coarse-grained". @@ -184,6 +186,14 @@ acquire. By default, it will acquire *all* cores in the cluster (that get offere only makes sense if you run just one application at a time. You can cap the maximum number of cores using `conf.set("spark.cores.max", "10")` (for example). +You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. + +{% highlight scala %} +conf.set("spark.mesos.constraints", "tachyon=true;us-east-1=false") +{% endhighlight %} + +For example, Let's say `spark.mesos.constraints` is set to `tachyon=true;us-east-1=false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. + # Mesos Docker Support Spark can make use of a Mesos Docker containerizer by setting the property `spark.mesos.executor.docker.image` @@ -208,6 +218,20 @@ node. Please refer to [Hadoop on Mesos](https://github.com/mesos/hadoop). In either case, HDFS runs separately from Hadoop MapReduce, without being scheduled through Mesos. +# Dynamic Resource Allocation with Mesos + +Mesos supports dynamic allocation only with coarse grain mode, which can resize the number of executors based on statistics +of the application. While dynamic allocation supports both scaling up and scaling down the number of executors, the coarse grain scheduler only supports scaling down +since it is already designed to run one executor per slave with the configured amount of resources. However, after scaling down the number of executors the coarse grain scheduler +can scale back up to the same amount of executors when Spark signals more executors are needed. + +Users that like to utilize this feature should launch the Mesos Shuffle Service that +provides shuffle data cleanup functionality on top of the Shuffle Service since Mesos doesn't yet support notifying another framework's +termination. To launch/stop the Mesos Shuffle Service please use the provided sbin/start-mesos-shuffle-service.sh and sbin/stop-mesos-shuffle-service.sh +scripts accordingly. + +The Shuffle Service is expected to be running on each slave node that will run Spark executors. One way to easily achieve this with Mesos +is to launch the Shuffle Service with Marathon with a unique host constraint. # Configuration @@ -298,6 +322,50 @@ See the [configuration page](configuration.html) for information on Spark config the final overhead will be this value. + + spark.mesos.uris + (none) + + A list of URIs to be downloaded to the sandbox when driver or executor is launched by Mesos. + This applies to both coarse-grain and fine-grain mode. + + + + spark.mesos.principal + Framework principal to authenticate to Mesos + + Set the principal with which Spark framework will use to authenticate with Mesos. + + + + spark.mesos.secret + Framework secret to authenticate to Mesos + + Set the secret with which Spark framework will use to authenticate with Mesos. + + + + spark.mesos.role + Role for the Spark framework + + Set the role of this Spark framework for Mesos. Roles are used in Mesos for reservations + and resource weight sharing. + + + + spark.mesos.constraints + Attribute based constraints to be matched against when accepting resource offers. + + Attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. Refer to Mesos Attributes & Resources for more information on attributes. +
      +
    • Scalar constraints are matched with "less than equal" semantics i.e. value in the constraint must be less than or equal to the value in the resource offer.
    • +
    • Range constraints are matched with "contains" semantics i.e. value in the constraint must be within the resource offer's value.
    • +
    • Set constraints are matched with "subset of" semantics i.e. value in the constraint must be a subset of the resource offer's value.
    • +
    • Text constraints are metched with "equality" semantics i.e. value in the constraint must be exactly equal to the resource offer's value.
    • +
    • In case there is no value present as a part of the constraint any offer with the corresponding attribute will be accepted (without value check).
    • +
    + + # Troubleshooting and Debugging diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 96cf612c54fd..d1244323edff 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -7,16 +7,93 @@ Support for running on [YARN (Hadoop NextGen)](http://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/YARN.html) was added to Spark in version 0.6.0, and improved in subsequent releases. +# Launching Spark on YARN + +Ensure that `HADOOP_CONF_DIR` or `YARN_CONF_DIR` points to the directory which contains the (client side) configuration files for the Hadoop cluster. +These configs are used to write to HDFS and connect to the YARN ResourceManager. The +configuration contained in this directory will be distributed to the YARN cluster so that all +containers used by the application use the same configuration. If the configuration references +Java system properties or environment variables not managed by YARN, they should also be set in the +Spark application's configuration (driver, executors, and the AM when running in client mode). + +There are two deploy modes that can be used to launch Spark applications on YARN. In `yarn-cluster` mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In `yarn-client` mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. + +Unlike [Spark standalone](spark-standalone.html) and [Mesos](running-on-mesos.html) modes, in which the master's address is specified in the `--master` parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the `--master` parameter is `yarn-client` or `yarn-cluster`. + +To launch a Spark application in `yarn-cluster` mode: + + $ ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options] + +For example: + + $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ + --master yarn-cluster \ + --driver-memory 4g \ + --executor-memory 2g \ + --executor-cores 1 \ + --queue thequeue \ + lib/spark-examples*.jar \ + 10 + +The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. + +To launch a Spark application in `yarn-client` mode, do the same, but replace `yarn-cluster` with `yarn-client`. The following shows how you can run `spark-shell` in `yarn-client` mode: + + $ ./bin/spark-shell --master yarn-client + +## Adding Other JARs + +In `yarn-cluster` mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. + + $ ./bin/spark-submit --class my.main.Class \ + --master yarn-cluster \ + --jars my-other-jar.jar,my-other-other-jar.jar + my-main-jar.jar + app_arg1 app_arg2 + + # Preparations -Running Spark-on-YARN requires a binary distribution of Spark which is built with YARN support. -Binary distributions can be downloaded from the Spark project website. +Running Spark on YARN requires a binary distribution of Spark which is built with YARN support. +Binary distributions can be downloaded from the [downloads page](http://spark.apache.org/downloads.html) of the project website. To build Spark yourself, refer to [Building Spark](building-spark.html). # Configuration Most of the configs are the same for Spark on YARN as for other deployment modes. See the [configuration page](configuration.html) for more information on those. These are configs that are specific to Spark on YARN. +# Debugging your Application + +In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command. + + yarn logs -applicationId + +will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). The logs are also available on the Spark Web UI under the Executors Tab. You need to have both the Spark history server and the MapReduce history server running and configure `yarn.log.server.url` in `yarn-site.xml` properly. The log URL on the Spark history server UI will redirect you to the MapReduce history server to show the aggregated logs. + +When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. The logs are also available on the Spark Web UI under the Executors Tab and doesn't require running the MapReduce history server. + +To review per-container launch environment, increase `yarn.nodemanager.delete.debug-delay-sec` to a +large value (e.g. 36000), and then access the application cache through `yarn.nodemanager.local-dirs` +on the nodes on which containers are launched. This directory contains the launch script, JARs, and +all environment variables used for launching each container. This process is useful for debugging +classpath problems in particular. (Note that enabling this requires admin privileges on cluster +settings and a restart of all node managers. Thus, this is not applicable to hosted clusters). + +To use a custom log4j configuration for the application master or executors, there are two options: + +- upload a custom `log4j.properties` using `spark-submit`, by adding it to the `--files` list of files + to be uploaded with the application. +- add `-Dlog4j.configuration=` to `spark.driver.extraJavaOptions` + (for the driver) or `spark.executor.extraJavaOptions` (for executors). Note that if using a file, + the `file:` protocol should be explicitly provided, and the file needs to exist locally on all + the nodes. + +Note that for the first option, both executors and the application master will share the same +log4j configuration, which may cause issues when they run on the same node (e.g. trying to write +to the same log file). + +If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your log4j.properties. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming application, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log file, and logs can be accessed using YARN's log utility. + #### Spark Properties @@ -50,8 +127,8 @@ Most of the configs are the same for Spark on YARN as for other deployment modes @@ -122,7 +199,7 @@ Most of the configs are the same for Spark on YARN as for other deployment modes @@ -189,8 +266,8 @@ Most of the configs are the same for Spark on YARN as for other deployment modes @@ -206,7 +283,7 @@ Most of the configs are the same for Spark on YARN as for other deployment modes @@ -242,6 +319,14 @@ Most of the configs are the same for Spark on YARN as for other deployment modes running against earlier versions, this property will be ignored. + + + + + @@ -258,85 +343,46 @@ Most of the configs are the same for Spark on YARN as for other deployment modes Principal to be used to login to KDC, while running on secure HDFS. + + + + + + + + + + + + + + +
    spark.yarn.am.waitTime 100s - In yarn-cluster mode, time for the application master to wait for the - SparkContext to be initialized. In yarn-client mode, time for the application master to wait + In `yarn-cluster` mode, time for the application master to wait for the + SparkContext to be initialized. In `yarn-client` mode, time for the application master to wait for the driver to connect to it.
    spark.executor.instances 2 - The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. + The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. If both spark.dynamicAllocation.enabled and spark.executor.instances are specified, dynamic allocation is turned off and the specified number of spark.executor.instances is used.
    Add the environment variable specified by EnvironmentVariableName to the Application Master process launched on YARN. The user can specify multiple of - these and to set multiple environment variables. In yarn-cluster mode this controls - the environment of the SPARK driver and in yarn-client mode it only controls + these and to set multiple environment variables. In `yarn-cluster` mode this controls + the environment of the SPARK driver and in `yarn-client` mode it only controls the environment of the executor launcher.
    (none) A string of extra JVM options to pass to the YARN Application Master in client mode. - In cluster mode, use spark.driver.extraJavaOptions instead. + In cluster mode, use `spark.driver.extraJavaOptions` instead.
    spark.yarn.tags(none) + Comma-separated list of strings to pass through as YARN application tags appearing + in YARN ApplicationReports, which can be used for filtering when querying YARN apps. +
    spark.yarn.keytab (none)
    spark.yarn.config.gatewayPath(none) + A path that is valid on the gateway host (the host where a Spark application is started) but may + differ for paths for the same resource in other nodes in the cluster. Coupled with + spark.yarn.config.replacementPath, this is used to support clusters with + heterogeneous configurations, so that Spark can correctly launch remote processes. +

    + The replacement path normally will contain a reference to some environment variable exported by + YARN (and, thus, visible to Spark containers). +

    + For example, if the gateway node has Hadoop libraries installed on /disk1/hadoop, and + the location of the Hadoop install is exported by YARN as the HADOOP_HOME + environment variable, setting this value to /disk1/hadoop and the replacement path to + $HADOOP_HOME will make sure that paths used to launch remote processes properly + reference the local YARN configuration. +

    spark.yarn.config.replacementPath(none) + See spark.yarn.config.gatewayPath. +
    spark.yarn.security.tokens.${service}.enabledtrue + Controls whether to retrieve delegation tokens for non-HDFS services when security is enabled. + By default, delegation tokens for all supported services are retrieved when those services are + configured, but it's possible to disable that behavior if it somehow conflicts with the + application being run. +

    + Currently supported services are: hive, hbase +

    -# Launching Spark on YARN - -Ensure that `HADOOP_CONF_DIR` or `YARN_CONF_DIR` points to the directory which contains the (client side) configuration files for the Hadoop cluster. -These configs are used to write to the dfs and connect to the YARN ResourceManager. The -configuration contained in this directory will be distributed to the YARN cluster so that all -containers used by the application use the same configuration. If the configuration references -Java system properties or environment variables not managed by YARN, they should also be set in the -Spark application's configuration (driver, executors, and the AM when running in client mode). - -There are two deploy modes that can be used to launch Spark applications on YARN. In yarn-cluster mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In yarn-client mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. - -Unlike in Spark standalone and Mesos mode, in which the master's address is specified in the "master" parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the master parameter is simply "yarn-client" or "yarn-cluster". - -To launch a Spark application in yarn-cluster mode: - - ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options] - -For example: - - $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ - --master yarn-cluster \ - --num-executors 3 \ - --driver-memory 4g \ - --executor-memory 2g \ - --executor-cores 1 \ - --queue thequeue \ - lib/spark-examples*.jar \ - 10 - -The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. - -To launch a Spark application in yarn-client mode, do the same, but replace "yarn-cluster" with "yarn-client". To run spark-shell: - - $ ./bin/spark-shell --master yarn-client - -## Adding Other JARs - -In yarn-cluster mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. - - $ ./bin/spark-submit --class my.main.Class \ - --master yarn-cluster \ - --jars my-other-jar.jar,my-other-other-jar.jar - my-main-jar.jar - app_arg1 app_arg2 - -# Debugging your Application - -In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command. - - yarn logs -applicationId - -will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). - -When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. - -To review per-container launch environment, increase `yarn.nodemanager.delete.debug-delay-sec` to a -large value (e.g. 36000), and then access the application cache through `yarn.nodemanager.local-dirs` -on the nodes on which containers are launched. This directory contains the launch script, JARs, and -all environment variables used for launching each container. This process is useful for debugging -classpath problems in particular. (Note that enabling this requires admin privileges on cluster -settings and a restart of all node managers. Thus, this is not applicable to hosted clusters). - -To use a custom log4j configuration for the application master or executors, there are two options: - -- upload a custom `log4j.properties` using `spark-submit`, by adding it to the `--files` list of files - to be uploaded with the application. -- add `-Dlog4j.configuration=` to `spark.driver.extraJavaOptions` - (for the driver) or `spark.executor.extraJavaOptions` (for executors). Note that if using a file, - the `file:` protocol should be explicitly provided, and the file needs to exist locally on all - the nodes. - -Note that for the first option, both executors and the application master will share the same -log4j configuration, which may cause issues when they run on the same node (e.g. trying to write -to the same log file). - -If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your log4j.properties. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming application, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log file, and logs can be accessed using YARN's log utility. - # Important notes - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 4f71fbc086cd..2fe9ec3542b2 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -152,7 +152,7 @@ You can optionally configure the cluster further by setting environment variable SPARK_DAEMON_MEMORY - Memory to allocate to the Spark master and worker daemons themselves (default: 512m). + Memory to allocate to the Spark master and worker daemons themselves (default: 1g). SPARK_DAEMON_JAVA_OPTS diff --git a/docs/sparkr.md b/docs/sparkr.md index 4d82129921a3..7139d16b4a06 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -11,7 +11,8 @@ title: SparkR (R on Spark) SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. In Spark {{site.SPARK_VERSION}}, SparkR provides a distributed data frame implementation that supports operations like selection, filtering, aggregation etc. (similar to R data frames, -[dplyr](https://github.com/hadley/dplyr)) but on large datasets. +[dplyr](https://github.com/hadley/dplyr)) but on large datasets. SparkR also supports distributed +machine learning using MLlib. # SparkR DataFrames @@ -27,9 +28,9 @@ All of the examples on this page use sample data included in R or the Spark dist
    The entry point into SparkR is the `SparkContext` which connects your R program to a Spark cluster. You can create a `SparkContext` using `sparkR.init` and pass in options such as the application name -etc. Further, to work with DataFrames we will need a `SQLContext`, which can be created from the -SparkContext. If you are working from the SparkR shell, the `SQLContext` and `SparkContext` should -already be created for you. +, any spark packages depended on, etc. Further, to work with DataFrames we will need a `SQLContext`, +which can be created from the SparkContext. If you are working from the SparkR shell, the +`SQLContext` and `SparkContext` should already be created for you. {% highlight r %} sc <- sparkR.init() @@ -62,7 +63,16 @@ head(df) SparkR supports operating on a variety of data sources through the `DataFrame` interface. This section describes the general methods for loading and saving data using Data Sources. You can check the Spark SQL programming guide for more [specific options](sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. -The general method for creating DataFrames from data sources is `read.df`. This method takes in the `SQLContext`, the path for the file to load and the type of data source. SparkR supports reading JSON and Parquet files natively and through [Spark Packages](http://spark-packages.org/) you can find data source connectors for popular file formats like [CSV](http://spark-packages.org/package/databricks/spark-csv) and [Avro](http://spark-packages.org/package/databricks/spark-avro). +The general method for creating DataFrames from data sources is `read.df`. This method takes in the `SQLContext`, the path for the file to load and the type of data source. SparkR supports reading JSON and Parquet files natively and through [Spark Packages](http://spark-packages.org/) you can find data source connectors for popular file formats like [CSV](http://spark-packages.org/package/databricks/spark-csv) and [Avro](http://spark-packages.org/package/databricks/spark-avro). These packages can either be added by +specifying `--packages` with `spark-submit` or `sparkR` commands, or if creating context through `init` +you can specify the packages with the `packages` argument. + +
    +{% highlight r %} +sc <- sparkR.init(sparkPackages="com.databricks:spark-csv_2.11:1.0.3") +sqlContext <- sparkRSQL.init(sc) +{% endhighlight %} +
    We can see how to use data sources using an example JSON input file. Note that the file that is used here is _not_ a typical JSON file. Each line in the file must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. @@ -107,7 +117,7 @@ sql(hiveContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") sql(hiveContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") # Queries can be expressed in HiveQL. -results <- hiveContext.sql("FROM src SELECT key, value") +results <- sql(hiveContext, "FROM src SELECT key, value") # results is now a DataFrame head(results) @@ -221,3 +231,37 @@ head(teenagers) {% endhighlight %}
    + +# Machine Learning + +SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', '+', and '-'. The example below shows the use of building a gaussian GLM model using SparkR. + +
    +{% highlight r %} +# Create the DataFrame +df <- createDataFrame(sqlContext, iris) + +# Fit a linear model over the dataset. +model <- glm(Sepal_Length ~ Sepal_Width + Species, data = df, family = "gaussian") + +# Model coefficients are returned in a similar format to R's native glm(). +summary(model) +##$coefficients +## Estimate +##(Intercept) 2.2513930 +##Sepal_Width 0.8035609 +##Species_versicolor 1.4587432 +##Species_virginica 1.9468169 + +# Make predictions based on the model. +predictions <- predict(model, newData = df) +head(select(predictions, "Sepal_Length", "prediction")) +## Sepal_Length prediction +##1 5.1 5.063856 +##2 4.9 4.662076 +##3 4.7 4.822788 +##4 4.6 4.742432 +##5 5.0 5.144212 +##6 5.4 5.385281 +{% endhighlight %} +
    diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 26c036f6648d..a0b911d20724 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -11,7 +11,7 @@ title: Spark SQL and DataFrames Spark SQL is a Spark module for structured data processing. It provides a programming abstraction called DataFrames and can also act as distributed SQL query engine. -For how to enable Hive support, please refer to the [Hive Tables](#hive-tables) section. +Spark SQL can also be used to read data from an existing Hive installation. For more on how to configure this feature, please refer to the [Hive Tables](#hive-tables) section. # DataFrames @@ -22,13 +22,13 @@ The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark. All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`, `pyspark` shell, or `sparkR` shell. -## Starting Point: `SQLContext` +## Starting Point: SQLContext
    The entry point into all functionality in Spark SQL is the -[`SQLContext`](api/scala/index.html#org.apache.spark.sql.`SQLContext`) class, or one of its +[`SQLContext`](api/scala/index.html#org.apache.spark.sql.SQLContext) class, or one of its descendants. To create a basic `SQLContext`, all you need is a SparkContext. {% highlight scala %} @@ -213,6 +213,11 @@ df.groupBy("age").count().show() // 30 1 {% endhighlight %} +For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/scala/index.html#org.apache.spark.sql.DataFrame). + +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/scala/index.html#org.apache.spark.sql.DataFrame). + +
    @@ -263,6 +268,10 @@ df.groupBy("age").count().show(); // 30 1 {% endhighlight %} +For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/java/org/apache/spark/sql/DataFrame.html). + +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/java/org/apache/spark/sql/functions.html). +
    @@ -320,6 +329,10 @@ df.groupBy("age").count().show() {% endhighlight %} +For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/python/pyspark.sql.html#pyspark.sql.DataFrame). + +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/python/pyspark.sql.html#module-pyspark.sql.functions). +
    @@ -370,10 +383,13 @@ showDF(count(groupBy(df, "age"))) {% endhighlight %} -
    +For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/R/index.html). + +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/R/index.html).
    +
    ## Running SQL Queries Programmatically @@ -570,7 +586,7 @@ teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 1 # The results of SQL queries are RDDs and support all the normal RDD operations. teenNames = teenagers.map(lambda p: "Name: " + p.name) for teenName in teenNames.collect(): - print teenName + print(teenName) {% endhighlight %}
    @@ -752,7 +768,7 @@ results = sqlContext.sql("SELECT name FROM people") # The results of SQL queries are RDDs and support all the normal RDD operations. names = results.map(lambda p: "Name: " + p.name) for name in names.collect(): - print name + print(name) {% endhighlight %} @@ -828,7 +844,7 @@ using this syntax. {% highlight scala %} val df = sqlContext.read.format("json").load("examples/src/main/resources/people.json") -df.select("name", "age").write.format("json").save("namesAndAges.json") +df.select("name", "age").write.format("parquet").save("namesAndAges.parquet") {% endhighlight %} @@ -870,12 +886,11 @@ saveDF(select(df, "name", "age"), "namesAndAges.parquet", "parquet") Save operations can optionally take a `SaveMode`, that specifies how to handle existing data if present. It is important to realize that these save modes do not utilize any locking and are not -atomic. Thus, it is not safe to have multiple writers attempting to write to the same location. -Additionally, when performing a `Overwrite`, the data will be deleted before writing out the +atomic. Additionally, when performing a `Overwrite`, the data will be deleted before writing out the new data. - + @@ -1006,7 +1021,7 @@ parquetFile.registerTempTable("parquetFile"); teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") teenNames = teenagers.map(lambda p: "Name: " + p.name) for teenName in teenNames.collect(): - print teenName + print(teenName) {% endhighlight %} @@ -1036,6 +1051,15 @@ for (teenName in collect(teenNames)) { +
    + +{% highlight python %} +# sqlContext is an existing HiveContext +sqlContext.sql("REFRESH TABLE my_table") +{% endhighlight %} + +
    +
    {% highlight sql %} @@ -1054,7 +1078,7 @@ SELECT * FROM parquetTable
    -### Partition discovery +### Partition Discovery Table partitioning is a common optimization approach used in systems like Hive. In a partitioned table, data are usually stored in different directories, with partitioning column values encoded in @@ -1108,13 +1132,20 @@ can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, w `true`. When type inference is disabled, string type will be used for the partitioning columns. -### Schema merging +### Schema Merging Like ProtocolBuffer, Avro, and Thrift, Parquet also supports schema evolution. Users can start with a simple schema, and gradually add more columns to the schema as needed. In this way, users may end up with multiple Parquet files with different but mutually compatible schemas. The Parquet data source is now able to automatically detect this case and merge schemas of all these files. +Since schema merging is a relatively expensive operation, and is not a necessity in most cases, we +turned it off by default starting from 1.5.0. You may enable it by + +1. setting data source option `mergeSchema` to `true` when reading Parquet files (as shown in the + examples below), or +2. setting the global SQL option `spark.sql.parquet.mergeSchema` to `true`. +
    @@ -1134,7 +1165,7 @@ val df2 = sc.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple") df2.write.parquet("data/test_table/key=2") // Read the partitioned table -val df3 = sqlContext.read.parquet("data/test_table") +val df3 = sqlContext.read.option("mergeSchema", "true").parquet("data/test_table") df3.printSchema() // The final schema consists of all 3 columns in the Parquet files together @@ -1156,16 +1187,16 @@ df3.printSchema() # Create a simple DataFrame, stored into a partition directory df1 = sqlContext.createDataFrame(sc.parallelize(range(1, 6))\ .map(lambda i: Row(single=i, double=i * 2))) -df1.save("data/test_table/key=1", "parquet") +df1.write.parquet("data/test_table/key=1") # Create another DataFrame in a new partition directory, # adding a new column and dropping an existing column df2 = sqlContext.createDataFrame(sc.parallelize(range(6, 11)) .map(lambda i: Row(single=i, triple=i * 3))) -df2.save("data/test_table/key=2", "parquet") +df2.write.parquet("data/test_table/key=2") # Read the partitioned table -df3 = sqlContext.load("data/test_table", "parquet") +df3 = sqlContext.read.option("mergeSchema", "true").parquet("data/test_table") df3.printSchema() # The final schema consists of all 3 columns in the Parquet files together @@ -1192,7 +1223,7 @@ saveDF(df1, "data/test_table/key=1", "parquet", "overwrite") saveDF(df2, "data/test_table/key=2", "parquet", "overwrite") # Read the partitioned table -df3 <- loadDF(sqlContext, "data/test_table", "parquet") +df3 <- loadDF(sqlContext, "data/test_table", "parquet", mergeSchema="true") printSchema(df3) # The final schema consists of all 3 columns in the Parquet files together @@ -1208,6 +1239,79 @@ printSchema(df3)
    +### Hive metastore Parquet table conversion + +When reading from and writing to Hive metastore Parquet tables, Spark SQL will try to use its own +Parquet support instead of Hive SerDe for better performance. This behavior is controlled by the +`spark.sql.hive.convertMetastoreParquet` configuration, and is turned on by default. + +#### Hive/Parquet Schema Reconciliation + +There are two key differences between Hive and Parquet from the perspective of table schema +processing. + +1. Hive is case insensitive, while Parquet is not +1. Hive considers all columns nullable, while nullability in Parquet is significant + +Due to this reason, we must reconcile Hive metastore schema with Parquet schema when converting a +Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation rules are: + +1. Fields that have the same name in both schema must have the same data type regardless of + nullability. The reconciled field should have the data type of the Parquet side, so that + nullability is respected. + +1. The reconciled schema contains exactly those fields defined in Hive metastore schema. + + - Any fields that only appear in the Parquet schema are dropped in the reconciled schema. + - Any fileds that only appear in the Hive metastore schema are added as nullable field in the + reconciled schema. + +#### Metadata Refreshing + +Spark SQL caches Parquet metadata for better performance. When Hive metastore Parquet table +conversion is enabled, metadata of those converted tables are also cached. If these tables are +updated by Hive or other external tools, you need to refresh them manually to ensure consistent +metadata. + +
    + +
    + +{% highlight scala %} +// sqlContext is an existing HiveContext +sqlContext.refreshTable("my_table") +{% endhighlight %} + +
    + +
    + +{% highlight java %} +// sqlContext is an existing HiveContext +sqlContext.refreshTable("my_table") +{% endhighlight %} + +
    + +
    + +{% highlight python %} +# sqlContext is an existing HiveContext +sqlContext.refreshTable("my_table") +{% endhighlight %} + +
    + +
    + +{% highlight sql %} +REFRESH TABLE my_table; +{% endhighlight %} + +
    + +
    + ### Configuration Configuration of Parquet can be done using the `setConf` method on `SQLContext` or by running @@ -1219,7 +1323,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`
    @@ -1228,8 +1332,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` @@ -1250,13 +1353,8 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` - - + + @@ -1266,6 +1364,47 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` support. + + + + + + + + + +
    Scala/JavaPythonMeaning
    Scala/JavaAny LanguageMeaning
    SaveMode.ErrorIfExists (default) "error" (default)spark.sql.parquet.binaryAsString false - Some other Parquet-producing systems, in particular Impala and older versions of Spark SQL, do + Some other Parquet-producing systems, in particular Impala, Hive, and older versions of Spark SQL, do not differentiate between binary data and strings when writing out the Parquet schema. This flag tells Spark SQL to interpret binary data as a string to provide compatibility with these systems. spark.sql.parquet.int96AsTimestamp true - Some Parquet-producing systems, in particular Impala, store Timestamp into INT96. Spark would also - store Timestamp as INT96 because we need to avoid precision lost of the nanoseconds field. This + Some Parquet-producing systems, in particular Impala and Hive, store Timestamp into INT96. This flag tells Spark SQL to interpret INT96 data as a timestamp to provide compatibility with these systems.
    spark.sql.parquet.filterPushdownfalse - Turn on Parquet filter pushdown optimization. This feature is turned off by default because of a known - bug in Parquet 1.6.0rc3 (PARQUET-136). - However, if your table doesn't contain any nullable string or binary columns, it's still safe to turn - this feature on. - trueEnables Parquet filter push-down optimization when set to true.
    spark.sql.hive.convertMetastoreParquet
    spark.sql.parquet.output.committer.classorg.apache.parquet.hadoop.
    ParquetOutputCommitter
    +

    + The output committer class used by Parquet. The specified class needs to be a subclass of + org.apache.hadoop.
    mapreduce.OutputCommitter
    . Typically, it's also a + subclass of org.apache.parquet.hadoop.ParquetOutputCommitter. +

    +

    + Note: +

      +
    • + This option is automatically ignored if spark.speculation is turned on. +
    • +
    • + This option must be set via Hadoop Configuration rather than Spark + SQLConf. +
    • +
    • + This option overrides spark.sql.sources.
      outputCommitterClass
      . +
    • +
    +

    +

    + Spark SQL comes with a builtin + org.apache.spark.sql.
    parquet.DirectParquetOutputCommitter
    , which can be more + efficient then the default Parquet output committer when writing data to S3. +

    +
    spark.sql.parquet.mergeSchemafalse +

    + When true, the Parquet data source merges schemas collected from all data files, otherwise the + schema is picked from the summary file or a random data file if no summary file is available. +

    +
    ## JSON Datasets @@ -1445,8 +1584,8 @@ This command builds a new assembly jar that includes Hive. Note that this Hive a on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries (SerDes) in order to access data stored in Hive. -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. Please note when running -the query on a YARN cluster (`yarn-cluster` mode), the `datanucleus` jars under the `lib_managed/jars` directory +Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. Please note when running +the query on a YARN cluster (`yarn-cluster` mode), the `datanucleus` jars under the `lib_managed/jars` directory and `hive-site.xml` under `conf/` directory need to be available on the driver and all executors launched by the YARN cluster. The convenient way to do this is adding them through the `--jars` option and `--file` option of the `spark-submit` command. @@ -1527,7 +1666,7 @@ sql(sqlContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") sql(sqlContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") # Queries can be expressed in HiveQL. -results = sqlContext.sql("FROM src SELECT key, value").collect() +results <- collect(sql(sqlContext, "FROM src SELECT key, value")) {% endhighlight %} @@ -1537,21 +1676,21 @@ results = sqlContext.sql("FROM src SELECT key, value").collect() ### Interacting with Different Versions of Hive Metastore One of the most important pieces of Spark SQL's Hive support is interaction with Hive metastore, -which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary build of Spark SQL can be used to query different versions of Hive metastores, using the configuration described below. +which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary +build of Spark SQL can be used to query different versions of Hive metastores, using the configuration described below. +Note that independent of the version of Hive that is being used to talk to the metastore, internally Spark SQL +will compile against Hive 1.2.1 and use those classes for internal execution (serdes, UDFs, UDAFs, etc). -Internally, Spark SQL uses two Hive clients, one for executing native Hive commands like `SET` -and `DESCRIBE`, the other dedicated for communicating with Hive metastore. The former uses Hive -jars of version 0.13.1, which are bundled with Spark 1.4.0. The latter uses Hive jars of the -version specified by users. An isolated classloader is used here to avoid dependency conflicts. +The following options can be used to configure the version of Hive that is used to retrieve metadata: - + @@ -1562,12 +1701,16 @@ version specified by users. An isolated classloader is used here to avoid depend property can be one of three options:
    1. builtin
    2. - Use Hive 0.13.1, which is bundled with the Spark assembly jar when -Phive is + Use Hive 1.2.1, which is bundled with the Spark assembly jar when -Phive is enabled. When this option is chosen, spark.sql.hive.metastore.version must be - either 0.13.1 or not defined. + either 1.2.1 or not defined.
    3. maven
    4. - Use Hive jars of specified version downloaded from Maven repositories. -
    5. A classpath in the standard format for both Hive and Hadoop.
    6. + Use Hive jars of specified version downloaded from Maven repositories. This configuration + is not generally recommended for production deployments. +
    7. A classpath in the standard format for the JVM. This classpath must include all of Hive + and its dependencies, including the correct version of Hadoop. These jars only need to be + present on the driver, but if you are running in yarn cluster mode then you must ensure + they are packaged with you application.
    @@ -1663,9 +1806,9 @@ the Data Sources API. The following options are supported:
    {% highlight scala %} -val jdbcDF = sqlContext.load("jdbc", Map( - "url" -> "jdbc:postgresql:dbserver", - "dbtable" -> "schema.tablename")) +val jdbcDF = sqlContext.read.format("jdbc").options( + Map("url" -> "jdbc:postgresql:dbserver", + "dbtable" -> "schema.tablename")).load() {% endhighlight %}
    @@ -1678,7 +1821,7 @@ Map options = new HashMap(); options.put("url", "jdbc:postgresql:dbserver"); options.put("dbtable", "schema.tablename"); -DataFrame jdbcDF = sqlContext.load("jdbc", options) +DataFrame jdbcDF = sqlContext.read().format("jdbc"). options(options).load(); {% endhighlight %} @@ -1688,7 +1831,7 @@ DataFrame jdbcDF = sqlContext.load("jdbc", options) {% highlight python %} -df = sqlContext.load(source="jdbc", url="jdbc:postgresql:dbserver", dbtable="schema.tablename") +df = sqlContext.read.format('jdbc').options(url='jdbc:postgresql:dbserver', dbtable='schema.tablename').load() {% endhighlight %} @@ -1779,12 +1922,11 @@ that these options will be deprecated in future release as more optimizations ar - - + + @@ -1794,9 +1936,9 @@ that these options will be deprecated in future release as more optimizations ar Configures the number of partitions to use when shuffling data for joins or aggregations. - + - + @@ -1884,12 +2026,40 @@ options. # Migration Guide +## Upgrading From Spark SQL 1.4 to 1.5 + + - Optimized execution using manually managed memory (Tungsten) is now enabled by default, along with + code generation for expression evaluation. These features can both be disabled by setting + `spark.sql.tungsten.enabled` to `false. + - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting + `spark.sql.parquet.mergeSchema` to `true`. + - Resolution of strings to columns in python now supports using dots (`.`) to qualify the column or + access nested values. For example `df['table.column.nestedField']`. However, this means that if + your column name contains any dots you must now escape them using backticks (e.g., ``table.`column.with.dots`.nested``). + - In-memory columnar storage partition pruning is on by default. It can be disabled by setting + `spark.sql.inMemoryColumnarStorage.partitionPruning` to `false`. + - Unlimited precision decimal columns are no longer supported, instead Spark SQL enforces a maximum + precision of 38. When inferring schema from `BigDecimal` objects, a precision of (38, 18) is now + used. When no precision is specified in DDL then the default remains `Decimal(10, 0)`. + - Timestamps are now stored at a precision of 1us, rather than 1ns + - In the `sql` dialect, floating point numbers are now parsed as decimal. HiveQL parsing remains + unchanged. + - The canonical name of SQL/DataFrame functions are now lower case (e.g. sum vs SUM). + - It has been determined that using the DirectOutputCommitter when speculation is enabled is unsafe + and thus this output committer will not be used when speculation is on, independent of configuration. + - JSON data source will not automatically load new files that are created by other applications + (i.e. files that are not inserted to the dataset through Spark SQL). + For a JSON persistent table (i.e. the metadata of the table is stored in Hive Metastore), + users can use `REFRESH TABLE` SQL command or `HiveContext`'s `refreshTable` method + to include those new files to the table. For a DataFrame representing a JSON dataset, users need to recreate + the DataFrame and the new DataFrame will include new files. + ## Upgrading from Spark SQL 1.3 to 1.4 #### DataFrame data reader/writer interface Based on user feedback, we created a new, more fluid API for reading data in (`SQLContext.read`) -and writing data out (`DataFrame.write`), +and writing data out (`DataFrame.write`), and deprecated the old APIs (e.g. `SQLContext.parquetFile`, `SQLContext.jsonFile`). See the API docs for `SQLContext.read` ( @@ -1905,7 +2075,8 @@ See the API docs for `SQLContext.read` ( #### DataFrame.groupBy retains grouping columns -Based on user feedback, we changed the default behavior of `DataFrame.groupBy().agg()` to retain the grouping columns in the resulting `DataFrame`. To keep the behavior in 1.3, set `spark.sql.retainGroupColumns` to `false`. +Based on user feedback, we changed the default behavior of `DataFrame.groupBy().agg()` to retain the +grouping columns in the resulting `DataFrame`. To keep the behavior in 1.3, set `spark.sql.retainGroupColumns` to `false`.
    @@ -2042,7 +2213,7 @@ Python UDF registration is unchanged. When using DataTypes in Python you will need to construct them (i.e. `StringType()`) instead of referencing a singleton. -## Migration Guide for Shark User +## Migration Guide for Shark Users ### Scheduling To set a [Fair Scheduler](job-scheduling.html#fair-scheduler-pools) pool for a JDBC client session, @@ -2118,6 +2289,7 @@ Spark SQL supports the vast majority of Hive features, such as: * User defined functions (UDF) * User defined aggregation functions (UDAF) * User defined serialization formats (SerDes) +* Window functions * Joins * `JOIN` * `{LEFT|RIGHT|FULL} OUTER JOIN` @@ -2128,7 +2300,7 @@ Spark SQL supports the vast majority of Hive features, such as: * `SELECT col FROM ( SELECT a + b AS col from t1) t2` * Sampling * Explain -* Partitioned tables +* Partitioned tables including dynamic partition insertion * View * All Hive DDL Functions, including: * `CREATE TABLE` @@ -2190,8 +2362,9 @@ releases of Spark SQL. Hive can optionally merge the small files into fewer large files to avoid overflowing the HDFS metadata. Spark SQL does not support that. +# Reference -# Data Types +## Data Types Spark SQL and DataFrames support the following data types: @@ -2804,3 +2977,13 @@ from pyspark.sql.types import *
    +## NaN Semantics + +There is specially handling for not-a-number (NaN) when dealing with `float` or `double` types that +does not exactly match standard floating point semantics. +Specifically: + + - NaN = NaN returns true. + - In aggregations all NaN values are grouped together. + - NaN is treated as a normal value in join keys. + - NaN values go last when in ascending order, larger than any other numeric value. diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index 8d6e74370918..383d954409ce 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -5,8 +5,6 @@ title: Spark Streaming + Flume Integration Guide [Apache Flume](https://flume.apache.org/) is a distributed, reliable, and available service for efficiently collecting, aggregating, and moving large amounts of log data. Here we explain how to configure Flume and Spark Streaming to receive data from Flume. There are two approaches to this. -Python API Flume is not yet available in the Python API. - ## Approach 1: Flume-style Push-based Approach Flume is designed to push data between Flume agents. In this approach, Spark Streaming essentially sets up a receiver that acts an Avro agent for Flume, to which Flume can push the data. Here are the configuration steps. @@ -58,6 +56,15 @@ configuring Flume agents. See the [API docs](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html) and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java).
    +
    + from pyspark.streaming.flume import FlumeUtils + + flumeStream = FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]) + + By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type. + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/flume_wordcount.py). +
    Note that the hostname should be the same as the one used by the resource manager in the @@ -135,6 +142,15 @@ configuring Flume agents. JavaReceiverInputDStreamflumeStream = FlumeUtils.createPollingStream(streamingContext, [sink machine hostname], [sink port]); +
    + from pyspark.streaming.flume import FlumeUtils + + addresses = [([sink machine hostname 1], [sink port 1]), ([sink machine hostname 2], [sink port 2])] + flumeStream = FlumeUtils.createPollingStream(streamingContext, addresses) + + By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type. + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils). +
    See the Scala example [FlumePollingEventCount]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala). diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 775d508d4879..5db39ae54a27 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -82,7 +82,7 @@ This approach has the following advantages over the receiver-based approach (i.e - *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka. -- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semanitcs of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). +- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semantics of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). Note that one disadvantage of this approach is that it does not update offsets in Zookeeper, hence Zookeeper-based Kafka monitoring tools will not show progress. However, you can access the offsets processed by this approach in each batch and update Zookeeper yourself (see below). @@ -152,7 +152,7 @@ Next, we discuss how to use this approach in your streaming application.
    // Hold a reference to the current offset ranges, so it can be used downstream - final AtomicReference offsetRanges = new AtomicReference(); + final AtomicReference offsetRanges = new AtomicReference<>(); directKafkaStream.transformToPair( new Function, JavaPairRDD>() { diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index aa9749afbc86..238a911a9199 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -51,6 +51,17 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m See the [API docs](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html) and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java). Refer to the next subsection for instructions to run the example. +
    +
    + from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream + + kinesisStream = KinesisUtils.createStream( + streamingContext, [Kinesis app name], [Kinesis stream name], [endpoint URL], + [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2) + + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kinesis.KinesisUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py). Refer to the next subsection for instructions to run the example. +
    @@ -80,7 +91,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m - Kinesis data processing is ordered per partition and occurs at-least once per message. - - Multiple applications can read from the same Kinesis stream. Kinesis will maintain the application-specific shard and checkpoint info in DynamodDB. + - Multiple applications can read from the same Kinesis stream. Kinesis will maintain the application-specific shard and checkpoint info in DynamoDB. - A single Kinesis stream shard is processed by one input DStream at a time. @@ -135,6 +146,14 @@ To run the example, bin/run-example streaming.JavaKinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL] + +
    + + bin/spark-submit --jars extras/kinesis-asl/target/scala-*/\ + spark-streaming-kinesis-asl-assembly_*.jar \ + extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py \ + [Kinesis app name] [Kinesis stream name] [endpoint URL] [region name] +
    diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index b784d59666fe..c751dbb41785 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -50,13 +50,7 @@ all of which are presented in this guide. You will find tabs throughout this guide that let you choose between code snippets of different languages. -**Note:** Python API for Spark Streaming has been introduced in Spark 1.2. It has all the DStream -transformations and almost all the output operations available in Scala and Java interfaces. -However, it only has support for basic sources like text files and text data over sockets. -APIs for additional sources, like Kafka and Flume, will be available in the future. -Further information about available features in the Python API are mentioned throughout this -document; look out for the tag -Python API. +**Note:** There are a few APIs that are either different or not available in Python. Throughout this guide, you will find the tag Python API highlighting these differences. *************************************************************************************************** @@ -683,7 +677,7 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea {:.no_toc} Python API As of Spark {{site.SPARK_VERSION_SHORT}}, -out of these sources, *only* Kafka is available in the Python API. We will add more advanced sources in the Python API in future. +out of these sources, Kafka, Kinesis, Flume and MQTT are available in the Python API. This category of sources require interfacing with external non-Spark libraries, some of them with complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts @@ -725,9 +719,9 @@ Some of these advanced sources are as follows. - **Kafka:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kafka 0.8.2.1. See the [Kafka Integration Guide](streaming-kafka-integration.html) for more details. -- **Flume:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Flume 1.4.0. See the [Flume Integration Guide](streaming-flume-integration.html) for more details. +- **Flume:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Flume 1.6.0. See the [Flume Integration Guide](streaming-flume-integration.html) for more details. -- **Kinesis:** See the [Kinesis Integration Guide](streaming-kinesis-integration.html) for more details. +- **Kinesis:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kinesis Client Library 1.2.1. See the [Kinesis Integration Guide](streaming-kinesis-integration.html) for more details. - **Twitter:** Spark Streaming's TwitterUtils uses Twitter4j 3.0.3 to get the public stream of tweets using [Twitter's Streaming API](https://dev.twitter.com/docs/streaming-apis). Authentication information @@ -854,6 +848,8 @@ it with new information. To use this, you will have to do two steps. 1. Define the state update function - Specify with a function how to update the state using the previous state and the new values from an input stream. +In every batch, Spark will apply the state update function for all existing keys, regardless of whether they have new data in a batch or not. If the update function returns `None` then the key-value pair will be eliminated. + Let's illustrate this with an example. Say you want to maintain a running count of each word seen in a text data stream. Here, the running count is the state and it is an integer. We define the update function as: @@ -1139,7 +1135,7 @@ val joinedStream = stream1.join(stream2) {% highlight java %} JavaPairDStream stream1 = ... JavaPairDStream stream2 = ... -JavaPairDStream joinedStream = stream1.join(stream2); +JavaPairDStream> joinedStream = stream1.join(stream2); {% endhighlight %}
    @@ -1164,7 +1160,7 @@ val joinedStream = windowedStream1.join(windowedStream2) {% highlight java %} JavaPairDStream windowedStream1 = stream1.window(Durations.seconds(20)); JavaPairDStream windowedStream2 = stream2.window(Durations.minutes(1)); -JavaPairDStream joinedStream = windowedStream1.join(windowedStream2); +JavaPairDStream> joinedStream = windowedStream1.join(windowedStream2); {% endhighlight %}
    @@ -1523,7 +1519,7 @@ def getSqlContextInstance(sparkContext): words = ... # DStream of strings def process(time, rdd): - print "========= %s =========" % str(time) + print("========= %s =========" % str(time)) try: # Get the singleton instance of SQLContext sqlContext = getSqlContextInstance(rdd.context) @@ -1700,7 +1696,7 @@ context.awaitTermination(); If the `checkpointDirectory` exists, then the context will be recreated from the checkpoint data. If the directory does not exist (i.e., running for the first time), then the function `contextFactory` will be called to create a new -context and set up the DStreams. See the Scala example +context and set up the DStreams. See the Java example [JavaRecoverableNetworkWordCount]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java). This example appends the word counts of network data into a file. @@ -1811,7 +1807,7 @@ To run a Spark Streaming applications, you need to have the following. + *Mesos* - [Marathon](https://github.com/mesosphere/marathon) has been used to achieve this with Mesos. -- *[Since Spark 1.2] Configuring write ahead logs* - Since Spark 1.2, +- *Configuring write ahead logs* - Since Spark 1.2, we have introduced _write ahead logs_ for achieving strong fault-tolerance guarantees. If enabled, all the data received from a receiver gets written into a write ahead log in the configuration checkpoint directory. This prevents data loss on driver @@ -1826,6 +1822,17 @@ To run a Spark Streaming applications, you need to have the following. stored in a replicated storage system. This can be done by setting the storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER`. +- *Setting the max receiving rate* - If the cluster resources is not large enough for the streaming + application to process data as fast as it is being received, the receivers can be rate limited + by setting a maximum rate limit in terms of records / sec. + See the [configuration parameters](configuration.html#spark-streaming) + `spark.streaming.receiver.maxRate` for receivers and `spark.streaming.kafka.maxRatePerPartition` + for Direct Kafka approach. In Spark 1.5, we have introduced a feature called *backpressure* that + eliminate the need to set this rate limit, as Spark Streaming automatically figures out the + rate limits and dynamically adjusts them if the processing conditions change. This backpressure + can be enabled by setting the [configuration parameter](configuration.html#spark-streaming) + `spark.streaming.backpressure.enabled` to `true`. + ### Upgrading Application Code {:.no_toc} diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index e58645274e52..7ea4d6f1a3f8 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -65,8 +65,8 @@ For Python applications, simply pass a `.py` file in the place of ` 0." + + "EBS volumes are only attached if --ebs-vol-size > 0. " + "Only support up to 8 EBS volumes.") parser.add_option( "--placement-group", type="string", default=None, @@ -289,6 +293,10 @@ def parse_args(): parser.add_option( "--additional-security-group", type="string", default="", help="Additional security group to place the machines in") + parser.add_option( + "--additional-tags", type="string", default="", + help="Additional tags to set on the machines; tags are comma-separated, while name and " + + "value are colon separated; ex: \"Task:MySparkProject,Env:production\"") parser.add_option( "--copy-aws-credentials", action="store_true", default=False, help="Add AWS credentials to hadoop configuration to allow Spark to access S3") @@ -302,6 +310,13 @@ def parse_args(): "--private-ips", action="store_true", default=False, help="Use private IPs for instances rather than public if VPC/subnet " + "requires that.") + parser.add_option( + "--instance-initiated-shutdown-behavior", default="stop", + choices=["stop", "terminate"], + help="Whether instances should terminate when shut down or just stop") + parser.add_option( + "--instance-profile-name", default=None, + help="IAM profile name to launch instances under") (opts, args) = parser.parse_args() if len(args) != 2: @@ -314,14 +329,16 @@ def parse_args(): home_dir = os.getenv('HOME') if home_dir is None or not os.path.isfile(home_dir + '/.boto'): if not os.path.isfile('/etc/boto.cfg'): - if os.getenv('AWS_ACCESS_KEY_ID') is None: - print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set", - file=stderr) - sys.exit(1) - if os.getenv('AWS_SECRET_ACCESS_KEY') is None: - print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set", - file=stderr) - sys.exit(1) + # If there is no boto config, check aws credentials + if not os.path.isfile(home_dir + '/.aws/credentials'): + if os.getenv('AWS_ACCESS_KEY_ID') is None: + print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set", + file=stderr) + sys.exit(1) + if os.getenv('AWS_SECRET_ACCESS_KEY') is None: + print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set", + file=stderr) + sys.exit(1) return (opts, action, cluster_name) @@ -358,7 +375,7 @@ def get_validate_spark_version(version, repo): # Source: http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/ -# Last Updated: 2015-05-08 +# Last Updated: 2015-06-19 # For easy maintainability, please keep this manually-inputted dictionary sorted by key. EC2_INSTANCE_TYPES = { "c1.medium": "pvm", @@ -400,6 +417,11 @@ def get_validate_spark_version(version, repo): "m3.large": "hvm", "m3.xlarge": "hvm", "m3.2xlarge": "hvm", + "m4.large": "hvm", + "m4.xlarge": "hvm", + "m4.2xlarge": "hvm", + "m4.4xlarge": "hvm", + "m4.10xlarge": "hvm", "r3.large": "hvm", "r3.xlarge": "hvm", "r3.2xlarge": "hvm", @@ -409,6 +431,7 @@ def get_validate_spark_version(version, repo): "t2.micro": "hvm", "t2.small": "hvm", "t2.medium": "hvm", + "t2.large": "hvm", } @@ -488,6 +511,8 @@ def launch_cluster(conn, opts, cluster_name): master_group.authorize('tcp', 50070, 50070, authorized_address) master_group.authorize('tcp', 60070, 60070, authorized_address) master_group.authorize('tcp', 4040, 4045, authorized_address) + # Rstudio (GUI for R) needs port 8787 for web access + master_group.authorize('tcp', 8787, 8787, authorized_address) # HDFS NFS gateway requires 111,2049,4242 for tcp & udp master_group.authorize('tcp', 111, 111, authorized_address) master_group.authorize('udp', 111, 111, authorized_address) @@ -592,7 +617,8 @@ def launch_cluster(conn, opts, cluster_name): block_device_map=block_map, subnet_id=opts.subnet_id, placement_group=opts.placement_group, - user_data=user_data_content) + user_data=user_data_content, + instance_profile_name=opts.instance_profile_name) my_req_ids += [req.id for req in slave_reqs] i += 1 @@ -637,16 +663,19 @@ def launch_cluster(conn, opts, cluster_name): for zone in zones: num_slaves_this_zone = get_partition(opts.slaves, num_zones, i) if num_slaves_this_zone > 0: - slave_res = image.run(key_name=opts.key_pair, - security_group_ids=[slave_group.id] + additional_group_ids, - instance_type=opts.instance_type, - placement=zone, - min_count=num_slaves_this_zone, - max_count=num_slaves_this_zone, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content) + slave_res = image.run( + key_name=opts.key_pair, + security_group_ids=[slave_group.id] + additional_group_ids, + instance_type=opts.instance_type, + placement=zone, + min_count=num_slaves_this_zone, + max_count=num_slaves_this_zone, + block_device_map=block_map, + subnet_id=opts.subnet_id, + placement_group=opts.placement_group, + user_data=user_data_content, + instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior, + instance_profile_name=opts.instance_profile_name) slave_nodes += slave_res.instances print("Launched {s} slave{plural_s} in {z}, regid = {r}".format( s=num_slaves_this_zone, @@ -668,32 +697,43 @@ def launch_cluster(conn, opts, cluster_name): master_type = opts.instance_type if opts.zone == 'all': opts.zone = random.choice(conn.get_all_zones()).name - master_res = image.run(key_name=opts.key_pair, - security_group_ids=[master_group.id] + additional_group_ids, - instance_type=master_type, - placement=opts.zone, - min_count=1, - max_count=1, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content) + master_res = image.run( + key_name=opts.key_pair, + security_group_ids=[master_group.id] + additional_group_ids, + instance_type=master_type, + placement=opts.zone, + min_count=1, + max_count=1, + block_device_map=block_map, + subnet_id=opts.subnet_id, + placement_group=opts.placement_group, + user_data=user_data_content, + instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior, + instance_profile_name=opts.instance_profile_name) master_nodes = master_res.instances print("Launched master in %s, regid = %s" % (zone, master_res.id)) # This wait time corresponds to SPARK-4983 print("Waiting for AWS to propagate instance metadata...") - time.sleep(5) - # Give the instances descriptive names + time.sleep(15) + + # Give the instances descriptive names and set additional tags + additional_tags = {} + if opts.additional_tags.strip(): + additional_tags = dict( + map(str.strip, tag.split(':', 1)) for tag in opts.additional_tags.split(',') + ) + for master in master_nodes: - master.add_tag( - key='Name', - value='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) + master.add_tags( + dict(additional_tags, Name='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) + ) + for slave in slave_nodes: - slave.add_tag( - key='Name', - value='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) + slave.add_tags( + dict(additional_tags, Name='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) + ) # Return all the instances return (master_nodes, slave_nodes) @@ -757,7 +797,7 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): ssh_write(slave_address, opts, ['tar', 'x'], dot_ssh_tar) modules = ['spark', 'ephemeral-hdfs', 'persistent-hdfs', - 'mapreduce', 'spark-standalone', 'tachyon'] + 'mapreduce', 'spark-standalone', 'tachyon', 'rstudio'] if opts.hadoop_major_version == "1": modules = list(filter(lambda x: x != "mapreduce", modules)) @@ -911,7 +951,7 @@ def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state): # Get number of local disks available for a given EC2 instance type. def get_num_disks(instance_type): # Source: http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html - # Last Updated: 2015-05-08 + # Last Updated: 2015-06-19 # For easy maintainability, please keep this manually-inputted dictionary sorted by key. disks_by_instance = { "c1.medium": 1, @@ -953,6 +993,11 @@ def get_num_disks(instance_type): "m3.large": 1, "m3.xlarge": 2, "m3.2xlarge": 2, + "m4.large": 0, + "m4.xlarge": 0, + "m4.2xlarge": 0, + "m4.4xlarge": 0, + "m4.10xlarge": 0, "r3.large": 1, "r3.xlarge": 1, "r3.2xlarge": 1, @@ -962,6 +1007,7 @@ def get_num_disks(instance_type): "t2.micro": 0, "t2.small": 0, "t2.medium": 0, + "t2.large": 0, } if instance_type in disks_by_instance: return disks_by_instance[instance_type] @@ -1113,8 +1159,8 @@ def ssh(host, opts, command): # If this was an ssh failure, provide the user with hints. if e.returncode == 255: raise UsageError( - "Failed to SSH to remote host {0}.\n" + - "Please check that you have provided the correct --identity-file and " + + "Failed to SSH to remote host {0}.\n" + "Please check that you have provided the correct --identity-file and " "--key-pair parameters and try again.".format(host)) else: raise e diff --git a/examples/pom.xml b/examples/pom.xml index e6884b09dca9..f5ab2a7fdc09 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index 9df26ffca577..a377694507d2 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -124,7 +124,7 @@ public String uid() { /** * Param for max number of iterations - *

    + *

    * NOTE: The usual way to add a parameter to a model or algorithm is to include: * - val myParamName: ParamType * - def getMyParamName @@ -222,7 +222,7 @@ public Vector predictRaw(Vector features) { /** * Create a copy of the model. * The copy is shallow, except for the embedded paramMap, which gets a deep copy. - *

    + *

    * This is used for the defaul implementation of [[transform()]]. * * In Java, we have to make this method public since Java does not understand Scala's protected @@ -230,6 +230,7 @@ public Vector predictRaw(Vector features) { */ @Override public MyJavaLogisticRegressionModel copy(ParamMap extra) { - return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra); + return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra) + .setParent(parent()); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java new file mode 100644 index 000000000000..be2bf0c7b465 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.util.regex.Pattern; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.ml.clustering.KMeansModel; +import org.apache.spark.ml.clustering.KMeans; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + + +/** + * An example demonstrating a k-means clustering. + * Run with + *

    + * bin/run-example ml.JavaSimpleParamsExample  
    + * 
    + */ +public class JavaKMeansExample { + + private static class ParsePoint implements Function { + private static final Pattern separator = Pattern.compile(" "); + + @Override + public Row call(String line) { + String[] tok = separator.split(line); + double[] point = new double[tok.length]; + for (int i = 0; i < tok.length; ++i) { + point[i] = Double.parseDouble(tok[i]); + } + Vector[] points = {Vectors.dense(point)}; + return new GenericRow(points); + } + } + + public static void main(String[] args) { + if (args.length != 2) { + System.err.println("Usage: ml.JavaKMeansExample "); + System.exit(1); + } + String inputFile = args[0]; + int k = Integer.parseInt(args[1]); + + // Parses the arguments + SparkConf conf = new SparkConf().setAppName("JavaKMeansExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // Loads data + JavaRDD points = jsc.textFile(inputFile).map(new ParsePoint()); + StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; + StructType schema = new StructType(fields); + DataFrame dataset = sqlContext.createDataFrame(points, schema); + + // Trains a k-means model + KMeans kmeans = new KMeans() + .setK(k); + KMeansModel model = kmeans.fit(dataset); + + // Shows the result + Vector[] centers = model.clusterCenters(); + System.out.println("Cluster Centers: "); + for (Vector center: centers) { + System.out.println(center); + } + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java index 75063dbf800d..e7f2f6f61507 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java @@ -178,6 +178,7 @@ private static Params parse(String[] args) { return params; } + @SuppressWarnings("static") private static Options generateCommandlineOptions() { Option input = OptionBuilder.withArgName("input") .hasArg() diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index dac649d1d5ae..94beeced3d47 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -77,7 +77,8 @@ public static void main(String[] args) { ParamMap paramMap = new ParamMap(); paramMap.put(lr.maxIter().w(20)); // Specify 1 Param. paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter. - paramMap.put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. + double thresholds[] = {0.45, 0.55}; + paramMap.put(lr.regParam().w(0.1), lr.thresholds().w(thresholds)); // Specify multiple Params. // One can also combine ParamMaps. ParamMap paramMap2 = new ParamMap(); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java new file mode 100644 index 000000000000..23f834ab4332 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.regression.LinearRegression; +import org.apache.spark.ml.tuning.*; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +/** + * A simple example demonstrating model selection using TrainValidationSplit. + * + * The example is based on {@link org.apache.spark.examples.ml.JavaSimpleParamsExample} + * using linear regression. + * + * Run with + * {{{ + * bin/run-example ml.JavaTrainValidationSplitExample + * }}} + */ +public class JavaTrainValidationSplitExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaTrainValidationSplitExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + DataFrame data = jsql.createDataFrame( + MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"), + LabeledPoint.class); + + // Prepare training and test data. + DataFrame[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345); + DataFrame training = splits[0]; + DataFrame test = splits[1]; + + LinearRegression lr = new LinearRegression(); + + // We use a ParamGridBuilder to construct a grid of parameters to search over. + // TrainValidationSplit will try all combinations of values and determine best model using + // the evaluator. + ParamMap[] paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam(), new double[] {0.1, 0.01}) + .addGrid(lr.fitIntercept()) + .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0}) + .build(); + + // In this case the estimator is simply the linear regression. + // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + TrainValidationSplit trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator()) + .setEstimatorParamMaps(paramGrid); + + // 80% of the data will be used for training and the remaining 20% for validation. + trainValidationSplit.setTrainRatio(0.8); + + // Run train validation split, and choose the best set of parameters. + TrainValidationSplitModel model = trainValidationSplit.fit(training); + + // Make predictions on test data. model is the model with combination of parameters + // that performed best. + model.transform(test) + .select("features", "label", "prediction") + .show(); + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java index dbf2ef02d7b7..99b63a2590ae 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java @@ -45,7 +45,7 @@ * Usage: JavaStatefulNetworkWordCount * and describe the TCP server that Spark Streaming would connect to receive * data. - *

    + *

    * To run this on your local machine, you need to first run a Netcat server * `$ nc -lk 9999` * and then run the example @@ -85,7 +85,7 @@ public Optional call(List values, Optional state) { @SuppressWarnings("unchecked") List> tuples = Arrays.asList(new Tuple2("hello", 1), new Tuple2("world", 1)); - JavaPairRDD initialRDD = ssc.sc().parallelizePairs(tuples); + JavaPairRDD initialRDD = ssc.sparkContext().parallelizePairs(tuples); JavaReceiverInputDStream lines = ssc.socketTextStream( args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER_2); @@ -107,7 +107,7 @@ public Tuple2 call(String s) { // This will give a Dstream made of state (which is the cumulative count of the words) JavaPairDStream stateDstream = wordsDstream.updateStateByKey(updateFunction, - new HashPartitioner(ssc.sc().defaultParallelism()), initialRDD); + new HashPartitioner(ssc.sparkContext().defaultParallelism()), initialRDD); stateDstream.print(); ssc.start(); diff --git a/examples/src/main/python/ml/kmeans_example.py b/examples/src/main/python/ml/kmeans_example.py new file mode 100644 index 000000000000..150dadd42f33 --- /dev/null +++ b/examples/src/main/python/ml/kmeans_example.py @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys +import re + +import numpy as np +from pyspark import SparkContext +from pyspark.ml.clustering import KMeans, KMeansModel +from pyspark.mllib.linalg import VectorUDT, _convert_to_vector +from pyspark.sql import SQLContext +from pyspark.sql.types import Row, StructField, StructType + +""" +A simple example demonstrating a k-means clustering. +Run with: + bin/spark-submit examples/src/main/python/ml/kmeans_example.py + +This example requires NumPy (http://www.numpy.org/). +""" + + +def parseVector(line): + array = np.array([float(x) for x in line.split(' ')]) + return _convert_to_vector(array) + + +if __name__ == "__main__": + + FEATURES_COL = "features" + + if len(sys.argv) != 3: + print("Usage: kmeans_example.py ", file=sys.stderr) + exit(-1) + path = sys.argv[1] + k = sys.argv[2] + + sc = SparkContext(appName="PythonKMeansExample") + sqlContext = SQLContext(sc) + + lines = sc.textFile(path) + data = lines.map(parseVector) + row_rdd = data.map(lambda x: Row(x)) + schema = StructType([StructField(FEATURES_COL, VectorUDT(), False)]) + df = sqlContext.createDataFrame(row_rdd, schema) + + kmeans = KMeans().setK(2).setSeed(1).setFeaturesCol(FEATURES_COL) + model = kmeans.fit(df) + centers = model.clusterCenters() + + print("Cluster Centers: ") + for center in centers: + print(center) + + sc.stop() diff --git a/examples/src/main/python/ml/logistic_regression.py b/examples/src/main/python/ml/logistic_regression.py new file mode 100644 index 000000000000..55afe1b207fe --- /dev/null +++ b/examples/src/main/python/ml/logistic_regression.py @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import LogisticRegression +from pyspark.mllib.evaluation import MulticlassMetrics +from pyspark.ml.feature import StringIndexer +from pyspark.mllib.util import MLUtils +from pyspark.sql import SQLContext + +""" +A simple example demonstrating a logistic regression with elastic net regularization Pipeline. +Run with: + bin/spark-submit examples/src/main/python/ml/logistic_regression.py +""" + +if __name__ == "__main__": + + if len(sys.argv) > 1: + print("Usage: logistic_regression", file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonLogisticRegressionExample") + sqlContext = SQLContext(sc) + + # Load and parse the data file into a dataframe. + df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Map labels into an indexed column of labels in [0, numLabels) + stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") + si_model = stringIndexer.fit(df) + td = si_model.transform(df) + [training, test] = td.randomSplit([0.7, 0.3]) + + lr = LogisticRegression(maxIter=100, regParam=0.3).setLabelCol("indexedLabel") + lr.setElasticNetParam(0.8) + + # Fit the model + lrModel = lr.fit(training) + + predictionAndLabels = lrModel.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = MulticlassMetrics(predictionAndLabels) + print("weighted f-measure %.3f" % metrics.weightedFMeasure()) + print("precision %s" % metrics.precision()) + print("recall %s" % metrics.recall()) + + sc.stop() diff --git a/examples/src/main/python/ml/simple_params_example.py b/examples/src/main/python/ml/simple_params_example.py index a9f29dab2d60..2d6d115d54d0 100644 --- a/examples/src/main/python/ml/simple_params_example.py +++ b/examples/src/main/python/ml/simple_params_example.py @@ -70,7 +70,7 @@ # We may alternatively specify parameters using a parameter map. # paramMap overrides all lr parameters set earlier. - paramMap = {lr.maxIter: 20, lr.threshold: 0.55, lr.probabilityCol: "myProbability"} + paramMap = {lr.maxIter: 20, lr.thresholds: [0.45, 0.55], lr.probabilityCol: "myProbability"} # Now learn a new model using the new parameters. model2 = lr.fit(training, paramMap) diff --git a/examples/src/main/python/streaming/direct_kafka_wordcount.py b/examples/src/main/python/streaming/direct_kafka_wordcount.py index 6ef188a220c5..ea20678b9aca 100644 --- a/examples/src/main/python/streaming/direct_kafka_wordcount.py +++ b/examples/src/main/python/streaming/direct_kafka_wordcount.py @@ -23,8 +23,8 @@ http://kafka.apache.org/documentation.html#quickstart and then run the example - `$ bin/spark-submit --jars external/kafka-assembly/target/scala-*/\ - spark-streaming-kafka-assembly-*.jar \ + `$ bin/spark-submit --jars \ + external/kafka-assembly/target/scala-*/spark-streaming-kafka-assembly-*.jar \ examples/src/main/python/streaming/direct_kafka_wordcount.py \ localhost:9092 test` """ @@ -37,7 +37,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: direct_kafka_wordcount.py " + print("Usage: direct_kafka_wordcount.py ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonStreamingDirectKafkaWordCount") diff --git a/examples/src/main/python/streaming/flume_wordcount.py b/examples/src/main/python/streaming/flume_wordcount.py new file mode 100644 index 000000000000..d75bc6daac13 --- /dev/null +++ b/examples/src/main/python/streaming/flume_wordcount.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + Usage: flume_wordcount.py + + To run this on your local machine, you need to setup Flume first, see + https://flume.apache.org/documentation.html + + and then run the example + `$ bin/spark-submit --jars \ + external/flume-assembly/target/scala-*/spark-streaming-flume-assembly-*.jar \ + examples/src/main/python/streaming/flume_wordcount.py \ + localhost 12345 +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.flume import FlumeUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: flume_wordcount.py ", file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonStreamingFlumeWordCount") + ssc = StreamingContext(sc, 1) + + hostname, port = sys.argv[1:] + kvs = FlumeUtils.createStream(ssc, hostname, int(port)) + lines = kvs.map(lambda x: x[1]) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py index b178e7899b5e..8d697f620f46 100644 --- a/examples/src/main/python/streaming/kafka_wordcount.py +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -23,8 +23,9 @@ http://kafka.apache.org/documentation.html#quickstart and then run the example - `$ bin/spark-submit --jars external/kafka-assembly/target/scala-*/\ - spark-streaming-kafka-assembly-*.jar examples/src/main/python/streaming/kafka_wordcount.py \ + `$ bin/spark-submit --jars \ + external/kafka-assembly/target/scala-*/spark-streaming-kafka-assembly-*.jar \ + examples/src/main/python/streaming/kafka_wordcount.py \ localhost:2181 test` """ from __future__ import print_function diff --git a/examples/src/main/python/streaming/mqtt_wordcount.py b/examples/src/main/python/streaming/mqtt_wordcount.py new file mode 100644 index 000000000000..abf9c0e21d30 --- /dev/null +++ b/examples/src/main/python/streaming/mqtt_wordcount.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + A sample wordcount with MqttStream stream + Usage: mqtt_wordcount.py + + To run this in your local machine, you need to setup a MQTT broker and publisher first, + Mosquitto is one of the open source MQTT Brokers, see + http://mosquitto.org/ + Eclipse paho project provides number of clients and utilities for working with MQTT, see + http://www.eclipse.org/paho/#getting-started + + and then run the example + `$ bin/spark-submit --jars \ + external/mqtt-assembly/target/scala-*/spark-streaming-mqtt-assembly-*.jar \ + examples/src/main/python/streaming/mqtt_wordcount.py \ + tcp://localhost:1883 foo` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.mqtt import MQTTUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: mqtt_wordcount.py " + exit(-1) + + sc = SparkContext(appName="PythonStreamingMQTTWordCount") + ssc = StreamingContext(sc, 1) + + brokerUrl = sys.argv[1] + topic = sys.argv[2] + + lines = MQTTUtils.createStream(ssc, brokerUrl, topic) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/queue_stream.py b/examples/src/main/python/streaming/queue_stream.py index dcd6a0fc6ff9..b3808907f74a 100644 --- a/examples/src/main/python/streaming/queue_stream.py +++ b/examples/src/main/python/streaming/queue_stream.py @@ -36,8 +36,8 @@ # Create the queue through which RDDs can be pushed to # a QueueInputDStream rddQueue = [] - for i in xrange(5): - rddQueue += [ssc.sparkContext.parallelize([j for j in xrange(1, 1001)], 10)] + for i in range(5): + rddQueue += [ssc.sparkContext.parallelize([j for j in range(1, 1001)], 10)] # Create the QueueInputDStream and use it do some processing inputStream = ssc.queueStream(rddQueue) diff --git a/examples/src/main/r/data-manipulation.R b/examples/src/main/r/data-manipulation.R new file mode 100644 index 000000000000..aa2336e300a9 --- /dev/null +++ b/examples/src/main/r/data-manipulation.R @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# For this example, we shall use the "flights" dataset +# The dataset consists of every flight departing Houston in 2011. +# The data set is made up of 227,496 rows x 14 columns. + +# To run this example use +# ./bin/sparkR --packages com.databricks:spark-csv_2.10:1.0.3 +# examples/src/main/r/data-manipulation.R + +# Load SparkR library into your R session +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 1) { + print("Usage: data-manipulation.R % + summarize(avg(flightsDF$dep_delay), avg(flightsDF$arr_delay)) -> dailyDelayDF + + # Print the computed data frame + head(dailyDelayDF) +} + +# Stop the SparkContext now +sparkR.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index 4c129dbe2d12..d812262fd87d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.{SparkConf, SparkContext} @@ -52,3 +53,4 @@ object BroadcastTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index 023bb3ee2d10..d1b9b8d398dd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -15,13 +15,12 @@ * limitations under the License. */ + // scalastyle:off println + // scalastyle:off jobcontext package org.apache.spark.examples import java.nio.ByteBuffer - -import scala.collection.JavaConversions._ -import scala.collection.mutable.ListBuffer -import scala.collection.immutable.Map +import java.util.Collections import org.apache.cassandra.hadoop.ConfigHelper import org.apache.cassandra.hadoop.cql3.CqlPagingInputFormat @@ -31,7 +30,6 @@ import org.apache.cassandra.utils.ByteBufferUtil import org.apache.hadoop.mapreduce.Job import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ /* @@ -84,6 +82,7 @@ object CassandraCQLTest { val job = new Job() job.setInputFormatClass(classOf[CqlPagingInputFormat]) + val configuration = job.getConfiguration ConfigHelper.setInputInitialAddress(job.getConfiguration(), cHost) ConfigHelper.setInputRpcPort(job.getConfiguration(), cPort) ConfigHelper.setInputColumnFamily(job.getConfiguration(), KeySpace, InputColumnFamily) @@ -120,12 +119,9 @@ object CassandraCQLTest { val casoutputCF = aggregatedRDD.map { case (productId, saleCount) => { - val outColFamKey = Map("prod_id" -> ByteBufferUtil.bytes(productId)) - val outKey: java.util.Map[String, ByteBuffer] = outColFamKey - var outColFamVal = new ListBuffer[ByteBuffer] - outColFamVal += ByteBufferUtil.bytes(saleCount) - val outVal: java.util.List[ByteBuffer] = outColFamVal - (outKey, outVal) + val outKey = Collections.singletonMap("prod_id", ByteBufferUtil.bytes(productId)) + val outVal = Collections.singletonList(ByteBufferUtil.bytes(saleCount)) + (outKey, outVal) } } @@ -140,3 +136,5 @@ object CassandraCQLTest { sc.stop() } } +// scalastyle:on println +// scalastyle:on jobcontext diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index ec689474aecb..1e679bfb5534 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -15,13 +15,14 @@ * limitations under the License. */ +// scalastyle:off println +// scalastyle:off jobcontext package org.apache.spark.examples import java.nio.ByteBuffer +import java.util.Arrays import java.util.SortedMap -import scala.collection.JavaConversions._ - import org.apache.cassandra.db.IColumn import org.apache.cassandra.hadoop.ColumnFamilyOutputFormat import org.apache.cassandra.hadoop.ConfigHelper @@ -31,7 +32,6 @@ import org.apache.cassandra.utils.ByteBufferUtil import org.apache.hadoop.mapreduce.Job import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ /* * This example demonstrates using Spark with Cassandra with the New Hadoop API and Cassandra @@ -117,7 +117,7 @@ object CassandraTest { val outputkey = ByteBufferUtil.bytes(word + "-COUNT-" + System.currentTimeMillis) - val mutations: java.util.List[Mutation] = new Mutation() :: new Mutation() :: Nil + val mutations = Arrays.asList(new Mutation(), new Mutation()) mutations.get(0).setColumn_or_supercolumn(new ColumnOrSuperColumn()) mutations.get(0).column_or_supercolumn.setColumn(colWord) mutations.get(1).setColumn_or_supercolumn(new ColumnOrSuperColumn()) @@ -130,6 +130,8 @@ object CassandraTest { sc.stop() } } +// scalastyle:on println +// scalastyle:on jobcontext /* create keyspace casDemo; diff --git a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala index 1f12034ce0f5..d651fe4d6ee7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.io.File @@ -136,3 +137,4 @@ object DFSReadWriteTest { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala index e757283823fc..bec61f3cd429 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala @@ -15,9 +15,10 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.util.Utils @@ -35,10 +36,10 @@ object DriverSubmissionTest { val properties = Utils.getSystemProperties println("Environment variables containing SPARK_TEST:") - env.filter{case (k, v) => k.contains("SPARK_TEST")}.foreach(println) + env.asScala.filter { case (k, _) => k.contains("SPARK_TEST")}.foreach(println) println("System properties containing spark.test:") - properties.filter{case (k, v) => k.toString.contains("spark.test")}.foreach(println) + properties.filter { case (k, _) => k.toString.contains("spark.test") }.foreach(println) for (i <- 1 until numSecondsToSleep) { println(s"Alive for $i out of $numSecondsToSleep seconds") @@ -46,3 +47,4 @@ object DriverSubmissionTest { } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala index 15f6678648b2..fa4a3afeecd1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -53,3 +54,4 @@ object GroupByTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala index 95c96111c9b1..244742327a90 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.hadoop.hbase.client.HBaseAdmin @@ -62,3 +63,4 @@ object HBaseTest { admin.close() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala index ed2b38e2ca6f..124dc9af6390 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark._ @@ -41,3 +42,4 @@ object HdfsTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala index 3d5259463003..af5f216f28ba 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.commons.math3.linear._ @@ -142,3 +143,4 @@ object LocalALS { new Array2DRowRealMatrix(Array.fill(rows, cols)(math.random)) } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala index ac2ea35bbd0e..9c8aae53cf48 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -73,3 +74,4 @@ object LocalFileLR { println("Final w: " + w) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala index 04fc0a033014..e7b28d38bdfc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -119,3 +120,4 @@ object LocalKMeans { println("Final centers: " + kPoints) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index c3fc74a116c0..4f6b092a59ca 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -77,3 +78,4 @@ object LocalLR { println("Final w: " + w) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala index ee6b3ee34aeb..3d923625f11b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.math.random @@ -33,3 +34,4 @@ object LocalPi { println("Pi is roughly " + 4 * count / 100000.0) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala index 75c82117cbad..a80de10f4610 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.{SparkConf, SparkContext} @@ -83,3 +84,4 @@ object LogQuery { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala index 2a5c0c0defe1..61ce9db914f9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.rdd.RDD @@ -53,3 +54,4 @@ object MultiBroadcastTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index 5291ab81f459..3b0b00fe4dd0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -67,3 +68,4 @@ object SimpleSkewedGroupByTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala index 017d4e1e5ce1..719e2176fed3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -57,3 +58,4 @@ object SkewedGroupByTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index 30c426155183..69799b7c2bb3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.commons.math3.linear._ @@ -144,3 +145,4 @@ object SparkALS { new Array2DRowRealMatrix(Array.fill(rows, cols)(math.random)) } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index 9099c2fcc90b..505ea5a4c7a8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -97,3 +98,4 @@ object SparkHdfsLR { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala index b514d9123f5e..c56e1124ad41 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import breeze.linalg.{Vector, DenseVector, squaredDistance} @@ -100,3 +101,4 @@ object SparkKMeans { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index 1e6b4fb0c751..d265c227f4ed 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -86,3 +87,4 @@ object SparkLR { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index bd7894f184c4..0fd79660dd19 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.SparkContext._ @@ -74,3 +75,4 @@ object SparkPageRank { ctx.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala index 35b8dd6c29b6..818d4f2b81f8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.math.random @@ -37,3 +38,4 @@ object SparkPi { spark.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index 772cd897f514..95072071ccdd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.util.Random @@ -70,3 +71,4 @@ object SparkTC { spark.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala index 4393b99e636b..cfbdae02212a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -94,3 +95,4 @@ object SparkTachyonHdfsLR { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala index 7743f7968b10..e46ac655beb5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.math.random @@ -46,3 +47,4 @@ object SparkTachyonPi { spark.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index 409721b01c8f..8dd6c9706e7d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.graphx import scala.collection.mutable @@ -151,3 +152,4 @@ object Analytics extends Logging { } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala index f6f8d9f90c27..da3ffca1a6f2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.graphx /** @@ -42,3 +43,4 @@ object LiveJournalPageRank { Analytics.main(args.patch(0, List("pagerank"), 0)) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala index 3ec20d594b78..46e52aacd90b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.graphx import org.apache.spark.SparkContext._ @@ -128,3 +129,4 @@ object SynthBenchmark { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala index 6c0af20461d3..14b358d46f6a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} @@ -110,3 +111,4 @@ object CrossValidatorExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index 54e407394105..f28671f7869f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -355,3 +356,4 @@ object DecisionTreeExample { println(s" Root mean squared error (RMSE): $RMSE") } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 7b8cc21ed898..340c3559b15e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} @@ -178,6 +179,7 @@ private class MyLogisticRegressionModel( * This is used for the default implementation of [[transform()]]. */ override def copy(extra: ParamMap): MyLogisticRegressionModel = { - copyValues(new MyLogisticRegressionModel(uid, weights), extra) + copyValues(new MyLogisticRegressionModel(uid, weights), extra).setParent(parent) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala index 33905277c734..f4a15f806ea8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -236,3 +237,4 @@ object GBTExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala new file mode 100644 index 000000000000..5ce38462d118 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +import org.apache.spark.ml.clustering.KMeans +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.types.{StructField, StructType} + + +/** + * An example demonstrating a k-means clustering. + * Run with + * {{{ + * bin/run-example ml.KMeansExample + * }}} + */ +object KMeansExample { + + final val FEATURES_COL = "features" + + def main(args: Array[String]): Unit = { + if (args.length != 2) { + // scalastyle:off println + System.err.println("Usage: ml.KMeansExample ") + // scalastyle:on println + System.exit(1) + } + val input = args(0) + val k = args(1).toInt + + // Creates a Spark context and a SQL context + val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // Loads data + val rowRDD = sc.textFile(input).filter(_.nonEmpty) + .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_)) + val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false))) + val dataset = sqlContext.createDataFrame(rowRDD, schema) + + // Trains a k-means model + val kmeans = new KMeans() + .setK(k) + .setFeaturesCol(FEATURES_COL) + val model = kmeans.fit(dataset) + + // Shows the result + // scalastyle:off println + println("Final Centers: ") + model.clusterCenters.foreach(println) + // scalastyle:on println + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala index b54466fd48bc..b73299fb12d3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -140,3 +141,4 @@ object LinearRegressionExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala index 3cf193f353fb..8e3760ddb50a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -135,6 +136,7 @@ object LogisticRegressionExample { .setElasticNetParam(params.elasticNetParam) .setMaxIter(params.maxIter) .setTol(params.tol) + .setFitIntercept(params.fitIntercept) stages += lor val pipeline = new Pipeline().setStages(stages.toArray) @@ -157,3 +159,4 @@ object LogisticRegressionExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala index 25f21113bf62..3ae53e57dbdb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scopt.OptionParser @@ -75,7 +76,7 @@ object MovieLensALS { .text("path to a MovieLens dataset of movies") .action((x, c) => c.copy(movies = x)) opt[Int]("rank") - .text(s"rank, default: ${defaultParams.rank}}") + .text(s"rank, default: ${defaultParams.rank}") .action((x, c) => c.copy(rank = x)) opt[Int]("maxIter") .text(s"max number of iterations, default: ${defaultParams.maxIter}") @@ -178,3 +179,4 @@ object MovieLensALS { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala index 6927eb8f275c..bab31f585b0e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import java.util.concurrent.TimeUnit.{NANOSECONDS => NANO} @@ -183,3 +184,4 @@ object OneVsRestExample { (NANO.toSeconds(t1 - t0), result) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala index 9f7cad68a459..109178f4137b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -244,3 +245,4 @@ object RandomForestExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index a0561e2573fc..f4d1fe57856a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} @@ -69,7 +70,7 @@ object SimpleParamsExample { // which supports several methods for specifying parameters. val paramMap = ParamMap(lr.maxIter -> 20) paramMap.put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. - paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. + paramMap.put(lr.regParam -> 0.1, lr.thresholds -> Array(0.45, 0.55)) // Specify multiple Params. // One can also combine ParamMaps. val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name @@ -100,3 +101,4 @@ object SimpleParamsExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index 1324b066c30c..960280137cbf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.beans.BeanInfo @@ -89,3 +90,4 @@ object SimpleTextClassificationPipeline { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala new file mode 100644 index 000000000000..1abdf219b1c0 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +/** + * A simple example demonstrating model selection using TrainValidationSplit. + * + * The example is based on [[SimpleParamsExample]] using linear regression. + * Run with + * {{{ + * bin/run-example ml.TrainValidationSplitExample + * }}} + */ +object TrainValidationSplitExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("TrainValidationSplitExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Prepare training and test data. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) + + val lr = new LinearRegression() + + // We use a ParamGridBuilder to construct a grid of parameters to search over. + // TrainValidationSplit will try all combinations of values and determine best model using + // the evaluator. + val paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.01)) + .addGrid(lr.fitIntercept, Array(true, false)) + .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)) + .build() + + // In this case the estimator is simply the linear regression. + // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + val trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator) + .setEstimatorParamMaps(paramGrid) + + // 80% of the data will be used for training and the remaining 20% for validation. + trainValidationSplit.setTrainRatio(0.8) + + // Run train validation split, and choose the best set of parameters. + val model = trainValidationSplit.fit(training) + + // Make predictions on test data. model is the model with combination of parameters + // that performed best. + model.transform(test) + .select("features", "label", "prediction") + .show() + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala index a113653810b9..1a4016f76c2a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -153,3 +154,4 @@ object BinaryClassification { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala index e49129c4e784..026d4ecc6d10 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -91,3 +92,4 @@ object Correlations { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala index cb1abbd18fd4..69988cc1b933 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -106,3 +107,4 @@ object CosineSimilarity { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala index 520893b26d59..dc13f82488af 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import java.io.File @@ -119,3 +120,4 @@ object DatasetExample { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 3381941673db..cc6bce3cb7c9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scala.language.reflectiveCalls @@ -99,7 +100,7 @@ object DecisionTreeRunner { .action((x, c) => c.copy(numTrees = x)) opt[String]("featureSubsetStrategy") .text(s"feature subset sampling strategy" + - s" (${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}}), " + + s" (${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}), " + s"default: ${defaultParams.featureSubsetStrategy}") .action((x, c) => c.copy(featureSubsetStrategy = x)) opt[Double]("fracTest") @@ -368,3 +369,4 @@ object DecisionTreeRunner { } // scalastyle:on structural.type } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala index f8c71ccabc43..1fce4ba7efd6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} @@ -65,3 +66,4 @@ object DenseGaussianMixture { println() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala index 14cc5cbb679c..380d85d60e7b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -107,3 +108,4 @@ object DenseKMeans { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala index 13f24a1e5961..14b930550d55 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -80,3 +81,4 @@ object FPGrowthExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala index 7416fb5a4084..e16a6bf03357 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -145,3 +146,4 @@ object GradientBoostedTreesRunner { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index 31d629f85316..75b0f69cf91a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import java.text.BreakIterator @@ -302,3 +303,4 @@ private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Se } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala index 6a456ba7ec07..8878061a0970 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -134,3 +135,4 @@ object LinearRegression { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 99588b0984ab..69691ae297f6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scala.collection.mutable @@ -54,7 +55,7 @@ object MovieLensALS { val parser = new OptionParser[Params]("MovieLensALS") { head("MovieLensALS: an example app for ALS on MovieLens data.") opt[Int]("rank") - .text(s"rank, default: ${defaultParams.rank}}") + .text(s"rank, default: ${defaultParams.rank}") .action((x, c) => c.copy(rank = x)) opt[Int]("numIterations") .text(s"number of iterations, default: ${defaultParams.numIterations}") @@ -189,3 +190,4 @@ object MovieLensALS { math.sqrt(predictionsAndRatings.map(x => (x._1 - x._2) * (x._1 - x._2)).mean()) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala index 6e4e2d07f284..5f839c75dd58 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -97,3 +98,4 @@ object MultivariateSummarizer { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala index 6d8b806569df..072322395461 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -154,4 +155,4 @@ object PowerIterationClusteringExample { coeff * math.exp(expCoeff * ssquares) } } - +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala index 924b586e3af9..bee85ba0f996 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.random.RandomRDDs @@ -58,3 +59,4 @@ object RandomRDDGeneration { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala index 663c12734af6..6963f43e082c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.util.MLUtils @@ -125,3 +126,4 @@ object SampledRDDs { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala index f1ff4e6911f5..f81fc292a3bd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -100,3 +101,4 @@ object SparseNaiveBayes { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala index 8bb12d2ee9ed..af03724a8ac6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.SparkConf @@ -75,3 +76,4 @@ object StreamingKMeansExample { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala index 1a95048bbfe2..b4a5dca031ab 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.linalg.Vectors @@ -69,3 +70,4 @@ object StreamingLinearRegression { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala index e1998099c2d7..b42f4cb5f933 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.linalg.Vectors @@ -71,3 +72,4 @@ object StreamingLogisticRegression { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala index 3cd9cb743e30..464fbd385ab5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} @@ -58,3 +59,4 @@ object TallSkinnyPCA { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala index 4d6690318615..65b4bc46f026 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} @@ -58,3 +59,4 @@ object TallSkinnySVD { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala index 3ebb112fc069..805184e740f0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala @@ -19,7 +19,7 @@ package org.apache.spark.examples.pythonconverters import java.util.{Collection => JCollection, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.avro.generic.{GenericFixed, IndexedRecord} import org.apache.avro.mapred.AvroWrapper @@ -58,7 +58,7 @@ object AvroConversionUtil extends Serializable { val map = new java.util.HashMap[String, Any] obj match { case record: IndexedRecord => - record.getSchema.getFields.zipWithIndex.foreach { case (f, i) => + record.getSchema.getFields.asScala.zipWithIndex.foreach { case (f, i) => map.put(f.name, fromAvro(record.get(i), f.schema)) } case other => throw new SparkException( @@ -68,9 +68,9 @@ object AvroConversionUtil extends Serializable { } def unpackMap(obj: Any, schema: Schema): JMap[String, Any] = { - obj.asInstanceOf[JMap[_, _]].map { case (key, value) => + obj.asInstanceOf[JMap[_, _]].asScala.map { case (key, value) => (key.toString, fromAvro(value, schema.getValueType)) - } + }.asJava } def unpackFixed(obj: Any, schema: Schema): Array[Byte] = { @@ -91,17 +91,17 @@ object AvroConversionUtil extends Serializable { def unpackArray(obj: Any, schema: Schema): JCollection[Any] = obj match { case c: JCollection[_] => - c.map(fromAvro(_, schema.getElementType)) + c.asScala.map(fromAvro(_, schema.getElementType)).toSeq.asJava case arr: Array[_] if arr.getClass.getComponentType.isPrimitive => - arr.toSeq + arr.toSeq.asJava.asInstanceOf[JCollection[Any]] case arr: Array[_] => - arr.map(fromAvro(_, schema.getElementType)).toSeq + arr.map(fromAvro(_, schema.getElementType)).toSeq.asJava case other => throw new SparkException( s"Unknown ARRAY type ${other.getClass.getName}") } def unpackUnion(obj: Any, schema: Schema): Any = { - schema.getTypes.toList match { + schema.getTypes.asScala.toList match { case List(s) => fromAvro(obj, s) case List(n, s) if n.getType == NULL => fromAvro(obj, s) case List(s, n) if n.getType == NULL => fromAvro(obj, s) diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala index 83feb5703b90..00ce47af4813 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala @@ -17,11 +17,13 @@ package org.apache.spark.examples.pythonconverters -import org.apache.spark.api.python.Converter import java.nio.ByteBuffer + +import scala.collection.JavaConverters._ + import org.apache.cassandra.utils.ByteBufferUtil -import collection.JavaConversions._ +import org.apache.spark.api.python.Converter /** * Implementation of [[org.apache.spark.api.python.Converter]] that converts Cassandra @@ -30,7 +32,7 @@ import collection.JavaConversions._ class CassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, Int]] { override def convert(obj: Any): java.util.Map[String, Int] = { val result = obj.asInstanceOf[java.util.Map[String, ByteBuffer]] - mapAsJavaMap(result.mapValues(bb => ByteBufferUtil.toInt(bb))) + result.asScala.mapValues(ByteBufferUtil.toInt).asJava } } @@ -41,7 +43,7 @@ class CassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, Int] class CassandraCQLValueConverter extends Converter[Any, java.util.Map[String, String]] { override def convert(obj: Any): java.util.Map[String, String] = { val result = obj.asInstanceOf[java.util.Map[String, ByteBuffer]] - mapAsJavaMap(result.mapValues(bb => ByteBufferUtil.string(bb))) + result.asScala.mapValues(ByteBufferUtil.string).asJava } } @@ -52,7 +54,7 @@ class CassandraCQLValueConverter extends Converter[Any, java.util.Map[String, St class ToCassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, ByteBuffer]] { override def convert(obj: Any): java.util.Map[String, ByteBuffer] = { val input = obj.asInstanceOf[java.util.Map[String, Int]] - mapAsJavaMap(input.mapValues(i => ByteBufferUtil.bytes(i))) + input.asScala.mapValues(ByteBufferUtil.bytes).asJava } } @@ -63,6 +65,6 @@ class ToCassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, By class ToCassandraCQLValueConverter extends Converter[Any, java.util.List[ByteBuffer]] { override def convert(obj: Any): java.util.List[ByteBuffer] = { val input = obj.asInstanceOf[java.util.List[String]] - seqAsJavaList(input.map(s => ByteBufferUtil.bytes(s))) + input.asScala.map(ByteBufferUtil.bytes).asJava } } diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala index 90d48a64106c..0a25ee7ae56f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala @@ -17,7 +17,7 @@ package org.apache.spark.examples.pythonconverters -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.parsing.json.JSONObject import org.apache.spark.api.python.Converter @@ -33,7 +33,6 @@ import org.apache.hadoop.hbase.CellUtil */ class HBaseResultToStringConverter extends Converter[Any, String] { override def convert(obj: Any): String = { - import collection.JavaConverters._ val result = obj.asInstanceOf[Result] val output = result.listCells.asScala.map(cell => Map( @@ -77,7 +76,7 @@ class StringToImmutableBytesWritableConverter extends Converter[Any, ImmutableBy */ class StringListToPutConverter extends Converter[Any, Put] { override def convert(obj: Any): Put = { - val output = obj.asInstanceOf[java.util.ArrayList[String]].map(Bytes.toBytes(_)).toArray + val output = obj.asInstanceOf[java.util.ArrayList[String]].asScala.map(Bytes.toBytes).toArray val put = new Put(output(0)) put.add(output(1), output(2), output(3)) } diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index b11e32047dc3..2cc56f04e5c1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.sql import org.apache.spark.{SparkConf, SparkContext} @@ -73,3 +74,4 @@ object RDDRelation { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index b7ba60ec2815..bf40bd1ef13d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.sql.hive import com.google.common.io.{ByteStreams, Files} @@ -77,3 +78,4 @@ object HiveFromSpark { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index 016de4c63d1d..e9c990719876 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import scala.collection.mutable.LinkedList @@ -170,3 +171,4 @@ object ActorWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala index 30269a7ccae9..28e9bf520e56 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.io.{InputStreamReader, BufferedReader, InputStream} @@ -100,3 +101,4 @@ class CustomReceiver(host: String, port: Int) } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala index fbe394de4a17..bd78526f8c29 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import kafka.serializer.StringDecoder @@ -70,3 +71,4 @@ object DirectKafkaWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala index 20e7df7c45b1..91e52e4eff5a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -66,3 +67,4 @@ object FlumeEventCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala index 1cc8c8d5c23b..2bdbc37e2a28 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -65,3 +66,4 @@ object FlumePollingEventCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala index 4b4667fec44e..1f282d437dc3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -53,3 +54,4 @@ object HdfsWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala index 60416ee34354..b40d17e9c2fa 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.util.HashMap @@ -101,3 +102,4 @@ object KafkaWordCountProducer { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala index 813c8554f519..d772ae309f40 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.eclipse.paho.client.mqttv3._ @@ -96,8 +97,10 @@ object MQTTWordCount { def main(args: Array[String]) { if (args.length < 2) { + // scalastyle:off println System.err.println( "Usage: MQTTWordCount ") + // scalastyle:on println System.exit(1) } @@ -113,3 +116,4 @@ object MQTTWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala index 2cd8073dada1..9a57fe286d1a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -57,3 +58,4 @@ object NetworkWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala index a9aaa445bccb..5322929d177b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -58,3 +59,4 @@ object RawNetworkGrep { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 751b30ea1578..9916882e4f94 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.io.File @@ -108,3 +109,4 @@ object RecoverableNetworkWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala index 5a6b9216a3fb..ed617754cbf1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -99,3 +100,4 @@ object SQLContextSingleton { instance } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index 345d0bc44135..02ba1c2eed0f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -78,3 +79,4 @@ object StatefulNetworkWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala index c10de84a80ff..825c671a929b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import com.twitter.algebird._ @@ -113,3 +114,4 @@ object TwitterAlgebirdCMS { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala index 62db5e663b8a..49826ede7041 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import com.twitter.algebird.HyperLogLogMonoid @@ -90,3 +91,4 @@ object TwitterAlgebirdHLL { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala index f253d75b279f..49cee1b43c2d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.streaming.{Seconds, StreamingContext} @@ -82,3 +83,4 @@ object TwitterPopularTags { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala index e99d1baa72b9..6ac9a72c3794 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import akka.actor.ActorSystem @@ -97,3 +98,4 @@ object ZeroMQWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 889f052c7026..bea7a47cb285 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming.clickstream import java.net.ServerSocket @@ -108,3 +109,4 @@ object PageViewGenerator { } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index fbacaee98690..ec7d39da8b2e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming.clickstream import org.apache.spark.SparkContext._ @@ -107,3 +108,4 @@ object PageViewStream { ssc.start() } } +// scalastyle:on println diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml new file mode 100644 index 000000000000..dceedcf23ed5 --- /dev/null +++ b/external/flume-assembly/pom.xml @@ -0,0 +1,168 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.6.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-flume-assembly_2.10 + jar + Spark Project External Flume Assembly + http://spark.apache.org/ + + + provided + streaming-flume-assembly + + + + + org.apache.spark + spark-streaming-flume_${scala.binary.version} + ${project.version} + + + org.mortbay.jetty + jetty + + + org.mortbay.jetty + jetty-util + + + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + + commons-codec + commons-codec + provided + + + commons-lang + commons-lang + provided + + + commons-net + commons-net + provided + + + com.google.protobuf + protobuf-java + provided + + + org.apache.avro + avro + provided + + + org.apache.avro + avro-ipc + provided + + + org.apache.avro + avro-mapred + ${avro.mapred.classifier} + provided + + + org.scala-lang + scala-library + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + + + + flume-provided + + provided + + + + + diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 7a7dccc3d092..d7c2ac474a18 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -35,10 +35,6 @@ http://spark.apache.org/ - - org.apache.commons - commons-lang3 - org.apache.flume flume-ng-sdk diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala index 17cbc6707b5e..d87b86932dd4 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala @@ -113,7 +113,9 @@ private[sink] object Logging { try { // We use reflection here to handle the case where users remove the // slf4j-to-jul bridge order to route their logs to JUL. + // scalastyle:off classforname val bridgeClass = Class.forName("org.slf4j.bridge.SLF4JBridgeHandler") + // scalastyle:on classforname bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null) val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean] if (!installed) { diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala index dc2a4ab138e1..719fca0938b3 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -16,13 +16,13 @@ */ package org.apache.spark.streaming.flume.sink +import java.util.UUID import java.util.concurrent.{CountDownLatch, Executors} import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable import org.apache.flume.Channel -import org.apache.commons.lang3.RandomStringUtils /** * Class that implements the SparkFlumeProtocol, that is used by the Avro Netty Server to process @@ -53,7 +53,7 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha // Since the new txn may not have the same sequence number we must guard against accidentally // committing a new transaction. To reduce the probability of that happening a random string is // prepended to the sequence number. Does not change for life of sink - private val seqBase = RandomStringUtils.randomAlphanumeric(8) + private val seqBase = UUID.randomUUID().toString.substring(0, 8) private val seqCounter = new AtomicLong(0) // Protected by `sequenceNumberToProcessor` diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index fa43629d4977..d2654700ea72 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -20,7 +20,7 @@ import java.net.InetSocketAddress import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.{TimeUnit, CountDownLatch, Executors} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} @@ -166,7 +166,7 @@ class SparkSinkSuite extends FunSuite { channelContext.put("capacity", channelCapacity.toString) channelContext.put("transactionCapacity", 1000.toString) channelContext.put("keep-alive", 0.toString) - channelContext.putAll(overrides) + channelContext.putAll(overrides.asJava) channel.setName(scala.util.Random.nextString(10)) channel.configure(channelContext) diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 14f7daaf417e..132062f94fb4 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala index 65c49c131518..48df27b26867 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.flume import java.io.{ObjectOutput, ObjectInput} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.util.Utils import org.apache.spark.Logging @@ -60,7 +60,7 @@ private[streaming] object EventTransformer extends Logging { out.write(body) val numHeaders = headers.size() out.writeInt(numHeaders) - for ((k, v) <- headers) { + for ((k, v) <- headers.asScala) { val keyBuff = Utils.serialize(k.toString) out.writeInt(keyBuff.length) out.write(keyBuff) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala index 88cc2aa3bf02..b9d4e762ca05 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala @@ -16,7 +16,6 @@ */ package org.apache.spark.streaming.flume -import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import com.google.common.base.Throwables @@ -155,7 +154,7 @@ private[flume] class FlumeBatchFetcher(receiver: FlumePollingReceiver) extends R val buffer = new ArrayBuffer[SparkFlumeEvent](events.size()) var j = 0 while (j < events.size()) { - val event = events(j) + val event = events.get(j) val sparkFlumeEvent = new SparkFlumeEvent() sparkFlumeEvent.event.setBody(event.getBody) sparkFlumeEvent.event.setHeaders(event.getHeaders) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 1e32a365a1ee..c8780aa83bdb 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -22,7 +22,7 @@ import java.io.{ObjectInput, ObjectOutput, Externalizable} import java.nio.ByteBuffer import java.util.concurrent.Executors -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.flume.source.avro.AvroSourceProtocol @@ -43,7 +43,7 @@ import org.jboss.netty.handler.codec.compression._ private[streaming] class FlumeInputDStream[T: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, host: String, port: Int, storageLevel: StorageLevel, @@ -99,7 +99,7 @@ class SparkFlumeEvent() extends Externalizable { val numHeaders = event.getHeaders.size() out.writeInt(numHeaders) - for ((k, v) <- event.getHeaders) { + for ((k, v) <- event.getHeaders.asScala) { val keyBuff = Utils.serialize(k.toString) out.writeInt(keyBuff.length) out.write(keyBuff) @@ -127,8 +127,7 @@ class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol { } override def appendBatch(events : java.util.List[AvroFlumeEvent]) : Status = { - events.foreach (event => - receiver.store(SparkFlumeEvent.fromAvroFlumeEvent(event))) + events.asScala.foreach(event => receiver.store(SparkFlumeEvent.fromAvroFlumeEvent(event))) Status.OK } } diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala index 583e7dca317a..3b936d88abd3 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress import java.util.concurrent.{LinkedBlockingQueue, Executors} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import com.google.common.util.concurrent.ThreadFactoryBuilder @@ -46,7 +46,7 @@ import org.apache.spark.streaming.flume.sink._ * @tparam T Class type of the object of this stream */ private[streaming] class FlumePollingInputDStream[T: ClassTag]( - @transient _ssc: StreamingContext, + _ssc: StreamingContext, val addresses: Seq[InetSocketAddress], val maxBatchSize: Int, val parallelism: Int, @@ -94,9 +94,7 @@ private[streaming] class FlumePollingReceiver( override def onStop(): Unit = { logInfo("Shutting down Flume Polling Receiver") receiverExecutor.shutdownNow() - connections.foreach(connection => { - connection.transceiver.close() - }) + connections.asScala.foreach(_.transceiver.close()) channelFactory.releaseExternalResources() } diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala new file mode 100644 index 000000000000..70018c86f92b --- /dev/null +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.flume + +import java.net.{InetSocketAddress, ServerSocket} +import java.nio.ByteBuffer +import java.util.Collections + +import scala.collection.JavaConverters._ + +import com.google.common.base.Charsets.UTF_8 +import org.apache.avro.ipc.NettyTransceiver +import org.apache.avro.ipc.specific.SpecificRequestor +import org.apache.commons.lang3.RandomUtils +import org.apache.flume.source.avro +import org.apache.flume.source.avro.{AvroSourceProtocol, AvroFlumeEvent} +import org.jboss.netty.channel.ChannelPipeline +import org.jboss.netty.channel.socket.SocketChannel +import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory +import org.jboss.netty.handler.codec.compression.{ZlibDecoder, ZlibEncoder} + +import org.apache.spark.util.Utils +import org.apache.spark.SparkConf + +/** + * Share codes for Scala and Python unit tests + */ +private[flume] class FlumeTestUtils { + + private var transceiver: NettyTransceiver = null + + private val testPort: Int = findFreePort() + + def getTestPort(): Int = testPort + + /** Find a free port */ + private def findFreePort(): Int = { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + }, new SparkConf())._2 + } + + /** Send data to the flume receiver */ + def writeInput(input: Seq[String], enableCompression: Boolean): Unit = { + val testAddress = new InetSocketAddress("localhost", testPort) + + val inputEvents = input.map { item => + val event = new AvroFlumeEvent + event.setBody(ByteBuffer.wrap(item.getBytes(UTF_8))) + event.setHeaders(Collections.singletonMap("test", "header")) + event + } + + // if last attempted transceiver had succeeded, close it + close() + + // Create transceiver + transceiver = { + if (enableCompression) { + new NettyTransceiver(testAddress, new CompressionChannelFactory(6)) + } else { + new NettyTransceiver(testAddress) + } + } + + // Create Avro client with the transceiver + val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver) + if (client == null) { + throw new AssertionError("Cannot create client") + } + + // Send data + val status = client.appendBatch(inputEvents.asJava) + if (status != avro.Status.OK) { + throw new AssertionError("Sent events unsuccessfully") + } + } + + def close(): Unit = { + if (transceiver != null) { + transceiver.close() + transceiver = null + } + } + + /** Class to create socket channel with compression */ + private class CompressionChannelFactory(compressionLevel: Int) + extends NioClientSocketChannelFactory { + + override def newChannel(pipeline: ChannelPipeline): SocketChannel = { + val encoder = new ZlibEncoder(compressionLevel) + pipeline.addFirst("deflater", encoder) + pipeline.addFirst("inflater", new ZlibDecoder()) + super.newChannel(pipeline) + } + } + +} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index 44dec45c227c..c719b80aca7e 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -18,10 +18,16 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress +import java.io.{DataOutputStream, ByteArrayOutputStream} +import java.util.{List => JList, Map => JMap} +import scala.collection.JavaConverters._ + +import org.apache.spark.api.java.function.PairFunction +import org.apache.spark.api.python.PythonRDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream @@ -236,3 +242,71 @@ object FlumeUtils { createPollingStream(jssc.ssc, addresses, storageLevel, maxBatchSize, parallelism) } } + +/** + * This is a helper class that wraps the methods in FlumeUtils into more Python-friendly class and + * function so that it can be easily instantiated and called from Python's FlumeUtils. + */ +private[flume] class FlumeUtilsPythonHelper { + + def createStream( + jssc: JavaStreamingContext, + hostname: String, + port: Int, + storageLevel: StorageLevel, + enableDecompression: Boolean + ): JavaPairDStream[Array[Byte], Array[Byte]] = { + val dstream = FlumeUtils.createStream(jssc, hostname, port, storageLevel, enableDecompression) + FlumeUtilsPythonHelper.toByteArrayPairDStream(dstream) + } + + def createPollingStream( + jssc: JavaStreamingContext, + hosts: JList[String], + ports: JList[Int], + storageLevel: StorageLevel, + maxBatchSize: Int, + parallelism: Int + ): JavaPairDStream[Array[Byte], Array[Byte]] = { + assert(hosts.size() == ports.size()) + val addresses = hosts.asScala.zip(ports.asScala).map { + case (host, port) => new InetSocketAddress(host, port) + } + val dstream = FlumeUtils.createPollingStream( + jssc.ssc, addresses, storageLevel, maxBatchSize, parallelism) + FlumeUtilsPythonHelper.toByteArrayPairDStream(dstream) + } + +} + +private object FlumeUtilsPythonHelper { + + private def stringMapToByteArray(map: JMap[CharSequence, CharSequence]): Array[Byte] = { + val byteStream = new ByteArrayOutputStream() + val output = new DataOutputStream(byteStream) + try { + output.writeInt(map.size) + map.asScala.foreach { kv => + PythonRDD.writeUTF(kv._1.toString, output) + PythonRDD.writeUTF(kv._2.toString, output) + } + byteStream.toByteArray + } + finally { + output.close() + } + } + + private def toByteArrayPairDStream(dstream: JavaReceiverInputDStream[SparkFlumeEvent]): + JavaPairDStream[Array[Byte], Array[Byte]] = { + dstream.mapToPair(new PairFunction[SparkFlumeEvent, Array[Byte], Array[Byte]] { + override def call(sparkEvent: SparkFlumeEvent): (Array[Byte], Array[Byte]) = { + val event = sparkEvent.event + val byteBuffer = event.getBody + val body = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(body) + (stringMapToByteArray(event.getHeaders), body) + } + }) + } +} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala new file mode 100644 index 000000000000..a2ab320957db --- /dev/null +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.flume + +import java.util.concurrent._ +import java.util.{Map => JMap, Collections} + +import scala.collection.mutable.ArrayBuffer + +import com.google.common.base.Charsets.UTF_8 +import org.apache.flume.event.EventBuilder +import org.apache.flume.Context +import org.apache.flume.channel.MemoryChannel +import org.apache.flume.conf.Configurables + +import org.apache.spark.streaming.flume.sink.{SparkSinkConfig, SparkSink} + +/** + * Share codes for Scala and Python unit tests + */ +private[flume] class PollingFlumeTestUtils { + + private val batchCount = 5 + val eventsPerBatch = 100 + private val totalEventsPerChannel = batchCount * eventsPerBatch + private val channelCapacity = 5000 + + def getTotalEvents: Int = totalEventsPerChannel * channels.size + + private val channels = new ArrayBuffer[MemoryChannel] + private val sinks = new ArrayBuffer[SparkSink] + + /** + * Start a sink and return the port of this sink + */ + def startSingleSink(): Int = { + channels.clear() + sinks.clear() + + // Start the channel and sink. + val context = new Context() + context.put("capacity", channelCapacity.toString) + context.put("transactionCapacity", "1000") + context.put("keep-alive", "0") + val channel = new MemoryChannel() + Configurables.configure(channel, context) + + val sink = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) + Configurables.configure(sink, context) + sink.setChannel(channel) + sink.start() + + channels += (channel) + sinks += sink + + sink.getPort() + } + + /** + * Start 2 sinks and return the ports + */ + def startMultipleSinks(): Seq[Int] = { + channels.clear() + sinks.clear() + + // Start the channel and sink. + val context = new Context() + context.put("capacity", channelCapacity.toString) + context.put("transactionCapacity", "1000") + context.put("keep-alive", "0") + val channel = new MemoryChannel() + Configurables.configure(channel, context) + + val channel2 = new MemoryChannel() + Configurables.configure(channel2, context) + + val sink = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) + Configurables.configure(sink, context) + sink.setChannel(channel) + sink.start() + + val sink2 = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) + Configurables.configure(sink2, context) + sink2.setChannel(channel2) + sink2.start() + + sinks += sink + sinks += sink2 + channels += channel + channels += channel2 + + sinks.map(_.getPort()) + } + + /** + * Send data and wait until all data has been received + */ + def sendDatAndEnsureAllDataHasBeenReceived(): Unit = { + val executor = Executors.newCachedThreadPool() + val executorCompletion = new ExecutorCompletionService[Void](executor) + + val latch = new CountDownLatch(batchCount * channels.size) + sinks.foreach(_.countdownWhenBatchReceived(latch)) + + channels.foreach(channel => { + executorCompletion.submit(new TxnSubmitter(channel)) + }) + + for (i <- 0 until channels.size) { + executorCompletion.take() + } + + latch.await(15, TimeUnit.SECONDS) // Ensure all data has been received. + } + + /** + * A Python-friendly method to assert the output + */ + def assertOutput(outputHeaders: Seq[JMap[String, String]], outputBodies: Seq[String]): Unit = { + require(outputHeaders.size == outputBodies.size) + val eventSize = outputHeaders.size + if (eventSize != totalEventsPerChannel * channels.size) { + throw new AssertionError( + s"Expected ${totalEventsPerChannel * channels.size} events, but was $eventSize") + } + var counter = 0 + for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { + val eventBodyToVerify = s"${channels(k).getName}-$i" + val eventHeaderToVerify: JMap[String, String] = Collections.singletonMap(s"test-$i", "header") + var found = false + var j = 0 + while (j < eventSize && !found) { + if (eventBodyToVerify == outputBodies(j) && + eventHeaderToVerify == outputHeaders(j)) { + found = true + counter += 1 + } + j += 1 + } + } + if (counter != totalEventsPerChannel * channels.size) { + throw new AssertionError( + s"111 Expected ${totalEventsPerChannel * channels.size} events, but was $counter") + } + } + + def assertChannelsAreEmpty(): Unit = { + channels.foreach(assertChannelIsEmpty) + } + + private def assertChannelIsEmpty(channel: MemoryChannel): Unit = { + val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") + queueRemaining.setAccessible(true) + val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") + if (m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] != 5000) { + throw new AssertionError(s"Channel ${channel.getName} is not empty") + } + } + + def close(): Unit = { + sinks.foreach(_.stop()) + sinks.clear() + channels.foreach(_.stop()) + channels.clear() + } + + private class TxnSubmitter(channel: MemoryChannel) extends Callable[Void] { + override def call(): Void = { + var t = 0 + for (i <- 0 until batchCount) { + val tx = channel.getTransaction + tx.begin() + for (j <- 0 until eventsPerBatch) { + channel.put(EventBuilder.withBody(s"${channel.getName}-$t".getBytes(UTF_8), + Collections.singletonMap(s"test-$t", "header"))) + t += 1 + } + tx.commit() + tx.close() + Thread.sleep(500) // Allow some time for the events to reach + } + null + } + } + +} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index d772b9ca9b57..ff2fb8eed204 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -18,47 +18,33 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress -import java.util.concurrent._ -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} import scala.concurrent.duration._ import scala.language.postfixOps -import org.apache.flume.Context -import org.apache.flume.channel.MemoryChannel -import org.apache.flume.conf.Configurables -import org.apache.flume.event.EventBuilder -import org.scalatest.concurrent.Eventually._ - +import com.google.common.base.Charsets.UTF_8 import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} -import org.apache.spark.streaming.flume.sink._ import org.apache.spark.util.{ManualClock, Utils} class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { - val batchCount = 5 - val eventsPerBatch = 100 - val totalEventsPerChannel = batchCount * eventsPerBatch - val channelCapacity = 5000 val maxAttempts = 5 val batchDuration = Seconds(1) val conf = new SparkConf() .setMaster("local[2]") .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock") - def beforeFunction() { - logInfo("Using manual clock") - conf.set("spark.streaming.clock", "org.apache.spark.util.ManualClock") - } - - before(beforeFunction()) + val utils = new PollingFlumeTestUtils test("flume polling test") { testMultipleTimes(testFlumePolling) @@ -89,146 +75,55 @@ class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Log } private def testFlumePolling(): Unit = { - // Start the channel and sink. - val context = new Context() - context.put("capacity", channelCapacity.toString) - context.put("transactionCapacity", "1000") - context.put("keep-alive", "0") - val channel = new MemoryChannel() - Configurables.configure(channel, context) - - val sink = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink, context) - sink.setChannel(channel) - sink.start() - - writeAndVerify(Seq(sink), Seq(channel)) - assertChannelIsEmpty(channel) - sink.stop() - channel.stop() + try { + val port = utils.startSingleSink() + + writeAndVerify(Seq(port)) + utils.assertChannelsAreEmpty() + } finally { + utils.close() + } } private def testFlumePollingMultipleHost(): Unit = { - // Start the channel and sink. - val context = new Context() - context.put("capacity", channelCapacity.toString) - context.put("transactionCapacity", "1000") - context.put("keep-alive", "0") - val channel = new MemoryChannel() - Configurables.configure(channel, context) - - val channel2 = new MemoryChannel() - Configurables.configure(channel2, context) - - val sink = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink, context) - sink.setChannel(channel) - sink.start() - - val sink2 = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink2, context) - sink2.setChannel(channel2) - sink2.start() try { - writeAndVerify(Seq(sink, sink2), Seq(channel, channel2)) - assertChannelIsEmpty(channel) - assertChannelIsEmpty(channel2) + val ports = utils.startMultipleSinks() + writeAndVerify(ports) + utils.assertChannelsAreEmpty() } finally { - sink.stop() - sink2.stop() - channel.stop() - channel2.stop() + utils.close() } } - def writeAndVerify(sinks: Seq[SparkSink], channels: Seq[MemoryChannel]) { + def writeAndVerify(sinkPorts: Seq[Int]): Unit = { // Set up the streaming context and input streams val ssc = new StreamingContext(conf, batchDuration) - val addresses = sinks.map(sink => new InetSocketAddress("localhost", sink.getPort())) + val addresses = sinkPorts.map(port => new InetSocketAddress("localhost", port)) val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, - eventsPerBatch, 5) + utils.eventsPerBatch, 5) val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] with SynchronizedBuffer[Seq[SparkFlumeEvent]] val outputStream = new TestOutputStream(flumeStream, outputBuffer) outputStream.register() ssc.start() - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val executor = Executors.newCachedThreadPool() - val executorCompletion = new ExecutorCompletionService[Void](executor) - - val latch = new CountDownLatch(batchCount * channels.size) - sinks.foreach(_.countdownWhenBatchReceived(latch)) - - channels.foreach(channel => { - executorCompletion.submit(new TxnSubmitter(channel, clock)) - }) - - for (i <- 0 until channels.size) { - executorCompletion.take() - } - - latch.await(15, TimeUnit.SECONDS) // Ensure all data has been received. - clock.advance(batchDuration.milliseconds) - - // The eventually is required to ensure that all data in the batch has been processed. - eventually(timeout(10 seconds), interval(100 milliseconds)) { - val flattenedBuffer = outputBuffer.flatten - assert(flattenedBuffer.size === totalEventsPerChannel * channels.size) - var counter = 0 - for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { - val eventToVerify = EventBuilder.withBody((channels(k).getName + " - " + - String.valueOf(i)).getBytes("utf-8"), - Map[String, String]("test-" + i.toString -> "header")) - var found = false - var j = 0 - while (j < flattenedBuffer.size && !found) { - val strToCompare = new String(flattenedBuffer(j).event.getBody.array(), "utf-8") - if (new String(eventToVerify.getBody, "utf-8") == strToCompare && - eventToVerify.getHeaders.get("test-" + i.toString) - .equals(flattenedBuffer(j).event.getHeaders.get("test-" + i.toString))) { - found = true - counter += 1 - } - j += 1 - } - } - assert(counter === totalEventsPerChannel * channels.size) - } - ssc.stop() - } - - def assertChannelIsEmpty(channel: MemoryChannel): Unit = { - val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") - queueRemaining.setAccessible(true) - val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") - assert(m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] === 5000) - } - - private class TxnSubmitter(channel: MemoryChannel, clock: ManualClock) extends Callable[Void] { - override def call(): Void = { - var t = 0 - for (i <- 0 until batchCount) { - val tx = channel.getTransaction - tx.begin() - for (j <- 0 until eventsPerBatch) { - channel.put(EventBuilder.withBody((channel.getName + " - " + String.valueOf(t)).getBytes( - "utf-8"), - Map[String, String]("test-" + t.toString -> "header"))) - t += 1 - } - tx.commit() - tx.close() - Thread.sleep(500) // Allow some time for the events to reach + try { + utils.sendDatAndEnsureAllDataHasBeenReceived() + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + clock.advance(batchDuration.milliseconds) + + // The eventually is required to ensure that all data in the batch has been processed. + eventually(timeout(10 seconds), interval(100 milliseconds)) { + val flattenOutputBuffer = outputBuffer.flatten + val headers = flattenOutputBuffer.map(_.event.getHeaders.asScala.map { + case (key, value) => (key.toString, value.toString) + }).map(_.asJava) + val bodies = flattenOutputBuffer.map(e => new String(e.event.getBody.array(), UTF_8)) + utils.assertOutput(headers, bodies) } - null + } finally { + ssc.stop() } } diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index c926359987d8..5ffb60bd602f 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -17,20 +17,12 @@ package org.apache.spark.streaming.flume -import java.net.{InetSocketAddress, ServerSocket} -import java.nio.ByteBuffer - -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps import com.google.common.base.Charsets -import org.apache.avro.ipc.NettyTransceiver -import org.apache.avro.ipc.specific.SpecificRequestor -import org.apache.commons.lang3.RandomUtils -import org.apache.flume.source.avro -import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol} import org.jboss.netty.channel.ChannelPipeline import org.jboss.netty.channel.socket.SocketChannel import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory @@ -41,22 +33,10 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} -import org.apache.spark.util.Utils class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite") - var ssc: StreamingContext = null - var transceiver: NettyTransceiver = null - - after { - if (ssc != null) { - ssc.stop() - } - if (transceiver != null) { - transceiver.close() - } - } test("flume input stream") { testFlumeStream(testCompression = false) @@ -69,19 +49,29 @@ class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers w /** Run test on flume stream */ private def testFlumeStream(testCompression: Boolean): Unit = { val input = (1 to 100).map { _.toString } - val testPort = findFreePort() - val outputBuffer = startContext(testPort, testCompression) - writeAndVerify(input, testPort, outputBuffer, testCompression) - } + val utils = new FlumeTestUtils + try { + val outputBuffer = startContext(utils.getTestPort(), testCompression) - /** Find a free port */ - private def findFreePort(): Int = { - val candidatePort = RandomUtils.nextInt(1024, 65536) - Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { - val socket = new ServerSocket(trialPort) - socket.close() - (null, trialPort) - }, conf)._2 + eventually(timeout(10 seconds), interval(100 milliseconds)) { + utils.writeInput(input, testCompression) + } + + eventually(timeout(10 seconds), interval(100 milliseconds)) { + val outputEvents = outputBuffer.flatten.map { _.event } + outputEvents.foreach { + event => + event.getHeaders.get("test") should be("header") + } + val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8)) + output should be (input) + } + } finally { + if (ssc != null) { + ssc.stop() + } + utils.close() + } } /** Setup and start the streaming context */ @@ -98,58 +88,6 @@ class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers w outputBuffer } - /** Send data to the flume receiver and verify whether the data was received */ - private def writeAndVerify( - input: Seq[String], - testPort: Int, - outputBuffer: ArrayBuffer[Seq[SparkFlumeEvent]], - enableCompression: Boolean - ) { - val testAddress = new InetSocketAddress("localhost", testPort) - - val inputEvents = input.map { item => - val event = new AvroFlumeEvent - event.setBody(ByteBuffer.wrap(item.getBytes(Charsets.UTF_8))) - event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) - event - } - - eventually(timeout(10 seconds), interval(100 milliseconds)) { - // if last attempted transceiver had succeeded, close it - if (transceiver != null) { - transceiver.close() - transceiver = null - } - - // Create transceiver - transceiver = { - if (enableCompression) { - new NettyTransceiver(testAddress, new CompressionChannelFactory(6)) - } else { - new NettyTransceiver(testAddress) - } - } - - // Create Avro client with the transceiver - val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver) - client should not be null - - // Send data - val status = client.appendBatch(inputEvents.toList) - status should be (avro.Status.OK) - } - - eventually(timeout(10 seconds), interval(100 milliseconds)) { - val outputEvents = outputBuffer.flatten.map { _.event } - outputEvents.foreach { - event => - event.getHeaders.get("test") should be("header") - } - val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8)) - output should be (input) - } - } - /** Class to create socket channel with compression */ private class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml index 8059c443827e..a9ed39ef8c9a 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -47,6 +47,90 @@ ${project.version} provided + + + commons-codec + commons-codec + provided + + + commons-lang + commons-lang + provided + + + com.google.protobuf + protobuf-java + provided + + + com.sun.jersey + jersey-server + provided + + + com.sun.jersey + jersey-core + provided + + + net.jpountz.lz4 + lz4 + provided + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.avro + avro-mapred + ${avro.mapred.classifier} + provided + + + org.apache.curator + curator-recipes + provided + + + org.apache.zookeeper + zookeeper + provided + + + log4j + log4j + provided + + + net.java.dev.jets3t + jets3t + provided + + + org.scala-lang + scala-library + provided + + + org.slf4j + slf4j-api + provided + + + org.slf4j + slf4j-log4j12 + provided + + + org.xerial.snappy + snappy-java + provided + diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index ded863bd985e..05abd9e2e681 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala index 5a74febb4bd4..9159051ba06e 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala @@ -20,11 +20,9 @@ package org.apache.spark.streaming.kafka import org.apache.spark.annotation.Experimental /** - * :: Experimental :: - * Represent the host and port info for a Kafka broker. - * Differs from the Kafka project's internal kafka.cluster.Broker, which contains a server ID + * Represents the host and port info for a Kafka broker. + * Differs from the Kafka project's internal kafka.cluster.Broker, which contains a server ID. */ -@Experimental final class Broker private( /** Broker's hostname */ val host: String, diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 876456c96477..8a087474d316 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.kafka import scala.annotation.tailrec import scala.collection.mutable -import scala.reflect.{classTag, ClassTag} +import scala.reflect.ClassTag import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata @@ -29,7 +29,8 @@ import org.apache.spark.{Logging, SparkException} import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset -import org.apache.spark.streaming.scheduler.InputInfo +import org.apache.spark.streaming.scheduler.{RateController, StreamInputInfo} +import org.apache.spark.streaming.scheduler.rate.RateEstimator /** * A stream of {@link org.apache.spark.streaming.kafka.KafkaRDD} where @@ -57,11 +58,11 @@ class DirectKafkaInputDStream[ U <: Decoder[K]: ClassTag, T <: Decoder[V]: ClassTag, R: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, val kafkaParams: Map[String, String], val fromOffsets: Map[TopicAndPartition, Long], messageHandler: MessageAndMetadata[K, V] => R -) extends InputDStream[R](ssc_) with Logging { + ) extends InputDStream[R](ssc_) with Logging { val maxRetries = context.sparkContext.getConf.getInt( "spark.streaming.kafka.maxRetries", 1) @@ -71,14 +72,40 @@ class DirectKafkaInputDStream[ protected[streaming] override val checkpointData = new DirectKafkaInputDStreamCheckpointData + + /** + * Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker. + */ + override protected[streaming] val rateController: Option[RateController] = { + if (RateController.isBackPressureEnabled(ssc.conf)) { + Some(new DirectKafkaRateController(id, + RateEstimator.create(ssc.conf, context.graph.batchDuration))) + } else { + None + } + } + protected val kc = new KafkaCluster(kafkaParams) - protected val maxMessagesPerPartition: Option[Long] = { - val ratePerSec = context.sparkContext.getConf.getInt( + private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt( "spark.streaming.kafka.maxRatePerPartition", 0) - if (ratePerSec > 0) { + protected def maxMessagesPerPartition: Option[Long] = { + val estimatedRateLimit = rateController.map(_.getLatestRate().toInt) + val numPartitions = currentOffsets.keys.size + + val effectiveRateLimitPerPartition = estimatedRateLimit + .filter(_ > 0) + .map { limit => + if (maxRateLimitPerPartition > 0) { + Math.min(maxRateLimitPerPartition, (limit / numPartitions)) + } else { + limit / numPartitions + } + }.getOrElse(maxRateLimitPerPartition) + + if (effectiveRateLimitPerPartition > 0) { val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 - Some((secsPerBatch * ratePerSec).toLong) + Some((secsPerBatch * effectiveRateLimitPerPartition).toLong) } else { None } @@ -119,8 +146,23 @@ class DirectKafkaInputDStream[ val rdd = KafkaRDD[K, V, U, T, R]( context.sparkContext, kafkaParams, currentOffsets, untilOffsets, messageHandler) - // Report the record number of this batch interval to InputInfoTracker. - val inputInfo = InputInfo(id, rdd.count) + // Report the record number and metadata of this batch interval to InputInfoTracker. + val offsetRanges = currentOffsets.map { case (tp, fo) => + val uo = untilOffsets(tp) + OffsetRange(tp.topic, tp.partition, fo, uo.offset) + } + val description = offsetRanges.filter { offsetRange => + // Don't display empty ranges. + offsetRange.fromOffset != offsetRange.untilOffset + }.map { offsetRange => + s"topic: ${offsetRange.topic}\tpartition: ${offsetRange.partition}\t" + + s"offsets: ${offsetRange.fromOffset} to ${offsetRange.untilOffset}" + }.mkString("\n") + // Copy offsetRanges to immutable.List to prevent from being modified by the user + val metadata = Map( + "offsets" -> offsetRanges.toList, + StreamInputInfo.METADATA_KEY_DESCRIPTION -> description) + val inputInfo = StreamInputInfo(id, rdd.count, metadata) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) currentOffsets = untilOffsets.map(kv => kv._1 -> kv._2.offset) @@ -155,11 +197,18 @@ class DirectKafkaInputDStream[ val leaders = KafkaCluster.checkErrors(kc.findLeaders(topics)) batchForTime.toSeq.sortBy(_._1)(Time.ordering).foreach { case (t, b) => - logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}") - generatedRDDs += t -> new KafkaRDD[K, V, U, T, R]( - context.sparkContext, kafkaParams, b.map(OffsetRange(_)), leaders, messageHandler) + logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}") + generatedRDDs += t -> new KafkaRDD[K, V, U, T, R]( + context.sparkContext, kafkaParams, b.map(OffsetRange(_)), leaders, messageHandler) } } } + /** + * A RateController to retrieve the rate from RateEstimator. + */ + private[streaming] class DirectKafkaRateController(id: Int, estimator: RateEstimator) + extends RateController(id, estimator) { + override def publish(rate: Long): Unit = () + } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index 3e6b937af57b..8465432c5850 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -410,7 +410,7 @@ object KafkaCluster { } Seq("zookeeper.connect", "group.id").foreach { s => - if (!props.contains(s)) { + if (!props.containsKey(s)) { props.setProperty(s, "") } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index 04b2dc10d39e..38730fecf332 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -48,7 +48,7 @@ class KafkaInputDStream[ V: ClassTag, U <: Decoder[_]: ClassTag, T <: Decoder[_]: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, kafkaParams: Map[String, String], topics: Map[String, Int], useReliableReceiver: Boolean, diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index c5cd2154772a..ea5f842c6caf 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -98,8 +98,7 @@ class KafkaRDD[ val res = context.runJob( this, (tc: TaskContext, it: Iterator[R]) => it.take(parts(tc.partitionId)).toArray, - parts.keys.toArray, - allowLocal = true) + parts.keys.toArray) res.foreach(buf ++= _) buf.toArray } @@ -198,7 +197,11 @@ class KafkaRDD[ .dropWhile(_.offset < requestOffset) } - override def close(): Unit = consumer.close() + override def close(): Unit = { + if (consumer != null) { + consumer.close() + } + } override def getNext(): R = { if (iter == null || !iter.hasNext) { diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala index b608b7595272..c9fd715d3d55 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -20,27 +20,26 @@ package org.apache.spark.streaming.kafka import java.io.File import java.lang.{Integer => JInt} import java.net.InetSocketAddress -import java.util.{Map => JMap} -import java.util.Properties import java.util.concurrent.TimeoutException +import java.util.{Map => JMap, Properties} import scala.annotation.tailrec +import scala.collection.JavaConverters._ import scala.language.postfixOps import scala.util.control.NonFatal import kafka.admin.AdminUtils import kafka.api.Request -import kafka.common.TopicAndPartition import kafka.producer.{KeyedMessage, Producer, ProducerConfig} import kafka.serializer.StringEncoder import kafka.server.{KafkaConfig, KafkaServer} import kafka.utils.{ZKStringSerializer, ZkUtils} -import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} import org.I0Itec.zkclient.ZkClient +import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} -import org.apache.spark.{Logging, SparkConf} import org.apache.spark.streaming.Time import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf} /** * This is a helper class for Kafka test suites. This has the functionality to set up @@ -48,7 +47,7 @@ import org.apache.spark.util.Utils * * The reason to put Kafka test utility class in src is to test Python related Kafka APIs. */ -private class KafkaTestUtils extends Logging { +private[kafka] class KafkaTestUtils extends Logging { // Zookeeper related configurations private val zkHost = "localhost" @@ -161,8 +160,7 @@ private class KafkaTestUtils extends Logging { /** Java-friendly function for sending messages to the Kafka broker */ def sendMessages(topic: String, messageToFreq: JMap[String, JInt]): Unit = { - import scala.collection.JavaConversions._ - sendMessages(topic, Map(messageToFreq.mapValues(_.intValue()).toSeq: _*)) + sendMessages(topic, Map(messageToFreq.asScala.mapValues(_.intValue()).toSeq: _*)) } /** Send the messages to the Kafka broker */ diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 0e33362d34ac..312822207753 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -17,29 +17,25 @@ package org.apache.spark.streaming.kafka -import java.lang.{Integer => JInt} -import java.lang.{Long => JLong} -import java.util.{Map => JMap} -import java.util.{Set => JSet} -import java.util.{List => JList} +import java.lang.{Integer => JInt, Long => JLong} +import java.util.{List => JList, Map => JMap, Set => JSet} +import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import scala.collection.JavaConversions._ import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata -import kafka.serializer.{DefaultDecoder, Decoder, StringDecoder} +import kafka.serializer.{Decoder, DefaultDecoder, StringDecoder} import org.apache.spark.api.java.function.{Function => JFunction} -import org.apache.spark.streaming.util.WriteAheadLogUtils -import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaPairInputDStream, JavaInputDStream, JavaPairReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.api.java.{JavaInputDStream, JavaPairInputDStream, JavaPairReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream} -import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} +import org.apache.spark.streaming.util.WriteAheadLogUtils +import org.apache.spark.{SparkContext, SparkException} object KafkaUtils { /** @@ -100,7 +96,7 @@ object KafkaUtils { groupId: String, topics: JMap[String, JInt] ): JavaPairReceiverInputDStream[String, String] = { - createStream(jssc.ssc, zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*)) + createStream(jssc.ssc, zkQuorum, groupId, Map(topics.asScala.mapValues(_.intValue()).toSeq: _*)) } /** @@ -119,7 +115,7 @@ object KafkaUtils { topics: JMap[String, JInt], storageLevel: StorageLevel ): JavaPairReceiverInputDStream[String, String] = { - createStream(jssc.ssc, zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), + createStream(jssc.ssc, zkQuorum, groupId, Map(topics.asScala.mapValues(_.intValue()).toSeq: _*), storageLevel) } @@ -153,7 +149,10 @@ object KafkaUtils { implicit val valueCmd: ClassTag[T] = ClassTag(valueDecoderClass) createStream[K, V, U, T]( - jssc.ssc, kafkaParams.toMap, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) + jssc.ssc, + kafkaParams.asScala.toMap, + Map(topics.asScala.mapValues(_.intValue()).toSeq: _*), + storageLevel) } /** get leaders for the given offset ranges, or throw an exception */ @@ -196,7 +195,6 @@ object KafkaUtils { * @param offsetRanges Each OffsetRange in the batch corresponds to a * range of offsets for a given Kafka topic/partition */ - @Experimental def createRDD[ K: ClassTag, V: ClassTag, @@ -214,7 +212,6 @@ object KafkaUtils { } /** - * :: Experimental :: * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you * specify the Kafka leader to connect to (to optimize fetching) and access the message as well * as the metadata. @@ -230,7 +227,6 @@ object KafkaUtils { * in which case leaders will be looked up on the driver. * @param messageHandler Function for translating each message and metadata into the desired type */ - @Experimental def createRDD[ K: ClassTag, V: ClassTag, @@ -268,7 +264,6 @@ object KafkaUtils { * @param offsetRanges Each OffsetRange in the batch corresponds to a * range of offsets for a given Kafka topic/partition */ - @Experimental def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V]]( jsc: JavaSparkContext, keyClass: Class[K], @@ -283,11 +278,10 @@ object KafkaUtils { implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) new JavaPairRDD(createRDD[K, V, KD, VD]( - jsc.sc, Map(kafkaParams.toSeq: _*), offsetRanges)) + jsc.sc, Map(kafkaParams.asScala.toSeq: _*), offsetRanges)) } /** - * :: Experimental :: * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you * specify the Kafka leader to connect to (to optimize fetching) and access the message as well * as the metadata. @@ -303,7 +297,6 @@ object KafkaUtils { * in which case leaders will be looked up on the driver. * @param messageHandler Function for translating each message and metadata into the desired type */ - @Experimental def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V], R]( jsc: JavaSparkContext, keyClass: Class[K], @@ -321,13 +314,12 @@ object KafkaUtils { implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) implicit val recordCmt: ClassTag[R] = ClassTag(recordClass) - val leaderMap = Map(leaders.toSeq: _*) + val leaderMap = Map(leaders.asScala.toSeq: _*) createRDD[K, V, KD, VD, R]( - jsc.sc, Map(kafkaParams.toSeq: _*), offsetRanges, leaderMap, messageHandler.call _) + jsc.sc, Map(kafkaParams.asScala.toSeq: _*), offsetRanges, leaderMap, messageHandler.call(_)) } /** - * :: Experimental :: * Create an input stream that directly pulls messages from Kafka Brokers * without using any receiver. This stream can guarantee that each message * from Kafka is included in transformations exactly once (see points below). @@ -357,7 +349,6 @@ object KafkaUtils { * starting point of the stream * @param messageHandler Function for translating each message and metadata into the desired type */ - @Experimental def createDirectStream[ K: ClassTag, V: ClassTag, @@ -375,7 +366,6 @@ object KafkaUtils { } /** - * :: Experimental :: * Create an input stream that directly pulls messages from Kafka Brokers * without using any receiver. This stream can guarantee that each message * from Kafka is included in transformations exactly once (see points below). @@ -405,7 +395,6 @@ object KafkaUtils { * to determine where the stream starts (defaults to "largest") * @param topics Names of the topics to consume */ - @Experimental def createDirectStream[ K: ClassTag, V: ClassTag, @@ -437,7 +426,6 @@ object KafkaUtils { } /** - * :: Experimental :: * Create an input stream that directly pulls messages from Kafka Brokers * without using any receiver. This stream can guarantee that each message * from Kafka is included in transformations exactly once (see points below). @@ -472,7 +460,6 @@ object KafkaUtils { * starting point of the stream * @param messageHandler Function for translating each message and metadata into the desired type */ - @Experimental def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V], R]( jssc: JavaStreamingContext, keyClass: Class[K], @@ -492,14 +479,13 @@ object KafkaUtils { val cleanedHandler = jssc.sparkContext.clean(messageHandler.call _) createDirectStream[K, V, KD, VD, R]( jssc.ssc, - Map(kafkaParams.toSeq: _*), - Map(fromOffsets.mapValues { _.longValue() }.toSeq: _*), + Map(kafkaParams.asScala.toSeq: _*), + Map(fromOffsets.asScala.mapValues(_.longValue()).toSeq: _*), cleanedHandler ) } /** - * :: Experimental :: * Create an input stream that directly pulls messages from Kafka Brokers * without using any receiver. This stream can guarantee that each message * from Kafka is included in transformations exactly once (see points below). @@ -533,7 +519,6 @@ object KafkaUtils { * to determine where the stream starts (defaults to "largest") * @param topics Names of the topics to consume */ - @Experimental def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V]]( jssc: JavaStreamingContext, keyClass: Class[K], @@ -549,8 +534,8 @@ object KafkaUtils { implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) createDirectStream[K, V, KD, VD]( jssc.ssc, - Map(kafkaParams.toSeq: _*), - Set(topics.toSeq: _*) + Map(kafkaParams.asScala.toSeq: _*), + Set(topics.asScala.toSeq: _*) ) } } @@ -564,7 +549,7 @@ object KafkaUtils { * classOf[KafkaUtilsPythonHelper].newInstance(), and the createStream() * takes care of known parameters instead of passing them from Python */ -private class KafkaUtilsPythonHelper { +private[kafka] class KafkaUtilsPythonHelper { def createStream( jssc: JavaStreamingContext, kafkaParams: JMap[String, String], @@ -620,10 +605,10 @@ private class KafkaUtilsPythonHelper { ): JavaPairInputDStream[Array[Byte], Array[Byte]] = { if (!fromOffsets.isEmpty) { - import scala.collection.JavaConversions._ - val topicsFromOffsets = fromOffsets.keySet().map(_.topic) - if (topicsFromOffsets != topics.toSet) { - throw new IllegalStateException(s"The specified topics: ${topics.toSet.mkString(" ")} " + + val topicsFromOffsets = fromOffsets.keySet().asScala.map(_.topic) + if (topicsFromOffsets != topics.asScala.toSet) { + throw new IllegalStateException( + s"The specified topics: ${topics.asScala.toSet.mkString(" ")} " + s"do not equal to the topic from offsets: ${topicsFromOffsets.mkString(" ")}") } } @@ -670,4 +655,17 @@ private class KafkaUtilsPythonHelper { TopicAndPartition(topic, partition) def createBroker(host: String, port: JInt): Broker = Broker(host, port) + + def offsetRangesOfKafkaRDD(rdd: RDD[_]): JList[OffsetRange] = { + val parentRDDs = rdd.getNarrowAncestors + val kafkaRDDs = parentRDDs.filter(rdd => rdd.isInstanceOf[KafkaRDD[_, _, _, _, _]]) + + require( + kafkaRDDs.length == 1, + "Cannot get offset ranges, as there may be multiple Kafka RDDs or no Kafka RDD associated" + + "with this RDD, please call this method only on a Kafka RDD.") + + val kafkaRDD = kafkaRDDs.head.asInstanceOf[KafkaRDD[_, _, _, _, _]] + kafkaRDD.offsetRanges.toSeq.asJava + } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala index 267504266630..8a5f37149451 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala @@ -19,10 +19,7 @@ package org.apache.spark.streaming.kafka import kafka.common.TopicAndPartition -import org.apache.spark.annotation.Experimental - /** - * :: Experimental :: * Represents any object that has a collection of [[OffsetRange]]s. This can be used access the * offset ranges in RDDs generated by the direct Kafka DStream (see * [[KafkaUtils.createDirectStream()]]). @@ -33,25 +30,22 @@ import org.apache.spark.annotation.Experimental * } * }}} */ -@Experimental trait HasOffsetRanges { def offsetRanges: Array[OffsetRange] } /** - * :: Experimental :: * Represents a range of offsets from a single Kafka TopicAndPartition. Instances of this class * can be created with `OffsetRange.create()`. + * @param topic Kafka topic name + * @param partition Kafka partition id + * @param fromOffset Inclusive starting offset + * @param untilOffset Exclusive ending offset */ -@Experimental final class OffsetRange private( - /** Kafka topic name */ val topic: String, - /** Kafka partition id */ val partition: Int, - /** inclusive starting offset */ val fromOffset: Long, - /** exclusive ending offset */ val untilOffset: Long) extends Serializable { import OffsetRange.OffsetRangeTuple @@ -75,7 +69,7 @@ final class OffsetRange private( } override def toString(): String = { - s"OffsetRange(topic: '$topic', partition: $partition, range: [$fromOffset -> $untilOffset]" + s"OffsetRange(topic: '$topic', partition: $partition, range: [$fromOffset -> $untilOffset])" } /** this is to avoid ClassNotFoundException during checkpoint restore */ @@ -84,10 +78,8 @@ final class OffsetRange private( } /** - * :: Experimental :: * Companion object the provides methods to create instances of [[OffsetRange]]. */ -@Experimental object OffsetRange { def create(topic: String, partition: Int, fromOffset: Long, untilOffset: Long): OffsetRange = new OffsetRange(topic, partition, fromOffset, untilOffset) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala index 75f0dfc22b9d..764d170934aa 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala @@ -96,7 +96,7 @@ class ReliableKafkaReceiver[ blockOffsetMap = new ConcurrentHashMap[StreamBlockId, Map[TopicAndPartition, Long]]() // Initialize the block generator for storing Kafka message. - blockGenerator = new BlockGenerator(new GeneratedBlockHandler, streamId, conf) + blockGenerator = supervisor.createBlockGenerator(new GeneratedBlockHandler) if (kafkaParams.contains(AUTO_OFFSET_COMMIT) && kafkaParams(AUTO_OFFSET_COMMIT) == "true") { logWarning(s"$AUTO_OFFSET_COMMIT should be set to false in ReliableKafkaReceiver, " + diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java index 02cd24a35906..fbdfbf7e509b 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java @@ -70,16 +70,16 @@ public void testKafkaStream() throws InterruptedException { final String topic1 = "topic1"; final String topic2 = "topic2"; // hold a reference to the current offset ranges, so it can be used downstream - final AtomicReference offsetRanges = new AtomicReference(); + final AtomicReference offsetRanges = new AtomicReference<>(); String[] topic1data = createTopicAndSendData(topic1); String[] topic2data = createTopicAndSendData(topic2); - HashSet sent = new HashSet(); + Set sent = new HashSet<>(); sent.addAll(Arrays.asList(topic1data)); sent.addAll(Arrays.asList(topic2data)); - HashMap kafkaParams = new HashMap(); + Map kafkaParams = new HashMap<>(); kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); kafkaParams.put("auto.offset.reset", "smallest"); @@ -95,17 +95,17 @@ public void testKafkaStream() throws InterruptedException { // Make sure you can get offset ranges from the rdd new Function, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD rdd) throws Exception { + public JavaPairRDD call(JavaPairRDD rdd) { OffsetRange[] offsets = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); offsetRanges.set(offsets); - Assert.assertEquals(offsets[0].topic(), topic1); + Assert.assertEquals(topic1, offsets[0].topic()); return rdd; } } ).map( new Function, String>() { @Override - public String call(Tuple2 kv) throws Exception { + public String call(Tuple2 kv) { return kv._2(); } } @@ -119,10 +119,10 @@ public String call(Tuple2 kv) throws Exception { StringDecoder.class, String.class, kafkaParams, - topicOffsetToMap(topic2, (long) 0), + topicOffsetToMap(topic2, 0L), new Function, String>() { @Override - public String call(MessageAndMetadata msgAndMd) throws Exception { + public String call(MessageAndMetadata msgAndMd) { return msgAndMd.message(); } } @@ -133,7 +133,7 @@ public String call(MessageAndMetadata msgAndMd) throws Exception unifiedStream.foreachRDD( new Function, Void>() { @Override - public Void call(JavaRDD rdd) throws Exception { + public Void call(JavaRDD rdd) { result.addAll(rdd.collect()); for (OffsetRange o : offsetRanges.get()) { System.out.println( @@ -155,14 +155,14 @@ public Void call(JavaRDD rdd) throws Exception { ssc.stop(); } - private HashSet topicToSet(String topic) { - HashSet topicSet = new HashSet(); + private static Set topicToSet(String topic) { + Set topicSet = new HashSet<>(); topicSet.add(topic); return topicSet; } - private HashMap topicOffsetToMap(String topic, Long offsetToStart) { - HashMap topicMap = new HashMap(); + private static Map topicOffsetToMap(String topic, Long offsetToStart) { + Map topicMap = new HashMap<>(); topicMap.put(new TopicAndPartition(topic, 0), offsetToStart); return topicMap; } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java index a9dc6e50613c..afcc6cfccd39 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java @@ -19,6 +19,7 @@ import java.io.Serializable; import java.util.HashMap; +import java.util.Map; import scala.Tuple2; @@ -66,10 +67,10 @@ public void testKafkaRDD() throws InterruptedException { String topic1 = "topic1"; String topic2 = "topic2"; - String[] topic1data = createTopicAndSendData(topic1); - String[] topic2data = createTopicAndSendData(topic2); + createTopicAndSendData(topic1); + createTopicAndSendData(topic2); - HashMap kafkaParams = new HashMap(); + Map kafkaParams = new HashMap<>(); kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); OffsetRange[] offsetRanges = { @@ -77,8 +78,8 @@ public void testKafkaRDD() throws InterruptedException { OffsetRange.create(topic2, 0, 0, 1) }; - HashMap emptyLeaders = new HashMap(); - HashMap leaders = new HashMap(); + Map emptyLeaders = new HashMap<>(); + Map leaders = new HashMap<>(); String[] hostAndPort = kafkaTestUtils.brokerAddress().split(":"); Broker broker = Broker.create(hostAndPort[0], Integer.parseInt(hostAndPort[1])); leaders.put(new TopicAndPartition(topic1, 0), broker); @@ -95,7 +96,7 @@ public void testKafkaRDD() throws InterruptedException { ).map( new Function, String>() { @Override - public String call(Tuple2 kv) throws Exception { + public String call(Tuple2 kv) { return kv._2(); } } @@ -113,7 +114,7 @@ public String call(Tuple2 kv) throws Exception { emptyLeaders, new Function, String>() { @Override - public String call(MessageAndMetadata msgAndMd) throws Exception { + public String call(MessageAndMetadata msgAndMd) { return msgAndMd.message(); } } @@ -131,7 +132,7 @@ public String call(MessageAndMetadata msgAndMd) throws Exception leaders, new Function, String>() { @Override - public String call(MessageAndMetadata msgAndMd) throws Exception { + public String call(MessageAndMetadata msgAndMd) { return msgAndMd.message(); } } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index e4c659215b76..1e69de46cd35 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -67,10 +67,10 @@ public void tearDown() { @Test public void testKafkaStream() throws InterruptedException { String topic = "topic1"; - HashMap topics = new HashMap(); + Map topics = new HashMap<>(); topics.put(topic, 1); - HashMap sent = new HashMap(); + Map sent = new HashMap<>(); sent.put("a", 5); sent.put("b", 3); sent.put("c", 10); @@ -78,7 +78,7 @@ public void testKafkaStream() throws InterruptedException { kafkaTestUtils.createTopic(topic); kafkaTestUtils.sendMessages(topic, sent); - HashMap kafkaParams = new HashMap(); + Map kafkaParams = new HashMap<>(); kafkaParams.put("zookeeper.connect", kafkaTestUtils.zkAddress()); kafkaParams.put("group.id", "test-consumer-" + random.nextInt(10000)); kafkaParams.put("auto.offset.reset", "smallest"); @@ -97,7 +97,7 @@ public void testKafkaStream() throws InterruptedException { JavaDStream words = stream.map( new Function, String>() { @Override - public String call(Tuple2 tuple2) throws Exception { + public String call(Tuple2 tuple2) { return tuple2._2(); } } @@ -106,7 +106,7 @@ public String call(Tuple2 tuple2) throws Exception { words.countByValue().foreachRDD( new Function, Void>() { @Override - public Void call(JavaPairRDD rdd) throws Exception { + public Void call(JavaPairRDD rdd) { List> ret = rdd.collect(); for (Tuple2 r : ret) { if (result.containsKey(r._1())) { @@ -130,8 +130,8 @@ public Void call(JavaPairRDD rdd) throws Exception { Thread.sleep(200); } Assert.assertEquals(sent.size(), result.size()); - for (String k : sent.keySet()) { - Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue()); + for (Map.Entry e : sent.entrySet()) { + Assert.assertEquals(e.getValue().intValue(), result.get(e.getKey()).intValue()); } } } diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 8e1715f6dbb9..02225d5aa7cc 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -20,6 +20,9 @@ package org.apache.spark.streaming.kafka import java.io.File import java.util.concurrent.atomic.AtomicLong +import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset +import org.apache.spark.streaming.scheduler.rate.RateEstimator + import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ @@ -111,7 +114,7 @@ class DirectKafkaStreamSuite rdd }.foreachRDD { rdd => for (o <- offsetRanges) { - println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + logInfo(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") } val collected = rdd.mapPartitionsWithIndex { (i, iter) => // For each partition, get size of the range in the partition, @@ -350,6 +353,77 @@ class DirectKafkaStreamSuite ssc.stop() } + test("using rate controller") { + val topic = "backpressure" + val topicPartition = TopicAndPartition(topic, 0) + kafkaTestUtils.createTopic(topic) + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "smallest" + ) + + val batchIntervalMilliseconds = 100 + val estimator = new ConstantEstimator(100) + val messageKeys = (1 to 200).map(_.toString) + val messages = messageKeys.map((_, 1)).toMap + + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.kafka.maxRatePerPartition", "100") + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds)) + + val kafkaStream = withClue("Error creating direct stream") { + val kc = new KafkaCluster(kafkaParams) + val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message) + val m = kc.getEarliestLeaderOffsets(Set(topicPartition)) + .fold(e => Map.empty[TopicAndPartition, Long], m => m.mapValues(lo => lo.offset)) + + new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)]( + ssc, kafkaParams, m, messageHandler) { + override protected[streaming] val rateController = + Some(new DirectKafkaRateController(id, estimator)) + } + } + + val collectedData = + new mutable.ArrayBuffer[Array[String]]() with mutable.SynchronizedBuffer[Array[String]] + + // Used for assertion failure messages. + def dataToString: String = + collectedData.map(_.mkString("[", ",", "]")).mkString("{", ", ", "}") + + // This is to collect the raw data received from Kafka + kafkaStream.foreachRDD { (rdd: RDD[(String, String)], time: Time) => + val data = rdd.map { _._2 }.collect() + collectedData += data + } + + ssc.start() + + // Try different rate limits. + // Send data to Kafka and wait for arrays of data to appear matching the rate. + Seq(100, 50, 20).foreach { rate => + collectedData.clear() // Empty this buffer on each pass. + estimator.updateRate(rate) // Set a new rate. + // Expect blocks of data equal to "rate", scaled by the interval length in secs. + val expectedSize = Math.round(rate * batchIntervalMilliseconds * 0.001) + kafkaTestUtils.sendMessages(topic, messages) + eventually(timeout(5.seconds), interval(batchIntervalMilliseconds.milliseconds)) { + // Assert that rate estimator values are used to determine maxMessagesPerPartition. + // Funky "-" in message makes the complete assertion message read better. + assert(collectedData.exists(_.size == expectedSize), + s" - No arrays of size $expectedSize for rate $rate found in $dataToString") + } + } + + ssc.stop() + } + /** Get the generated offset ranges from the DirectKafkaStream */ private def getOffsetRanges[K, V]( kafkaStream: DStream[(K, V)]): Seq[(Time, Array[OffsetRange])] = { @@ -381,3 +455,18 @@ object DirectKafkaStreamSuite { } } } + +private[streaming] class ConstantEstimator(@volatile private var rate: Long) + extends RateEstimator { + + def updateRate(newRate: Long): Unit = { + rate = newRate + } + + def compute( + time: Long, + elements: Long, + processingDelay: Long, + schedulingDelay: Long): Option[Double] = Some(rate) +} + diff --git a/external/mqtt-assembly/pom.xml b/external/mqtt-assembly/pom.xml new file mode 100644 index 000000000000..89713a28ca6a --- /dev/null +++ b/external/mqtt-assembly/pom.xml @@ -0,0 +1,175 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.6.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-mqtt-assembly_2.10 + jar + Spark Project External MQTT Assembly + http://spark.apache.org/ + + + streaming-mqtt-assembly + + + + + org.apache.spark + spark-streaming-mqtt_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + + commons-lang + commons-lang + provided + + + com.google.protobuf + protobuf-java + provided + + + com.sun.jersey + jersey-server + provided + + + com.sun.jersey + jersey-core + provided + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.avro + avro-mapred + ${avro.mapred.classifier} + provided + + + org.apache.curator + curator-recipes + provided + + + org.apache.zookeeper + zookeeper + provided + + + log4j + log4j + provided + + + net.java.dev.jets3t + jets3t + provided + + + org.scala-lang + scala-library + provided + + + org.slf4j + slf4j-api + provided + + + org.slf4j + slf4j-log4j12 + provided + + + org.xerial.snappy + snappy-java + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 0e41e5781784..05e6338a08b0 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -78,5 +78,33 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + + + + org.apache.maven.plugins + maven-assembly-plugin + + + test-jar-with-dependencies + package + + single + + + + spark-streaming-mqtt-test-${project.version} + ${project.build.directory}/scala-${scala.binary.version}/ + false + + false + + src/main/assembly/assembly.xml + + + + + + diff --git a/external/mqtt/src/main/assembly/assembly.xml b/external/mqtt/src/main/assembly/assembly.xml new file mode 100644 index 000000000000..ecab5b360eb3 --- /dev/null +++ b/external/mqtt/src/main/assembly/assembly.xml @@ -0,0 +1,44 @@ + + + test-jar-with-dependencies + + jar + + false + + + + ${project.build.directory}/scala-${scala.binary.version}/test-classes + / + + + + + + true + test + true + + org.apache.hadoop:*:jar + org.apache.zookeeper:*:jar + org.apache.avro:*:jar + + + + + diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala index 7c2f18cb35bd..116c170489e9 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala @@ -38,7 +38,7 @@ import org.apache.spark.streaming.receiver.Receiver private[streaming] class MQTTInputDStream( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, brokerUrl: String, topic: String, storageLevel: StorageLevel diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala index 1142d0f56ba3..7b8d56d6faf2 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala @@ -21,8 +21,8 @@ import scala.reflect.ClassTag import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext, JavaDStream} -import org.apache.spark.streaming.dstream.{ReceiverInputDStream, DStream} +import org.apache.spark.streaming.api.java.{JavaDStream, JavaReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.dstream.ReceiverInputDStream object MQTTUtils { /** @@ -74,3 +74,19 @@ object MQTTUtils { createStream(jssc.ssc, brokerUrl, topic, storageLevel) } } + +/** + * This is a helper class that wraps the methods in MQTTUtils into more Python-friendly class and + * function so that it can be easily instantiated and called from Python's MQTTUtils. + */ +private[mqtt] class MQTTUtilsPythonHelper { + + def createStream( + jssc: JavaStreamingContext, + brokerUrl: String, + topic: String, + storageLevel: StorageLevel + ): JavaDStream[String] = { + MQTTUtils.createStream(jssc, brokerUrl, topic, storageLevel) + } +} diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index c4bf5aa7869b..a6a9249db8ed 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -17,46 +17,30 @@ package org.apache.spark.streaming.mqtt -import java.net.{URI, ServerSocket} -import java.util.concurrent.CountDownLatch -import java.util.concurrent.TimeUnit - import scala.concurrent.duration._ import scala.language.postfixOps -import org.apache.activemq.broker.{TransportConnector, BrokerService} -import org.apache.commons.lang3.RandomUtils -import org.eclipse.paho.client.mqttv3._ -import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence - import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually -import org.apache.spark.streaming.{Milliseconds, StreamingContext} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.scheduler.StreamingListener -import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.util.Utils +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Milliseconds, StreamingContext} class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter { private val batchDuration = Milliseconds(500) private val master = "local[2]" private val framework = this.getClass.getSimpleName - private val freePort = findFreePort() - private val brokerUri = "//localhost:" + freePort private val topic = "def" - private val persistenceDir = Utils.createTempDir() private var ssc: StreamingContext = _ - private var broker: BrokerService = _ - private var connector: TransportConnector = _ + private var mqttTestUtils: MQTTTestUtils = _ before { ssc = new StreamingContext(master, framework, batchDuration) - setupMQTT() + mqttTestUtils = new MQTTTestUtils + mqttTestUtils.setup() } after { @@ -64,14 +48,17 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter ssc.stop() ssc = null } - Utils.deleteRecursively(persistenceDir) - tearDownMQTT() + if (mqttTestUtils != null) { + mqttTestUtils.teardown() + mqttTestUtils = null + } } test("mqtt input stream") { val sendMessage = "MQTT demo for spark streaming" - val receiveStream = - MQTTUtils.createStream(ssc, "tcp:" + brokerUri, topic, StorageLevel.MEMORY_ONLY) + val receiveStream = MQTTUtils.createStream(ssc, "tcp://" + mqttTestUtils.brokerUri, topic, + StorageLevel.MEMORY_ONLY) + @volatile var receiveMessage: List[String] = List() receiveStream.foreachRDD { rdd => if (rdd.collect.length > 0) { @@ -79,89 +66,14 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter receiveMessage } } - ssc.start() - // wait for the receiver to start before publishing data, or we risk failing - // the test nondeterministically. See SPARK-4631 - waitForReceiverToStart() + ssc.start() - publishData(sendMessage) + // Retry it because we don't know when the receiver will start. eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { + mqttTestUtils.publishData(topic, sendMessage) assert(sendMessage.equals(receiveMessage(0))) } ssc.stop() } - - private def setupMQTT() { - broker = new BrokerService() - broker.setDataDirectoryFile(Utils.createTempDir()) - connector = new TransportConnector() - connector.setName("mqtt") - connector.setUri(new URI("mqtt:" + brokerUri)) - broker.addConnector(connector) - broker.start() - } - - private def tearDownMQTT() { - if (broker != null) { - broker.stop() - broker = null - } - if (connector != null) { - connector.stop() - connector = null - } - } - - private def findFreePort(): Int = { - val candidatePort = RandomUtils.nextInt(1024, 65536) - Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { - val socket = new ServerSocket(trialPort) - socket.close() - (null, trialPort) - }, new SparkConf())._2 - } - - def publishData(data: String): Unit = { - var client: MqttClient = null - try { - val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) - client = new MqttClient("tcp:" + brokerUri, MqttClient.generateClientId(), persistence) - client.connect() - if (client.isConnected) { - val msgTopic = client.getTopic(topic) - val message = new MqttMessage(data.getBytes("utf-8")) - message.setQos(1) - message.setRetained(true) - - for (i <- 0 to 10) { - try { - msgTopic.publish(message) - } catch { - case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => - // wait for Spark streaming to consume something from the message queue - Thread.sleep(50) - } - } - } - } finally { - client.disconnect() - client.close() - client = null - } - } - - /** - * Block until at least one receiver has started or timeout occurs. - */ - private def waitForReceiverToStart() = { - val latch = new CountDownLatch(1) - ssc.addStreamingListener(new StreamingListener { - override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { - latch.countDown() - } - }) - - assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.") - } } diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala new file mode 100644 index 000000000000..1618e2c088b7 --- /dev/null +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.mqtt + +import java.net.{ServerSocket, URI} + +import scala.language.postfixOps + +import com.google.common.base.Charsets.UTF_8 +import org.apache.activemq.broker.{BrokerService, TransportConnector} +import org.apache.commons.lang3.RandomUtils +import org.eclipse.paho.client.mqttv3._ +import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence + +import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf} + +/** + * Share codes for Scala and Python unit tests + */ +private[mqtt] class MQTTTestUtils extends Logging { + + private val persistenceDir = Utils.createTempDir() + private val brokerHost = "localhost" + private val brokerPort = findFreePort() + + private var broker: BrokerService = _ + private var connector: TransportConnector = _ + + def brokerUri: String = { + s"$brokerHost:$brokerPort" + } + + def setup(): Unit = { + broker = new BrokerService() + broker.setDataDirectoryFile(Utils.createTempDir()) + connector = new TransportConnector() + connector.setName("mqtt") + connector.setUri(new URI("mqtt://" + brokerUri)) + broker.addConnector(connector) + broker.start() + } + + def teardown(): Unit = { + if (broker != null) { + broker.stop() + broker = null + } + if (connector != null) { + connector.stop() + connector = null + } + Utils.deleteRecursively(persistenceDir) + } + + private def findFreePort(): Int = { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + }, new SparkConf())._2 + } + + def publishData(topic: String, data: String): Unit = { + var client: MqttClient = null + try { + val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) + client = new MqttClient("tcp://" + brokerUri, MqttClient.generateClientId(), persistence) + client.connect() + if (client.isConnected) { + val msgTopic = client.getTopic(topic) + val message = new MqttMessage(data.getBytes(UTF_8)) + message.setQos(1) + message.setRetained(true) + + for (i <- 0 to 10) { + try { + msgTopic.publish(message) + } catch { + case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => + // wait for Spark streaming to consume something from the message queue + Thread.sleep(50) + } + } + } + } finally { + if (client != null) { + client.disconnect() + client.close() + client = null + } + } + } + +} diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 178ae8de13b5..244ad58ae959 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala index 7cf02d85d73d..d7de74b35054 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala @@ -39,7 +39,7 @@ import org.apache.spark.streaming.receiver.Receiver */ private[streaming] class TwitterInputDStream( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, twitterAuth: Option[Authorization], filters: Seq[String], storageLevel: StorageLevel diff --git a/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java b/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java index e46b4e5c7531..26ec8af455bc 100644 --- a/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java +++ b/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java @@ -17,8 +17,6 @@ package org.apache.spark.streaming.twitter; -import java.util.Arrays; - import org.junit.Test; import twitter4j.Status; import twitter4j.auth.Authorization; @@ -30,7 +28,7 @@ public class JavaTwitterStreamSuite extends LocalJavaStreamingContext { @Test public void testTwitterStream() { - String[] filters = (String[])Arrays.asList("filter1", "filter2").toArray(); + String[] filters = { "filter1", "filter2" }; Authorization auth = NullAuthorization.getInstance(); // tests the API, does not actually test data receiving diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 37bfd10d4366..171df8682c84 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala index 0469d0af8864..4ea218eaa4de 100644 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala +++ b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala @@ -18,15 +18,17 @@ package org.apache.spark.streaming.zeromq import scala.reflect.ClassTag -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ + import akka.actor.{Props, SupervisorStrategy} import akka.util.ByteString import akka.zeromq.Subscribe + import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.{ReceiverInputDStream} +import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.receiver.ActorSupervisorStrategy object ZeroMQUtils { @@ -75,7 +77,8 @@ object ZeroMQUtils { ): JavaReceiverInputDStream[T] = { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val fn = (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).toIterator + val fn = + (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).iterator().asScala createStream[T](jssc.ssc, publisherUrl, subscribe, fn, storageLevel, supervisorStrategy) } @@ -99,7 +102,8 @@ object ZeroMQUtils { ): JavaReceiverInputDStream[T] = { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val fn = (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).toIterator + val fn = + (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).iterator().asScala createStream[T](jssc.ssc, publisherUrl, subscribe, fn, storageLevel) } @@ -122,7 +126,8 @@ object ZeroMQUtils { ): JavaReceiverInputDStream[T] = { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val fn = (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).toIterator + val fn = + (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).iterator().asScala createStream[T](jssc.ssc, publisherUrl, subscribe, fn) } } diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index f138251748c9..81794a853631 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -39,6 +39,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming_${scala.binary.version} @@ -49,6 +56,7 @@ spark-streaming_${scala.binary.version} ${project.version} test-jar + test junit diff --git a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java index 729bc0459ce5..14975265ab2c 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java @@ -77,7 +77,7 @@ public void call(String s) { public void foreach() { foreachCalls = 0; JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); - rdd.foreach((x) -> foreachCalls++); + rdd.foreach(x -> foreachCalls++); Assert.assertEquals(2, foreachCalls); } @@ -180,7 +180,7 @@ public void map() { JavaPairRDD pairs = rdd.mapToPair(x -> new Tuple2<>(x, x)) .cache(); pairs.collect(); - JavaRDD strings = rdd.map(x -> x.toString()).cache(); + JavaRDD strings = rdd.map(Object::toString).cache(); strings.collect(); } @@ -195,7 +195,9 @@ public void flatMap() { JavaPairRDD pairs = rdd.flatMapToPair(s -> { List> pairs2 = new LinkedList<>(); - for (String word : s.split(" ")) pairs2.add(new Tuple2<>(word, word)); + for (String word : s.split(" ")) { + pairs2.add(new Tuple2<>(word, word)); + } return pairs2; }); @@ -204,11 +206,12 @@ public void flatMap() { JavaDoubleRDD doubles = rdd.flatMapToDouble(s -> { List lengths = new LinkedList<>(); - for (String word : s.split(" ")) lengths.add(word.length() * 1.0); + for (String word : s.split(" ")) { + lengths.add((double) word.length()); + } return lengths; }); - Double x = doubles.first(); Assert.assertEquals(5.0, doubles.first(), 0.01); Assert.assertEquals(11, pairs.count()); } @@ -228,7 +231,7 @@ public void mapsFromPairsToPairs() { swapped.collect(); // There was never a bug here, but it's worth testing: - pairRDD.map(item -> item.swap()).collect(); + pairRDD.map(Tuple2::swap).collect(); } @Test @@ -282,11 +285,11 @@ public void zipPartitions() { FlatMapFunction2, Iterator, Integer> sizesFn = (Iterator i, Iterator s) -> { int sizeI = 0; - int sizeS = 0; while (i.hasNext()) { sizeI += 1; i.next(); } + int sizeS = 0; while (s.hasNext()) { sizeS += 1; s.next(); @@ -301,30 +304,31 @@ public void zipPartitions() { public void accumulators() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - final Accumulator intAccum = sc.intAccumulator(10); - rdd.foreach(x -> intAccum.add(x)); + Accumulator intAccum = sc.intAccumulator(10); + rdd.foreach(intAccum::add); Assert.assertEquals((Integer) 25, intAccum.value()); - final Accumulator doubleAccum = sc.doubleAccumulator(10.0); + Accumulator doubleAccum = sc.doubleAccumulator(10.0); rdd.foreach(x -> doubleAccum.add((double) x)); Assert.assertEquals((Double) 25.0, doubleAccum.value()); // Try a custom accumulator type AccumulatorParam floatAccumulatorParam = new AccumulatorParam() { + @Override public Float addInPlace(Float r, Float t) { return r + t; } - + @Override public Float addAccumulator(Float r, Float t) { return r + t; } - + @Override public Float zero(Float initialValue) { return 0.0f; } }; - final Accumulator floatAccum = sc.accumulator(10.0f, floatAccumulatorParam); + Accumulator floatAccum = sc.accumulator(10.0f, floatAccumulatorParam); rdd.foreach(x -> floatAccum.add((float) x)); Assert.assertEquals((Float) 25.0f, floatAccum.value()); @@ -336,7 +340,7 @@ public Float zero(Float initialValue) { @Test public void keyBy() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2)); - List> s = rdd.keyBy(x -> x.toString()).collect(); + List> s = rdd.keyBy(Object::toString).collect(); Assert.assertEquals(new Tuple2<>("1", 1), s.get(0)); Assert.assertEquals(new Tuple2<>("2", 2), s.get(1)); } @@ -349,7 +353,7 @@ public void mapOnPairRDD() { JavaPairRDD rdd3 = rdd2.mapToPair(in -> new Tuple2<>(in._2(), in._1())); Assert.assertEquals(Arrays.asList( - new Tuple2(1, 1), + new Tuple2<>(1, 1), new Tuple2<>(0, 2), new Tuple2<>(1, 3), new Tuple2<>(0, 4)), rdd3.collect()); @@ -361,7 +365,7 @@ public void collectPartitions() { JavaPairRDD rdd2 = rdd1.mapToPair(i -> new Tuple2<>(i, i % 2)); - List[] parts = rdd1.collectPartitions(new int[]{0}); + List[] parts = rdd1.collectPartitions(new int[]{0}); Assert.assertEquals(Arrays.asList(1, 2), parts[0]); parts = rdd1.collectPartitions(new int[]{1, 2}); @@ -371,19 +375,19 @@ public void collectPartitions() { Assert.assertEquals(Arrays.asList(new Tuple2<>(1, 1), new Tuple2<>(2, 0)), rdd2.collectPartitions(new int[]{0})[0]); - parts = rdd2.collectPartitions(new int[]{1, 2}); - Assert.assertEquals(Arrays.asList(new Tuple2<>(3, 1), new Tuple2<>(4, 0)), parts[0]); + List>[] parts2 = rdd2.collectPartitions(new int[]{1, 2}); + Assert.assertEquals(Arrays.asList(new Tuple2<>(3, 1), new Tuple2<>(4, 0)), parts2[0]); Assert.assertEquals(Arrays.asList(new Tuple2<>(5, 1), new Tuple2<>(6, 0), new Tuple2<>(7, 1)), - parts[1]); + parts2[1]); } @Test public void collectAsMapWithIntArrayValues() { // Regression test for SPARK-1040 - JavaRDD rdd = sc.parallelize(Arrays.asList(new Integer[]{1})); + JavaRDD rdd = sc.parallelize(Arrays.asList(1)); JavaPairRDD pairRDD = rdd.mapToPair(x -> new Tuple2<>(x, new int[]{x})); pairRDD.collect(); // Works fine - Map map = pairRDD.collectAsMap(); // Used to crash with ClassCastException + pairRDD.collectAsMap(); // Used to crash with ClassCastException } } diff --git a/extras/kinesis-asl-assembly/pom.xml b/extras/kinesis-asl-assembly/pom.xml new file mode 100644 index 000000000000..61ba4787fbf9 --- /dev/null +++ b/extras/kinesis-asl-assembly/pom.xml @@ -0,0 +1,181 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.6.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-kinesis-asl-assembly_2.10 + jar + Spark Project Kinesis Assembly + http://spark.apache.org/ + + + streaming-kinesis-asl-assembly + + + + + org.apache.spark + spark-streaming-kinesis-asl_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + + com.fasterxml.jackson.core + jackson-databind + provided + + + commons-lang + commons-lang + provided + + + com.google.protobuf + protobuf-java + provided + + + com.sun.jersey + jersey-server + provided + + + com.sun.jersey + jersey-core + provided + + + log4j + log4j + provided + + + net.java.dev.jets3t + jets3t + provided + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.avro + avro-ipc + provided + + + org.apache.avro + avro-mapred + ${avro.mapred.classifier} + provided + + + org.apache.curator + curator-recipes + provided + + + org.apache.zookeeper + zookeeper + provided + + + org.slf4j + slf4j-api + provided + + + org.slf4j + slf4j-log4j12 + provided + + + org.xerial.snappy + snappy-java + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + + diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index c6f60bc90743..6dd8ff69c294 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -31,7 +31,7 @@ Spark Kinesis Integration - kinesis-asl + streaming-kinesis-asl @@ -66,7 +66,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py b/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py new file mode 100644 index 000000000000..f428f64da3c4 --- /dev/null +++ b/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Consumes messages from a Amazon Kinesis streams and does wordcount. + + This example spins up 1 Kinesis Receiver per shard for the given stream. + It then starts pulling from the last checkpointed sequence number of the given stream. + + Usage: kinesis_wordcount_asl.py + is the name of the consumer app, used to track the read data in DynamoDB + name of the Kinesis stream (ie. mySparkStream) + endpoint of the Kinesis service + (e.g. https://kinesis.us-east-1.amazonaws.com) + + + Example: + # export AWS keys if necessary + $ export AWS_ACCESS_KEY_ID= + $ export AWS_SECRET_KEY= + + # run the example + $ bin/spark-submit -jar extras/kinesis-asl/target/scala-*/\ + spark-streaming-kinesis-asl-assembly_*.jar \ + extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py \ + myAppName mySparkStream https://kinesis.us-east-1.amazonaws.com + + There is a companion helper class called KinesisWordProducerASL which puts dummy data + onto the Kinesis stream. + + This code uses the DefaultAWSCredentialsProviderChain to find credentials + in the following order: + Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + Java System Properties - aws.accessKeyId and aws.secretKey + Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs + Instance profile credentials - delivered through the Amazon EC2 metadata service + For more information, see + http://docs.aws.amazon.com/AWSSdkDocsJava/latest/DeveloperGuide/credentials.html + + See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more details on + the Kinesis Spark Streaming integration. +""" +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream + +if __name__ == "__main__": + if len(sys.argv) != 5: + print( + "Usage: kinesis_wordcount_asl.py ", + file=sys.stderr) + sys.exit(-1) + + sc = SparkContext(appName="PythonStreamingKinesisWordCountAsl") + ssc = StreamingContext(sc, 1) + appName, streamName, endpointUrl, regionName = sys.argv[1:] + lines = KinesisUtils.createStream( + ssc, appName, streamName, endpointUrl, regionName, InitialPositionInStream.LATEST, 2) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index be8b62d3cc6b..de749626ec09 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.nio.ByteBuffer @@ -272,3 +273,4 @@ private[streaming] object StreamingExamples extends Logging { } } } +// scalastyle:on println diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala new file mode 100644 index 000000000000..5d32fa699ae5 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} +import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.kinesis.model._ + +import org.apache.spark._ +import org.apache.spark.rdd.{BlockRDD, BlockRDDPartition} +import org.apache.spark.storage.BlockId +import org.apache.spark.util.NextIterator + + +/** Class representing a range of Kinesis sequence numbers. Both sequence numbers are inclusive. */ +private[kinesis] +case class SequenceNumberRange( + streamName: String, shardId: String, fromSeqNumber: String, toSeqNumber: String) + +/** Class representing an array of Kinesis sequence number ranges */ +private[kinesis] +case class SequenceNumberRanges(ranges: Seq[SequenceNumberRange]) { + def isEmpty(): Boolean = ranges.isEmpty + + def nonEmpty(): Boolean = ranges.nonEmpty + + override def toString(): String = ranges.mkString("SequenceNumberRanges(", ", ", ")") +} + +private[kinesis] +object SequenceNumberRanges { + def apply(range: SequenceNumberRange): SequenceNumberRanges = { + new SequenceNumberRanges(Seq(range)) + } +} + + +/** Partition storing the information of the ranges of Kinesis sequence numbers to read */ +private[kinesis] +class KinesisBackedBlockRDDPartition( + idx: Int, + blockId: BlockId, + val isBlockIdValid: Boolean, + val seqNumberRanges: SequenceNumberRanges + ) extends BlockRDDPartition(blockId, idx) + +/** + * A BlockRDD where the block data is backed by Kinesis, which can accessed using the + * sequence numbers of the corresponding blocks. + */ +private[kinesis] +class KinesisBackedBlockRDD( + @transient sc: SparkContext, + val regionName: String, + val endpointUrl: String, + @transient blockIds: Array[BlockId], + @transient val arrayOfseqNumberRanges: Array[SequenceNumberRanges], + @transient isBlockIdValid: Array[Boolean] = Array.empty, + val retryTimeoutMs: Int = 10000, + val awsCredentialsOption: Option[SerializableAWSCredentials] = None + ) extends BlockRDD[Array[Byte]](sc, blockIds) { + + require(blockIds.length == arrayOfseqNumberRanges.length, + "Number of blockIds is not equal to the number of sequence number ranges") + + override def isValid(): Boolean = true + + override def getPartitions: Array[Partition] = { + Array.tabulate(blockIds.length) { i => + val isValid = if (isBlockIdValid.length == 0) true else isBlockIdValid(i) + new KinesisBackedBlockRDDPartition(i, blockIds(i), isValid, arrayOfseqNumberRanges(i)) + } + } + + override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { + val blockManager = SparkEnv.get.blockManager + val partition = split.asInstanceOf[KinesisBackedBlockRDDPartition] + val blockId = partition.blockId + + def getBlockFromBlockManager(): Option[Iterator[Array[Byte]]] = { + logDebug(s"Read partition data of $this from block manager, block $blockId") + blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[Array[Byte]]]) + } + + def getBlockFromKinesis(): Iterator[Array[Byte]] = { + val credenentials = awsCredentialsOption.getOrElse { + new DefaultAWSCredentialsProviderChain().getCredentials() + } + partition.seqNumberRanges.ranges.iterator.flatMap { range => + new KinesisSequenceRangeIterator( + credenentials, endpointUrl, regionName, range, retryTimeoutMs) + } + } + if (partition.isBlockIdValid) { + getBlockFromBlockManager().getOrElse { getBlockFromKinesis() } + } else { + getBlockFromKinesis() + } + } +} + + +/** + * An iterator that return the Kinesis data based on the given range of sequence numbers. + * Internally, it repeatedly fetches sets of records starting from the fromSequenceNumber, + * until the endSequenceNumber is reached. + */ +private[kinesis] +class KinesisSequenceRangeIterator( + credentials: AWSCredentials, + endpointUrl: String, + regionId: String, + range: SequenceNumberRange, + retryTimeoutMs: Int + ) extends NextIterator[Array[Byte]] with Logging { + + private val client = new AmazonKinesisClient(credentials) + private val streamName = range.streamName + private val shardId = range.shardId + + private var toSeqNumberReceived = false + private var lastSeqNumber: String = null + private var internalIterator: Iterator[Record] = null + + client.setEndpoint(endpointUrl, "kinesis", regionId) + + override protected def getNext(): Array[Byte] = { + var nextBytes: Array[Byte] = null + if (toSeqNumberReceived) { + finished = true + } else { + + if (internalIterator == null) { + + // If the internal iterator has not been initialized, + // then fetch records from starting sequence number + internalIterator = getRecords(ShardIteratorType.AT_SEQUENCE_NUMBER, range.fromSeqNumber) + } else if (!internalIterator.hasNext) { + + // If the internal iterator does not have any more records, + // then fetch more records after the last consumed sequence number + internalIterator = getRecords(ShardIteratorType.AFTER_SEQUENCE_NUMBER, lastSeqNumber) + } + + if (!internalIterator.hasNext) { + + // If the internal iterator still does not have any data, then throw exception + // and terminate this iterator + finished = true + throw new SparkException( + s"Could not read until the end sequence number of the range: $range") + } else { + + // Get the record, copy the data into a byte array and remember its sequence number + val nextRecord: Record = internalIterator.next() + val byteBuffer = nextRecord.getData() + nextBytes = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(nextBytes) + lastSeqNumber = nextRecord.getSequenceNumber() + + // If the this record's sequence number matches the stopping sequence number, then make sure + // the iterator is marked finished next time getNext() is called + if (nextRecord.getSequenceNumber == range.toSeqNumber) { + toSeqNumberReceived = true + } + } + + } + nextBytes + } + + override protected def close(): Unit = { + client.shutdown() + } + + /** + * Get records starting from or after the given sequence number. + */ + private def getRecords(iteratorType: ShardIteratorType, seqNum: String): Iterator[Record] = { + val shardIterator = getKinesisIterator(iteratorType, seqNum) + val result = getRecordsAndNextKinesisIterator(shardIterator) + result._1 + } + + /** + * Get the records starting from using a Kinesis shard iterator (which is a progress handle + * to get records from Kinesis), and get the next shard iterator for next consumption. + */ + private def getRecordsAndNextKinesisIterator( + shardIterator: String): (Iterator[Record], String) = { + val getRecordsRequest = new GetRecordsRequest + getRecordsRequest.setRequestCredentials(credentials) + getRecordsRequest.setShardIterator(shardIterator) + val getRecordsResult = retryOrTimeout[GetRecordsResult]( + s"getting records using shard iterator") { + client.getRecords(getRecordsRequest) + } + (getRecordsResult.getRecords.iterator().asScala, getRecordsResult.getNextShardIterator) + } + + /** + * Get the Kinesis shard iterator for getting records starting from or after the given + * sequence number. + */ + private def getKinesisIterator( + iteratorType: ShardIteratorType, + sequenceNumber: String): String = { + val getShardIteratorRequest = new GetShardIteratorRequest + getShardIteratorRequest.setRequestCredentials(credentials) + getShardIteratorRequest.setStreamName(streamName) + getShardIteratorRequest.setShardId(shardId) + getShardIteratorRequest.setShardIteratorType(iteratorType.toString) + getShardIteratorRequest.setStartingSequenceNumber(sequenceNumber) + val getShardIteratorResult = retryOrTimeout[GetShardIteratorResult]( + s"getting shard iterator from sequence number $sequenceNumber") { + client.getShardIterator(getShardIteratorRequest) + } + getShardIteratorResult.getShardIterator + } + + /** Helper method to retry Kinesis API request with exponential backoff and timeouts */ + private def retryOrTimeout[T](message: String)(body: => T): T = { + import KinesisSequenceRangeIterator._ + + var startTimeMs = System.currentTimeMillis() + var retryCount = 0 + var waitTimeMs = MIN_RETRY_WAIT_TIME_MS + var result: Option[T] = None + var lastError: Throwable = null + + def isTimedOut = (System.currentTimeMillis() - startTimeMs) >= retryTimeoutMs + def isMaxRetryDone = retryCount >= MAX_RETRIES + + while (result.isEmpty && !isTimedOut && !isMaxRetryDone) { + if (retryCount > 0) { // wait only if this is a retry + Thread.sleep(waitTimeMs) + waitTimeMs *= 2 // if you have waited, then double wait time for next round + } + try { + result = Some(body) + } catch { + case NonFatal(t) => + lastError = t + t match { + case ptee: ProvisionedThroughputExceededException => + logWarning(s"Error while $message [attempt = ${retryCount + 1}]", ptee) + case e: Throwable => + throw new SparkException(s"Error while $message", e) + } + } + retryCount += 1 + } + result.getOrElse { + if (isTimedOut) { + throw new SparkException( + s"Timed out after $retryTimeoutMs ms while $message, last exception: ", lastError) + } else { + throw new SparkException( + s"Gave up after $retryCount retries while $message, last exception: ", lastError) + } + } + } +} + +private[streaming] +object KinesisSequenceRangeIterator { + val MAX_RETRIES = 3 + val MIN_RETRY_WAIT_TIME_MS = 100 +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala new file mode 100644 index 000000000000..2e4204dcb6f1 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream + +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.{BlockId, StorageLevel} +import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.streaming.scheduler.ReceivedBlockInfo +import org.apache.spark.streaming.{Duration, StreamingContext, Time} + +private[kinesis] class KinesisInputDStream( + @transient _ssc: StreamingContext, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointAppName: String, + checkpointInterval: Duration, + storageLevel: StorageLevel, + awsCredentialsOption: Option[SerializableAWSCredentials] + ) extends ReceiverInputDStream[Array[Byte]](_ssc) { + + private[streaming] + override def createBlockRDD(time: Time, blockInfos: Seq[ReceivedBlockInfo]): RDD[Array[Byte]] = { + + // This returns true even for when blockInfos is empty + val allBlocksHaveRanges = blockInfos.map { _.metadataOption }.forall(_.nonEmpty) + + if (allBlocksHaveRanges) { + // Create a KinesisBackedBlockRDD, even when there are no blocks + val blockIds = blockInfos.map { _.blockId.asInstanceOf[BlockId] }.toArray + val seqNumRanges = blockInfos.map { + _.metadataOption.get.asInstanceOf[SequenceNumberRanges] }.toArray + val isBlockIdValid = blockInfos.map { _.isBlockIdValid() }.toArray + logDebug(s"Creating KinesisBackedBlockRDD for $time with ${seqNumRanges.length} " + + s"seq number ranges: ${seqNumRanges.mkString(", ")} ") + new KinesisBackedBlockRDD( + context.sc, regionName, endpointUrl, blockIds, seqNumRanges, + isBlockIdValid = isBlockIdValid, + retryTimeoutMs = ssc.graph.batchDuration.milliseconds.toInt, + awsCredentialsOption = awsCredentialsOption) + } else { + logWarning("Kinesis sequence number information was not present with some block metadata," + + " it may not be possible to recover from failures") + super.createBlockRDD(time, blockInfos) + } + } + + override def getReceiver(): Receiver[Array[Byte]] = { + new KinesisReceiver(streamName, endpointUrl, regionName, initialPositionInStream, + checkpointAppName, checkpointInterval, storageLevel, awsCredentialsOption) + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 1a8a4cecc114..6e0988c1af8a 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -18,17 +18,20 @@ package org.apache.spark.streaming.kinesis import java.util.UUID +import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.control.NonFatal -import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, BasicAWSCredentials, DefaultAWSCredentialsProviderChain} +import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, DefaultAWSCredentialsProviderChain} import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorFactory} import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker} +import com.amazonaws.services.kinesis.model.Record -import org.apache.spark.Logging -import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming.Duration -import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver} import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkEnv} private[kinesis] @@ -42,38 +45,47 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) * Custom AWS Kinesis-specific implementation of Spark Streaming's Receiver. * This implementation relies on the Kinesis Client Library (KCL) Worker as described here: * https://github.com/awslabs/amazon-kinesis-client - * This is a custom receiver used with StreamingContext.receiverStream(Receiver) as described here: - * http://spark.apache.org/docs/latest/streaming-custom-receivers.html - * Instances of this class will get shipped to the Spark Streaming Workers to run within a - * Spark Executor. * - * @param appName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams - * by the Kinesis Client Library. If you change the App name or Stream name, - * the KCL will throw errors. This usually requires deleting the backing - * DynamoDB table with the same name this Kinesis application. + * The way this Receiver works is as follows: + * - The receiver starts a KCL Worker, which is essentially runs a threadpool of multiple + * KinesisRecordProcessor + * - Each KinesisRecordProcessor receives data from a Kinesis shard in batches. Each batch is + * inserted into a Block Generator, and the corresponding range of sequence numbers is recorded. + * - When the block generator defines a block, then the recorded sequence number ranges that were + * inserted into the block are recorded separately for being used later. + * - When the block is ready to be pushed, the block is pushed and the ranges are reported as + * metadata of the block. In addition, the ranges are used to find out the latest sequence + * number for each shard that can be checkpointed through the DynamoDB. + * - Periodically, each KinesisRecordProcessor checkpoints the latest successfully stored sequence + * number for it own shard. + * * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) * @param regionName Region name used by the Kinesis Client Library for * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the * worker's initial starting position in the stream. * The values are either the beginning of the stream * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointAppName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams + * by the Kinesis Client Library. If you change the App name or Stream name, + * the KCL will throw errors. This usually requires deleting the backing + * DynamoDB table with the same name this Kinesis application. + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects * @param awsCredentialsOption Optional AWS credentials, used when user directly specifies * the credentials */ private[kinesis] class KinesisReceiver( - appName: String, - streamName: String, + val streamName: String, endpointUrl: String, regionName: String, initialPositionInStream: InitialPositionInStream, + checkpointAppName: String, checkpointInterval: Duration, storageLevel: StorageLevel, awsCredentialsOption: Option[SerializableAWSCredentials] @@ -90,7 +102,7 @@ private[kinesis] class KinesisReceiver( * workerId is used by the KCL should be based on the ip address of the actual Spark Worker * where this code runs (not the driver's IP address.) */ - private var workerId: String = null + @volatile private var workerId: String = null /** * Worker is the core client abstraction from the Kinesis Client Library (KCL). @@ -98,22 +110,40 @@ private[kinesis] class KinesisReceiver( * Each shard is assigned its own IRecordProcessor and the worker run multiple such * processors. */ - private var worker: Worker = null + @volatile private var worker: Worker = null + @volatile private var workerThread: Thread = null - /** Thread running the worker */ - private var workerThread: Thread = null + /** BlockGenerator used to generates blocks out of Kinesis data */ + @volatile private var blockGenerator: BlockGenerator = null + /** + * Sequence number ranges added to the current block being generated. + * Accessing and updating of this map is synchronized by locks in BlockGenerator. + */ + private val seqNumRangesInCurrentBlock = new mutable.ArrayBuffer[SequenceNumberRange] + + /** Sequence number ranges of data added to each generated block */ + private val blockIdToSeqNumRanges = new mutable.HashMap[StreamBlockId, SequenceNumberRanges] + with mutable.SynchronizedMap[StreamBlockId, SequenceNumberRanges] + + /** + * Latest sequence number ranges that have been stored successfully. + * This is used for checkpointing through KCL */ + private val shardIdToLatestStoredSeqNum = new mutable.HashMap[String, String] + with mutable.SynchronizedMap[String, String] /** * This is called when the KinesisReceiver starts and must be non-blocking. * The KCL creates and manages the receiving/processing thread pool through Worker.run(). */ override def onStart() { + blockGenerator = supervisor.createBlockGenerator(new GeneratedBlockHandler) + workerId = Utils.localHostName() + ":" + UUID.randomUUID() // KCL config instance val awsCredProvider = resolveAWSCredentialsProvider() val kinesisClientLibConfiguration = - new KinesisClientLibConfiguration(appName, streamName, awsCredProvider, workerId) + new KinesisClientLibConfiguration(checkpointAppName, streamName, awsCredProvider, workerId) .withKinesisEndpoint(endpointUrl) .withInitialPositionInStream(initialPositionInStream) .withTaskBackoffTimeMillis(500) @@ -141,6 +171,10 @@ private[kinesis] class KinesisReceiver( } } } + + blockIdToSeqNumRanges.clear() + blockGenerator.start() + workerThread.setName(s"Kinesis Receiver ${streamId}") workerThread.setDaemon(true) workerThread.start() @@ -165,6 +199,81 @@ private[kinesis] class KinesisReceiver( workerId = null } + /** Add records of the given shard to the current block being generated */ + private[kinesis] def addRecords(shardId: String, records: java.util.List[Record]): Unit = { + if (records.size > 0) { + val dataIterator = records.iterator().asScala.map { record => + val byteBuffer = record.getData() + val byteArray = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(byteArray) + byteArray + } + val metadata = SequenceNumberRange(streamName, shardId, + records.get(0).getSequenceNumber(), records.get(records.size() - 1).getSequenceNumber()) + blockGenerator.addMultipleDataWithCallback(dataIterator, metadata) + + } + } + + /** Get the latest sequence number for the given shard that can be checkpointed through KCL */ + private[kinesis] def getLatestSeqNumToCheckpoint(shardId: String): Option[String] = { + shardIdToLatestStoredSeqNum.get(shardId) + } + + /** + * Remember the range of sequence numbers that was added to the currently active block. + * Internally, this is synchronized with `finalizeRangesForCurrentBlock()`. + */ + private def rememberAddedRange(range: SequenceNumberRange): Unit = { + seqNumRangesInCurrentBlock += range + } + + /** + * Finalize the ranges added to the block that was active and prepare the ranges buffer + * for next block. Internally, this is synchronized with `rememberAddedRange()`. + */ + private def finalizeRangesForCurrentBlock(blockId: StreamBlockId): Unit = { + blockIdToSeqNumRanges(blockId) = SequenceNumberRanges(seqNumRangesInCurrentBlock.toArray) + seqNumRangesInCurrentBlock.clear() + logDebug(s"Generated block $blockId has $blockIdToSeqNumRanges") + } + + /** Store the block along with its associated ranges */ + private def storeBlockWithRanges( + blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[Array[Byte]]): Unit = { + val rangesToReportOption = blockIdToSeqNumRanges.remove(blockId) + if (rangesToReportOption.isEmpty) { + stop("Error while storing block into Spark, could not find sequence number ranges " + + s"for block $blockId") + return + } + + val rangesToReport = rangesToReportOption.get + var attempt = 0 + var stored = false + var throwable: Throwable = null + while (!stored && attempt <= 3) { + try { + store(arrayBuffer, rangesToReport) + stored = true + } catch { + case NonFatal(th) => + attempt += 1 + throwable = th + } + } + if (!stored) { + stop("Error while storing block into Spark", throwable) + } + + // Update the latest sequence number that have been successfully stored for each shard + // Note that we are doing this sequentially because the array of sequence number ranges + // is assumed to be + rangesToReport.ranges.foreach { range => + shardIdToLatestStoredSeqNum(range.shardId) = range.toSeqNumber + } + } + /** * If AWS credential is provided, return a AWSCredentialProvider returning that credential. * Otherwise, return the DefaultAWSCredentialsProviderChain. @@ -182,4 +291,46 @@ private[kinesis] class KinesisReceiver( new DefaultAWSCredentialsProviderChain() } } + + + /** + * Class to handle blocks generated by this receiver's block generator. Specifically, in + * the context of the Kinesis Receiver, this handler does the following. + * + * - When an array of records is added to the current active block in the block generator, + * this handler keeps track of the corresponding sequence number range. + * - When the currently active block is ready to sealed (not more records), this handler + * keep track of the list of ranges added into this block in another H + */ + private class GeneratedBlockHandler extends BlockGeneratorListener { + + /** + * Callback method called after a data item is added into the BlockGenerator. + * The data addition, block generation, and calls to onAddData and onGenerateBlock + * are all synchronized through the same lock. + */ + def onAddData(data: Any, metadata: Any): Unit = { + rememberAddedRange(metadata.asInstanceOf[SequenceNumberRange]) + } + + /** + * Callback method called after a block has been generated. + * The data addition, block generation, and calls to onAddData and onGenerateBlock + * are all synchronized through the same lock. + */ + def onGenerateBlock(blockId: StreamBlockId): Unit = { + finalizeRangesForCurrentBlock(blockId) + } + + /** Callback method called when a block is ready to be pushed / stored. */ + def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { + storeBlockWithRanges(blockId, + arrayBuffer.asInstanceOf[mutable.ArrayBuffer[Array[Byte]]]) + } + + /** Callback called in case of any error in internal of the BlockGenerator */ + def onError(message: String, throwable: Throwable): Unit = { + reportError(message, throwable) + } + } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index fe9e3a0c793e..b2405123321e 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -18,20 +18,16 @@ package org.apache.spark.streaming.kinesis import java.util.List -import scala.collection.JavaConversions.asScalaBuffer import scala.util.Random +import scala.util.control.NonFatal -import org.apache.spark.Logging - -import com.amazonaws.services.kinesis.clientlibrary.exceptions.InvalidStateException -import com.amazonaws.services.kinesis.clientlibrary.exceptions.KinesisClientLibDependencyException -import com.amazonaws.services.kinesis.clientlibrary.exceptions.ShutdownException -import com.amazonaws.services.kinesis.clientlibrary.exceptions.ThrottlingException -import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessor -import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer +import com.amazonaws.services.kinesis.clientlibrary.exceptions.{InvalidStateException, KinesisClientLibDependencyException, ShutdownException, ThrottlingException} +import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer} import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record +import org.apache.spark.Logging + /** * Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor. * This implementation operates on the Array[Byte] from the KinesisReceiver. @@ -51,6 +47,7 @@ private[kinesis] class KinesisRecordProcessor( checkpointState: KinesisCheckpointState) extends IRecordProcessor with Logging { // shardId to be populated during initialize() + @volatile private var shardId: String = _ /** @@ -75,47 +72,38 @@ private[kinesis] class KinesisRecordProcessor( override def processRecords(batch: List[Record], checkpointer: IRecordProcessorCheckpointer) { if (!receiver.isStopped()) { try { - /* - * Notes: - * 1) If we try to store the raw ByteBuffer from record.getData(), the Spark Streaming - * Receiver.store(ByteBuffer) attempts to deserialize the ByteBuffer using the - * internally-configured Spark serializer (kryo, etc). - * 2) This is not desirable, so we instead store a raw Array[Byte] and decouple - * ourselves from Spark's internal serialization strategy. - * 3) For performance, the BlockGenerator is asynchronously queuing elements within its - * memory before creating blocks. This prevents the small block scenario, but requires - * that you register callbacks to know when a block has been generated and stored - * (WAL is sufficient for storage) before can checkpoint back to the source. - */ - batch.foreach(record => receiver.store(record.getData().array())) - - logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") + receiver.addRecords(shardId, batch) + logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") /* - * Checkpoint the sequence number of the last record successfully processed/stored - * in the batch. - * In this implementation, we're checkpointing after the given checkpointIntervalMillis. - * Note that this logic requires that processRecords() be called AND that it's time to - * checkpoint. I point this out because there is no background thread running the - * checkpointer. Checkpointing is tested and trigger only when a new batch comes in. - * If the worker is shutdown cleanly, checkpoint will happen (see shutdown() below). - * However, if the worker dies unexpectedly, a checkpoint may not happen. - * This could lead to records being processed more than once. + * + * Checkpoint the sequence number of the last record successfully stored. + * Note that in this current implementation, the checkpointing occurs only when after + * checkpointIntervalMillis from the last checkpoint, AND when there is new record + * to process. This leads to the checkpointing lagging behind what records have been + * stored by the receiver. Ofcourse, this can lead records processed more than once, + * under failures and restarts. + * + * TODO: Instead of checkpointing here, run a separate timer task to perform + * checkpointing so that it checkpoints in a timely manner independent of whether + * new records are available or not. */ if (checkpointState.shouldCheckpoint()) { - /* Perform the checkpoint */ - KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100) + receiver.getLatestSeqNumToCheckpoint(shardId).foreach { latestSeqNum => + /* Perform the checkpoint */ + KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(latestSeqNum), 4, 100) - /* Update the next checkpoint time */ - checkpointState.advanceCheckpoint() + /* Update the next checkpoint time */ + checkpointState.advanceCheckpoint() - logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint of ${batch.size}" + + logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint of ${batch.size}" + s" records for shardId $shardId") - logDebug(s"Checkpoint: Next checkpoint is at " + + logDebug(s"Checkpoint: Next checkpoint is at " + s" ${checkpointState.checkpointClock.getTimeMillis()} for shardId $shardId") + } } } catch { - case e: Throwable => { + case NonFatal(e) => { /* * If there is a failure within the batch, the batch will not be checkpointed. * This will potentially cause records since the last checkpoint to be processed @@ -130,7 +118,7 @@ private[kinesis] class KinesisRecordProcessor( } } else { /* RecordProcessor has been stopped. */ - logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" + + logInfo(s"Stopped: KinesisReceiver has stopped for workerId $workerId" + s" and shardId $shardId. No more records will be processed.") } } @@ -154,7 +142,11 @@ private[kinesis] class KinesisRecordProcessor( * It's now OK to read from the new shards that resulted from a resharding event. */ case ShutdownReason.TERMINATE => - KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100) + val latestSeqNumToCheckpointOption = receiver.getLatestSeqNumToCheckpoint(shardId) + if (latestSeqNumToCheckpointOption.nonEmpty) { + KinesisRecordProcessor.retryRandom( + checkpointer.checkpoint(latestSeqNumToCheckpointOption.get), 4, 100) + } /* * ZOMBIE Use Case. NoOp. diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala new file mode 100644 index 000000000000..634bf9452107 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import java.nio.ByteBuffer +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.{Failure, Random, Success, Try} + +import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} +import com.amazonaws.regions.RegionUtils +import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient +import com.amazonaws.services.dynamodbv2.document.DynamoDB +import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.kinesis.model._ + +import org.apache.spark.Logging + +/** + * Shared utility methods for performing Kinesis tests that actually transfer data + */ +private[kinesis] class KinesisTestUtils extends Logging { + + val endpointUrl = KinesisTestUtils.endpointUrl + val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName() + val streamShardCount = 2 + + private val createStreamTimeoutSeconds = 300 + private val describeStreamPollTimeSeconds = 1 + + @volatile + private var streamCreated = false + + @volatile + private var _streamName: String = _ + + private lazy val kinesisClient = { + val client = new AmazonKinesisClient(KinesisTestUtils.getAWSCredentials()) + client.setEndpoint(endpointUrl) + client + } + + private lazy val dynamoDB = { + val dynamoDBClient = new AmazonDynamoDBClient(new DefaultAWSCredentialsProviderChain()) + dynamoDBClient.setRegion(RegionUtils.getRegion(regionName)) + new DynamoDB(dynamoDBClient) + } + + def streamName: String = { + require(streamCreated, "Stream not yet created, call createStream() to create one") + _streamName + } + + def createStream(): Unit = { + require(!streamCreated, "Stream already created") + _streamName = findNonExistentStreamName() + + // Create a stream. The number of shards determines the provisioned throughput. + logInfo(s"Creating stream ${_streamName}") + val createStreamRequest = new CreateStreamRequest() + createStreamRequest.setStreamName(_streamName) + createStreamRequest.setShardCount(2) + kinesisClient.createStream(createStreamRequest) + + // The stream is now being created. Wait for it to become active. + waitForStreamToBeActive(_streamName) + streamCreated = true + logInfo(s"Created stream ${_streamName}") + } + + /** + * Push data to Kinesis stream and return a map of + * shardId -> seq of (data, seq number) pushed to corresponding shard + */ + def pushData(testData: Seq[Int]): Map[String, Seq[(Int, String)]] = { + require(streamCreated, "Stream not yet created, call createStream() to create one") + val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() + + testData.foreach { num => + val str = num.toString + val putRecordRequest = new PutRecordRequest().withStreamName(streamName) + .withData(ByteBuffer.wrap(str.getBytes())) + .withPartitionKey(str) + + val putRecordResult = kinesisClient.putRecord(putRecordRequest) + val shardId = putRecordResult.getShardId + val seqNumber = putRecordResult.getSequenceNumber() + val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, + new ArrayBuffer[(Int, String)]()) + sentSeqNumbers += ((num, seqNumber)) + } + + logInfo(s"Pushed $testData:\n\t ${shardIdToSeqNumbers.mkString("\n\t")}") + shardIdToSeqNumbers.toMap + } + + /** + * Expose a Python friendly API. + */ + def pushData(testData: java.util.List[Int]): Unit = { + pushData(testData.asScala) + } + + def deleteStream(): Unit = { + try { + if (streamCreated) { + kinesisClient.deleteStream(streamName) + } + } catch { + case e: Exception => + logWarning(s"Could not delete stream $streamName") + } + } + + def deleteDynamoDBTable(tableName: String): Unit = { + try { + val table = dynamoDB.getTable(tableName) + table.delete() + table.waitForDelete() + } catch { + case e: Exception => + logWarning(s"Could not delete DynamoDB table $tableName") + } + } + + private def describeStream(streamNameToDescribe: String): Option[StreamDescription] = { + try { + val describeStreamRequest = new DescribeStreamRequest().withStreamName(streamNameToDescribe) + val desc = kinesisClient.describeStream(describeStreamRequest).getStreamDescription() + Some(desc) + } catch { + case rnfe: ResourceNotFoundException => + None + } + } + + private def findNonExistentStreamName(): String = { + var testStreamName: String = null + do { + Thread.sleep(TimeUnit.SECONDS.toMillis(describeStreamPollTimeSeconds)) + testStreamName = s"KinesisTestUtils-${math.abs(Random.nextLong())}" + } while (describeStream(testStreamName).nonEmpty) + testStreamName + } + + private def waitForStreamToBeActive(streamNameToWaitFor: String): Unit = { + val startTime = System.currentTimeMillis() + val endTime = startTime + TimeUnit.SECONDS.toMillis(createStreamTimeoutSeconds) + while (System.currentTimeMillis() < endTime) { + Thread.sleep(TimeUnit.SECONDS.toMillis(describeStreamPollTimeSeconds)) + describeStream(streamNameToWaitFor).foreach { description => + val streamStatus = description.getStreamStatus() + logDebug(s"\t- current state: $streamStatus\n") + if ("ACTIVE".equals(streamStatus)) { + return + } + } + } + require(false, s"Stream $streamName never became active") + } +} + +private[kinesis] object KinesisTestUtils { + + val envVarNameForEnablingTests = "ENABLE_KINESIS_TESTS" + val endVarNameForEndpoint = "KINESIS_TEST_ENDPOINT_URL" + val defaultEndpointUrl = "https://kinesis.us-west-2.amazonaws.com" + + lazy val shouldRunTests = { + val isEnvSet = sys.env.get(envVarNameForEnablingTests) == Some("1") + if (isEnvSet) { + // scalastyle:off println + // Print this so that they are easily visible on the console and not hidden in the log4j logs. + println( + s""" + |Kinesis tests that actually send data has been enabled by setting the environment + |variable $envVarNameForEnablingTests to 1. This will create Kinesis Streams and + |DynamoDB tables in AWS. Please be aware that this may incur some AWS costs. + |By default, the tests use the endpoint URL $defaultEndpointUrl to create Kinesis streams. + |To change this endpoint URL to a different region, you can set the environment variable + |$endVarNameForEndpoint to the desired endpoint URL + |(e.g. $endVarNameForEndpoint="https://kinesis.us-west-2.amazonaws.com"). + """.stripMargin) + // scalastyle:on println + } + isEnvSet + } + + lazy val endpointUrl = { + val url = sys.env.getOrElse(endVarNameForEndpoint, defaultEndpointUrl) + // scalastyle:off println + // Print this so that they are easily visible on the console and not hidden in the log4j logs. + println(s"Using endpoint URL $url for creating Kinesis streams for tests.") + // scalastyle:on println + url + } + + def isAWSCredentialsPresent: Boolean = { + Try { new DefaultAWSCredentialsProviderChain().getCredentials() }.isSuccess + } + + def getAWSCredentials(): AWSCredentials = { + assert(shouldRunTests, + "Kinesis test not enabled, should not attempt to get AWS credentials") + Try { new DefaultAWSCredentialsProviderChain().getCredentials() } match { + case Success(cred) => cred + case Failure(e) => + throw new Exception( + s""" + |Kinesis tests enabled using environment variable $envVarNameForEnablingTests + |but could not find AWS credentials. Please follow instructions in AWS documentation + |to set the credentials in your system such that the DefaultAWSCredentialsProviderChain + |can find the credentials. + """.stripMargin) + } + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index e5acab50181e..c799fadf2d5c 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -65,9 +65,8 @@ object KinesisUtils { ): ReceiverInputDStream[Array[Byte]] = { // Setting scope to override receiver stream's scope of "receiver stream" ssc.withNamedScope("kinesis stream") { - ssc.receiverStream( - new KinesisReceiver(kinesisAppName, streamName, endpointUrl, validateRegion(regionName), - initialPositionInStream, checkpointInterval, storageLevel, None)) + new KinesisInputDStream(ssc, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, None) } } @@ -86,19 +85,19 @@ object KinesisUtils { * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) * @param regionName Name of region used by the Kinesis Client Library (KCL) to update * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) - * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the * worker's initial starting position in the stream. * The values are either the beginning of the stream * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects. * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) */ def createStream( ssc: StreamingContext, @@ -112,10 +111,11 @@ object KinesisUtils { awsAccessKeyId: String, awsSecretKey: String ): ReceiverInputDStream[Array[Byte]] = { - ssc.receiverStream( - new KinesisReceiver(kinesisAppName, streamName, endpointUrl, validateRegion(regionName), - initialPositionInStream, checkpointInterval, storageLevel, - Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey)))) + ssc.withNamedScope("kinesis stream") { + new KinesisInputDStream(ssc, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))) + } } /** @@ -130,7 +130,7 @@ object KinesisUtils { * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name in * [[org.apache.spark.SparkConf]]. * - * @param ssc Java StreamingContext object + * @param ssc StreamingContext object * @param streamName Kinesis stream name * @param endpointUrl Endpoint url of Kinesis service * (e.g., https://kinesis.us-east-1.amazonaws.com) @@ -155,9 +155,10 @@ object KinesisUtils { initialPositionInStream: InitialPositionInStream, storageLevel: StorageLevel ): ReceiverInputDStream[Array[Byte]] = { - ssc.receiverStream( - new KinesisReceiver(ssc.sc.appName, streamName, endpointUrl, getRegionByEndpoint(endpointUrl), - initialPositionInStream, checkpointInterval, storageLevel, None)) + ssc.withNamedScope("kinesis stream") { + new KinesisInputDStream(ssc, streamName, endpointUrl, getRegionByEndpoint(endpointUrl), + initialPositionInStream, ssc.sc.appName, checkpointInterval, storageLevel, None) + } } /** @@ -175,15 +176,15 @@ object KinesisUtils { * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) * @param regionName Name of region used by the Kinesis Client Library (KCL) to update * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the * worker's initial starting position in the stream. * The values are either the beginning of the stream * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects. * StorageLevel.MEMORY_AND_DISK_2 is recommended. */ @@ -206,8 +207,8 @@ object KinesisUtils { * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * * Note: - * The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. + * The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. * * @param jssc Java StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library @@ -216,19 +217,19 @@ object KinesisUtils { * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) * @param regionName Name of region used by the Kinesis Client Library (KCL) to update * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) - * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the * worker's initial starting position in the stream. * The values are either the beginning of the stream * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects. * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) */ def createStream( jssc: JavaStreamingContext, @@ -297,3 +298,49 @@ object KinesisUtils { } } } + +/** + * This is a helper class that wraps the methods in KinesisUtils into more Python-friendly class and + * function so that it can be easily instantiated and called from Python's KinesisUtils. + */ +private class KinesisUtilsPythonHelper { + + def getInitialPositionInStream(initialPositionInStream: Int): InitialPositionInStream = { + initialPositionInStream match { + case 0 => InitialPositionInStream.LATEST + case 1 => InitialPositionInStream.TRIM_HORIZON + case _ => throw new IllegalArgumentException( + "Illegal InitialPositionInStream. Please use " + + "InitialPositionInStream.LATEST or InitialPositionInStream.TRIM_HORIZON") + } + } + + def createStream( + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: Int, + checkpointInterval: Duration, + storageLevel: StorageLevel, + awsAccessKeyId: String, + awsSecretKey: String + ): JavaReceiverInputDStream[Array[Byte]] = { + if (awsAccessKeyId == null && awsSecretKey != null) { + throw new IllegalArgumentException("awsSecretKey is set but awsAccessKeyId is null") + } + if (awsAccessKeyId != null && awsSecretKey == null) { + throw new IllegalArgumentException("awsAccessKeyId is set but awsSecretKey is null") + } + if (awsAccessKeyId == null && awsSecretKey == null) { + KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, + getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel) + } else { + KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, + getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel, + awsAccessKeyId, awsSecretKey) + } + } + +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala new file mode 100644 index 000000000000..a89e5627e014 --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} +import org.apache.spark.{SparkConf, SparkContext, SparkException} + +class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll { + + private val testData = 1 to 8 + + private var testUtils: KinesisTestUtils = null + private var shardIds: Seq[String] = null + private var shardIdToData: Map[String, Seq[Int]] = null + private var shardIdToSeqNumbers: Map[String, Seq[String]] = null + private var shardIdToDataAndSeqNumbers: Map[String, Seq[(Int, String)]] = null + private var shardIdToRange: Map[String, SequenceNumberRange] = null + private var allRanges: Seq[SequenceNumberRange] = null + + private var sc: SparkContext = null + private var blockManager: BlockManager = null + + + override def beforeAll(): Unit = { + runIfTestsEnabled("Prepare KinesisTestUtils") { + testUtils = new KinesisTestUtils() + testUtils.createStream() + + shardIdToDataAndSeqNumbers = testUtils.pushData(testData) + require(shardIdToDataAndSeqNumbers.size > 1, "Need data to be sent to multiple shards") + + shardIds = shardIdToDataAndSeqNumbers.keySet.toSeq + shardIdToData = shardIdToDataAndSeqNumbers.mapValues { _.map { _._1 }} + shardIdToSeqNumbers = shardIdToDataAndSeqNumbers.mapValues { _.map { _._2 }} + shardIdToRange = shardIdToSeqNumbers.map { case (shardId, seqNumbers) => + val seqNumRange = SequenceNumberRange( + testUtils.streamName, shardId, seqNumbers.head, seqNumbers.last) + (shardId, seqNumRange) + } + allRanges = shardIdToRange.values.toSeq + + val conf = new SparkConf().setMaster("local[4]").setAppName("KinesisBackedBlockRDDSuite") + sc = new SparkContext(conf) + blockManager = sc.env.blockManager + } + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.deleteStream() + } + if (sc != null) { + sc.stop() + } + } + + testIfEnabled("Basic reading from Kinesis") { + // Verify all data using multiple ranges in a single RDD partition + val receivedData1 = new KinesisBackedBlockRDD(sc, testUtils.regionName, testUtils.endpointUrl, + fakeBlockIds(1), + Array(SequenceNumberRanges(allRanges.toArray)) + ).map { bytes => new String(bytes).toInt }.collect() + assert(receivedData1.toSet === testData.toSet) + + // Verify all data using one range in each of the multiple RDD partitions + val receivedData2 = new KinesisBackedBlockRDD(sc, testUtils.regionName, testUtils.endpointUrl, + fakeBlockIds(allRanges.size), + allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray + ).map { bytes => new String(bytes).toInt }.collect() + assert(receivedData2.toSet === testData.toSet) + + // Verify ordering within each partition + val receivedData3 = new KinesisBackedBlockRDD(sc, testUtils.regionName, testUtils.endpointUrl, + fakeBlockIds(allRanges.size), + allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray + ).map { bytes => new String(bytes).toInt }.collectPartitions() + assert(receivedData3.length === allRanges.size) + for (i <- 0 until allRanges.size) { + assert(receivedData3(i).toSeq === shardIdToData(allRanges(i).shardId)) + } + } + + testIfEnabled("Read data available in both block manager and Kinesis") { + testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 2) + } + + testIfEnabled("Read data available only in block manager, not in Kinesis") { + testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 0) + } + + testIfEnabled("Read data available only in Kinesis, not in block manager") { + testRDD(numPartitions = 2, numPartitionsInBM = 0, numPartitionsInKinesis = 2) + } + + testIfEnabled("Read data available partially in block manager, rest in Kinesis") { + testRDD(numPartitions = 2, numPartitionsInBM = 1, numPartitionsInKinesis = 1) + } + + testIfEnabled("Test isBlockValid skips block fetching from block manager") { + testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 0, + testIsBlockValid = true) + } + + testIfEnabled("Test whether RDD is valid after removing blocks from block anager") { + testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 2, + testBlockRemove = true) + } + + /** + * Test the WriteAheadLogBackedRDD, by writing some partitions of the data to block manager + * and the rest to a write ahead log, and then reading reading it all back using the RDD. + * It can also test if the partitions that were read from the log were again stored in + * block manager. + * + * + * + * @param numPartitions Number of partitions in RDD + * @param numPartitionsInBM Number of partitions to write to the BlockManager. + * Partitions 0 to (numPartitionsInBM-1) will be written to BlockManager + * @param numPartitionsInKinesis Number of partitions to write to the Kinesis. + * Partitions (numPartitions - 1 - numPartitionsInKinesis) to + * (numPartitions - 1) will be written to Kinesis + * @param testIsBlockValid Test whether setting isBlockValid to false skips block fetching + * @param testBlockRemove Test whether calling rdd.removeBlock() makes the RDD still usable with + * reads falling back to the WAL + * Example with numPartitions = 5, numPartitionsInBM = 3, and numPartitionsInWAL = 4 + * + * numPartitionsInBM = 3 + * |------------------| + * | | + * 0 1 2 3 4 + * | | + * |-------------------------| + * numPartitionsInKinesis = 4 + */ + private def testRDD( + numPartitions: Int, + numPartitionsInBM: Int, + numPartitionsInKinesis: Int, + testIsBlockValid: Boolean = false, + testBlockRemove: Boolean = false + ): Unit = { + require(shardIds.size > 1, "Need at least 2 shards to test") + require(numPartitionsInBM <= shardIds.size , + "Number of partitions in BlockManager cannot be more than the Kinesis test shards available") + require(numPartitionsInKinesis <= shardIds.size , + "Number of partitions in Kinesis cannot be more than the Kinesis test shards available") + require(numPartitionsInBM <= numPartitions, + "Number of partitions in BlockManager cannot be more than that in RDD") + require(numPartitionsInKinesis <= numPartitions, + "Number of partitions in Kinesis cannot be more than that in RDD") + + // Put necessary blocks in the block manager + val blockIds = fakeBlockIds(numPartitions) + blockIds.foreach(blockManager.removeBlock(_)) + (0 until numPartitionsInBM).foreach { i => + val blockData = shardIdToData(shardIds(i)).iterator.map { _.toString.getBytes() } + blockManager.putIterator(blockIds(i), blockData, StorageLevel.MEMORY_ONLY) + } + + // Create the necessary ranges to use in the RDD + val fakeRanges = Array.fill(numPartitions - numPartitionsInKinesis)( + SequenceNumberRanges(SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy"))) + val realRanges = Array.tabulate(numPartitionsInKinesis) { i => + val range = shardIdToRange(shardIds(i + (numPartitions - numPartitionsInKinesis))) + SequenceNumberRanges(Array(range)) + } + val ranges = (fakeRanges ++ realRanges) + + + // Make sure that the left `numPartitionsInBM` blocks are in block manager, and others are not + require( + blockIds.take(numPartitionsInBM).forall(blockManager.get(_).nonEmpty), + "Expected blocks not in BlockManager" + ) + + require( + blockIds.drop(numPartitionsInBM).forall(blockManager.get(_).isEmpty), + "Unexpected blocks in BlockManager" + ) + + // Make sure that the right sequence `numPartitionsInKinesis` are configured, and others are not + require( + ranges.takeRight(numPartitionsInKinesis).forall { + _.ranges.forall { _.streamName == testUtils.streamName } + }, "Incorrect configuration of RDD, expected ranges not set: " + ) + + require( + ranges.dropRight(numPartitionsInKinesis).forall { + _.ranges.forall { _.streamName != testUtils.streamName } + }, "Incorrect configuration of RDD, unexpected ranges set" + ) + + val rdd = new KinesisBackedBlockRDD( + sc, testUtils.regionName, testUtils.endpointUrl, blockIds, ranges) + val collectedData = rdd.map { bytes => + new String(bytes).toInt + }.collect() + assert(collectedData.toSet === testData.toSet) + + // Verify that the block fetching is skipped when isBlockValid is set to false. + // This is done by using a RDD whose data is only in memory but is set to skip block fetching + // Using that RDD will throw exception, as it skips block fetching even if the blocks are in + // in BlockManager. + if (testIsBlockValid) { + require(numPartitionsInBM === numPartitions, "All partitions must be in BlockManager") + require(numPartitionsInKinesis === 0, "No partitions must be in Kinesis") + val rdd2 = new KinesisBackedBlockRDD( + sc, testUtils.regionName, testUtils.endpointUrl, blockIds.toArray, ranges, + isBlockIdValid = Array.fill(blockIds.length)(false)) + intercept[SparkException] { + rdd2.collect() + } + } + + // Verify that the RDD is not invalid after the blocks are removed and can still read data + // from write ahead log + if (testBlockRemove) { + require(numPartitions === numPartitionsInKinesis, + "All partitions must be in WAL for this test") + require(numPartitionsInBM > 0, "Some partitions must be in BlockManager for this test") + rdd.removeBlocks() + assert(rdd.map { bytes => new String(bytes).toInt }.collect().toSet === testData.toSet) + } + } + + /** Generate fake block ids */ + private def fakeBlockIds(num: Int): Array[BlockId] = { + Array.tabulate(num) { i => new StreamBlockId(0, i) } + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala new file mode 100644 index 000000000000..ee428f31d6ce --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import org.apache.spark.SparkFunSuite + +/** + * Helper class that runs Kinesis real data transfer tests or + * ignores them based on env variable is set or not. + */ +trait KinesisFunSuite extends SparkFunSuite { + import KinesisTestUtils._ + + /** Run the test if environment variable is set or ignore the test */ + def testIfEnabled(testName: String)(testBody: => Unit) { + if (shouldRunTests) { + test(testName)(testBody) + } else { + ignore(s"$testName [enable by setting env var $envVarNameForEnablingTests=1]")(testBody) + } + } + + /** Run the give body of code only if Kinesis tests are enabled */ + def runIfTestsEnabled(message: String)(body: => Unit): Unit = { + if (shouldRunTests) { + body + } else { + ignore(s"$message [enable by setting env var $envVarNameForEnablingTests=1]")() + } + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 2103dca6b766..3d136aec2e70 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -17,20 +17,19 @@ package org.apache.spark.streaming.kinesis import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import java.util.Arrays -import scala.collection.JavaConversions.seqAsJavaList - -import com.amazonaws.services.kinesis.clientlibrary.exceptions.{InvalidStateException, KinesisClientLibDependencyException, ShutdownException, ThrottlingException} +import com.amazonaws.services.kinesis.clientlibrary.exceptions._ import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record +import org.mockito.Matchers._ import org.mockito.Mockito._ -import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.mock.MockitoSugar +import org.scalatest.{BeforeAndAfter, Matchers} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Milliseconds, Seconds, StreamingContext, TestSuiteBase} +import org.apache.spark.streaming.{Milliseconds, TestSuiteBase} import org.apache.spark.util.{Clock, ManualClock, Utils} /** @@ -44,12 +43,14 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft val endpoint = "endpoint-url" val workerId = "dummyWorkerId" val shardId = "dummyShardId" + val seqNum = "dummySeqNum" + val someSeqNum = Some(seqNum) val record1 = new Record() - record1.setData(ByteBuffer.wrap("Spark In Action".getBytes())) + record1.setData(ByteBuffer.wrap("Spark In Action".getBytes(StandardCharsets.UTF_8))) val record2 = new Record() - record2.setData(ByteBuffer.wrap("Learning Spark".getBytes())) - val batch = List[Record](record1, record2) + record2.setData(ByteBuffer.wrap("Learning Spark".getBytes(StandardCharsets.UTF_8))) + val batch = Arrays.asList(record1, record2) var receiverMock: KinesisReceiver = _ var checkpointerMock: IRecordProcessorCheckpointer = _ @@ -73,23 +74,6 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft checkpointStateMock, currentClockMock) } - test("KinesisUtils API") { - val ssc = new StreamingContext(master, framework, batchDuration) - // Tests the API, does not actually test data receiving - val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", Seconds(2), - InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2) - val kinesisStream2 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", "us-west-2", - InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2) - val kinesisStream3 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", "us-west-2", - InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2, - "awsAccessKey", "awsSecretKey") - - ssc.stop() - } - test("check serializability of SerializableAWSCredentials") { Utils.deserialize[SerializableAWSCredentials]( Utils.serialize(new SerializableAWSCredentials("x", "y"))) @@ -97,16 +81,18 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft test("process records including store and checkpoint") { when(receiverMock.isStopped()).thenReturn(false) + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) when(checkpointStateMock.shouldCheckpoint()).thenReturn(true) val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + recordProcessor.initialize(shardId) recordProcessor.processRecords(batch, checkpointerMock) verify(receiverMock, times(1)).isStopped() - verify(receiverMock, times(1)).store(record1.getData().array()) - verify(receiverMock, times(1)).store(record2.getData().array()) + verify(receiverMock, times(1)).addRecords(shardId, batch) + verify(receiverMock, times(1)).getLatestSeqNumToCheckpoint(shardId) verify(checkpointStateMock, times(1)).shouldCheckpoint() - verify(checkpointerMock, times(1)).checkpoint() + verify(checkpointerMock, times(1)).checkpoint(anyString) verify(checkpointStateMock, times(1)).advanceCheckpoint() } @@ -117,19 +103,25 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft recordProcessor.processRecords(batch, checkpointerMock) verify(receiverMock, times(1)).isStopped() + verify(receiverMock, never).addRecords(anyString, anyListOf(classOf[Record])) + verify(checkpointerMock, never).checkpoint(anyString) } test("shouldn't checkpoint when exception occurs during store") { when(receiverMock.isStopped()).thenReturn(false) - when(receiverMock.store(record1.getData().array())).thenThrow(new RuntimeException()) + when( + receiverMock.addRecords(anyString, anyListOf(classOf[Record])) + ).thenThrow(new RuntimeException()) intercept[RuntimeException] { val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + recordProcessor.initialize(shardId) recordProcessor.processRecords(batch, checkpointerMock) } verify(receiverMock, times(1)).isStopped() - verify(receiverMock, times(1)).store(record1.getData().array()) + verify(receiverMock, times(1)).addRecords(shardId, batch) + verify(checkpointerMock, never).checkpoint(anyString) } test("should set checkpoint time to currentTime + checkpoint interval upon instantiation") { @@ -175,19 +167,25 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft } test("shutdown should checkpoint if the reason is TERMINATE") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) - val reason = ShutdownReason.TERMINATE - recordProcessor.shutdown(checkpointerMock, reason) + recordProcessor.initialize(shardId) + recordProcessor.shutdown(checkpointerMock, ShutdownReason.TERMINATE) - verify(checkpointerMock, times(1)).checkpoint() + verify(receiverMock, times(1)).getLatestSeqNumToCheckpoint(shardId) + verify(checkpointerMock, times(1)).checkpoint(anyString) } test("shutdown should not checkpoint if the reason is something other than TERMINATE") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + recordProcessor.initialize(shardId) recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE) recordProcessor.shutdown(checkpointerMock, null) - verify(checkpointerMock, never()).checkpoint() + verify(checkpointerMock, never).checkpoint(anyString) } test("retry success on first attempt") { diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala new file mode 100644 index 000000000000..1177dc758100 --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import scala.collection.mutable +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.util.Random + +import com.amazonaws.regions.RegionUtils +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import org.scalatest.Matchers._ +import org.scalatest.concurrent.Eventually +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} + +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.{StorageLevel, StreamBlockId} +import org.apache.spark.streaming._ +import org.apache.spark.streaming.kinesis.KinesisTestUtils._ +import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult +import org.apache.spark.streaming.scheduler.ReceivedBlockInfo +import org.apache.spark.util.Utils +import org.apache.spark.{SparkConf, SparkContext} + +class KinesisStreamSuite extends KinesisFunSuite + with Eventually with BeforeAndAfter with BeforeAndAfterAll { + + // This is the name that KCL will use to save metadata to DynamoDB + private val appName = s"KinesisStreamSuite-${math.abs(Random.nextLong())}" + private val batchDuration = Seconds(1) + + // Dummy parameters for API testing + private val dummyEndpointUrl = defaultEndpointUrl + private val dummyRegionName = RegionUtils.getRegionByEndpoint(dummyEndpointUrl).getName() + private val dummyAWSAccessKey = "dummyAccessKey" + private val dummyAWSSecretKey = "dummySecretKey" + + private var testUtils: KinesisTestUtils = null + private var ssc: StreamingContext = null + private var sc: SparkContext = null + + override def beforeAll(): Unit = { + val conf = new SparkConf() + .setMaster("local[4]") + .setAppName("KinesisStreamSuite") // Setting Spark app name to Kinesis app name + sc = new SparkContext(conf) + + runIfTestsEnabled("Prepare KinesisTestUtils") { + testUtils = new KinesisTestUtils() + testUtils.createStream() + } + } + + override def afterAll(): Unit = { + if (ssc != null) { + ssc.stop() + } + if (sc != null) { + sc.stop() + } + if (testUtils != null) { + // Delete the Kinesis stream as well as the DynamoDB table generated by + // Kinesis Client Library when consuming the stream + testUtils.deleteStream() + testUtils.deleteDynamoDBTable(appName) + } + } + + before { + ssc = new StreamingContext(sc, batchDuration) + } + + after { + if (ssc != null) { + ssc.stop(stopSparkContext = false) + ssc = null + } + if (testUtils != null) { + testUtils.deleteDynamoDBTable(appName) + } + } + + test("KinesisUtils API") { + // Tests the API, does not actually test data receiving + val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream", + dummyEndpointUrl, Seconds(2), + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2) + val kinesisStream2 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", + dummyEndpointUrl, dummyRegionName, + InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2) + val kinesisStream3 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", + dummyEndpointUrl, dummyRegionName, + InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2, + dummyAWSAccessKey, dummyAWSSecretKey) + } + + test("RDD generation") { + val inputStream = KinesisUtils.createStream(ssc, appName, "dummyStream", + dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, Seconds(2), + StorageLevel.MEMORY_AND_DISK_2, dummyAWSAccessKey, dummyAWSSecretKey) + assert(inputStream.isInstanceOf[KinesisInputDStream]) + + val kinesisStream = inputStream.asInstanceOf[KinesisInputDStream] + val time = Time(1000) + + // Generate block info data for testing + val seqNumRanges1 = SequenceNumberRanges( + SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy")) + val blockId1 = StreamBlockId(kinesisStream.id, 123) + val blockInfo1 = ReceivedBlockInfo( + 0, None, Some(seqNumRanges1), new BlockManagerBasedStoreResult(blockId1, None)) + + val seqNumRanges2 = SequenceNumberRanges( + SequenceNumberRange("fakeStream", "fakeShardId", "aaa", "bbb")) + val blockId2 = StreamBlockId(kinesisStream.id, 345) + val blockInfo2 = ReceivedBlockInfo( + 0, None, Some(seqNumRanges2), new BlockManagerBasedStoreResult(blockId2, None)) + + // Verify that the generated KinesisBackedBlockRDD has the all the right information + val blockInfos = Seq(blockInfo1, blockInfo2) + val nonEmptyRDD = kinesisStream.createBlockRDD(time, blockInfos) + nonEmptyRDD shouldBe a [KinesisBackedBlockRDD] + val kinesisRDD = nonEmptyRDD.asInstanceOf[KinesisBackedBlockRDD] + assert(kinesisRDD.regionName === dummyRegionName) + assert(kinesisRDD.endpointUrl === dummyEndpointUrl) + assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds) + assert(kinesisRDD.awsCredentialsOption === + Some(SerializableAWSCredentials(dummyAWSAccessKey, dummyAWSSecretKey))) + assert(nonEmptyRDD.partitions.size === blockInfos.size) + nonEmptyRDD.partitions.foreach { _ shouldBe a [KinesisBackedBlockRDDPartition] } + val partitions = nonEmptyRDD.partitions.map { + _.asInstanceOf[KinesisBackedBlockRDDPartition] }.toSeq + assert(partitions.map { _.seqNumberRanges } === Seq(seqNumRanges1, seqNumRanges2)) + assert(partitions.map { _.blockId } === Seq(blockId1, blockId2)) + assert(partitions.forall { _.isBlockIdValid === true }) + + // Verify that KinesisBackedBlockRDD is generated even when there are no blocks + val emptyRDD = kinesisStream.createBlockRDD(time, Seq.empty) + emptyRDD shouldBe a [KinesisBackedBlockRDD] + emptyRDD.partitions shouldBe empty + + // Verify that the KinesisBackedBlockRDD has isBlockValid = false when blocks are invalid + blockInfos.foreach { _.setBlockIdInvalid() } + kinesisStream.createBlockRDD(time, blockInfos).partitions.foreach { partition => + assert(partition.asInstanceOf[KinesisBackedBlockRDDPartition].isBlockIdValid === false) + } + } + + + /** + * Test the stream by sending data to a Kinesis stream and receiving from it. + * This test is not run by default as it requires AWS credentials that the test + * environment may not have. Even if there is AWS credentials available, the user + * may not want to run these tests to avoid the Kinesis costs. To enable this test, + * you must have AWS credentials available through the default AWS provider chain, + * and you have to set the system environment variable RUN_KINESIS_TESTS=1 . + */ + testIfEnabled("basic operation") { + val awsCredentials = KinesisTestUtils.getAWSCredentials() + val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, + testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, + Seconds(10), StorageLevel.MEMORY_ONLY, + awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + + val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] + stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => + collected ++= rdd.collect() + logInfo("Collected = " + rdd.collect().toSeq.mkString(", ")) + } + ssc.start() + + val testData = 1 to 10 + eventually(timeout(120 seconds), interval(10 second)) { + testUtils.pushData(testData) + assert(collected === testData.toSet, "\nData received does not match data sent") + } + ssc.stop(stopSparkContext = false) + } + + testIfEnabled("failure recovery") { + val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName) + val checkpointDir = Utils.createTempDir().getAbsolutePath + + ssc = new StreamingContext(sc, Milliseconds(1000)) + ssc.checkpoint(checkpointDir) + + val awsCredentials = KinesisTestUtils.getAWSCredentials() + val collectedData = new mutable.HashMap[Time, (Array[SequenceNumberRanges], Seq[Int])] + with mutable.SynchronizedMap[Time, (Array[SequenceNumberRanges], Seq[Int])] + + val kinesisStream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, + testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, + Seconds(10), StorageLevel.MEMORY_ONLY, + awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + + // Verify that the generated RDDs are KinesisBackedBlockRDDs, and collect the data in each batch + kinesisStream.foreachRDD((rdd: RDD[Array[Byte]], time: Time) => { + val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD] + val data = rdd.map { bytes => new String(bytes).toInt }.collect().toSeq + collectedData(time) = (kRdd.arrayOfseqNumberRanges, data) + }) + + ssc.remember(Minutes(60)) // remember all the batches so that they are all saved in checkpoint + ssc.start() + + def numBatchesWithData: Int = collectedData.count(_._2._2.nonEmpty) + + def isCheckpointPresent: Boolean = Checkpoint.getCheckpointFiles(checkpointDir).nonEmpty + + // Run until there are at least 10 batches with some data in them + // If this times out because numBatchesWithData is empty, then its likely that foreachRDD + // function failed with exceptions, and nothing got added to `collectedData` + eventually(timeout(2 minutes), interval(1 seconds)) { + testUtils.pushData(1 to 5) + assert(isCheckpointPresent && numBatchesWithData > 10) + } + ssc.stop(stopSparkContext = true) // stop the SparkContext so that the blocks are not reused + + // Restart the context from checkpoint and verify whether the + logInfo("Restarting from checkpoint") + ssc = new StreamingContext(checkpointDir) + ssc.start() + val recoveredKinesisStream = ssc.graph.getInputStreams().head + + // Verify that the recomputed RDDs are KinesisBackedBlockRDDs with the same sequence ranges + // and return the same data + val times = collectedData.keySet + times.foreach { time => + val (arrayOfSeqNumRanges, data) = collectedData(time) + val rdd = recoveredKinesisStream.getOrCompute(time).get.asInstanceOf[RDD[Array[Byte]]] + rdd shouldBe a [KinesisBackedBlockRDD] + + // Verify the recovered sequence ranges + val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD] + assert(kRdd.arrayOfseqNumberRanges.size === arrayOfSeqNumRanges.size) + arrayOfSeqNumRanges.zip(kRdd.arrayOfseqNumberRanges).foreach { case (expected, found) => + assert(expected.ranges.toSeq === found.ranges.toSeq) + } + + // Verify the recovered data + assert(rdd.map { bytes => new String(bytes).toInt }.collect().toSeq === data) + } + ssc.stop() + } + +} diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index 478d0019a25f..87a4f05a0596 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 853dea9a7795..202fc19002d1 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java b/graphx/src/main/java/org/apache/spark/graphx/TripletFields.java similarity index 100% rename from graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java rename to graphx/src/main/java/org/apache/spark/graphx/TripletFields.java diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeActiveness.java b/graphx/src/main/java/org/apache/spark/graphx/impl/EdgeActiveness.java similarity index 100% rename from graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeActiveness.java rename to graphx/src/main/java/org/apache/spark/graphx/impl/EdgeActiveness.java diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index 4611a3ace219..ee7302a1edbf 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -38,8 +38,8 @@ import org.apache.spark.graphx.impl.EdgeRDDImpl * `impl.ReplicatedVertexView`. */ abstract class EdgeRDD[ED]( - @transient sc: SparkContext, - @transient deps: Seq[Dependency[_]]) extends RDD[Edge[ED]](sc, deps) { + sc: SparkContext, + deps: Seq[Dependency[_]]) extends RDD[Edge[ED]](sc, deps) { // scalastyle:off structural.type private[graphx] def partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])] forSome { type VD } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index db73a8abc573..869caa340f52 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -46,7 +46,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * @note vertex ids are unique. * @return an RDD containing the vertices in this graph */ - @transient val vertices: VertexRDD[VD] + val vertices: VertexRDD[VD] /** * An RDD containing the edges and their associated attributes. The entries in the RDD contain @@ -59,7 +59,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * along with their vertex data. * */ - @transient val edges: EdgeRDD[ED] + val edges: EdgeRDD[ED] /** * An RDD containing the edge triplets, which are edges along with the vertex data associated with @@ -77,7 +77,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * val numInvalid = graph.triplets.map(e => if (e.src.data == e.dst.data) 1 else 0).sum * }}} */ - @transient val triplets: RDD[EdgeTriplet[VD, ED]] + val triplets: RDD[EdgeTriplet[VD, ED]] /** * Caches the vertices and edges associated with this graph at the specified storage level, diff --git a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala index 7372dfbd9fe9..70a7592da8ae 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala @@ -32,7 +32,7 @@ trait PartitionStrategy extends Serializable { object PartitionStrategy { /** * Assigns edges to partitions using a 2D partitioning of the sparse edge adjacency matrix, - * guaranteeing a `2 * sqrt(numParts) - 1` bound on vertex replication. + * guaranteeing a `2 * sqrt(numParts)` bound on vertex replication. * * Suppose we have a graph with 12 vertices that we want to partition * over 9 machines. We can use the following sparse matrix representation: @@ -61,26 +61,36 @@ object PartitionStrategy { * that edges adjacent to `v11` can only be in the first column of blocks `(P0, P3, * P6)` or the last * row of blocks `(P6, P7, P8)`. As a consequence we can guarantee that `v11` will need to be - * replicated to at most `2 * sqrt(numParts) - 1` machines. + * replicated to at most `2 * sqrt(numParts)` machines. * * Notice that `P0` has many edges and as a consequence this partitioning would lead to poor work * balance. To improve balance we first multiply each vertex id by a large prime to shuffle the * vertex locations. * - * One of the limitations of this approach is that the number of machines must either be a - * perfect square. We partially address this limitation by computing the machine assignment to - * the next - * largest perfect square and then mapping back down to the actual number of machines. - * Unfortunately, this can also lead to work imbalance and so it is suggested that a perfect - * square is used. + * When the number of partitions requested is not a perfect square we use a slightly different + * method where the last column can have a different number of rows than the others while still + * maintaining the same size per block. */ case object EdgePartition2D extends PartitionStrategy { override def getPartition(src: VertexId, dst: VertexId, numParts: PartitionID): PartitionID = { val ceilSqrtNumParts: PartitionID = math.ceil(math.sqrt(numParts)).toInt val mixingPrime: VertexId = 1125899906842597L - val col: PartitionID = (math.abs(src * mixingPrime) % ceilSqrtNumParts).toInt - val row: PartitionID = (math.abs(dst * mixingPrime) % ceilSqrtNumParts).toInt - (col * ceilSqrtNumParts + row) % numParts + if (numParts == ceilSqrtNumParts * ceilSqrtNumParts) { + // Use old method for perfect squared to ensure we get same results + val col: PartitionID = (math.abs(src * mixingPrime) % ceilSqrtNumParts).toInt + val row: PartitionID = (math.abs(dst * mixingPrime) % ceilSqrtNumParts).toInt + (col * ceilSqrtNumParts + row) % numParts + + } else { + // Otherwise use new method + val cols = ceilSqrtNumParts + val rows = (numParts + cols - 1) / cols + val lastColRows = numParts - rows * (cols - 1) + val col = (math.abs(src * mixingPrime) % numParts / rows).toInt + val row = (math.abs(dst * mixingPrime) % (if (col < cols - 1) rows else lastColRows)).toInt + col * rows + row + + } } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index cfcf7244eaed..2ca60d51f833 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -127,28 +127,25 @@ object Pregel extends Logging { var prevG: Graph[VD, ED] = null var i = 0 while (activeMessages > 0 && i < maxIterations) { - // Receive the messages. Vertices that didn't get any messages do not appear in newVerts. - val newVerts = g.vertices.innerJoin(messages)(vprog).cache() - // Update the graph with the new vertices. + // Receive the messages and update the vertices. prevG = g - g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) } - g.cache() + g = g.joinVertices(messages)(vprog).cache() val oldMessages = messages - // Send new messages. Vertices that didn't get any messages don't appear in newVerts, so don't - // get to send messages. We must cache messages so it can be materialized on the next line, - // allowing us to uncache the previous iteration. - messages = g.mapReduceTriplets(sendMsg, mergeMsg, Some((newVerts, activeDirection))).cache() - // The call to count() materializes `messages`, `newVerts`, and the vertices of `g`. This - // hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the - // vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g). + // Send new messages, skipping edges where neither side received a message. We must cache + // messages so it can be materialized on the next line, allowing us to uncache the previous + // iteration. + messages = g.mapReduceTriplets( + sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() + // The call to count() materializes `messages` and the vertices of `g`. This hides oldMessages + // (depended on by the vertices of g) and the vertices of prevG (depended on by oldMessages + // and the vertices of g). activeMessages = messages.count() logInfo("Pregel finished iteration " + i) // Unpersist the RDDs hidden by newly-materialized RDDs oldMessages.unpersist(blocking = false) - newVerts.unpersist(blocking = false) prevG.unpersistVertices(blocking = false) prevG.edges.unpersist(blocking = false) // count the iteration diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index a9f04b559c3d..1ef7a78fbcd0 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -55,8 +55,8 @@ import org.apache.spark.graphx.impl.VertexRDDImpl * @tparam VD the vertex attribute associated with each vertex in the set. */ abstract class VertexRDD[VD]( - @transient sc: SparkContext, - @transient deps: Seq[Dependency[_]]) extends RDD[(VertexId, VD)](sc, deps) { + sc: SparkContext, + deps: Seq[Dependency[_]]) extends RDD[(VertexId, VD)](sc, deps) { implicit protected def vdTag: ClassTag[VD] diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 90a74d23a26c..da95314440d8 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -332,9 +332,9 @@ object GraphImpl { edgeStorageLevel: StorageLevel, vertexStorageLevel: StorageLevel): GraphImpl[VD, ED] = { val edgeRDD = EdgeRDD.fromEdges(edges)(classTag[ED], classTag[VD]) - .withTargetStorageLevel(edgeStorageLevel).cache() + .withTargetStorageLevel(edgeStorageLevel) val vertexRDD = VertexRDD(vertices, edgeRDD, defaultVertexAttr) - .withTargetStorageLevel(vertexStorageLevel).cache() + .withTargetStorageLevel(vertexStorageLevel) GraphImpl(vertexRDD, edgeRDD) } @@ -346,9 +346,14 @@ object GraphImpl { def apply[VD: ClassTag, ED: ClassTag]( vertices: VertexRDD[VD], edges: EdgeRDD[ED]): GraphImpl[VD, ED] = { + + vertices.cache() + // Convert the vertex partitions in edges to the correct type val newEdges = edges.asInstanceOf[EdgeRDDImpl[ED, _]] .mapEdgePartitions((pid, part) => part.withoutVertexAttributes[VD]) + .cache() + GraphImpl.fromExistingRDDs(vertices, newEdges) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index eb3c997e0f3c..4f1260a5a67b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -34,7 +34,7 @@ object RoutingTablePartition { /** * A message from an edge partition to a vertex specifying the position in which the edge * partition references the vertex (src, dst, or both). The edge partition is encoded in the lower - * 30 bytes of the Int, and the position is encoded in the upper 2 bytes of the Int. + * 30 bits of the Int, and the position is encoded in the upper 2 bits of the Int. */ type RoutingTableMessage = (VertexId, Int) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala index 33ac7b0ed609..7f4e7e9d79d6 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala @@ -87,7 +87,7 @@ class VertexRDDImpl[VD] private[graphx] ( /** The number of vertices in the RDD. */ override def count(): Long = { - partitionsRDD.map(_.size).reduce(_ + _) + partitionsRDD.map(_.size.toLong).reduce(_ + _) } override private[graphx] def mapVertexPartitions[VD2: ClassTag]( diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala index 2bcf8684b8b8..a3ad6bed1c99 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala @@ -43,7 +43,7 @@ object LabelPropagation { */ def run[VD, ED: ClassTag](graph: Graph[VD, ED], maxSteps: Int): Graph[VertexId, ED] = { val lpaGraph = graph.mapVertices { case (vid, _) => vid } - def sendMessage(e: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, Map[VertexId, VertexId])] = { + def sendMessage(e: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, Map[VertexId, Long])] = { Iterator((e.srcId, Map(e.dstAttr -> 1L)), (e.dstId, Map(e.srcAttr -> 1L))) } def mergeMessage(count1: Map[VertexId, Long], count2: Map[VertexId, Long]) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index be6b9047d932..74a7de18d416 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -66,7 +66,6 @@ private[graphx] object BytecodeUtils { val finder = new MethodInvocationFinder(c.getName, m) getClassReader(c).accept(finder, 0) for (classMethod <- finder.methodsInvoked) { - // println(classMethod) if (classMethod._1 == targetClass && classMethod._2 == targetMethod) { return true } else if (!seen.contains(classMethod)) { @@ -122,7 +121,7 @@ private[graphx] object BytecodeUtils { override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) { if (!skipClass(owner)) { - methodsInvoked.add((Class.forName(owner.replace("/", ".")), name)) + methodsInvoked.add((Utils.classForName(owner.replace("/", ".")), name)) } } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 9591c4e9b8f4..989e22630526 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -33,7 +33,7 @@ import org.apache.spark.graphx.Edge import org.apache.spark.graphx.impl.GraphImpl /** A collection of graph generating functions. */ -object GraphGenerators { +object GraphGenerators extends Logging { val RMATa = 0.45 val RMATb = 0.15 @@ -142,7 +142,7 @@ object GraphGenerators { var edges: Set[Edge[Int]] = Set() while (edges.size < numEdges) { if (edges.size % 100 == 0) { - println(edges.size + " edges") + logDebug(edges.size + " edges") } edges += addEdge(numVertices) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala index 186d0cc2a977..61e44dcab578 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.graphx.util import org.apache.spark.SparkFunSuite +// scalastyle:off println class BytecodeUtilsSuite extends SparkFunSuite { import BytecodeUtilsSuite.TestClass @@ -102,6 +103,7 @@ class BytecodeUtilsSuite extends SparkFunSuite { private val c = {e: TestClass => println(e.baz)} } +// scalastyle:on println object BytecodeUtilsSuite { class TestClass(val foo: Int, val bar: Long) { diff --git a/launcher/pom.xml b/launcher/pom.xml index 48dd0d5f9106..ed38e66aa246 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml @@ -49,7 +49,7 @@ org.mockito - mockito-all + mockito-core test @@ -68,12 +68,6 @@ org.apache.hadoop hadoop-client test - - - org.codehaus.jackson - jackson-mapper-asl - - diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 33d65d13f0d2..0a237ee73b67 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -136,7 +136,7 @@ void addPermGenSizeOpt(List cmd) { } } - cmd.add("-XX:MaxPermSize=128m"); + cmd.add("-XX:MaxPermSize=256m"); } void addOptionString(List cmd, String options) { @@ -169,9 +169,11 @@ List buildClassPath(String appClassPath) throws IOException { "streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver", "yarn", "launcher"); if (prependClasses) { - System.err.println( - "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark classes ahead of " + - "assembly."); + if (!isTesting) { + System.err.println( + "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark classes ahead of " + + "assembly."); + } for (String project : projects) { addToClassPath(cp, String.format("%s/%s/target/scala-%s/classes", sparkHome, project, scala)); @@ -200,7 +202,7 @@ List buildClassPath(String appClassPath) throws IOException { // For the user code case, we fall back to looking for the Spark assembly under SPARK_HOME. // That duplicates some of the code in the shell scripts that look for the assembly, though. String assembly = getenv(ENV_SPARK_ASSEMBLY); - if (assembly == null && isEmpty(getenv("SPARK_TESTING"))) { + if (assembly == null && !isTesting) { assembly = findAssembly(); } addToClassPath(cp, assembly); @@ -215,12 +217,14 @@ List buildClassPath(String appClassPath) throws IOException { libdir = new File(sparkHome, "lib_managed/jars"); } - checkState(libdir.isDirectory(), "Library directory '%s' does not exist.", - libdir.getAbsolutePath()); - for (File jar : libdir.listFiles()) { - if (jar.getName().startsWith("datanucleus-")) { - addToClassPath(cp, jar.getAbsolutePath()); + if (libdir.isDirectory()) { + for (File jar : libdir.listFiles()) { + if (jar.getName().startsWith("datanucleus-")) { + addToClassPath(cp, jar.getAbsolutePath()); + } } + } else { + checkState(isTesting, "Library directory '%s' does not exist.", libdir.getAbsolutePath()); } addToClassPath(cp, getenv("HADOOP_CONF_DIR")); @@ -256,15 +260,15 @@ String getScalaVersion() { return scala; } String sparkHome = getSparkHome(); - File scala210 = new File(sparkHome, "assembly/target/scala-2.10"); - File scala211 = new File(sparkHome, "assembly/target/scala-2.11"); + File scala210 = new File(sparkHome, "launcher/target/scala-2.10"); + File scala211 = new File(sparkHome, "launcher/target/scala-2.11"); checkState(!scala210.isDirectory() || !scala211.isDirectory(), "Presence of build for both scala versions (2.10 and 2.11) detected.\n" + "Either clean one of them or set SPARK_SCALA_VERSION in your environment."); if (scala210.isDirectory()) { return "2.10"; } else { - checkState(scala211.isDirectory(), "Cannot find any assembly build directories."); + checkState(scala211.isDirectory(), "Cannot find any build directories."); return "2.11"; } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 2665a700fe1f..a16c0d2b5ca0 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -27,7 +27,7 @@ */ class CommandBuilderUtils { - static final String DEFAULT_MEM = "512m"; + static final String DEFAULT_MEM = "1g"; static final String DEFAULT_PROPERTIES_FILE = "spark-defaults.conf"; static final String ENV_SPARK_HOME = "SPARK_HOME"; static final String ENV_SPARK_ASSEMBLY = "_SPARK_ASSEMBLY"; diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 62492f9baf3b..a4e3acc674f3 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -32,7 +32,7 @@ class Main { /** * Usage: Main [class] [class args] - *

    + *

    * This CLI works in two different modes: *

      *
    • "spark-submit": if class is "org.apache.spark.deploy.SparkSubmit", the @@ -42,7 +42,7 @@ class Main { * * This class works in tandem with the "bin/spark-class" script on Unix-like systems, and * "bin/spark-class2.cmd" batch script on Windows to execute the final command. - *

      + *

      * On Unix-like systems, the output is a list of command arguments, separated by the NULL * character. On Windows, the output is a command line suitable for direct execution from the * script. diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index de85720febf2..931a24cfd4b1 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -28,7 +28,7 @@ /** * Command builder for internal Spark classes. - *

      + *

      * This class handles building the command to launch all internal Spark classes except for * SparkSubmit (which is handled by {@link SparkSubmitCommandBuilder} class. */ @@ -69,7 +69,8 @@ public List buildCommand(Map env) throws IOException { } else if (className.equals("org.apache.spark.executor.MesosExecutorBackend")) { javaOptsKeys.add("SPARK_EXECUTOR_OPTS"); memKey = "SPARK_EXECUTOR_MEMORY"; - } else if (className.equals("org.apache.spark.deploy.ExternalShuffleService")) { + } else if (className.equals("org.apache.spark.deploy.ExternalShuffleService") || + className.equals("org.apache.spark.deploy.mesos.MesosExternalShuffleService")) { javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); javaOptsKeys.add("SPARK_SHUFFLE_OPTS"); memKey = "SPARK_DAEMON_MEMORY"; diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index d4cfeacb6ef1..57993405e47b 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -20,6 +20,7 @@ import java.io.File; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Map; @@ -27,9 +28,10 @@ /** * Launcher for Spark applications. - *

      + *

      * Use this class to start Spark applications programmatically. The class uses a builder pattern * to allow clients to configure the Spark application and launch it as a child process. + *

      */ public class SparkLauncher { @@ -56,7 +58,8 @@ public class SparkLauncher { /** Configuration key for the number of executor CPU cores. */ public static final String EXECUTOR_CORES = "spark.executor.cores"; - private final SparkSubmitCommandBuilder builder; + // Visible for testing. + final SparkSubmitCommandBuilder builder; public SparkLauncher() { this(null); @@ -186,6 +189,73 @@ public SparkLauncher setMainClass(String mainClass) { return this; } + /** + * Adds a no-value argument to the Spark invocation. If the argument is known, this method + * validates whether the argument is indeed a no-value argument, and throws an exception + * otherwise. + *

      + * Use this method with caution. It is possible to create an invalid Spark command by passing + * unknown arguments to this method, since those are allowed for forward compatibility. + * + * @param arg Argument to add. + * @return This launcher. + */ + public SparkLauncher addSparkArg(String arg) { + SparkSubmitOptionParser validator = new ArgumentValidator(false); + validator.parse(Arrays.asList(arg)); + builder.sparkArgs.add(arg); + return this; + } + + /** + * Adds an argument with a value to the Spark invocation. If the argument name corresponds to + * a known argument, the code validates that the argument actually expects a value, and throws + * an exception otherwise. + *

      + * It is safe to add arguments modified by other methods in this class (such as + * {@link #setMaster(String)} - the last invocation will be the one to take effect. + *

      + * Use this method with caution. It is possible to create an invalid Spark command by passing + * unknown arguments to this method, since those are allowed for forward compatibility. + * + * @param name Name of argument to add. + * @param value Value of the argument. + * @return This launcher. + */ + public SparkLauncher addSparkArg(String name, String value) { + SparkSubmitOptionParser validator = new ArgumentValidator(true); + if (validator.MASTER.equals(name)) { + setMaster(value); + } else if (validator.PROPERTIES_FILE.equals(name)) { + setPropertiesFile(value); + } else if (validator.CONF.equals(name)) { + String[] vals = value.split("=", 2); + setConf(vals[0], vals[1]); + } else if (validator.CLASS.equals(name)) { + setMainClass(value); + } else if (validator.JARS.equals(name)) { + builder.jars.clear(); + for (String jar : value.split(",")) { + addJar(jar); + } + } else if (validator.FILES.equals(name)) { + builder.files.clear(); + for (String file : value.split(",")) { + addFile(file); + } + } else if (validator.PY_FILES.equals(name)) { + builder.pyFiles.clear(); + for (String file : value.split(",")) { + addPyFile(file); + } + } else { + validator.parse(Arrays.asList(name, value)); + builder.sparkArgs.add(name); + builder.sparkArgs.add(value); + } + return this; + } + /** * Adds command line arguments for the application. * @@ -276,4 +346,32 @@ public Process launch() throws IOException { return pb.start(); } + private static class ArgumentValidator extends SparkSubmitOptionParser { + + private final boolean hasValue; + + ArgumentValidator(boolean hasValue) { + this.hasValue = hasValue; + } + + @Override + protected boolean handle(String opt, String value) { + if (value == null && hasValue) { + throw new IllegalArgumentException(String.format("'%s' does not expect a value.", opt)); + } + return true; + } + + @Override + protected boolean handleUnknown(String opt) { + // Do not fail on unknown arguments, to support future arguments added to SparkSubmit. + return true; + } + + protected void handleExtraArgs(List extra) { + // No op. + } + + }; + } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 3e5a2820b6c1..fc87814a59ed 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -25,11 +25,11 @@ /** * Special command builder for handling a CLI invocation of SparkSubmit. - *

      + *

      * This builder adds command line parsing compatible with SparkSubmit. It handles setting * driver-side options and special parsing behavior needed for the special-casing certain internal * Spark applications. - *

      + *

      * This class has also some special features to aid launching pyspark. */ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { @@ -76,7 +76,7 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { "spark-internal"); } - private final List sparkArgs; + final List sparkArgs; private final boolean printHelp; /** @@ -208,7 +208,7 @@ private List buildSparkSubmitCommand(Map env) throws IOE // - properties file. // - SPARK_DRIVER_MEMORY env variable // - SPARK_MEM env variable - // - default value (512m) + // - default value (1g) // Take Thrift Server as daemon String tsMemory = isThriftServer(mainClass) ? System.getenv("SPARK_DAEMON_MEMORY") : null; diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java index b88bba883ac6..6767cc507964 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java @@ -23,7 +23,7 @@ /** * Parser for spark-submit command line options. - *

      + *

      * This class encapsulates the parsing code for spark-submit command line options, so that there * is a single list of options that needs to be maintained (well, sort of, but it makes it harder * to break things). @@ -51,6 +51,7 @@ class SparkSubmitOptionParser { protected final String MASTER = "--master"; protected final String NAME = "--name"; protected final String PACKAGES = "--packages"; + protected final String PACKAGES_EXCLUDE = "--exclude-packages"; protected final String PROPERTIES_FILE = "--properties-file"; protected final String PROXY_USER = "--proxy-user"; protected final String PY_FILES = "--py-files"; @@ -79,10 +80,10 @@ class SparkSubmitOptionParser { * This is the canonical list of spark-submit options. Each entry in the array contains the * different aliases for the same option; the first element of each entry is the "official" * name of the option, passed to {@link #handle(String, String)}. - *

      + *

      * Options not listed here nor in the "switch" list below will result in a call to * {@link $#handleUnknown(String)}. - *

      + *

      * These two arrays are visible for tests. */ final String[][] opts = { @@ -105,6 +106,7 @@ class SparkSubmitOptionParser { { NAME }, { NUM_EXECUTORS }, { PACKAGES }, + { PACKAGES_EXCLUDE }, { PRINCIPAL }, { PROPERTIES_FILE }, { PROXY_USER }, @@ -128,7 +130,7 @@ class SparkSubmitOptionParser { /** * Parse a list of spark-submit command line options. - *

      + *

      * See SparkSubmitArguments.scala for a more formal description of available options. * * @throws IllegalArgumentException If an error is found during parsing. diff --git a/launcher/src/main/java/org/apache/spark/launcher/package-info.java b/launcher/src/main/java/org/apache/spark/launcher/package-info.java index 7ed756f4b859..7c97dba511b2 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/package-info.java +++ b/launcher/src/main/java/org/apache/spark/launcher/package-info.java @@ -17,13 +17,17 @@ /** * Library for launching Spark applications. - *

      + * + *

      * This library allows applications to launch Spark programmatically. There's only one entry * point to the library - the {@link org.apache.spark.launcher.SparkLauncher} class. - *

      + *

      + * + *

      * To launch a Spark application, just instantiate a {@link org.apache.spark.launcher.SparkLauncher} * and configure the application to run. For example: - * + *

      + * *
        * {@code
        *   import org.apache.spark.launcher.SparkLauncher;
      diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java
      index 97043a76cc61..7329ac9f7fb8 100644
      --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java
      +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java
      @@ -194,7 +194,7 @@ private void testCmdBuilder(boolean isDriver) throws Exception {
               if (isDriver) {
                 assertEquals("-XX:MaxPermSize=256m", arg);
               } else {
      -          assertEquals("-XX:MaxPermSize=128m", arg);
      +          assertEquals("-XX:MaxPermSize=256m", arg);
               }
             }
           }
      diff --git a/make-distribution.sh b/make-distribution.sh
      index 9f063da3a16c..04ad0052eb24 100755
      --- a/make-distribution.sh
      +++ b/make-distribution.sh
      @@ -33,7 +33,7 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)"
       DISTDIR="$SPARK_HOME/dist"
       
       SPARK_TACHYON=false
      -TACHYON_VERSION="0.6.4"
      +TACHYON_VERSION="0.7.1"
       TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz"
       TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}"
       
      diff --git a/mllib/pom.xml b/mllib/pom.xml
      index b16058ddc203..22c0c6008ba3 100644
      --- a/mllib/pom.xml
      +++ b/mllib/pom.xml
      @@ -21,7 +21,7 @@
         
           org.apache.spark
           spark-parent_2.10
      -    1.5.0-SNAPSHOT
      +    1.6.0-SNAPSHOT
           ../pom.xml
         
       
      @@ -106,7 +106,7 @@
           
           
             org.mockito
      -      mockito-all
      +      mockito-core
             test
           
           
      diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
      new file mode 100644
      index 000000000000..f632dd603c44
      --- /dev/null
      +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
      @@ -0,0 +1 @@
      +org.apache.spark.ml.source.libsvm.DefaultSource
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
      index a1f3851d804f..a3e59401c5cf 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
      @@ -95,6 +95,8 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
         /** @group setParam */
         def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this }
       
      +  // Below, we clone stages so that modifications to the list of stages will not change
      +  // the Param value in the Pipeline.
         /** @group getParam */
         def getStages: Array[PipelineStage] = $(stages).clone()
       
      @@ -196,6 +198,6 @@ class PipelineModel private[ml] (
         }
       
         override def copy(extra: ParamMap): PipelineModel = {
      -    new PipelineModel(uid, stages.map(_.copy(extra)))
      +    new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent)
         }
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
      index edaa2afb790e..19fe039b8fd0 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
      @@ -122,9 +122,7 @@ abstract class Predictor[
          */
         protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = {
           dataset.select($(labelCol), $(featuresCol))
      -      .map { case Row(label: Double, features: Vector) =>
      -      LabeledPoint(label, features)
      -    }
      +      .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) }
         }
       }
       
      @@ -171,7 +169,7 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
         override def transform(dataset: DataFrame): DataFrame = {
           transformSchema(dataset.schema, logging = true)
           if ($(predictionCol).nonEmpty) {
      -      dataset.withColumn($(predictionCol), callUDF(predict _, DoubleType, col($(featuresCol))))
      +      transformImpl(dataset)
           } else {
             this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
               " since no output columns were set.")
      @@ -179,6 +177,13 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
           }
         }
       
      +  protected def transformImpl(dataset: DataFrame): DataFrame = {
      +    val predictUDF = udf { (features: Any) =>
      +      predict(features.asInstanceOf[FeaturesType])
      +    }
      +    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
      +  }
      +
         /**
          * Predict label for the given features.
          * This internal method is used to implement [[transform()]] and output [[predictionCol]].
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala
      new file mode 100644
      index 000000000000..7429f9d652ac
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala
      @@ -0,0 +1,63 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.ann
      +
      +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV}
      +import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS}
      +
      +/**
      + * In-place DGEMM and DGEMV for Breeze
      + */
      +private[ann] object BreezeUtil {
      +
      +  // TODO: switch to MLlib BLAS interface
      +  private def transposeString(a: BDM[Double]): String = if (a.isTranspose) "T" else "N"
      +
      +  /**
      +   * DGEMM: C := alpha * A * B + beta * C
      +   * @param alpha alpha
      +   * @param a A
      +   * @param b B
      +   * @param beta beta
      +   * @param c C
      +   */
      +  def dgemm(alpha: Double, a: BDM[Double], b: BDM[Double], beta: Double, c: BDM[Double]): Unit = {
      +    // TODO: add code if matrices isTranspose!!!
      +    require(a.cols == b.rows, "A & B Dimension mismatch!")
      +    require(a.rows == c.rows, "A & C Dimension mismatch!")
      +    require(b.cols == c.cols, "A & C Dimension mismatch!")
      +    NativeBLAS.dgemm(transposeString(a), transposeString(b), c.rows, c.cols, a.cols,
      +      alpha, a.data, a.offset, a.majorStride, b.data, b.offset, b.majorStride,
      +      beta, c.data, c.offset, c.rows)
      +  }
      +
      +  /**
      +   * DGEMV: y := alpha * A * x + beta * y
      +   * @param alpha alpha
      +   * @param a A
      +   * @param x x
      +   * @param beta beta
      +   * @param y y
      +   */
      +  def dgemv(alpha: Double, a: BDM[Double], x: BDV[Double], beta: Double, y: BDV[Double]): Unit = {
      +    require(a.cols == x.length, "A & b Dimension mismatch!")
      +    NativeBLAS.dgemv(transposeString(a), a.rows, a.cols,
      +      alpha, a.data, a.offset, a.majorStride, x.data, x.offset, x.stride,
      +      beta, y.data, y.offset, y.stride)
      +  }
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
      new file mode 100644
      index 000000000000..b5258ff34847
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
      @@ -0,0 +1,882 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.ann
      +
      +import breeze.linalg.{*, DenseMatrix => BDM, DenseVector => BDV, Vector => BV, axpy => Baxpy,
      +  sum => Bsum}
      +import breeze.numerics.{log => Blog, sigmoid => Bsigmoid}
      +
      +import org.apache.spark.mllib.linalg.{Vector, Vectors}
      +import org.apache.spark.mllib.optimization._
      +import org.apache.spark.rdd.RDD
      +import org.apache.spark.util.random.XORShiftRandom
      +
      +/**
      + * Trait that holds Layer properties, that are needed to instantiate it.
      + * Implements Layer instantiation.
      + *
      + */
      +private[ann] trait Layer extends Serializable {
      +  /**
      +   * Returns the instance of the layer based on weights provided
      +   * @param weights vector with layer weights
      +   * @param position position of weights in the vector
      +   * @return the layer model
      +   */
      +  def getInstance(weights: Vector, position: Int): LayerModel
      +
      +  /**
      +   * Returns the instance of the layer with random generated weights
      +   * @param seed seed
      +   * @return the layer model
      +   */
      +  def getInstance(seed: Long): LayerModel
      +}
      +
      +/**
      + * Trait that holds Layer weights (or parameters).
      + * Implements functions needed for forward propagation, computing delta and gradient.
      + * Can return weights in Vector format.
      + */
      +private[ann] trait LayerModel extends Serializable {
      +  /**
      +   * number of weights
      +   */
      +  val size: Int
      +
      +  /**
      +   * Evaluates the data (process the data through the layer)
      +   * @param data data
      +   * @return processed data
      +   */
      +  def eval(data: BDM[Double]): BDM[Double]
      +
      +  /**
      +   * Computes the delta for back propagation
      +   * @param nextDelta delta of the next layer
      +   * @param input input data
      +   * @return delta
      +   */
      +  def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double]
      +
      +  /**
      +   * Computes the gradient
      +   * @param delta delta for this layer
      +   * @param input input data
      +   * @return gradient
      +   */
      +  def grad(delta: BDM[Double], input: BDM[Double]): Array[Double]
      +
      +  /**
      +   * Returns weights for the layer in a single vector
      +   * @return layer weights
      +   */
      +  def weights(): Vector
      +}
      +
      +/**
      + * Layer properties of affine transformations, that is y=A*x+b
      + * @param numIn number of inputs
      + * @param numOut number of outputs
      + */
      +private[ann] class AffineLayer(val numIn: Int, val numOut: Int) extends Layer {
      +
      +  override def getInstance(weights: Vector, position: Int): LayerModel = {
      +    AffineLayerModel(this, weights, position)
      +  }
      +
      +  override def getInstance(seed: Long = 11L): LayerModel = {
      +    AffineLayerModel(this, seed)
      +  }
      +}
      +
      +/**
      + * Model of Affine layer y=A*x+b
      + * @param w weights (matrix A)
      + * @param b bias (vector b)
      + */
      +private[ann] class AffineLayerModel private(w: BDM[Double], b: BDV[Double]) extends LayerModel {
      +  val size = w.size + b.length
      +  val gwb = new Array[Double](size)
      +  private lazy val gw: BDM[Double] = new BDM[Double](w.rows, w.cols, gwb)
      +  private lazy val gb: BDV[Double] = new BDV[Double](gwb, w.size)
      +  private var z: BDM[Double] = null
      +  private var d: BDM[Double] = null
      +  private var ones: BDV[Double] = null
      +
      +  override def eval(data: BDM[Double]): BDM[Double] = {
      +    if (z == null || z.cols != data.cols) z = new BDM[Double](w.rows, data.cols)
      +    z(::, *) := b
      +    BreezeUtil.dgemm(1.0, w, data, 1.0, z)
      +    z
      +  }
      +
      +  override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = {
      +    if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](w.cols, nextDelta.cols)
      +    BreezeUtil.dgemm(1.0, w.t, nextDelta, 0.0, d)
      +    d
      +  }
      +
      +  override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = {
      +    BreezeUtil.dgemm(1.0 / input.cols, delta, input.t, 0.0, gw)
      +    if (ones == null || ones.length != delta.cols) ones = BDV.ones[Double](delta.cols)
      +    BreezeUtil.dgemv(1.0 / input.cols, delta, ones, 0.0, gb)
      +    gwb
      +  }
      +
      +  override def weights(): Vector = AffineLayerModel.roll(w, b)
      +}
      +
      +/**
      + * Fabric for Affine layer models
      + */
      +private[ann] object AffineLayerModel {
      +
      +  /**
      +   * Creates a model of Affine layer
      +   * @param layer layer properties
      +   * @param weights vector with weights
      +   * @param position position of weights in the vector
      +   * @return model of Affine layer
      +   */
      +  def apply(layer: AffineLayer, weights: Vector, position: Int): AffineLayerModel = {
      +    val (w, b) = unroll(weights, position, layer.numIn, layer.numOut)
      +    new AffineLayerModel(w, b)
      +  }
      +
      +  /**
      +   * Creates a model of Affine layer
      +   * @param layer layer properties
      +   * @param seed seed
      +   * @return model of Affine layer
      +   */
      +  def apply(layer: AffineLayer, seed: Long): AffineLayerModel = {
      +    val (w, b) = randomWeights(layer.numIn, layer.numOut, seed)
      +    new AffineLayerModel(w, b)
      +  }
      +
      +  /**
      +   * Unrolls the weights from the vector
      +   * @param weights vector with weights
      +   * @param position position of weights for this layer
      +   * @param numIn number of layer inputs
      +   * @param numOut number of layer outputs
      +   * @return matrix A and vector b
      +   */
      +  def unroll(
      +    weights: Vector,
      +    position: Int,
      +    numIn: Int,
      +    numOut: Int): (BDM[Double], BDV[Double]) = {
      +    val weightsCopy = weights.toArray
      +    // TODO: the array is not copied to BDMs, make sure this is OK!
      +    val a = new BDM[Double](numOut, numIn, weightsCopy, position)
      +    val b = new BDV[Double](weightsCopy, position + (numOut * numIn), 1, numOut)
      +    (a, b)
      +  }
      +
      +  /**
      +   * Roll the layer weights into a vector
      +   * @param a matrix A
      +   * @param b vector b
      +   * @return vector of weights
      +   */
      +  def roll(a: BDM[Double], b: BDV[Double]): Vector = {
      +    val result = new Array[Double](a.size + b.length)
      +    // TODO: make sure that we need to copy!
      +    System.arraycopy(a.toArray, 0, result, 0, a.size)
      +    System.arraycopy(b.toArray, 0, result, a.size, b.length)
      +    Vectors.dense(result)
      +  }
      +
      +  /**
      +   * Generate random weights for the layer
      +   * @param numIn number of inputs
      +   * @param numOut number of outputs
      +   * @param seed seed
      +   * @return (matrix A, vector b)
      +   */
      +  def randomWeights(numIn: Int, numOut: Int, seed: Long = 11L): (BDM[Double], BDV[Double]) = {
      +    val rand: XORShiftRandom = new XORShiftRandom(seed)
      +    val weights = BDM.fill[Double](numOut, numIn){ (rand.nextDouble * 4.8 - 2.4) / numIn }
      +    val bias = BDV.fill[Double](numOut){ (rand.nextDouble * 4.8 - 2.4) / numIn }
      +    (weights, bias)
      +  }
      +}
      +
      +/**
      + * Trait for functions and their derivatives for functional layers
      + */
      +private[ann] trait ActivationFunction extends Serializable {
      +
      +  /**
      +   * Implements a function
      +   * @param x input data
      +   * @param y output data
      +   */
      +  def eval(x: BDM[Double], y: BDM[Double]): Unit
      +
      +  /**
      +   * Implements a derivative of a function (needed for the back propagation)
      +   * @param x input data
      +   * @param y output data
      +   */
      +  def derivative(x: BDM[Double], y: BDM[Double]): Unit
      +
      +  /**
      +   * Implements a cross entropy error of a function.
      +   * Needed if the functional layer that contains this function is the output layer
      +   * of the network.
      +   * @param target target output
      +   * @param output computed output
      +   * @param result intermediate result
      +   * @return cross-entropy
      +   */
      +  def crossEntropy(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double
      +
      +  /**
      +   * Implements a mean squared error of a function
      +   * @param target target output
      +   * @param output computed output
      +   * @param result intermediate result
      +   * @return mean squared error
      +   */
      +  def squared(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double
      +}
      +
      +/**
      + * Implements in-place application of functions
      + */
      +private[ann] object ActivationFunction {
      +
      +  def apply(x: BDM[Double], y: BDM[Double], func: Double => Double): Unit = {
      +    var i = 0
      +    while (i < x.rows) {
      +      var j = 0
      +      while (j < x.cols) {
      +        y(i, j) = func(x(i, j))
      +        j += 1
      +      }
      +      i += 1
      +    }
      +  }
      +
      +  def apply(
      +    x1: BDM[Double],
      +    x2: BDM[Double],
      +    y: BDM[Double],
      +    func: (Double, Double) => Double): Unit = {
      +    var i = 0
      +    while (i < x1.rows) {
      +      var j = 0
      +      while (j < x1.cols) {
      +        y(i, j) = func(x1(i, j), x2(i, j))
      +        j += 1
      +      }
      +      i += 1
      +    }
      +  }
      +}
      +
      +/**
      + * Implements SoftMax activation function
      + */
      +private[ann] class SoftmaxFunction extends ActivationFunction {
      +  override def eval(x: BDM[Double], y: BDM[Double]): Unit = {
      +    var j = 0
      +    // find max value to make sure later that exponent is computable
      +    while (j < x.cols) {
      +      var i = 0
      +      var max = Double.MinValue
      +      while (i < x.rows) {
      +        if (x(i, j) > max) {
      +          max = x(i, j)
      +        }
      +        i += 1
      +      }
      +      var sum = 0.0
      +      i = 0
      +      while (i < x.rows) {
      +        val res = Math.exp(x(i, j) - max)
      +        y(i, j) = res
      +        sum += res
      +        i += 1
      +      }
      +      i = 0
      +      while (i < x.rows) {
      +        y(i, j) /= sum
      +        i += 1
      +      }
      +      j += 1
      +    }
      +  }
      +
      +  override def crossEntropy(
      +    output: BDM[Double],
      +    target: BDM[Double],
      +    result: BDM[Double]): Double = {
      +    def m(o: Double, t: Double): Double = o - t
      +    ActivationFunction(output, target, result, m)
      +    -Bsum( target :* Blog(output)) / output.cols
      +  }
      +
      +  override def derivative(x: BDM[Double], y: BDM[Double]): Unit = {
      +    def sd(z: Double): Double = (1 - z) * z
      +    ActivationFunction(x, y, sd)
      +  }
      +
      +  override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = {
      +    throw new UnsupportedOperationException("Sorry, squared error is not defined for SoftMax.")
      +  }
      +}
      +
      +/**
      + * Implements Sigmoid activation function
      + */
      +private[ann] class SigmoidFunction extends ActivationFunction {
      +  override def eval(x: BDM[Double], y: BDM[Double]): Unit = {
      +    def s(z: Double): Double = Bsigmoid(z)
      +    ActivationFunction(x, y, s)
      +  }
      +
      +  override def crossEntropy(
      +    output: BDM[Double],
      +    target: BDM[Double],
      +    result: BDM[Double]): Double = {
      +    def m(o: Double, t: Double): Double = o - t
      +    ActivationFunction(output, target, result, m)
      +    -Bsum(target :* Blog(output)) / output.cols
      +  }
      +
      +  override def derivative(x: BDM[Double], y: BDM[Double]): Unit = {
      +    def sd(z: Double): Double = (1 - z) * z
      +    ActivationFunction(x, y, sd)
      +  }
      +
      +  override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = {
      +    // TODO: make it readable
      +    def m(o: Double, t: Double): Double = (o - t)
      +    ActivationFunction(output, target, result, m)
      +    val e = Bsum(result :* result) / 2 / output.cols
      +    def m2(x: Double, o: Double) = x * (o - o * o)
      +    ActivationFunction(result, output, result, m2)
      +    e
      +  }
      +}
      +
      +/**
      + * Functional layer properties, y = f(x)
      + * @param activationFunction activation function
      + */
      +private[ann] class FunctionalLayer (val activationFunction: ActivationFunction) extends Layer {
      +  override def getInstance(weights: Vector, position: Int): LayerModel = getInstance(0L)
      +
      +  override def getInstance(seed: Long): LayerModel =
      +    FunctionalLayerModel(this)
      +}
      +
      +/**
      + * Functional layer model. Holds no weights.
      + * @param activationFunction activation function
      + */
      +private[ann] class FunctionalLayerModel private (val activationFunction: ActivationFunction)
      +  extends LayerModel {
      +  val size = 0
      +  // matrices for in-place computations
      +  // outputs
      +  private var f: BDM[Double] = null
      +  // delta
      +  private var d: BDM[Double] = null
      +  // matrix for error computation
      +  private var e: BDM[Double] = null
      +  // delta gradient
      +  private lazy val dg = new Array[Double](0)
      +
      +  override def eval(data: BDM[Double]): BDM[Double] = {
      +    if (f == null || f.cols != data.cols) f = new BDM[Double](data.rows, data.cols)
      +    activationFunction.eval(data, f)
      +    f
      +  }
      +
      +  override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = {
      +    if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](nextDelta.rows, nextDelta.cols)
      +    activationFunction.derivative(input, d)
      +    d :*= nextDelta
      +    d
      +  }
      +
      +  override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = dg
      +
      +  override def weights(): Vector = Vectors.dense(new Array[Double](0))
      +
      +  def crossEntropy(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = {
      +    if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols)
      +    val error = activationFunction.crossEntropy(output, target, e)
      +    (e, error)
      +  }
      +
      +  def squared(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = {
      +    if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols)
      +    val error = activationFunction.squared(output, target, e)
      +    (e, error)
      +  }
      +
      +  def error(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = {
      +    // TODO: allow user pick error
      +    activationFunction match {
      +      case sigmoid: SigmoidFunction => squared(output, target)
      +      case softmax: SoftmaxFunction => crossEntropy(output, target)
      +    }
      +  }
      +}
      +
      +/**
      + * Fabric of functional layer models
      + */
      +private[ann] object FunctionalLayerModel {
      +  def apply(layer: FunctionalLayer): FunctionalLayerModel =
      +    new FunctionalLayerModel(layer.activationFunction)
      +}
      +
      +/**
      + * Trait for the artificial neural network (ANN) topology properties
      + */
      +private[ann] trait Topology extends Serializable{
      +  def getInstance(weights: Vector): TopologyModel
      +  def getInstance(seed: Long): TopologyModel
      +}
      +
      +/**
      + * Trait for ANN topology model
      + */
      +private[ann] trait TopologyModel extends Serializable{
      +  /**
      +   * Forward propagation
      +   * @param data input data
      +   * @return array of outputs for each of the layers
      +   */
      +  def forward(data: BDM[Double]): Array[BDM[Double]]
      +
      +  /**
      +   * Prediction of the model
      +   * @param data input data
      +   * @return prediction
      +   */
      +  def predict(data: Vector): Vector
      +
      +  /**
      +   * Computes gradient for the network
      +   * @param data input data
      +   * @param target target output
      +   * @param cumGradient cumulative gradient
      +   * @param blockSize block size
      +   * @return error
      +   */
      +  def computeGradient(data: BDM[Double], target: BDM[Double], cumGradient: Vector,
      +                      blockSize: Int): Double
      +
      +  /**
      +   * Returns the weights of the ANN
      +   * @return weights
      +   */
      +  def weights(): Vector
      +}
      +
      +/**
      + * Feed forward ANN
      + * @param layers
      + */
      +private[ann] class FeedForwardTopology private(val layers: Array[Layer]) extends Topology {
      +  override def getInstance(weights: Vector): TopologyModel = FeedForwardModel(this, weights)
      +
      +  override def getInstance(seed: Long): TopologyModel = FeedForwardModel(this, seed)
      +}
      +
      +/**
      + * Factory for some of the frequently-used topologies
      + */
      +private[ml] object FeedForwardTopology {
      +  /**
      +   * Creates a feed forward topology from the array of layers
      +   * @param layers array of layers
      +   * @return feed forward topology
      +   */
      +  def apply(layers: Array[Layer]): FeedForwardTopology = {
      +    new FeedForwardTopology(layers)
      +  }
      +
      +  /**
      +   * Creates a multi-layer perceptron
      +   * @param layerSizes sizes of layers including input and output size
      +   * @param softmax wether to use SoftMax or Sigmoid function for an output layer.
      +   *                Softmax is default
      +   * @return multilayer perceptron topology
      +   */
      +  def multiLayerPerceptron(layerSizes: Array[Int], softmax: Boolean = true): FeedForwardTopology = {
      +    val layers = new Array[Layer]((layerSizes.length - 1) * 2)
      +    for(i <- 0 until layerSizes.length - 1){
      +      layers(i * 2) = new AffineLayer(layerSizes(i), layerSizes(i + 1))
      +      layers(i * 2 + 1) =
      +        if (softmax && i == layerSizes.length - 2) {
      +          new FunctionalLayer(new SoftmaxFunction())
      +        } else {
      +          new FunctionalLayer(new SigmoidFunction())
      +        }
      +    }
      +    FeedForwardTopology(layers)
      +  }
      +}
      +
      +/**
      + * Model of Feed Forward Neural Network.
      + * Implements forward, gradient computation and can return weights in vector format.
      + * @param layerModels models of layers
      + * @param topology topology of the network
      + */
      +private[ml] class FeedForwardModel private(
      +    val layerModels: Array[LayerModel],
      +    val topology: FeedForwardTopology) extends TopologyModel {
      +  override def forward(data: BDM[Double]): Array[BDM[Double]] = {
      +    val outputs = new Array[BDM[Double]](layerModels.length)
      +    outputs(0) = layerModels(0).eval(data)
      +    for (i <- 1 until layerModels.length) {
      +      outputs(i) = layerModels(i).eval(outputs(i-1))
      +    }
      +    outputs
      +  }
      +
      +  override def computeGradient(
      +    data: BDM[Double],
      +    target: BDM[Double],
      +    cumGradient: Vector,
      +    realBatchSize: Int): Double = {
      +    val outputs = forward(data)
      +    val deltas = new Array[BDM[Double]](layerModels.length)
      +    val L = layerModels.length - 1
      +    val (newE, newError) = layerModels.last match {
      +      case flm: FunctionalLayerModel => flm.error(outputs.last, target)
      +      case _ =>
      +        throw new UnsupportedOperationException("Non-functional layer not supported at the top")
      +    }
      +    deltas(L) = new BDM[Double](0, 0)
      +    deltas(L - 1) = newE
      +    for (i <- (L - 2) to (0, -1)) {
      +      deltas(i) = layerModels(i + 1).prevDelta(deltas(i + 1), outputs(i + 1))
      +    }
      +    val grads = new Array[Array[Double]](layerModels.length)
      +    for (i <- 0 until layerModels.length) {
      +      val input = if (i==0) data else outputs(i - 1)
      +      grads(i) = layerModels(i).grad(deltas(i), input)
      +    }
      +    // update cumGradient
      +    val cumGradientArray = cumGradient.toArray
      +    var offset = 0
      +    // TODO: extract roll
      +    for (i <- 0 until grads.length) {
      +      val gradArray = grads(i)
      +      var k = 0
      +      while (k < gradArray.length) {
      +        cumGradientArray(offset + k) += gradArray(k)
      +        k += 1
      +      }
      +      offset += gradArray.length
      +    }
      +    newError
      +  }
      +
      +  // TODO: do we really need to copy the weights? they should be read-only
      +  override def weights(): Vector = {
      +    // TODO: extract roll
      +    var size = 0
      +    for (i <- 0 until layerModels.length) {
      +      size += layerModels(i).size
      +    }
      +    val array = new Array[Double](size)
      +    var offset = 0
      +    for (i <- 0 until layerModels.length) {
      +      val layerWeights = layerModels(i).weights().toArray
      +      System.arraycopy(layerWeights, 0, array, offset, layerWeights.length)
      +      offset += layerWeights.length
      +    }
      +    Vectors.dense(array)
      +  }
      +
      +  override def predict(data: Vector): Vector = {
      +    val size = data.size
      +    val result = forward(new BDM[Double](size, 1, data.toArray))
      +    Vectors.dense(result.last.toArray)
      +  }
      +}
      +
      +/**
      + * Fabric for feed forward ANN models
      + */
      +private[ann] object FeedForwardModel {
      +
      +  /**
      +   * Creates a model from a topology and weights
      +   * @param topology topology
      +   * @param weights weights
      +   * @return model
      +   */
      +  def apply(topology: FeedForwardTopology, weights: Vector): FeedForwardModel = {
      +    val layers = topology.layers
      +    val layerModels = new Array[LayerModel](layers.length)
      +    var offset = 0
      +    for (i <- 0 until layers.length) {
      +      layerModels(i) = layers(i).getInstance(weights, offset)
      +      offset += layerModels(i).size
      +    }
      +    new FeedForwardModel(layerModels, topology)
      +  }
      +
      +  /**
      +   * Creates a model given a topology and seed
      +   * @param topology topology
      +   * @param seed seed for generating the weights
      +   * @return model
      +   */
      +  def apply(topology: FeedForwardTopology, seed: Long = 11L): FeedForwardModel = {
      +    val layers = topology.layers
      +    val layerModels = new Array[LayerModel](layers.length)
      +    var offset = 0
      +    for(i <- 0 until layers.length){
      +      layerModels(i) = layers(i).getInstance(seed)
      +      offset += layerModels(i).size
      +    }
      +    new FeedForwardModel(layerModels, topology)
      +  }
      +}
      +
      +/**
      + * Neural network gradient. Does nothing but calling Model's gradient
      + * @param topology topology
      + * @param dataStacker data stacker
      + */
      +private[ann] class ANNGradient(topology: Topology, dataStacker: DataStacker) extends Gradient {
      +
      +  override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {
      +    val gradient = Vectors.zeros(weights.size)
      +    val loss = compute(data, label, weights, gradient)
      +    (gradient, loss)
      +  }
      +
      +  override def compute(
      +    data: Vector,
      +    label: Double,
      +    weights: Vector,
      +    cumGradient: Vector): Double = {
      +    val (input, target, realBatchSize) = dataStacker.unstack(data)
      +    val model = topology.getInstance(weights)
      +    model.computeGradient(input, target, cumGradient, realBatchSize)
      +  }
      +}
      +
      +/**
      + * Stacks pairs of training samples (input, output) in one vector allowing them to pass
      + * through Optimizer/Gradient interfaces. If stackSize is more than one, makes blocks
      + * or matrices of inputs and outputs and then stack them in one vector.
      + * This can be used for further batch computations after unstacking.
      + * @param stackSize stack size
      + * @param inputSize size of the input vectors
      + * @param outputSize size of the output vectors
      + */
      +private[ann] class DataStacker(stackSize: Int, inputSize: Int, outputSize: Int)
      +  extends Serializable {
      +
      +  /**
      +   * Stacks the data
      +   * @param data RDD of vector pairs
      +   * @return RDD of double (always zero) and vector that contains the stacked vectors
      +   */
      +  def stack(data: RDD[(Vector, Vector)]): RDD[(Double, Vector)] = {
      +    val stackedData = if (stackSize == 1) {
      +      data.map { v =>
      +        (0.0,
      +          Vectors.fromBreeze(BDV.vertcat(
      +            v._1.toBreeze.toDenseVector,
      +            v._2.toBreeze.toDenseVector))
      +          ) }
      +    } else {
      +      data.mapPartitions { it =>
      +        it.grouped(stackSize).map { seq =>
      +          val size = seq.size
      +          val bigVector = new Array[Double](inputSize * size + outputSize * size)
      +          var i = 0
      +          seq.foreach { case (in, out) =>
      +            System.arraycopy(in.toArray, 0, bigVector, i * inputSize, inputSize)
      +            System.arraycopy(out.toArray, 0, bigVector,
      +              inputSize * size + i * outputSize, outputSize)
      +            i += 1
      +          }
      +          (0.0, Vectors.dense(bigVector))
      +        }
      +      }
      +    }
      +    stackedData
      +  }
      +
      +  /**
      +   * Unstack the stacked vectors into matrices for batch operations
      +   * @param data stacked vector
      +   * @return pair of matrices holding input and output data and the real stack size
      +   */
      +  def unstack(data: Vector): (BDM[Double], BDM[Double], Int) = {
      +    val arrData = data.toArray
      +    val realStackSize = arrData.length / (inputSize + outputSize)
      +    val input = new BDM(inputSize, realStackSize, arrData)
      +    val target = new BDM(outputSize, realStackSize, arrData, inputSize * realStackSize)
      +    (input, target, realStackSize)
      +  }
      +}
      +
      +/**
      + * Simple updater
      + */
      +private[ann] class ANNUpdater extends Updater {
      +
      +  override def compute(
      +    weightsOld: Vector,
      +    gradient: Vector,
      +    stepSize: Double,
      +    iter: Int,
      +    regParam: Double): (Vector, Double) = {
      +    val thisIterStepSize = stepSize
      +    val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector
      +    Baxpy(-thisIterStepSize, gradient.toBreeze, brzWeights)
      +    (Vectors.fromBreeze(brzWeights), 0)
      +  }
      +}
      +
      +/**
      + * MLlib-style trainer class that trains a network given the data and topology
      + * @param topology topology of ANN
      + * @param inputSize input size
      + * @param outputSize output size
      + */
      +private[ml] class FeedForwardTrainer(
      +    topology: Topology,
      +    val inputSize: Int,
      +    val outputSize: Int) extends Serializable {
      +
      +  // TODO: what if we need to pass random seed?
      +  private var _weights = topology.getInstance(11L).weights()
      +  private var _stackSize = 128
      +  private var dataStacker = new DataStacker(_stackSize, inputSize, outputSize)
      +  private var _gradient: Gradient = new ANNGradient(topology, dataStacker)
      +  private var _updater: Updater = new ANNUpdater()
      +  private var optimizer: Optimizer = LBFGSOptimizer.setConvergenceTol(1e-4).setNumIterations(100)
      +
      +  /**
      +   * Returns weights
      +   * @return weights
      +   */
      +  def getWeights: Vector = _weights
      +
      +  /**
      +   * Sets weights
      +   * @param value weights
      +   * @return trainer
      +   */
      +  def setWeights(value: Vector): FeedForwardTrainer = {
      +    _weights = value
      +    this
      +  }
      +
      +  /**
      +   * Sets the stack size
      +   * @param value stack size
      +   * @return trainer
      +   */
      +  def setStackSize(value: Int): FeedForwardTrainer = {
      +    _stackSize = value
      +    dataStacker = new DataStacker(value, inputSize, outputSize)
      +    this
      +  }
      +
      +  /**
      +   * Sets the SGD optimizer
      +   * @return SGD optimizer
      +   */
      +  def SGDOptimizer: GradientDescent = {
      +    val sgd = new GradientDescent(_gradient, _updater)
      +    optimizer = sgd
      +    sgd
      +  }
      +
      +  /**
      +   * Sets the LBFGS optimizer
      +   * @return LBGS optimizer
      +   */
      +  def LBFGSOptimizer: LBFGS = {
      +    val lbfgs = new LBFGS(_gradient, _updater)
      +    optimizer = lbfgs
      +    lbfgs
      +  }
      +
      +  /**
      +   * Sets the updater
      +   * @param value updater
      +   * @return trainer
      +   */
      +  def setUpdater(value: Updater): FeedForwardTrainer = {
      +    _updater = value
      +    updateUpdater(value)
      +    this
      +  }
      +
      +  /**
      +   * Sets the gradient
      +   * @param value gradient
      +   * @return trainer
      +   */
      +  def setGradient(value: Gradient): FeedForwardTrainer = {
      +    _gradient = value
      +    updateGradient(value)
      +    this
      +  }
      +
      +  private[this] def updateGradient(gradient: Gradient): Unit = {
      +    optimizer match {
      +      case lbfgs: LBFGS => lbfgs.setGradient(gradient)
      +      case sgd: GradientDescent => sgd.setGradient(gradient)
      +      case other => throw new UnsupportedOperationException(
      +        s"Only LBFGS and GradientDescent are supported but got ${other.getClass}.")
      +    }
      +  }
      +
      +  private[this] def updateUpdater(updater: Updater): Unit = {
      +    optimizer match {
      +      case lbfgs: LBFGS => lbfgs.setUpdater(updater)
      +      case sgd: GradientDescent => sgd.setUpdater(updater)
      +      case other => throw new UnsupportedOperationException(
      +        s"Only LBFGS and GradientDescent are supported but got ${other.getClass}.")
      +    }
      +  }
      +
      +  /**
      +   * Trains the ANN
      +   * @param data RDD of input and output vector pairs
      +   * @return model
      +   */
      +  def train(data: RDD[(Vector, Vector)]): TopologyModel = {
      +    val newWeights = optimizer.optimize(dataStacker.stack(data), getWeights)
      +    topology.getInstance(newWeights)
      +  }
      +
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
      index ce43a450daad..e479f169021d 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
      @@ -20,7 +20,7 @@ package org.apache.spark.ml.attribute
       import scala.annotation.varargs
       
       import org.apache.spark.annotation.DeveloperApi
      -import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, StructField}
      +import org.apache.spark.sql.types.{DoubleType, NumericType, Metadata, MetadataBuilder, StructField}
       
       /**
        * :: DeveloperApi ::
      @@ -127,7 +127,7 @@ private[attribute] trait AttributeFactory {
          * Creates an [[Attribute]] from a [[StructField]] instance.
          */
         def fromStructField(field: StructField): Attribute = {
      -    require(field.dataType == DoubleType)
      +    require(field.dataType.isInstanceOf[NumericType])
           val metadata = field.metadata
           val mlAttr = AttributeKeys.ML_ATTR
           if (metadata.contains(mlAttr)) {
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
      index 14c285dbfc54..45df557a8990 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
      @@ -18,14 +18,13 @@
       package org.apache.spark.ml.classification
       
       import org.apache.spark.annotation.DeveloperApi
      -import org.apache.spark.ml.param.ParamMap
       import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor}
       import org.apache.spark.ml.param.shared.HasRawPredictionCol
       import org.apache.spark.ml.util.SchemaUtils
       import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
       import org.apache.spark.sql.DataFrame
       import org.apache.spark.sql.functions._
      -import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
      +import org.apache.spark.sql.types.{DataType, StructType}
       
       
       /**
      @@ -102,15 +101,20 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
           var outputData = dataset
           var numColsOutput = 0
           if (getRawPredictionCol != "") {
      -      outputData = outputData.withColumn(getRawPredictionCol,
      -        callUDF(predictRaw _, new VectorUDT, col(getFeaturesCol)))
      +      val predictRawUDF = udf { (features: Any) =>
      +        predictRaw(features.asInstanceOf[FeaturesType])
      +      }
      +      outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)))
             numColsOutput += 1
           }
           if (getPredictionCol != "") {
             val predUDF = if (getRawPredictionCol != "") {
      -        callUDF(raw2prediction _, DoubleType, col(getRawPredictionCol))
      +        udf(raw2prediction _).apply(col(getRawPredictionCol))
             } else {
      -        callUDF(predict _, DoubleType, col(getFeaturesCol))
      +        val predictUDF = udf { (features: Any) =>
      +          predict(features.asInstanceOf[FeaturesType])
      +        }
      +        predictUDF(col(getFeaturesCol))
             }
             outputData = outputData.withColumn(getPredictionCol, predUDF)
             numColsOutput += 1
      @@ -151,5 +155,5 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
          * This may be overridden to support thresholds which favor particular labels.
          * @return  predicted label
          */
      -  protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.toDense.argmax
      +  protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.argmax
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
      index 2dc1824964a4..b8eb49f9bdb4 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
      @@ -18,13 +18,13 @@
       package org.apache.spark.ml.classification
       
       import org.apache.spark.annotation.Experimental
      -import org.apache.spark.ml.{PredictionModel, Predictor}
       import org.apache.spark.ml.param.ParamMap
      +import org.apache.spark.ml.param.shared.HasCheckpointInterval
       import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
      +import org.apache.spark.ml.tree.impl.RandomForest
       import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
      -import org.apache.spark.mllib.linalg.Vector
      +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
       import org.apache.spark.mllib.regression.LabeledPoint
      -import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
       import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
       import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
       import org.apache.spark.rdd.RDD
      @@ -39,7 +39,7 @@ import org.apache.spark.sql.DataFrame
        */
       @Experimental
       final class DecisionTreeClassifier(override val uid: String)
      -  extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
      +  extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
         with DecisionTreeParams with TreeClassifierParams {
       
         def this() = this(Identifiable.randomUID("dtc"))
      @@ -75,8 +75,9 @@ final class DecisionTreeClassifier(override val uid: String)
           }
           val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
           val strategy = getOldStrategy(categoricalFeatures, numClasses)
      -    val oldModel = OldDecisionTree.train(oldDataset, strategy)
      -    DecisionTreeClassificationModel.fromOld(oldModel, this, categoricalFeatures)
      +    val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
      +      seed = 0L, parentUID = Some(uid))
      +    trees.head.asInstanceOf[DecisionTreeClassificationModel]
         }
       
         /** (private[ml]) Create a Strategy instance to use with the old API. */
      @@ -105,23 +106,47 @@ object DecisionTreeClassifier {
       @Experimental
       final class DecisionTreeClassificationModel private[ml] (
           override val uid: String,
      -    override val rootNode: Node)
      -  extends PredictionModel[Vector, DecisionTreeClassificationModel]
      +    override val rootNode: Node,
      +    override val numClasses: Int)
      +  extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
         with DecisionTreeModel with Serializable {
       
         require(rootNode != null,
           "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
       
      +  /**
      +   * Construct a decision tree classification model.
      +   * @param rootNode  Root node of tree, with other nodes attached.
      +   */
      +  private[ml] def this(rootNode: Node, numClasses: Int) =
      +    this(Identifiable.randomUID("dtc"), rootNode, numClasses)
      +
         override protected def predict(features: Vector): Double = {
      -    rootNode.predict(features)
      +    rootNode.predictImpl(features).prediction
      +  }
      +
      +  override protected def predictRaw(features: Vector): Vector = {
      +    Vectors.dense(rootNode.predictImpl(features).impurityStats.stats.clone())
      +  }
      +
      +  override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
      +    rawPrediction match {
      +      case dv: DenseVector =>
      +        ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(dv)
      +        dv
      +      case sv: SparseVector =>
      +        throw new RuntimeException("Unexpected error in DecisionTreeClassificationModel:" +
      +          " raw2probabilityInPlace encountered SparseVector")
      +    }
         }
       
         override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
      -    copyValues(new DecisionTreeClassificationModel(uid, rootNode), extra)
      +    copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra)
      +      .setParent(parent)
         }
       
         override def toString: String = {
      -    s"DecisionTreeClassificationModel of depth $depth with $numNodes nodes"
      +    s"DecisionTreeClassificationModel (uid=$uid) of depth $depth with $numNodes nodes"
         }
       
         /** (private[ml]) Convert to a model in the old API */
      @@ -142,6 +167,6 @@ private[ml] object DecisionTreeClassificationModel {
               s" DecisionTreeClassificationModel (new API).  Algo is: ${oldModel.algo}")
           val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
           val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
      -    new DecisionTreeClassificationModel(uid, rootNode)
      +    new DecisionTreeClassificationModel(uid, rootNode, -1)
         }
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
      index 554e3b8e052b..ad8683648b97 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
      @@ -34,6 +34,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss}
       import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
       import org.apache.spark.rdd.RDD
       import org.apache.spark.sql.DataFrame
      +import org.apache.spark.sql.functions._
      +import org.apache.spark.sql.types.DoubleType
       
       /**
        * :: Experimental ::
      @@ -177,21 +179,28 @@ final class GBTClassificationModel(
       
         override def treeWeights: Array[Double] = _treeWeights
       
      +  override protected def transformImpl(dataset: DataFrame): DataFrame = {
      +    val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
      +    val predictUDF = udf { (features: Any) =>
      +      bcastModel.value.predict(features.asInstanceOf[Vector])
      +    }
      +    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
      +  }
      +
         override protected def predict(features: Vector): Double = {
      -    // TODO: Override transform() to broadcast model: SPARK-7127
           // TODO: When we add a generic Boosting class, handle transform there?  SPARK-7129
           // Classifies by thresholding sum of weighted tree predictions
      -    val treePredictions = _trees.map(_.rootNode.predict(features))
      +    val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
           val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
           if (prediction > 0.0) 1.0 else 0.0
         }
       
         override def copy(extra: ParamMap): GBTClassificationModel = {
      -    copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra)
      +    copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra).setParent(parent)
         }
       
         override def toString: String = {
      -    s"GBTClassificationModel with $numTrees trees"
      +    s"GBTClassificationModel (uid=$uid) with $numTrees trees"
         }
       
         /** (private[ml]) Convert to a model in the old API */
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
      index 2e6eedd45ab0..bd96e8d000ff 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
      @@ -19,7 +19,7 @@ package org.apache.spark.ml.classification
       
       import scala.collection.mutable
       
      -import breeze.linalg.{DenseVector => BDV, norm => brzNorm}
      +import breeze.linalg.{DenseVector => BDV}
       import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
       
       import org.apache.spark.{Logging, SparkException}
      @@ -29,11 +29,12 @@ import org.apache.spark.ml.param.shared._
       import org.apache.spark.ml.util.Identifiable
       import org.apache.spark.mllib.linalg._
       import org.apache.spark.mllib.linalg.BLAS._
      -import org.apache.spark.mllib.regression.LabeledPoint
      +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
       import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
       import org.apache.spark.mllib.util.MLUtils
       import org.apache.spark.rdd.RDD
      -import org.apache.spark.sql.DataFrame
      +import org.apache.spark.sql.{DataFrame, Row}
      +import org.apache.spark.sql.functions.{col, lit}
       import org.apache.spark.storage.StorageLevel
       
       /**
      @@ -41,12 +42,126 @@ import org.apache.spark.storage.StorageLevel
        */
       private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
         with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol
      -  with HasThreshold
      +  with HasStandardization with HasWeightCol with HasThreshold {
      +
      +  /**
      +   * Set threshold in binary classification, in range [0, 1].
      +   *
      +   * If the estimated probability of class label 1 is > threshold, then predict 1, else 0.
      +   * A high threshold encourages the model to predict 0 more often;
      +   * a low threshold encourages the model to predict 1 more often.
      +   *
      +   * Note: Calling this with threshold p is equivalent to calling `setThresholds(Array(1-p, p))`.
      +   *       When [[setThreshold()]] is called, any user-set value for [[thresholds]] will be cleared.
      +   *       If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be
      +   *       equivalent.
      +   *
      +   * Default is 0.5.
      +   * @group setParam
      +   */
      +  def setThreshold(value: Double): this.type = {
      +    if (isSet(thresholds)) clear(thresholds)
      +    set(threshold, value)
      +  }
      +
      +  /**
      +   * Get threshold for binary classification.
      +   *
      +   * If [[threshold]] is set, returns that value.
      +   * Otherwise, if [[thresholds]] is set with length 2 (i.e., binary classification),
      +   * this returns the equivalent threshold: {{{1 / (1 + thresholds(0) / thresholds(1))}}}.
      +   * Otherwise, returns [[threshold]] default value.
      +   *
      +   * @group getParam
      +   * @throws IllegalArgumentException if [[thresholds]] is set to an array of length other than 2.
      +   */
      +  override def getThreshold: Double = {
      +    checkThresholdConsistency()
      +    if (isSet(thresholds)) {
      +      val ts = $(thresholds)
      +      require(ts.length == 2, "Logistic Regression getThreshold only applies to" +
      +        " binary classification, but thresholds has length != 2.  thresholds: " + ts.mkString(","))
      +      1.0 / (1.0 + ts(0) / ts(1))
      +    } else {
      +      $(threshold)
      +    }
      +  }
      +
      +  /**
      +   * Set thresholds in multiclass (or binary) classification to adjust the probability of
      +   * predicting each class. Array must have length equal to the number of classes, with values >= 0.
      +   * The class with largest value p/t is predicted, where p is the original probability of that
      +   * class and t is the class' threshold.
      +   *
      +   * Note: When [[setThresholds()]] is called, any user-set value for [[threshold]] will be cleared.
      +   *       If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be
      +   *       equivalent.
      +   *
      +   * @group setParam
      +   */
      +  def setThresholds(value: Array[Double]): this.type = {
      +    if (isSet(threshold)) clear(threshold)
      +    set(thresholds, value)
      +  }
      +
      +  /**
      +   * Get thresholds for binary or multiclass classification.
      +   *
      +   * If [[thresholds]] is set, return its value.
      +   * Otherwise, if [[threshold]] is set, return the equivalent thresholds for binary
      +   * classification: (1-threshold, threshold).
      +   * If neither are set, throw an exception.
      +   *
      +   * @group getParam
      +   */
      +  override def getThresholds: Array[Double] = {
      +    checkThresholdConsistency()
      +    if (!isSet(thresholds) && isSet(threshold)) {
      +      val t = $(threshold)
      +      Array(1-t, t)
      +    } else {
      +      $(thresholds)
      +    }
      +  }
      +
      +  /**
      +   * If [[threshold]] and [[thresholds]] are both set, ensures they are consistent.
      +   * @throws IllegalArgumentException if [[threshold]] and [[thresholds]] are not equivalent
      +   */
      +  protected def checkThresholdConsistency(): Unit = {
      +    if (isSet(threshold) && isSet(thresholds)) {
      +      val ts = $(thresholds)
      +      require(ts.length == 2, "Logistic Regression found inconsistent values for threshold and" +
      +        s" thresholds.  Param threshold is set (${$(threshold)}), indicating binary" +
      +        s" classification, but Param thresholds is set with length ${ts.length}." +
      +        " Clear one Param value to fix this problem.")
      +      val t = 1.0 / (1.0 + ts(0) / ts(1))
      +      require(math.abs($(threshold) - t) < 1E-5, "Logistic Regression getThreshold found" +
      +        s" inconsistent values for threshold (${$(threshold)}) and thresholds (equivalent to $t)")
      +    }
      +  }
      +
      +  override def validateParams(): Unit = {
      +    checkThresholdConsistency()
      +  }
      +}
      +
      +/**
      + * Class that represents an instance of weighted data point with label and features.
      + *
      + * TODO: Refactor this class to proper place.
      + *
      + * @param label Label for this data point.
      + * @param weight The weight of this instance.
      + * @param features The vector of features for this data point.
      + */
      +private[classification] case class Instance(label: Double, weight: Double, features: Vector)
       
       /**
        * :: Experimental ::
        * Logistic regression.
      - * Currently, this class only supports binary classification.
      + * Currently, this class only supports binary classification.  It will support multiclass
      + * in the future.
        */
       @Experimental
       class LogisticRegression(override val uid: String)
      @@ -94,35 +209,62 @@ class LogisticRegression(override val uid: String)
          * Whether to fit an intercept term.
          * Default is true.
          * @group setParam
      -   * */
      +   */
         def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
         setDefault(fitIntercept -> true)
       
      -  /** @group setParam */
      -  def setThreshold(value: Double): this.type = set(threshold, value)
      -  setDefault(threshold -> 0.5)
      +  /**
      +   * Whether to standardize the training features before fitting the model.
      +   * The coefficients of models will be always returned on the original scale,
      +   * so it will be transparent for users. Note that with/without standardization,
      +   * the models should be always converged to the same solution when no regularization
      +   * is applied. In R's GLMNET package, the default behavior is true as well.
      +   * Default is true.
      +   * @group setParam
      +   */
      +  def setStandardization(value: Boolean): this.type = set(standardization, value)
      +  setDefault(standardization -> true)
      +
      +  override def setThreshold(value: Double): this.type = super.setThreshold(value)
      +
      +  override def getThreshold: Double = super.getThreshold
      +
      +  /**
      +   * Whether to over-/under-sample training instances according to the given weights in weightCol.
      +   * If empty, all instances are treated equally (weight 1.0).
      +   * Default is empty, so all instances have weight one.
      +   * @group setParam
      +   */
      +  def setWeightCol(value: String): this.type = set(weightCol, value)
      +  setDefault(weightCol -> "")
      +
      +  override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
      +
      +  override def getThresholds: Array[Double] = super.getThresholds
       
         override protected def train(dataset: DataFrame): LogisticRegressionModel = {
           // Extract columns from data.  If dataset is persisted, do not persist oldDataset.
      -    val instances = extractLabeledPoints(dataset).map {
      -      case LabeledPoint(label: Double, features: Vector) => (label, features)
      +    val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
      +    val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map {
      +      case Row(label: Double, weight: Double, features: Vector) =>
      +        Instance(label, weight, features)
           }
      +
           val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
           if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
       
      -    val (summarizer, labelSummarizer) = instances.treeAggregate(
      -      (new MultivariateOnlineSummarizer, new MultiClassSummarizer))(
      -        seqOp = (c, v) => (c, v) match {
      -          case ((summarizer: MultivariateOnlineSummarizer, labelSummarizer: MultiClassSummarizer),
      -          (label: Double, features: Vector)) =>
      -            (summarizer.add(features), labelSummarizer.add(label))
      -      },
      -        combOp = (c1, c2) => (c1, c2) match {
      -          case ((summarizer1: MultivariateOnlineSummarizer,
      -          classSummarizer1: MultiClassSummarizer), (summarizer2: MultivariateOnlineSummarizer,
      -          classSummarizer2: MultiClassSummarizer)) =>
      -            (summarizer1.merge(summarizer2), classSummarizer1.merge(classSummarizer2))
      -      })
      +    val (summarizer, labelSummarizer) = {
      +      val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer),
      +        instance: Instance) =>
      +          (c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight))
      +
      +      val combOp = (c1: (MultivariateOnlineSummarizer, MultiClassSummarizer),
      +        c2: (MultivariateOnlineSummarizer, MultiClassSummarizer)) =>
      +          (c1._1.merge(c2._1), c1._2.merge(c2._2))
      +
      +      instances.treeAggregate(
      +        new MultivariateOnlineSummarizer, new MultiClassSummarizer)(seqOp, combOp)
      +    }
       
           val histogram = labelSummarizer.histogram
           val numInvalid = labelSummarizer.countInvalid
      @@ -149,76 +291,105 @@ class LogisticRegression(override val uid: String)
           val regParamL1 = $(elasticNetParam) * $(regParam)
           val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)
       
      -    val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept),
      +    val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept), $(standardization),
             featuresStd, featuresMean, regParamL2)
       
           val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
             new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
           } else {
      -      // Remove the L1 penalization on the intercept
             def regParamL1Fun = (index: Int) => {
      -        if (index == numFeatures) 0.0 else regParamL1
      +        // Remove the L1 penalization on the intercept
      +        if (index == numFeatures) {
      +          0.0
      +        } else {
      +          if ($(standardization)) {
      +            regParamL1
      +          } else {
      +            // If `standardization` is false, we still standardize the data
      +            // to improve the rate of convergence; as a result, we have to
      +            // perform this reverse standardization by penalizing each component
      +            // differently to get effectively the same objective function when
      +            // the training dataset is not standardized.
      +            if (featuresStd(index) != 0.0) regParamL1 / featuresStd(index) else 0.0
      +          }
      +        }
             }
             new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
           }
       
      -    val initialWeightsWithIntercept =
      +    val initialCoefficientsWithIntercept =
             Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures)
       
           if ($(fitIntercept)) {
      -      /**
      -       * For binary logistic regression, when we initialize the weights as zeros,
      -       * it will converge faster if we initialize the intercept such that
      -       * it follows the distribution of the labels.
      -       *
      -       * {{{
      -       * P(0) = 1 / (1 + \exp(b)), and
      -       * P(1) = \exp(b) / (1 + \exp(b))
      -       * }}}, hence
      -       * {{{
      -       * b = \log{P(1) / P(0)} = \log{count_1 / count_0}
      -       * }}}
      +      /*
      +         For binary logistic regression, when we initialize the weights as zeros,
      +         it will converge faster if we initialize the intercept such that
      +         it follows the distribution of the labels.
      +
      +         {{{
      +         P(0) = 1 / (1 + \exp(b)), and
      +         P(1) = \exp(b) / (1 + \exp(b))
      +         }}}, hence
      +         {{{
      +         b = \log{P(1) / P(0)} = \log{count_1 / count_0}
      +         }}}
              */
      -      initialWeightsWithIntercept.toArray(numFeatures)
      -        = math.log(histogram(1).toDouble / histogram(0).toDouble)
      +      initialCoefficientsWithIntercept.toArray(numFeatures)
      +        = math.log(histogram(1) / histogram(0))
           }
       
           val states = optimizer.iterations(new CachedDiffFunction(costFun),
      -      initialWeightsWithIntercept.toBreeze.toDenseVector)
      +      initialCoefficientsWithIntercept.toBreeze.toDenseVector)
       
      -    var state = states.next()
      -    val lossHistory = mutable.ArrayBuilder.make[Double]
      +    val (coefficients, intercept, objectiveHistory) = {
      +      /*
      +         Note that in Logistic Regression, the objective history (loss + regularization)
      +         is log-likelihood which is invariance under feature standardization. As a result,
      +         the objective history from optimizer is the same as the one in the original space.
      +       */
      +      val arrayBuilder = mutable.ArrayBuilder.make[Double]
      +      var state: optimizer.State = null
      +      while (states.hasNext) {
      +        state = states.next()
      +        arrayBuilder += state.adjustedValue
      +      }
       
      -    while (states.hasNext) {
      -      lossHistory += state.value
      -      state = states.next()
      -    }
      -    lossHistory += state.value
      +      if (state == null) {
      +        val msg = s"${optimizer.getClass.getName} failed."
      +        logError(msg)
      +        throw new SparkException(msg)
      +      }
       
      -    // The weights are trained in the scaled space; we're converting them back to
      -    // the original space.
      -    val weightsWithIntercept = {
      -      val rawWeights = state.x.toArray.clone()
      +      /*
      +         The coefficients are trained in the scaled space; we're converting them back to
      +         the original space.
      +         Note that the intercept in scaled space and original space is the same;
      +         as a result, no scaling is needed.
      +       */
      +      val rawCoefficients = state.x.toArray.clone()
             var i = 0
      -      // Note that the intercept in scaled space and original space is the same;
      -      // as a result, no scaling is needed.
             while (i < numFeatures) {
      -        rawWeights(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 }
      +        rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 }
               i += 1
             }
      -      Vectors.dense(rawWeights)
      +
      +      if ($(fitIntercept)) {
      +        (Vectors.dense(rawCoefficients.dropRight(1)).compressed, rawCoefficients.last,
      +          arrayBuilder.result())
      +      } else {
      +        (Vectors.dense(rawCoefficients).compressed, 0.0, arrayBuilder.result())
      +      }
           }
       
           if (handlePersistence) instances.unpersist()
       
      -    val (weights, intercept) = if ($(fitIntercept)) {
      -      (Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)),
      -        weightsWithIntercept(weightsWithIntercept.size - 1))
      -    } else {
      -      (weightsWithIntercept, 0.0)
      -    }
      -
      -    new LogisticRegressionModel(uid, weights.compressed, intercept)
      +    val model = copyValues(new LogisticRegressionModel(uid, coefficients, intercept))
      +    val logRegSummary = new BinaryLogisticRegressionTrainingSummary(
      +      model.transform(dataset),
      +      $(probabilityCol),
      +      $(labelCol),
      +      objectiveHistory)
      +    model.setSummary(logRegSummary)
         }
       
         override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra)
      @@ -236,8 +407,13 @@ class LogisticRegressionModel private[ml] (
         extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
         with LogisticRegressionParams {
       
      -  /** @group setParam */
      -  def setThreshold(value: Double): this.type = set(threshold, value)
      +  override def setThreshold(value: Double): this.type = super.setThreshold(value)
      +
      +  override def getThreshold: Double = super.getThreshold
      +
      +  override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
      +
      +  override def getThresholds: Array[Double] = super.getThresholds
       
         /** Margin (rawPrediction) for class label 1.  For binary classification only. */
         private val margin: Vector => Double = (features) => {
      @@ -252,11 +428,44 @@ class LogisticRegressionModel private[ml] (
       
         override val numClasses: Int = 2
       
      +  private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None
      +
      +  /**
      +   * Gets summary of model on training set. An exception is
      +   * thrown if `trainingSummary == None`.
      +   */
      +  def summary: LogisticRegressionTrainingSummary = trainingSummary match {
      +    case Some(summ) => summ
      +    case None =>
      +      throw new SparkException(
      +        "No training summary available for this LogisticRegressionModel",
      +        new NullPointerException())
      +  }
      +
      +  private[classification] def setSummary(
      +      summary: LogisticRegressionTrainingSummary): this.type = {
      +    this.trainingSummary = Some(summary)
      +    this
      +  }
      +
      +  /** Indicates whether a training summary exists for this model instance. */
      +  def hasSummary: Boolean = trainingSummary.isDefined
      +
      +  /**
      +   * Evaluates the model on a testset.
      +   * @param dataset Test dataset to evaluate model on.
      +   */
      +  // TODO: decide on a good name before exposing to public API
      +  private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = {
      +    new BinaryLogisticRegressionSummary(this.transform(dataset), $(probabilityCol), $(labelCol))
      +  }
      +
         /**
          * Predict label for the given feature vector.
      -   * The behavior of this can be adjusted using [[threshold]].
      +   * The behavior of this can be adjusted using [[thresholds]].
          */
         override protected def predict(features: Vector): Double = {
      +    // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
           if (score(features) > getThreshold) 1 else 0
         }
       
      @@ -282,10 +491,13 @@ class LogisticRegressionModel private[ml] (
         }
       
         override def copy(extra: ParamMap): LogisticRegressionModel = {
      -    copyValues(new LogisticRegressionModel(uid, weights, intercept), extra)
      +    val newModel = copyValues(new LogisticRegressionModel(uid, weights, intercept), extra)
      +    if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
      +    newModel.setParent(parent)
         }
       
         override protected def raw2prediction(rawPrediction: Vector): Double = {
      +    // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
           val t = getThreshold
           val rawThreshold = if (t == 0.0) {
             Double.NegativeInfinity
      @@ -298,6 +510,7 @@ class LogisticRegressionModel private[ml] (
         }
       
         override protected def probability2prediction(probability: Vector): Double = {
      +    // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
           if (probability(1) > getThreshold) 1 else 0
         }
       }
      @@ -311,22 +524,29 @@ class LogisticRegressionModel private[ml] (
        * corresponding joint dataset.
        */
       private[classification] class MultiClassSummarizer extends Serializable {
      -  private val distinctMap = new mutable.HashMap[Int, Long]
      +  // The first element of value in distinctMap is the actually number of instances,
      +  // and the second element of value is sum of the weights.
      +  private val distinctMap = new mutable.HashMap[Int, (Long, Double)]
         private var totalInvalidCnt: Long = 0L
       
         /**
          * Add a new label into this MultilabelSummarizer, and update the distinct map.
          * @param label The label for this data point.
      +   * @param weight The weight of this instances.
          * @return This MultilabelSummarizer
          */
      -  def add(label: Double): this.type = {
      +  def add(label: Double, weight: Double = 1.0): this.type = {
      +    require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0")
      +
      +    if (weight == 0.0) return this
      +
           if (label - label.toInt != 0.0 || label < 0) {
             totalInvalidCnt += 1
             this
           }
           else {
      -      val counts: Long = distinctMap.getOrElse(label.toInt, 0L)
      -      distinctMap.put(label.toInt, counts + 1)
      +      val (counts: Long, weightSum: Double) = distinctMap.getOrElse(label.toInt, (0L, 0.0))
      +      distinctMap.put(label.toInt, (counts + 1L, weightSum + weight))
             this
           }
         }
      @@ -347,8 +567,8 @@ private[classification] class MultiClassSummarizer extends Serializable {
           }
           smallMap.distinctMap.foreach {
             case (key, value) =>
      -        val counts = largeMap.distinctMap.getOrElse(key, 0L)
      -        largeMap.distinctMap.put(key, counts + value)
      +        val (counts: Long, weightSum: Double) = largeMap.distinctMap.getOrElse(key, (0L, 0.0))
      +        largeMap.distinctMap.put(key, (counts + value._1, weightSum + value._2))
           }
           largeMap.totalInvalidCnt += smallMap.totalInvalidCnt
           largeMap
      @@ -360,29 +580,153 @@ private[classification] class MultiClassSummarizer extends Serializable {
         /** @return The number of distinct labels in the input dataset. */
         def numClasses: Int = distinctMap.keySet.max + 1
       
      -  /** @return The counts of each label in the input dataset. */
      -  def histogram: Array[Long] = {
      -    val result = Array.ofDim[Long](numClasses)
      +  /** @return The weightSum of each label in the input dataset. */
      +  def histogram: Array[Double] = {
      +    val result = Array.ofDim[Double](numClasses)
           var i = 0
           val len = result.length
           while (i < len) {
      -      result(i) = distinctMap.getOrElse(i, 0L)
      +      result(i) = distinctMap.getOrElse(i, (0L, 0.0))._2
             i += 1
           }
           result
         }
       }
       
      +/**
      + * Abstraction for multinomial Logistic Regression Training results.
      + * Currently, the training summary ignores the training weights except
      + * for the objective trace.
      + */
      +sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary {
      +
      +  /** objective function (scaled loss + regularization) at each iteration. */
      +  def objectiveHistory: Array[Double]
      +
      +  /** Number of training iterations until termination */
      +  def totalIterations: Int = objectiveHistory.length
      +
      +}
      +
      +/**
      + * Abstraction for Logistic Regression Results for a given model.
      + */
      +sealed trait LogisticRegressionSummary extends Serializable {
      +
      +  /** Dataframe outputted by the model's `transform` method. */
      +  def predictions: DataFrame
      +
      +  /** Field in "predictions" which gives the calibrated probability of each instance as a vector. */
      +  def probabilityCol: String
      +
      +  /** Field in "predictions" which gives the the true label of each instance. */
      +  def labelCol: String
      +
      +}
      +
      +/**
      + * :: Experimental ::
      + * Logistic regression training results.
      + * @param predictions dataframe outputted by the model's `transform` method.
      + * @param probabilityCol field in "predictions" which gives the calibrated probability of
      + *                       each instance as a vector.
      + * @param labelCol field in "predictions" which gives the true label of each instance.
      + * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
      + */
      +@Experimental
      +class BinaryLogisticRegressionTrainingSummary private[classification] (
      +    predictions: DataFrame,
      +    probabilityCol: String,
      +    labelCol: String,
      +    val objectiveHistory: Array[Double])
      +  extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol)
      +  with LogisticRegressionTrainingSummary {
      +
      +}
      +
      +/**
      + * :: Experimental ::
      + * Binary Logistic regression results for a given model.
      + * @param predictions dataframe outputted by the model's `transform` method.
      + * @param probabilityCol field in "predictions" which gives the calibrated probability of
      + *                       each instance.
      + * @param labelCol field in "predictions" which gives the true label of each instance.
      + */
      +@Experimental
      +class BinaryLogisticRegressionSummary private[classification] (
      +    @transient override val predictions: DataFrame,
      +    override val probabilityCol: String,
      +    override val labelCol: String) extends LogisticRegressionSummary {
      +
      +  private val sqlContext = predictions.sqlContext
      +  import sqlContext.implicits._
      +
      +  /**
      +   * Returns a BinaryClassificationMetrics object.
      +   */
      +  // TODO: Allow the user to vary the number of bins using a setBins method in
      +  // BinaryClassificationMetrics. For now the default is set to 100.
      +  @transient private val binaryMetrics = new BinaryClassificationMetrics(
      +    predictions.select(probabilityCol, labelCol).map {
      +      case Row(score: Vector, label: Double) => (score(1), label)
      +    }, 100
      +  )
      +
      +  /**
      +   * Returns the receiver operating characteristic (ROC) curve,
      +   * which is an Dataframe having two fields (FPR, TPR)
      +   * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
      +   * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic
      +   */
      +  @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR")
      +
      +  /**
      +   * Computes the area under the receiver operating characteristic (ROC) curve.
      +   */
      +  lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC()
      +
      +  /**
      +   * Returns the precision-recall curve, which is an Dataframe containing
      +   * two fields recall, precision with (0.0, 1.0) prepended to it.
      +   */
      +  @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision")
      +
      +  /**
      +   * Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0.
      +   */
      +  @transient lazy val fMeasureByThreshold: DataFrame = {
      +    binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure")
      +  }
      +
      +  /**
      +   * Returns a dataframe with two fields (threshold, precision) curve.
      +   * Every possible probability obtained in transforming the dataset are used
      +   * as thresholds used in calculating the precision.
      +   */
      +  @transient lazy val precisionByThreshold: DataFrame = {
      +    binaryMetrics.precisionByThreshold().toDF("threshold", "precision")
      +  }
      +
      +  /**
      +   * Returns a dataframe with two fields (threshold, recall) curve.
      +   * Every possible probability obtained in transforming the dataset are used
      +   * as thresholds used in calculating the recall.
      +   */
      +  @transient lazy val recallByThreshold: DataFrame = {
      +    binaryMetrics.recallByThreshold().toDF("threshold", "recall")
      +  }
      +}
      +
       /**
        * LogisticAggregator computes the gradient and loss for binary logistic loss function, as used
      - * in binary classification for samples in sparse or dense vector in a online fashion.
      + * in binary classification for instances in sparse or dense vector in a online fashion.
        *
        * Note that multinomial logistic loss is not supported yet!
        *
        * Two LogisticAggregator can be merged together to have a summary of loss and gradient of
        * the corresponding joint dataset.
        *
      - * @param weights The weights/coefficients corresponding to the features.
      + * @param coefficients The coefficients corresponding to the features.
        * @param numClasses the number of possible outcomes for k classes classification problem in
        *                   Multinomial Logistic Regression.
        * @param fitIntercept Whether to fit an intercept term.
      @@ -390,25 +734,25 @@ private[classification] class MultiClassSummarizer extends Serializable {
        * @param featuresMean The mean values of the features.
        */
       private class LogisticAggregator(
      -    weights: Vector,
      +    coefficients: Vector,
           numClasses: Int,
           fitIntercept: Boolean,
           featuresStd: Array[Double],
           featuresMean: Array[Double]) extends Serializable {
       
      -  private var totalCnt: Long = 0L
      +  private var weightSum = 0.0
         private var lossSum = 0.0
       
      -  private val weightsArray = weights match {
      +  private val coefficientsArray = coefficients match {
           case dv: DenseVector => dv.values
           case _ =>
             throw new IllegalArgumentException(
      -        s"weights only supports dense vector but got type ${weights.getClass}.")
      +        s"coefficients only supports dense vector but got type ${coefficients.getClass}.")
         }
       
      -  private val dim = if (fitIntercept) weightsArray.length - 1 else weightsArray.length
      +  private val dim = if (fitIntercept) coefficientsArray.length - 1 else coefficientsArray.length
       
      -  private val gradientSumArray = Array.ofDim[Double](weightsArray.length)
      +  private val gradientSumArray = Array.ofDim[Double](coefficientsArray.length)
       
         /**
          * Add a new training data to this LogisticAggregator, and update the loss and gradient
      @@ -417,33 +761,33 @@ private class LogisticAggregator(
          * @param label The label for this data point.
          * @param data The features for one data point in dense/sparse vector format to be added
          *             into this aggregator.
      +   * @param weight The weight for over-/undersamples each of training instance. Default is one.
          * @return This LogisticAggregator object.
          */
      -  def add(label: Double, data: Vector): this.type = {
      -    require(dim == data.size, s"Dimensions mismatch when adding new sample." +
      +  def add(label: Double, data: Vector, weight: Double = 1.0): this.type = {
      +    require(dim == data.size, s"Dimensions mismatch when adding new instance." +
             s" Expecting $dim but got ${data.size}.")
      +    require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0")
       
      -    val dataSize = data.size
      +    if (weight == 0.0) return this
       
      -    val localWeightsArray = weightsArray
      +    val localCoefficientsArray = coefficientsArray
           val localGradientSumArray = gradientSumArray
       
           numClasses match {
             case 2 =>
      -        /**
      -         * For Binary Logistic Regression.
      -         */
      +        // For Binary Logistic Regression.
               val margin = - {
                 var sum = 0.0
                 data.foreachActive { (index, value) =>
                   if (featuresStd(index) != 0.0 && value != 0.0) {
      -              sum += localWeightsArray(index) * (value / featuresStd(index))
      +              sum += localCoefficientsArray(index) * (value / featuresStd(index))
                   }
                 }
      -          sum + { if (fitIntercept) localWeightsArray(dim) else 0.0 }
      +          sum + { if (fitIntercept) localCoefficientsArray(dim) else 0.0 }
               }
       
      -        val multiplier = (1.0 / (1.0 + math.exp(margin))) - label
      +        val multiplier = weight * (1.0 / (1.0 + math.exp(margin)) - label)
       
               data.foreachActive { (index, value) =>
                 if (featuresStd(index) != 0.0 && value != 0.0) {
      @@ -457,15 +801,15 @@ private class LogisticAggregator(
       
               if (label > 0) {
                 // The following is equivalent to log(1 + exp(margin)) but more numerically stable.
      -          lossSum += MLUtils.log1pExp(margin)
      +          lossSum += weight * MLUtils.log1pExp(margin)
               } else {
      -          lossSum += MLUtils.log1pExp(margin) - margin
      +          lossSum += weight * (MLUtils.log1pExp(margin) - margin)
               }
             case _ =>
               new NotImplementedError("LogisticRegression with ElasticNet in ML package only supports " +
                 "binary classification for now.")
           }
      -    totalCnt += 1
      +    weightSum += weight
           this
         }
       
      @@ -481,8 +825,8 @@ private class LogisticAggregator(
           require(dim == other.dim, s"Dimensions mismatch when merging with another " +
             s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.")
       
      -    if (other.totalCnt != 0) {
      -      totalCnt += other.totalCnt
      +    if (other.weightSum != 0.0) {
      +      weightSum += other.weightSum
             lossSum += other.lossSum
       
             var i = 0
      @@ -497,13 +841,17 @@ private class LogisticAggregator(
           this
         }
       
      -  def count: Long = totalCnt
      -
      -  def loss: Double = lossSum / totalCnt
      +  def loss: Double = {
      +    require(weightSum > 0.0, s"The effective number of instances should be " +
      +      s"greater than 0.0, but $weightSum.")
      +    lossSum / weightSum
      +  }
       
         def gradient: Vector = {
      +    require(weightSum > 0.0, s"The effective number of instances should be " +
      +      s"greater than 0.0, but $weightSum.")
           val result = Vectors.dense(gradientSumArray.clone())
      -    scal(1.0 / totalCnt, result)
      +    scal(1.0 / weightSum, result)
           result
         }
       }
      @@ -515,46 +863,65 @@ private class LogisticAggregator(
        * It's used in Breeze's convex optimization routines.
        */
       private class LogisticCostFun(
      -    data: RDD[(Double, Vector)],
      +    data: RDD[Instance],
           numClasses: Int,
           fitIntercept: Boolean,
      +    standardization: Boolean,
           featuresStd: Array[Double],
           featuresMean: Array[Double],
           regParamL2: Double) extends DiffFunction[BDV[Double]] {
       
      -  override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = {
      -    val w = Vectors.fromBreeze(weights)
      +  override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
      +    val numFeatures = featuresStd.length
      +    val w = Vectors.fromBreeze(coefficients)
       
      -    val logisticAggregator = data.treeAggregate(new LogisticAggregator(w, numClasses, fitIntercept,
      -      featuresStd, featuresMean))(
      -        seqOp = (c, v) => (c, v) match {
      -          case (aggregator, (label, features)) => aggregator.add(label, features)
      -        },
      -        combOp = (c1, c2) => (c1, c2) match {
      -          case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
      -        })
      +    val logisticAggregator = {
      +      val seqOp = (c: LogisticAggregator, instance: Instance) =>
      +        c.add(instance.label, instance.features, instance.weight)
      +      val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2)
       
      -    // regVal is the sum of weight squares for L2 regularization
      -    val norm = if (regParamL2 == 0.0) {
      -      0.0
      -    } else if (fitIntercept) {
      -      brzNorm(Vectors.dense(weights.toArray.slice(0, weights.size -1)).toBreeze, 2.0)
      -    } else {
      -      brzNorm(weights, 2.0)
      +      data.treeAggregate(
      +        new LogisticAggregator(w, numClasses, fitIntercept, featuresStd, featuresMean)
      +      )(seqOp, combOp)
           }
      -    val regVal = 0.5 * regParamL2 * norm * norm
       
      -    val loss = logisticAggregator.loss + regVal
      -    val gradient = logisticAggregator.gradient
      +    val totalGradientArray = logisticAggregator.gradient.toArray
       
      -    if (fitIntercept) {
      -      val wArray = w.toArray.clone()
      -      wArray(wArray.length - 1) = 0.0
      -      axpy(regParamL2, Vectors.dense(wArray), gradient)
      +    // regVal is the sum of coefficients squares excluding intercept for L2 regularization.
      +    val regVal = if (regParamL2 == 0.0) {
      +      0.0
           } else {
      -      axpy(regParamL2, w, gradient)
      +      var sum = 0.0
      +      w.foreachActive { (index, value) =>
      +        // If `fitIntercept` is true, the last term which is intercept doesn't
      +        // contribute to the regularization.
      +        if (index != numFeatures) {
      +          // The following code will compute the loss of the regularization; also
      +          // the gradient of the regularization, and add back to totalGradientArray.
      +          sum += {
      +            if (standardization) {
      +              totalGradientArray(index) += regParamL2 * value
      +              value * value
      +            } else {
      +              if (featuresStd(index) != 0.0) {
      +                // If `standardization` is false, we still standardize the data
      +                // to improve the rate of convergence; as a result, we have to
      +                // perform this reverse standardization by penalizing each component
      +                // differently to get effectively the same objective function when
      +                // the training dataset is not standardized.
      +                val temp = value / (featuresStd(index) * featuresStd(index))
      +                totalGradientArray(index) += regParamL2 * temp
      +                value * temp
      +              } else {
      +                0.0
      +              }
      +            }
      +          }
      +        }
      +      }
      +      0.5 * regParamL2 * sum
           }
       
      -    (loss, gradient.toBreeze.asInstanceOf[BDV[Double]])
      +    (logisticAggregator.loss + regVal, new BDV(totalGradientArray))
         }
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
      new file mode 100644
      index 000000000000..5f60dea91fcf
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
      @@ -0,0 +1,204 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.classification
      +
      +import scala.collection.JavaConverters._
      +
      +import org.apache.spark.annotation.Experimental
      +import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed}
      +import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor}
      +import org.apache.spark.ml.param.{IntParam, ParamValidators, IntArrayParam, ParamMap}
      +import org.apache.spark.ml.util.Identifiable
      +import org.apache.spark.ml.ann.{FeedForwardTrainer, FeedForwardTopology}
      +import org.apache.spark.mllib.linalg.{Vectors, Vector}
      +import org.apache.spark.mllib.regression.LabeledPoint
      +import org.apache.spark.sql.DataFrame
      +
      +/** Params for Multilayer Perceptron. */
      +private[ml] trait MultilayerPerceptronParams extends PredictorParams
      +  with HasSeed with HasMaxIter with HasTol {
      +  /**
      +   * Layer sizes including input size and output size.
      +   * Default: Array(1, 1)
      +   * @group param
      +   */
      +  final val layers: IntArrayParam = new IntArrayParam(this, "layers",
      +    "Sizes of layers from input layer to output layer" +
      +      " E.g., Array(780, 100, 10) means 780 inputs, " +
      +      "one hidden layer with 100 neurons and output layer of 10 neurons.",
      +    // TODO: how to check ALSO that all elements are greater than 0?
      +    ParamValidators.arrayLengthGt(1)
      +  )
      +
      +  /** @group getParam */
      +  final def getLayers: Array[Int] = $(layers)
      +
      +  /**
      +   * Block size for stacking input data in matrices to speed up the computation.
      +   * Data is stacked within partitions. If block size is more than remaining data in
      +   * a partition then it is adjusted to the size of this data.
      +   * Recommended size is between 10 and 1000.
      +   * Default: 128
      +   * @group expertParam
      +   */
      +  final val blockSize: IntParam = new IntParam(this, "blockSize",
      +    "Block size for stacking input data in matrices. Data is stacked within partitions." +
      +      " If block size is more than remaining data in a partition then " +
      +      "it is adjusted to the size of this data. Recommended size is between 10 and 1000",
      +    ParamValidators.gt(0))
      +
      +  /** @group getParam */
      +  final def getBlockSize: Int = $(blockSize)
      +
      +  setDefault(maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 128)
      +}
      +
      +/** Label to vector converter. */
      +private object LabelConverter {
      +  // TODO: Use OneHotEncoder instead
      +  /**
      +   * Encodes a label as a vector.
      +   * Returns a vector of given length with zeroes at all positions
      +   * and value 1.0 at the position that corresponds to the label.
      +   *
      +   * @param labeledPoint labeled point
      +   * @param labelCount total number of labels
      +   * @return pair of features and vector encoding of a label
      +   */
      +  def encodeLabeledPoint(labeledPoint: LabeledPoint, labelCount: Int): (Vector, Vector) = {
      +    val output = Array.fill(labelCount)(0.0)
      +    output(labeledPoint.label.toInt) = 1.0
      +    (labeledPoint.features, Vectors.dense(output))
      +  }
      +
      +  /**
      +   * Converts a vector to a label.
      +   * Returns the position of the maximal element of a vector.
      +   *
      +   * @param output label encoded with a vector
      +   * @return label
      +   */
      +  def decodeLabel(output: Vector): Double = {
      +    output.argmax.toDouble
      +  }
      +}
      +
      +/**
      + * :: Experimental ::
      + * Classifier trainer based on the Multilayer Perceptron.
      + * Each layer has sigmoid activation function, output layer has softmax.
      + * Number of inputs has to be equal to the size of feature vectors.
      + * Number of outputs has to be equal to the total number of labels.
      + *
      + */
      +@Experimental
      +class MultilayerPerceptronClassifier(override val uid: String)
      +  extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel]
      +  with MultilayerPerceptronParams {
      +
      +  def this() = this(Identifiable.randomUID("mlpc"))
      +
      +  /** @group setParam */
      +  def setLayers(value: Array[Int]): this.type = set(layers, value)
      +
      +  /** @group setParam */
      +  def setBlockSize(value: Int): this.type = set(blockSize, value)
      +
      +  /**
      +   * Set the maximum number of iterations.
      +   * Default is 100.
      +   * @group setParam
      +   */
      +  def setMaxIter(value: Int): this.type = set(maxIter, value)
      +
      +  /**
      +   * Set the convergence tolerance of iterations.
      +   * Smaller value will lead to higher accuracy with the cost of more iterations.
      +   * Default is 1E-4.
      +   * @group setParam
      +   */
      +  def setTol(value: Double): this.type = set(tol, value)
      +
      +  /**
      +   * Set the seed for weights initialization.
      +   * @group setParam
      +   */
      +  def setSeed(value: Long): this.type = set(seed, value)
      +
      +  override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra)
      +
      +  /**
      +   * Train a model using the given dataset and parameters.
      +   * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation
      +   * and copying parameters into the model.
      +   *
      +   * @param dataset Training dataset
      +   * @return Fitted model
      +   */
      +  override protected def train(dataset: DataFrame): MultilayerPerceptronClassificationModel = {
      +    val myLayers = $(layers)
      +    val labels = myLayers.last
      +    val lpData = extractLabeledPoints(dataset)
      +    val data = lpData.map(lp => LabelConverter.encodeLabeledPoint(lp, labels))
      +    val topology = FeedForwardTopology.multiLayerPerceptron(myLayers, true)
      +    val FeedForwardTrainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last)
      +    FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol($(tol)).setNumIterations($(maxIter))
      +    FeedForwardTrainer.setStackSize($(blockSize))
      +    val mlpModel = FeedForwardTrainer.train(data)
      +    new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights())
      +  }
      +}
      +
      +/**
      + * :: Experimental ::
      + * Classification model based on the Multilayer Perceptron.
      + * Each layer has sigmoid activation function, output layer has softmax.
      + * @param uid uid
      + * @param layers array of layer sizes including input and output layers
      + * @param weights vector of initial weights for the model that consists of the weights of layers
      + * @return prediction model
      + */
      +@Experimental
      +class MultilayerPerceptronClassificationModel private[ml] (
      +    override val uid: String,
      +    val layers: Array[Int],
      +    val weights: Vector)
      +  extends PredictionModel[Vector, MultilayerPerceptronClassificationModel]
      +  with Serializable {
      +
      +  private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights)
      +
      +  /**
      +   * Returns layers in a Java List.
      +   */
      +  private[ml] def javaLayers: java.util.List[Int] = {
      +    layers.toList.asJava
      +  }
      +
      +  /**
      +   * Predict label for the given features.
      +   * This internal method is used to implement [[transform()]] and output [[predictionCol]].
      +   */
      +  override protected def predict(features: Vector): Double = {
      +    LabelConverter.decodeLabel(mlpModel.predict(features))
      +  }
      +
      +  override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = {
      +    copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra)
      +  }
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
      new file mode 100644
      index 000000000000..082ea1ffad58
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
      @@ -0,0 +1,219 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.classification
      +
      +import org.apache.spark.SparkException
      +import org.apache.spark.annotation.Experimental
      +import org.apache.spark.ml.PredictorParams
      +import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
      +import org.apache.spark.ml.util.Identifiable
      +import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes, NaiveBayesModel => OldNaiveBayesModel}
      +import org.apache.spark.mllib.linalg._
      +import org.apache.spark.mllib.regression.LabeledPoint
      +import org.apache.spark.rdd.RDD
      +import org.apache.spark.sql.DataFrame
      +
      +/**
      + * Params for Naive Bayes Classifiers.
      + */
      +private[ml] trait NaiveBayesParams extends PredictorParams {
      +
      +  /**
      +   * The smoothing parameter.
      +   * (default = 1.0).
      +   * @group param
      +   */
      +  final val smoothing: DoubleParam = new DoubleParam(this, "smoothing", "The smoothing parameter.",
      +    ParamValidators.gtEq(0))
      +
      +  /** @group getParam */
      +  final def getSmoothing: Double = $(smoothing)
      +
      +  /**
      +   * The model type which is a string (case-sensitive).
      +   * Supported options: "multinomial" and "bernoulli".
      +   * (default = multinomial)
      +   * @group param
      +   */
      +  final val modelType: Param[String] = new Param[String](this, "modelType", "The model type " +
      +    "which is a string (case-sensitive). Supported options: multinomial (default) and bernoulli.",
      +    ParamValidators.inArray[String](OldNaiveBayes.supportedModelTypes.toArray))
      +
      +  /** @group getParam */
      +  final def getModelType: String = $(modelType)
      +}
      +
      +/**
      + * :: Experimental ::
      + * Naive Bayes Classifiers.
      + * It supports both Multinomial NB
      + * ([[http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html]])
      + * which can handle finitely supported discrete data. For example, by converting documents into
      + * TF-IDF vectors, it can be used for document classification. By making every vector a
      + * binary (0/1) data, it can also be used as Bernoulli NB
      + * ([[http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html]]).
      + * The input feature values must be nonnegative.
      + */
      +@Experimental
      +class NaiveBayes(override val uid: String)
      +  extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel]
      +  with NaiveBayesParams {
      +
      +  def this() = this(Identifiable.randomUID("nb"))
      +
      +  /**
      +   * Set the smoothing parameter.
      +   * Default is 1.0.
      +   * @group setParam
      +   */
      +  def setSmoothing(value: Double): this.type = set(smoothing, value)
      +  setDefault(smoothing -> 1.0)
      +
      +  /**
      +   * Set the model type using a string (case-sensitive).
      +   * Supported options: "multinomial" and "bernoulli".
      +   * Default is "multinomial"
      +   * @group setParam
      +   */
      +  def setModelType(value: String): this.type = set(modelType, value)
      +  setDefault(modelType -> OldNaiveBayes.Multinomial)
      +
      +  override protected def train(dataset: DataFrame): NaiveBayesModel = {
      +    val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
      +    val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType))
      +    NaiveBayesModel.fromOld(oldModel, this)
      +  }
      +
      +  override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra)
      +}
      +
      +/**
      + * :: Experimental ::
      + * Model produced by [[NaiveBayes]]
      + * @param pi log of class priors, whose dimension is C (number of classes)
      + * @param theta log of class conditional probabilities, whose dimension is C (number of classes)
      + *              by D (number of features)
      + */
      +@Experimental
      +class NaiveBayesModel private[ml] (
      +    override val uid: String,
      +    val pi: Vector,
      +    val theta: Matrix)
      +  extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams {
      +
      +  import OldNaiveBayes.{Bernoulli, Multinomial}
      +
      +  /**
      +   * Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
      +   * This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
      +   * application of this condition (in predict function).
      +   */
      +  private lazy val (thetaMinusNegTheta, negThetaSum) = $(modelType) match {
      +    case Multinomial => (None, None)
      +    case Bernoulli =>
      +      val negTheta = theta.map(value => math.log(1.0 - math.exp(value)))
      +      val ones = new DenseVector(Array.fill(theta.numCols){1.0})
      +      val thetaMinusNegTheta = theta.map { value =>
      +        value - math.log(1.0 - math.exp(value))
      +      }
      +      (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones)))
      +    case _ =>
      +      // This should never happen.
      +      throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
      +  }
      +
      +  override val numClasses: Int = pi.size
      +
      +  private def multinomialCalculation(features: Vector) = {
      +    val prob = theta.multiply(features)
      +    BLAS.axpy(1.0, pi, prob)
      +    prob
      +  }
      +
      +  private def bernoulliCalculation(features: Vector) = {
      +    features.foreachActive((_, value) =>
      +      if (value != 0.0 && value != 1.0) {
      +        throw new SparkException(
      +          s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features.")
      +      }
      +    )
      +    val prob = thetaMinusNegTheta.get.multiply(features)
      +    BLAS.axpy(1.0, pi, prob)
      +    BLAS.axpy(1.0, negThetaSum.get, prob)
      +    prob
      +  }
      +
      +  override protected def predictRaw(features: Vector): Vector = {
      +    $(modelType) match {
      +      case Multinomial =>
      +        multinomialCalculation(features)
      +      case Bernoulli =>
      +        bernoulliCalculation(features)
      +      case _ =>
      +        // This should never happen.
      +        throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
      +    }
      +  }
      +
      +  override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
      +    rawPrediction match {
      +      case dv: DenseVector =>
      +        var i = 0
      +        val size = dv.size
      +        val maxLog = dv.values.max
      +        while (i < size) {
      +          dv.values(i) = math.exp(dv.values(i) - maxLog)
      +          i += 1
      +        }
      +        val probSum = dv.values.sum
      +        i = 0
      +        while (i < size) {
      +          dv.values(i) = dv.values(i) / probSum
      +          i += 1
      +        }
      +        dv
      +      case sv: SparseVector =>
      +        throw new RuntimeException("Unexpected error in NaiveBayesModel:" +
      +          " raw2probabilityInPlace encountered SparseVector")
      +    }
      +  }
      +
      +  override def copy(extra: ParamMap): NaiveBayesModel = {
      +    copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra)
      +  }
      +
      +  override def toString: String = {
      +    s"NaiveBayesModel (uid=$uid) with ${pi.size} classes"
      +  }
      +
      +}
      +
      +private[ml] object NaiveBayesModel {
      +
      +  /** Convert a model from the old API */
      +  def fromOld(
      +      oldModel: OldNaiveBayesModel,
      +      parent: NaiveBayes): NaiveBayesModel = {
      +    val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb")
      +    val labels = Vectors.dense(oldModel.labels)
      +    val pi = Vectors.dense(oldModel.pi)
      +    val theta = new DenseMatrix(oldModel.labels.length, oldModel.theta(0).length,
      +      oldModel.theta.flatten, true)
      +    new NaiveBayesModel(uid, pi, theta)
      +  }
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
      index b657882f8ad3..debc164bf243 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
      @@ -47,6 +47,8 @@ private[ml] trait OneVsRestParams extends PredictorParams {
       
         /**
          * param for the base binary classifier that we reduce multiclass classification into.
      +   * The base classifier input and output columns are ignored in favor of
      +   * the ones specified in [[OneVsRest]].
          * @group param
          */
         val classifier: Param[ClassifierType] = new Param(this, "classifier", "base binary classifier")
      @@ -88,9 +90,8 @@ final class OneVsRestModel private[ml] (
       
           // add an accumulator column to store predictions of all the models
           val accColName = "mbc$acc" + UUID.randomUUID().toString
      -    val init: () => Map[Int, Double] = () => {Map()}
      -    val mapType = MapType(IntegerType, DoubleType, valueContainsNull = false)
      -    val newDataset = dataset.withColumn(accColName, callUDF(init, mapType))
      +    val initUDF = udf { () => Map[Int, Double]() }
      +    val newDataset = dataset.withColumn(accColName, initUDF())
       
           // persist if underlying dataset is not persistent.
           val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
      @@ -106,13 +107,12 @@ final class OneVsRestModel private[ml] (
       
               // add temporary column to store intermediate scores and update
               val tmpColName = "mbc$tmp" + UUID.randomUUID().toString
      -        val update: (Map[Int, Double], Vector) => Map[Int, Double] =
      -          (predictions: Map[Int, Double], prediction: Vector) => {
      -            predictions + ((index, prediction(1)))
      -          }
      -        val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol))
      +        val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) =>
      +          predictions + ((index, prediction(1)))
      +        }
               val transformedDataset = model.transform(df).select(columns : _*)
      -        val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf)
      +        val updatedDataset = transformedDataset
      +          .withColumn(tmpColName, updateUDF(col(accColName), col(rawPredictionCol)))
               val newColumns = origCols ++ List(col(tmpColName))
       
               // switch out the intermediate column with the accumulator column
      @@ -124,20 +124,20 @@ final class OneVsRestModel private[ml] (
           }
       
           // output the index of the classifier with highest confidence as prediction
      -    val label: Map[Int, Double] => Double = (predictions: Map[Int, Double]) => {
      +    val labelUDF = udf { (predictions: Map[Int, Double]) =>
             predictions.maxBy(_._2)._1.toDouble
           }
       
           // output label and label metadata as prediction
      -    val labelUdf = callUDF(label, DoubleType, col(accColName))
      -    aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
      +    aggregatedDataset
      +      .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata)
             .drop(accColName)
         }
       
         override def copy(extra: ParamMap): OneVsRestModel = {
           val copied = new OneVsRestModel(
             uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]]))
      -    copyValues(copied, extra)
      +    copyValues(copied, extra).setParent(parent)
         }
       }
       
      @@ -161,6 +161,15 @@ final class OneVsRest(override val uid: String)
           set(classifier, value.asInstanceOf[ClassifierType])
         }
       
      +  /** @group setParam */
      +  def setLabelCol(value: String): this.type = set(labelCol, value)
      +
      +  /** @group setParam */
      +  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
      +
      +  /** @group setParam */
      +  def setPredictionCol(value: String): this.type = set(predictionCol, value)
      +
         override def transformSchema(schema: StructType): StructType = {
           validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
         }
      @@ -185,20 +194,17 @@ final class OneVsRest(override val uid: String)
       
           // create k columns, one for each binary classifier.
           val models = Range(0, numClasses).par.map { index =>
      -
      -      val label: Double => Double = (label: Double) => {
      -        if (label.toInt == index) 1.0 else 0.0
      -      }
      -
             // generate new label metadata for the binary problem.
      -      // TODO: use when ... otherwise after SPARK-7321 is merged
      -      val labelUDF = callUDF(label, DoubleType, col($(labelCol)))
             val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
             val labelColName = "mc2b$" + index
      -      val labelUDFWithNewMeta = labelUDF.as(labelColName, newLabelMeta)
      -      val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
      +      val trainingDataset = multiclassLabeled.withColumn(
      +        labelColName, when(col($(labelCol)) === index.toDouble, 1.0).otherwise(0.0), newLabelMeta)
             val classifier = getClassifier
      -      classifier.fit(trainingDataset, classifier.labelCol -> labelColName)
      +      val paramMap = new ParamMap()
      +      paramMap.put(classifier.labelCol -> labelColName)
      +      paramMap.put(classifier.featuresCol -> getFeaturesCol)
      +      paramMap.put(classifier.predictionCol -> getPredictionCol)
      +      classifier.fit(trainingDataset, paramMap)
           }.toArray[ClassificationModel[_, _]]
       
           if (handlePersistence) {
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
      index 330ae2938f4e..fdd1851ae550 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
      @@ -20,17 +20,16 @@ package org.apache.spark.ml.classification
       import org.apache.spark.annotation.DeveloperApi
       import org.apache.spark.ml.param.shared._
       import org.apache.spark.ml.util.SchemaUtils
      -import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
      +import org.apache.spark.mllib.linalg.{DenseVector, Vector, VectorUDT, Vectors}
       import org.apache.spark.sql.DataFrame
       import org.apache.spark.sql.functions._
      -import org.apache.spark.sql.types.{DoubleType, DataType, StructType}
      +import org.apache.spark.sql.types.{DataType, StructType}
       
       /**
        * (private[classification])  Params for probabilistic classification.
        */
       private[classification] trait ProbabilisticClassifierParams
      -  extends ClassifierParams with HasProbabilityCol {
      -
      +  extends ClassifierParams with HasProbabilityCol with HasThresholds {
         override protected def validateAndTransformSchema(
             schema: StructType,
             fitting: Boolean,
      @@ -51,7 +50,7 @@ private[classification] trait ProbabilisticClassifierParams
        * @tparam M  Concrete Model type
        */
       @DeveloperApi
      -private[spark] abstract class ProbabilisticClassifier[
      +abstract class ProbabilisticClassifier[
           FeaturesType,
           E <: ProbabilisticClassifier[FeaturesType, E, M],
           M <: ProbabilisticClassificationModel[FeaturesType, M]]
      @@ -59,6 +58,9 @@ private[spark] abstract class ProbabilisticClassifier[
       
         /** @group setParam */
         def setProbabilityCol(value: String): E = set(probabilityCol, value).asInstanceOf[E]
      +
      +  /** @group setParam */
      +  def setThresholds(value: Array[Double]): E = set(thresholds, value).asInstanceOf[E]
       }
       
       
      @@ -72,7 +74,7 @@ private[spark] abstract class ProbabilisticClassifier[
        * @tparam M  Concrete Model type
        */
       @DeveloperApi
      -private[spark] abstract class ProbabilisticClassificationModel[
      +abstract class ProbabilisticClassificationModel[
           FeaturesType,
           M <: ProbabilisticClassificationModel[FeaturesType, M]]
         extends ClassificationModel[FeaturesType, M] with ProbabilisticClassifierParams {
      @@ -80,6 +82,9 @@ private[spark] abstract class ProbabilisticClassificationModel[
         /** @group setParam */
         def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M]
       
      +  /** @group setParam */
      +  def setThresholds(value: Array[Double]): M = set(thresholds, value).asInstanceOf[M]
      +
         /**
          * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by
          * parameters:
      @@ -92,32 +97,45 @@ private[spark] abstract class ProbabilisticClassificationModel[
          */
         override def transform(dataset: DataFrame): DataFrame = {
           transformSchema(dataset.schema, logging = true)
      +    if (isDefined(thresholds)) {
      +      require($(thresholds).length == numClasses, this.getClass.getSimpleName +
      +        ".transform() called with non-matching numClasses and thresholds.length." +
      +        s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
      +    }
       
           // Output selected columns only.
           // This is a bit complicated since it tries to avoid repeated computation.
           var outputData = dataset
           var numColsOutput = 0
           if ($(rawPredictionCol).nonEmpty) {
      -      outputData = outputData.withColumn(getRawPredictionCol,
      -        callUDF(predictRaw _, new VectorUDT, col(getFeaturesCol)))
      +      val predictRawUDF = udf { (features: Any) =>
      +        predictRaw(features.asInstanceOf[FeaturesType])
      +      }
      +      outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)))
             numColsOutput += 1
           }
           if ($(probabilityCol).nonEmpty) {
             val probUDF = if ($(rawPredictionCol).nonEmpty) {
      -        callUDF(raw2probability _, new VectorUDT, col($(rawPredictionCol)))
      +        udf(raw2probability _).apply(col($(rawPredictionCol)))
             } else {
      -        callUDF(predictProbability _, new VectorUDT, col($(featuresCol)))
      +        val probabilityUDF = udf { (features: Any) =>
      +          predictProbability(features.asInstanceOf[FeaturesType])
      +        }
      +        probabilityUDF(col($(featuresCol)))
             }
             outputData = outputData.withColumn($(probabilityCol), probUDF)
             numColsOutput += 1
           }
           if ($(predictionCol).nonEmpty) {
             val predUDF = if ($(rawPredictionCol).nonEmpty) {
      -        callUDF(raw2prediction _, DoubleType, col($(rawPredictionCol)))
      +        udf(raw2prediction _).apply(col($(rawPredictionCol)))
             } else if ($(probabilityCol).nonEmpty) {
      -        callUDF(probability2prediction _, DoubleType, col($(probabilityCol)))
      +        udf(probability2prediction _).apply(col($(probabilityCol)))
             } else {
      -        callUDF(predict _, DoubleType, col($(featuresCol)))
      +        val predictUDF = udf { (features: Any) =>
      +          predict(features.asInstanceOf[FeaturesType])
      +        }
      +        predictUDF(col($(featuresCol)))
             }
             outputData = outputData.withColumn($(predictionCol), predUDF)
             numColsOutput += 1
      @@ -147,6 +165,14 @@ private[spark] abstract class ProbabilisticClassificationModel[
           raw2probabilityInPlace(probs)
         }
       
      +  override protected def raw2prediction(rawPrediction: Vector): Double = {
      +    if (!isDefined(thresholds)) {
      +      rawPrediction.argmax
      +    } else {
      +      probability2prediction(raw2probability(rawPrediction))
      +    }
      +  }
      +
         /**
          * Predict the probability of each class given the features.
          * These predictions are also called class conditional probabilities.
      @@ -162,8 +188,44 @@ private[spark] abstract class ProbabilisticClassificationModel[
       
         /**
          * Given a vector of class conditional probabilities, select the predicted label.
      -   * This may be overridden to support thresholds which favor particular labels.
      +   * This supports thresholds which favor particular labels.
          * @return  predicted label
          */
      -  protected def probability2prediction(probability: Vector): Double = probability.toDense.argmax
      +  protected def probability2prediction(probability: Vector): Double = {
      +    if (!isDefined(thresholds)) {
      +      probability.argmax
      +    } else {
      +      val thresholds: Array[Double] = getThresholds
      +      val scaledProbability: Array[Double] =
      +        probability.toArray.zip(thresholds).map { case (p, t) =>
      +          if (t == 0.0) Double.PositiveInfinity else p / t
      +        }
      +      Vectors.dense(scaledProbability).argmax
      +    }
      +  }
      +}
      +
      +private[ml] object ProbabilisticClassificationModel {
      +
      +  /**
      +   * Normalize a vector of raw predictions to be a multinomial probability vector, in place.
      +   *
      +   * The input raw predictions should be >= 0.
      +   * The output vector sums to 1, unless the input vector is all-0 (in which case the output is
      +   * all-0 too).
      +   *
      +   * NOTE: This is NOT applicable to all models, only ones which effectively use class
      +   *       instance counts for raw predictions.
      +   */
      +  def normalizeToProbabilitiesInPlace(v: DenseVector): Unit = {
      +    val sum = v.values.sum
      +    if (sum != 0) {
      +      var i = 0
      +      val size = v.size
      +      while (i < size) {
      +        v.values(i) /= sum
      +        i += 1
      +      }
      +    }
      +  }
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
      index d3c67494a31e..a6ebee1bb10a 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
      @@ -17,20 +17,19 @@
       
       package org.apache.spark.ml.classification
       
      -import scala.collection.mutable
      -
       import org.apache.spark.annotation.Experimental
      -import org.apache.spark.ml.{PredictionModel, Predictor}
      +import org.apache.spark.ml.tree.impl.RandomForest
       import org.apache.spark.ml.param.ParamMap
       import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
       import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
      -import org.apache.spark.mllib.linalg.Vector
      +import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors}
       import org.apache.spark.mllib.regression.LabeledPoint
      -import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
       import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
       import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
       import org.apache.spark.rdd.RDD
       import org.apache.spark.sql.DataFrame
      +import org.apache.spark.sql.functions._
      +
       
       /**
        * :: Experimental ::
      @@ -41,7 +40,7 @@ import org.apache.spark.sql.DataFrame
        */
       @Experimental
       final class RandomForestClassifier(override val uid: String)
      -  extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel]
      +  extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
         with RandomForestParams with TreeClassifierParams {
       
         def this() = this(Identifiable.randomUID("rfc"))
      @@ -93,9 +92,11 @@ final class RandomForestClassifier(override val uid: String)
           val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
           val strategy =
             super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
      -    val oldModel = OldRandomForest.trainClassifier(
      -      oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
      -    RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures)
      +    val trees =
      +      RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
      +        .map(_.asInstanceOf[DecisionTreeClassificationModel])
      +    val numFeatures = oldDataset.first().features.size
      +    new RandomForestClassificationModel(trees, numFeatures, numClasses)
         }
       
         override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
      @@ -118,16 +119,29 @@ object RandomForestClassifier {
        * features.
        * @param _trees  Decision trees in the ensemble.
        *               Warning: These have null parents.
      + * @param numFeatures  Number of features used by this model
        */
       @Experimental
       final class RandomForestClassificationModel private[ml] (
           override val uid: String,
      -    private val _trees: Array[DecisionTreeClassificationModel])
      -  extends PredictionModel[Vector, RandomForestClassificationModel]
      +    private val _trees: Array[DecisionTreeClassificationModel],
      +    val numFeatures: Int,
      +    override val numClasses: Int)
      +  extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
         with TreeEnsembleModel with Serializable {
       
         require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")
       
      +  /**
      +   * Construct a random forest classification model, with all trees weighted equally.
      +   * @param trees  Component trees
      +   */
      +  private[ml] def this(
      +      trees: Array[DecisionTreeClassificationModel],
      +      numFeatures: Int,
      +      numClasses: Int) =
      +    this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses)
      +
         override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
       
         // Note: We may add support for weights (based on tree performance) later on.
      @@ -135,27 +149,70 @@ final class RandomForestClassificationModel private[ml] (
       
         override def treeWeights: Array[Double] = _treeWeights
       
      -  override protected def predict(features: Vector): Double = {
      -    // TODO: Override transform() to broadcast model.  SPARK-7127
      +  override protected def transformImpl(dataset: DataFrame): DataFrame = {
      +    val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
      +    val predictUDF = udf { (features: Any) =>
      +      bcastModel.value.predict(features.asInstanceOf[Vector])
      +    }
      +    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
      +  }
      +
      +  override protected def predictRaw(features: Vector): Vector = {
           // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
           // Classifies using majority votes.
      -    // Ignore the weights since all are 1.0 for now.
      -    val votes = mutable.Map.empty[Int, Double]
      +    // Ignore the tree weights since all are 1.0 for now.
      +    val votes = Array.fill[Double](numClasses)(0.0)
           _trees.view.foreach { tree =>
      -      val prediction = tree.rootNode.predict(features).toInt
      -      votes(prediction) = votes.getOrElse(prediction, 0.0) + 1.0 // 1.0 = weight
      +      val classCounts: Array[Double] = tree.rootNode.predictImpl(features).impurityStats.stats
      +      val total = classCounts.sum
      +      if (total != 0) {
      +        var i = 0
      +        while (i < numClasses) {
      +          votes(i) += classCounts(i) / total
      +          i += 1
      +        }
      +      }
      +    }
      +    Vectors.dense(votes)
      +  }
      +
      +  override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
      +    rawPrediction match {
      +      case dv: DenseVector =>
      +        ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(dv)
      +        dv
      +      case sv: SparseVector =>
      +        throw new RuntimeException("Unexpected error in RandomForestClassificationModel:" +
      +          " raw2probabilityInPlace encountered SparseVector")
           }
      -    votes.maxBy(_._2)._1
         }
       
         override def copy(extra: ParamMap): RandomForestClassificationModel = {
      -    copyValues(new RandomForestClassificationModel(uid, _trees), extra)
      +    copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra)
      +      .setParent(parent)
         }
       
         override def toString: String = {
      -    s"RandomForestClassificationModel with $numTrees trees"
      +    s"RandomForestClassificationModel (uid=$uid) with $numTrees trees"
         }
       
      +  /**
      +   * Estimate of the importance of each feature.
      +   *
      +   * This generalizes the idea of "Gini" importance to other losses,
      +   * following the explanation of Gini importance from "Random Forests" documentation
      +   * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
      +   *
      +   * This feature importance is calculated as follows:
      +   *  - Average over trees:
      +   *     - importance(feature j) = sum (over nodes which split on feature j) of the gain,
      +   *       where gain is scaled by the number of instances passing through node
      +   *     - Normalize importances for tree based on total number of training instances used
      +   *       to build tree.
      +   *  - Normalize feature importance vector to sum to 1.
      +   */
      +  lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
      +
         /** (private[ml]) Convert to a model in the old API */
         private[ml] def toOld: OldRandomForestModel = {
           new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld))
      @@ -168,7 +225,8 @@ private[ml] object RandomForestClassificationModel {
         def fromOld(
             oldModel: OldRandomForestModel,
             parent: RandomForestClassifier,
      -      categoricalFeatures: Map[Int, Int]): RandomForestClassificationModel = {
      +      categoricalFeatures: Map[Int, Int],
      +      numClasses: Int): RandomForestClassificationModel = {
           require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" +
             s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).")
           val newTrees = oldModel.trees.map { tree =>
      @@ -176,6 +234,6 @@ private[ml] object RandomForestClassificationModel {
             DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
           }
           val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc")
      -    new RandomForestClassificationModel(uid, newTrees)
      +    new RandomForestClassificationModel(uid, newTrees, -1, numClasses)
         }
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
      new file mode 100644
      index 000000000000..f40ab71fb22a
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
      @@ -0,0 +1,200 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.clustering
      +
      +import org.apache.spark.annotation.{Since, Experimental}
      +import org.apache.spark.ml.param.{Param, Params, IntParam, ParamMap}
      +import org.apache.spark.ml.param.shared._
      +import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
      +import org.apache.spark.ml.{Estimator, Model}
      +import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
      +import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
      +import org.apache.spark.sql.functions.{col, udf}
      +import org.apache.spark.sql.types.{IntegerType, StructType}
      +import org.apache.spark.sql.{DataFrame, Row}
      +
      +
      +/**
      + * Common params for KMeans and KMeansModel
      + */
      +private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol
      +  with HasSeed with HasPredictionCol with HasTol {
      +
      +  /**
      +   * Set the number of clusters to create (k). Must be > 1. Default: 2.
      +   * @group param
      +   */
      +  @Since("1.5.0")
      +  final val k = new IntParam(this, "k", "number of clusters to create", (x: Int) => x > 1)
      +
      +  /** @group getParam */
      +  @Since("1.5.0")
      +  def getK: Int = $(k)
      +
      +  /**
      +   * Param for the initialization algorithm. This can be either "random" to choose random points as
      +   * initial cluster centers, or "k-means||" to use a parallel variant of k-means++
      +   * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||.
      +   * @group expertParam
      +   */
      +  @Since("1.5.0")
      +  final val initMode = new Param[String](this, "initMode", "initialization algorithm",
      +    (value: String) => MLlibKMeans.validateInitMode(value))
      +
      +  /** @group expertGetParam */
      +  @Since("1.5.0")
      +  def getInitMode: String = $(initMode)
      +
      +  /**
      +   * Param for the number of steps for the k-means|| initialization mode. This is an advanced
      +   * setting -- the default of 5 is almost always enough. Must be > 0. Default: 5.
      +   * @group expertParam
      +   */
      +  @Since("1.5.0")
      +  final val initSteps = new IntParam(this, "initSteps", "number of steps for k-means||",
      +    (value: Int) => value > 0)
      +
      +  /** @group expertGetParam */
      +  @Since("1.5.0")
      +  def getInitSteps: Int = $(initSteps)
      +
      +  /**
      +   * Validates and transforms the input schema.
      +   * @param schema input schema
      +   * @return output schema
      +   */
      +  protected def validateAndTransformSchema(schema: StructType): StructType = {
      +    SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
      +    SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
      +  }
      +}
      +
      +/**
      + * :: Experimental ::
      + * Model fitted by KMeans.
      + *
      + * @param parentModel a model trained by spark.mllib.clustering.KMeans.
      + */
      +@Since("1.5.0")
      +@Experimental
      +class KMeansModel private[ml] (
      +    @Since("1.5.0") override val uid: String,
      +    private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams {
      +
      +  @Since("1.5.0")
      +  override def copy(extra: ParamMap): KMeansModel = {
      +    val copied = new KMeansModel(uid, parentModel)
      +    copyValues(copied, extra)
      +  }
      +
      +  @Since("1.5.0")
      +  override def transform(dataset: DataFrame): DataFrame = {
      +    val predictUDF = udf((vector: Vector) => predict(vector))
      +    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
      +  }
      +
      +  @Since("1.5.0")
      +  override def transformSchema(schema: StructType): StructType = {
      +    validateAndTransformSchema(schema)
      +  }
      +
      +  private[clustering] def predict(features: Vector): Int = parentModel.predict(features)
      +
      +  @Since("1.5.0")
      +  def clusterCenters: Array[Vector] = parentModel.clusterCenters
      +}
      +
      +/**
      + * :: Experimental ::
      + * K-means clustering with support for k-means|| initialization proposed by Bahmani et al.
      + *
      + * @see [[http://dx.doi.org/10.14778/2180912.2180915 Bahmani et al., Scalable k-means++.]]
      + */
      +@Since("1.5.0")
      +@Experimental
      +class KMeans @Since("1.5.0") (
      +    @Since("1.5.0") override val uid: String)
      +  extends Estimator[KMeansModel] with KMeansParams {
      +
      +  setDefault(
      +    k -> 2,
      +    maxIter -> 20,
      +    initMode -> MLlibKMeans.K_MEANS_PARALLEL,
      +    initSteps -> 5,
      +    tol -> 1e-4)
      +
      +  @Since("1.5.0")
      +  override def copy(extra: ParamMap): KMeans = defaultCopy(extra)
      +
      +  @Since("1.5.0")
      +  def this() = this(Identifiable.randomUID("kmeans"))
      +
      +  /** @group setParam */
      +  @Since("1.5.0")
      +  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
      +
      +  /** @group setParam */
      +  @Since("1.5.0")
      +  def setPredictionCol(value: String): this.type = set(predictionCol, value)
      +
      +  /** @group setParam */
      +  @Since("1.5.0")
      +  def setK(value: Int): this.type = set(k, value)
      +
      +  /** @group expertSetParam */
      +  @Since("1.5.0")
      +  def setInitMode(value: String): this.type = set(initMode, value)
      +
      +  /** @group expertSetParam */
      +  @Since("1.5.0")
      +  def setInitSteps(value: Int): this.type = set(initSteps, value)
      +
      +  /** @group setParam */
      +  @Since("1.5.0")
      +  def setMaxIter(value: Int): this.type = set(maxIter, value)
      +
      +  /** @group setParam */
      +  @Since("1.5.0")
      +  def setTol(value: Double): this.type = set(tol, value)
      +
      +  /** @group setParam */
      +  @Since("1.5.0")
      +  def setSeed(value: Long): this.type = set(seed, value)
      +
      +  @Since("1.5.0")
      +  override def fit(dataset: DataFrame): KMeansModel = {
      +    val rdd = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point }
      +
      +    val algo = new MLlibKMeans()
      +      .setK($(k))
      +      .setInitializationMode($(initMode))
      +      .setInitializationSteps($(initSteps))
      +      .setMaxIterations($(maxIter))
      +      .setSeed($(seed))
      +      .setEpsilon($(tol))
      +    val parentModel = algo.run(rdd)
      +    val model = new KMeansModel(uid, parentModel)
      +    copyValues(model)
      +  }
      +
      +  @Since("1.5.0")
      +  override def transformSchema(schema: StructType): StructType = {
      +    validateAndTransformSchema(schema)
      +  }
      +}
      +
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
      index 4a82b77f0edc..08df2919a8a8 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
      @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.DoubleType
       
       /**
        * :: Experimental ::
      - * Evaluator for binary classification, which expects two input columns: score and label.
      + * Evaluator for binary classification, which expects two input columns: rawPrediction and label.
        */
       @Experimental
       class BinaryClassificationEvaluator(override val uid: String)
      @@ -38,10 +38,14 @@ class BinaryClassificationEvaluator(override val uid: String)
       
         /**
          * param for metric name in evaluation
      +   * Default: areaUnderROC
          * @group param
          */
      -  val metricName: Param[String] = new Param(this, "metricName",
      -    "metric name in evaluation (areaUnderROC|areaUnderPR)")
      +  val metricName: Param[String] = {
      +    val allowedParams = ParamValidators.inArray(Array("areaUnderROC", "areaUnderPR"))
      +    new Param(
      +      this, "metricName", "metric name in evaluation (areaUnderROC|areaUnderPR)", allowedParams)
      +  }
       
         /** @group getParam */
         def getMetricName: String = $(metricName)
      @@ -50,6 +54,13 @@ class BinaryClassificationEvaluator(override val uid: String)
         def setMetricName(value: String): this.type = set(metricName, value)
       
         /** @group setParam */
      +  def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value)
      +
      +  /**
      +   * @group setParam
      +   * @deprecated use [[setRawPredictionCol()]] instead
      +   */
      +  @deprecated("use setRawPredictionCol instead", "1.5.0")
         def setScoreCol(value: String): this.type = set(rawPredictionCol, value)
       
         /** @group setParam */
      @@ -69,16 +80,17 @@ class BinaryClassificationEvaluator(override val uid: String)
             }
           val metrics = new BinaryClassificationMetrics(scoreAndLabels)
           val metric = $(metricName) match {
      -      case "areaUnderROC" =>
      -        metrics.areaUnderROC()
      -      case "areaUnderPR" =>
      -        metrics.areaUnderPR()
      -      case other =>
      -        throw new IllegalArgumentException(s"Does not support metric $other.")
      +      case "areaUnderROC" => metrics.areaUnderROC()
      +      case "areaUnderPR" => metrics.areaUnderPR()
           }
           metrics.unpersist()
           metric
         }
       
      +  override def isLargerBetter: Boolean = $(metricName) match {
      +    case "areaUnderROC" => true
      +    case "areaUnderPR" => true
      +  }
      +
         override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra)
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
      index e56c946a063e..13bd3307f8a2 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
      @@ -46,5 +46,12 @@ abstract class Evaluator extends Params {
          */
         def evaluate(dataset: DataFrame): Double
       
      +  /**
      +   * Indicates whether the metric returned by [[evaluate()]] should be maximized (true, default)
      +   * or minimized (false).
      +   * A given evaluator may support multiple metrics which may be maximized or minimized.
      +   */
      +  def isLargerBetter: Boolean = true
      +
         override def copy(extra: ParamMap): Evaluator
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
      new file mode 100644
      index 000000000000..f73d2345078e
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
      @@ -0,0 +1,93 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.evaluation
      +
      +import org.apache.spark.annotation.Experimental
      +import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param}
      +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
      +import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
      +import org.apache.spark.mllib.evaluation.MulticlassMetrics
      +import org.apache.spark.sql.{Row, DataFrame}
      +import org.apache.spark.sql.types.DoubleType
      +
      +/**
      + * :: Experimental ::
      + * Evaluator for multiclass classification, which expects two input columns: score and label.
      + */
      +@Experimental
      +class MulticlassClassificationEvaluator (override val uid: String)
      +  extends Evaluator with HasPredictionCol with HasLabelCol {
      +
      +  def this() = this(Identifiable.randomUID("mcEval"))
      +
      +  /**
      +   * param for metric name in evaluation (supports `"f1"` (default), `"precision"`, `"recall"`,
      +   * `"weightedPrecision"`, `"weightedRecall"`)
      +   * @group param
      +   */
      +  val metricName: Param[String] = {
      +    val allowedParams = ParamValidators.inArray(Array("f1", "precision",
      +      "recall", "weightedPrecision", "weightedRecall"))
      +    new Param(this, "metricName", "metric name in evaluation " +
      +      "(f1|precision|recall|weightedPrecision|weightedRecall)", allowedParams)
      +  }
      +
      +  /** @group getParam */
      +  def getMetricName: String = $(metricName)
      +
      +  /** @group setParam */
      +  def setMetricName(value: String): this.type = set(metricName, value)
      +
      +  /** @group setParam */
      +  def setPredictionCol(value: String): this.type = set(predictionCol, value)
      +
      +  /** @group setParam */
      +  def setLabelCol(value: String): this.type = set(labelCol, value)
      +
      +  setDefault(metricName -> "f1")
      +
      +  override def evaluate(dataset: DataFrame): Double = {
      +    val schema = dataset.schema
      +    SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
      +    SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
      +
      +    val predictionAndLabels = dataset.select($(predictionCol), $(labelCol))
      +      .map { case Row(prediction: Double, label: Double) =>
      +      (prediction, label)
      +    }
      +    val metrics = new MulticlassMetrics(predictionAndLabels)
      +    val metric = $(metricName) match {
      +      case "f1" => metrics.weightedFMeasure
      +      case "precision" => metrics.precision
      +      case "recall" => metrics.recall
      +      case "weightedPrecision" => metrics.weightedPrecision
      +      case "weightedRecall" => metrics.weightedRecall
      +    }
      +    metric
      +  }
      +
      +  override def isLargerBetter: Boolean = $(metricName) match {
      +    case "f1" => true
      +    case "precision" => true
      +    case "recall" => true
      +    case "weightedPrecision" => true
      +    case "weightedRecall" => true
      +  }
      +
      +  override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra)
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
      index 01c000b47514..d21c88ab9b10 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
      @@ -73,17 +73,20 @@ final class RegressionEvaluator(override val uid: String)
             }
           val metrics = new RegressionMetrics(predictionAndLabels)
           val metric = $(metricName) match {
      -      case "rmse" =>
      -        -metrics.rootMeanSquaredError
      -      case "mse" =>
      -        -metrics.meanSquaredError
      -      case "r2" =>
      -        metrics.r2
      -      case "mae" =>
      -        -metrics.meanAbsoluteError
      +      case "rmse" => metrics.rootMeanSquaredError
      +      case "mse" => metrics.meanSquaredError
      +      case "r2" => metrics.r2
      +      case "mae" => metrics.meanAbsoluteError
           }
           metric
         }
       
      +  override def isLargerBetter: Boolean = $(metricName) match {
      +    case "rmse" => false
      +    case "mse" => false
      +    case "r2" => true
      +    case "mae" => false
      +  }
      +
         override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra)
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
      index 46314854d5e3..edad75443645 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
      @@ -41,6 +41,7 @@ final class Binarizer(override val uid: String)
          * Param for threshold used to binarize continuous features.
          * The features greater than the threshold, will be binarized to 1.0.
          * The features equal to or less than the threshold, will be binarized to 0.0.
      +   * Default: 0.0
          * @group param
          */
         val threshold: DoubleParam =
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
      index 67e4785bc355..6fdf25b015b0 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
      @@ -75,7 +75,7 @@ final class Bucketizer(override val uid: String)
           }
           val newCol = bucketizer(dataset($(inputCol)))
           val newField = prepOutputField(dataset.schema)
      -    dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
      +    dataset.withColumn($(outputCol), newCol, newField.metadata)
         }
       
         private def prepOutputField(schema: StructType): StructField = {
      @@ -90,7 +90,9 @@ final class Bucketizer(override val uid: String)
           SchemaUtils.appendColumn(schema, prepOutputField(schema))
         }
       
      -  override def copy(extra: ParamMap): Bucketizer = defaultCopy(extra)
      +  override def copy(extra: ParamMap): Bucketizer = {
      +    defaultCopy[Bucketizer](extra).setParent(parent)
      +  }
       }
       
       private[feature] object Bucketizer {
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
      new file mode 100644
      index 000000000000..49028e4b8506
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
      @@ -0,0 +1,235 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +package org.apache.spark.ml.feature
      +
      +import org.apache.spark.annotation.Experimental
      +import org.apache.spark.broadcast.Broadcast
      +import org.apache.spark.ml.param._
      +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
      +import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
      +import org.apache.spark.ml.{Estimator, Model}
      +import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
      +import org.apache.spark.rdd.RDD
      +import org.apache.spark.sql.functions._
      +import org.apache.spark.sql.types._
      +import org.apache.spark.sql.DataFrame
      +import org.apache.spark.util.collection.OpenHashMap
      +
      +/**
      + * Params for [[CountVectorizer]] and [[CountVectorizerModel]].
      + */
      +private[feature] trait CountVectorizerParams extends Params with HasInputCol with HasOutputCol {
      +
      +  /**
      +   * Max size of the vocabulary.
      +   * CountVectorizer will build a vocabulary that only considers the top
      +   * vocabSize terms ordered by term frequency across the corpus.
      +   *
      +   * Default: 2^18^
      +   * @group param
      +   */
      +  val vocabSize: IntParam =
      +    new IntParam(this, "vocabSize", "max size of the vocabulary", ParamValidators.gt(0))
      +
      +  /** @group getParam */
      +  def getVocabSize: Int = $(vocabSize)
      +
      +  /**
      +   * Specifies the minimum number of different documents a term must appear in to be included
      +   * in the vocabulary.
      +   * If this is an integer >= 1, this specifies the number of documents the term must appear in;
      +   * if this is a double in [0,1), then this specifies the fraction of documents.
      +   *
      +   * Default: 1
      +   * @group param
      +   */
      +  val minDF: DoubleParam = new DoubleParam(this, "minDF", "Specifies the minimum number of" +
      +    " different documents a term must appear in to be included in the vocabulary." +
      +    " If this is an integer >= 1, this specifies the number of documents the term must" +
      +    " appear in; if this is a double in [0,1), then this specifies the fraction of documents.",
      +    ParamValidators.gtEq(0.0))
      +
      +  /** @group getParam */
      +  def getMinDF: Double = $(minDF)
      +
      +  /** Validates and transforms the input schema. */
      +  protected def validateAndTransformSchema(schema: StructType): StructType = {
      +    SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true))
      +    SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
      +  }
      +
      +  /**
      +   * Filter to ignore rare words in a document. For each document, terms with
      +   * frequency/count less than the given threshold are ignored.
      +   * If this is an integer >= 1, then this specifies a count (of times the term must appear
      +   * in the document);
      +   * if this is a double in [0,1), then this specifies a fraction (out of the document's token
      +   * count).
      +   *
      +   * Note that the parameter is only used in transform of [[CountVectorizerModel]] and does not
      +   * affect fitting.
      +   *
      +   * Default: 1
      +   * @group param
      +   */
      +  val minTF: DoubleParam = new DoubleParam(this, "minTF", "Filter to ignore rare words in" +
      +    " a document. For each document, terms with frequency/count less than the given threshold are" +
      +    " ignored. If this is an integer >= 1, then this specifies a count (of times the term must" +
      +    " appear in the document); if this is a double in [0,1), then this specifies a fraction (out" +
      +    " of the document's token count). Note that the parameter is only used in transform of" +
      +    " CountVectorizerModel and does not affect fitting.", ParamValidators.gtEq(0.0))
      +
      +  setDefault(minTF -> 1)
      +
      +  /** @group getParam */
      +  def getMinTF: Double = $(minTF)
      +}
      +
      +/**
      + * :: Experimental ::
      + * Extracts a vocabulary from document collections and generates a [[CountVectorizerModel]].
      + */
      +@Experimental
      +class CountVectorizer(override val uid: String)
      +  extends Estimator[CountVectorizerModel] with CountVectorizerParams {
      +
      +  def this() = this(Identifiable.randomUID("cntVec"))
      +
      +  /** @group setParam */
      +  def setInputCol(value: String): this.type = set(inputCol, value)
      +
      +  /** @group setParam */
      +  def setOutputCol(value: String): this.type = set(outputCol, value)
      +
      +  /** @group setParam */
      +  def setVocabSize(value: Int): this.type = set(vocabSize, value)
      +
      +  /** @group setParam */
      +  def setMinDF(value: Double): this.type = set(minDF, value)
      +
      +  /** @group setParam */
      +  def setMinTF(value: Double): this.type = set(minTF, value)
      +
      +  setDefault(vocabSize -> (1 << 18), minDF -> 1)
      +
      +  override def fit(dataset: DataFrame): CountVectorizerModel = {
      +    transformSchema(dataset.schema, logging = true)
      +    val vocSize = $(vocabSize)
      +    val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0))
      +    val minDf = if ($(minDF) >= 1.0) {
      +      $(minDF)
      +    } else {
      +      $(minDF) * input.cache().count()
      +    }
      +    val wordCounts: RDD[(String, Long)] = input.flatMap { case (tokens) =>
      +      val wc = new OpenHashMap[String, Long]
      +      tokens.foreach { w =>
      +        wc.changeValue(w, 1L, _ + 1L)
      +      }
      +      wc.map { case (word, count) => (word, (count, 1)) }
      +    }.reduceByKey { case ((wc1, df1), (wc2, df2)) =>
      +      (wc1 + wc2, df1 + df2)
      +    }.filter { case (word, (wc, df)) =>
      +      df >= minDf
      +    }.map { case (word, (count, dfCount)) =>
      +      (word, count)
      +    }.cache()
      +    val fullVocabSize = wordCounts.count()
      +    val vocab: Array[String] = {
      +      val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) {
      +        // Use all terms
      +        wordCounts.collect().sortBy(-_._2)
      +      } else {
      +        // Sort terms to select vocab
      +        wordCounts.sortBy(_._2, ascending = false).take(vocSize)
      +      }
      +      tmpSortedWC.map(_._1)
      +    }
      +
      +    require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.")
      +    copyValues(new CountVectorizerModel(uid, vocab).setParent(this))
      +  }
      +
      +  override def transformSchema(schema: StructType): StructType = {
      +    validateAndTransformSchema(schema)
      +  }
      +
      +  override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra)
      +}
      +
      +/**
      + * :: Experimental ::
      + * Converts a text document to a sparse vector of token counts.
      + * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted.
      + */
      +@Experimental
      +class CountVectorizerModel(override val uid: String, val vocabulary: Array[String])
      +  extends Model[CountVectorizerModel] with CountVectorizerParams {
      +
      +  def this(vocabulary: Array[String]) = {
      +    this(Identifiable.randomUID("cntVecModel"), vocabulary)
      +    set(vocabSize, vocabulary.length)
      +  }
      +
      +  /** @group setParam */
      +  def setInputCol(value: String): this.type = set(inputCol, value)
      +
      +  /** @group setParam */
      +  def setOutputCol(value: String): this.type = set(outputCol, value)
      +
      +  /** @group setParam */
      +  def setMinTF(value: Double): this.type = set(minTF, value)
      +
      +  /** Dictionary created from [[vocabulary]] and its indices, broadcast once for [[transform()]] */
      +  private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None
      +
      +  override def transform(dataset: DataFrame): DataFrame = {
      +    if (broadcastDict.isEmpty) {
      +      val dict = vocabulary.zipWithIndex.toMap
      +      broadcastDict = Some(dataset.sqlContext.sparkContext.broadcast(dict))
      +    }
      +    val dictBr = broadcastDict.get
      +    val minTf = $(minTF)
      +    val vectorizer = udf { (document: Seq[String]) =>
      +      val termCounts = new OpenHashMap[Int, Double]
      +      var tokenCount = 0L
      +      document.foreach { term =>
      +        dictBr.value.get(term) match {
      +          case Some(index) => termCounts.changeValue(index, 1.0, _ + 1.0)
      +          case None => // ignore terms not in the vocabulary
      +        }
      +        tokenCount += 1
      +      }
      +      val effectiveMinTF = if (minTf >= 1.0) {
      +        minTf
      +      } else {
      +        tokenCount * minTf
      +      }
      +      Vectors.sparse(dictBr.value.size, termCounts.filter(_._2 >= effectiveMinTF).toSeq)
      +    }
      +    dataset.withColumn($(outputCol), vectorizer(col($(inputCol))))
      +  }
      +
      +  override def transformSchema(schema: StructType): StructType = {
      +    validateAndTransformSchema(schema)
      +  }
      +
      +  override def copy(extra: ParamMap): CountVectorizerModel = {
      +    val copied = new CountVectorizerModel(uid, vocabulary).setParent(parent)
      +    copyValues(copied, extra)
      +  }
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
      new file mode 100644
      index 000000000000..228347635c92
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
      @@ -0,0 +1,72 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.feature
      +
      +import edu.emory.mathcs.jtransforms.dct._
      +
      +import org.apache.spark.annotation.Experimental
      +import org.apache.spark.ml.UnaryTransformer
      +import org.apache.spark.ml.param.BooleanParam
      +import org.apache.spark.ml.util.Identifiable
      +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
      +import org.apache.spark.sql.types.DataType
      +
      +/**
      + * :: Experimental ::
      + * A feature transformer that takes the 1D discrete cosine transform of a real vector. No zero
      + * padding is performed on the input vector.
      + * It returns a real vector of the same length representing the DCT. The return vector is scaled
      + * such that the transform matrix is unitary (aka scaled DCT-II).
      + *
      + * More information on [[https://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II Wikipedia]].
      + */
      +@Experimental
      +class DCT(override val uid: String)
      +  extends UnaryTransformer[Vector, Vector, DCT] {
      +
      +  def this() = this(Identifiable.randomUID("dct"))
      +
      +  /**
      +   * Indicates whether to perform the inverse DCT (true) or forward DCT (false).
      +   * Default: false
      +   * @group param
      +   */
      +  def inverse: BooleanParam = new BooleanParam(
      +    this, "inverse", "Set transformer to perform inverse DCT")
      +
      +  /** @group setParam */
      +  def setInverse(value: Boolean): this.type = set(inverse, value)
      +
      +  /** @group getParam */
      +  def getInverse: Boolean = $(inverse)
      +
      +  setDefault(inverse -> false)
      +
      +  override protected def createTransformFunc: Vector => Vector = { vec =>
      +    val result = vec.toArray
      +    val jTransformer = new DoubleDCT_1D(result.length)
      +    if ($(inverse)) jTransformer.inverse(result, true) else jTransformer.forward(result, true)
      +    Vectors.dense(result)
      +  }
      +
      +  override protected def validateInputType(inputType: DataType): Unit = {
      +    require(inputType.isInstanceOf[VectorUDT], s"Input type must be VectorUDT but got $inputType.")
      +  }
      +
      +  override protected def outputDataType: DataType = new VectorUDT
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
      index ecde80810580..4c36df75d8aa 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
      @@ -35,6 +35,7 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
       
         /**
          * The minimum of documents in which a term should appear.
      +   * Default: 0
          * @group param
          */
         final val minDocFreq = new IntParam(
      @@ -114,6 +115,6 @@ class IDFModel private[ml] (
       
         override def copy(extra: ParamMap): IDFModel = {
           val copied = new IDFModel(uid, idfModel)
      -    copyValues(copied, extra)
      +    copyValues(copied, extra).setParent(parent)
         }
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
      new file mode 100644
      index 000000000000..1b494ec8b172
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
      @@ -0,0 +1,178 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.feature
      +
      +import org.apache.spark.annotation.Experimental
      +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
      +import org.apache.spark.ml.param.{ParamMap, DoubleParam, Params}
      +import org.apache.spark.ml.util.Identifiable
      +import org.apache.spark.ml.{Estimator, Model}
      +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
      +import org.apache.spark.mllib.stat.Statistics
      +import org.apache.spark.sql._
      +import org.apache.spark.sql.functions._
      +import org.apache.spark.sql.types.{StructField, StructType}
      +
      +/**
      + * Params for [[MinMaxScaler]] and [[MinMaxScalerModel]].
      + */
      +private[feature] trait MinMaxScalerParams extends Params with HasInputCol with HasOutputCol {
      +
      +  /**
      +   * lower bound after transformation, shared by all features
      +   * Default: 0.0
      +   * @group param
      +   */
      +  val min: DoubleParam = new DoubleParam(this, "min",
      +    "lower bound of the output feature range")
      +
      +  /** @group getParam */
      +  def getMin: Double = $(min)
      +
      +  /**
      +   * upper bound after transformation, shared by all features
      +   * Default: 1.0
      +   * @group param
      +   */
      +  val max: DoubleParam = new DoubleParam(this, "max",
      +    "upper bound of the output feature range")
      +
      +  /** @group getParam */
      +  def getMax: Double = $(max)
      +
      +  /** Validates and transforms the input schema. */
      +  protected def validateAndTransformSchema(schema: StructType): StructType = {
      +    val inputType = schema($(inputCol)).dataType
      +    require(inputType.isInstanceOf[VectorUDT],
      +      s"Input column ${$(inputCol)} must be a vector column")
      +    require(!schema.fieldNames.contains($(outputCol)),
      +      s"Output column ${$(outputCol)} already exists.")
      +    val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
      +    StructType(outputFields)
      +  }
      +
      +  override def validateParams(): Unit = {
      +    require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})")
      +  }
      +}
      +
      +/**
      + * :: Experimental ::
      + * Rescale each feature individually to a common range [min, max] linearly using column summary
      + * statistics, which is also known as min-max normalization or Rescaling. The rescaled value for
      + * feature E is calculated as,
      + *
      + * Rescaled(e_i) = \frac{e_i - E_{min}}{E_{max} - E_{min}} * (max - min) + min
      + *
      + * For the case E_{max} == E_{min}, Rescaled(e_i) = 0.5 * (max + min)
      + * Note that since zero values will probably be transformed to non-zero values, output of the
      + * transformer will be DenseVector even for sparse input.
      + */
      +@Experimental
      +class MinMaxScaler(override val uid: String)
      +  extends Estimator[MinMaxScalerModel] with MinMaxScalerParams {
      +
      +  def this() = this(Identifiable.randomUID("minMaxScal"))
      +
      +  setDefault(min -> 0.0, max -> 1.0)
      +
      +  /** @group setParam */
      +  def setInputCol(value: String): this.type = set(inputCol, value)
      +
      +  /** @group setParam */
      +  def setOutputCol(value: String): this.type = set(outputCol, value)
      +
      +  /** @group setParam */
      +  def setMin(value: Double): this.type = set(min, value)
      +
      +  /** @group setParam */
      +  def setMax(value: Double): this.type = set(max, value)
      +
      +  override def fit(dataset: DataFrame): MinMaxScalerModel = {
      +    transformSchema(dataset.schema, logging = true)
      +    val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
      +    val summary = Statistics.colStats(input)
      +    copyValues(new MinMaxScalerModel(uid, summary.min, summary.max).setParent(this))
      +  }
      +
      +  override def transformSchema(schema: StructType): StructType = {
      +    validateAndTransformSchema(schema)
      +  }
      +
      +  override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra)
      +}
      +
      +/**
      + * :: Experimental ::
      + * Model fitted by [[MinMaxScaler]].
      + *
      + * @param originalMin min value for each original column during fitting
      + * @param originalMax max value for each original column during fitting
      + *
      + * TODO: The transformer does not yet set the metadata in the output column (SPARK-8529).
      + */
      +@Experimental
      +class MinMaxScalerModel private[ml] (
      +    override val uid: String,
      +    val originalMin: Vector,
      +    val originalMax: Vector)
      +  extends Model[MinMaxScalerModel] with MinMaxScalerParams {
      +
      +  /** @group setParam */
      +  def setInputCol(value: String): this.type = set(inputCol, value)
      +
      +  /** @group setParam */
      +  def setOutputCol(value: String): this.type = set(outputCol, value)
      +
      +  /** @group setParam */
      +  def setMin(value: Double): this.type = set(min, value)
      +
      +  /** @group setParam */
      +  def setMax(value: Double): this.type = set(max, value)
      +
      +  override def transform(dataset: DataFrame): DataFrame = {
      +    val originalRange = (originalMax.toBreeze - originalMin.toBreeze).toArray
      +    val minArray = originalMin.toArray
      +
      +    val reScale = udf { (vector: Vector) =>
      +      val scale = $(max) - $(min)
      +
      +      // 0 in sparse vector will probably be rescaled to non-zero
      +      val values = vector.toArray
      +      val size = values.size
      +      var i = 0
      +      while (i < size) {
      +        val raw = if (originalRange(i) != 0) (values(i) - minArray(i)) / originalRange(i) else 0.5
      +        values(i) = raw * scale + $(min)
      +        i += 1
      +      }
      +      Vectors.dense(values)
      +    }
      +
      +    dataset.withColumn($(outputCol), reScale(col($(inputCol))))
      +  }
      +
      +  override def transformSchema(schema: StructType): StructType = {
      +    validateAndTransformSchema(schema)
      +  }
      +
      +  override def copy(extra: ParamMap): MinMaxScalerModel = {
      +    val copied = new MinMaxScalerModel(uid, originalMin, originalMax)
      +    copyValues(copied, extra).setParent(parent)
      +  }
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
      new file mode 100644
      index 000000000000..8de10eb51f92
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
      @@ -0,0 +1,69 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.feature
      +
      +import org.apache.spark.annotation.Experimental
      +import org.apache.spark.ml.UnaryTransformer
      +import org.apache.spark.ml.param._
      +import org.apache.spark.ml.util.Identifiable
      +import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
      +
      +/**
      + * :: Experimental ::
      + * A feature transformer that converts the input array of strings into an array of n-grams. Null
      + * values in the input array are ignored.
      + * It returns an array of n-grams where each n-gram is represented by a space-separated string of
      + * words.
      + *
      + * When the input is empty, an empty array is returned.
      + * When the input array length is less than n (number of elements per n-gram), no n-grams are
      + * returned.
      + */
      +@Experimental
      +class NGram(override val uid: String)
      +  extends UnaryTransformer[Seq[String], Seq[String], NGram] {
      +
      +  def this() = this(Identifiable.randomUID("ngram"))
      +
      +  /**
      +   * Minimum n-gram length, >= 1.
      +   * Default: 2, bigram features
      +   * @group param
      +   */
      +  val n: IntParam = new IntParam(this, "n", "number elements per n-gram (>=1)",
      +    ParamValidators.gtEq(1))
      +
      +  /** @group setParam */
      +  def setN(value: Int): this.type = set(n, value)
      +
      +  /** @group getParam */
      +  def getN: Int = $(n)
      +
      +  setDefault(n -> 2)
      +
      +  override protected def createTransformFunc: Seq[String] => Seq[String] = {
      +    _.iterator.sliding($(n)).withPartial(false).map(_.mkString(" ")).toSeq
      +  }
      +
      +  override protected def validateInputType(inputType: DataType): Unit = {
      +    require(inputType.sameType(ArrayType(StringType)),
      +      s"Input type must be ArrayType(StringType) but got $inputType.")
      +  }
      +
      +  override protected def outputDataType: DataType = new ArrayType(StringType, false)
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
      index 382594279564..9c60d4084ec4 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
      @@ -66,7 +66,6 @@ class OneHotEncoder(override val uid: String) extends Transformer
         def setOutputCol(value: String): this.type = set(outputCol, value)
       
         override def transformSchema(schema: StructType): StructType = {
      -    val is = "_is_"
           val inputColName = $(inputCol)
           val outputColName = $(outputCol)
       
      @@ -79,17 +78,17 @@ class OneHotEncoder(override val uid: String) extends Transformer
           val outputAttrNames: Option[Array[String]] = inputAttr match {
             case nominal: NominalAttribute =>
               if (nominal.values.isDefined) {
      -          nominal.values.map(_.map(v => inputColName + is + v))
      +          nominal.values
               } else if (nominal.numValues.isDefined) {
      -          nominal.numValues.map(n => Array.tabulate(n)(i => inputColName + is + i))
      +          nominal.numValues.map(n => Array.tabulate(n)(_.toString))
               } else {
                 None
               }
             case binary: BinaryAttribute =>
               if (binary.values.isDefined) {
      -          binary.values.map(_.map(v => inputColName + is + v))
      +          binary.values
               } else {
      -          Some(Array.tabulate(2)(i => inputColName + is + i))
      +          Some(Array.tabulate(2)(_.toString))
               }
             case _: NumericAttribute =>
               throw new RuntimeException(
      @@ -123,7 +122,6 @@ class OneHotEncoder(override val uid: String) extends Transformer
       
         override def transform(dataset: DataFrame): DataFrame = {
           // schema transformation
      -    val is = "_is_"
           val inputColName = $(inputCol)
           val outputColName = $(outputCol)
           val shouldDropLast = $(dropLast)
      @@ -142,7 +140,7 @@ class OneHotEncoder(override val uid: String) extends Transformer
                   math.max(m0, m1)
                 }
               ).toInt + 1
      -      val outputAttrNames = Array.tabulate(numAttrs)(i => inputColName + is + i)
      +      val outputAttrNames = Array.tabulate(numAttrs)(_.toString)
             val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames
             val outputAttrs: Array[Attribute] =
               filtered.map(name => BinaryAttribute.defaultAttr.withName(name))
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
      new file mode 100644
      index 000000000000..539084704b65
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
      @@ -0,0 +1,130 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.feature
      +
      +import org.apache.spark.annotation.Experimental
      +import org.apache.spark.ml._
      +import org.apache.spark.ml.param._
      +import org.apache.spark.ml.param.shared._
      +import org.apache.spark.ml.util.Identifiable
      +import org.apache.spark.mllib.feature
      +import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
      +import org.apache.spark.sql._
      +import org.apache.spark.sql.functions._
      +import org.apache.spark.sql.types.{StructField, StructType}
      +
      +/**
      + * Params for [[PCA]] and [[PCAModel]].
      + */
      +private[feature] trait PCAParams extends Params with HasInputCol with HasOutputCol {
      +
      +  /**
      +   * The number of principal components.
      +   * @group param
      +   */
      +  final val k: IntParam = new IntParam(this, "k", "the number of principal components")
      +
      +  /** @group getParam */
      +  def getK: Int = $(k)
      +
      +}
      +
      +/**
      + * :: Experimental ::
      + * PCA trains a model to project vectors to a low-dimensional space using PCA.
      + */
      +@Experimental
      +class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams {
      +
      +  def this() = this(Identifiable.randomUID("pca"))
      +
      +  /** @group setParam */
      +  def setInputCol(value: String): this.type = set(inputCol, value)
      +
      +  /** @group setParam */
      +  def setOutputCol(value: String): this.type = set(outputCol, value)
      +
      +  /** @group setParam */
      +  def setK(value: Int): this.type = set(k, value)
      +
      +  /**
      +   * Computes a [[PCAModel]] that contains the principal components of the input vectors.
      +   */
      +  override def fit(dataset: DataFrame): PCAModel = {
      +    transformSchema(dataset.schema, logging = true)
      +    val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v}
      +    val pca = new feature.PCA(k = $(k))
      +    val pcaModel = pca.fit(input)
      +    copyValues(new PCAModel(uid, pcaModel).setParent(this))
      +  }
      +
      +  override def transformSchema(schema: StructType): StructType = {
      +    val inputType = schema($(inputCol)).dataType
      +    require(inputType.isInstanceOf[VectorUDT],
      +      s"Input column ${$(inputCol)} must be a vector column")
      +    require(!schema.fieldNames.contains($(outputCol)),
      +      s"Output column ${$(outputCol)} already exists.")
      +    val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
      +    StructType(outputFields)
      +  }
      +
      +  override def copy(extra: ParamMap): PCA = defaultCopy(extra)
      +}
      +
      +/**
      + * :: Experimental ::
      + * Model fitted by [[PCA]].
      + */
      +@Experimental
      +class PCAModel private[ml] (
      +    override val uid: String,
      +    pcaModel: feature.PCAModel)
      +  extends Model[PCAModel] with PCAParams {
      +
      +  /** @group setParam */
      +  def setInputCol(value: String): this.type = set(inputCol, value)
      +
      +  /** @group setParam */
      +  def setOutputCol(value: String): this.type = set(outputCol, value)
      +
      +  /**
      +   * Transform a vector by computed Principal Components.
      +   * NOTE: Vectors to be transformed must be the same length
      +   * as the source vectors given to [[PCA.fit()]].
      +   */
      +  override def transform(dataset: DataFrame): DataFrame = {
      +    transformSchema(dataset.schema, logging = true)
      +    val pcaOp = udf { pcaModel.transform _ }
      +    dataset.withColumn($(outputCol), pcaOp(col($(inputCol))))
      +  }
      +
      +  override def transformSchema(schema: StructType): StructType = {
      +    val inputType = schema($(inputCol)).dataType
      +    require(inputType.isInstanceOf[VectorUDT],
      +      s"Input column ${$(inputCol)} must be a vector column")
      +    require(!schema.fieldNames.contains($(outputCol)),
      +      s"Output column ${$(outputCol)} already exists.")
      +    val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
      +    StructType(outputFields)
      +  }
      +
      +  override def copy(extra: ParamMap): PCAModel = {
      +    val copied = new PCAModel(uid, pcaModel)
      +    copyValues(copied, extra).setParent(parent)
      +  }
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
      new file mode 100644
      index 000000000000..dcd6fe3c406a
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
      @@ -0,0 +1,220 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.feature
      +
      +import scala.collection.mutable
      +import scala.collection.mutable.ArrayBuffer
      +
      +import org.apache.spark.annotation.Experimental
      +import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer}
      +import org.apache.spark.ml.param.{Param, ParamMap}
      +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
      +import org.apache.spark.ml.util.Identifiable
      +import org.apache.spark.mllib.linalg.VectorUDT
      +import org.apache.spark.sql.DataFrame
      +import org.apache.spark.sql.types._
      +
      +/**
      + * Base trait for [[RFormula]] and [[RFormulaModel]].
      + */
      +private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol {
      +
      +  protected def hasLabelCol(schema: StructType): Boolean = {
      +    schema.map(_.name).contains($(labelCol))
      +  }
      +}
      +
      +/**
      + * :: Experimental ::
      + * Implements the transforms required for fitting a dataset against an R model formula. Currently
      + * we support a limited subset of the R operators, including '.', '~', '+', and '-'. Also see the
      + * R formula docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
      + */
      +@Experimental
      +class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase {
      +
      +  def this() = this(Identifiable.randomUID("rFormula"))
      +
      +  /**
      +   * R formula parameter. The formula is provided in string form.
      +   * @group param
      +   */
      +  val formula: Param[String] = new Param(this, "formula", "R model formula")
      +
      +  /**
      +   * Sets the formula to use for this transformer. Must be called before use.
      +   * @group setParam
      +   * @param value an R formula in string form (e.g. "y ~ x + z")
      +   */
      +  def setFormula(value: String): this.type = set(formula, value)
      +
      +  /** @group getParam */
      +  def getFormula: String = $(formula)
      +
      +  /** @group setParam */
      +  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
      +
      +  /** @group setParam */
      +  def setLabelCol(value: String): this.type = set(labelCol, value)
      +
      +  /** Whether the formula specifies fitting an intercept. */
      +  private[ml] def hasIntercept: Boolean = {
      +    require(isDefined(formula), "Formula must be defined first.")
      +    RFormulaParser.parse($(formula)).hasIntercept
      +  }
      +
      +  override def fit(dataset: DataFrame): RFormulaModel = {
      +    require(isDefined(formula), "Formula must be defined first.")
      +    val parsedFormula = RFormulaParser.parse($(formula))
      +    val resolvedFormula = parsedFormula.resolve(dataset.schema)
      +    // StringType terms and terms representing interactions need to be encoded before assembly.
      +    // TODO(ekl) add support for feature interactions
      +    val encoderStages = ArrayBuffer[PipelineStage]()
      +    val tempColumns = ArrayBuffer[String]()
      +    val takenNames = mutable.Set(dataset.columns: _*)
      +    val encodedTerms = resolvedFormula.terms.map { term =>
      +      dataset.schema(term) match {
      +        case column if column.dataType == StringType =>
      +          val indexCol = term + "_idx_" + uid
      +          val encodedCol = {
      +            var tmp = term
      +            while (takenNames.contains(tmp)) {
      +              tmp += "_"
      +            }
      +            tmp
      +          }
      +          takenNames.add(indexCol)
      +          takenNames.add(encodedCol)
      +          encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol)
      +          encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol)
      +          tempColumns += indexCol
      +          tempColumns += encodedCol
      +          encodedCol
      +        case _ =>
      +          term
      +      }
      +    }
      +    encoderStages += new VectorAssembler(uid)
      +      .setInputCols(encodedTerms.toArray)
      +      .setOutputCol($(featuresCol))
      +    encoderStages += new ColumnPruner(tempColumns.toSet)
      +    val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset)
      +    copyValues(new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(this))
      +  }
      +
      +  // optimistic schema; does not contain any ML attributes
      +  override def transformSchema(schema: StructType): StructType = {
      +    if (hasLabelCol(schema)) {
      +      StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true))
      +    } else {
      +      StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true) :+
      +        StructField($(labelCol), DoubleType, true))
      +    }
      +  }
      +
      +  override def copy(extra: ParamMap): RFormula = defaultCopy(extra)
      +
      +  override def toString: String = s"RFormula(${get(formula)}) (uid=$uid)"
      +}
      +
      +/**
      + * :: Experimental ::
      + * A fitted RFormula. Fitting is required to determine the factor levels of formula terms.
      + * @param resolvedFormula the fitted R formula.
      + * @param pipelineModel the fitted feature model, including factor to index mappings.
      + */
      +@Experimental
      +class RFormulaModel private[feature](
      +    override val uid: String,
      +    resolvedFormula: ResolvedRFormula,
      +    pipelineModel: PipelineModel)
      +  extends Model[RFormulaModel] with RFormulaBase {
      +
      +  override def transform(dataset: DataFrame): DataFrame = {
      +    checkCanTransform(dataset.schema)
      +    transformLabel(pipelineModel.transform(dataset))
      +  }
      +
      +  override def transformSchema(schema: StructType): StructType = {
      +    checkCanTransform(schema)
      +    val withFeatures = pipelineModel.transformSchema(schema)
      +    if (hasLabelCol(schema)) {
      +      withFeatures
      +    } else if (schema.exists(_.name == resolvedFormula.label)) {
      +      val nullable = schema(resolvedFormula.label).dataType match {
      +        case _: NumericType | BooleanType => false
      +        case _ => true
      +      }
      +      StructType(withFeatures.fields :+ StructField($(labelCol), DoubleType, nullable))
      +    } else {
      +      // Ignore the label field. This is a hack so that this transformer can also work on test
      +      // datasets in a Pipeline.
      +      withFeatures
      +    }
      +  }
      +
      +  override def copy(extra: ParamMap): RFormulaModel = copyValues(
      +    new RFormulaModel(uid, resolvedFormula, pipelineModel))
      +
      +  override def toString: String = s"RFormulaModel(${resolvedFormula}) (uid=$uid)"
      +
      +  private def transformLabel(dataset: DataFrame): DataFrame = {
      +    val labelName = resolvedFormula.label
      +    if (hasLabelCol(dataset.schema)) {
      +      dataset
      +    } else if (dataset.schema.exists(_.name == labelName)) {
      +      dataset.schema(labelName).dataType match {
      +        case _: NumericType | BooleanType =>
      +          dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType))
      +        case other =>
      +          throw new IllegalArgumentException("Unsupported type for label: " + other)
      +      }
      +    } else {
      +      // Ignore the label field. This is a hack so that this transformer can also work on test
      +      // datasets in a Pipeline.
      +      dataset
      +    }
      +  }
      +
      +  private def checkCanTransform(schema: StructType) {
      +    val columnNames = schema.map(_.name)
      +    require(!columnNames.contains($(featuresCol)), "Features column already exists.")
      +    require(
      +      !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType,
      +      "Label column already exists and is not of type DoubleType.")
      +  }
      +}
      +
      +/**
      + * Utility transformer for removing temporary columns from a DataFrame.
      + * TODO(ekl) make this a public transformer
      + */
      +private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
      +  override val uid = Identifiable.randomUID("columnPruner")
      +
      +  override def transform(dataset: DataFrame): DataFrame = {
      +    val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_))
      +    dataset.select(columnsToKeep.map(dataset.col) : _*)
      +  }
      +
      +  override def transformSchema(schema: StructType): StructType = {
      +    StructType(schema.fields.filter(col => !columnsToPrune.contains(col.name)))
      +  }
      +
      +  override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra)
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
      new file mode 100644
      index 000000000000..1ca3b92a7d92
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
      @@ -0,0 +1,129 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.feature
      +
      +import scala.util.parsing.combinator.RegexParsers
      +
      +import org.apache.spark.mllib.linalg.VectorUDT
      +import org.apache.spark.sql.types._
      +
      +/**
      + * Represents a parsed R formula.
      + */
      +private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
      +  /**
      +   * Resolves formula terms into column names. A schema is necessary for inferring the meaning
      +   * of the special '.' term. Duplicate terms will be removed during resolution.
      +   */
      +  def resolve(schema: StructType): ResolvedRFormula = {
      +    var includedTerms = Seq[String]()
      +    terms.foreach {
      +      case Dot =>
      +        includedTerms ++= simpleTypes(schema).filter(_ != label.value)
      +      case ColumnRef(value) =>
      +        includedTerms :+= value
      +      case Deletion(term: Term) =>
      +        term match {
      +          case ColumnRef(value) =>
      +            includedTerms = includedTerms.filter(_ != value)
      +          case Dot =>
      +            // e.g. "- .", which removes all first-order terms
      +            val fromSchema = simpleTypes(schema)
      +            includedTerms = includedTerms.filter(fromSchema.contains(_))
      +          case _: Deletion =>
      +            assert(false, "Deletion terms cannot be nested")
      +          case _: Intercept =>
      +        }
      +      case _: Intercept =>
      +    }
      +    ResolvedRFormula(label.value, includedTerms.distinct)
      +  }
      +
      +  /** Whether this formula specifies fitting with an intercept term. */
      +  def hasIntercept: Boolean = {
      +    var intercept = true
      +    terms.foreach {
      +      case Intercept(enabled) =>
      +        intercept = enabled
      +      case Deletion(Intercept(enabled)) =>
      +        intercept = !enabled
      +      case _ =>
      +    }
      +    intercept
      +  }
      +
      +  // the dot operator excludes complex column types
      +  private def simpleTypes(schema: StructType): Seq[String] = {
      +    schema.fields.filter(_.dataType match {
      +      case _: NumericType | StringType | BooleanType | _: VectorUDT => true
      +      case _ => false
      +    }).map(_.name)
      +  }
      +}
      +
      +/**
      + * Represents a fully evaluated and simplified R formula.
      + */
      +private[ml] case class ResolvedRFormula(label: String, terms: Seq[String])
      +
      +/**
      + * R formula terms. See the R formula docs here for more information:
      + * http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
      + */
      +private[ml] sealed trait Term
      +
      +/* R formula reference to all available columns, e.g. "." in a formula */
      +private[ml] case object Dot extends Term
      +
      +/* R formula reference to a column, e.g. "+ Species" in a formula */
      +private[ml] case class ColumnRef(value: String) extends Term
      +
      +/* R formula intercept toggle, e.g. "+ 0" in a formula */
      +private[ml] case class Intercept(enabled: Boolean) extends Term
      +
      +/* R formula deletion of a variable, e.g. "- Species" in a formula */
      +private[ml] case class Deletion(term: Term) extends Term
      +
      +/**
      + * Limited implementation of R formula parsing. Currently supports: '~', '+', '-', '.'.
      + */
      +private[ml] object RFormulaParser extends RegexParsers {
      +  def intercept: Parser[Intercept] =
      +    "([01])".r ^^ { case a => Intercept(a == "1") }
      +
      +  def columnRef: Parser[ColumnRef] =
      +    "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r ^^ { case a => ColumnRef(a) }
      +
      +  def term: Parser[Term] = intercept | columnRef | "\\.".r ^^ { case _ => Dot }
      +
      +  def terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ term)) ^^ {
      +    case op ~ list => list.foldLeft(List(op)) {
      +      case (left, "+" ~ right) => left ++ Seq(right)
      +      case (left, "-" ~ right) => left ++ Seq(Deletion(right))
      +    }
      +  }
      +
      +  def formula: Parser[ParsedRFormula] =
      +    (columnRef ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }
      +
      +  def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
      +    case Success(result, _) => result
      +    case failure: NoSuccess => throw new IllegalArgumentException(
      +      "Could not parse formula: " + value)
      +  }
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
      new file mode 100644
      index 000000000000..95e430563873
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
      @@ -0,0 +1,72 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.feature
      +
      +import org.apache.spark.SparkContext
      +import org.apache.spark.annotation.Experimental
      +import org.apache.spark.ml.param.{ParamMap, Param}
      +import org.apache.spark.ml.Transformer
      +import org.apache.spark.ml.util.Identifiable
      +import org.apache.spark.sql.{SQLContext, DataFrame, Row}
      +import org.apache.spark.sql.types.StructType
      +
      +/**
      + * :: Experimental ::
      + * Implements the transforms which are defined by SQL statement.
      + * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__'
      + * where '__THIS__' represents the underlying table of the input dataset.
      + */
      +@Experimental
      +class SQLTransformer (override val uid: String) extends Transformer {
      +
      +  def this() = this(Identifiable.randomUID("sql"))
      +
      +  /**
      +   * SQL statement parameter. The statement is provided in string form.
      +   * @group param
      +   */
      +  final val statement: Param[String] = new Param[String](this, "statement", "SQL statement")
      +
      +  /** @group setParam */
      +  def setStatement(value: String): this.type = set(statement, value)
      +
      +  /** @group getParam */
      +  def getStatement: String = $(statement)
      +
      +  private val tableIdentifier: String = "__THIS__"
      +
      +  override def transform(dataset: DataFrame): DataFrame = {
      +    val tableName = Identifiable.randomUID(uid)
      +    dataset.registerTempTable(tableName)
      +    val realStatement = $(statement).replace(tableIdentifier, tableName)
      +    val outputDF = dataset.sqlContext.sql(realStatement)
      +    outputDF
      +  }
      +
      +  override def transformSchema(schema: StructType): StructType = {
      +    val sc = SparkContext.getOrCreate()
      +    val sqlContext = SQLContext.getOrCreate(sc)
      +    val dummyRDD = sc.parallelize(Seq(Row.empty))
      +    val dummyDF = sqlContext.createDataFrame(dummyRDD, schema)
      +    dummyDF.registerTempTable(tableIdentifier)
      +    val outputSchema = sqlContext.sql($(statement)).schema
      +    outputSchema
      +  }
      +
      +  override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra)
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
      index ca3c1cfb56b7..f6d0b0c0e9e7 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
      @@ -106,6 +106,12 @@ class StandardScalerModel private[ml] (
           scaler: feature.StandardScalerModel)
         extends Model[StandardScalerModel] with StandardScalerParams {
       
      +  /** Standard deviation of the StandardScalerModel */
      +  val std: Vector = scaler.std
      +
      +  /** Mean of the StandardScalerModel */
      +  val mean: Vector = scaler.mean
      +
         /** @group setParam */
         def setInputCol(value: String): this.type = set(inputCol, value)
       
      @@ -130,6 +136,6 @@ class StandardScalerModel private[ml] (
       
         override def copy(extra: ParamMap): StandardScalerModel = {
           val copied = new StandardScalerModel(uid, scaler)
      -    copyValues(copied, extra)
      +    copyValues(copied, extra).setParent(parent)
         }
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
      new file mode 100644
      index 000000000000..2a79582625e9
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
      @@ -0,0 +1,157 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.feature
      +
      +import org.apache.spark.annotation.Experimental
      +import org.apache.spark.ml.Transformer
      +import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam}
      +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
      +import org.apache.spark.ml.util.Identifiable
      +import org.apache.spark.sql.DataFrame
      +import org.apache.spark.sql.functions.{col, udf}
      +import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType}
      +
      +/**
      + * stop words list
      + */
      +private[spark] object StopWords {
      +
      +  /**
      +   * Use the same default stopwords list as scikit-learn.
      +   * The original list can be found from "Glasgow Information Retrieval Group"
      +   * [[http://ir.dcs.gla.ac.uk/resources/linguistic_utils/stop_words]]
      +   */
      +  val English = Array( "a", "about", "above", "across", "after", "afterwards", "again",
      +    "against", "all", "almost", "alone", "along", "already", "also", "although", "always",
      +    "am", "among", "amongst", "amoungst", "amount", "an", "and", "another",
      +    "any", "anyhow", "anyone", "anything", "anyway", "anywhere", "are",
      +    "around", "as", "at", "back", "be", "became", "because", "become",
      +    "becomes", "becoming", "been", "before", "beforehand", "behind", "being",
      +    "below", "beside", "besides", "between", "beyond", "bill", "both",
      +    "bottom", "but", "by", "call", "can", "cannot", "cant", "co", "con",
      +    "could", "couldnt", "cry", "de", "describe", "detail", "do", "done",
      +    "down", "due", "during", "each", "eg", "eight", "either", "eleven", "else",
      +    "elsewhere", "empty", "enough", "etc", "even", "ever", "every", "everyone",
      +    "everything", "everywhere", "except", "few", "fifteen", "fify", "fill",
      +    "find", "fire", "first", "five", "for", "former", "formerly", "forty",
      +    "found", "four", "from", "front", "full", "further", "get", "give", "go",
      +    "had", "has", "hasnt", "have", "he", "hence", "her", "here", "hereafter",
      +    "hereby", "herein", "hereupon", "hers", "herself", "him", "himself", "his",
      +    "how", "however", "hundred", "i", "ie", "if", "in", "inc", "indeed",
      +    "interest", "into", "is", "it", "its", "itself", "keep", "last", "latter",
      +    "latterly", "least", "less", "ltd", "made", "many", "may", "me",
      +    "meanwhile", "might", "mill", "mine", "more", "moreover", "most", "mostly",
      +    "move", "much", "must", "my", "myself", "name", "namely", "neither",
      +    "never", "nevertheless", "next", "nine", "no", "nobody", "none", "noone",
      +    "nor", "not", "nothing", "now", "nowhere", "of", "off", "often", "on",
      +    "once", "one", "only", "onto", "or", "other", "others", "otherwise", "our",
      +    "ours", "ourselves", "out", "over", "own", "part", "per", "perhaps",
      +    "please", "put", "rather", "re", "same", "see", "seem", "seemed",
      +    "seeming", "seems", "serious", "several", "she", "should", "show", "side",
      +    "since", "sincere", "six", "sixty", "so", "some", "somehow", "someone",
      +    "something", "sometime", "sometimes", "somewhere", "still", "such",
      +    "system", "take", "ten", "than", "that", "the", "their", "them",
      +    "themselves", "then", "thence", "there", "thereafter", "thereby",
      +    "therefore", "therein", "thereupon", "these", "they", "thick", "thin",
      +    "third", "this", "those", "though", "three", "through", "throughout",
      +    "thru", "thus", "to", "together", "too", "top", "toward", "towards",
      +    "twelve", "twenty", "two", "un", "under", "until", "up", "upon", "us",
      +    "very", "via", "was", "we", "well", "were", "what", "whatever", "when",
      +    "whence", "whenever", "where", "whereafter", "whereas", "whereby",
      +    "wherein", "whereupon", "wherever", "whether", "which", "while", "whither",
      +    "who", "whoever", "whole", "whom", "whose", "why", "will", "with",
      +    "within", "without", "would", "yet", "you", "your", "yours", "yourself", "yourselves")
      +}
      +
      +/**
      + * :: Experimental ::
      + * A feature transformer that filters out stop words from input.
      + * Note: null values from input array are preserved unless adding null to stopWords explicitly.
      + * @see [[http://en.wikipedia.org/wiki/Stop_words]]
      + */
      +@Experimental
      +class StopWordsRemover(override val uid: String)
      +  extends Transformer with HasInputCol with HasOutputCol {
      +
      +  def this() = this(Identifiable.randomUID("stopWords"))
      +
      +  /** @group setParam */
      +  def setInputCol(value: String): this.type = set(inputCol, value)
      +
      +  /** @group setParam */
      +  def setOutputCol(value: String): this.type = set(outputCol, value)
      +
      +  /**
      +   * the stop words set to be filtered out
      +   * Default: [[StopWords.English]]
      +   * @group param
      +   */
      +  val stopWords: StringArrayParam = new StringArrayParam(this, "stopWords", "stop words")
      +
      +  /** @group setParam */
      +  def setStopWords(value: Array[String]): this.type = set(stopWords, value)
      +
      +  /** @group getParam */
      +  def getStopWords: Array[String] = $(stopWords)
      +
      +  /**
      +   * whether to do a case sensitive comparison over the stop words
      +   * Default: false
      +   * @group param
      +   */
      +  val caseSensitive: BooleanParam = new BooleanParam(this, "caseSensitive",
      +    "whether to do case-sensitive comparison during filtering")
      +
      +  /** @group setParam */
      +  def setCaseSensitive(value: Boolean): this.type = set(caseSensitive, value)
      +
      +  /** @group getParam */
      +  def getCaseSensitive: Boolean = $(caseSensitive)
      +
      +  setDefault(stopWords -> StopWords.English, caseSensitive -> false)
      +
      +  override def transform(dataset: DataFrame): DataFrame = {
      +    val outputSchema = transformSchema(dataset.schema)
      +    val t = if ($(caseSensitive)) {
      +        val stopWordsSet = $(stopWords).toSet
      +        udf { terms: Seq[String] =>
      +          terms.filter(s => !stopWordsSet.contains(s))
      +        }
      +      } else {
      +        val toLower = (s: String) => if (s != null) s.toLowerCase else s
      +        val lowerStopWords = $(stopWords).map(toLower(_)).toSet
      +        udf { terms: Seq[String] =>
      +          terms.filter(s => !lowerStopWords.contains(toLower(s)))
      +        }
      +    }
      +
      +    val metadata = outputSchema($(outputCol)).metadata
      +    dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
      +  }
      +
      +  override def transformSchema(schema: StructType): StructType = {
      +    val inputType = schema($(inputCol)).dataType
      +    require(inputType.sameType(ArrayType(StringType)),
      +      s"Input type must be ArrayType(StringType) but got $inputType.")
      +    val outputFields = schema.fields :+
      +      StructField($(outputCol), inputType, schema($(inputCol)).nullable)
      +    StructType(outputFields)
      +  }
      +
      +  override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra)
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
      index bf7be363b822..2b1592930e77 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
      @@ -20,19 +20,21 @@ package org.apache.spark.ml.feature
       import org.apache.spark.SparkException
       import org.apache.spark.annotation.Experimental
       import org.apache.spark.ml.{Estimator, Model}
      -import org.apache.spark.ml.attribute.NominalAttribute
      +import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
       import org.apache.spark.ml.param._
       import org.apache.spark.ml.param.shared._
      +import org.apache.spark.ml.Transformer
       import org.apache.spark.ml.util.Identifiable
       import org.apache.spark.sql.DataFrame
       import org.apache.spark.sql.functions._
      -import org.apache.spark.sql.types.{NumericType, StringType, StructType}
      +import org.apache.spark.sql.types._
       import org.apache.spark.util.collection.OpenHashMap
       
       /**
        * Base trait for [[StringIndexer]] and [[StringIndexerModel]].
        */
      -private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {
      +private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol
      +    with HasHandleInvalid {
       
         /** Validates and transforms the input schema. */
         protected def validateAndTransformSchema(schema: StructType): StructType = {
      @@ -57,6 +59,8 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
        * If the input column is numeric, we cast it to string and index the string values.
        * The indices are in [0, numLabels), ordered by label frequencies.
        * So the most frequent label gets index 0.
      + *
      + * @see [[IndexToString]] for the inverse transformation
        */
       @Experimental
       class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel]
      @@ -64,13 +68,16 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
       
         def this() = this(Identifiable.randomUID("strIdx"))
       
      +  /** @group setParam */
      +  def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
      +  setDefault(handleInvalid, "error")
      +
         /** @group setParam */
         def setInputCol(value: String): this.type = set(inputCol, value)
       
         /** @group setParam */
         def setOutputCol(value: String): this.type = set(outputCol, value)
       
      -  // TODO: handle unseen labels
       
         override def fit(dataset: DataFrame): StringIndexerModel = {
           val counts = dataset.select(col($(inputCol)).cast(StringType))
      @@ -90,14 +97,19 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
       /**
        * :: Experimental ::
        * Model fitted by [[StringIndexer]].
      + *
        * NOTE: During transformation, if the input column does not exist,
        * [[StringIndexerModel.transform]] would return the input dataset unmodified.
        * This is a temporary fix for the case when target labels do not exist during prediction.
      + *
      + * @param labels  Ordered list of labels, corresponding to indices to be assigned.
        */
       @Experimental
      -class StringIndexerModel private[ml] (
      +class StringIndexerModel (
           override val uid: String,
      -    labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase {
      +    val labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase {
      +
      +  def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels)
       
         private val labelToIndex: OpenHashMap[String, Double] = {
           val n = labels.length
      @@ -110,6 +122,10 @@ class StringIndexerModel private[ml] (
           map
         }
       
      +  /** @group setParam */
      +  def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
      +  setDefault(handleInvalid, "error")
      +
         /** @group setParam */
         def setInputCol(value: String): this.type = set(inputCol, value)
       
      @@ -127,14 +143,24 @@ class StringIndexerModel private[ml] (
             if (labelToIndex.contains(label)) {
               labelToIndex(label)
             } else {
      -        // TODO: handle unseen labels
               throw new SparkException(s"Unseen label: $label.")
             }
           }
      +
           val outputColName = $(outputCol)
           val metadata = NominalAttribute.defaultAttr
             .withName(outputColName).withValues(labels).toMetadata()
      -    dataset.select(col("*"),
      +    // If we are skipping invalid records, filter them out.
      +    val filteredDataset = (getHandleInvalid) match {
      +      case "skip" => {
      +        val filterer = udf { label: String =>
      +          labelToIndex.contains(label)
      +        }
      +        dataset.where(filterer(dataset($(inputCol))))
      +      }
      +      case _ => dataset
      +    }
      +    filteredDataset.select(col("*"),
             indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, metadata))
         }
       
      @@ -149,6 +175,87 @@ class StringIndexerModel private[ml] (
       
         override def copy(extra: ParamMap): StringIndexerModel = {
           val copied = new StringIndexerModel(uid, labels)
      -    copyValues(copied, extra)
      +    copyValues(copied, extra).setParent(parent)
      +  }
      +}
      +
      +/**
      + * :: Experimental ::
      + * A [[Transformer]] that maps a column of indices back to a new column of corresponding
      + * string values.
      + * The index-string mapping is either from the ML attributes of the input column,
      + * or from user-supplied labels (which take precedence over ML attributes).
      + *
      + * @see [[StringIndexer]] for converting strings into indices
      + */
      +@Experimental
      +class IndexToString private[ml] (
      +  override val uid: String) extends Transformer
      +    with HasInputCol with HasOutputCol {
      +
      +  def this() =
      +    this(Identifiable.randomUID("idxToStr"))
      +
      +  /** @group setParam */
      +  def setInputCol(value: String): this.type = set(inputCol, value)
      +
      +  /** @group setParam */
      +  def setOutputCol(value: String): this.type = set(outputCol, value)
      +
      +  /** @group setParam */
      +  def setLabels(value: Array[String]): this.type = set(labels, value)
      +
      +  /**
      +   * Optional param for array of labels specifying index-string mapping.
      +   *
      +   * Default: Empty array, in which case [[inputCol]] metadata is used for labels.
      +   * @group param
      +   */
      +  final val labels: StringArrayParam = new StringArrayParam(this, "labels",
      +    "Optional array of labels specifying index-string mapping." +
      +      " If not provided or if empty, then metadata from inputCol is used instead.")
      +  setDefault(labels, Array.empty[String])
      +
      +  /** @group getParam */
      +  final def getLabels: Array[String] = $(labels)
      +
      +  override def transformSchema(schema: StructType): StructType = {
      +    val inputColName = $(inputCol)
      +    val inputDataType = schema(inputColName).dataType
      +    require(inputDataType.isInstanceOf[NumericType],
      +      s"The input column $inputColName must be a numeric type, " +
      +        s"but got $inputDataType.")
      +    val inputFields = schema.fields
      +    val outputColName = $(outputCol)
      +    require(inputFields.forall(_.name != outputColName),
      +      s"Output column $outputColName already exists.")
      +    val outputFields = inputFields :+ StructField($(outputCol), StringType)
      +    StructType(outputFields)
      +  }
      +
      +  override def transform(dataset: DataFrame): DataFrame = {
      +    val inputColSchema = dataset.schema($(inputCol))
      +    // If the labels array is empty use column metadata
      +    val values = if ($(labels).isEmpty) {
      +      Attribute.fromStructField(inputColSchema)
      +        .asInstanceOf[NominalAttribute].values.get
      +    } else {
      +      $(labels)
      +    }
      +    val indexer = udf { index: Double =>
      +      val idx = index.toInt
      +      if (0 <= idx && idx < values.length) {
      +        values(idx)
      +      } else {
      +        throw new SparkException(s"Unseen index: $index ??")
      +      }
      +    }
      +    val outputColName = $(outputCol)
      +    dataset.select(col("*"),
      +      indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName))
      +  }
      +
      +  override def copy(extra: ParamMap): IndexToString = {
      +    defaultCopy(extra)
         }
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
      index 5f9f57a2ebcf..248288ca73e9 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
      @@ -42,7 +42,7 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S
           require(inputType == StringType, s"Input type must be string type but got $inputType.")
         }
       
      -  override protected def outputDataType: DataType = new ArrayType(StringType, false)
      +  override protected def outputDataType: DataType = new ArrayType(StringType, true)
       
         override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra)
       }
      @@ -50,7 +50,7 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S
       /**
        * :: Experimental ::
        * A regex based tokenizer that extracts tokens either by using the provided regex pattern to split
      - * the text (default) or repeatedly matching the regex (if `gaps` is true).
      + * the text (default) or repeatedly matching the regex (if `gaps` is false).
        * Optional parameters also allow filtering tokens using a minimal length.
        * It returns an array of strings that can be empty.
        */
      @@ -113,7 +113,7 @@ class RegexTokenizer(override val uid: String)
           require(inputType == StringType, s"Input type must be string type but got $inputType.")
         }
       
      -  override protected def outputDataType: DataType = new ArrayType(StringType, false)
      +  override protected def outputDataType: DataType = new ArrayType(StringType, true)
       
         override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra)
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
      index 9f83c2ee1617..086917fa680f 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
      @@ -116,7 +116,7 @@ class VectorAssembler(override val uid: String)
           if (schema.fieldNames.contains(outputColName)) {
             throw new IllegalArgumentException(s"Output column $outputColName already exists.")
           }
      -    StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false))
      +    StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, true))
         }
       
         override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra)
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
      index f4854a5e4b7b..52e0599e38d8 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
      @@ -30,7 +30,7 @@ import org.apache.spark.ml.param.shared._
       import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
       import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT}
       import org.apache.spark.sql.{DataFrame, Row}
      -import org.apache.spark.sql.functions.callUDF
      +import org.apache.spark.sql.functions.udf
       import org.apache.spark.sql.types.{StructField, StructType}
       import org.apache.spark.util.collection.OpenHashSet
       
      @@ -43,6 +43,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
          * Must be >= 2.
          *
          * (default = 20)
      +   * @group param
          */
         val maxCategories = new IntParam(this, "maxCategories",
           "Threshold for the number of values a categorical feature can take (>= 2)." +
      @@ -339,8 +340,9 @@ class VectorIndexerModel private[ml] (
         override def transform(dataset: DataFrame): DataFrame = {
           transformSchema(dataset.schema, logging = true)
           val newField = prepOutputField(dataset.schema)
      -    val newCol = callUDF(transformFunc, new VectorUDT, dataset($(inputCol)))
      -    dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
      +    val transformUDF = udf { (vector: Vector) => transformFunc(vector) }
      +    val newCol = transformUDF(dataset($(inputCol)))
      +    dataset.withColumn($(outputCol), newCol, newField.metadata)
         }
       
         override def transformSchema(schema: StructType): StructType = {
      @@ -404,6 +406,6 @@ class VectorIndexerModel private[ml] (
       
         override def copy(extra: ParamMap): VectorIndexerModel = {
           val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps)
      -    copyValues(copied, extra)
      +    copyValues(copied, extra).setParent(parent)
         }
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
      new file mode 100644
      index 000000000000..fb3387d4aa9b
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
      @@ -0,0 +1,171 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.feature
      +
      +import org.apache.spark.annotation.Experimental
      +import org.apache.spark.ml.Transformer
      +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup}
      +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
      +import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam}
      +import org.apache.spark.ml.util.{Identifiable, MetadataUtils, SchemaUtils}
      +import org.apache.spark.mllib.linalg._
      +import org.apache.spark.sql.DataFrame
      +import org.apache.spark.sql.functions._
      +import org.apache.spark.sql.types.StructType
      +
      +/**
      + * :: Experimental ::
      + * This class takes a feature vector and outputs a new feature vector with a subarray of the
      + * original features.
      + *
      + * The subset of features can be specified with either indices ([[setIndices()]])
      + * or names ([[setNames()]]).  At least one feature must be selected. Duplicate features
      + * are not allowed, so there can be no overlap between selected indices and names.
      + *
      + * The output vector will order features with the selected indices first (in the order given),
      + * followed by the selected names (in the order given).
      + */
      +@Experimental
      +final class VectorSlicer(override val uid: String)
      +  extends Transformer with HasInputCol with HasOutputCol {
      +
      +  def this() = this(Identifiable.randomUID("vectorSlicer"))
      +
      +  /**
      +   * An array of indices to select features from a vector column.
      +   * There can be no overlap with [[names]].
      +   * Default: Empty array
      +   * @group param
      +   */
      +  val indices = new IntArrayParam(this, "indices",
      +    "An array of indices to select features from a vector column." +
      +      " There can be no overlap with names.", VectorSlicer.validIndices)
      +
      +  setDefault(indices -> Array.empty[Int])
      +
      +  /** @group getParam */
      +  def getIndices: Array[Int] = $(indices)
      +
      +  /** @group setParam */
      +  def setIndices(value: Array[Int]): this.type = set(indices, value)
      +
      +  /**
      +   * An array of feature names to select features from a vector column.
      +   * These names must be specified by ML [[org.apache.spark.ml.attribute.Attribute]]s.
      +   * There can be no overlap with [[indices]].
      +   * Default: Empty Array
      +   * @group param
      +   */
      +  val names = new StringArrayParam(this, "names",
      +    "An array of feature names to select features from a vector column." +
      +      " There can be no overlap with indices.", VectorSlicer.validNames)
      +
      +  setDefault(names -> Array.empty[String])
      +
      +  /** @group getParam */
      +  def getNames: Array[String] = $(names)
      +
      +  /** @group setParam */
      +  def setNames(value: Array[String]): this.type = set(names, value)
      +
      +  /** @group setParam */
      +  def setInputCol(value: String): this.type = set(inputCol, value)
      +
      +  /** @group setParam */
      +  def setOutputCol(value: String): this.type = set(outputCol, value)
      +
      +  override def validateParams(): Unit = {
      +    require($(indices).length > 0 || $(names).length > 0,
      +      s"VectorSlicer requires that at least one feature be selected.")
      +  }
      +
      +  override def transform(dataset: DataFrame): DataFrame = {
      +    // Validity checks
      +    transformSchema(dataset.schema)
      +    val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol)))
      +    inputAttr.numAttributes.foreach { numFeatures =>
      +      val maxIndex = $(indices).max
      +      require(maxIndex < numFeatures,
      +        s"Selected feature index $maxIndex invalid for only $numFeatures input features.")
      +    }
      +
      +    // Prepare output attributes
      +    val inds = getSelectedFeatureIndices(dataset.schema)
      +    val selectedAttrs: Option[Array[Attribute]] = inputAttr.attributes.map { attrs =>
      +      inds.map(index => attrs(index))
      +    }
      +    val outputAttr = selectedAttrs match {
      +      case Some(attrs) => new AttributeGroup($(outputCol), attrs)
      +      case None => new AttributeGroup($(outputCol), inds.length)
      +    }
      +
      +    // Select features
      +    val slicer = udf { vec: Vector =>
      +      vec match {
      +        case features: DenseVector => Vectors.dense(inds.map(features.apply))
      +        case features: SparseVector => features.slice(inds)
      +      }
      +    }
      +    dataset.withColumn($(outputCol), slicer(dataset($(inputCol))), outputAttr.toMetadata())
      +  }
      +
      +  /** Get the feature indices in order: indices, names */
      +  private def getSelectedFeatureIndices(schema: StructType): Array[Int] = {
      +    val nameFeatures = MetadataUtils.getFeatureIndicesFromNames(schema($(inputCol)), $(names))
      +    val indFeatures = $(indices)
      +    val numDistinctFeatures = (nameFeatures ++ indFeatures).distinct.length
      +    lazy val errMsg = "VectorSlicer requires indices and names to be disjoint" +
      +      s" sets of features, but they overlap." +
      +      s" indices: ${indFeatures.mkString("[", ",", "]")}." +
      +      s" names: " +
      +      nameFeatures.zip($(names)).map { case (i, n) => s"$i:$n" }.mkString("[", ",", "]")
      +    require(nameFeatures.length + indFeatures.length == numDistinctFeatures, errMsg)
      +    indFeatures ++ nameFeatures
      +  }
      +
      +  override def transformSchema(schema: StructType): StructType = {
      +    SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
      +
      +    if (schema.fieldNames.contains($(outputCol))) {
      +      throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.")
      +    }
      +    val numFeaturesSelected = $(indices).length + $(names).length
      +    val outputAttr = new AttributeGroup($(outputCol), numFeaturesSelected)
      +    val outputFields = schema.fields :+ outputAttr.toStructField()
      +    StructType(outputFields)
      +  }
      +
      +  override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra)
      +}
      +
      +private[feature] object VectorSlicer {
      +
      +  /** Return true if given feature indices are valid */
      +  def validIndices(indices: Array[Int]): Boolean = {
      +    if (indices.isEmpty) {
      +      true
      +    } else {
      +      indices.length == indices.distinct.length && indices.forall(_ >= 0)
      +    }
      +  }
      +
      +  /** Return true if given feature names are valid */
      +  def validNames(names: Array[String]): Boolean = {
      +    names.forall(_.nonEmpty) && names.length == names.distinct.length
      +  }
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
      index 6ea659095630..9edab3af913c 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
      @@ -18,15 +18,17 @@
       package org.apache.spark.ml.feature
       
       import org.apache.spark.annotation.Experimental
      +import org.apache.spark.SparkContext
       import org.apache.spark.ml.{Estimator, Model}
       import org.apache.spark.ml.param._
       import org.apache.spark.ml.param.shared._
       import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
       import org.apache.spark.mllib.feature
      -import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
      +import org.apache.spark.mllib.linalg.{VectorUDT, Vector, Vectors}
       import org.apache.spark.mllib.linalg.BLAS._
       import org.apache.spark.sql.DataFrame
       import org.apache.spark.sql.functions._
      +import org.apache.spark.sql.SQLContext
       import org.apache.spark.sql.types._
       
       /**
      @@ -37,6 +39,7 @@ private[feature] trait Word2VecBase extends Params
       
         /**
          * The dimension of the code that you want to transform from words.
      +   * Default: 100
          * @group param
          */
         final val vectorSize = new IntParam(
      @@ -48,6 +51,7 @@ private[feature] trait Word2VecBase extends Params
       
         /**
          * Number of partitions for sentences of words.
      +   * Default: 1
          * @group param
          */
         final val numPartitions = new IntParam(
      @@ -60,6 +64,7 @@ private[feature] trait Word2VecBase extends Params
         /**
          * The minimum number of times a token must appear to be included in the word2vec model's
          * vocabulary.
      +   * Default: 5
          * @group param
          */
         final val minCount = new IntParam(this, "minCount", "the minimum number of times a token must " +
      @@ -146,6 +151,40 @@ class Word2VecModel private[ml] (
           wordVectors: feature.Word2VecModel)
         extends Model[Word2VecModel] with Word2VecBase {
       
      +
      +  /**
      +   * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and
      +   * and the vector the DenseVector that it is mapped to.
      +   */
      +  @transient lazy val getVectors: DataFrame = {
      +    val sc = SparkContext.getOrCreate()
      +    val sqlContext = SQLContext.getOrCreate(sc)
      +    import sqlContext.implicits._
      +    val wordVec = wordVectors.getVectors.mapValues(vec => Vectors.dense(vec.map(_.toDouble)))
      +    sc.parallelize(wordVec.toSeq).toDF("word", "vector")
      +  }
      +
      +  /**
      +   * Find "num" number of words closest in similarity to the given word.
      +   * Returns a dataframe with the words and the cosine similarities between the
      +   * synonyms and the given word.
      +   */
      +  def findSynonyms(word: String, num: Int): DataFrame = {
      +    findSynonyms(wordVectors.transform(word), num)
      +  }
      +
      +  /**
      +   * Find "num" number of words closest to similarity to the given vector representation
      +   * of the word. Returns a dataframe with the words and the cosine similarities between the
      +   * synonyms and the given word vector.
      +   */
      +  def findSynonyms(word: Vector, num: Int): DataFrame = {
      +    val sc = SparkContext.getOrCreate()
      +    val sqlContext = SQLContext.getOrCreate(sc)
      +    import sqlContext.implicits._
      +    sc.parallelize(wordVectors.findSynonyms(word, num)).toDF("word", "similarity")
      +  }
      +
         /** @group setParam */
         def setInputCol(value: String): this.type = set(inputCol, value)
       
      @@ -185,6 +224,6 @@ class Word2VecModel private[ml] (
       
         override def copy(extra: ParamMap): Word2VecModel = {
           val copied = new Word2VecModel(uid, wordVectors)
      -    copyValues(copied, extra)
      +    copyValues(copied, extra).setParent(parent)
         }
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/package.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/package.scala
      new file mode 100644
      index 000000000000..4571ab26800c
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/package.scala
      @@ -0,0 +1,89 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml
      +
      +import org.apache.spark.ml.feature.{HashingTF, IDF, IDFModel, VectorAssembler}
      +import org.apache.spark.sql.DataFrame
      +
      +/**
      + * == Feature transformers ==
      + *
      + * The `ml.feature` package provides common feature transformers that help convert raw data or
      + * features into more suitable forms for model fitting.
      + * Most feature transformers are implemented as [[Transformer]]s, which transform one [[DataFrame]]
      + * into another, e.g., [[HashingTF]].
      + * Some feature transformers are implemented as [[Estimator]]s, because the transformation requires
      + * some aggregated information of the dataset, e.g., document frequencies in [[IDF]].
      + * For those feature transformers, calling [[Estimator!.fit]] is required to obtain the model first,
      + * e.g., [[IDFModel]], in order to apply transformation.
      + * The transformation is usually done by appending new columns to the input [[DataFrame]], so all
      + * input columns are carried over.
      + *
      + * We try to make each transformer minimal, so it becomes flexible to assemble feature
      + * transformation pipelines.
      + * [[Pipeline]] can be used to chain feature transformers, and [[VectorAssembler]] can be used to
      + * combine multiple feature transformations, for example:
      + *
      + * {{{
      + *   import org.apache.spark.ml.feature._
      + *   import org.apache.spark.ml.Pipeline
      + *
      + *   // a DataFrame with three columns: id (integer), text (string), and rating (double).
      + *   val df = sqlContext.createDataFrame(Seq(
      + *     (0, "Hi I heard about Spark", 3.0),
      + *     (1, "I wish Java could use case classes", 4.0),
      + *     (2, "Logistic regression models are neat", 4.0)
      + *   )).toDF("id", "text", "rating")
      + *
      + *   // define feature transformers
      + *   val tok = new RegexTokenizer()
      + *     .setInputCol("text")
      + *     .setOutputCol("words")
      + *   val sw = new StopWordsRemover()
      + *     .setInputCol("words")
      + *     .setOutputCol("filtered_words")
      + *   val tf = new HashingTF()
      + *     .setInputCol("filtered_words")
      + *     .setOutputCol("tf")
      + *     .setNumFeatures(10000)
      + *   val idf = new IDF()
      + *     .setInputCol("tf")
      + *     .setOutputCol("tf_idf")
      + *   val assembler = new VectorAssembler()
      + *     .setInputCols(Array("tf_idf", "rating"))
      + *     .setOutputCol("features")
      + *
      + *   // assemble and fit the feature transformation pipeline
      + *   val pipeline = new Pipeline()
      + *     .setStages(Array(tok, sw, tf, idf, assembler))
      + *   val model = pipeline.fit(df)
      + *
      + *   // save transformed features with raw data
      + *   model.transform(df)
      + *     .select("id", "text", "rating", "features")
      + *     .write.format("parquet").save("/output/path")
      + * }}}
      + *
      + * Some feature transformers implemented in MLlib are inspired by those implemented in scikit-learn.
      + * The major difference is that most scikit-learn feature transformers operate eagerly on the entire
      + * input dataset, while MLlib's feature transformers operate lazily on individual columns,
      + * which is more efficient and flexible to handle large and complex datasets.
      + *
      + * @see [[http://scikit-learn.org/stable/modules/preprocessing.html scikit-learn.preprocessing]]
      + */
      +package object feature
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
      new file mode 100644
      index 000000000000..0ff8931b0bab
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
      @@ -0,0 +1,296 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.optim
      +
      +import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
      +import org.netlib.util.intW
      +
      +import org.apache.spark.Logging
      +import org.apache.spark.mllib.linalg._
      +import org.apache.spark.mllib.linalg.distributed.RowMatrix
      +import org.apache.spark.rdd.RDD
      +
      +/**
      + * Model fitted by [[WeightedLeastSquares]].
      + * @param coefficients model coefficients
      + * @param intercept model intercept
      + */
      +private[ml] class WeightedLeastSquaresModel(
      +    val coefficients: DenseVector,
      +    val intercept: Double) extends Serializable
      +
      +/**
      + * Weighted least squares solver via normal equation.
      + * Given weighted observations (w,,i,,, a,,i,,, b,,i,,), we use the following weighted least squares
      + * formulation:
      + *
      + * min,,x,z,, 1/2 sum,,i,, w,,i,, (a,,i,,^T^ x + z - b,,i,,)^2^ / sum,,i,, w_i
      + *   + 1/2 lambda / delta sum,,j,, (sigma,,j,, x,,j,,)^2^,
      + *
      + * where lambda is the regularization parameter, and delta and sigma,,j,, are controlled by
      + * [[standardizeLabel]] and [[standardizeFeatures]], respectively.
      + *
      + * Set [[regParam]] to 0.0 and turn off both [[standardizeFeatures]] and [[standardizeLabel]] to
      + * match R's `lm`.
      + * Turn on [[standardizeLabel]] to match R's `glmnet`.
      + *
      + * @param fitIntercept whether to fit intercept. If false, z is 0.0.
      + * @param regParam L2 regularization parameter (lambda)
      + * @param standardizeFeatures whether to standardize features. If true, sigma_,,j,, is the
      + *                            population standard deviation of the j-th column of A. Otherwise,
      + *                            sigma,,j,, is 1.0.
      + * @param standardizeLabel whether to standardize label. If true, delta is the population standard
      + *                         deviation of the label column b. Otherwise, delta is 1.0.
      + */
      +private[ml] class WeightedLeastSquares(
      +    val fitIntercept: Boolean,
      +    val regParam: Double,
      +    val standardizeFeatures: Boolean,
      +    val standardizeLabel: Boolean) extends Logging with Serializable {
      +  import WeightedLeastSquares._
      +
      +  require(regParam >= 0.0, s"regParam cannot be negative: $regParam")
      +  if (regParam == 0.0) {
      +    logWarning("regParam is zero, which might cause numerical instability and overfitting.")
      +  }
      +
      +  /**
      +   * Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s.
      +   */
      +  def fit(instances: RDD[Instance]): WeightedLeastSquaresModel = {
      +    val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_))
      +    summary.validate()
      +    logInfo(s"Number of instances: ${summary.count}.")
      +    val triK = summary.triK
      +    val bBar = summary.bBar
      +    val bStd = summary.bStd
      +    val aBar = summary.aBar
      +    val aVar = summary.aVar
      +    val abBar = summary.abBar
      +    val aaBar = summary.aaBar
      +    val aaValues = aaBar.values
      +
      +    if (fitIntercept) {
      +      // shift centers
      +      // A^T A - aBar aBar^T
      +      BLAS.spr(-1.0, aBar, aaValues)
      +      // A^T b - bBar aBar
      +      BLAS.axpy(-bBar, aBar, abBar)
      +    }
      +
      +    // add regularization to diagonals
      +    var i = 0
      +    var j = 2
      +    while (i < triK) {
      +      var lambda = regParam
      +      if (standardizeFeatures) {
      +        lambda *= aVar(j - 2)
      +      }
      +      if (standardizeLabel) {
      +        // TODO: handle the case when bStd = 0
      +        lambda /= bStd
      +      }
      +      aaValues(i) += lambda
      +      i += j
      +      j += 1
      +    }
      +
      +    val x = choleskySolve(aaBar.values, abBar)
      +
      +    // compute intercept
      +    val intercept = if (fitIntercept) {
      +      bBar - BLAS.dot(aBar, x)
      +    } else {
      +      0.0
      +    }
      +
      +    new WeightedLeastSquaresModel(x, intercept)
      +  }
      +
      +  /**
      +   * Solves a symmetric positive definite linear system via Cholesky factorization.
      +   * The input arguments are modified in-place to store the factorization and the solution.
      +   * @param A the upper triangular part of A
      +   * @param bx right-hand side
      +   * @return the solution vector
      +   */
      +  // TODO: SPARK-10490 - consolidate this and the Cholesky solver in ALS
      +  private def choleskySolve(A: Array[Double], bx: DenseVector): DenseVector = {
      +    val k = bx.size
      +    val info = new intW(0)
      +    lapack.dppsv("U", k, 1, A, bx.values, k, info)
      +    val code = info.`val`
      +    assert(code == 0, s"lapack.dpotrs returned $code.")
      +    bx
      +  }
      +}
      +
      +private[ml] object WeightedLeastSquares {
      +
      +  /**
      +   * Case class for weighted observations.
      +   * @param w weight, must be positive
      +   * @param a features
      +   * @param b label
      +   */
      +  case class Instance(w: Double, a: Vector, b: Double) {
      +    require(w >= 0.0, s"Weight cannot be negative: $w.")
      +  }
      +
      +  /**
      +   * Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]].
      +   */
      +  // TODO: consolidate aggregates for summary statistics
      +  private class Aggregator extends Serializable {
      +    var initialized: Boolean = false
      +    var k: Int = _
      +    var count: Long = _
      +    var triK: Int = _
      +    private var wSum: Double = _
      +    private var wwSum: Double = _
      +    private var bSum: Double = _
      +    private var bbSum: Double = _
      +    private var aSum: DenseVector = _
      +    private var abSum: DenseVector = _
      +    private var aaSum: DenseVector = _
      +
      +    private def init(k: Int): Unit = {
      +      require(k <= 4096, "In order to take the normal equation approach efficiently, " +
      +        s"we set the max number of features to 4096 but got $k.")
      +      this.k = k
      +      triK = k * (k + 1) / 2
      +      count = 0L
      +      wSum = 0.0
      +      wwSum = 0.0
      +      bSum = 0.0
      +      bbSum = 0.0
      +      aSum = new DenseVector(Array.ofDim(k))
      +      abSum = new DenseVector(Array.ofDim(k))
      +      aaSum = new DenseVector(Array.ofDim(triK))
      +      initialized = true
      +    }
      +
      +    /**
      +     * Adds an instance.
      +     */
      +    def add(instance: Instance): this.type = {
      +      val Instance(w, a, b) = instance
      +      val ak = a.size
      +      if (!initialized) {
      +        init(ak)
      +        initialized = true
      +      }
      +      assert(ak == k, s"Dimension mismatch. Expect vectors of size $k but got $ak.")
      +      count += 1L
      +      wSum += w
      +      wwSum += w * w
      +      bSum += w * b
      +      bbSum += w * b * b
      +      BLAS.axpy(w, a, aSum)
      +      BLAS.axpy(w * b, a, abSum)
      +      BLAS.spr(w, a, aaSum)
      +      this
      +    }
      +
      +    /**
      +     * Merges another [[Aggregator]].
      +     */
      +    def merge(other: Aggregator): this.type = {
      +      if (!other.initialized) {
      +        this
      +      } else {
      +        if (!initialized) {
      +          init(other.k)
      +        }
      +        assert(k == other.k, s"dimension mismatch: this.k = $k but other.k = ${other.k}")
      +        count += other.count
      +        wSum += other.wSum
      +        wwSum += other.wwSum
      +        bSum += other.bSum
      +        bbSum += other.bbSum
      +        BLAS.axpy(1.0, other.aSum, aSum)
      +        BLAS.axpy(1.0, other.abSum, abSum)
      +        BLAS.axpy(1.0, other.aaSum, aaSum)
      +        this
      +      }
      +    }
      +
      +    /**
      +     * Validates that we have seen observations.
      +     */
      +    def validate(): Unit = {
      +      assert(initialized, "Training dataset is empty.")
      +      assert(wSum > 0.0, "Sum of weights cannot be zero.")
      +    }
      +
      +    /**
      +     * Weighted mean of features.
      +     */
      +    def aBar: DenseVector = {
      +      val output = aSum.copy
      +      BLAS.scal(1.0 / wSum, output)
      +      output
      +    }
      +
      +    /**
      +     * Weighted mean of labels.
      +     */
      +    def bBar: Double = bSum / wSum
      +
      +    /**
      +     * Weighted population standard deviation of labels.
      +     */
      +    def bStd: Double = math.sqrt(bbSum / wSum - bBar * bBar)
      +
      +    /**
      +     * Weighted mean of (label * features).
      +     */
      +    def abBar: DenseVector = {
      +      val output = abSum.copy
      +      BLAS.scal(1.0 / wSum, output)
      +      output
      +    }
      +
      +    /**
      +     * Weighted mean of (features * features^T^).
      +     */
      +    def aaBar: DenseVector = {
      +      val output = aaSum.copy
      +      BLAS.scal(1.0 / wSum, output)
      +      output
      +    }
      +
      +    /**
      +     * Weighted population variance of features.
      +     */
      +    def aVar: DenseVector = {
      +      val variance = Array.ofDim[Double](k)
      +      var i = 0
      +      var j = 2
      +      val aaValues = aaSum.values
      +      while (i < triK) {
      +        val l = j - 2
      +        val aw = aSum(l) / wSum
      +        variance(l) = aaValues(i) / wSum - aw * aw
      +        i += j
      +        j += 1
      +      }
      +      new DenseVector(variance)
      +    }
      +  }
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
      index 50c0d855066f..de32b7218c27 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
      @@ -166,6 +166,11 @@ object ParamValidators {
         def inArray[T](allowed: java.util.List[T]): T => Boolean = { (value: T) =>
           allowed.contains(value)
         }
      +
      +  /** Check that the array length is greater than lowerBound. */
      +  def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T]) =>
      +    value.length > lowerBound
      +  }
       }
       
       // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...
      @@ -295,6 +300,22 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array
           w(value.asScala.map(_.asInstanceOf[Double]).toArray)
       }
       
      +/**
      + * :: DeveloperApi ::
      + * Specialized version of [[Param[Array[Int]]]] for Java.
      + */
      +@DeveloperApi
      +class IntArrayParam(parent: Params, name: String, doc: String, isValid: Array[Int] => Boolean)
      +  extends Param[Array[Int]](parent, name, doc, isValid) {
      +
      +  def this(parent: Params, name: String, doc: String) =
      +    this(parent, name, doc, ParamValidators.alwaysTrue)
      +
      +  /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
      +  def w(value: java.util.List[java.lang.Integer]): ParamPair[Array[Int]] =
      +    w(value.asScala.map(_.asInstanceOf[Int]).toArray)
      +}
      +
       /**
        * :: Experimental ::
        * A param and its value.
      @@ -341,9 +362,7 @@ trait Params extends Identifiable with Serializable {
          * those are checked during schema validation.
          */
         def validateParams(): Unit = {
      -    params.filter(isDefined).foreach { param =>
      -      param.asInstanceOf[Param[Any]].validate($(param))
      -    }
      +    // Do nothing by default.  Override to handle Param interactions.
         }
       
         /**
      @@ -442,7 +461,8 @@ trait Params extends Identifiable with Serializable {
          */
         final def getOrDefault[T](param: Param[T]): T = {
           shouldOwn(param)
      -    get(param).orElse(getDefault(param)).get
      +    get(param).orElse(getDefault(param)).getOrElse(
      +      throw new NoSuchElementException(s"Failed to find a default value for ${param.name}"))
         }
       
         /** An alias for [[getOrDefault()]]. */
      @@ -462,11 +482,14 @@ trait Params extends Identifiable with Serializable {
         /**
          * Sets default values for a list of params.
          *
      +   * Note: Java developers should use the single-parameter [[setDefault()]].
      +   *       Annotating this with varargs can cause compilation failures due to a Scala compiler bug.
      +   *       See SPARK-9268.
      +   *
          * @param paramPairs  a list of param pairs that specify params and their default values to set
          *                    respectively. Make sure that the params are initialized before this method
          *                    gets called.
          */
      -  @varargs
         protected final def setDefault(paramPairs: ParamPair[_]*): this.type = {
           paramPairs.foreach { p =>
             setDefault(p.param.asInstanceOf[Param[Any]], p.value)
      @@ -537,13 +560,26 @@ trait Params extends Identifiable with Serializable {
       
         /**
          * Copies param values from this instance to another instance for params shared by them.
      -   * @param to the target instance
      -   * @param extra extra params to be copied
      +   *
      +   * This handles default Params and explicitly set Params separately.
      +   * Default Params are copied from and to [[defaultParamMap]], and explicitly set Params are
      +   * copied from and to [[paramMap]].
      +   * Warning: This implicitly assumes that this [[Params]] instance and the target instance
      +   *          share the same set of default Params.
      +   *
      +   * @param to the target instance, which should work with the same set of default Params as this
      +   *           source instance
      +   * @param extra extra params to be copied to the target's [[paramMap]]
          * @return the target instance with param values copied
          */
         protected def copyValues[T <: Params](to: T, extra: ParamMap = ParamMap.empty): T = {
      -    val map = extractParamMap(extra)
      +    val map = paramMap ++ extra
           params.foreach { param =>
      +      // copy default Params
      +      if (defaultParamMap.contains(param) && to.hasParam(param.name)) {
      +        to.defaultParamMap.put(to.getParam(param.name), defaultParamMap(param))
      +      }
      +      // copy explicitly set Params
             if (map.contains(param) && to.hasParam(param.name)) {
               to.set(param.name, map(param))
             }
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
      index 8ffbcf0d8bc7..8049d51fee5e 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
      @@ -42,23 +42,38 @@ private[shared] object SharedParamsCodeGen {
               Some("\"rawPrediction\"")),
             ParamDesc[String]("probabilityCol", "Column name for predicted class conditional" +
               " probabilities. Note: Not all models output well-calibrated probability estimates!" +
      -        " These probabilities should be treated as confidences, not precise probabilities.",
      +        " These probabilities should be treated as confidences, not precise probabilities",
               Some("\"probability\"")),
             ParamDesc[Double]("threshold",
      -        "threshold in binary classification prediction, in range [0, 1]",
      -        isValid = "ParamValidators.inRange(0, 1)"),
      +        "threshold in binary classification prediction, in range [0, 1]", Some("0.5"),
      +        isValid = "ParamValidators.inRange(0, 1)", finalMethods = false),
      +      ParamDesc[Array[Double]]("thresholds", "Thresholds in multi-class classification" +
      +        " to adjust the probability of predicting each class." +
      +        " Array must have length equal to the number of classes, with values >= 0." +
      +        " The class with largest value p/t is predicted, where p is the original probability" +
      +        " of that class and t is the class' threshold.",
      +        isValid = "(t: Array[Double]) => t.forall(_ >= 0)", finalMethods = false),
             ParamDesc[String]("inputCol", "input column name"),
             ParamDesc[Array[String]]("inputCols", "input column names"),
             ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")),
      -      ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)",
      +      ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1). E.g. 10 means that " +
      +        "the cache will get checkpointed every 10 iterations.",
               isValid = "ParamValidators.gtEq(1)"),
             ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
      +      ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " +
      +        "will filter out rows with bad values), or error (which will throw an errror). More " +
      +        "options may be added later.",
      +        isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"),
      +      ParamDesc[Boolean]("standardization", "whether to standardize the training features" +
      +        " before fitting the model", Some("true")),
             ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")),
             ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." +
      -        " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.",
      +        " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty",
               isValid = "ParamValidators.inRange(0, 1)"),
             ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"),
      -      ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."))
      +      ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."),
      +      ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " +
      +        "all instance weights as 1.0."))
       
           val code = genSharedParams(params)
           val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"
      @@ -72,7 +87,8 @@ private[shared] object SharedParamsCodeGen {
             name: String,
             doc: String,
             defaultValueStr: Option[String] = None,
      -      isValid: String = "") {
      +      isValid: String = "",
      +      finalMethods: Boolean = true) {
       
           require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.")
           require(doc.nonEmpty) // TODO: more rigorous on doc
      @@ -86,6 +102,7 @@ private[shared] object SharedParamsCodeGen {
               case _ if c == classOf[Double] => "DoubleParam"
               case _ if c == classOf[Boolean] => "BooleanParam"
               case _ if c.isArray && c.getComponentType == classOf[String] => s"StringArrayParam"
      +        case _ if c.isArray && c.getComponentType == classOf[Double] => s"DoubleArrayParam"
               case _ => s"Param[${getTypeString(c)}]"
             }
           }
      @@ -129,10 +146,15 @@ private[shared] object SharedParamsCodeGen {
           } else {
             ""
           }
      +    val methodStr = if (param.finalMethods) {
      +      "final def"
      +    } else {
      +      "def"
      +    }
       
           s"""
             |/**
      -      | * (private[ml]) Trait for shared param $name$defaultValueDoc.
      +      | * Trait for shared param $name$defaultValueDoc.
             | */
             |private[ml] trait Has$Name extends Params {
             |
      @@ -143,7 +165,7 @@ private[shared] object SharedParamsCodeGen {
             |  final val $name: $Param = new $Param(this, "$name", "$doc"$isValid)
             |$setDefault
             |  /** @group getParam */
      -      |  final def get$Name: $T = $$($name)
      +      |  $methodStr get$Name: $T = $$($name)
             |}
             |""".stripMargin
         }
      @@ -171,7 +193,6 @@ private[shared] object SharedParamsCodeGen {
               |package org.apache.spark.ml.param.shared
               |
               |import org.apache.spark.ml.param._
      -        |import org.apache.spark.util.Utils
               |
               |// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen.
               |
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
      index a0c8ccdac9ad..aff47fc326c4 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
      @@ -18,14 +18,13 @@
       package org.apache.spark.ml.param.shared
       
       import org.apache.spark.ml.param._
      -import org.apache.spark.util.Utils
       
       // DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen.
       
       // scalastyle:off
       
       /**
      - * (private[ml]) Trait for shared param regParam.
      + * Trait for shared param regParam.
        */
       private[ml] trait HasRegParam extends Params {
       
      @@ -40,7 +39,7 @@ private[ml] trait HasRegParam extends Params {
       }
       
       /**
      - * (private[ml]) Trait for shared param maxIter.
      + * Trait for shared param maxIter.
        */
       private[ml] trait HasMaxIter extends Params {
       
      @@ -55,7 +54,7 @@ private[ml] trait HasMaxIter extends Params {
       }
       
       /**
      - * (private[ml]) Trait for shared param featuresCol (default: "features").
      + * Trait for shared param featuresCol (default: "features").
        */
       private[ml] trait HasFeaturesCol extends Params {
       
      @@ -72,7 +71,7 @@ private[ml] trait HasFeaturesCol extends Params {
       }
       
       /**
      - * (private[ml]) Trait for shared param labelCol (default: "label").
      + * Trait for shared param labelCol (default: "label").
        */
       private[ml] trait HasLabelCol extends Params {
       
      @@ -89,7 +88,7 @@ private[ml] trait HasLabelCol extends Params {
       }
       
       /**
      - * (private[ml]) Trait for shared param predictionCol (default: "prediction").
      + * Trait for shared param predictionCol (default: "prediction").
        */
       private[ml] trait HasPredictionCol extends Params {
       
      @@ -106,7 +105,7 @@ private[ml] trait HasPredictionCol extends Params {
       }
       
       /**
      - * (private[ml]) Trait for shared param rawPredictionCol (default: "rawPrediction").
      + * Trait for shared param rawPredictionCol (default: "rawPrediction").
        */
       private[ml] trait HasRawPredictionCol extends Params {
       
      @@ -123,15 +122,15 @@ private[ml] trait HasRawPredictionCol extends Params {
       }
       
       /**
      - * (private[ml]) Trait for shared param probabilityCol (default: "probability").
      + * Trait for shared param probabilityCol (default: "probability").
        */
       private[ml] trait HasProbabilityCol extends Params {
       
         /**
      -   * Param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities..
      +   * Param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.
          * @group param
          */
      -  final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.")
      +  final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities")
       
         setDefault(probabilityCol, "probability")
       
      @@ -140,7 +139,7 @@ private[ml] trait HasProbabilityCol extends Params {
       }
       
       /**
      - * (private[ml]) Trait for shared param threshold.
      + * Trait for shared param threshold (default: 0.5).
        */
       private[ml] trait HasThreshold extends Params {
       
      @@ -150,12 +149,29 @@ private[ml] trait HasThreshold extends Params {
          */
         final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]", ParamValidators.inRange(0, 1))
       
      +  setDefault(threshold, 0.5)
      +
      +  /** @group getParam */
      +  def getThreshold: Double = $(threshold)
      +}
      +
      +/**
      + * Trait for shared param thresholds.
      + */
      +private[ml] trait HasThresholds extends Params {
      +
      +  /**
      +   * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold..
      +   * @group param
      +   */
      +  final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", (t: Array[Double]) => t.forall(_ >= 0))
      +
         /** @group getParam */
      -  final def getThreshold: Double = $(threshold)
      +  def getThresholds: Array[Double] = $(thresholds)
       }
       
       /**
      - * (private[ml]) Trait for shared param inputCol.
      + * Trait for shared param inputCol.
        */
       private[ml] trait HasInputCol extends Params {
       
      @@ -170,7 +186,7 @@ private[ml] trait HasInputCol extends Params {
       }
       
       /**
      - * (private[ml]) Trait for shared param inputCols.
      + * Trait for shared param inputCols.
        */
       private[ml] trait HasInputCols extends Params {
       
      @@ -185,7 +201,7 @@ private[ml] trait HasInputCols extends Params {
       }
       
       /**
      - * (private[ml]) Trait for shared param outputCol (default: uid + "__output").
      + * Trait for shared param outputCol (default: uid + "__output").
        */
       private[ml] trait HasOutputCol extends Params {
       
      @@ -202,22 +218,22 @@ private[ml] trait HasOutputCol extends Params {
       }
       
       /**
      - * (private[ml]) Trait for shared param checkpointInterval.
      + * Trait for shared param checkpointInterval.
        */
       private[ml] trait HasCheckpointInterval extends Params {
       
         /**
      -   * Param for checkpoint interval (>= 1).
      +   * Param for checkpoint interval (>= 1). E.g. 10 means that the cache will get checkpointed every 10 iterations..
          * @group param
          */
      -  final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1)", ParamValidators.gtEq(1))
      +  final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", ParamValidators.gtEq(1))
       
         /** @group getParam */
         final def getCheckpointInterval: Int = $(checkpointInterval)
       }
       
       /**
      - * (private[ml]) Trait for shared param fitIntercept (default: true).
      + * Trait for shared param fitIntercept (default: true).
        */
       private[ml] trait HasFitIntercept extends Params {
       
      @@ -234,7 +250,39 @@ private[ml] trait HasFitIntercept extends Params {
       }
       
       /**
      - * (private[ml]) Trait for shared param seed (default: this.getClass.getName.hashCode.toLong).
      + * Trait for shared param handleInvalid.
      + */
      +private[ml] trait HasHandleInvalid extends Params {
      +
      +  /**
      +   * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later..
      +   * @group param
      +   */
      +  final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", ParamValidators.inArray(Array("skip", "error")))
      +
      +  /** @group getParam */
      +  final def getHandleInvalid: String = $(handleInvalid)
      +}
      +
      +/**
      + * Trait for shared param standardization (default: true).
      + */
      +private[ml] trait HasStandardization extends Params {
      +
      +  /**
      +   * Param for whether to standardize the training features before fitting the model.
      +   * @group param
      +   */
      +  final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features before fitting the model")
      +
      +  setDefault(standardization, true)
      +
      +  /** @group getParam */
      +  final def getStandardization: Boolean = $(standardization)
      +}
      +
      +/**
      + * Trait for shared param seed (default: this.getClass.getName.hashCode.toLong).
        */
       private[ml] trait HasSeed extends Params {
       
      @@ -251,22 +299,22 @@ private[ml] trait HasSeed extends Params {
       }
       
       /**
      - * (private[ml]) Trait for shared param elasticNetParam.
      + * Trait for shared param elasticNetParam.
        */
       private[ml] trait HasElasticNetParam extends Params {
       
         /**
      -   * Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty..
      +   * Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
          * @group param
          */
      -  final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", ParamValidators.inRange(0, 1))
      +  final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty", ParamValidators.inRange(0, 1))
       
         /** @group getParam */
         final def getElasticNetParam: Double = $(elasticNetParam)
       }
       
       /**
      - * (private[ml]) Trait for shared param tol.
      + * Trait for shared param tol.
        */
       private[ml] trait HasTol extends Params {
       
      @@ -281,7 +329,7 @@ private[ml] trait HasTol extends Params {
       }
       
       /**
      - * (private[ml]) Trait for shared param stepSize.
      + * Trait for shared param stepSize.
        */
       private[ml] trait HasStepSize extends Params {
       
      @@ -294,4 +342,19 @@ private[ml] trait HasStepSize extends Params {
         /** @group getParam */
         final def getStepSize: Double = $(stepSize)
       }
      +
      +/**
      + * Trait for shared param weightCol.
      + */
      +private[ml] trait HasWeightCol extends Params {
      +
      +  /**
      +   * Param for weight column name. If this is not set or empty, we treat all instance weights as 1.0..
      +   * @group param
      +   */
      +  final val weightCol: Param[String] = new Param[String](this, "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.")
      +
      +  /** @group getParam */
      +  final def getWeightCol: String = $(weightCol)
      +}
       // scalastyle:on
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
      new file mode 100644
      index 000000000000..f5a022c31ed9
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
      @@ -0,0 +1,70 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.api.r
      +
      +import org.apache.spark.ml.attribute._
      +import org.apache.spark.ml.feature.RFormula
      +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
      +import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
      +import org.apache.spark.ml.{Pipeline, PipelineModel}
      +import org.apache.spark.sql.DataFrame
      +
      +private[r] object SparkRWrappers {
      +  def fitRModelFormula(
      +      value: String,
      +      df: DataFrame,
      +      family: String,
      +      lambda: Double,
      +      alpha: Double): PipelineModel = {
      +    val formula = new RFormula().setFormula(value)
      +    val estimator = family match {
      +      case "gaussian" => new LinearRegression()
      +        .setRegParam(lambda)
      +        .setElasticNetParam(alpha)
      +        .setFitIntercept(formula.hasIntercept)
      +      case "binomial" => new LogisticRegression()
      +        .setRegParam(lambda)
      +        .setElasticNetParam(alpha)
      +        .setFitIntercept(formula.hasIntercept)
      +    }
      +    val pipeline = new Pipeline().setStages(Array(formula, estimator))
      +    pipeline.fit(df)
      +  }
      +
      +  def getModelWeights(model: PipelineModel): Array[Double] = {
      +    model.stages.last match {
      +      case m: LinearRegressionModel =>
      +        Array(m.intercept) ++ m.weights.toArray
      +      case _: LogisticRegressionModel =>
      +        throw new UnsupportedOperationException(
      +          "No weights available for LogisticRegressionModel")  // SPARK-9492
      +    }
      +  }
      +
      +  def getModelFeatures(model: PipelineModel): Array[String] = {
      +    model.stages.last match {
      +      case m: LinearRegressionModel =>
      +        val attrs = AttributeGroup.fromStructField(
      +          m.summary.predictions.schema(m.summary.featuresCol))
      +        Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
      +      case _: LogisticRegressionModel =>
      +        throw new UnsupportedOperationException(
      +          "No features names available for LogisticRegressionModel")  // SPARK-9492
      +    }
      +  }
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
      index 2e44cd4cc6a2..7db8ad8d2791 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
      @@ -219,7 +219,7 @@ class ALSModel private[ml] (
       
         override def copy(extra: ParamMap): ALSModel = {
           val copied = new ALSModel(uid, rank, userFactors, itemFactors)
      -    copyValues(copied, extra)
      +    copyValues(copied, extra).setParent(parent)
         }
       }
       
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
      index be1f8063d41d..d9a244bea28d 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
      @@ -21,10 +21,10 @@ import org.apache.spark.annotation.Experimental
       import org.apache.spark.ml.{PredictionModel, Predictor}
       import org.apache.spark.ml.param.ParamMap
       import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams}
      +import org.apache.spark.ml.tree.impl.RandomForest
       import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
       import org.apache.spark.mllib.linalg.Vector
       import org.apache.spark.mllib.regression.LabeledPoint
      -import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
       import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
       import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
       import org.apache.spark.rdd.RDD
      @@ -67,8 +67,9 @@ final class DecisionTreeRegressor(override val uid: String)
             MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
           val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
           val strategy = getOldStrategy(categoricalFeatures)
      -    val oldModel = OldDecisionTree.train(oldDataset, strategy)
      -    DecisionTreeRegressionModel.fromOld(oldModel, this, categoricalFeatures)
      +    val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
      +      seed = 0L, parentUID = Some(uid))
      +    trees.head.asInstanceOf[DecisionTreeRegressionModel]
         }
       
         /** (private[ml]) Create a Strategy instance to use with the old API. */
      @@ -102,16 +103,22 @@ final class DecisionTreeRegressionModel private[ml] (
         require(rootNode != null,
           "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
       
      +  /**
      +   * Construct a decision tree regression model.
      +   * @param rootNode  Root node of tree, with other nodes attached.
      +   */
      +  private[ml] def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode)
      +
         override protected def predict(features: Vector): Double = {
      -    rootNode.predict(features)
      +    rootNode.predictImpl(features).prediction
         }
       
         override def copy(extra: ParamMap): DecisionTreeRegressionModel = {
      -    copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra)
      +    copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra).setParent(parent)
         }
       
         override def toString: String = {
      -    s"DecisionTreeRegressionModel of depth $depth with $numNodes nodes"
      +    s"DecisionTreeRegressionModel (uid=$uid) of depth $depth with $numNodes nodes"
         }
       
         /** Convert to a model in the old API */
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
      index 036e3acb0741..d841ecb9e58d 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
      @@ -33,6 +33,8 @@ import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss
       import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
       import org.apache.spark.rdd.RDD
       import org.apache.spark.sql.DataFrame
      +import org.apache.spark.sql.functions._
      +import org.apache.spark.sql.types.DoubleType
       
       /**
        * :: Experimental ::
      @@ -167,21 +169,27 @@ final class GBTRegressionModel(
       
         override def treeWeights: Array[Double] = _treeWeights
       
      +  override protected def transformImpl(dataset: DataFrame): DataFrame = {
      +    val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
      +    val predictUDF = udf { (features: Any) =>
      +      bcastModel.value.predict(features.asInstanceOf[Vector])
      +    }
      +    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
      +  }
      +
         override protected def predict(features: Vector): Double = {
      -    // TODO: Override transform() to broadcast model. SPARK-7127
           // TODO: When we add a generic Boosting class, handle transform there?  SPARK-7129
           // Classifies by thresholding sum of weighted tree predictions
      -    val treePredictions = _trees.map(_.rootNode.predict(features))
      -    val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
      -    if (prediction > 0.0) 1.0 else 0.0
      +    val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
      +    blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
         }
       
         override def copy(extra: ParamMap): GBTRegressionModel = {
      -    copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra)
      +    copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra).setParent(parent)
         }
       
         override def toString: String = {
      -    s"GBTRegressionModel with $numTrees trees"
      +    s"GBTRegressionModel (uid=$uid) with $numTrees trees"
         }
       
         /** (private[ml]) Convert to a model in the old API */
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
      new file mode 100644
      index 000000000000..2ff500f291ab
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
      @@ -0,0 +1,223 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.regression
      +
      +import org.apache.spark.Logging
      +import org.apache.spark.annotation.Experimental
      +import org.apache.spark.ml.{Estimator, Model}
      +import org.apache.spark.ml.param._
      +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasPredictionCol, HasWeightCol}
      +import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
      +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
      +import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression, IsotonicRegressionModel => MLlibIsotonicRegressionModel}
      +import org.apache.spark.rdd.RDD
      +import org.apache.spark.sql.{DataFrame, Row}
      +import org.apache.spark.sql.functions.{col, lit, udf}
      +import org.apache.spark.sql.types.{DoubleType, StructType}
      +import org.apache.spark.storage.StorageLevel
      +
      +/**
      + * Params for isotonic regression.
      + */
      +private[regression] trait IsotonicRegressionBase extends Params with HasFeaturesCol
      +  with HasLabelCol with HasPredictionCol with HasWeightCol with Logging {
      +
      +  /**
      +   * Param for whether the output sequence should be isotonic/increasing (true) or
      +   * antitonic/decreasing (false).
      +   * Default: true
      +   * @group param
      +   */
      +  final val isotonic: BooleanParam =
      +    new BooleanParam(this, "isotonic",
      +      "whether the output sequence should be isotonic/increasing (true) or" +
      +        "antitonic/decreasing (false)")
      +
      +  /** @group getParam */
      +  final def getIsotonic: Boolean = $(isotonic)
      +
      +  /**
      +   * Param for the index of the feature if [[featuresCol]] is a vector column (default: `0`), no
      +   * effect otherwise.
      +   * @group param
      +   */
      +  final val featureIndex: IntParam = new IntParam(this, "featureIndex",
      +    "The index of the feature if featuresCol is a vector column, no effect otherwise.")
      +
      +  /** @group getParam */
      +  final def getFeatureIndex: Int = $(featureIndex)
      +
      +  setDefault(isotonic -> true, featureIndex -> 0)
      +
      +  /** Checks whether the input has weight column. */
      +  protected[ml] def hasWeightCol: Boolean = {
      +    isDefined(weightCol) && $(weightCol) != ""
      +  }
      +
      +  /**
      +   * Extracts (label, feature, weight) from input dataset.
      +   */
      +  protected[ml] def extractWeightedLabeledPoints(
      +      dataset: DataFrame): RDD[(Double, Double, Double)] = {
      +    val f = if (dataset.schema($(featuresCol)).dataType.isInstanceOf[VectorUDT]) {
      +      val idx = $(featureIndex)
      +      val extract = udf { v: Vector => v(idx) }
      +      extract(col($(featuresCol)))
      +    } else {
      +      col($(featuresCol))
      +    }
      +    val w = if (hasWeightCol) {
      +      col($(weightCol))
      +    } else {
      +      lit(1.0)
      +    }
      +    dataset.select(col($(labelCol)), f, w)
      +      .map { case Row(label: Double, feature: Double, weights: Double) =>
      +      (label, feature, weights)
      +    }
      +  }
      +
      +  /**
      +   * Validates and transforms input schema.
      +   * @param schema input schema
      +   * @param fitting whether this is in fitting or prediction
      +   * @return output schema
      +   */
      +  protected[ml] def validateAndTransformSchema(
      +      schema: StructType,
      +      fitting: Boolean): StructType = {
      +    if (fitting) {
      +      SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
      +      if (hasWeightCol) {
      +        SchemaUtils.checkColumnType(schema, $(weightCol), DoubleType)
      +      } else {
      +        logInfo("The weight column is not defined. Treat all instance weights as 1.0.")
      +      }
      +    }
      +    val featuresType = schema($(featuresCol)).dataType
      +    require(featuresType == DoubleType || featuresType.isInstanceOf[VectorUDT])
      +    SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
      +  }
      +}
      +
      +/**
      + * :: Experimental ::
      + * Isotonic regression.
      + *
      + * Currently implemented using parallelized pool adjacent violators algorithm.
      + * Only univariate (single feature) algorithm supported.
      + *
      + * Uses [[org.apache.spark.mllib.regression.IsotonicRegression]].
      + */
      +@Experimental
      +class IsotonicRegression(override val uid: String) extends Estimator[IsotonicRegressionModel]
      +  with IsotonicRegressionBase {
      +
      +  def this() = this(Identifiable.randomUID("isoReg"))
      +
      +  /** @group setParam */
      +  def setLabelCol(value: String): this.type = set(labelCol, value)
      +
      +  /** @group setParam */
      +  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
      +
      +  /** @group setParam */
      +  def setPredictionCol(value: String): this.type = set(predictionCol, value)
      +
      +  /** @group setParam */
      +  def setIsotonic(value: Boolean): this.type = set(isotonic, value)
      +
      +  /** @group setParam */
      +  def setWeightCol(value: String): this.type = set(weightCol, value)
      +
      +  /** @group setParam */
      +  def setFeatureIndex(value: Int): this.type = set(featureIndex, value)
      +
      +  override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra)
      +
      +  override def fit(dataset: DataFrame): IsotonicRegressionModel = {
      +    validateAndTransformSchema(dataset.schema, fitting = true)
      +    // Extract columns from data.  If dataset is persisted, do not persist oldDataset.
      +    val instances = extractWeightedLabeledPoints(dataset)
      +    val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
      +    if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
      +
      +    val isotonicRegression = new MLlibIsotonicRegression().setIsotonic($(isotonic))
      +    val oldModel = isotonicRegression.run(instances)
      +
      +    copyValues(new IsotonicRegressionModel(uid, oldModel).setParent(this))
      +  }
      +
      +  override def transformSchema(schema: StructType): StructType = {
      +    validateAndTransformSchema(schema, fitting = true)
      +  }
      +}
      +
      +/**
      + * :: Experimental ::
      + * Model fitted by IsotonicRegression.
      + * Predicts using a piecewise linear function.
      + *
      + * For detailed rules see [[org.apache.spark.mllib.regression.IsotonicRegressionModel.predict()]].
      + *
      + * @param oldModel A [[org.apache.spark.mllib.regression.IsotonicRegressionModel]]
      + *                 model trained by [[org.apache.spark.mllib.regression.IsotonicRegression]].
      + */
      +@Experimental
      +class IsotonicRegressionModel private[ml] (
      +    override val uid: String,
      +    private val oldModel: MLlibIsotonicRegressionModel)
      +  extends Model[IsotonicRegressionModel] with IsotonicRegressionBase {
      +
      +  /** @group setParam */
      +  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
      +
      +  /** @group setParam */
      +  def setPredictionCol(value: String): this.type = set(predictionCol, value)
      +
      +  /** @group setParam */
      +  def setFeatureIndex(value: Int): this.type = set(featureIndex, value)
      +
      +  /** Boundaries in increasing order for which predictions are known. */
      +  def boundaries: Vector = Vectors.dense(oldModel.boundaries)
      +
      +  /**
      +   * Predictions associated with the boundaries at the same index, monotone because of isotonic
      +   * regression.
      +   */
      +  def predictions: Vector = Vectors.dense(oldModel.predictions)
      +
      +  override def copy(extra: ParamMap): IsotonicRegressionModel = {
      +    copyValues(new IsotonicRegressionModel(uid, oldModel), extra).setParent(parent)
      +  }
      +
      +  override def transform(dataset: DataFrame): DataFrame = {
      +    val predict = dataset.schema($(featuresCol)).dataType match {
      +      case DoubleType =>
      +        udf { feature: Double => oldModel.predict(feature) }
      +      case _: VectorUDT =>
      +        val idx = $(featureIndex)
      +        udf { features: Vector => oldModel.predict(features(idx)) }
      +    }
      +    dataset.withColumn($(predictionCol), predict(col($(featuresCol))))
      +  }
      +
      +  override def transformSchema(schema: StructType): StructType = {
      +    validateAndTransformSchema(schema, fitting = false)
      +  }
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
      index 01306545fc7c..e4602d36ccc8 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
      @@ -22,18 +22,21 @@ import scala.collection.mutable
       import breeze.linalg.{DenseVector => BDV, norm => brzNorm}
       import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
       
      -import org.apache.spark.Logging
      +import org.apache.spark.{Logging, SparkException}
       import org.apache.spark.annotation.Experimental
       import org.apache.spark.ml.PredictorParams
       import org.apache.spark.ml.param.ParamMap
      -import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol}
      +import org.apache.spark.ml.param.shared._
       import org.apache.spark.ml.util.Identifiable
      +import org.apache.spark.mllib.evaluation.RegressionMetrics
       import org.apache.spark.mllib.linalg.{Vector, Vectors}
       import org.apache.spark.mllib.linalg.BLAS._
       import org.apache.spark.mllib.regression.LabeledPoint
       import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
       import org.apache.spark.rdd.RDD
      -import org.apache.spark.sql.DataFrame
      +import org.apache.spark.sql.{DataFrame, Row}
      +import org.apache.spark.sql.functions.{col, udf}
      +import org.apache.spark.sql.types.StructField
       import org.apache.spark.storage.StorageLevel
       import org.apache.spark.util.StatCounter
       
      @@ -41,7 +44,8 @@ import org.apache.spark.util.StatCounter
        * Params for linear regression.
        */
       private[regression] trait LinearRegressionParams extends PredictorParams
      -  with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
      +    with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
      +    with HasFitIntercept with HasStandardization
       
       /**
        * :: Experimental ::
      @@ -72,6 +76,26 @@ class LinearRegression(override val uid: String)
         def setRegParam(value: Double): this.type = set(regParam, value)
         setDefault(regParam -> 0.0)
       
      +  /**
      +   * Set if we should fit the intercept
      +   * Default is true.
      +   * @group setParam
      +   */
      +  def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
      +  setDefault(fitIntercept -> true)
      +
      +  /**
      +   * Whether to standardize the training features before fitting the model.
      +   * The coefficients of models will be always returned on the original scale,
      +   * so it will be transparent for users. Note that with/without standardization,
      +   * the models should be always converged to the same solution when no regularization
      +   * is applied. In R's GLMNET package, the default behavior is true as well.
      +   * Default is true.
      +   * @group setParam
      +   */
      +  def setStandardization(value: Boolean): this.type = set(standardization, value)
      +  setDefault(standardization -> true)
      +
         /**
          * Set the ElasticNet mixing parameter.
          * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
      @@ -130,7 +154,17 @@ class LinearRegression(override val uid: String)
             logWarning(s"The standard deviation of the label is zero, so the weights will be zeros " +
               s"and the intercept will be the mean of the label; as a result, training is not needed.")
             if (handlePersistence) instances.unpersist()
      -      return new LinearRegressionModel(uid, Vectors.sparse(numFeatures, Seq()), yMean)
      +      val weights = Vectors.sparse(numFeatures, Seq())
      +      val intercept = yMean
      +
      +      val model = new LinearRegressionModel(uid, weights, intercept)
      +      val trainingSummary = new LinearRegressionTrainingSummary(
      +        model.transform(dataset),
      +        $(predictionCol),
      +        $(labelCol),
      +        $(featuresCol),
      +        Array(0D))
      +      return copyValues(model.setSummary(trainingSummary))
           }
       
           val featuresMean = summarizer.mean.toArray
      @@ -142,31 +176,55 @@ class LinearRegression(override val uid: String)
           val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam
           val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam
       
      -    val costFun = new LeastSquaresCostFun(instances, yStd, yMean,
      -      featuresStd, featuresMean, effectiveL2RegParam)
      +    val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept),
      +      $(standardization), featuresStd, featuresMean, effectiveL2RegParam)
       
           val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
             new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
           } else {
      -      new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegParam, $(tol))
      +      def effectiveL1RegFun = (index: Int) => {
      +        if ($(standardization)) {
      +          effectiveL1RegParam
      +        } else {
      +          // If `standardization` is false, we still standardize the data
      +          // to improve the rate of convergence; as a result, we have to
      +          // perform this reverse standardization by penalizing each component
      +          // differently to get effectively the same objective function when
      +          // the training dataset is not standardized.
      +          if (featuresStd(index) != 0.0) effectiveL1RegParam / featuresStd(index) else 0.0
      +        }
      +      }
      +      new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegFun, $(tol))
           }
       
           val initialWeights = Vectors.zeros(numFeatures)
      -    val states =
      -      optimizer.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector)
      -
      -    var state = states.next()
      -    val lossHistory = mutable.ArrayBuilder.make[Double]
      -
      -    while (states.hasNext) {
      -      lossHistory += state.value
      -      state = states.next()
      -    }
      -    lossHistory += state.value
      +    val states = optimizer.iterations(new CachedDiffFunction(costFun),
      +      initialWeights.toBreeze.toDenseVector)
      +
      +    val (weights, objectiveHistory) = {
      +      /*
      +         Note that in Linear Regression, the objective history (loss + regularization) returned
      +         from optimizer is computed in the scaled space given by the following formula.
      +         {{{
      +         L = 1/2n||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2 + regTerms
      +         }}}
      +       */
      +      val arrayBuilder = mutable.ArrayBuilder.make[Double]
      +      var state: optimizer.State = null
      +      while (states.hasNext) {
      +        state = states.next()
      +        arrayBuilder += state.adjustedValue
      +      }
      +      if (state == null) {
      +        val msg = s"${optimizer.getClass.getName} failed."
      +        logError(msg)
      +        throw new SparkException(msg)
      +      }
       
      -    // The weights are trained in the scaled space; we're converting them back to
      -    // the original space.
      -    val weights = {
      +      /*
      +         The weights are trained in the scaled space; we're converting them back to
      +         the original space.
      +       */
             val rawWeights = state.x.toArray.clone()
             var i = 0
             val len = rawWeights.length
      @@ -174,17 +232,27 @@ class LinearRegression(override val uid: String)
               rawWeights(i) *= { if (featuresStd(i) != 0.0) yStd / featuresStd(i) else 0.0 }
               i += 1
             }
      -      Vectors.dense(rawWeights)
      +
      +      (Vectors.dense(rawWeights).compressed, arrayBuilder.result())
           }
       
      -    // The intercept in R's GLMNET is computed using closed form after the coefficients are
      -    // converged. See the following discussion for detail.
      -    // http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
      -    val intercept = yMean - dot(weights, Vectors.dense(featuresMean))
      +    /*
      +       The intercept in R's GLMNET is computed using closed form after the coefficients are
      +       converged. See the following discussion for detail.
      +       http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
      +     */
      +    val intercept = if ($(fitIntercept)) yMean - dot(weights, Vectors.dense(featuresMean)) else 0.0
      +
           if (handlePersistence) instances.unpersist()
       
      -    // TODO: Converts to sparse format based on the storage, but may base on the scoring speed.
      -    copyValues(new LinearRegressionModel(uid, weights.compressed, intercept))
      +    val model = copyValues(new LinearRegressionModel(uid, weights, intercept))
      +    val trainingSummary = new LinearRegressionTrainingSummary(
      +      model.transform(dataset),
      +      $(predictionCol),
      +      $(labelCol),
      +      $(featuresCol),
      +      objectiveHistory)
      +    model.setSummary(trainingSummary)
         }
       
         override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra)
      @@ -202,13 +270,125 @@ class LinearRegressionModel private[ml] (
         extends RegressionModel[Vector, LinearRegressionModel]
         with LinearRegressionParams {
       
      +  private var trainingSummary: Option[LinearRegressionTrainingSummary] = None
      +
      +  /**
      +   * Gets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is
      +   * thrown if `trainingSummary == None`.
      +   */
      +  def summary: LinearRegressionTrainingSummary = trainingSummary match {
      +    case Some(summ) => summ
      +    case None =>
      +      throw new SparkException(
      +        "No training summary available for this LinearRegressionModel",
      +        new NullPointerException())
      +  }
      +
      +  private[regression] def setSummary(summary: LinearRegressionTrainingSummary): this.type = {
      +    this.trainingSummary = Some(summary)
      +    this
      +  }
      +
      +  /** Indicates whether a training summary exists for this model instance. */
      +  def hasSummary: Boolean = trainingSummary.isDefined
      +
      +  /**
      +   * Evaluates the model on a testset.
      +   * @param dataset Test dataset to evaluate model on.
      +   */
      +  // TODO: decide on a good name before exposing to public API
      +  private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = {
      +    val t = udf { features: Vector => predict(features) }
      +    val predictionAndObservations = dataset
      +      .select(col($(labelCol)), t(col($(featuresCol))).as($(predictionCol)))
      +
      +    new LinearRegressionSummary(predictionAndObservations, $(predictionCol), $(labelCol))
      +  }
      +
         override protected def predict(features: Vector): Double = {
           dot(features, weights) + intercept
         }
       
         override def copy(extra: ParamMap): LinearRegressionModel = {
      -    copyValues(new LinearRegressionModel(uid, weights, intercept), extra)
      +    val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept), extra)
      +    if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
      +    newModel.setParent(parent)
      +  }
      +}
      +
      +/**
      + * :: Experimental ::
      + * Linear regression training results.
      + * @param predictions predictions outputted by the model's `transform` method.
      + * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
      + */
      +@Experimental
      +class LinearRegressionTrainingSummary private[regression] (
      +    predictions: DataFrame,
      +    predictionCol: String,
      +    labelCol: String,
      +    val featuresCol: String,
      +    val objectiveHistory: Array[Double])
      +  extends LinearRegressionSummary(predictions, predictionCol, labelCol) {
      +
      +  /** Number of training iterations until termination */
      +  val totalIterations = objectiveHistory.length
      +
      +}
      +
      +/**
      + * :: Experimental ::
      + * Linear regression results evaluated on a dataset.
      + * @param predictions predictions outputted by the model's `transform` method.
      + */
      +@Experimental
      +class LinearRegressionSummary private[regression] (
      +    @transient val predictions: DataFrame,
      +    val predictionCol: String,
      +    val labelCol: String) extends Serializable {
      +
      +  @transient private val metrics = new RegressionMetrics(
      +    predictions
      +      .select(predictionCol, labelCol)
      +      .map { case Row(pred: Double, label: Double) => (pred, label) } )
      +
      +  /**
      +   * Returns the explained variance regression score.
      +   * explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
      +   * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]]
      +   */
      +  val explainedVariance: Double = metrics.explainedVariance
      +
      +  /**
      +   * Returns the mean absolute error, which is a risk function corresponding to the
      +   * expected value of the absolute error loss or l1-norm loss.
      +   */
      +  val meanAbsoluteError: Double = metrics.meanAbsoluteError
      +
      +  /**
      +   * Returns the mean squared error, which is a risk function corresponding to the
      +   * expected value of the squared error loss or quadratic loss.
      +   */
      +  val meanSquaredError: Double = metrics.meanSquaredError
      +
      +  /**
      +   * Returns the root mean squared error, which is defined as the square root of
      +   * the mean squared error.
      +   */
      +  val rootMeanSquaredError: Double = metrics.rootMeanSquaredError
      +
      +  /**
      +   * Returns R^2^, the coefficient of determination.
      +   * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
      +   */
      +  val r2: Double = metrics.r2
      +
      +  /** Residuals (label - predicted value) */
      +  @transient lazy val residuals: DataFrame = {
      +    val t = udf { (pred: Double, label: Double) => label - pred }
      +    predictions.select(t(col(predictionCol), col(labelCol)).as("residuals"))
         }
      +
       }
       
       /**
      @@ -234,6 +414,7 @@ class LinearRegressionModel private[ml] (
        * See this discussion for detail.
        * http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
        *
      + * When training with intercept enabled,
        * The objective function in the scaled space is given by
        * {{{
        * L = 1/2n ||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2,
      @@ -241,6 +422,10 @@ class LinearRegressionModel private[ml] (
        * where \bar{x_i} is the mean of x_i, \hat{x_i} is the standard deviation of x_i,
        * \bar{y} is the mean of label, and \hat{y} is the standard deviation of label.
        *
      + * If we fitting the intercept disabled (that is forced through 0.0),
      + * we can use the same equation except we set \bar{y} and \bar{x_i} to 0 instead
      + * of the respective means.
      + *
        * This can be rewritten as
        * {{{
        * L = 1/2n ||\sum_i (w_i/\hat{x_i})x_i - \sum_i (w_i/\hat{x_i})\bar{x_i} - y / \hat{y}
      @@ -255,6 +440,7 @@ class LinearRegressionModel private[ml] (
        * \sum_i w_i^\prime x_i - y / \hat{y} + offset
        * }}}
        *
      + *
        * Note that the effective weights and offset don't depend on training dataset,
        * so they can be precomputed.
        *
      @@ -294,6 +480,7 @@ class LinearRegressionModel private[ml] (
        * @param weights The weights/coefficients corresponding to the features.
        * @param labelStd The standard deviation value of the label.
        * @param labelMean The mean value of the label.
      + * @param fitIntercept Whether to fit an intercept term.
        * @param featuresStd The standard deviation values of the features.
        * @param featuresMean The mean values of the features.
        */
      @@ -301,6 +488,7 @@ private class LeastSquaresAggregator(
           weights: Vector,
           labelStd: Double,
           labelMean: Double,
      +    fitIntercept: Boolean,
           featuresStd: Array[Double],
           featuresMean: Array[Double]) extends Serializable {
       
      @@ -321,7 +509,7 @@ private class LeastSquaresAggregator(
             }
             i += 1
           }
      -    (weightsArray, -sum + labelMean / labelStd, weightsArray.length)
      +    (weightsArray, if (fitIntercept) labelMean / labelStd - sum else 0.0, weightsArray.length)
         }
       
         private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray)
      @@ -404,6 +592,8 @@ private class LeastSquaresCostFun(
           data: RDD[(Double, Vector)],
           labelStd: Double,
           labelMean: Double,
      +    fitIntercept: Boolean,
      +    standardization: Boolean,
           featuresStd: Array[Double],
           featuresMean: Array[Double],
           effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] {
      @@ -412,7 +602,7 @@ private class LeastSquaresCostFun(
           val w = Vectors.fromBreeze(weights)
       
           val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd,
      -      labelMean, featuresStd, featuresMean))(
      +      labelMean, fitIntercept, featuresStd, featuresMean))(
               seqOp = (c, v) => (c, v) match {
                 case (aggregator, (label, features)) => aggregator.add(label, features)
               },
      @@ -420,14 +610,38 @@ private class LeastSquaresCostFun(
                 case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
               })
       
      -    // regVal is the sum of weight squares for L2 regularization
      -    val norm = brzNorm(weights, 2.0)
      -    val regVal = 0.5 * effectiveL2regParam * norm * norm
      +    val totalGradientArray = leastSquaresAggregator.gradient.toArray
       
      -    val loss = leastSquaresAggregator.loss + regVal
      -    val gradient = leastSquaresAggregator.gradient
      -    axpy(effectiveL2regParam, w, gradient)
      +    val regVal = if (effectiveL2regParam == 0.0) {
      +      0.0
      +    } else {
      +      var sum = 0.0
      +      w.foreachActive { (index, value) =>
      +        // The following code will compute the loss of the regularization; also
      +        // the gradient of the regularization, and add back to totalGradientArray.
      +        sum += {
      +          if (standardization) {
      +            totalGradientArray(index) += effectiveL2regParam * value
      +            value * value
      +          } else {
      +            if (featuresStd(index) != 0.0) {
      +              // If `standardization` is false, we still standardize the data
      +              // to improve the rate of convergence; as a result, we have to
      +              // perform this reverse standardization by penalizing each component
      +              // differently to get effectively the same objective function when
      +              // the training dataset is not standardized.
      +              val temp = value / (featuresStd(index) * featuresStd(index))
      +              totalGradientArray(index) += effectiveL2regParam * temp
      +              value * temp
      +            } else {
      +              0.0
      +            }
      +          }
      +        }
      +      }
      +      0.5 * effectiveL2regParam * sum
      +    }
       
      -    (loss, gradient.toBreeze.asInstanceOf[BDV[Double]])
      +    (leastSquaresAggregator.loss + regVal, new BDV(totalGradientArray))
         }
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
      index 21c59061a02f..ddb7214416a6 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
      @@ -21,14 +21,16 @@ import org.apache.spark.annotation.Experimental
       import org.apache.spark.ml.{PredictionModel, Predictor}
       import org.apache.spark.ml.param.ParamMap
       import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams}
      +import org.apache.spark.ml.tree.impl.RandomForest
       import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
       import org.apache.spark.mllib.linalg.Vector
       import org.apache.spark.mllib.regression.LabeledPoint
      -import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
       import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
       import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
       import org.apache.spark.rdd.RDD
       import org.apache.spark.sql.DataFrame
      +import org.apache.spark.sql.functions._
      +
       
       /**
        * :: Experimental ::
      @@ -82,9 +84,11 @@ final class RandomForestRegressor(override val uid: String)
           val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
           val strategy =
             super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
      -    val oldModel = OldRandomForest.trainRegressor(
      -      oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
      -    RandomForestRegressionModel.fromOld(oldModel, this, categoricalFeatures)
      +    val trees =
      +      RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
      +        .map(_.asInstanceOf[DecisionTreeRegressionModel])
      +    val numFeatures = oldDataset.first().features.size
      +    new RandomForestRegressionModel(trees, numFeatures)
         }
       
         override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra)
      @@ -105,16 +109,25 @@ object RandomForestRegressor {
        * [[http://en.wikipedia.org/wiki/Random_forest  Random Forest]] model for regression.
        * It supports both continuous and categorical features.
        * @param _trees  Decision trees in the ensemble.
      + * @param numFeatures  Number of features used by this model
        */
       @Experimental
       final class RandomForestRegressionModel private[ml] (
           override val uid: String,
      -    private val _trees: Array[DecisionTreeRegressionModel])
      +    private val _trees: Array[DecisionTreeRegressionModel],
      +    val numFeatures: Int)
         extends PredictionModel[Vector, RandomForestRegressionModel]
         with TreeEnsembleModel with Serializable {
       
         require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.")
       
      +  /**
      +   * Construct a random forest regression model, with all trees weighted equally.
      +   * @param trees  Component trees
      +   */
      +  private[ml] def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) =
      +    this(Identifiable.randomUID("rfr"), trees, numFeatures)
      +
         override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
       
         // Note: We may add support for weights (based on tree performance) later on.
      @@ -122,22 +135,46 @@ final class RandomForestRegressionModel private[ml] (
       
         override def treeWeights: Array[Double] = _treeWeights
       
      +  override protected def transformImpl(dataset: DataFrame): DataFrame = {
      +    val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
      +    val predictUDF = udf { (features: Any) =>
      +      bcastModel.value.predict(features.asInstanceOf[Vector])
      +    }
      +    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
      +  }
      +
         override protected def predict(features: Vector): Double = {
      -    // TODO: Override transform() to broadcast model.  SPARK-7127
           // TODO: When we add a generic Bagging class, handle transform there.  SPARK-7128
           // Predict average of tree predictions.
           // Ignore the weights since all are 1.0 for now.
      -    _trees.map(_.rootNode.predict(features)).sum / numTrees
      +    _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees
         }
       
         override def copy(extra: ParamMap): RandomForestRegressionModel = {
      -    copyValues(new RandomForestRegressionModel(uid, _trees), extra)
      +    copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent)
         }
       
         override def toString: String = {
      -    s"RandomForestRegressionModel with $numTrees trees"
      +    s"RandomForestRegressionModel (uid=$uid) with $numTrees trees"
         }
       
      +  /**
      +   * Estimate of the importance of each feature.
      +   *
      +   * This generalizes the idea of "Gini" importance to other losses,
      +   * following the explanation of Gini importance from "Random Forests" documentation
      +   * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
      +   *
      +   * This feature importance is calculated as follows:
      +   *  - Average over trees:
      +   *     - importance(feature j) = sum (over nodes which split on feature j) of the gain,
      +   *       where gain is scaled by the number of instances passing through node
      +   *     - Normalize importances for tree based on total number of training instances used
      +   *       to build tree.
      +   *  - Normalize feature importance vector to sum to 1.
      +   */
      +  lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
      +
         /** (private[ml]) Convert to a model in the old API */
         private[ml] def toOld: OldRandomForestModel = {
           new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld))
      @@ -157,6 +194,6 @@ private[ml] object RandomForestRegressionModel {
             // parent for each tree is null since there is no good way to set this.
             DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
           }
      -    new RandomForestRegressionModel(parent.uid, newTrees)
      +    new RandomForestRegressionModel(parent.uid, newTrees, -1)
         }
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
      new file mode 100644
      index 000000000000..1f627777fc68
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
      @@ -0,0 +1,116 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.source.libsvm
      +
      +import com.google.common.base.Objects
      +
      +import org.apache.spark.Logging
      +import org.apache.spark.annotation.Since
      +import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
      +import org.apache.spark.mllib.util.MLUtils
      +import org.apache.spark.rdd.RDD
      +import org.apache.spark.sql.{DataFrameReader, DataFrame, Row, SQLContext}
      +import org.apache.spark.sql.sources._
      +import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
      +
      +/**
      + * LibSVMRelation provides the DataFrame constructed from LibSVM format data.
      + * @param path File path of LibSVM format
      + * @param numFeatures The number of features
      + * @param vectorType The type of vector. It can be 'sparse' or 'dense'
      + * @param sqlContext The Spark SQLContext
      + */
      +private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String)
      +    (@transient val sqlContext: SQLContext)
      +  extends BaseRelation with TableScan with Logging with Serializable {
      +
      +  override def schema: StructType = StructType(
      +    StructField("label", DoubleType, nullable = false) ::
      +      StructField("features", new VectorUDT(), nullable = false) :: Nil
      +  )
      +
      +  override def buildScan(): RDD[Row] = {
      +    val sc = sqlContext.sparkContext
      +    val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures)
      +    val sparse = vectorType == "sparse"
      +    baseRdd.map { pt =>
      +      val features = if (sparse) pt.features.toSparse else pt.features.toDense
      +      Row(pt.label, features)
      +    }
      +  }
      +
      +  override def hashCode(): Int = {
      +    Objects.hashCode(path, Double.box(numFeatures), vectorType)
      +  }
      +
      +  override def equals(other: Any): Boolean = other match {
      +    case that: LibSVMRelation =>
      +      path == that.path &&
      +        numFeatures == that.numFeatures &&
      +        vectorType == that.vectorType
      +    case _ =>
      +      false
      +  }
      +}
      +
      +/**
      + * `libsvm` package implements Spark SQL data source API for loading LIBSVM data as [[DataFrame]].
      + * The loaded [[DataFrame]] has two columns: `label` containing labels stored as doubles and
      + * `features` containing feature vectors stored as [[Vector]]s.
      + *
      + * To use LIBSVM data source, you need to set "libsvm" as the format in [[DataFrameReader]] and
      + * optionally specify options, for example:
      + * {{{
      + *   // Scala
      + *   val df = sqlContext.read.format("libsvm")
      + *     .option("numFeatures", "780")
      + *     .load("data/mllib/sample_libsvm_data.txt")
      + *
      + *   // Java
      + *   DataFrame df = sqlContext.read.format("libsvm")
      + *     .option("numFeatures, "780")
      + *     .load("data/mllib/sample_libsvm_data.txt");
      + * }}}
      + *
      + * LIBSVM data source supports the following options:
      + *  - "numFeatures": number of features.
      + *    If unspecified or nonpositive, the number of features will be determined automatically at the
      + *    cost of one additional pass.
      + *    This is also useful when the dataset is already split into multiple files and you want to load
      + *    them separately, because some features may not present in certain files, which leads to
      + *    inconsistent feature dimensions.
      + *  - "vectorType": feature vector type, "sparse" (default) or "dense".
      + *
      + *  @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]]
      + */
      +@Since("1.6.0")
      +class DefaultSource extends RelationProvider with DataSourceRegister {
      +
      +  @Since("1.6.0")
      +  override def shortName(): String = "libsvm"
      +
      +  @Since("1.6.0")
      +  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String])
      +    : BaseRelation = {
      +    val path = parameters.getOrElse("path",
      +      throw new IllegalArgumentException("'path' must be specified"))
      +    val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt
      +    val vectorType = parameters.getOrElse("vectorType", "sparse")
      +    new LibSVMRelation(path, numFeatures, vectorType)(sqlContext)
      +  }
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
      index 4242154be14c..cd2493129390 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
      @@ -19,8 +19,9 @@ package org.apache.spark.ml.tree
       
       import org.apache.spark.annotation.DeveloperApi
       import org.apache.spark.mllib.linalg.Vector
      +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
       import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats,
      -  Node => OldNode, Predict => OldPredict}
      +  Node => OldNode, Predict => OldPredict, ImpurityStats}
       
       /**
        * :: DeveloperApi ::
      @@ -38,8 +39,15 @@ sealed abstract class Node extends Serializable {
         /** Impurity measure at this node (for training data) */
         def impurity: Double
       
      +  /**
      +   * Statistics aggregated from training data at this node, used to compute prediction, impurity,
      +   * and probabilities.
      +   * For classification, the array of class counts must be normalized to a probability distribution.
      +   */
      +  private[ml] def impurityStats: ImpurityCalculator
      +
         /** Recursive prediction helper method */
      -  private[ml] def predict(features: Vector): Double = prediction
      +  private[ml] def predictImpl(features: Vector): LeafNode
       
         /**
          * Get the number of nodes in tree below this node, including leaf nodes.
      @@ -64,6 +72,12 @@ sealed abstract class Node extends Serializable {
          * @param id  Node ID using old format IDs
          */
         private[ml] def toOld(id: Int): OldNode
      +
      +  /**
      +   * Trace down the tree, and return the largest feature index used in any split.
      +   * @return  Max feature index used in a split, or -1 if there are no splits (single leaf node).
      +   */
      +  private[ml] def maxSplitFeatureIndex(): Int
       }
       
       private[ml] object Node {
      @@ -75,7 +89,8 @@ private[ml] object Node {
           if (oldNode.isLeaf) {
             // TODO: Once the implementation has been moved to this API, then include sufficient
             //       statistics here.
      -      new LeafNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity)
      +      new LeafNode(prediction = oldNode.predict.predict,
      +        impurity = oldNode.impurity, impurityStats = null)
           } else {
             val gain = if (oldNode.stats.nonEmpty) {
               oldNode.stats.get.gain
      @@ -85,7 +100,7 @@ private[ml] object Node {
             new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity,
               gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures),
               rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
      -        split = Split.fromOld(oldNode.split.get, categoricalFeatures))
      +        split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null)
           }
         }
       }
      @@ -99,11 +114,13 @@ private[ml] object Node {
       @DeveloperApi
       final class LeafNode private[ml] (
           override val prediction: Double,
      -    override val impurity: Double) extends Node {
      +    override val impurity: Double,
      +    override private[ml] val impurityStats: ImpurityCalculator) extends Node {
       
      -  override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)"
      +  override def toString: String =
      +    s"LeafNode(prediction = $prediction, impurity = $impurity)"
       
      -  override private[ml] def predict(features: Vector): Double = prediction
      +  override private[ml] def predictImpl(features: Vector): LeafNode = this
       
         override private[tree] def numDescendants: Int = 0
       
      @@ -115,10 +132,11 @@ final class LeafNode private[ml] (
         override private[tree] def subtreeDepth: Int = 0
       
         override private[ml] def toOld(id: Int): OldNode = {
      -    // NOTE: We do NOT store 'prob' in the new API currently.
      -    new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = true,
      -      None, None, None, None)
      +    new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)),
      +      impurity, isLeaf = true, None, None, None, None)
         }
      +
      +  override private[ml] def maxSplitFeatureIndex(): Int = -1
       }
       
       /**
      @@ -139,17 +157,18 @@ final class InternalNode private[ml] (
           val gain: Double,
           val leftChild: Node,
           val rightChild: Node,
      -    val split: Split) extends Node {
      +    val split: Split,
      +    override private[ml] val impurityStats: ImpurityCalculator) extends Node {
       
         override def toString: String = {
           s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)"
         }
       
      -  override private[ml] def predict(features: Vector): Double = {
      +  override private[ml] def predictImpl(features: Vector): LeafNode = {
           if (split.shouldGoLeft(features)) {
      -      leftChild.predict(features)
      +      leftChild.predictImpl(features)
           } else {
      -      rightChild.predict(features)
      +      rightChild.predictImpl(features)
           }
         }
       
      @@ -172,14 +191,18 @@ final class InternalNode private[ml] (
         override private[ml] def toOld(id: Int): OldNode = {
           assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API"
             + " since the old API does not support deep trees.")
      -    // NOTE: We do NOT store 'prob' in the new API currently.
      -    new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = false,
      -      Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))),
      +    new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), impurity,
      +      isLeaf = false, Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))),
             Some(rightChild.toOld(OldNode.rightChildIndex(id))),
             Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity,
               new OldPredict(leftChild.prediction, prob = 0.0),
               new OldPredict(rightChild.prediction, prob = 0.0))))
         }
      +
      +  override private[ml] def maxSplitFeatureIndex(): Int = {
      +    math.max(split.featureIndex,
      +      math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex()))
      +  }
       }
       
       private object InternalNode {
      @@ -209,3 +232,130 @@ private object InternalNode {
           }
         }
       }
      +
      +/**
      + * Version of a node used in learning.  This uses vars so that we can modify nodes as we split the
      + * tree by adding children, etc.
      + *
      + * For now, we use node IDs.  These will be kept internal since we hope to remove node IDs
      + * in the future, or at least change the indexing (so that we can support much deeper trees).
      + *
      + * This node can either be:
      + *  - a leaf node, with leftChild, rightChild, split set to null, or
      + *  - an internal node, with all values set
      + *
      + * @param id  We currently use the same indexing as the old implementation in
      + *            [[org.apache.spark.mllib.tree.model.Node]], but this will change later.
      + * @param isLeaf  Indicates whether this node will definitely be a leaf in the learned tree,
      + *                so that we do not need to consider splitting it further.
      + * @param stats  Impurity statistics for this node.
      + */
      +private[tree] class LearningNode(
      +    var id: Int,
      +    var leftChild: Option[LearningNode],
      +    var rightChild: Option[LearningNode],
      +    var split: Option[Split],
      +    var isLeaf: Boolean,
      +    var stats: ImpurityStats) extends Serializable {
      +
      +  /**
      +   * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children.
      +   */
      +  def toNode: Node = {
      +    if (leftChild.nonEmpty) {
      +      assert(rightChild.nonEmpty && split.nonEmpty && stats != null,
      +        "Unknown error during Decision Tree learning.  Could not convert LearningNode to Node.")
      +      new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain,
      +        leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator)
      +    } else {
      +      if (stats.valid) {
      +        new LeafNode(stats.impurityCalculator.predict, stats.impurity,
      +          stats.impurityCalculator)
      +      } else {
      +        // Here we want to keep same behavior with the old mllib.DecisionTreeModel
      +        new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator)
      +      }
      +
      +    }
      +  }
      +
      +}
      +
      +private[tree] object LearningNode {
      +
      +  /** Create a node with some of its fields set. */
      +  def apply(
      +      id: Int,
      +      isLeaf: Boolean,
      +      stats: ImpurityStats): LearningNode = {
      +    new LearningNode(id, None, None, None, false, stats)
      +  }
      +
      +  /** Create an empty node with the given node index.  Values must be set later on. */
      +  def emptyNode(nodeIndex: Int): LearningNode = {
      +    new LearningNode(nodeIndex, None, None, None, false, null)
      +  }
      +
      +  // The below indexing methods were copied from spark.mllib.tree.model.Node
      +
      +  /**
      +   * Return the index of the left child of this node.
      +   */
      +  def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1
      +
      +  /**
      +   * Return the index of the right child of this node.
      +   */
      +  def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1
      +
      +  /**
      +   * Get the parent index of the given node, or 0 if it is the root.
      +   */
      +  def parentIndex(nodeIndex: Int): Int = nodeIndex >> 1
      +
      +  /**
      +   * Return the level of a tree which the given node is in.
      +   */
      +  def indexToLevel(nodeIndex: Int): Int = if (nodeIndex == 0) {
      +    throw new IllegalArgumentException(s"0 is not a valid node index.")
      +  } else {
      +    java.lang.Integer.numberOfTrailingZeros(java.lang.Integer.highestOneBit(nodeIndex))
      +  }
      +
      +  /**
      +   * Returns true if this is a left child.
      +   * Note: Returns false for the root.
      +   */
      +  def isLeftChild(nodeIndex: Int): Boolean = nodeIndex > 1 && nodeIndex % 2 == 0
      +
      +  /**
      +   * Return the maximum number of nodes which can be in the given level of the tree.
      +   * @param level  Level of tree (0 = root).
      +   */
      +  def maxNodesInLevel(level: Int): Int = 1 << level
      +
      +  /**
      +   * Return the index of the first node in the given level.
      +   * @param level  Level of tree (0 = root).
      +   */
      +  def startIndexInLevel(level: Int): Int = 1 << level
      +
      +  /**
      +   * Traces down from a root node to get the node with the given node index.
      +   * This assumes the node exists.
      +   */
      +  def getNode(nodeIndex: Int, rootNode: LearningNode): LearningNode = {
      +    var tmpNode: LearningNode = rootNode
      +    var levelsToGo = indexToLevel(nodeIndex)
      +    while (levelsToGo > 0) {
      +      if ((nodeIndex & (1 << levelsToGo - 1)) == 0) {
      +        tmpNode = tmpNode.leftChild.asInstanceOf[LearningNode]
      +      } else {
      +        tmpNode = tmpNode.rightChild.asInstanceOf[LearningNode]
      +      }
      +      levelsToGo -= 1
      +    }
      +    tmpNode
      +  }
      +
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
      index 7acdeeee72d2..78199cc2df58 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
      @@ -34,9 +34,19 @@ sealed trait Split extends Serializable {
         /** Index of feature which this split tests */
         def featureIndex: Int
       
      -  /** Return true (split to left) or false (split to right) */
      +  /**
      +   * Return true (split to left) or false (split to right).
      +   * @param features  Vector of features (original values, not binned).
      +   */
         private[ml] def shouldGoLeft(features: Vector): Boolean
       
      +  /**
      +   * Return true (split to left) or false (split to right).
      +   * @param binnedFeature Binned feature value.
      +   * @param splits All splits for the given feature.
      +   */
      +  private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean
      +
         /** Convert to old Split format */
         private[tree] def toOld: OldSplit
       }
      @@ -94,6 +104,14 @@ final class CategoricalSplit private[ml] (
           }
         }
       
      +  override private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean = {
      +    if (isLeft) {
      +      categories.contains(binnedFeature.toDouble)
      +    } else {
      +      !categories.contains(binnedFeature.toDouble)
      +    }
      +  }
      +
         override def equals(o: Any): Boolean = {
           o match {
             case other: CategoricalSplit => featureIndex == other.featureIndex &&
      @@ -144,6 +162,16 @@ final class ContinuousSplit private[ml] (override val featureIndex: Int, val thr
           features(featureIndex) <= threshold
         }
       
      +  override private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean = {
      +    if (binnedFeature == splits.length) {
      +      // > last split, so split right
      +      false
      +    } else {
      +      val featureValueUpperBound = splits(binnedFeature).asInstanceOf[ContinuousSplit].threshold
      +      featureValueUpperBound <= threshold
      +    }
      +  }
      +
         override def equals(o: Any): Boolean = {
           o match {
             case other: ContinuousSplit =>
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
      new file mode 100644
      index 000000000000..488e8e4fb5dc
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
      @@ -0,0 +1,194 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.tree.impl
      +
      +import java.io.IOException
      +
      +import scala.collection.mutable
      +
      +import org.apache.hadoop.fs.{Path, FileSystem}
      +
      +import org.apache.spark.Logging
      +import org.apache.spark.annotation.DeveloperApi
      +import org.apache.spark.ml.tree.{LearningNode, Split}
      +import org.apache.spark.mllib.tree.impl.BaggedPoint
      +import org.apache.spark.rdd.RDD
      +import org.apache.spark.storage.StorageLevel
      +
      +
      +/**
      + * This is used by the node id cache to find the child id that a data point would belong to.
      + * @param split Split information.
      + * @param nodeIndex The current node index of a data point that this will update.
      + */
      +private[tree] case class NodeIndexUpdater(split: Split, nodeIndex: Int) {
      +
      +  /**
      +   * Determine a child node index based on the feature value and the split.
      +   * @param binnedFeature Binned feature value.
      +   * @param splits Split information to convert the bin indices to approximate feature values.
      +   * @return Child node index to update to.
      +   */
      +  def updateNodeIndex(binnedFeature: Int, splits: Array[Split]): Int = {
      +    if (split.shouldGoLeft(binnedFeature, splits)) {
      +      LearningNode.leftChildIndex(nodeIndex)
      +    } else {
      +      LearningNode.rightChildIndex(nodeIndex)
      +    }
      +  }
      +}
      +
      +/**
      + * Each TreePoint belongs to a particular node per tree.
      + * Each row in the nodeIdsForInstances RDD is an array over trees of the node index
      + * in each tree. Initially, values should all be 1 for root node.
      + * The nodeIdsForInstances RDD needs to be updated at each iteration.
      + * @param nodeIdsForInstances The initial values in the cache
      + *                           (should be an Array of all 1's (meaning the root nodes)).
      + * @param checkpointInterval The checkpointing interval
      + *                           (how often should the cache be checkpointed.).
      + */
      +private[spark] class NodeIdCache(
      +  var nodeIdsForInstances: RDD[Array[Int]],
      +  val checkpointInterval: Int) extends Logging {
      +
      +  // Keep a reference to a previous node Ids for instances.
      +  // Because we will keep on re-persisting updated node Ids,
      +  // we want to unpersist the previous RDD.
      +  private var prevNodeIdsForInstances: RDD[Array[Int]] = null
      +
      +  // To keep track of the past checkpointed RDDs.
      +  private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
      +  private var rddUpdateCount = 0
      +
      +  // Indicates whether we can checkpoint
      +  private val canCheckpoint = nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty
      +
      +  // FileSystem instance for deleting checkpoints as needed
      +  private val fs = FileSystem.get(nodeIdsForInstances.sparkContext.hadoopConfiguration)
      +
      +  /**
      +   * Update the node index values in the cache.
      +   * This updates the RDD and its lineage.
      +   * TODO: Passing bin information to executors seems unnecessary and costly.
      +   * @param data The RDD of training rows.
      +   * @param nodeIdUpdaters A map of node index updaters.
      +   *                       The key is the indices of nodes that we want to update.
      +   * @param splits  Split information needed to find child node indices.
      +   */
      +  def updateNodeIndices(
      +      data: RDD[BaggedPoint[TreePoint]],
      +      nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]],
      +      splits: Array[Array[Split]]): Unit = {
      +    if (prevNodeIdsForInstances != null) {
      +      // Unpersist the previous one if one exists.
      +      prevNodeIdsForInstances.unpersist()
      +    }
      +
      +    prevNodeIdsForInstances = nodeIdsForInstances
      +    nodeIdsForInstances = data.zip(nodeIdsForInstances).map { case (point, ids) =>
      +      var treeId = 0
      +      while (treeId < nodeIdUpdaters.length) {
      +        val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(ids(treeId), null)
      +        if (nodeIdUpdater != null) {
      +          val featureIndex = nodeIdUpdater.split.featureIndex
      +          val newNodeIndex = nodeIdUpdater.updateNodeIndex(
      +            binnedFeature = point.datum.binnedFeatures(featureIndex),
      +            splits = splits(featureIndex))
      +          ids(treeId) = newNodeIndex
      +        }
      +        treeId += 1
      +      }
      +      ids
      +    }
      +
      +    // Keep on persisting new ones.
      +    nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK)
      +    rddUpdateCount += 1
      +
      +    // Handle checkpointing if the directory is not None.
      +    if (canCheckpoint && (rddUpdateCount % checkpointInterval) == 0) {
      +      // Let's see if we can delete previous checkpoints.
      +      var canDelete = true
      +      while (checkpointQueue.size > 1 && canDelete) {
      +        // We can delete the oldest checkpoint iff
      +        // the next checkpoint actually exists in the file system.
      +        if (checkpointQueue(1).getCheckpointFile.isDefined) {
      +          val old = checkpointQueue.dequeue()
      +          // Since the old checkpoint is not deleted by Spark, we'll manually delete it here.
      +          try {
      +            fs.delete(new Path(old.getCheckpointFile.get), true)
      +          } catch {
      +            case e: IOException =>
      +              logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" +
      +                s" file: ${old.getCheckpointFile.get}")
      +          }
      +        } else {
      +          canDelete = false
      +        }
      +      }
      +
      +      nodeIdsForInstances.checkpoint()
      +      checkpointQueue.enqueue(nodeIdsForInstances)
      +    }
      +  }
      +
      +  /**
      +   * Call this after training is finished to delete any remaining checkpoints.
      +   */
      +  def deleteAllCheckpoints(): Unit = {
      +    while (checkpointQueue.nonEmpty) {
      +      val old = checkpointQueue.dequeue()
      +      if (old.getCheckpointFile.isDefined) {
      +        try {
      +          fs.delete(new Path(old.getCheckpointFile.get), true)
      +        } catch {
      +          case e: IOException =>
      +            logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" +
      +              s" file: ${old.getCheckpointFile.get}")
      +        }
      +      }
      +    }
      +  }
      +  if (prevNodeIdsForInstances != null) {
      +    // Unpersist the previous one if one exists.
      +    prevNodeIdsForInstances.unpersist()
      +  }
      +}
      +
      +@DeveloperApi
      +private[spark] object NodeIdCache {
      +  /**
      +   * Initialize the node Id cache with initial node Id values.
      +   * @param data The RDD of training rows.
      +   * @param numTrees The number of trees that we want to create cache for.
      +   * @param checkpointInterval The checkpointing interval
      +   *                           (how often should the cache be checkpointed.).
      +   * @param initVal The initial values in the cache.
      +   * @return A node Id cache containing an RDD of initial root node Indices.
      +   */
      +  def init(
      +      data: RDD[BaggedPoint[TreePoint]],
      +      numTrees: Int,
      +      checkpointInterval: Int,
      +      initVal: Int = 1): NodeIdCache = {
      +    new NodeIdCache(
      +      data.map(_ => Array.fill[Int](numTrees)(initVal)),
      +      checkpointInterval)
      +  }
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
      new file mode 100644
      index 000000000000..4ac51a475474
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
      @@ -0,0 +1,1208 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.tree.impl
      +
      +import java.io.IOException
      +
      +import scala.collection.mutable
      +import scala.util.Random
      +
      +import org.apache.spark.Logging
      +import org.apache.spark.ml.classification.DecisionTreeClassificationModel
      +import org.apache.spark.ml.regression.DecisionTreeRegressionModel
      +import org.apache.spark.ml.tree._
      +import org.apache.spark.mllib.linalg.{Vectors, Vector}
      +import org.apache.spark.mllib.regression.LabeledPoint
      +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
      +import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata,
      +  TimeTracker}
      +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
      +import org.apache.spark.mllib.tree.model.ImpurityStats
      +import org.apache.spark.rdd.RDD
      +import org.apache.spark.storage.StorageLevel
      +import org.apache.spark.util.collection.OpenHashMap
      +import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
      +
      +
      +private[ml] object RandomForest extends Logging {
      +
      +  /**
      +   * Train a random forest.
      +   * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
      +   * @return an unweighted set of trees
      +   */
      +  def run(
      +      input: RDD[LabeledPoint],
      +      strategy: OldStrategy,
      +      numTrees: Int,
      +      featureSubsetStrategy: String,
      +      seed: Long,
      +      parentUID: Option[String] = None): Array[DecisionTreeModel] = {
      +
      +    val timer = new TimeTracker()
      +
      +    timer.start("total")
      +
      +    timer.start("init")
      +
      +    val retaggedInput = input.retag(classOf[LabeledPoint])
      +    val metadata =
      +      DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
      +    logDebug("algo = " + strategy.algo)
      +    logDebug("numTrees = " + numTrees)
      +    logDebug("seed = " + seed)
      +    logDebug("maxBins = " + metadata.maxBins)
      +    logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
      +    logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)
      +    logDebug("subsamplingRate = " + strategy.subsamplingRate)
      +
      +    // Find the splits and the corresponding bins (interval between the splits) using a sample
      +    // of the input data.
      +    timer.start("findSplitsBins")
      +    val splits = findSplits(retaggedInput, metadata)
      +    timer.stop("findSplitsBins")
      +    logDebug("numBins: feature: number of bins")
      +    logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
      +      s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
      +    }.mkString("\n"))
      +
      +    // Bin feature values (TreePoint representation).
      +    // Cache input RDD for speedup during multiple passes.
      +    val treeInput = TreePoint.convertToTreeRDD(retaggedInput, splits, metadata)
      +
      +    val withReplacement = numTrees > 1
      +
      +    val baggedInput = BaggedPoint
      +      .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, seed)
      +      .persist(StorageLevel.MEMORY_AND_DISK)
      +
      +    // depth of the decision tree
      +    val maxDepth = strategy.maxDepth
      +    require(maxDepth <= 30,
      +      s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
      +
      +    // Max memory usage for aggregates
      +    // TODO: Calculate memory usage more precisely.
      +    val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
      +    logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
      +    val maxMemoryPerNode = {
      +      val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
      +        // Find numFeaturesPerNode largest bins to get an upper bound on memory usage.
      +        Some(metadata.numBins.zipWithIndex.sortBy(- _._1)
      +          .take(metadata.numFeaturesPerNode).map(_._2))
      +      } else {
      +        None
      +      }
      +      RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
      +    }
      +    require(maxMemoryPerNode <= maxMemoryUsage,
      +      s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," +
      +        " which is too small for the given features." +
      +        s"  Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}")
      +
      +    timer.stop("init")
      +
      +    /*
      +     * The main idea here is to perform group-wise training of the decision tree nodes thus
      +     * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup).
      +     * Each data sample is handled by a particular node (or it reaches a leaf and is not used
      +     * in lower levels).
      +     */
      +
      +    // Create an RDD of node Id cache.
      +    // At first, all the rows belong to the root nodes (node Id == 1).
      +    val nodeIdCache = if (strategy.useNodeIdCache) {
      +      Some(NodeIdCache.init(
      +        data = baggedInput,
      +        numTrees = numTrees,
      +        checkpointInterval = strategy.checkpointInterval,
      +        initVal = 1))
      +    } else {
      +      None
      +    }
      +
      +    // FIFO queue of nodes to train: (treeIndex, node)
      +    val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
      +
      +    val rng = new Random()
      +    rng.setSeed(seed)
      +
      +    // Allocate and queue root nodes.
      +    val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
      +    Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
      +
      +    while (nodeQueue.nonEmpty) {
      +      // Collect some nodes to split, and choose features for each node (if subsampling).
      +      // Each group of nodes may come from one or multiple trees, and at multiple levels.
      +      val (nodesForGroup, treeToNodeToIndexInfo) =
      +        RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
      +      // Sanity check (should never occur):
      +      assert(nodesForGroup.nonEmpty,
      +        s"RandomForest selected empty nodesForGroup.  Error for unknown reason.")
      +
      +      // Choose node splits, and enqueue new nodes as needed.
      +      timer.start("findBestSplits")
      +      RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
      +        treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache)
      +      timer.stop("findBestSplits")
      +    }
      +
      +    baggedInput.unpersist()
      +
      +    timer.stop("total")
      +
      +    logInfo("Internal timing for DecisionTree:")
      +    logInfo(s"$timer")
      +
      +    // Delete any remaining checkpoints used for node Id cache.
      +    if (nodeIdCache.nonEmpty) {
      +      try {
      +        nodeIdCache.get.deleteAllCheckpoints()
      +      } catch {
      +        case e: IOException =>
      +          logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}")
      +      }
      +    }
      +
      +    parentUID match {
      +      case Some(uid) =>
      +        if (strategy.algo == OldAlgo.Classification) {
      +          topNodes.map { rootNode =>
      +            new DecisionTreeClassificationModel(uid, rootNode.toNode, strategy.getNumClasses)
      +          }
      +        } else {
      +          topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode))
      +        }
      +      case None =>
      +        if (strategy.algo == OldAlgo.Classification) {
      +          topNodes.map { rootNode =>
      +            new DecisionTreeClassificationModel(rootNode.toNode, strategy.getNumClasses)
      +          }
      +        } else {
      +          topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode))
      +        }
      +    }
      +  }
      +
      +  /**
      +   * Get the node index corresponding to this data point.
      +   * This function mimics prediction, passing an example from the root node down to a leaf
      +   * or unsplit node; that node's index is returned.
      +   *
      +   * @param node  Node in tree from which to classify the given data point.
      +   * @param binnedFeatures  Binned feature vector for data point.
      +   * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
      +   * @return  Leaf index if the data point reaches a leaf.
      +   *          Otherwise, last node reachable in tree matching this example.
      +   *          Note: This is the global node index, i.e., the index used in the tree.
      +   *                This index is different from the index used during training a particular
      +   *                group of nodes on one call to [[findBestSplits()]].
      +   */
      +  private def predictNodeIndex(
      +      node: LearningNode,
      +      binnedFeatures: Array[Int],
      +      splits: Array[Array[Split]]): Int = {
      +    if (node.isLeaf || node.split.isEmpty) {
      +      node.id
      +    } else {
      +      val split = node.split.get
      +      val featureIndex = split.featureIndex
      +      val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex))
      +      if (node.leftChild.isEmpty) {
      +        // Not yet split. Return index from next layer of nodes to train
      +        if (splitLeft) {
      +          LearningNode.leftChildIndex(node.id)
      +        } else {
      +          LearningNode.rightChildIndex(node.id)
      +        }
      +      } else {
      +        if (splitLeft) {
      +          predictNodeIndex(node.leftChild.get, binnedFeatures, splits)
      +        } else {
      +          predictNodeIndex(node.rightChild.get, binnedFeatures, splits)
      +        }
      +      }
      +    }
      +  }
      +
      +  /**
      +   * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features.
      +   *
      +   * For ordered features, a single bin is updated.
      +   * For unordered features, bins correspond to subsets of categories; either the left or right bin
      +   * for each subset is updated.
      +   *
      +   * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for
      +   *             each (feature, bin).
      +   * @param treePoint  Data point being aggregated.
      +   * @param splits possible splits indexed (numFeatures)(numSplits)
      +   * @param unorderedFeatures  Set of indices of unordered features.
      +   * @param instanceWeight  Weight (importance) of instance in dataset.
      +   */
      +  private def mixedBinSeqOp(
      +      agg: DTStatsAggregator,
      +      treePoint: TreePoint,
      +      splits: Array[Array[Split]],
      +      unorderedFeatures: Set[Int],
      +      instanceWeight: Double,
      +      featuresForNode: Option[Array[Int]]): Unit = {
      +    val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
      +      // Use subsampled features
      +      featuresForNode.get.length
      +    } else {
      +      // Use all features
      +      agg.metadata.numFeatures
      +    }
      +    // Iterate over features.
      +    var featureIndexIdx = 0
      +    while (featureIndexIdx < numFeaturesPerNode) {
      +      val featureIndex = if (featuresForNode.nonEmpty) {
      +        featuresForNode.get.apply(featureIndexIdx)
      +      } else {
      +        featureIndexIdx
      +      }
      +      if (unorderedFeatures.contains(featureIndex)) {
      +        // Unordered feature
      +        val featureValue = treePoint.binnedFeatures(featureIndex)
      +        val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
      +          agg.getLeftRightFeatureOffsets(featureIndexIdx)
      +        // Update the left or right bin for each split.
      +        val numSplits = agg.metadata.numSplits(featureIndex)
      +        val featureSplits = splits(featureIndex)
      +        var splitIndex = 0
      +        while (splitIndex < numSplits) {
      +          if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) {
      +            agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
      +          } else {
      +            agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
      +          }
      +          splitIndex += 1
      +        }
      +      } else {
      +        // Ordered feature
      +        val binIndex = treePoint.binnedFeatures(featureIndex)
      +        agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight)
      +      }
      +      featureIndexIdx += 1
      +    }
      +  }
      +
      +  /**
      +   * Helper for binSeqOp, for regression and for classification with only ordered features.
      +   *
      +   * For each feature, the sufficient statistics of one bin are updated.
      +   *
      +   * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for
      +   *             each (feature, bin).
      +   * @param treePoint  Data point being aggregated.
      +   * @param instanceWeight  Weight (importance) of instance in dataset.
      +   */
      +  private def orderedBinSeqOp(
      +      agg: DTStatsAggregator,
      +      treePoint: TreePoint,
      +      instanceWeight: Double,
      +      featuresForNode: Option[Array[Int]]): Unit = {
      +    val label = treePoint.label
      +
      +    // Iterate over features.
      +    if (featuresForNode.nonEmpty) {
      +      // Use subsampled features
      +      var featureIndexIdx = 0
      +      while (featureIndexIdx < featuresForNode.get.length) {
      +        val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
      +        agg.update(featureIndexIdx, binIndex, label, instanceWeight)
      +        featureIndexIdx += 1
      +      }
      +    } else {
      +      // Use all features
      +      val numFeatures = agg.metadata.numFeatures
      +      var featureIndex = 0
      +      while (featureIndex < numFeatures) {
      +        val binIndex = treePoint.binnedFeatures(featureIndex)
      +        agg.update(featureIndex, binIndex, label, instanceWeight)
      +        featureIndex += 1
      +      }
      +    }
      +  }
      +
      +  /**
      +   * Given a group of nodes, this finds the best split for each node.
      +   *
      +   * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
      +   * @param metadata Learning and dataset metadata
      +   * @param topNodes Root node for each tree.  Used for matching instances with nodes.
      +   * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
      +   * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
      +   *                              where nodeIndexInfo stores the index in the group and the
      +   *                              feature subsets (if using feature subsets).
      +   * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
      +   * @param nodeQueue  Queue of nodes to split, with values (treeIndex, node).
      +   *                   Updated with new non-leaf nodes which are created.
      +   * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
      +   *                    each value in the array is the data point's node Id
      +   *                    for a corresponding tree. This is used to prevent the need
      +   *                    to pass the entire tree to the executors during
      +   *                    the node stat aggregation phase.
      +   */
      +  private[tree] def findBestSplits(
      +      input: RDD[BaggedPoint[TreePoint]],
      +      metadata: DecisionTreeMetadata,
      +      topNodes: Array[LearningNode],
      +      nodesForGroup: Map[Int, Array[LearningNode]],
      +      treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
      +      splits: Array[Array[Split]],
      +      nodeQueue: mutable.Queue[(Int, LearningNode)],
      +      timer: TimeTracker = new TimeTracker,
      +      nodeIdCache: Option[NodeIdCache] = None): Unit = {
      +
      +    /*
      +     * The high-level descriptions of the best split optimizations are noted here.
      +     *
      +     * *Group-wise training*
      +     * We perform bin calculations for groups of nodes to reduce the number of
      +     * passes over the data.  Each iteration requires more computation and storage,
      +     * but saves several iterations over the data.
      +     *
      +     * *Bin-wise computation*
      +     * We use a bin-wise best split computation strategy instead of a straightforward best split
      +     * computation strategy. Instead of analyzing each sample for contribution to the left/right
      +     * child node impurity of every split, we first categorize each feature of a sample into a
      +     * bin. We exploit this structure to calculate aggregates for bins and then use these aggregates
      +     * to calculate information gain for each split.
      +     *
      +     * *Aggregation over partitions*
      +     * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know
      +     * the number of splits in advance. Thus, we store the aggregates (at the appropriate
      +     * indices) in a single array for all bins and rely upon the RDD aggregate method to
      +     * drastically reduce the communication overhead.
      +     */
      +
      +    // numNodes:  Number of nodes in this group
      +    val numNodes = nodesForGroup.values.map(_.length).sum
      +    logDebug("numNodes = " + numNodes)
      +    logDebug("numFeatures = " + metadata.numFeatures)
      +    logDebug("numClasses = " + metadata.numClasses)
      +    logDebug("isMulticlass = " + metadata.isMulticlass)
      +    logDebug("isMulticlassWithCategoricalFeatures = " +
      +      metadata.isMulticlassWithCategoricalFeatures)
      +    logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString)
      +
      +    /**
      +     * Performs a sequential aggregation over a partition for a particular tree and node.
      +     *
      +     * For each feature, the aggregate sufficient statistics are updated for the relevant
      +     * bins.
      +     *
      +     * @param treeIndex Index of the tree that we want to perform aggregation for.
      +     * @param nodeInfo The node info for the tree node.
      +     * @param agg Array storing aggregate calculation, with a set of sufficient statistics
      +     *            for each (node, feature, bin).
      +     * @param baggedPoint Data point being aggregated.
      +     */
      +    def nodeBinSeqOp(
      +        treeIndex: Int,
      +        nodeInfo: NodeIndexInfo,
      +        agg: Array[DTStatsAggregator],
      +        baggedPoint: BaggedPoint[TreePoint]): Unit = {
      +      if (nodeInfo != null) {
      +        val aggNodeIndex = nodeInfo.nodeIndexInGroup
      +        val featuresForNode = nodeInfo.featureSubset
      +        val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
      +        if (metadata.unorderedFeatures.isEmpty) {
      +          orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
      +        } else {
      +          mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
      +            metadata.unorderedFeatures, instanceWeight, featuresForNode)
      +        }
      +      }
      +    }
      +
      +    /**
      +     * Performs a sequential aggregation over a partition.
      +     *
      +     * Each data point contributes to one node. For each feature,
      +     * the aggregate sufficient statistics are updated for the relevant bins.
      +     *
      +     * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for
      +     *             each (node, feature, bin).
      +     * @param baggedPoint   Data point being aggregated.
      +     * @return  agg
      +     */
      +    def binSeqOp(
      +        agg: Array[DTStatsAggregator],
      +        baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
      +      treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
      +        val nodeIndex =
      +          predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, splits)
      +        nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
      +      }
      +      agg
      +    }
      +
      +    /**
      +     * Do the same thing as binSeqOp, but with nodeIdCache.
      +     */
      +    def binSeqOpWithNodeIdCache(
      +        agg: Array[DTStatsAggregator],
      +        dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = {
      +      treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
      +        val baggedPoint = dataPoint._1
      +        val nodeIdCache = dataPoint._2
      +        val nodeIndex = nodeIdCache(treeIndex)
      +        nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
      +      }
      +
      +      agg
      +    }
      +
      +    /**
      +     * Get node index in group --> features indices map,
      +     * which is a short cut to find feature indices for a node given node index in group.
      +     */
      +    def getNodeToFeatures(
      +        treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]): Option[Map[Int, Array[Int]]] = {
      +      if (!metadata.subsamplingFeatures) {
      +        None
      +      } else {
      +        val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]()
      +        treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo =>
      +          nodeIdToNodeInfo.values.foreach { nodeIndexInfo =>
      +            assert(nodeIndexInfo.featureSubset.isDefined)
      +            mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get
      +          }
      +        }
      +        Some(mutableNodeToFeatures.toMap)
      +      }
      +    }
      +
      +    // array of nodes to train indexed by node index in group
      +    val nodes = new Array[LearningNode](numNodes)
      +    nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
      +      nodesForTree.foreach { node =>
      +        nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
      +      }
      +    }
      +
      +    // Calculate best splits for all nodes in the group
      +    timer.start("chooseSplits")
      +
      +    // In each partition, iterate all instances and compute aggregate stats for each node,
      +    // yield an (nodeIndex, nodeAggregateStats) pair for each node.
      +    // After a `reduceByKey` operation,
      +    // stats of a node will be shuffled to a particular partition and be combined together,
      +    // then best splits for nodes are found there.
      +    // Finally, only best Splits for nodes are collected to driver to construct decision tree.
      +    val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
      +    val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
      +
      +    val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
      +      input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
      +        // Construct a nodeStatsAggregators array to hold node aggregate stats,
      +        // each node will have a nodeStatsAggregator
      +        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
      +          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
      +            Some(nodeToFeatures(nodeIndex))
      +          }
      +          new DTStatsAggregator(metadata, featuresForNode)
      +        }
      +
      +        // iterator all instances in current partition and update aggregate stats
      +        points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))
      +
      +        // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
      +        // which can be combined with other partition using `reduceByKey`
      +        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
      +      }
      +    } else {
      +      input.mapPartitions { points =>
      +        // Construct a nodeStatsAggregators array to hold node aggregate stats,
      +        // each node will have a nodeStatsAggregator
      +        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
      +          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
      +            Some(nodeToFeatures(nodeIndex))
      +          }
      +          new DTStatsAggregator(metadata, featuresForNode)
      +        }
      +
      +        // iterator all instances in current partition and update aggregate stats
      +        points.foreach(binSeqOp(nodeStatsAggregators, _))
      +
      +        // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
      +        // which can be combined with other partition using `reduceByKey`
      +        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
      +      }
      +    }
      +
      +    val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)).map {
      +      case (nodeIndex, aggStats) =>
      +        val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
      +          Some(nodeToFeatures(nodeIndex))
      +        }
      +
      +        // find best split for each node
      +        val (split: Split, stats: ImpurityStats) =
      +          binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
      +        (nodeIndex, (split, stats))
      +    }.collectAsMap()
      +
      +    timer.stop("chooseSplits")
      +
      +    val nodeIdUpdaters = if (nodeIdCache.nonEmpty) {
      +      Array.fill[mutable.Map[Int, NodeIndexUpdater]](
      +        metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]())
      +    } else {
      +      null
      +    }
      +    // Iterate over all nodes in this group.
      +    nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
      +      nodesForTree.foreach { node =>
      +        val nodeIndex = node.id
      +        val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
      +        val aggNodeIndex = nodeInfo.nodeIndexInGroup
      +        val (split: Split, stats: ImpurityStats) =
      +          nodeToBestSplits(aggNodeIndex)
      +        logDebug("best split = " + split)
      +
      +        // Extract info for this node.  Create children if not leaf.
      +        val isLeaf =
      +          (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth)
      +        node.isLeaf = isLeaf
      +        node.stats = stats
      +        logDebug("Node = " + node)
      +
      +        if (!isLeaf) {
      +          node.split = Some(split)
      +          val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
      +          val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
      +          val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
      +          node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
      +            leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator)))
      +          node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
      +            rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator)))
      +
      +          if (nodeIdCache.nonEmpty) {
      +            val nodeIndexUpdater = NodeIndexUpdater(
      +              split = split,
      +              nodeIndex = nodeIndex)
      +            nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater)
      +          }
      +
      +          // enqueue left child and right child if they are not leaves
      +          if (!leftChildIsLeaf) {
      +            nodeQueue.enqueue((treeIndex, node.leftChild.get))
      +          }
      +          if (!rightChildIsLeaf) {
      +            nodeQueue.enqueue((treeIndex, node.rightChild.get))
      +          }
      +
      +          logDebug("leftChildIndex = " + node.leftChild.get.id +
      +            ", impurity = " + stats.leftImpurity)
      +          logDebug("rightChildIndex = " + node.rightChild.get.id +
      +            ", impurity = " + stats.rightImpurity)
      +        }
      +      }
      +    }
      +
      +    if (nodeIdCache.nonEmpty) {
      +      // Update the cache if needed.
      +      nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, splits)
      +    }
      +  }
      +
      +  /**
      +   * Calculate the impurity statistics for a give (feature, split) based upon left/right aggregates.
      +   * @param stats the recycle impurity statistics for this feature's all splits,
      +   *              only 'impurity' and 'impurityCalculator' are valid between each iteration
      +   * @param leftImpurityCalculator left node aggregates for this (feature, split)
      +   * @param rightImpurityCalculator right node aggregate for this (feature, split)
      +   * @param metadata learning and dataset metadata for DecisionTree
      +   * @return Impurity statistics for this (feature, split)
      +   */
      +  private def calculateImpurityStats(
      +      stats: ImpurityStats,
      +      leftImpurityCalculator: ImpurityCalculator,
      +      rightImpurityCalculator: ImpurityCalculator,
      +      metadata: DecisionTreeMetadata): ImpurityStats = {
      +
      +    val parentImpurityCalculator: ImpurityCalculator = if (stats == null) {
      +      leftImpurityCalculator.copy.add(rightImpurityCalculator)
      +    } else {
      +      stats.impurityCalculator
      +    }
      +
      +    val impurity: Double = if (stats == null) {
      +      parentImpurityCalculator.calculate()
      +    } else {
      +      stats.impurity
      +    }
      +
      +    val leftCount = leftImpurityCalculator.count
      +    val rightCount = rightImpurityCalculator.count
      +
      +    val totalCount = leftCount + rightCount
      +
      +    // If left child or right child doesn't satisfy minimum instances per node,
      +    // then this split is invalid, return invalid information gain stats.
      +    if ((leftCount < metadata.minInstancesPerNode) ||
      +      (rightCount < metadata.minInstancesPerNode)) {
      +      return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
      +    }
      +
      +    val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
      +    val rightImpurity = rightImpurityCalculator.calculate()
      +
      +    val leftWeight = leftCount / totalCount.toDouble
      +    val rightWeight = rightCount / totalCount.toDouble
      +
      +    val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
      +
      +    // if information gain doesn't satisfy minimum information gain,
      +    // then this split is invalid, return invalid information gain stats.
      +    if (gain < metadata.minInfoGain) {
      +      return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
      +    }
      +
      +    new ImpurityStats(gain, impurity, parentImpurityCalculator,
      +      leftImpurityCalculator, rightImpurityCalculator)
      +  }
      +
      +  /**
      +   * Find the best split for a node.
      +   * @param binAggregates Bin statistics.
      +   * @return tuple for best split: (Split, information gain, prediction at node)
      +   */
      +  private def binsToBestSplit(
      +      binAggregates: DTStatsAggregator,
      +      splits: Array[Array[Split]],
      +      featuresForNode: Option[Array[Int]],
      +      node: LearningNode): (Split, ImpurityStats) = {
      +
      +    // Calculate InformationGain and ImpurityStats if current node is top node
      +    val level = LearningNode.indexToLevel(node.id)
      +    var gainAndImpurityStats: ImpurityStats = if (level ==0) {
      +      null
      +    } else {
      +      node.stats
      +    }
      +
      +    // For each (feature, split), calculate the gain, and select the best (feature, split).
      +    val (bestSplit, bestSplitStats) =
      +      Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
      +        val featureIndex = if (featuresForNode.nonEmpty) {
      +          featuresForNode.get.apply(featureIndexIdx)
      +        } else {
      +          featureIndexIdx
      +        }
      +        val numSplits = binAggregates.metadata.numSplits(featureIndex)
      +        if (binAggregates.metadata.isContinuous(featureIndex)) {
      +          // Cumulative sum (scanLeft) of bin statistics.
      +          // Afterwards, binAggregates for a bin is the sum of aggregates for
      +          // that bin + all preceding bins.
      +          val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
      +          var splitIndex = 0
      +          while (splitIndex < numSplits) {
      +            binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
      +            splitIndex += 1
      +          }
      +          // Find best split.
      +          val (bestFeatureSplitIndex, bestFeatureGainStats) =
      +            Range(0, numSplits).map { case splitIdx =>
      +              val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
      +              val rightChildStats =
      +                binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
      +              rightChildStats.subtract(leftChildStats)
      +              gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
      +                leftChildStats, rightChildStats, binAggregates.metadata)
      +              (splitIdx, gainAndImpurityStats)
      +            }.maxBy(_._2.gain)
      +          (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
      +        } else if (binAggregates.metadata.isUnordered(featureIndex)) {
      +          // Unordered categorical feature
      +          val (leftChildOffset, rightChildOffset) =
      +            binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
      +          val (bestFeatureSplitIndex, bestFeatureGainStats) =
      +            Range(0, numSplits).map { splitIndex =>
      +              val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
      +              val rightChildStats =
      +                binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
      +              gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
      +                leftChildStats, rightChildStats, binAggregates.metadata)
      +              (splitIndex, gainAndImpurityStats)
      +            }.maxBy(_._2.gain)
      +          (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
      +        } else {
      +          // Ordered categorical feature
      +          val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
      +          val numCategories = binAggregates.metadata.numBins(featureIndex)
      +
      +          /* Each bin is one category (feature value).
      +           * The bins are ordered based on centroidForCategories, and this ordering determines which
      +           * splits are considered.  (With K categories, we consider K - 1 possible splits.)
      +           *
      +           * centroidForCategories is a list: (category, centroid)
      +           */
      +          val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
      +            // For categorical variables in multiclass classification,
      +            // the bins are ordered by the impurity of their corresponding labels.
      +            Range(0, numCategories).map { case featureValue =>
      +              val categoryStats =
      +                binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
      +              val centroid = if (categoryStats.count != 0) {
      +                categoryStats.calculate()
      +              } else {
      +                Double.MaxValue
      +              }
      +              (featureValue, centroid)
      +            }
      +          } else { // regression or binary classification
      +            // For categorical variables in regression and binary classification,
      +            // the bins are ordered by the centroid of their corresponding labels.
      +            Range(0, numCategories).map { case featureValue =>
      +              val categoryStats =
      +                binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
      +              val centroid = if (categoryStats.count != 0) {
      +                categoryStats.predict
      +              } else {
      +                Double.MaxValue
      +              }
      +              (featureValue, centroid)
      +            }
      +          }
      +
      +          logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
      +
      +          // bins sorted by centroids
      +          val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
      +
      +          logDebug("Sorted centroids for categorical variable = " +
      +            categoriesSortedByCentroid.mkString(","))
      +
      +          // Cumulative sum (scanLeft) of bin statistics.
      +          // Afterwards, binAggregates for a bin is the sum of aggregates for
      +          // that bin + all preceding bins.
      +          var splitIndex = 0
      +          while (splitIndex < numSplits) {
      +            val currentCategory = categoriesSortedByCentroid(splitIndex)._1
      +            val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
      +            binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
      +            splitIndex += 1
      +          }
      +          // lastCategory = index of bin with total aggregates for this (node, feature)
      +          val lastCategory = categoriesSortedByCentroid.last._1
      +          // Find best split.
      +          val (bestFeatureSplitIndex, bestFeatureGainStats) =
      +            Range(0, numSplits).map { splitIndex =>
      +              val featureValue = categoriesSortedByCentroid(splitIndex)._1
      +              val leftChildStats =
      +                binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
      +              val rightChildStats =
      +                binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
      +              rightChildStats.subtract(leftChildStats)
      +              gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
      +                leftChildStats, rightChildStats, binAggregates.metadata)
      +              (splitIndex, gainAndImpurityStats)
      +            }.maxBy(_._2.gain)
      +          val categoriesForSplit =
      +            categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
      +          val bestFeatureSplit =
      +            new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories)
      +          (bestFeatureSplit, bestFeatureGainStats)
      +        }
      +      }.maxBy(_._2.gain)
      +
      +    (bestSplit, bestSplitStats)
      +  }
      +
      +  /**
      +   * Returns splits and bins for decision tree calculation.
      +   * Continuous and categorical features are handled differently.
      +   *
      +   * Continuous features:
      +   *   For each feature, there are numBins - 1 possible splits representing the possible binary
      +   *   decisions at each node in the tree.
      +   *   This finds locations (feature values) for splits using a subsample of the data.
      +   *
      +   * Categorical features:
      +   *   For each feature, there is 1 bin per split.
      +   *   Splits and bins are handled in 2 ways:
      +   *   (a) "unordered features"
      +   *       For multiclass classification with a low-arity feature
      +   *       (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
      +   *       the feature is split based on subsets of categories.
      +   *   (b) "ordered features"
      +   *       For regression and binary classification,
      +   *       and for multiclass classification with a high-arity feature,
      +   *       there is one bin per category.
      +   *
      +   * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
      +   * @param metadata Learning and dataset metadata
      +   * @return A tuple of (splits, bins).
      +   *         Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
      +   *          of size (numFeatures, numSplits).
      +   *         Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]]
      +   *          of size (numFeatures, numBins).
      +   */
      +  protected[tree] def findSplits(
      +      input: RDD[LabeledPoint],
      +      metadata: DecisionTreeMetadata): Array[Array[Split]] = {
      +
      +    logDebug("isMulticlass = " + metadata.isMulticlass)
      +
      +    val numFeatures = metadata.numFeatures
      +
      +    // Sample the input only if there are continuous features.
      +    val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous)
      +    val sampledInput = if (hasContinuousFeatures) {
      +      // Calculate the number of samples for approximate quantile calculation.
      +      val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
      +      val fraction = if (requiredSamples < metadata.numExamples) {
      +        requiredSamples.toDouble / metadata.numExamples
      +      } else {
      +        1.0
      +      }
      +      logDebug("fraction of data used for calculating quantiles = " + fraction)
      +      input.sample(withReplacement = false, fraction, new XORShiftRandom(1).nextInt()).collect()
      +    } else {
      +      new Array[LabeledPoint](0)
      +    }
      +
      +    val splits = new Array[Array[Split]](numFeatures)
      +
      +    // Find all splits.
      +    // Iterate over all features.
      +    var featureIndex = 0
      +    while (featureIndex < numFeatures) {
      +      if (metadata.isContinuous(featureIndex)) {
      +        val featureSamples = sampledInput.map(_.features(featureIndex))
      +        val featureSplits = findSplitsForContinuousFeature(featureSamples, metadata, featureIndex)
      +
      +        val numSplits = featureSplits.length
      +        logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits")
      +        splits(featureIndex) = new Array[Split](numSplits)
      +
      +        var splitIndex = 0
      +        while (splitIndex < numSplits) {
      +          val threshold = featureSplits(splitIndex)
      +          splits(featureIndex)(splitIndex) = new ContinuousSplit(featureIndex, threshold)
      +          splitIndex += 1
      +        }
      +      } else {
      +        // Categorical feature
      +        if (metadata.isUnordered(featureIndex)) {
      +          val numSplits = metadata.numSplits(featureIndex)
      +          val featureArity = metadata.featureArity(featureIndex)
      +          // TODO: Use an implicit representation mapping each category to a subset of indices.
      +          //       I.e., track indices such that we can calculate the set of bins for which
      +          //       feature value x splits to the left.
      +          // Unordered features
      +          // 2^(maxFeatureValue - 1) - 1 combinations
      +          splits(featureIndex) = new Array[Split](numSplits)
      +          var splitIndex = 0
      +          while (splitIndex < numSplits) {
      +            val categories: List[Double] =
      +              extractMultiClassCategories(splitIndex + 1, featureArity)
      +            splits(featureIndex)(splitIndex) =
      +              new CategoricalSplit(featureIndex, categories.toArray, featureArity)
      +            splitIndex += 1
      +          }
      +        } else {
      +          // Ordered features
      +          //   Bins correspond to feature values, so we do not need to compute splits or bins
      +          //   beforehand.  Splits are constructed as needed during training.
      +          splits(featureIndex) = new Array[Split](0)
      +        }
      +      }
      +      featureIndex += 1
      +    }
      +    splits
      +  }
      +
      +  /**
      +   * Nested method to extract list of eligible categories given an index. It extracts the
      +   * position of ones in a binary representation of the input. If binary
      +   * representation of an number is 01101 (13), the output list should (3.0, 2.0,
      +   * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones.
      +   */
      +  private[tree] def extractMultiClassCategories(
      +      input: Int,
      +      maxFeatureValue: Int): List[Double] = {
      +    var categories = List[Double]()
      +    var j = 0
      +    var bitShiftedInput = input
      +    while (j < maxFeatureValue) {
      +      if (bitShiftedInput % 2 != 0) {
      +        // updating the list of categories.
      +        categories = j.toDouble :: categories
      +      }
      +      // Right shift by one
      +      bitShiftedInput = bitShiftedInput >> 1
      +      j += 1
      +    }
      +    categories
      +  }
      +
      +  /**
      +   * Find splits for a continuous feature
      +   * NOTE: Returned number of splits is set based on `featureSamples` and
      +   *       could be different from the specified `numSplits`.
      +   *       The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
      +   * @param featureSamples feature values of each sample
      +   * @param metadata decision tree metadata
      +   *                 NOTE: `metadata.numbins` will be changed accordingly
      +   *                       if there are not enough splits to be found
      +   * @param featureIndex feature index to find splits
      +   * @return array of splits
      +   */
      +  private[tree] def findSplitsForContinuousFeature(
      +      featureSamples: Array[Double],
      +      metadata: DecisionTreeMetadata,
      +      featureIndex: Int): Array[Double] = {
      +    require(metadata.isContinuous(featureIndex),
      +      "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
      +
      +    val splits = {
      +      val numSplits = metadata.numSplits(featureIndex)
      +
      +      // get count for each distinct value
      +      val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
      +        m + ((x, m.getOrElse(x, 0) + 1))
      +      }
      +      // sort distinct values
      +      val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
      +
      +      // if possible splits is not enough or just enough, just return all possible splits
      +      val possibleSplits = valueCounts.length
      +      if (possibleSplits <= numSplits) {
      +        valueCounts.map(_._1)
      +      } else {
      +        // stride between splits
      +        val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
      +        logDebug("stride = " + stride)
      +
      +        // iterate `valueCount` to find splits
      +        val splitsBuilder = mutable.ArrayBuilder.make[Double]
      +        var index = 1
      +        // currentCount: sum of counts of values that have been visited
      +        var currentCount = valueCounts(0)._2
      +        // targetCount: target value for `currentCount`.
      +        // If `currentCount` is closest value to `targetCount`,
      +        // then current value is a split threshold.
      +        // After finding a split threshold, `targetCount` is added by stride.
      +        var targetCount = stride
      +        while (index < valueCounts.length) {
      +          val previousCount = currentCount
      +          currentCount += valueCounts(index)._2
      +          val previousGap = math.abs(previousCount - targetCount)
      +          val currentGap = math.abs(currentCount - targetCount)
      +          // If adding count of current value to currentCount
      +          // makes the gap between currentCount and targetCount smaller,
      +          // previous value is a split threshold.
      +          if (previousGap < currentGap) {
      +            splitsBuilder += valueCounts(index - 1)._1
      +            targetCount += stride
      +          }
      +          index += 1
      +        }
      +
      +        splitsBuilder.result()
      +      }
      +    }
      +
      +    // TODO: Do not fail; just ignore the useless feature.
      +    assert(splits.length > 0,
      +      s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
      +        "  Please remove this feature and then try again.")
      +    // set number of splits accordingly
      +    metadata.setNumSplits(featureIndex, splits.length)
      +
      +    splits
      +  }
      +
      +  private[tree] class NodeIndexInfo(
      +      val nodeIndexInGroup: Int,
      +      val featureSubset: Option[Array[Int]]) extends Serializable
      +
      +  /**
      +   * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration.
      +   * This tracks the memory usage for aggregates and stops adding nodes when too much memory
      +   * will be needed; this allows an adaptive number of nodes since different nodes may require
      +   * different amounts of memory (if featureSubsetStrategy is not "all").
      +   *
      +   * @param nodeQueue  Queue of nodes to split.
      +   * @param maxMemoryUsage  Bound on size of aggregate statistics.
      +   * @return  (nodesForGroup, treeToNodeToIndexInfo).
      +   *          nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
      +   *
      +   *          treeToNodeToIndexInfo holds indices selected features for each node:
      +   *            treeIndex --> (global) node index --> (node index in group, feature indices).
      +   *          The (global) node index is the index in the tree; the node index in group is the
      +   *           index in [0, numNodesInGroup) of the node in this group.
      +   *          The feature indices are None if not subsampling features.
      +   */
      +  private[tree] def selectNodesToSplit(
      +      nodeQueue: mutable.Queue[(Int, LearningNode)],
      +      maxMemoryUsage: Long,
      +      metadata: DecisionTreeMetadata,
      +      rng: Random): (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = {
      +    // Collect some nodes to split:
      +    //  nodesForGroup(treeIndex) = nodes to split
      +    val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[LearningNode]]()
      +    val mutableTreeToNodeToIndexInfo =
      +      new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]()
      +    var memUsage: Long = 0L
      +    var numNodesInGroup = 0
      +    while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) {
      +      val (treeIndex, node) = nodeQueue.head
      +      // Choose subset of features for node (if subsampling).
      +      val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
      +        Some(SamplingUtils.reservoirSampleAndCount(Range(0,
      +          metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong())._1)
      +      } else {
      +        None
      +      }
      +      // Check if enough memory remains to add this node to the group.
      +      val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
      +      if (memUsage + nodeMemUsage <= maxMemoryUsage) {
      +        nodeQueue.dequeue()
      +        mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) +=
      +          node
      +        mutableTreeToNodeToIndexInfo
      +          .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id)
      +          = new NodeIndexInfo(numNodesInGroup, featureSubset)
      +      }
      +      numNodesInGroup += 1
      +      memUsage += nodeMemUsage
      +    }
      +    // Convert mutable maps to immutable ones.
      +    val nodesForGroup: Map[Int, Array[LearningNode]] =
      +      mutableNodesForGroup.mapValues(_.toArray).toMap
      +    val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap
      +    (nodesForGroup, treeToNodeToIndexInfo)
      +  }
      +
      +  /**
      +   * Get the number of values to be stored for this node in the bin aggregates.
      +   * @param featureSubset  Indices of features which may be split at this node.
      +   *                       If None, then use all features.
      +   */
      +  private def aggregateSizeForNode(
      +      metadata: DecisionTreeMetadata,
      +      featureSubset: Option[Array[Int]]): Long = {
      +    val totalBins = if (featureSubset.nonEmpty) {
      +      featureSubset.get.map(featureIndex => metadata.numBins(featureIndex).toLong).sum
      +    } else {
      +      metadata.numBins.map(_.toLong).sum
      +    }
      +    if (metadata.isClassification) {
      +      metadata.numClasses * totalBins
      +    } else {
      +      3 * totalBins
      +    }
      +  }
      +
      +  /**
      +   * Given a Random Forest model, compute the importance of each feature.
      +   * This generalizes the idea of "Gini" importance to other losses,
      +   * following the explanation of Gini importance from "Random Forests" documentation
      +   * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
      +   *
      +   * This feature importance is calculated as follows:
      +   *  - Average over trees:
      +   *     - importance(feature j) = sum (over nodes which split on feature j) of the gain,
      +   *       where gain is scaled by the number of instances passing through node
      +   *     - Normalize importances for tree based on total number of training instances used
      +   *       to build tree.
      +   *  - Normalize feature importance vector to sum to 1.
      +   *
      +   * Note: This should not be used with Gradient-Boosted Trees.  It only makes sense for
      +   *       independently trained trees.
      +   * @param trees  Unweighted forest of trees
      +   * @param numFeatures  Number of features in model (even if not all are explicitly used by
      +   *                     the model).
      +   *                     If -1, then numFeatures is set based on the max feature index in all trees.
      +   * @return  Feature importance values, of length numFeatures.
      +   */
      +  private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
      +    val totalImportances = new OpenHashMap[Int, Double]()
      +    trees.foreach { tree =>
      +      // Aggregate feature importance vector for this tree
      +      val importances = new OpenHashMap[Int, Double]()
      +      computeFeatureImportance(tree.rootNode, importances)
      +      // Normalize importance vector for this tree, and add it to total.
      +      // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
      +      val treeNorm = importances.map(_._2).sum
      +      if (treeNorm != 0) {
      +        importances.foreach { case (idx, impt) =>
      +          val normImpt = impt / treeNorm
      +          totalImportances.changeValue(idx, normImpt, _ + normImpt)
      +        }
      +      }
      +    }
      +    // Normalize importances
      +    normalizeMapValues(totalImportances)
      +    // Construct vector
      +    val d = if (numFeatures != -1) {
      +      numFeatures
      +    } else {
      +      // Find max feature index used in trees
      +      val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
      +      maxFeatureIndex + 1
      +    }
      +    if (d == 0) {
      +      assert(totalImportances.size == 0, s"Unknown error in computing RandomForest feature" +
      +        s" importance: No splits in forest, but some non-zero importances.")
      +    }
      +    val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
      +    Vectors.sparse(d, indices.toArray, values.toArray)
      +  }
      +
      +  /**
      +   * Recursive method for computing feature importances for one tree.
      +   * This walks down the tree, adding to the importance of 1 feature at each node.
      +   * @param node  Current node in recursion
      +   * @param importances  Aggregate feature importances, modified by this method
      +   */
      +  private[impl] def computeFeatureImportance(
      +      node: Node,
      +      importances: OpenHashMap[Int, Double]): Unit = {
      +    node match {
      +      case n: InternalNode =>
      +        val feature = n.split.featureIndex
      +        val scaledGain = n.gain * n.impurityStats.count
      +        importances.changeValue(feature, scaledGain, _ + scaledGain)
      +        computeFeatureImportance(n.leftChild, importances)
      +        computeFeatureImportance(n.rightChild, importances)
      +      case n: LeafNode =>
      +        // do nothing
      +    }
      +  }
      +
      +  /**
      +   * Normalize the values of this map to sum to 1, in place.
      +   * If all values are 0, this method does nothing.
      +   * @param map  Map with non-negative values.
      +   */
      +  private[impl] def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = {
      +    val total = map.map(_._2).sum
      +    if (total != 0) {
      +      val keys = map.iterator.map(_._1).toArray
      +      keys.foreach { key => map.changeValue(key, 0.0, _ / total) }
      +    }
      +  }
      +
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
      new file mode 100644
      index 000000000000..9fa27e5e1f72
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
      @@ -0,0 +1,134 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.tree.impl
      +
      +import org.apache.spark.ml.tree.{ContinuousSplit, Split}
      +import org.apache.spark.mllib.regression.LabeledPoint
      +import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
      +import org.apache.spark.rdd.RDD
      +
      +
      +/**
      + * Internal representation of LabeledPoint for DecisionTree.
      + * This bins feature values based on a subsampled of data as follows:
      + *  (a) Continuous features are binned into ranges.
      + *  (b) Unordered categorical features are binned based on subsets of feature values.
      + *      "Unordered categorical features" are categorical features with low arity used in
      + *      multiclass classification.
      + *  (c) Ordered categorical features are binned based on feature values.
      + *      "Ordered categorical features" are categorical features with high arity,
      + *      or any categorical feature used in regression or binary classification.
      + *
      + * @param label  Label from LabeledPoint
      + * @param binnedFeatures  Binned feature values.
      + *                        Same length as LabeledPoint.features, but values are bin indices.
      + */
      +private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int])
      +  extends Serializable {
      +}
      +
      +private[spark] object TreePoint {
      +
      +  /**
      +   * Convert an input dataset into its TreePoint representation,
      +   * binning feature values in preparation for DecisionTree training.
      +   * @param input     Input dataset.
      +   * @param splits    Splits for features, of size (numFeatures, numSplits).
      +   * @param metadata  Learning and dataset metadata
      +   * @return  TreePoint dataset representation
      +   */
      +  def convertToTreeRDD(
      +      input: RDD[LabeledPoint],
      +      splits: Array[Array[Split]],
      +      metadata: DecisionTreeMetadata): RDD[TreePoint] = {
      +    // Construct arrays for featureArity for efficiency in the inner loop.
      +    val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
      +    var featureIndex = 0
      +    while (featureIndex < metadata.numFeatures) {
      +      featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0)
      +      featureIndex += 1
      +    }
      +    val thresholds: Array[Array[Double]] = featureArity.zipWithIndex.map { case (arity, idx) =>
      +      if (arity == 0) {
      +        splits(idx).map(_.asInstanceOf[ContinuousSplit].threshold)
      +      } else {
      +        Array.empty[Double]
      +      }
      +    }
      +    input.map { x =>
      +      TreePoint.labeledPointToTreePoint(x, thresholds, featureArity)
      +    }
      +  }
      +
      +  /**
      +   * Convert one LabeledPoint into its TreePoint representation.
      +   * @param thresholds  For each feature, split thresholds for continuous features,
      +   *                    empty for categorical features.
      +   * @param featureArity  Array indexed by feature, with value 0 for continuous and numCategories
      +   *                      for categorical features.
      +   */
      +  private def labeledPointToTreePoint(
      +      labeledPoint: LabeledPoint,
      +      thresholds: Array[Array[Double]],
      +      featureArity: Array[Int]): TreePoint = {
      +    val numFeatures = labeledPoint.features.size
      +    val arr = new Array[Int](numFeatures)
      +    var featureIndex = 0
      +    while (featureIndex < numFeatures) {
      +      arr(featureIndex) =
      +        findBin(featureIndex, labeledPoint, featureArity(featureIndex), thresholds(featureIndex))
      +      featureIndex += 1
      +    }
      +    new TreePoint(labeledPoint.label, arr)
      +  }
      +
      +  /**
      +   * Find discretized value for one (labeledPoint, feature).
      +   *
      +   * NOTE: We cannot use Bucketizer since it handles split thresholds differently than the old
      +   *       (mllib) tree API.  We want to maintain the same behavior as the old tree API.
      +   *
      +   * @param featureArity  0 for continuous features; number of categories for categorical features.
      +   */
      +  private def findBin(
      +      featureIndex: Int,
      +      labeledPoint: LabeledPoint,
      +      featureArity: Int,
      +      thresholds: Array[Double]): Int = {
      +    val featureValue = labeledPoint.features(featureIndex)
      +
      +    if (featureArity == 0) {
      +      val idx = java.util.Arrays.binarySearch(thresholds, featureValue)
      +      if (idx >= 0) {
      +        idx
      +      } else {
      +        -idx - 1
      +      }
      +    } else {
      +      // Categorical feature bins are indexed by feature values.
      +      if (featureValue < 0 || featureValue >= featureArity) {
      +        throw new IllegalArgumentException(
      +          s"DecisionTree given invalid data:" +
      +            s" Feature $featureIndex is categorical with values in {0,...,${featureArity - 1}," +
      +            s" but a data point gives it value $featureValue.\n" +
      +            "  Bad data point: " + labeledPoint.toString)
      +      }
      +      featureValue.toInt
      +    }
      +  }
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
      index 1929f9d02156..b77191156f68 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
      @@ -17,6 +17,7 @@
       
       package org.apache.spark.ml.tree
       
      +import org.apache.spark.mllib.linalg.{Vectors, Vector}
       
       /**
        * Abstraction for Decision Tree models.
      @@ -52,6 +53,12 @@ private[ml] trait DecisionTreeModel {
           val header = toString + "\n"
           header + rootNode.subtreeToString(2)
         }
      +
      +  /**
      +   * Trace down the tree, and return the largest feature index used in any split.
      +   * @return  Max feature index used in a split, or -1 if there are no splits (single leaf node).
      +   */
      +  private[ml] def maxSplitFeatureIndex(): Int = rootNode.maxSplitFeatureIndex()
       }
       
       /**
      @@ -70,6 +77,10 @@ private[ml] trait TreeEnsembleModel {
         /** Weights for each tree, zippable with [[trees]] */
         def treeWeights: Array[Double]
       
      +  /** Weights used by the python wrappers. */
      +  // Note: An array cannot be returned directly due to serialization problems.
      +  private[spark] def javaTreeWeights: Vector = Vectors.dense(treeWeights)
      +
         /** Summary of the model */
         override def toString: String = {
           // Implementing classes should generally override this method to be more descriptive.
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
      index a0c5238d966b..d29f5253c9c3 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
      @@ -17,9 +17,10 @@
       
       package org.apache.spark.ml.tree
       
      +import org.apache.spark.ml.classification.ClassifierParams
       import org.apache.spark.ml.PredictorParams
       import org.apache.spark.ml.param._
      -import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed}
      +import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasMaxIter, HasSeed, HasThresholds}
       import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
       import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
       import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
      @@ -29,7 +30,7 @@ import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
        *
        * Note: Marked as private and DeveloperApi since this may be made public in the future.
        */
      -private[ml] trait DecisionTreeParams extends PredictorParams {
      +private[ml] trait DecisionTreeParams extends PredictorParams with HasCheckpointInterval {
       
         /**
          * Maximum depth of the tree (>= 0).
      @@ -95,21 +96,6 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
           " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" +
           " trees.")
       
      -  /**
      -   * Specifies how often to checkpoint the cached node IDs.
      -   * E.g. 10 means that the cache will get checkpointed every 10 iterations.
      -   * This is only used if cacheNodeIds is true and if the checkpoint directory is set in
      -   * [[org.apache.spark.SparkContext]].
      -   * Must be >= 1.
      -   * (default = 10)
      -   * @group expertParam
      -   */
      -  final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" +
      -    " how often to checkpoint the cached node IDs.  E.g. 10 means that the cache will get" +
      -    " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" +
      -    " checkpoint directory is set in the SparkContext. Must be >= 1.",
      -    ParamValidators.gtEq(1))
      -
         setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0,
           maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10)
       
      @@ -149,12 +135,17 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
         /** @group expertGetParam */
         final def getCacheNodeIds: Boolean = $(cacheNodeIds)
       
      -  /** @group expertSetParam */
      +  /**
      +   * Specifies how often to checkpoint the cached node IDs.
      +   * E.g. 10 means that the cache will get checkpointed every 10 iterations.
      +   * This is only used if cacheNodeIds is true and if the checkpoint directory is set in
      +   * [[org.apache.spark.SparkContext]].
      +   * Must be >= 1.
      +   * (default = 10)
      +   * @group expertSetParam
      +   */
         def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
       
      -  /** @group expertGetParam */
      -  final def getCheckpointInterval: Int = $(checkpointInterval)
      -
         /** (private[ml]) Create a Strategy instance to use with the old API. */
         private[ml] def getOldStrategy(
             categoricalFeatures: Map[Int, Int],
      @@ -162,7 +153,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
             oldAlgo: OldAlgo.Algo,
             oldImpurity: OldImpurity,
             subsamplingRate: Double): OldStrategy = {
      -    val strategy = OldStrategy.defaultStategy(oldAlgo)
      +    val strategy = OldStrategy.defaultStrategy(oldAlgo)
           strategy.impurity = oldImpurity
           strategy.checkpointInterval = getCheckpointInterval
           strategy.maxBins = getMaxBins
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
      index e2444ab65b43..0679bfd0f3ff 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
      @@ -32,38 +32,7 @@ import org.apache.spark.sql.types.StructType
       /**
        * Params for [[CrossValidator]] and [[CrossValidatorModel]].
        */
      -private[ml] trait CrossValidatorParams extends Params {
      -
      -  /**
      -   * param for the estimator to be cross-validated
      -   * @group param
      -   */
      -  val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
      -
      -  /** @group getParam */
      -  def getEstimator: Estimator[_] = $(estimator)
      -
      -  /**
      -   * param for estimator param maps
      -   * @group param
      -   */
      -  val estimatorParamMaps: Param[Array[ParamMap]] =
      -    new Param(this, "estimatorParamMaps", "param maps for the estimator")
      -
      -  /** @group getParam */
      -  def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps)
      -
      -  /**
      -   * param for the evaluator used to select hyper-parameters that maximize the cross-validated
      -   * metric
      -   * @group param
      -   */
      -  val evaluator: Param[Evaluator] = new Param(this, "evaluator",
      -    "evaluator used to select hyper-parameters that maximize the cross-validated metric")
      -
      -  /** @group getParam */
      -  def getEvaluator: Evaluator = $(evaluator)
      -
      +private[ml] trait CrossValidatorParams extends ValidatorParams {
         /**
          * Param for number of folds for cross validation.  Must be >= 2.
          * Default: 3
      @@ -131,7 +100,9 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
           }
           f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1)
           logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")
      -    val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1)
      +    val (bestMetric, bestIndex) =
      +      if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1)
      +      else metrics.zipWithIndex.minBy(_._1)
           logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
           logInfo(s"Best cross-validation metric: $bestMetric.")
           val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
      @@ -191,6 +162,6 @@ class CrossValidatorModel private[ml] (
             uid,
             bestModel.copy(extra).asInstanceOf[Model[_]],
             avgMetrics.clone())
      -    copyValues(copied, extra)
      +    copyValues(copied, extra).setParent(parent)
         }
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
      new file mode 100644
      index 000000000000..73a14b831015
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
      @@ -0,0 +1,170 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.tuning
      +
      +import org.apache.spark.Logging
      +import org.apache.spark.annotation.Experimental
      +import org.apache.spark.ml.evaluation.Evaluator
      +import org.apache.spark.ml.{Estimator, Model}
      +import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
      +import org.apache.spark.ml.util.Identifiable
      +import org.apache.spark.sql.DataFrame
      +import org.apache.spark.sql.types.StructType
      +
      +/**
      + * Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]].
      + */
      +private[ml] trait TrainValidationSplitParams extends ValidatorParams {
      +  /**
      +   * Param for ratio between train and validation data. Must be between 0 and 1.
      +   * Default: 0.75
      +   * @group param
      +   */
      +  val trainRatio: DoubleParam = new DoubleParam(this, "trainRatio",
      +    "ratio between training set and validation set (>= 0 && <= 1)", ParamValidators.inRange(0, 1))
      +
      +  /** @group getParam */
      +  def getTrainRatio: Double = $(trainRatio)
      +
      +  setDefault(trainRatio -> 0.75)
      +}
      +
      +/**
      + * :: Experimental ::
      + * Validation for hyper-parameter tuning.
      + * Randomly splits the input dataset into train and validation sets,
      + * and uses evaluation metric on the validation set to select the best model.
      + * Similar to [[CrossValidator]], but only splits the set once.
      + */
      +@Experimental
      +class TrainValidationSplit(override val uid: String) extends Estimator[TrainValidationSplitModel]
      +  with TrainValidationSplitParams with Logging {
      +
      +  def this() = this(Identifiable.randomUID("tvs"))
      +
      +  /** @group setParam */
      +  def setEstimator(value: Estimator[_]): this.type = set(estimator, value)
      +
      +  /** @group setParam */
      +  def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value)
      +
      +  /** @group setParam */
      +  def setEvaluator(value: Evaluator): this.type = set(evaluator, value)
      +
      +  /** @group setParam */
      +  def setTrainRatio(value: Double): this.type = set(trainRatio, value)
      +
      +  override def fit(dataset: DataFrame): TrainValidationSplitModel = {
      +    val schema = dataset.schema
      +    transformSchema(schema, logging = true)
      +    val sqlCtx = dataset.sqlContext
      +    val est = $(estimator)
      +    val eval = $(evaluator)
      +    val epm = $(estimatorParamMaps)
      +    val numModels = epm.length
      +    val metrics = new Array[Double](epm.length)
      +
      +    val Array(training, validation) =
      +      dataset.rdd.randomSplit(Array($(trainRatio), 1 - $(trainRatio)))
      +    val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
      +    val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
      +
      +    // multi-model training
      +    logDebug(s"Train split with multiple sets of parameters.")
      +    val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
      +    trainingDataset.unpersist()
      +    var i = 0
      +    while (i < numModels) {
      +      // TODO: duplicate evaluator to take extra params from input
      +      val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)))
      +      logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
      +      metrics(i) += metric
      +      i += 1
      +    }
      +    validationDataset.unpersist()
      +
      +    logInfo(s"Train validation split metrics: ${metrics.toSeq}")
      +    val (bestMetric, bestIndex) =
      +      if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1)
      +      else metrics.zipWithIndex.minBy(_._1)
      +    logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
      +    logInfo(s"Best train validation split metric: $bestMetric.")
      +    val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
      +    copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this))
      +  }
      +
      +  override def transformSchema(schema: StructType): StructType = {
      +    $(estimator).transformSchema(schema)
      +  }
      +
      +  override def validateParams(): Unit = {
      +    super.validateParams()
      +    val est = $(estimator)
      +    for (paramMap <- $(estimatorParamMaps)) {
      +      est.copy(paramMap).validateParams()
      +    }
      +  }
      +
      +  override def copy(extra: ParamMap): TrainValidationSplit = {
      +    val copied = defaultCopy(extra).asInstanceOf[TrainValidationSplit]
      +    if (copied.isDefined(estimator)) {
      +      copied.setEstimator(copied.getEstimator.copy(extra))
      +    }
      +    if (copied.isDefined(evaluator)) {
      +      copied.setEvaluator(copied.getEvaluator.copy(extra))
      +    }
      +    copied
      +  }
      +}
      +
      +/**
      + * :: Experimental ::
      + * Model from train validation split.
      + *
      + * @param uid Id.
      + * @param bestModel Estimator determined best model.
      + * @param validationMetrics Evaluated validation metrics.
      + */
      +@Experimental
      +class TrainValidationSplitModel private[ml] (
      +    override val uid: String,
      +    val bestModel: Model[_],
      +    val validationMetrics: Array[Double])
      +  extends Model[TrainValidationSplitModel] with TrainValidationSplitParams {
      +
      +  override def validateParams(): Unit = {
      +    bestModel.validateParams()
      +  }
      +
      +  override def transform(dataset: DataFrame): DataFrame = {
      +    transformSchema(dataset.schema, logging = true)
      +    bestModel.transform(dataset)
      +  }
      +
      +  override def transformSchema(schema: StructType): StructType = {
      +    bestModel.transformSchema(schema)
      +  }
      +
      +  override def copy(extra: ParamMap): TrainValidationSplitModel = {
      +    val copied = new TrainValidationSplitModel (
      +      uid,
      +      bestModel.copy(extra).asInstanceOf[Model[_]],
      +      validationMetrics.clone())
      +    copyValues(copied, extra)
      +  }
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
      new file mode 100644
      index 000000000000..8897ab0825ac
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
      @@ -0,0 +1,60 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.tuning
      +
      +import org.apache.spark.annotation.DeveloperApi
      +import org.apache.spark.ml.Estimator
      +import org.apache.spark.ml.evaluation.Evaluator
      +import org.apache.spark.ml.param.{ParamMap, Param, Params}
      +
      +/**
      + * :: DeveloperApi ::
      + * Common params for [[TrainValidationSplitParams]] and [[CrossValidatorParams]].
      + */
      +@DeveloperApi
      +private[ml] trait ValidatorParams extends Params {
      +
      +  /**
      +   * param for the estimator to be validated
      +   * @group param
      +   */
      +  val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
      +
      +  /** @group getParam */
      +  def getEstimator: Estimator[_] = $(estimator)
      +
      +  /**
      +   * param for estimator param maps
      +   * @group param
      +   */
      +  val estimatorParamMaps: Param[Array[ParamMap]] =
      +    new Param(this, "estimatorParamMaps", "param maps for the estimator")
      +
      +  /** @group getParam */
      +  def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps)
      +
      +  /**
      +   * param for the evaluator used to select hyper-parameters that maximize the validated metric
      +   * @group param
      +   */
      +  val evaluator: Param[Evaluator] = new Param(this, "evaluator",
      +    "evaluator used to select hyper-parameters that maximize the validated metric")
      +
      +  /** @group getParam */
      +  def getEvaluator: Evaluator = $(evaluator)
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala
      index ddd34a54503a..bd213e7362e9 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala
      @@ -19,11 +19,19 @@ package org.apache.spark.ml.util
       
       import java.util.UUID
       
      +import org.apache.spark.annotation.DeveloperApi
      +
       
       /**
      + * :: DeveloperApi ::
      + *
        * Trait for an object with an immutable unique ID that identifies itself and its derivatives.
      + *
      + * WARNING: There have not yet been final discussions on this API, so it may be broken in future
      + *          releases.
        */
      -private[spark] trait Identifiable {
      +@DeveloperApi
      +trait Identifiable {
       
         /**
          * An immutable unique ID for the object and its derivatives.
      @@ -33,7 +41,11 @@ private[spark] trait Identifiable {
         override def toString: String = uid
       }
       
      -private[spark] object Identifiable {
      +/**
      + * :: DeveloperApi ::
      + */
      +@DeveloperApi
      +object Identifiable {
       
         /**
          * Returns a random UID that concatenates the given prefix, "_", and 12 random hex chars.
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
      index 2a1db90f2ca2..96a38a3bde96 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
      @@ -20,11 +20,12 @@ package org.apache.spark.ml.util
       import scala.collection.immutable.HashMap
       
       import org.apache.spark.ml.attribute._
      +import org.apache.spark.mllib.linalg.VectorUDT
       import org.apache.spark.sql.types.StructField
       
       
       /**
      - * Helper utilities for tree-based algorithms
      + * Helper utilities for algorithms using ML metadata
        */
       private[spark] object MetadataUtils {
       
      @@ -74,4 +75,20 @@ private[spark] object MetadataUtils {
           }
         }
       
      +  /**
      +   * Takes a Vector column and a list of feature names, and returns the corresponding list of
      +   * feature indices in the column, in order.
      +   * @param col  Vector column which must have feature names specified via attributes
      +   * @param names  List of feature names
      +   */
      +  def getFeatureIndicesFromNames(col: StructField, names: Array[String]): Array[Int] = {
      +    require(col.dataType.isInstanceOf[VectorUDT], s"getFeatureIndicesFromNames expected column $col"
      +      + s" to be Vector type, but it was type ${col.dataType} instead.")
      +    val inputAttr = AttributeGroup.fromStructField(col)
      +    names.map { name =>
      +      require(inputAttr.hasAttr(name),
      +        s"getFeatureIndicesFromNames found no feature with name $name in column $col.")
      +      inputAttr.getAttr(name).index.get
      +    }
      +  }
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
      index 7cd53c6d7ef7..76f651488aef 100644
      --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
      +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
      @@ -32,10 +32,15 @@ private[spark] object SchemaUtils {
          * @param colName  column name
          * @param dataType  required column data type
          */
      -  def checkColumnType(schema: StructType, colName: String, dataType: DataType): Unit = {
      +  def checkColumnType(
      +      schema: StructType,
      +      colName: String,
      +      dataType: DataType,
      +      msg: String = ""): Unit = {
           val actualDataType = schema(colName).dataType
      +    val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
           require(actualDataType.equals(dataType),
      -      s"Column $colName must be of type $dataType but was actually $actualDataType.")
      +      s"Column $colName must be of type $dataType but was actually $actualDataType.$message")
         }
       
         /**
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala
      new file mode 100644
      index 000000000000..8d4174124b5c
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala
      @@ -0,0 +1,153 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.util
      +
      +import scala.collection.mutable
      +
      +import org.apache.spark.{Accumulator, SparkContext}
      +
      +/**
      + * Abstract class for stopwatches.
      + */
      +private[spark] abstract class Stopwatch extends Serializable {
      +
      +  @transient private var running: Boolean = false
      +  private var startTime: Long = _
      +
      +  /**
      +   * Name of the stopwatch.
      +   */
      +  val name: String
      +
      +  /**
      +   * Starts the stopwatch.
      +   * Throws an exception if the stopwatch is already running.
      +   */
      +  def start(): Unit = {
      +    assume(!running, "start() called but the stopwatch is already running.")
      +    running = true
      +    startTime = now
      +  }
      +
      +  /**
      +   * Stops the stopwatch and returns the duration of the last session in milliseconds.
      +   * Throws an exception if the stopwatch is not running.
      +   */
      +  def stop(): Long = {
      +    assume(running, "stop() called but the stopwatch is not running.")
      +    val duration = now - startTime
      +    add(duration)
      +    running = false
      +    duration
      +  }
      +
      +  /**
      +   * Checks whether the stopwatch is running.
      +   */
      +  def isRunning: Boolean = running
      +
      +  /**
      +   * Returns total elapsed time in milliseconds, not counting the current session if the stopwatch
      +   * is running.
      +   */
      +  def elapsed(): Long
      +
      +  override def toString: String = s"$name: ${elapsed()}ms"
      +
      +  /**
      +   * Gets the current time in milliseconds.
      +   */
      +  protected def now: Long = System.currentTimeMillis()
      +
      +  /**
      +   * Adds input duration to total elapsed time.
      +   */
      +  protected def add(duration: Long): Unit
      +}
      +
      +/**
      + * A local [[Stopwatch]].
      + */
      +private[spark] class LocalStopwatch(override val name: String) extends Stopwatch {
      +
      +  private var elapsedTime: Long = 0L
      +
      +  override def elapsed(): Long = elapsedTime
      +
      +  override protected def add(duration: Long): Unit = {
      +    elapsedTime += duration
      +  }
      +}
      +
      +/**
      + * A distributed [[Stopwatch]] using Spark accumulator.
      + * @param sc SparkContext
      + */
      +private[spark] class DistributedStopwatch(
      +    sc: SparkContext,
      +    override val name: String) extends Stopwatch {
      +
      +  private val elapsedTime: Accumulator[Long] = sc.accumulator(0L, s"DistributedStopwatch($name)")
      +
      +  override def elapsed(): Long = elapsedTime.value
      +
      +  override protected def add(duration: Long): Unit = {
      +    elapsedTime += duration
      +  }
      +}
      +
      +/**
      + * A multiple stopwatch that contains local and distributed stopwatches.
      + * @param sc SparkContext
      + */
      +private[spark] class MultiStopwatch(@transient private val sc: SparkContext) extends Serializable {
      +
      +  private val stopwatches: mutable.Map[String, Stopwatch] = mutable.Map.empty
      +
      +  /**
      +   * Adds a local stopwatch.
      +   * @param name stopwatch name
      +   */
      +  def addLocal(name: String): this.type = {
      +    require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.")
      +    stopwatches(name) = new LocalStopwatch(name)
      +    this
      +  }
      +
      +  /**
      +   * Adds a distributed stopwatch.
      +   * @param name stopwatch name
      +   */
      +  def addDistributed(name: String): this.type = {
      +    require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.")
      +    stopwatches(name) = new DistributedStopwatch(sc, name)
      +    this
      +  }
      +
      +  /**
      +   * Gets a stopwatch.
      +   * @param name stopwatch name
      +   */
      +  def apply(name: String): Stopwatch = stopwatches(name)
      +
      +  override def toString: String = {
      +    stopwatches.values.toArray.sortBy(_.name)
      +      .map(c => s"  $c")
      +      .mkString("{\n", ",\n", "\n}")
      +  }
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
      new file mode 100644
      index 000000000000..0ec88ef77d69
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
      @@ -0,0 +1,53 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.mllib.api.python
      +
      +import java.util.{List => JList}
      +
      +import scala.collection.JavaConverters._
      +import scala.collection.mutable.ArrayBuffer
      +
      +import org.apache.spark.SparkContext
      +import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrix}
      +import org.apache.spark.mllib.clustering.GaussianMixtureModel
      +
      +/**
      +  * Wrapper around GaussianMixtureModel to provide helper methods in Python
      +  */
      +private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) {
      +  val weights: Vector = Vectors.dense(model.weights)
      +  val k: Int = weights.size
      +
      +  /**
      +    * Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian
      +    */
      +  val gaussians: JList[Object] = {
      +    val modelGaussians = model.gaussians
      +    var i = 0
      +    var mu = ArrayBuffer.empty[Vector]
      +    var sigma = ArrayBuffer.empty[Matrix]
      +    while (i < k) {
      +      mu += modelGaussians(i).mu
      +      sigma += modelGaussians(i).sigma
      +      i += 1
      +    }
      +    List(mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
      +  }
      +
      +  def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala
      new file mode 100644
      index 000000000000..bc6041b22173
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala
      @@ -0,0 +1,32 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.mllib.api.python
      +
      +import org.apache.spark.rdd.RDD
      +import org.apache.spark.mllib.clustering.PowerIterationClusteringModel
      +
      +/**
      + * A Wrapper of PowerIterationClusteringModel to provide helper method for Python
      + */
      +private[python] class PowerIterationClusteringModelWrapper(model: PowerIterationClusteringModel)
      +  extends PowerIterationClusteringModel(model.k, model.assignments) {
      +
      +  def getAssignments: RDD[Array[Any]] = {
      +    model.assignments.map(x => Array(x.id, x.cluster))
      +  }
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
      index 634d56d08d17..69ce7f50709a 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
      @@ -28,6 +28,7 @@ import scala.reflect.ClassTag
       
       import net.razorvine.pickle._
       
      +import org.apache.spark.SparkContext
       import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
       import org.apache.spark.api.python.SerDeUtil
       import org.apache.spark.mllib.classification._
      @@ -36,13 +37,14 @@ import org.apache.spark.mllib.evaluation.RankingMetrics
       import org.apache.spark.mllib.feature._
       import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel}
       import org.apache.spark.mllib.linalg._
      +import org.apache.spark.mllib.linalg.distributed._
       import org.apache.spark.mllib.optimization._
       import org.apache.spark.mllib.random.{RandomRDDs => RG}
       import org.apache.spark.mllib.recommendation._
       import org.apache.spark.mllib.regression._
       import org.apache.spark.mllib.stat.correlation.CorrelationNames
       import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
      -import org.apache.spark.mllib.stat.test.ChiSqTestResult
      +import org.apache.spark.mllib.stat.test.{ChiSqTestResult, KolmogorovSmirnovTestResult}
       import org.apache.spark.mllib.stat.{
         KernelDensity, MultivariateStatisticalSummary, Statistics}
       import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy}
      @@ -51,8 +53,9 @@ import org.apache.spark.mllib.tree.loss.Losses
       import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel, RandomForestModel}
       import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest}
       import org.apache.spark.mllib.util.MLUtils
      +import org.apache.spark.mllib.util.LinearDataGenerator
       import org.apache.spark.rdd.RDD
      -import org.apache.spark.sql.DataFrame
      +import org.apache.spark.sql.{DataFrame, Row, SQLContext}
       import org.apache.spark.storage.StorageLevel
       import org.apache.spark.util.Utils
       
      @@ -74,6 +77,15 @@ private[python] class PythonMLLibAPI extends Serializable {
             minPartitions: Int): JavaRDD[LabeledPoint] =
           MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions)
       
      +  /**
      +   * Loads and serializes vectors saved with `RDD#saveAsTextFile`.
      +   * @param jsc Java SparkContext
      +   * @param path file or directory path in any Hadoop-supported file system URI
      +   * @return serialized vectors in a RDD
      +   */
      +  def loadVectors(jsc: JavaSparkContext, path: String): RDD[Vector] =
      +    MLUtils.loadVectors(jsc.sc, path)
      +
         private def trainRegressionModel(
             learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel],
             data: JavaRDD[LabeledPoint],
      @@ -120,7 +132,8 @@ private[python] class PythonMLLibAPI extends Serializable {
             regParam: Double,
             regType: String,
             intercept: Boolean,
      -      validateData: Boolean): JList[Object] = {
      +      validateData: Boolean,
      +      convergenceTol: Double): JList[Object] = {
           val lrAlg = new LinearRegressionWithSGD()
           lrAlg.setIntercept(intercept)
             .setValidateData(validateData)
      @@ -129,6 +142,7 @@ private[python] class PythonMLLibAPI extends Serializable {
             .setRegParam(regParam)
             .setStepSize(stepSize)
             .setMiniBatchFraction(miniBatchFraction)
      +      .setConvergenceTol(convergenceTol)
           lrAlg.optimizer.setUpdater(getUpdaterFromString(regType))
           trainRegressionModel(
             lrAlg,
      @@ -147,7 +161,8 @@ private[python] class PythonMLLibAPI extends Serializable {
             miniBatchFraction: Double,
             initialWeights: Vector,
             intercept: Boolean,
      -      validateData: Boolean): JList[Object] = {
      +      validateData: Boolean,
      +      convergenceTol: Double): JList[Object] = {
           val lassoAlg = new LassoWithSGD()
           lassoAlg.setIntercept(intercept)
             .setValidateData(validateData)
      @@ -156,6 +171,7 @@ private[python] class PythonMLLibAPI extends Serializable {
             .setRegParam(regParam)
             .setStepSize(stepSize)
             .setMiniBatchFraction(miniBatchFraction)
      +      .setConvergenceTol(convergenceTol)
           trainRegressionModel(
             lassoAlg,
             data,
      @@ -173,7 +189,8 @@ private[python] class PythonMLLibAPI extends Serializable {
             miniBatchFraction: Double,
             initialWeights: Vector,
             intercept: Boolean,
      -      validateData: Boolean): JList[Object] = {
      +      validateData: Boolean,
      +      convergenceTol: Double): JList[Object] = {
           val ridgeAlg = new RidgeRegressionWithSGD()
           ridgeAlg.setIntercept(intercept)
             .setValidateData(validateData)
      @@ -182,6 +199,7 @@ private[python] class PythonMLLibAPI extends Serializable {
             .setRegParam(regParam)
             .setStepSize(stepSize)
             .setMiniBatchFraction(miniBatchFraction)
      +      .setConvergenceTol(convergenceTol)
           trainRegressionModel(
             ridgeAlg,
             data,
      @@ -200,7 +218,8 @@ private[python] class PythonMLLibAPI extends Serializable {
             initialWeights: Vector,
             regType: String,
             intercept: Boolean,
      -      validateData: Boolean): JList[Object] = {
      +      validateData: Boolean,
      +      convergenceTol: Double): JList[Object] = {
           val SVMAlg = new SVMWithSGD()
           SVMAlg.setIntercept(intercept)
             .setValidateData(validateData)
      @@ -209,6 +228,7 @@ private[python] class PythonMLLibAPI extends Serializable {
             .setRegParam(regParam)
             .setStepSize(stepSize)
             .setMiniBatchFraction(miniBatchFraction)
      +      .setConvergenceTol(convergenceTol)
           SVMAlg.optimizer.setUpdater(getUpdaterFromString(regType))
           trainRegressionModel(
             SVMAlg,
      @@ -228,7 +248,8 @@ private[python] class PythonMLLibAPI extends Serializable {
             regParam: Double,
             regType: String,
             intercept: Boolean,
      -      validateData: Boolean): JList[Object] = {
      +      validateData: Boolean,
      +      convergenceTol: Double): JList[Object] = {
           val LogRegAlg = new LogisticRegressionWithSGD()
           LogRegAlg.setIntercept(intercept)
             .setValidateData(validateData)
      @@ -237,6 +258,7 @@ private[python] class PythonMLLibAPI extends Serializable {
             .setRegParam(regParam)
             .setStepSize(stepSize)
             .setMiniBatchFraction(miniBatchFraction)
      +      .setConvergenceTol(convergenceTol)
           LogRegAlg.optimizer.setUpdater(getUpdaterFromString(regType))
           trainRegressionModel(
             LogRegAlg,
      @@ -277,7 +299,7 @@ private[python] class PythonMLLibAPI extends Serializable {
         /**
          * Java stub for NaiveBayes.train()
          */
      -  def trainNaiveBayes(
      +  def trainNaiveBayesModel(
             data: JavaRDD[LabeledPoint],
             lambda: Double): JList[Object] = {
           val model = NaiveBayes.train(data.rdd, lambda)
      @@ -345,7 +367,7 @@ private[python] class PythonMLLibAPI extends Serializable {
          * Java stub for Python mllib GaussianMixture.run()
          * Returns a list containing weights, mean and covariance of each mixture component.
          */
      -  def trainGaussianMixture(
      +  def trainGaussianMixtureModel(
             data: JavaRDD[Vector],
             k: Int,
             convergenceTol: Double,
      @@ -353,7 +375,7 @@ private[python] class PythonMLLibAPI extends Serializable {
             seed: java.lang.Long,
             initialModelWeights: java.util.ArrayList[Double],
             initialModelMu: java.util.ArrayList[Vector],
      -      initialModelSigma: java.util.ArrayList[Matrix]): JList[Object] = {
      +      initialModelSigma: java.util.ArrayList[Matrix]): GaussianMixtureModelWrapper = {
           val gmmAlg = new GaussianMixture()
             .setK(k)
             .setConvergenceTol(convergenceTol)
      @@ -371,16 +393,7 @@ private[python] class PythonMLLibAPI extends Serializable {
           if (seed != null) gmmAlg.setSeed(seed)
       
           try {
      -      val model = gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
      -      var wt = ArrayBuffer.empty[Double]
      -      var mu = ArrayBuffer.empty[Vector]
      -      var sigma = ArrayBuffer.empty[Matrix]
      -      for (i <- 0 until model.k) {
      -          wt += model.weights(i)
      -          mu += model.gaussians(i).mu
      -          sigma += model.gaussians(i).sigma
      -      }
      -      List(Vectors.dense(wt.toArray), mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
      +      new GaussianMixtureModelWrapper(gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)))
           } finally {
             data.rdd.unpersist(blocking = false)
           }
      @@ -405,6 +418,33 @@ private[python] class PythonMLLibAPI extends Serializable {
             model.predictSoft(data).map(Vectors.dense)
         }
       
      +  /**
      +   * Java stub for Python mllib PowerIterationClustering.run(). This stub returns a
      +   * handle to the Java object instead of the content of the Java object.  Extra care
      +   * needs to be taken in the Python code to ensure it gets freed on exit; see the
      +   * Py4J documentation.
      +   * @param data an RDD of (i, j, s,,ij,,) tuples representing the affinity matrix.
      +   * @param k number of clusters.
      +   * @param maxIterations maximum number of iterations of the power iteration loop.
      +   * @param initMode the initialization mode. This can be either "random" to use
      +   *                 a random vector as vertex properties, or "degree" to use
      +   *                 normalized sum similarities. Default: random.
      +   */
      +  def trainPowerIterationClusteringModel(
      +      data: JavaRDD[Vector],
      +      k: Int,
      +      maxIterations: Int,
      +      initMode: String): PowerIterationClusteringModel = {
      +
      +    val pic = new PowerIterationClustering()
      +      .setK(k)
      +      .setMaxIterations(maxIterations)
      +      .setInitializationMode(initMode)
      +
      +    val model = pic.run(data.rdd.map(v => (v(0).toLong, v(1).toLong, v(2))))
      +    new PowerIterationClusteringModelWrapper(model)
      +  }
      +
         /**
          * Java stub for Python mllib ALS.train().  This stub returns a handle
          * to the Java object instead of the content of the Java object.  Extra care
      @@ -464,6 +504,39 @@ private[python] class PythonMLLibAPI extends Serializable {
           new MatrixFactorizationModelWrapper(model)
         }
       
      +  /**
      +   * Java stub for Python mllib LDA.run()
      +   */
      +  def trainLDAModel(
      +      data: JavaRDD[java.util.List[Any]],
      +      k: Int,
      +      maxIterations: Int,
      +      docConcentration: Double,
      +      topicConcentration: Double,
      +      seed: java.lang.Long,
      +      checkpointInterval: Int,
      +      optimizer: String): LDAModel = {
      +    val algo = new LDA()
      +      .setK(k)
      +      .setMaxIterations(maxIterations)
      +      .setDocConcentration(docConcentration)
      +      .setTopicConcentration(topicConcentration)
      +      .setCheckpointInterval(checkpointInterval)
      +      .setOptimizer(optimizer)
      +
      +    if (seed != null) algo.setSeed(seed)
      +
      +    val documents = data.rdd.map(_.asScala.toArray).map { r =>
      +      r(0) match {
      +        case i: java.lang.Integer => (i.toLong, r(1).asInstanceOf[Vector])
      +        case i: java.lang.Long => (i.toLong, r(1).asInstanceOf[Vector])
      +        case _ => throw new IllegalArgumentException("input values contains invalid type value.")
      +      }
      +    }
      +    algo.run(documents)
      +  }
      +
      +
         /**
          * Java stub for Python mllib FPGrowth.train().  This stub returns a handle
          * to the Java object instead of the content of the Java object.  Extra care
      @@ -552,7 +625,7 @@ private[python] class PythonMLLibAPI extends Serializable {
          * @param seed initial seed for random generator
          * @return A handle to java Word2VecModelWrapper instance at python side
          */
      -  def trainWord2Vec(
      +  def trainWord2VecModel(
             dataJRDD: JavaRDD[java.util.ArrayList[String]],
             vectorSize: Int,
             learningRate: Double,
      @@ -604,6 +677,8 @@ private[python] class PythonMLLibAPI extends Serializable {
           def getVectors: JMap[String, JList[Float]] = {
             model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava
           }
      +
      +    def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
         }
       
         /**
      @@ -696,12 +771,14 @@ private[python] class PythonMLLibAPI extends Serializable {
             lossStr: String,
             numIterations: Int,
             learningRate: Double,
      -      maxDepth: Int): GradientBoostedTreesModel = {
      +      maxDepth: Int,
      +      maxBins: Int): GradientBoostedTreesModel = {
           val boostingStrategy = BoostingStrategy.defaultParams(algoStr)
           boostingStrategy.setLoss(Losses.fromString(lossStr))
           boostingStrategy.setNumIterations(numIterations)
           boostingStrategy.setLearningRate(learningRate)
           boostingStrategy.treeStrategy.setMaxDepth(maxDepth)
      +    boostingStrategy.treeStrategy.setMaxBins(maxBins)
           boostingStrategy.treeStrategy.categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap
       
           val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK)
      @@ -970,7 +1047,7 @@ private[python] class PythonMLLibAPI extends Serializable {
         def estimateKernelDensity(
             sample: JavaRDD[Double],
             bandwidth: Double, points: java.util.ArrayList[Double]): Array[Double] = {
      -    return new KernelDensity().setSample(sample).setBandwidth(bandwidth).estimate(
      +    new KernelDensity().setSample(sample).setBandwidth(bandwidth).estimate(
             points.asScala.toArray)
         }
       
      @@ -989,6 +1066,122 @@ private[python] class PythonMLLibAPI extends Serializable {
             List[AnyRef](model.clusterCenters, Vectors.dense(model.clusterWeights)).asJava
         }
       
      +  /**
      +   * Wrapper around the generateLinearInput method of LinearDataGenerator.
      +   */
      +  def generateLinearInputWrapper(
      +      intercept: Double,
      +      weights: JList[Double],
      +      xMean: JList[Double],
      +      xVariance: JList[Double],
      +      nPoints: Int,
      +      seed: Int,
      +      eps: Double): Array[LabeledPoint] = {
      +    LinearDataGenerator.generateLinearInput(
      +      intercept, weights.asScala.toArray, xMean.asScala.toArray,
      +      xVariance.asScala.toArray, nPoints, seed, eps).toArray
      +  }
      +
      +  /**
      +   * Wrapper around the generateLinearRDD method of LinearDataGenerator.
      +   */
      +  def generateLinearRDDWrapper(
      +      sc: JavaSparkContext,
      +      nexamples: Int,
      +      nfeatures: Int,
      +      eps: Double,
      +      nparts: Int,
      +      intercept: Double): JavaRDD[LabeledPoint] = {
      +    LinearDataGenerator.generateLinearRDD(
      +      sc, nexamples, nfeatures, eps, nparts, intercept)
      +  }
      +
      +  /**
      +   * Java stub for Statistics.kolmogorovSmirnovTest()
      +   */
      +  def kolmogorovSmirnovTest(
      +      data: JavaRDD[Double],
      +      distName: String,
      +      params: JList[Double]): KolmogorovSmirnovTestResult = {
      +    val paramsSeq = params.asScala.toSeq
      +    Statistics.kolmogorovSmirnovTest(data, distName, paramsSeq: _*)
      +  }
      +
      +  /**
      +   * Wrapper around RowMatrix constructor.
      +   */
      +  def createRowMatrix(rows: JavaRDD[Vector], numRows: Long, numCols: Int): RowMatrix = {
      +    new RowMatrix(rows.rdd, numRows, numCols)
      +  }
      +
      +  /**
      +   * Wrapper around IndexedRowMatrix constructor.
      +   */
      +  def createIndexedRowMatrix(rows: DataFrame, numRows: Long, numCols: Int): IndexedRowMatrix = {
      +    // We use DataFrames for serialization of IndexedRows from Python,
      +    // so map each Row in the DataFrame back to an IndexedRow.
      +    val indexedRows = rows.map {
      +      case Row(index: Long, vector: Vector) => IndexedRow(index, vector)
      +    }
      +    new IndexedRowMatrix(indexedRows, numRows, numCols)
      +  }
      +
      +  /**
      +   * Wrapper around CoordinateMatrix constructor.
      +   */
      +  def createCoordinateMatrix(rows: DataFrame, numRows: Long, numCols: Long): CoordinateMatrix = {
      +    // We use DataFrames for serialization of MatrixEntry entries from
      +    // Python, so map each Row in the DataFrame back to a MatrixEntry.
      +    val entries = rows.map {
      +      case Row(i: Long, j: Long, value: Double) => MatrixEntry(i, j, value)
      +    }
      +    new CoordinateMatrix(entries, numRows, numCols)
      +  }
      +
      +  /**
      +   * Wrapper around BlockMatrix constructor.
      +   */
      +  def createBlockMatrix(blocks: DataFrame, rowsPerBlock: Int, colsPerBlock: Int,
      +                        numRows: Long, numCols: Long): BlockMatrix = {
      +    // We use DataFrames for serialization of sub-matrix blocks from
      +    // Python, so map each Row in the DataFrame back to a
      +    // ((blockRowIndex, blockColIndex), sub-matrix) tuple.
      +    val blockTuples = blocks.map {
      +      case Row(Row(blockRowIndex: Long, blockColIndex: Long), subMatrix: Matrix) =>
      +        ((blockRowIndex.toInt, blockColIndex.toInt), subMatrix)
      +    }
      +    new BlockMatrix(blockTuples, rowsPerBlock, colsPerBlock, numRows, numCols)
      +  }
      +
      +  /**
      +   * Return the rows of an IndexedRowMatrix.
      +   */
      +  def getIndexedRows(indexedRowMatrix: IndexedRowMatrix): DataFrame = {
      +    // We use DataFrames for serialization of IndexedRows to Python,
      +    // so return a DataFrame.
      +    val sqlContext = new SQLContext(indexedRowMatrix.rows.sparkContext)
      +    sqlContext.createDataFrame(indexedRowMatrix.rows)
      +  }
      +
      +  /**
      +   * Return the entries of a CoordinateMatrix.
      +   */
      +  def getMatrixEntries(coordinateMatrix: CoordinateMatrix): DataFrame = {
      +    // We use DataFrames for serialization of MatrixEntry entries to
      +    // Python, so return a DataFrame.
      +    val sqlContext = new SQLContext(coordinateMatrix.entries.sparkContext)
      +    sqlContext.createDataFrame(coordinateMatrix.entries)
      +  }
      +
      +  /**
      +   * Return the sub-matrix blocks of a BlockMatrix.
      +   */
      +  def getMatrixBlocks(blockMatrix: BlockMatrix): DataFrame = {
      +    // We use DataFrames for serialization of sub-matrix blocks to
      +    // Python, so return a DataFrame.
      +    val sqlContext = new SQLContext(blockMatrix.blocks.sparkContext)
      +    sqlContext.createDataFrame(blockMatrix.blocks)
      +  }
       }
       
       /**
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
      index 35a0db76f3a8..85a413243b04 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
      @@ -19,7 +19,7 @@ package org.apache.spark.mllib.classification
       
       import org.json4s.{DefaultFormats, JValue}
       
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.api.java.JavaRDD
       import org.apache.spark.mllib.linalg.Vector
       import org.apache.spark.rdd.RDD
      @@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD
        * belongs. The categories are represented by double values: 0.0, 1.0, 2.0, etc.
        */
       @Experimental
      +@Since("0.8.0")
       trait ClassificationModel extends Serializable {
         /**
          * Predict values for the given data set using the model trained.
      @@ -37,6 +38,7 @@ trait ClassificationModel extends Serializable {
          * @param testData RDD representing data points to be predicted
          * @return an RDD[Double] where each entry contains the corresponding prediction
          */
      +  @Since("1.0.0")
         def predict(testData: RDD[Vector]): RDD[Double]
       
         /**
      @@ -45,6 +47,7 @@ trait ClassificationModel extends Serializable {
          * @param testData array representing a single data point
          * @return predicted category from the trained model
          */
      +  @Since("1.0.0")
         def predict(testData: Vector): Double
       
         /**
      @@ -52,6 +55,7 @@ trait ClassificationModel extends Serializable {
          * @param testData JavaRDD representing data points to be predicted
          * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction
          */
      +  @Since("1.0.0")
         def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
           predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
      index 2df4d21e8cd5..5ceff5b2259e 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
      @@ -18,7 +18,7 @@
       package org.apache.spark.mllib.classification
       
       import org.apache.spark.SparkContext
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.mllib.classification.impl.GLMClassificationModel
       import org.apache.spark.mllib.linalg.BLAS.dot
       import org.apache.spark.mllib.linalg.{DenseVector, Vector}
      @@ -41,11 +41,12 @@ import org.apache.spark.rdd.RDD
        *                   Multinomial Logistic Regression. By default, it is binary logistic regression
        *                   so numClasses will be set to 2.
        */
      -class LogisticRegressionModel (
      -    override val weights: Vector,
      -    override val intercept: Double,
      -    val numFeatures: Int,
      -    val numClasses: Int)
      +@Since("0.8.0")
      +class LogisticRegressionModel @Since("1.3.0") (
      +    @Since("1.0.0") override val weights: Vector,
      +    @Since("1.0.0") override val intercept: Double,
      +    @Since("1.3.0") val numFeatures: Int,
      +    @Since("1.3.0") val numClasses: Int)
         extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
         with Saveable with PMMLExportable {
       
      @@ -75,6 +76,7 @@ class LogisticRegressionModel (
         /**
          * Constructs a [[LogisticRegressionModel]] with weights and intercept for binary classification.
          */
      +  @Since("1.0.0")
         def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2)
       
         private var threshold: Option[Double] = Some(0.5)
      @@ -86,6 +88,7 @@ class LogisticRegressionModel (
          * this threshold is identified as an positive, and negative otherwise. The default value is 0.5.
          * It is only used for binary classification.
          */
      +  @Since("1.0.0")
         @Experimental
         def setThreshold(threshold: Double): this.type = {
           this.threshold = Some(threshold)
      @@ -97,6 +100,7 @@ class LogisticRegressionModel (
          * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions.
          * It is only used for binary classification.
          */
      +  @Since("1.3.0")
         @Experimental
         def getThreshold: Option[Double] = threshold
       
      @@ -105,6 +109,7 @@ class LogisticRegressionModel (
          * Clears the threshold so that `predict` will output raw prediction scores.
          * It is only used for binary classification.
          */
      +  @Since("1.0.0")
         @Experimental
         def clearThreshold(): this.type = {
           threshold = None
      @@ -155,6 +160,7 @@ class LogisticRegressionModel (
           }
         }
       
      +  @Since("1.3.0")
         override def save(sc: SparkContext, path: String): Unit = {
           GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName,
             numFeatures, numClasses, weights, intercept, threshold)
      @@ -167,8 +173,10 @@ class LogisticRegressionModel (
         }
       }
       
      +@Since("1.3.0")
       object LogisticRegressionModel extends Loader[LogisticRegressionModel] {
       
      +  @Since("1.3.0")
         override def load(sc: SparkContext, path: String): LogisticRegressionModel = {
           val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
           // Hard-code class name string in case it changes in the future
      @@ -201,6 +209,7 @@ object LogisticRegressionModel extends Loader[LogisticRegressionModel] {
        * for k classes multi-label classification problem.
        * Using [[LogisticRegressionWithLBFGS]] is recommended over this.
        */
      +@Since("0.8.0")
       class LogisticRegressionWithSGD private[mllib] (
           private var stepSize: Double,
           private var numIterations: Int,
      @@ -210,6 +219,7 @@ class LogisticRegressionWithSGD private[mllib] (
       
         private val gradient = new LogisticGradient()
         private val updater = new SquaredL2Updater()
      +  @Since("0.8.0")
         override val optimizer = new GradientDescent(gradient, updater)
           .setStepSize(stepSize)
           .setNumIterations(numIterations)
      @@ -221,6 +231,7 @@ class LogisticRegressionWithSGD private[mllib] (
          * Construct a LogisticRegression object with default parameters: {stepSize: 1.0,
          * numIterations: 100, regParm: 0.01, miniBatchFraction: 1.0}.
          */
      +  @Since("0.8.0")
         def this() = this(1.0, 100, 0.01, 1.0)
       
         override protected[mllib] def createModel(weights: Vector, intercept: Double) = {
      @@ -232,6 +243,7 @@ class LogisticRegressionWithSGD private[mllib] (
        * Top-level methods for calling Logistic Regression using Stochastic Gradient Descent.
        * NOTE: Labels used in Logistic Regression should be {0, 1}
        */
      +@Since("0.8.0")
       object LogisticRegressionWithSGD {
         // NOTE(shivaram): We use multiple train methods instead of default arguments to support
         // Java programs.
      @@ -250,6 +262,7 @@ object LogisticRegressionWithSGD {
          * @param initialWeights Initial set of weights to be used. Array should be equal in size to
          *        the number of features in the data.
          */
      +  @Since("1.0.0")
         def train(
             input: RDD[LabeledPoint],
             numIterations: Int,
      @@ -272,6 +285,7 @@ object LogisticRegressionWithSGD {
       
          * @param miniBatchFraction Fraction of data to be used per iteration.
          */
      +  @Since("1.0.0")
         def train(
             input: RDD[LabeledPoint],
             numIterations: Int,
      @@ -293,6 +307,7 @@ object LogisticRegressionWithSGD {
          * @param numIterations Number of iterations of gradient descent to run.
          * @return a LogisticRegressionModel which has the weights and offset from training.
          */
      +  @Since("1.0.0")
         def train(
             input: RDD[LabeledPoint],
             numIterations: Int,
      @@ -310,6 +325,7 @@ object LogisticRegressionWithSGD {
          * @param numIterations Number of iterations of gradient descent to run.
          * @return a LogisticRegressionModel which has the weights and offset from training.
          */
      +  @Since("1.0.0")
         def train(
             input: RDD[LabeledPoint],
             numIterations: Int): LogisticRegressionModel = {
      @@ -323,11 +339,13 @@ object LogisticRegressionWithSGD {
        * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1}
        * for k classes multi-label classification problem.
        */
      +@Since("1.1.0")
       class LogisticRegressionWithLBFGS
         extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable {
       
         this.setFeatureScaling(true)
       
      +  @Since("1.1.0")
         override val optimizer = new LBFGS(new LogisticGradient, new SquaredL2Updater)
       
         override protected val validators = List(multiLabelValidator)
      @@ -346,6 +364,7 @@ class LogisticRegressionWithLBFGS
          * Multinomial Logistic Regression.
          * By default, it is binary logistic regression so k will be set to 2.
          */
      +  @Since("1.3.0")
         @Experimental
         def setNumClasses(numClasses: Int): this.type = {
           require(numClasses > 1)
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
      index f51ee36d0dfc..a956084ae06e 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
      @@ -25,6 +25,7 @@ import org.json4s.JsonDSL._
       import org.json4s.jackson.JsonMethods._
       
       import org.apache.spark.{Logging, SparkContext, SparkException}
      +import org.apache.spark.annotation.Since
       import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector}
       import org.apache.spark.mllib.regression.LabeledPoint
       import org.apache.spark.mllib.util.{Loader, Saveable}
      @@ -40,11 +41,12 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
        *              where D is number of features
        * @param modelType The type of NB model to fit  can be "multinomial" or "bernoulli"
        */
      -class NaiveBayesModel private[mllib] (
      -    val labels: Array[Double],
      -    val pi: Array[Double],
      -    val theta: Array[Array[Double]],
      -    val modelType: String)
      +@Since("0.9.0")
      +class NaiveBayesModel private[spark] (
      +    @Since("1.0.0") val labels: Array[Double],
      +    @Since("0.9.0") val pi: Array[Double],
      +    @Since("0.9.0") val theta: Array[Array[Double]],
      +    @Since("1.4.0") val modelType: String)
         extends ClassificationModel with Serializable with Saveable {
       
         import NaiveBayes.{Bernoulli, Multinomial, supportedModelTypes}
      @@ -82,6 +84,7 @@ class NaiveBayesModel private[mllib] (
             throw new UnknownError(s"Invalid modelType: $modelType.")
         }
       
      +  @Since("1.0.0")
         override def predict(testData: RDD[Vector]): RDD[Double] = {
           val bcModel = testData.context.broadcast(this)
           testData.mapPartitions { iter =>
      @@ -90,29 +93,77 @@ class NaiveBayesModel private[mllib] (
           }
         }
       
      +  @Since("1.0.0")
         override def predict(testData: Vector): Double = {
           modelType match {
             case Multinomial =>
      -        val prob = thetaMatrix.multiply(testData)
      -        BLAS.axpy(1.0, piVector, prob)
      -        labels(prob.argmax)
      +        labels(multinomialCalculation(testData).argmax)
             case Bernoulli =>
      -        testData.foreachActive { (index, value) =>
      -          if (value != 0.0 && value != 1.0) {
      -            throw new SparkException(
      -              s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.")
      -          }
      -        }
      -        val prob = thetaMinusNegTheta.get.multiply(testData)
      -        BLAS.axpy(1.0, piVector, prob)
      -        BLAS.axpy(1.0, negThetaSum.get, prob)
      -        labels(prob.argmax)
      -      case _ =>
      -        // This should never happen.
      -        throw new UnknownError(s"Invalid modelType: $modelType.")
      +        labels(bernoulliCalculation(testData).argmax)
      +    }
      +  }
      +
      +  /**
      +   * Predict values for the given data set using the model trained.
      +   *
      +   * @param testData RDD representing data points to be predicted
      +   * @return an RDD[Vector] where each entry contains the predicted posterior class probabilities,
      +   *         in the same order as class labels
      +   */
      +  @Since("1.5.0")
      +  def predictProbabilities(testData: RDD[Vector]): RDD[Vector] = {
      +    val bcModel = testData.context.broadcast(this)
      +    testData.mapPartitions { iter =>
      +      val model = bcModel.value
      +      iter.map(model.predictProbabilities)
      +    }
      +  }
      +
      +  /**
      +   * Predict posterior class probabilities for a single data point using the model trained.
      +   *
      +   * @param testData array representing a single data point
      +   * @return predicted posterior class probabilities from the trained model,
      +   *         in the same order as class labels
      +   */
      +  @Since("1.5.0")
      +  def predictProbabilities(testData: Vector): Vector = {
      +    modelType match {
      +      case Multinomial =>
      +        posteriorProbabilities(multinomialCalculation(testData))
      +      case Bernoulli =>
      +        posteriorProbabilities(bernoulliCalculation(testData))
           }
         }
       
      +  private def multinomialCalculation(testData: Vector) = {
      +    val prob = thetaMatrix.multiply(testData)
      +    BLAS.axpy(1.0, piVector, prob)
      +    prob
      +  }
      +
      +  private def bernoulliCalculation(testData: Vector) = {
      +    testData.foreachActive((_, value) =>
      +      if (value != 0.0 && value != 1.0) {
      +        throw new SparkException(
      +          s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.")
      +      }
      +    )
      +    val prob = thetaMinusNegTheta.get.multiply(testData)
      +    BLAS.axpy(1.0, piVector, prob)
      +    BLAS.axpy(1.0, negThetaSum.get, prob)
      +    prob
      +  }
      +
      +  private def posteriorProbabilities(logProb: DenseVector) = {
      +    val logProbArray = logProb.toArray
      +    val maxLog = logProbArray.max
      +    val scaledProbs = logProbArray.map(lp => math.exp(lp - maxLog))
      +    val probSum = scaledProbs.sum
      +    new DenseVector(scaledProbs.map(_ / probSum))
      +  }
      +
      +  @Since("1.3.0")
         override def save(sc: SparkContext, path: String): Unit = {
           val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType)
           NaiveBayesModel.SaveLoadV2_0.save(sc, path, data)
      @@ -121,6 +172,7 @@ class NaiveBayesModel private[mllib] (
         override protected def formatVersion: String = "2.0"
       }
       
      +@Since("1.3.0")
       object NaiveBayesModel extends Loader[NaiveBayesModel] {
       
         import org.apache.spark.mllib.util.Loader._
      @@ -154,6 +206,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
             dataRDD.write.parquet(dataPath(path))
           }
       
      +    @Since("1.3.0")
           def load(sc: SparkContext, path: String): NaiveBayesModel = {
             val sqlContext = new SQLContext(sc)
             // Load Parquet data.
      @@ -256,30 +309,35 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
        * document classification.  By making every vector a 0-1 vector, it can also be used as
        * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative.
        */
      -
      +@Since("0.9.0")
       class NaiveBayes private (
           private var lambda: Double,
           private var modelType: String) extends Serializable with Logging {
       
         import NaiveBayes.{Bernoulli, Multinomial}
       
      +  @Since("1.4.0")
         def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial)
       
      +  @Since("0.9.0")
         def this() = this(1.0, NaiveBayes.Multinomial)
       
         /** Set the smoothing parameter. Default: 1.0. */
      +  @Since("0.9.0")
         def setLambda(lambda: Double): NaiveBayes = {
           this.lambda = lambda
           this
         }
       
         /** Get the smoothing parameter. */
      +  @Since("1.4.0")
         def getLambda: Double = lambda
       
         /**
          * Set the model type using a string (case-sensitive).
          * Supported options: "multinomial" (default) and "bernoulli".
          */
      +  @Since("1.4.0")
         def setModelType(modelType: String): NaiveBayes = {
           require(NaiveBayes.supportedModelTypes.contains(modelType),
             s"NaiveBayes was created with an unknown modelType: $modelType.")
      @@ -288,6 +346,7 @@ class NaiveBayes private (
         }
       
         /** Get the model type. */
      +  @Since("1.4.0")
         def getModelType: String = this.modelType
       
         /**
      @@ -295,6 +354,7 @@ class NaiveBayes private (
          *
          * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
          */
      +  @Since("0.9.0")
         def run(data: RDD[LabeledPoint]): NaiveBayesModel = {
           val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
             val values = v match {
      @@ -338,7 +398,7 @@ class NaiveBayes private (
               BLAS.axpy(1.0, c2._2, c1._2)
               (c1._1 + c2._1, c1._2)
             }
      -    ).collect()
      +    ).collect().sortBy(_._1)
       
           val numLabels = aggregated.length
           var numDocuments = 0L
      @@ -378,16 +438,17 @@ class NaiveBayes private (
       /**
        * Top-level methods for calling naive Bayes.
        */
      +@Since("0.9.0")
       object NaiveBayes {
       
         /** String name for multinomial model type. */
      -  private[classification] val Multinomial: String = "multinomial"
      +  private[spark] val Multinomial: String = "multinomial"
       
         /** String name for Bernoulli model type. */
      -  private[classification] val Bernoulli: String = "bernoulli"
      +  private[spark] val Bernoulli: String = "bernoulli"
       
         /* Set of modelTypes that NaiveBayes supports */
      -  private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli)
      +  private[spark] val supportedModelTypes = Set(Multinomial, Bernoulli)
       
         /**
          * Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
      @@ -401,6 +462,7 @@ object NaiveBayes {
          * @param input RDD of `(label, array of features)` pairs.  Every vector should be a frequency
          *              vector or a count vector.
          */
      +  @Since("0.9.0")
         def train(input: RDD[LabeledPoint]): NaiveBayesModel = {
           new NaiveBayes().run(input)
         }
      @@ -416,6 +478,7 @@ object NaiveBayes {
          *              vector or a count vector.
          * @param lambda The smoothing parameter
          */
      +  @Since("0.9.0")
         def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
           new NaiveBayes(lambda, Multinomial).run(input)
         }
      @@ -438,6 +501,7 @@ object NaiveBayes {
          * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be
          *              multinomial or bernoulli
          */
      +  @Since("1.4.0")
         def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
           require(supportedModelTypes.contains(modelType),
             s"NaiveBayes was created with an unknown modelType: $modelType.")
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
      index 348485560713..896565cd90e8 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
      @@ -18,7 +18,7 @@
       package org.apache.spark.mllib.classification
       
       import org.apache.spark.SparkContext
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.mllib.classification.impl.GLMClassificationModel
       import org.apache.spark.mllib.linalg.Vector
       import org.apache.spark.mllib.optimization._
      @@ -33,9 +33,10 @@ import org.apache.spark.rdd.RDD
        * @param weights Weights computed for every feature.
        * @param intercept Intercept computed for this model.
        */
      -class SVMModel (
      -    override val weights: Vector,
      -    override val intercept: Double)
      +@Since("0.8.0")
      +class SVMModel @Since("1.1.0") (
      +    @Since("1.0.0") override val weights: Vector,
      +    @Since("0.8.0") override val intercept: Double)
         extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
         with Saveable with PMMLExportable {
       
      @@ -47,6 +48,7 @@ class SVMModel (
          * with prediction score greater than or equal to this threshold is identified as an positive,
          * and negative otherwise. The default value is 0.0.
          */
      +  @Since("1.0.0")
         @Experimental
         def setThreshold(threshold: Double): this.type = {
           this.threshold = Some(threshold)
      @@ -57,6 +59,7 @@ class SVMModel (
          * :: Experimental ::
          * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions.
          */
      +  @Since("1.3.0")
         @Experimental
         def getThreshold: Option[Double] = threshold
       
      @@ -64,6 +67,7 @@ class SVMModel (
          * :: Experimental ::
          * Clears the threshold so that `predict` will output raw prediction scores.
          */
      +  @Since("1.0.0")
         @Experimental
         def clearThreshold(): this.type = {
           threshold = None
      @@ -81,6 +85,7 @@ class SVMModel (
           }
         }
       
      +  @Since("1.3.0")
         override def save(sc: SparkContext, path: String): Unit = {
           GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName,
             numFeatures = weights.size, numClasses = 2, weights, intercept, threshold)
      @@ -93,8 +98,10 @@ class SVMModel (
         }
       }
       
      +@Since("1.3.0")
       object SVMModel extends Loader[SVMModel] {
       
      +  @Since("1.3.0")
         override def load(sc: SparkContext, path: String): SVMModel = {
           val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
           // Hard-code class name string in case it changes in the future
      @@ -126,6 +133,7 @@ object SVMModel extends Loader[SVMModel] {
        * regularization is used, which can be changed via [[SVMWithSGD.optimizer]].
        * NOTE: Labels used in SVM should be {0, 1}.
        */
      +@Since("0.8.0")
       class SVMWithSGD private (
           private var stepSize: Double,
           private var numIterations: Int,
      @@ -135,6 +143,7 @@ class SVMWithSGD private (
       
         private val gradient = new HingeGradient()
         private val updater = new SquaredL2Updater()
      +  @Since("0.8.0")
         override val optimizer = new GradientDescent(gradient, updater)
           .setStepSize(stepSize)
           .setNumIterations(numIterations)
      @@ -146,6 +155,7 @@ class SVMWithSGD private (
          * Construct a SVM object with default parameters: {stepSize: 1.0, numIterations: 100,
          * regParm: 0.01, miniBatchFraction: 1.0}.
          */
      +  @Since("0.8.0")
         def this() = this(1.0, 100, 0.01, 1.0)
       
         override protected def createModel(weights: Vector, intercept: Double) = {
      @@ -156,6 +166,7 @@ class SVMWithSGD private (
       /**
        * Top-level methods for calling SVM. NOTE: Labels used in SVM should be {0, 1}.
        */
      +@Since("0.8.0")
       object SVMWithSGD {
       
         /**
      @@ -174,6 +185,7 @@ object SVMWithSGD {
          * @param initialWeights Initial set of weights to be used. Array should be equal in size to
          *        the number of features in the data.
          */
      +  @Since("0.8.0")
         def train(
             input: RDD[LabeledPoint],
             numIterations: Int,
      @@ -197,6 +209,7 @@ object SVMWithSGD {
          * @param regParam Regularization parameter.
          * @param miniBatchFraction Fraction of data to be used per iteration.
          */
      +  @Since("0.8.0")
         def train(
             input: RDD[LabeledPoint],
             numIterations: Int,
      @@ -218,6 +231,7 @@ object SVMWithSGD {
          * @param numIterations Number of iterations of gradient descent to run.
          * @return a SVMModel which has the weights and offset from training.
          */
      +  @Since("0.8.0")
         def train(
             input: RDD[LabeledPoint],
             numIterations: Int,
      @@ -236,6 +250,7 @@ object SVMWithSGD {
          * @param numIterations Number of iterations of gradient descent to run.
          * @return a SVMModel which has the weights and offset from training.
          */
      +  @Since("0.8.0")
         def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = {
           train(input, numIterations, 1.0, 0.01, 1.0)
         }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala
      index 7d33df3221fb..75630054d136 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala
      @@ -17,7 +17,7 @@
       
       package org.apache.spark.mllib.classification
       
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.mllib.linalg.Vector
       import org.apache.spark.mllib.regression.StreamingLinearAlgorithm
       
      @@ -44,6 +44,7 @@ import org.apache.spark.mllib.regression.StreamingLinearAlgorithm
        * }}}
        */
       @Experimental
      +@Since("1.3.0")
       class StreamingLogisticRegressionWithSGD private[mllib] (
           private var stepSize: Double,
           private var numIterations: Int,
      @@ -58,6 +59,7 @@ class StreamingLogisticRegressionWithSGD private[mllib] (
          * Initial weights must be set before using trainOn or predictOn
          * (see `StreamingLinearAlgorithm`)
          */
      +  @Since("1.3.0")
         def this() = this(0.1, 50, 1.0, 0.0)
       
         protected val algorithm = new LogisticRegressionWithSGD(
      @@ -66,30 +68,35 @@ class StreamingLogisticRegressionWithSGD private[mllib] (
         protected var model: Option[LogisticRegressionModel] = None
       
         /** Set the step size for gradient descent. Default: 0.1. */
      +  @Since("1.3.0")
         def setStepSize(stepSize: Double): this.type = {
           this.algorithm.optimizer.setStepSize(stepSize)
           this
         }
       
         /** Set the number of iterations of gradient descent to run per update. Default: 50. */
      +  @Since("1.3.0")
         def setNumIterations(numIterations: Int): this.type = {
           this.algorithm.optimizer.setNumIterations(numIterations)
           this
         }
       
         /** Set the fraction of each batch to use for updates. Default: 1.0. */
      +  @Since("1.3.0")
         def setMiniBatchFraction(miniBatchFraction: Double): this.type = {
           this.algorithm.optimizer.setMiniBatchFraction(miniBatchFraction)
           this
         }
       
         /** Set the regularization parameter. Default: 0.0. */
      +  @Since("1.3.0")
         def setRegParam(regParam: Double): this.type = {
           this.algorithm.optimizer.setRegParam(regParam)
           this
         }
       
         /** Set the initial weights. Default: [0.0, 0.0]. */
      +  @Since("1.3.0")
         def setInitialWeights(initialWeights: Vector): this.type = {
           this.model = Some(algorithm.createModel(initialWeights, 0.0))
           this
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
      index fc509d2ba147..f82bd82c2037 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
      @@ -21,7 +21,7 @@ import scala.collection.mutable.IndexedSeq
       
       import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, Vector => BV}
       
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.api.java.JavaRDD
       import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, Matrices, Vector, Vectors}
       import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
      @@ -53,6 +53,7 @@ import org.apache.spark.util.Utils
        * @param maxIterations The maximum number of iterations to perform
        */
       @Experimental
      +@Since("1.3.0")
       class GaussianMixture private (
           private var k: Int,
           private var convergenceTol: Double,
      @@ -63,6 +64,7 @@ class GaussianMixture private (
          * Constructs a default instance. The default parameters are {k: 2, convergenceTol: 0.01,
          * maxIterations: 100, seed: random}.
          */
      +  @Since("1.3.0")
         def this() = this(2, 0.01, 100, Utils.random.nextLong())
       
         // number of samples per cluster to use when initializing Gaussians
      @@ -72,10 +74,12 @@ class GaussianMixture private (
         // default random starting point
         private var initialModel: Option[GaussianMixtureModel] = None
       
      -  /** Set the initial GMM starting point, bypassing the random initialization.
      -   *  You must call setK() prior to calling this method, and the condition
      -   *  (model.k == this.k) must be met; failure will result in an IllegalArgumentException
      +  /**
      +   * Set the initial GMM starting point, bypassing the random initialization.
      +   * You must call setK() prior to calling this method, and the condition
      +   * (model.k == this.k) must be met; failure will result in an IllegalArgumentException
          */
      +  @Since("1.3.0")
         def setInitialModel(model: GaussianMixtureModel): this.type = {
           if (model.k == k) {
             initialModel = Some(model)
      @@ -85,31 +89,47 @@ class GaussianMixture private (
           this
         }
       
      -  /** Return the user supplied initial GMM, if supplied */
      +  /**
      +   * Return the user supplied initial GMM, if supplied
      +   */
      +  @Since("1.3.0")
         def getInitialModel: Option[GaussianMixtureModel] = initialModel
       
      -  /** Set the number of Gaussians in the mixture model.  Default: 2 */
      +  /**
      +   * Set the number of Gaussians in the mixture model.  Default: 2
      +   */
      +  @Since("1.3.0")
         def setK(k: Int): this.type = {
           this.k = k
           this
         }
       
      -  /** Return the number of Gaussians in the mixture model */
      +  /**
      +   * Return the number of Gaussians in the mixture model
      +   */
      +  @Since("1.3.0")
         def getK: Int = k
       
      -  /** Set the maximum number of iterations to run. Default: 100 */
      +  /**
      +   * Set the maximum number of iterations to run. Default: 100
      +   */
      +  @Since("1.3.0")
         def setMaxIterations(maxIterations: Int): this.type = {
           this.maxIterations = maxIterations
           this
         }
       
      -  /** Return the maximum number of iterations to run */
      +  /**
      +   * Return the maximum number of iterations to run
      +   */
      +  @Since("1.3.0")
         def getMaxIterations: Int = maxIterations
       
         /**
          * Set the largest change in log-likelihood at which convergence is
          * considered to have occurred.
          */
      +  @Since("1.3.0")
         def setConvergenceTol(convergenceTol: Double): this.type = {
           this.convergenceTol = convergenceTol
           this
      @@ -119,18 +139,28 @@ class GaussianMixture private (
          * Return the largest change in log-likelihood at which convergence is
          * considered to have occurred.
          */
      +  @Since("1.3.0")
         def getConvergenceTol: Double = convergenceTol
       
      -  /** Set the random seed */
      +  /**
      +   * Set the random seed
      +   */
      +  @Since("1.3.0")
         def setSeed(seed: Long): this.type = {
           this.seed = seed
           this
         }
       
      -  /** Return the random seed */
      +  /**
      +   * Return the random seed
      +   */
      +  @Since("1.3.0")
         def getSeed: Long = seed
       
      -  /** Perform expectation maximization */
      +  /**
      +   * Perform expectation maximization
      +   */
      +  @Since("1.3.0")
         def run(data: RDD[Vector]): GaussianMixtureModel = {
           val sc = data.sparkContext
       
      @@ -140,6 +170,8 @@ class GaussianMixture private (
           // Get length of the input vectors
           val d = breezeData.first().length
       
      +    val shouldDistributeGaussians = GaussianMixture.shouldDistributeGaussians(k, d)
      +
           // Determine initial weights and corresponding Gaussians.
           // If the user supplied an initial GMM, we use those values, otherwise
           // we start with uniform weights, a random mean from the data, and
      @@ -171,14 +203,25 @@ class GaussianMixture private (
             // Create new distributions based on the partial assignments
             // (often referred to as the "M" step in literature)
             val sumWeights = sums.weights.sum
      -      var i = 0
      -      while (i < k) {
      -        val mu = sums.means(i) / sums.weights(i)
      -        BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu),
      -          Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix])
      -        weights(i) = sums.weights(i) / sumWeights
      -        gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i))
      -        i = i + 1
      +
      +      if (shouldDistributeGaussians) {
      +        val numPartitions = math.min(k, 1024)
      +        val tuples =
      +          Seq.tabulate(k)(i => (sums.means(i), sums.sigmas(i), sums.weights(i)))
      +        val (ws, gs) = sc.parallelize(tuples, numPartitions).map { case (mean, sigma, weight) =>
      +          updateWeightsAndGaussians(mean, sigma, weight, sumWeights)
      +        }.collect().unzip
      +        Array.copy(ws.toArray, 0, weights, 0, ws.length)
      +        Array.copy(gs.toArray, 0, gaussians, 0, gs.length)
      +      } else {
      +        var i = 0
      +        while (i < k) {
      +          val (weight, gaussian) =
      +            updateWeightsAndGaussians(sums.means(i), sums.sigmas(i), sums.weights(i), sumWeights)
      +          weights(i) = weight
      +          gaussians(i) = gaussian
      +          i = i + 1
      +        }
             }
       
             llhp = llh // current becomes previous
      @@ -189,9 +232,25 @@ class GaussianMixture private (
           new GaussianMixtureModel(weights, gaussians)
         }
       
      -  /** Java-friendly version of [[run()]] */
      +  /**
      +   * Java-friendly version of [[run()]]
      +   */
      +  @Since("1.3.0")
         def run(data: JavaRDD[Vector]): GaussianMixtureModel = run(data.rdd)
       
      +  private def updateWeightsAndGaussians(
      +      mean: BDV[Double],
      +      sigma: BreezeMatrix[Double],
      +      weight: Double,
      +      sumWeights: Double): (Double, MultivariateGaussian) = {
      +    val mu = (mean /= weight)
      +    BLAS.syr(-weight, Vectors.fromBreeze(mu),
      +      Matrices.fromBreeze(sigma).asInstanceOf[DenseMatrix])
      +    val newWeight = weight / sumWeights
      +    val newGaussian = new MultivariateGaussian(mu, sigma / weight)
      +    (newWeight, newGaussian)
      +  }
      +
         /** Average of dense breeze vectors */
         private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = {
           val v = BDV.zeros[Double](x(0).length)
      @@ -211,6 +270,16 @@ class GaussianMixture private (
         }
       }
       
      +private[clustering] object GaussianMixture {
      +  /**
      +   * Heuristic to distribute the computation of the [[MultivariateGaussian]]s, approximately when
      +   * d > 25 except for when k is very small.
      +   * @param k  Number of topics
      +   * @param d  Number of features
      +   */
      +  def shouldDistributeGaussians(k: Int, d: Int): Boolean = ((k - 1.0) / k) * d > 25
      +}
      +
       // companion class to provide zero constructor for ExpectationSum
       private object ExpectationSum {
         def zero(k: Int, d: Int): ExpectationSum = {
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
      index cb807c803810..a5902190d463 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
      @@ -24,7 +24,7 @@ import org.json4s.JsonDSL._
       import org.json4s.jackson.JsonMethods._
       
       import org.apache.spark.SparkContext
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.api.java.JavaRDD
       import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix}
       import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
      @@ -44,29 +44,49 @@ import org.apache.spark.sql.{SQLContext, Row}
        * @param gaussians Array of MultivariateGaussian where gaussians(i) represents
        *                  the Multivariate Gaussian (Normal) Distribution for Gaussian i
        */
      +@Since("1.3.0")
       @Experimental
      -class GaussianMixtureModel(
      -  val weights: Array[Double],
      -  val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable {
      +class GaussianMixtureModel @Since("1.3.0") (
      +  @Since("1.3.0") val weights: Array[Double],
      +  @Since("1.3.0") val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable {
       
         require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match")
       
         override protected def formatVersion = "1.0"
       
      +  @Since("1.4.0")
         override def save(sc: SparkContext, path: String): Unit = {
           GaussianMixtureModel.SaveLoadV1_0.save(sc, path, weights, gaussians)
         }
       
      -  /** Number of gaussians in mixture */
      +  /**
      +   * Number of gaussians in mixture
      +   */
      +  @Since("1.3.0")
         def k: Int = weights.length
       
      -  /** Maps given points to their cluster indices. */
      +  /**
      +   * Maps given points to their cluster indices.
      +   */
      +  @Since("1.3.0")
         def predict(points: RDD[Vector]): RDD[Int] = {
           val responsibilityMatrix = predictSoft(points)
           responsibilityMatrix.map(r => r.indexOf(r.max))
         }
       
      -  /** Java-friendly version of [[predict()]] */
      +  /**
      +   * Maps given point to its cluster index.
      +   */
      +  @Since("1.5.0")
      +  def predict(point: Vector): Int = {
      +    val r = computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k)
      +    r.indexOf(r.max)
      +  }
      +
      +  /**
      +   * Java-friendly version of [[predict()]]
      +   */
      +  @Since("1.4.0")
         def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] =
           predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]]
       
      @@ -74,6 +94,7 @@ class GaussianMixtureModel(
          * Given the input vectors, return the membership value of each vector
          * to all mixture components.
          */
      +  @Since("1.3.0")
         def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = {
           val sc = points.sparkContext
           val bcDists = sc.broadcast(gaussians)
      @@ -83,6 +104,14 @@ class GaussianMixtureModel(
           }
         }
       
      +  /**
      +   * Given the input vector, return the membership values to all mixture components.
      +   */
      +  @Since("1.4.0")
      +  def predictSoft(point: Vector): Array[Double] = {
      +    computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k)
      +  }
      +
         /**
          * Compute the partial assignments for each vector
          */
      @@ -102,6 +131,7 @@ class GaussianMixtureModel(
         }
       }
       
      +@Since("1.4.0")
       @Experimental
       object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
       
      @@ -138,20 +168,20 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
             val dataPath = Loader.dataPath(path)
             val sqlContext = new SQLContext(sc)
             val dataFrame = sqlContext.read.parquet(dataPath)
      -      val dataArray = dataFrame.select("weight", "mu", "sigma").collect()
      -
             // Check schema explicitly since erasure makes it hard to use match-case for checking.
             Loader.checkSchema[Data](dataFrame.schema)
      +      val dataArray = dataFrame.select("weight", "mu", "sigma").collect()
       
             val (weights, gaussians) = dataArray.map {
               case Row(weight: Double, mu: Vector, sigma: Matrix) =>
                 (weight, new MultivariateGaussian(mu, sigma))
             }.unzip
       
      -      return new GaussianMixtureModel(weights.toArray, gaussians.toArray)
      +      new GaussianMixtureModel(weights.toArray, gaussians.toArray)
           }
         }
       
      +  @Since("1.4.0")
         override def load(sc: SparkContext, path: String) : GaussianMixtureModel = {
           val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
           implicit val formats = DefaultFormats
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
      index 0f8d6a399682..7168aac32c99 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
      @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering
       import scala.collection.mutable.ArrayBuffer
       
       import org.apache.spark.Logging
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.mllib.linalg.{Vector, Vectors}
       import org.apache.spark.mllib.linalg.BLAS.{axpy, scal}
       import org.apache.spark.mllib.util.MLUtils
      @@ -37,6 +37,7 @@ import org.apache.spark.util.random.XORShiftRandom
        * This is an iterative algorithm that will make multiple passes over the data, so any RDDs given
        * to it should be cached by the user.
        */
      +@Since("0.8.0")
       class KMeans private (
           private var k: Int,
           private var maxIterations: Int,
      @@ -50,14 +51,19 @@ class KMeans private (
          * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1,
          * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4, seed: random}.
          */
      +  @Since("0.8.0")
         def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4, Utils.random.nextLong())
       
         /**
          * Number of clusters to create (k).
          */
      +  @Since("1.4.0")
         def getK: Int = k
       
      -  /** Set the number of clusters to create (k). Default: 2. */
      +  /**
      +   * Set the number of clusters to create (k). Default: 2.
      +   */
      +  @Since("0.8.0")
         def setK(k: Int): this.type = {
           this.k = k
           this
      @@ -66,9 +72,13 @@ class KMeans private (
         /**
          * Maximum number of iterations to run.
          */
      +  @Since("1.4.0")
         def getMaxIterations: Int = maxIterations
       
      -  /** Set maximum number of iterations to run. Default: 20. */
      +  /**
      +   * Set maximum number of iterations to run. Default: 20.
      +   */
      +  @Since("0.8.0")
         def setMaxIterations(maxIterations: Int): this.type = {
           this.maxIterations = maxIterations
           this
      @@ -77,6 +87,7 @@ class KMeans private (
         /**
          * The initialization algorithm. This can be either "random" or "k-means||".
          */
      +  @Since("1.4.0")
         def getInitializationMode: String = initializationMode
       
         /**
      @@ -84,10 +95,9 @@ class KMeans private (
          * initial cluster centers, or "k-means||" to use a parallel variant of k-means++
          * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||.
          */
      +  @Since("0.8.0")
         def setInitializationMode(initializationMode: String): this.type = {
      -    if (initializationMode != KMeans.RANDOM && initializationMode != KMeans.K_MEANS_PARALLEL) {
      -      throw new IllegalArgumentException("Invalid initialization mode: " + initializationMode)
      -    }
      +    KMeans.validateInitMode(initializationMode)
           this.initializationMode = initializationMode
           this
         }
      @@ -96,6 +106,7 @@ class KMeans private (
          * :: Experimental ::
          * Number of runs of the algorithm to execute in parallel.
          */
      +  @Since("1.4.0")
         @Experimental
         def getRuns: Int = runs
       
      @@ -105,6 +116,7 @@ class KMeans private (
          * this many times with random starting conditions (configured by the initialization mode), then
          * return the best clustering found over any run. Default: 1.
          */
      +  @Since("0.8.0")
         @Experimental
         def setRuns(runs: Int): this.type = {
           if (runs <= 0) {
      @@ -117,12 +129,14 @@ class KMeans private (
         /**
          * Number of steps for the k-means|| initialization mode
          */
      +  @Since("1.4.0")
         def getInitializationSteps: Int = initializationSteps
       
         /**
          * Set the number of steps for the k-means|| initialization mode. This is an advanced
          * setting -- the default of 5 is almost always enough. Default: 5.
          */
      +  @Since("0.8.0")
         def setInitializationSteps(initializationSteps: Int): this.type = {
           if (initializationSteps <= 0) {
             throw new IllegalArgumentException("Number of initialization steps must be positive")
      @@ -134,12 +148,14 @@ class KMeans private (
         /**
          * The distance threshold within which we've consider centers to have converged.
          */
      +  @Since("1.4.0")
         def getEpsilon: Double = epsilon
       
         /**
          * Set the distance threshold within which we've consider centers to have converged.
          * If all centers move less than this Euclidean distance, we stop iterating one run.
          */
      +  @Since("0.8.0")
         def setEpsilon(epsilon: Double): this.type = {
           this.epsilon = epsilon
           this
      @@ -148,18 +164,39 @@ class KMeans private (
         /**
          * The random seed for cluster initialization.
          */
      +  @Since("1.4.0")
         def getSeed: Long = seed
       
      -  /** Set the random seed for cluster initialization. */
      +  /**
      +   * Set the random seed for cluster initialization.
      +   */
      +  @Since("1.4.0")
         def setSeed(seed: Long): this.type = {
           this.seed = seed
           this
         }
       
      +  // Initial cluster centers can be provided as a KMeansModel object rather than using the
      +  // random or k-means|| initializationMode
      +  private var initialModel: Option[KMeansModel] = None
      +
      +  /**
      +   * Set the initial starting point, bypassing the random initialization or k-means||
      +   * The condition model.k == this.k must be met, failure results
      +   * in an IllegalArgumentException.
      +   */
      +  @Since("1.4.0")
      +  def setInitialModel(model: KMeansModel): this.type = {
      +    require(model.k == k, "mismatched cluster count")
      +    initialModel = Some(model)
      +    this
      +  }
      +
         /**
          * Train a K-means model on the given set of points; `data` should be cached for high
          * performance, because this is an iterative algorithm.
          */
      +  @Since("0.8.0")
         def run(data: RDD[Vector]): KMeansModel = {
       
           if (data.getStorageLevel == StorageLevel.NONE) {
      @@ -193,20 +230,34 @@ class KMeans private (
       
           val initStartTime = System.nanoTime()
       
      -    val centers = if (initializationMode == KMeans.RANDOM) {
      -      initRandom(data)
      +    // Only one run is allowed when initialModel is given
      +    val numRuns = if (initialModel.nonEmpty) {
      +      if (runs > 1) logWarning("Ignoring runs; one run is allowed when initialModel is given.")
      +      1
           } else {
      -      initKMeansParallel(data)
      +      runs
           }
       
      +    val centers = initialModel match {
      +      case Some(kMeansCenters) => {
      +        Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s)))
      +      }
      +      case None => {
      +        if (initializationMode == KMeans.RANDOM) {
      +          initRandom(data)
      +        } else {
      +          initKMeansParallel(data)
      +        }
      +      }
      +    }
           val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
           logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) +
             " seconds.")
       
      -    val active = Array.fill(runs)(true)
      -    val costs = Array.fill(runs)(0.0)
      +    val active = Array.fill(numRuns)(true)
      +    val costs = Array.fill(numRuns)(0.0)
       
      -    var activeRuns = new ArrayBuffer[Int] ++ (0 until runs)
      +    var activeRuns = new ArrayBuffer[Int] ++ (0 until numRuns)
           var iteration = 0
       
           val iterationStartTime = System.nanoTime()
      @@ -318,7 +369,7 @@ class KMeans private (
         : Array[Array[VectorWithNorm]] = {
           // Initialize empty centers and point costs.
           val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm])
      -    var costs = data.map(_ => Vectors.dense(Array.fill(runs)(Double.PositiveInfinity))).cache()
      +    var costs = data.map(_ => Array.fill(runs)(Double.PositiveInfinity))
       
           // Initialize each run's first center to a random point.
           val seed = new XORShiftRandom(this.seed).nextInt()
      @@ -343,21 +394,28 @@ class KMeans private (
             val bcNewCenters = data.context.broadcast(newCenters)
             val preCosts = costs
             costs = data.zip(preCosts).map { case (point, cost) =>
      -        Vectors.dense(
                 Array.tabulate(runs) { r =>
                   math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r))
      -          })
      -      }.cache()
      +          }
      +        }.persist(StorageLevel.MEMORY_AND_DISK)
             val sumCosts = costs
      -        .aggregate(Vectors.zeros(runs))(
      +        .aggregate(new Array[Double](runs))(
                 seqOp = (s, v) => {
                   // s += v
      -            axpy(1.0, v, s)
      +            var r = 0
      +            while (r < runs) {
      +              s(r) += v(r)
      +              r += 1
      +            }
                   s
                 },
                 combOp = (s0, s1) => {
                   // s0 += s1
      -            axpy(1.0, s1, s0)
      +            var r = 0
      +            while (r < runs) {
      +              s0(r) += s1(r)
      +              r += 1
      +            }
                   s0
                 }
               )
      @@ -404,10 +462,13 @@ class KMeans private (
       /**
        * Top-level methods for calling K-means clustering.
        */
      +@Since("0.8.0")
       object KMeans {
       
         // Initialization mode names
      +  @Since("0.8.0")
         val RANDOM = "random"
      +  @Since("0.8.0")
         val K_MEANS_PARALLEL = "k-means||"
       
         /**
      @@ -420,6 +481,7 @@ object KMeans {
          * @param initializationMode initialization model, either "random" or "k-means||" (default).
          * @param seed random seed value for cluster initialization
          */
      +  @Since("1.3.0")
         def train(
             data: RDD[Vector],
             k: Int,
      @@ -444,6 +506,7 @@ object KMeans {
          * @param runs number of parallel runs, defaults to 1. The best model is returned.
          * @param initializationMode initialization model, either "random" or "k-means||" (default).
          */
      +  @Since("0.8.0")
         def train(
             data: RDD[Vector],
             k: Int,
      @@ -460,6 +523,7 @@ object KMeans {
         /**
          * Trains a k-means model using specified parameters and the default values for unspecified.
          */
      +  @Since("0.8.0")
         def train(
             data: RDD[Vector],
             k: Int,
      @@ -470,6 +534,7 @@ object KMeans {
         /**
          * Trains a k-means model using specified parameters and the default values for unspecified.
          */
      +  @Since("0.8.0")
         def train(
             data: RDD[Vector],
             k: Int,
      @@ -521,6 +586,14 @@ object KMeans {
             v2: VectorWithNorm): Double = {
           MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm)
         }
      +
      +  private[spark] def validateInitMode(initMode: String): Boolean = {
      +    initMode match {
      +      case KMeans.RANDOM => true
      +      case KMeans.K_MEANS_PARALLEL => true
      +      case _ => false
      +    }
      +  }
       }
       
       /**
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
      index 8ecb3df11d95..a74158498272 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
      @@ -23,6 +23,7 @@ import org.json4s._
       import org.json4s.JsonDSL._
       import org.json4s.jackson.JsonMethods._
       
      +import org.apache.spark.annotation.Since
       import org.apache.spark.api.java.JavaRDD
       import org.apache.spark.mllib.linalg.Vector
       import org.apache.spark.mllib.pmml.PMMLExportable
      @@ -35,28 +36,44 @@ import org.apache.spark.sql.Row
       /**
        * A clustering model for K-means. Each point belongs to the cluster with the closest center.
        */
      -class KMeansModel (
      -    val clusterCenters: Array[Vector]) extends Saveable with Serializable with PMMLExportable {
      +@Since("0.8.0")
      +class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vector])
      +  extends Saveable with Serializable with PMMLExportable {
       
      -  /** A Java-friendly constructor that takes an Iterable of Vectors. */
      +  /**
      +   * A Java-friendly constructor that takes an Iterable of Vectors.
      +   */
      +  @Since("1.4.0")
         def this(centers: java.lang.Iterable[Vector]) = this(centers.asScala.toArray)
       
      -  /** Total number of clusters. */
      +  /**
      +   * Total number of clusters.
      +   */
      +  @Since("0.8.0")
         def k: Int = clusterCenters.length
       
      -  /** Returns the cluster index that a given point belongs to. */
      +  /**
      +   * Returns the cluster index that a given point belongs to.
      +   */
      +  @Since("0.8.0")
         def predict(point: Vector): Int = {
           KMeans.findClosest(clusterCentersWithNorm, new VectorWithNorm(point))._1
         }
       
      -  /** Maps given points to their cluster indices. */
      +  /**
      +   * Maps given points to their cluster indices.
      +   */
      +  @Since("1.0.0")
         def predict(points: RDD[Vector]): RDD[Int] = {
           val centersWithNorm = clusterCentersWithNorm
           val bcCentersWithNorm = points.context.broadcast(centersWithNorm)
           points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1)
         }
       
      -  /** Maps given points to their cluster indices. */
      +  /**
      +   * Maps given points to their cluster indices.
      +   */
      +  @Since("1.0.0")
         def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] =
           predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]]
       
      @@ -64,6 +81,7 @@ class KMeansModel (
          * Return the K-means cost (sum of squared distances of points to their nearest center) for this
          * model on the given data.
          */
      +  @Since("0.8.0")
         def computeCost(data: RDD[Vector]): Double = {
           val centersWithNorm = clusterCentersWithNorm
           val bcCentersWithNorm = data.context.broadcast(centersWithNorm)
      @@ -73,6 +91,7 @@ class KMeansModel (
         private def clusterCentersWithNorm: Iterable[VectorWithNorm] =
           clusterCenters.map(new VectorWithNorm(_))
       
      +  @Since("1.4.0")
         override def save(sc: SparkContext, path: String): Unit = {
           KMeansModel.SaveLoadV1_0.save(sc, this, path)
         }
      @@ -80,7 +99,10 @@ class KMeansModel (
         override protected def formatVersion: String = "1.0"
       }
       
      +@Since("1.4.0")
       object KMeansModel extends Loader[KMeansModel] {
      +
      +  @Since("1.4.0")
         override def load(sc: SparkContext, path: String): KMeansModel = {
           KMeansModel.SaveLoadV1_0.load(sc, path)
         }
      @@ -120,11 +142,11 @@ object KMeansModel extends Loader[KMeansModel] {
             assert(className == thisClassName)
             assert(formatVersion == thisFormatVersion)
             val k = (metadata \ "k").extract[Int]
      -      val centriods = sqlContext.read.parquet(Loader.dataPath(path))
      -      Loader.checkSchema[Cluster](centriods.schema)
      -      val localCentriods = centriods.map(Cluster.apply).collect()
      -      assert(k == localCentriods.size)
      -      new KMeansModel(localCentriods.sortBy(_.id).map(_.point))
      +      val centroids = sqlContext.read.parquet(Loader.dataPath(path))
      +      Loader.checkSchema[Cluster](centroids.schema)
      +      val localCentroids = centroids.map(Cluster.apply).collect()
      +      assert(k == localCentroids.size)
      +      new KMeansModel(localCentroids.sortBy(_.id).map(_.point))
           }
         }
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
      index a410547a72fd..92a321afb0ca 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
      @@ -20,14 +20,13 @@ package org.apache.spark.mllib.clustering
       import breeze.linalg.{DenseVector => BDV}
       
       import org.apache.spark.Logging
      -import org.apache.spark.annotation.{DeveloperApi, Experimental}
      +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
       import org.apache.spark.api.java.JavaPairRDD
       import org.apache.spark.graphx._
      -import org.apache.spark.mllib.linalg.Vector
      +import org.apache.spark.mllib.linalg.{Vector, Vectors}
       import org.apache.spark.rdd.RDD
       import org.apache.spark.util.Utils
       
      -
       /**
        * :: Experimental ::
        *
      @@ -45,28 +44,37 @@ import org.apache.spark.util.Utils
        * @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation
        *       (Wikipedia)]]
        */
      +@Since("1.3.0")
       @Experimental
       class LDA private (
           private var k: Int,
           private var maxIterations: Int,
      -    private var docConcentration: Double,
      +    private var docConcentration: Vector,
           private var topicConcentration: Double,
           private var seed: Long,
           private var checkpointInterval: Int,
           private var ldaOptimizer: LDAOptimizer) extends Logging {
       
      -  def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1,
      -    seed = Utils.random.nextLong(), checkpointInterval = 10, ldaOptimizer = new EMLDAOptimizer)
      +  /**
      +   * Constructs a LDA instance with default parameters.
      +   */
      +  @Since("1.3.0")
      +  def this() = this(k = 10, maxIterations = 20, docConcentration = Vectors.dense(-1),
      +    topicConcentration = -1, seed = Utils.random.nextLong(), checkpointInterval = 10,
      +    ldaOptimizer = new EMLDAOptimizer)
       
         /**
          * Number of topics to infer.  I.e., the number of soft cluster centers.
      +   *
          */
      +  @Since("1.3.0")
         def getK: Int = k
       
         /**
          * Number of topics to infer.  I.e., the number of soft cluster centers.
          * (default = 10)
          */
      +  @Since("1.3.0")
         def setK(k: Int): this.type = {
           require(k > 0, s"LDA k (number of clusters) must be > 0, but was set to $k")
           this.k = k
      @@ -77,39 +85,91 @@ class LDA private (
          * Concentration parameter (commonly named "alpha") for the prior placed on documents'
          * distributions over topics ("theta").
          *
      -   * This is the parameter to a symmetric Dirichlet distribution.
      +   * This is the parameter to a Dirichlet distribution.
          */
      -  def getDocConcentration: Double = this.docConcentration
      +  @Since("1.5.0")
      +  def getAsymmetricDocConcentration: Vector = this.docConcentration
       
         /**
          * Concentration parameter (commonly named "alpha") for the prior placed on documents'
          * distributions over topics ("theta").
          *
      -   * This is the parameter to a symmetric Dirichlet distribution, where larger values
      -   * mean more smoothing (more regularization).
      +   * This method assumes the Dirichlet distribution is symmetric and can be described by a single
      +   * [[Double]] parameter. It should fail if docConcentration is asymmetric.
      +   */
      +  @Since("1.3.0")
      +  def getDocConcentration: Double = {
      +    val parameter = docConcentration(0)
      +    if (docConcentration.size == 1) {
      +      parameter
      +    } else {
      +      require(docConcentration.toArray.forall(_ == parameter))
      +      parameter
      +    }
      +  }
      +
      +  /**
      +   * Concentration parameter (commonly named "alpha") for the prior placed on documents'
      +   * distributions over topics ("theta").
          *
      -   * If set to -1, then docConcentration is set automatically.
      -   *  (default = -1 = automatic)
      +   * This is the parameter to a Dirichlet distribution, where larger values mean more smoothing
      +   * (more regularization).
      +   *
      +   * If set to a singleton vector Vector(-1), then docConcentration is set automatically. If set to
      +   * singleton vector Vector(t) where t != -1, then t is replicated to a vector of length k during
      +   * [[LDAOptimizer.initialize()]]. Otherwise, the [[docConcentration]] vector must be length k.
      +   * (default = Vector(-1) = automatic)
          *
          * Optimizer-specific parameter settings:
          *  - EM
      -   *     - Value should be > 1.0
      -   *     - default = (50 / k) + 1, where 50/k is common in LDA libraries and +1 follows
      -   *       Asuncion et al. (2009), who recommend a +1 adjustment for EM.
      +   *     - Currently only supports symmetric distributions, so all values in the vector should be
      +   *       the same.
      +   *     - Values should be > 1.0
      +   *     - default = uniformly (50 / k) + 1, where 50/k is common in LDA libraries and +1 follows
      +   *       from Asuncion et al. (2009), who recommend a +1 adjustment for EM.
          *  - Online
      -   *     - Value should be >= 0
      -   *     - default = (1.0 / k), following the implementation from
      +   *     - Values should be >= 0
      +   *     - default = uniformly (1.0 / k), following the implementation from
          *       [[https://github.com/Blei-Lab/onlineldavb]].
          */
      -  def setDocConcentration(docConcentration: Double): this.type = {
      +  @Since("1.5.0")
      +  def setDocConcentration(docConcentration: Vector): this.type = {
      +    require(docConcentration.size > 0, "docConcentration must have > 0 elements")
           this.docConcentration = docConcentration
           this
         }
       
      -  /** Alias for [[getDocConcentration]] */
      +  /**
      +   * Replicates a [[Double]] docConcentration to create a symmetric prior.
      +   */
      +  @Since("1.3.0")
      +  def setDocConcentration(docConcentration: Double): this.type = {
      +    this.docConcentration = Vectors.dense(docConcentration)
      +    this
      +  }
      +
      +  /**
      +   * Alias for [[getAsymmetricDocConcentration]]
      +   */
      +  @Since("1.5.0")
      +  def getAsymmetricAlpha: Vector = getAsymmetricDocConcentration
      +
      +  /**
      +   * Alias for [[getDocConcentration]]
      +   */
      +  @Since("1.3.0")
         def getAlpha: Double = getDocConcentration
       
      -  /** Alias for [[setDocConcentration()]] */
      +  /**
      +   * Alias for [[setDocConcentration()]]
      +   */
      +  @Since("1.5.0")
      +  def setAlpha(alpha: Vector): this.type = setDocConcentration(alpha)
      +
      +  /**
      +   * Alias for [[setDocConcentration()]]
      +   */
      +  @Since("1.3.0")
         def setAlpha(alpha: Double): this.type = setDocConcentration(alpha)
       
         /**
      @@ -121,6 +181,7 @@ class LDA private (
          * Note: The topics' distributions over terms are called "beta" in the original LDA paper
          * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009.
          */
      +  @Since("1.3.0")
         def getTopicConcentration: Double = this.topicConcentration
       
         /**
      @@ -145,35 +206,50 @@ class LDA private (
          *     - default = (1.0 / k), following the implementation from
          *       [[https://github.com/Blei-Lab/onlineldavb]].
          */
      +  @Since("1.3.0")
         def setTopicConcentration(topicConcentration: Double): this.type = {
           this.topicConcentration = topicConcentration
           this
         }
       
      -  /** Alias for [[getTopicConcentration]] */
      +  /**
      +   * Alias for [[getTopicConcentration]]
      +   */
      +  @Since("1.3.0")
         def getBeta: Double = getTopicConcentration
       
      -  /** Alias for [[setTopicConcentration()]] */
      +  /**
      +   * Alias for [[setTopicConcentration()]]
      +   */
      +  @Since("1.3.0")
         def setBeta(beta: Double): this.type = setTopicConcentration(beta)
       
         /**
          * Maximum number of iterations for learning.
          */
      +  @Since("1.3.0")
         def getMaxIterations: Int = maxIterations
       
         /**
          * Maximum number of iterations for learning.
          * (default = 20)
          */
      +  @Since("1.3.0")
         def setMaxIterations(maxIterations: Int): this.type = {
           this.maxIterations = maxIterations
           this
         }
       
      -  /** Random seed */
      +  /**
      +   * Random seed
      +   */
      +  @Since("1.3.0")
         def getSeed: Long = seed
       
      -  /** Random seed */
      +  /**
      +   * Random seed
      +   */
      +  @Since("1.3.0")
         def setSeed(seed: Long): this.type = {
           this.seed = seed
           this
      @@ -182,6 +258,7 @@ class LDA private (
         /**
          * Period (in iterations) between checkpoints.
          */
      +  @Since("1.3.0")
         def getCheckpointInterval: Int = checkpointInterval
       
         /**
      @@ -192,6 +269,7 @@ class LDA private (
          *
          * @see [[org.apache.spark.SparkContext#setCheckpointDir]]
          */
      +  @Since("1.3.0")
         def setCheckpointInterval(checkpointInterval: Int): this.type = {
           this.checkpointInterval = checkpointInterval
           this
      @@ -203,6 +281,7 @@ class LDA private (
          *
          * LDAOptimizer used to perform the actual calculation
          */
      +  @Since("1.4.0")
         @DeveloperApi
         def getOptimizer: LDAOptimizer = ldaOptimizer
       
      @@ -211,6 +290,7 @@ class LDA private (
          *
          * LDAOptimizer used to perform the actual calculation (default = EMLDAOptimizer)
          */
      +  @Since("1.4.0")
         @DeveloperApi
         def setOptimizer(optimizer: LDAOptimizer): this.type = {
           this.ldaOptimizer = optimizer
      @@ -221,6 +301,7 @@ class LDA private (
          * Set the LDAOptimizer used to perform the actual calculation by algorithm name.
          * Currently "em", "online" are supported.
          */
      +  @Since("1.4.0")
         def setOptimizer(optimizerName: String): this.type = {
           this.ldaOptimizer =
             optimizerName.toLowerCase match {
      @@ -241,6 +322,7 @@ class LDA private (
          *                   Document IDs must be unique and >= 0.
          * @return  Inferred LDA model
          */
      +  @Since("1.3.0")
         def run(documents: RDD[(Long, Vector)]): LDAModel = {
           val state = ldaOptimizer.initialize(documents, this)
           var iter = 0
      @@ -255,7 +337,10 @@ class LDA private (
           state.getLDAModel(iterationTimes)
         }
       
      -  /** Java-friendly version of [[run()]] */
      +  /**
      +   * Java-friendly version of [[run()]]
      +   */
      +  @Since("1.3.0")
         def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = {
           run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]])
         }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
      index 974b26924dfb..15129e0dd5a9 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
      @@ -17,13 +17,21 @@
       
       package org.apache.spark.mllib.clustering
       
      -import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum}
      -
      -import org.apache.spark.annotation.Experimental
      -import org.apache.spark.api.java.JavaPairRDD
      -import org.apache.spark.graphx.{VertexId, EdgeContext, Graph}
      -import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix}
      +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax, argtopk, normalize, sum}
      +import breeze.numerics.{exp, lgamma}
      +import org.apache.hadoop.fs.Path
      +import org.json4s.DefaultFormats
      +import org.json4s.JsonDSL._
      +import org.json4s.jackson.JsonMethods._
      +
      +import org.apache.spark.SparkContext
      +import org.apache.spark.annotation.{Experimental, Since}
      +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
      +import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId}
      +import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors}
      +import org.apache.spark.mllib.util.{Loader, Saveable}
       import org.apache.spark.rdd.RDD
      +import org.apache.spark.sql.{Row, SQLContext}
       import org.apache.spark.util.BoundedPriorityQueue
       
       /**
      @@ -35,33 +43,61 @@ import org.apache.spark.util.BoundedPriorityQueue
        * including local and distributed data structures.
        */
       @Experimental
      -abstract class LDAModel private[clustering] {
      +@Since("1.3.0")
      +abstract class LDAModel private[clustering] extends Saveable {
       
         /** Number of topics */
      +  @Since("1.3.0")
         def k: Int
       
         /** Vocabulary size (number of terms or terms in the vocabulary) */
      +  @Since("1.3.0")
         def vocabSize: Int
       
      +  /**
      +   * Concentration parameter (commonly named "alpha") for the prior placed on documents'
      +   * distributions over topics ("theta").
      +   *
      +   * This is the parameter to a Dirichlet distribution.
      +   */
      +  @Since("1.5.0")
      +  def docConcentration: Vector
      +
      +  /**
      +   * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics'
      +   * distributions over terms.
      +   *
      +   * This is the parameter to a symmetric Dirichlet distribution.
      +   *
      +   * Note: The topics' distributions over terms are called "beta" in the original LDA paper
      +   * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009.
      +   */
      +  @Since("1.5.0")
      +  def topicConcentration: Double
      +
      +  /**
      +  * Shape parameter for random initialization of variational parameter gamma.
      +  * Used for variational inference for perplexity and other test-time computations.
      +  */
      +  protected def gammaShape: Double
      +
         /**
          * Inferred topics, where each topic is represented by a distribution over terms.
          * This is a matrix of size vocabSize x k, where each column is a topic.
          * No guarantees are given about the ordering of the topics.
          */
      +  @Since("1.3.0")
         def topicsMatrix: Matrix
       
         /**
          * Return the topics described by weighted terms.
          *
      -   * This limits the number of terms per topic.
      -   * This is approximate; it may not return exactly the top-weighted terms for each topic.
      -   * To get a more precise set of top terms, increase maxTermsPerTopic.
      -   *
          * @param maxTermsPerTopic  Maximum number of terms to collect for each topic.
          * @return  Array over topics.  Each topic is represented as a pair of matching arrays:
          *          (term indices, term weights in topic).
          *          Each topic's terms are sorted in order of decreasing weight.
          */
      +  @Since("1.3.0")
         def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])]
       
         /**
      @@ -73,6 +109,7 @@ abstract class LDAModel private[clustering] {
          *          (term indices, term weights in topic).
          *          Each topic's terms are sorted in order of decreasing weight.
          */
      +  @Since("1.3.0")
         def describeTopics(): Array[(Array[Int], Array[Double])] = describeTopics(vocabSize)
       
         /* TODO (once LDA can be trained with Strings or given a dictionary)
      @@ -153,19 +190,27 @@ abstract class LDAModel private[clustering] {
        * This model stores only the inferred topics.
        * It may be used for computing topics for new documents, but it may give less accurate answers
        * than the [[DistributedLDAModel]].
      - *
        * @param topics Inferred topics (vocabSize x k matrix).
        */
       @Experimental
      +@Since("1.3.0")
       class LocalLDAModel private[clustering] (
      -    private val topics: Matrix) extends LDAModel with Serializable {
      +    @Since("1.3.0") val topics: Matrix,
      +    @Since("1.5.0") override val docConcentration: Vector,
      +    @Since("1.5.0") override val topicConcentration: Double,
      +    override protected[clustering] val gammaShape: Double = 100)
      +  extends LDAModel with Serializable {
       
      +  @Since("1.3.0")
         override def k: Int = topics.numCols
       
      +  @Since("1.3.0")
         override def vocabSize: Int = topics.numRows
       
      +  @Since("1.3.0")
         override def topicsMatrix: Matrix = topics
       
      +  @Since("1.3.0")
         override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = {
           val brzTopics = topics.toBreeze.toDenseMatrix
           Range(0, k).map { topicIndex =>
      @@ -176,14 +221,265 @@ class LocalLDAModel private[clustering] (
           }.toArray
         }
       
      -  // TODO
      -  // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???
      +  override protected def formatVersion = "1.0"
       
      -  // TODO:
      -  // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???
      +  @Since("1.5.0")
      +  override def save(sc: SparkContext, path: String): Unit = {
      +    LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration,
      +      gammaShape)
      +  }
      +
      +  // TODO: declare in LDAModel and override once implemented in DistributedLDAModel
      +  /**
      +   * Calculates a lower bound on the log likelihood of the entire corpus.
      +   *
      +   * See Equation (16) in original Online LDA paper.
      +   *
      +   * @param documents test corpus to use for calculating log likelihood
      +   * @return variational lower bound on the log likelihood of the entire corpus
      +   */
      +  @Since("1.5.0")
      +  def logLikelihood(documents: RDD[(Long, Vector)]): Double = logLikelihoodBound(documents,
      +    docConcentration, topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k,
      +    vocabSize)
      +
      +  /**
      +   * Java-friendly version of [[logLikelihood]]
      +   */
      +  @Since("1.5.0")
      +  def logLikelihood(documents: JavaPairRDD[java.lang.Long, Vector]): Double = {
      +    logLikelihood(documents.rdd.asInstanceOf[RDD[(Long, Vector)]])
      +  }
      +
      +  /**
      +   * Calculate an upper bound bound on perplexity.  (Lower is better.)
      +   * See Equation (16) in original Online LDA paper.
      +   *
      +   * @param documents test corpus to use for calculating perplexity
      +   * @return Variational upper bound on log perplexity per token.
      +   */
      +  @Since("1.5.0")
      +  def logPerplexity(documents: RDD[(Long, Vector)]): Double = {
      +    val corpusTokenCount = documents
      +      .map { case (_, termCounts) => termCounts.toArray.sum }
      +      .sum()
      +    -logLikelihood(documents) / corpusTokenCount
      +  }
      +
      +  /** Java-friendly version of [[logPerplexity]] */
      +  @Since("1.5.0")
      +  def logPerplexity(documents: JavaPairRDD[java.lang.Long, Vector]): Double = {
      +    logPerplexity(documents.rdd.asInstanceOf[RDD[(Long, Vector)]])
      +  }
      +
      +  /**
      +   * Estimate the variational likelihood bound of from `documents`:
      +   *    log p(documents) >= E_q[log p(documents)] - E_q[log q(documents)]
      +   * This bound is derived by decomposing the LDA model to:
      +   *    log p(documents) = E_q[log p(documents)] - E_q[log q(documents)] + D(q|p)
      +   * and noting that the KL-divergence D(q|p) >= 0.
      +   *
      +   * See Equation (16) in original Online LDA paper, as well as Appendix A.3 in the JMLR version of
      +   * the original LDA paper.
      +   * @param documents a subset of the test corpus
      +   * @param alpha document-topic Dirichlet prior parameters
      +   * @param eta topic-word Dirichlet prior parameter
      +   * @param lambda parameters for variational q(beta | lambda) topic-word distributions
      +   * @param gammaShape shape parameter for random initialization of variational q(theta | gamma)
      +   *                   topic mixture distributions
      +   * @param k number of topics
      +   * @param vocabSize number of unique terms in the entire test corpus
      +   */
      +  private def logLikelihoodBound(
      +      documents: RDD[(Long, Vector)],
      +      alpha: Vector,
      +      eta: Double,
      +      lambda: BDM[Double],
      +      gammaShape: Double,
      +      k: Int,
      +      vocabSize: Long): Double = {
      +    val brzAlpha = alpha.toBreeze.toDenseVector
      +    // transpose because dirichletExpectation normalizes by row and we need to normalize
      +    // by topic (columns of lambda)
      +    val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t
      +    val ElogbetaBc = documents.sparkContext.broadcast(Elogbeta)
      +
      +    // Sum bound components for each document:
      +    //  component for prob(tokens) + component for prob(document-topic distribution)
      +    val corpusPart =
      +      documents.filter(_._2.numNonzeros > 0).map { case (id: Long, termCounts: Vector) =>
      +        val localElogbeta = ElogbetaBc.value
      +        var docBound = 0.0D
      +        val (gammad: BDV[Double], _) = OnlineLDAOptimizer.variationalTopicInference(
      +          termCounts, exp(localElogbeta), brzAlpha, gammaShape, k)
      +        val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad)
      +
      +        // E[log p(doc | theta, beta)]
      +        termCounts.foreachActive { case (idx, count) =>
      +          docBound += count * LDAUtils.logSumExp(Elogthetad + localElogbeta(idx, ::).t)
      +        }
      +        // E[log p(theta | alpha) - log q(theta | gamma)]
      +        docBound += sum((brzAlpha - gammad) :* Elogthetad)
      +        docBound += sum(lgamma(gammad) - lgamma(brzAlpha))
      +        docBound += lgamma(sum(brzAlpha)) - lgamma(sum(gammad))
      +
      +        docBound
      +      }.sum()
      +
      +    // Bound component for prob(topic-term distributions):
      +    //   E[log p(beta | eta) - log q(beta | lambda)]
      +    val sumEta = eta * vocabSize
      +    val topicsPart = sum((eta - lambda) :* Elogbeta) +
      +      sum(lgamma(lambda) - lgamma(eta)) +
      +      sum(lgamma(sumEta) - lgamma(sum(lambda(::, breeze.linalg.*))))
      +
      +    corpusPart + topicsPart
      +  }
      +
      +  /**
      +   * Predicts the topic mixture distribution for each document (often called "theta" in the
      +   * literature).  Returns a vector of zeros for an empty document.
      +   *
      +   * This uses a variational approximation following Hoffman et al. (2010), where the approximate
      +   * distribution is called "gamma."  Technically, this method returns this approximation "gamma"
      +   * for each document.
      +   * @param documents documents to predict topic mixture distributions for
      +   * @return An RDD of (document ID, topic mixture distribution for document)
      +   */
      +  @Since("1.3.0")
      +  // TODO: declare in LDAModel and override once implemented in DistributedLDAModel
      +  def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = {
      +    // Double transpose because dirichletExpectation normalizes by row and we need to normalize
      +    // by topic (columns of lambda)
      +    val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t)
      +    val expElogbetaBc = documents.sparkContext.broadcast(expElogbeta)
      +    val docConcentrationBrz = this.docConcentration.toBreeze
      +    val gammaShape = this.gammaShape
      +    val k = this.k
      +
      +    documents.map { case (id: Long, termCounts: Vector) =>
      +      if (termCounts.numNonzeros == 0) {
      +         (id, Vectors.zeros(k))
      +      } else {
      +        val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference(
      +          termCounts,
      +          expElogbetaBc.value,
      +          docConcentrationBrz,
      +          gammaShape,
      +          k)
      +        (id, Vectors.dense(normalize(gamma, 1.0).toArray))
      +      }
      +    }
      +  }
      +
      +  /**
      +   * Java-friendly version of [[topicDistributions]]
      +   */
      +  @Since("1.4.1")
      +  def topicDistributions(
      +      documents: JavaPairRDD[java.lang.Long, Vector]): JavaPairRDD[java.lang.Long, Vector] = {
      +    val distributions = topicDistributions(documents.rdd.asInstanceOf[RDD[(Long, Vector)]])
      +    JavaPairRDD.fromRDD(distributions.asInstanceOf[RDD[(java.lang.Long, Vector)]])
      +  }
       
       }
       
      +@Experimental
      +@Since("1.5.0")
      +object LocalLDAModel extends Loader[LocalLDAModel] {
      +
      +  private object SaveLoadV1_0 {
      +
      +    val thisFormatVersion = "1.0"
      +
      +    val thisClassName = "org.apache.spark.mllib.clustering.LocalLDAModel"
      +
      +    // Store the distribution of terms of each topic and the column index in topicsMatrix
      +    // as a Row in data.
      +    case class Data(topic: Vector, index: Int)
      +
      +    def save(
      +        sc: SparkContext,
      +        path: String,
      +        topicsMatrix: Matrix,
      +        docConcentration: Vector,
      +        topicConcentration: Double,
      +        gammaShape: Double): Unit = {
      +      val sqlContext = SQLContext.getOrCreate(sc)
      +      import sqlContext.implicits._
      +
      +      val k = topicsMatrix.numCols
      +      val metadata = compact(render
      +        (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
      +          ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows) ~
      +          ("docConcentration" -> docConcentration.toArray.toSeq) ~
      +          ("topicConcentration" -> topicConcentration) ~
      +          ("gammaShape" -> gammaShape)))
      +      sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
      +
      +      val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix
      +      val topics = Range(0, k).map { topicInd =>
      +        Data(Vectors.dense((topicsDenseMatrix(::, topicInd).toArray)), topicInd)
      +      }.toSeq
      +      sc.parallelize(topics, 1).toDF().write.parquet(Loader.dataPath(path))
      +    }
      +
      +    def load(
      +        sc: SparkContext,
      +        path: String,
      +        docConcentration: Vector,
      +        topicConcentration: Double,
      +        gammaShape: Double): LocalLDAModel = {
      +      val dataPath = Loader.dataPath(path)
      +      val sqlContext = SQLContext.getOrCreate(sc)
      +      val dataFrame = sqlContext.read.parquet(dataPath)
      +
      +      Loader.checkSchema[Data](dataFrame.schema)
      +      val topics = dataFrame.collect()
      +      val vocabSize = topics(0).getAs[Vector](0).size
      +      val k = topics.length
      +
      +      val brzTopics = BDM.zeros[Double](vocabSize, k)
      +      topics.foreach { case Row(vec: Vector, ind: Int) =>
      +        brzTopics(::, ind) := vec.toBreeze
      +      }
      +      val topicsMat = Matrices.fromBreeze(brzTopics)
      +
      +      new LocalLDAModel(topicsMat, docConcentration, topicConcentration, gammaShape)
      +    }
      +  }
      +
      +  @Since("1.5.0")
      +  override def load(sc: SparkContext, path: String): LocalLDAModel = {
      +    val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
      +    implicit val formats = DefaultFormats
      +    val expectedK = (metadata \ "k").extract[Int]
      +    val expectedVocabSize = (metadata \ "vocabSize").extract[Int]
      +    val docConcentration =
      +      Vectors.dense((metadata \ "docConcentration").extract[Seq[Double]].toArray)
      +    val topicConcentration = (metadata \ "topicConcentration").extract[Double]
      +    val gammaShape = (metadata \ "gammaShape").extract[Double]
      +    val classNameV1_0 = SaveLoadV1_0.thisClassName
      +
      +    val model = (loadedClassName, loadedVersion) match {
      +      case (className, "1.0") if className == classNameV1_0 =>
      +        SaveLoadV1_0.load(sc, path, docConcentration, topicConcentration, gammaShape)
      +      case _ => throw new Exception(
      +        s"LocalLDAModel.load did not recognize model with (className, format version):" +
      +          s"($loadedClassName, $loadedVersion).  Supported:\n" +
      +          s"  ($classNameV1_0, 1.0)")
      +    }
      +
      +    val topicsMatrix = model.topicsMatrix
      +    require(expectedK == topicsMatrix.numCols,
      +      s"LocalLDAModel requires $expectedK topics, got ${topicsMatrix.numCols} topics")
      +    require(expectedVocabSize == topicsMatrix.numRows,
      +      s"LocalLDAModel requires $expectedVocabSize terms for each topic, " +
      +        s"but got ${topicsMatrix.numRows}")
      +    model
      +  }
      +}
      +
       /**
        * :: Experimental ::
        *
      @@ -193,28 +489,28 @@ class LocalLDAModel private[clustering] (
        * than the [[LocalLDAModel]].
        */
       @Experimental
      -class DistributedLDAModel private (
      -    private val graph: Graph[LDA.TopicCounts, LDA.TokenCount],
      -    private val globalTopicTotals: LDA.TopicCounts,
      -    val k: Int,
      -    val vocabSize: Int,
      -    private val docConcentration: Double,
      -    private val topicConcentration: Double,
      -    private[spark] val iterationTimes: Array[Double]) extends LDAModel {
      +@Since("1.3.0")
      +class DistributedLDAModel private[clustering] (
      +    private[clustering] val graph: Graph[LDA.TopicCounts, LDA.TokenCount],
      +    private[clustering] val globalTopicTotals: LDA.TopicCounts,
      +    @Since("1.3.0") val k: Int,
      +    @Since("1.3.0") val vocabSize: Int,
      +    @Since("1.5.0") override val docConcentration: Vector,
      +    @Since("1.5.0") override val topicConcentration: Double,
      +    private[spark] val iterationTimes: Array[Double],
      +    override protected[clustering] val gammaShape: Double = 100)
      +  extends LDAModel {
       
         import LDA._
       
      -  private[clustering] def this(state: EMLDAOptimizer, iterationTimes: Array[Double]) = {
      -    this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration,
      -      state.topicConcentration, iterationTimes)
      -  }
      -
         /**
          * Convert model to a local model.
          * The local model stores the inferred topics but not the topic distributions for training
          * documents.
          */
      -  def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix)
      +  @Since("1.3.0")
      +  def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix, docConcentration, topicConcentration,
      +    gammaShape)
       
         /**
          * Inferred topics, where each topic is represented by a distribution over terms.
      @@ -223,6 +519,7 @@ class DistributedLDAModel private (
          *
          * WARNING: This matrix is collected from an RDD. Beware memory usage when vocabSize, k are large.
          */
      +  @Since("1.3.0")
         override lazy val topicsMatrix: Matrix = {
           // Collect row-major topics
           val termTopicCounts: Array[(Int, TopicCounts)] =
      @@ -241,6 +538,7 @@ class DistributedLDAModel private (
           Matrices.fromBreeze(brzTopics)
         }
       
      +  @Since("1.3.0")
         override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = {
           val numTopics = k
           // Note: N_k is not needed to find the top terms, but it is needed to normalize weights
      @@ -272,6 +570,87 @@ class DistributedLDAModel private (
           }
         }
       
      +  /**
      +   * Return the top documents for each topic
      +   *
      +   * @param maxDocumentsPerTopic  Maximum number of documents to collect for each topic.
      +   * @return  Array over topics.  Each element represent as a pair of matching arrays:
      +   *          (IDs for the documents, weights of the topic in these documents).
      +   *          For each topic, documents are sorted in order of decreasing topic weights.
      +   */
      +  @Since("1.5.0")
      +  def topDocumentsPerTopic(maxDocumentsPerTopic: Int): Array[(Array[Long], Array[Double])] = {
      +    val numTopics = k
      +    val topicsInQueues: Array[BoundedPriorityQueue[(Double, Long)]] =
      +      topicDistributions.mapPartitions { docVertices =>
      +        // For this partition, collect the most common docs for each topic in queues:
      +        //  queues(topic) = queue of (doc topic, doc ID).
      +        val queues =
      +          Array.fill(numTopics)(new BoundedPriorityQueue[(Double, Long)](maxDocumentsPerTopic))
      +        for ((docId, docTopics) <- docVertices) {
      +          var topic = 0
      +          while (topic < numTopics) {
      +            queues(topic) += (docTopics(topic) -> docId)
      +            topic += 1
      +          }
      +        }
      +        Iterator(queues)
      +      }.treeReduce { (q1, q2) =>
      +        q1.zip(q2).foreach { case (a, b) => a ++= b }
      +        q1
      +      }
      +    topicsInQueues.map { q =>
      +      val (docTopics, docs) = q.toArray.sortBy(-_._1).unzip
      +      (docs.toArray, docTopics.toArray)
      +    }
      +  }
      +
      +  /**
      +   * Return the top topic for each (doc, term) pair.  I.e., for each document, what is the most
      +   * likely topic generating each term?
      +   *
      +   * @return RDD of (doc ID, assignment of top topic index for each term),
      +   *         where the assignment is specified via a pair of zippable arrays
      +   *         (term indices, topic indices).  Note that terms will be omitted if not present in
      +   *         the document.
      +   */
      +  @Since("1.5.0")
      +  lazy val topicAssignments: RDD[(Long, Array[Int], Array[Int])] = {
      +    // For reference, compare the below code with the core part of EMLDAOptimizer.next().
      +    val eta = topicConcentration
      +    val W = vocabSize
      +    val alpha = docConcentration(0)
      +    val N_k = globalTopicTotals
      +    val sendMsg: EdgeContext[TopicCounts, TokenCount, (Array[Int], Array[Int])] => Unit =
      +      (edgeContext) => {
      +        // E-STEP: Compute gamma_{wjk} (smoothed topic distributions).
      +        val scaledTopicDistribution: TopicCounts =
      +          computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha)
      +        // For this (doc j, term w), send top topic k to doc vertex.
      +        val topTopic: Int = argmax(scaledTopicDistribution)
      +        val term: Int = index2term(edgeContext.dstId)
      +        edgeContext.sendToSrc((Array(term), Array(topTopic)))
      +      }
      +    val mergeMsg: ((Array[Int], Array[Int]), (Array[Int], Array[Int])) => (Array[Int], Array[Int]) =
      +      (terms_topics0, terms_topics1) => {
      +        (terms_topics0._1 ++ terms_topics1._1, terms_topics0._2 ++ terms_topics1._2)
      +      }
      +    // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts.
      +    val perDocAssignments =
      +      graph.aggregateMessages[(Array[Int], Array[Int])](sendMsg, mergeMsg).filter(isDocumentVertex)
      +    perDocAssignments.map { case (docID: Long, (terms: Array[Int], topics: Array[Int])) =>
      +      // TODO: Avoid zip, which is inefficient.
      +      val (sortedTerms, sortedTopics) = terms.zip(topics).sortBy(_._1).unzip
      +      (docID, sortedTerms.toArray, sortedTopics.toArray)
      +    }
      +  }
      +
      +  /** Java-friendly version of [[topicAssignments]] */
      +  @Since("1.5.0")
      +  lazy val javaTopicAssignments: JavaRDD[(java.lang.Long, Array[Int], Array[Int])] = {
      +    topicAssignments.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[Int])]].toJavaRDD()
      +  }
      +
         // TODO
         // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???
       
      @@ -285,9 +664,11 @@ class DistributedLDAModel private (
          *  - Even with [[logPrior]], this is NOT the same as the data log likelihood given the
          *    hyperparameters.
          */
      +  @Since("1.3.0")
         lazy val logLikelihood: Double = {
      -    val eta = topicConcentration
      -    val alpha = docConcentration
      +    // TODO: generalize this for asymmetric (non-scalar) alpha
      +    val alpha = this.docConcentration(0) // To avoid closure capture of enclosing object
      +    val eta = this.topicConcentration
           assert(eta > 1.0)
           assert(alpha > 1.0)
           val N_k = globalTopicTotals
      @@ -308,11 +689,13 @@ class DistributedLDAModel private (
       
         /**
          * Log probability of the current parameter estimate:
      -   *  log P(topics, topic distributions for docs | alpha, eta)
      +   * log P(topics, topic distributions for docs | alpha, eta)
          */
      +  @Since("1.3.0")
         lazy val logPrior: Double = {
      -    val eta = topicConcentration
      -    val alpha = docConcentration
      +    // TODO: generalize this for asymmetric (non-scalar) alpha
      +    val alpha = this.docConcentration(0) // To avoid closure capture of enclosing object
      +    val eta = this.topicConcentration
           // Term vertices: Compute phi_{wk}.  Use to compute prior log probability.
           // Doc vertex: Compute theta_{kj}.  Use to compute prior log probability.
           val N_k = globalTopicTotals
      @@ -323,12 +706,12 @@ class DistributedLDAModel private (
                 val N_wk = vertex._2
                 val smoothed_N_wk: TopicCounts = N_wk + (eta - 1.0)
                 val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k
      -          (eta - 1.0) * brzSum(phi_wk.map(math.log))
      +          (eta - 1.0) * sum(phi_wk.map(math.log))
               } else {
                 val N_kj = vertex._2
                 val smoothed_N_kj: TopicCounts = N_kj + (alpha - 1.0)
                 val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0)
      -          (alpha - 1.0) * brzSum(theta_kj.map(math.log))
      +          (alpha - 1.0) * sum(theta_kj.map(math.log))
               }
           }
           graph.vertices.aggregate(0.0)(seqOp, _ + _)
      @@ -340,18 +723,192 @@ class DistributedLDAModel private (
          *
          * @return  RDD of (document ID, topic distribution) pairs
          */
      +  @Since("1.3.0")
         def topicDistributions: RDD[(Long, Vector)] = {
           graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) =>
             (docID.toLong, Vectors.fromBreeze(normalize(topicCounts, 1.0)))
           }
         }
       
      -  /** Java-friendly version of [[topicDistributions]] */
      +  /**
      +   * Java-friendly version of [[topicDistributions]]
      +   */
      +  @Since("1.4.1")
         def javaTopicDistributions: JavaPairRDD[java.lang.Long, Vector] = {
           JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]])
         }
       
      +  /**
      +   * For each document, return the top k weighted topics for that document and their weights.
      +   * @return RDD of (doc ID, topic indices, topic weights)
      +   */
      +  @Since("1.5.0")
      +  def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] = {
      +    graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) =>
      +      val topIndices = argtopk(topicCounts, k)
      +      val sumCounts = sum(topicCounts)
      +      val weights = if (sumCounts != 0) {
      +        topicCounts(topIndices) / sumCounts
      +      } else {
      +        topicCounts(topIndices)
      +      }
      +      (docID.toLong, topIndices.toArray, weights.toArray)
      +    }
      +  }
      +
      +  /**
      +   * Java-friendly version of [[topTopicsPerDocument]]
      +   */
      +  @Since("1.5.0")
      +  def javaTopTopicsPerDocument(k: Int): JavaRDD[(java.lang.Long, Array[Int], Array[Double])] = {
      +    val topics = topTopicsPerDocument(k)
      +    topics.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[Double])]].toJavaRDD()
      +  }
      +
         // TODO:
         // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???
       
      +  override protected def formatVersion = "1.0"
      +
      +  /**
      +   * Java-friendly version of [[topicDistributions]]
      +   */
      +  @Since("1.5.0")
      +  override def save(sc: SparkContext, path: String): Unit = {
      +    DistributedLDAModel.SaveLoadV1_0.save(
      +      sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration,
      +      iterationTimes, gammaShape)
      +  }
      +}
      +
      +
      +@Experimental
      +@Since("1.5.0")
      +object DistributedLDAModel extends Loader[DistributedLDAModel] {
      +
      +  private object SaveLoadV1_0 {
      +
      +    val thisFormatVersion = "1.0"
      +
      +    val thisClassName = "org.apache.spark.mllib.clustering.DistributedLDAModel"
      +
      +    // Store globalTopicTotals as a Vector.
      +    case class Data(globalTopicTotals: Vector)
      +
      +    // Store each term and document vertex with an id and the topicWeights.
      +    case class VertexData(id: Long, topicWeights: Vector)
      +
      +    // Store each edge with the source id, destination id and tokenCounts.
      +    case class EdgeData(srcId: Long, dstId: Long, tokenCounts: Double)
      +
      +    def save(
      +        sc: SparkContext,
      +        path: String,
      +        graph: Graph[LDA.TopicCounts, LDA.TokenCount],
      +        globalTopicTotals: LDA.TopicCounts,
      +        k: Int,
      +        vocabSize: Int,
      +        docConcentration: Vector,
      +        topicConcentration: Double,
      +        iterationTimes: Array[Double],
      +        gammaShape: Double): Unit = {
      +      val sqlContext = SQLContext.getOrCreate(sc)
      +      import sqlContext.implicits._
      +
      +      val metadata = compact(render
      +        (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
      +          ("k" -> k) ~ ("vocabSize" -> vocabSize) ~
      +          ("docConcentration" -> docConcentration.toArray.toSeq) ~
      +          ("topicConcentration" -> topicConcentration) ~
      +          ("iterationTimes" -> iterationTimes.toSeq) ~
      +          ("gammaShape" -> gammaShape)))
      +      sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
      +
      +      val newPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString
      +      sc.parallelize(Seq(Data(Vectors.fromBreeze(globalTopicTotals)))).toDF()
      +        .write.parquet(newPath)
      +
      +      val verticesPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString
      +      graph.vertices.map { case (ind, vertex) =>
      +        VertexData(ind, Vectors.fromBreeze(vertex))
      +      }.toDF().write.parquet(verticesPath)
      +
      +      val edgesPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString
      +      graph.edges.map { case Edge(srcId, dstId, prop) =>
      +        EdgeData(srcId, dstId, prop)
      +      }.toDF().write.parquet(edgesPath)
      +    }
      +
      +    def load(
      +        sc: SparkContext,
      +        path: String,
      +        vocabSize: Int,
      +        docConcentration: Vector,
      +        topicConcentration: Double,
      +        iterationTimes: Array[Double],
      +        gammaShape: Double): DistributedLDAModel = {
      +      val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString
      +      val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString
      +      val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString
      +      val sqlContext = SQLContext.getOrCreate(sc)
      +      val dataFrame = sqlContext.read.parquet(dataPath)
      +      val vertexDataFrame = sqlContext.read.parquet(vertexDataPath)
      +      val edgeDataFrame = sqlContext.read.parquet(edgeDataPath)
      +
      +      Loader.checkSchema[Data](dataFrame.schema)
      +      Loader.checkSchema[VertexData](vertexDataFrame.schema)
      +      Loader.checkSchema[EdgeData](edgeDataFrame.schema)
      +      val globalTopicTotals: LDA.TopicCounts =
      +        dataFrame.first().getAs[Vector](0).toBreeze.toDenseVector
      +      val vertices: RDD[(VertexId, LDA.TopicCounts)] = vertexDataFrame.map {
      +        case Row(ind: Long, vec: Vector) => (ind, vec.toBreeze.toDenseVector)
      +      }
      +
      +      val edges: RDD[Edge[LDA.TokenCount]] = edgeDataFrame.map {
      +        case Row(srcId: Long, dstId: Long, prop: Double) => Edge(srcId, dstId, prop)
      +      }
      +      val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges)
      +
      +      new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize,
      +        docConcentration, topicConcentration, iterationTimes, gammaShape)
      +    }
      +
      +  }
      +
      +  @Since("1.5.0")
      +  override def load(sc: SparkContext, path: String): DistributedLDAModel = {
      +    val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
      +    implicit val formats = DefaultFormats
      +    val expectedK = (metadata \ "k").extract[Int]
      +    val vocabSize = (metadata \ "vocabSize").extract[Int]
      +    val docConcentration =
      +      Vectors.dense((metadata \ "docConcentration").extract[Seq[Double]].toArray)
      +    val topicConcentration = (metadata \ "topicConcentration").extract[Double]
      +    val iterationTimes = (metadata \ "iterationTimes").extract[Seq[Double]]
      +    val gammaShape = (metadata \ "gammaShape").extract[Double]
      +    val classNameV1_0 = SaveLoadV1_0.thisClassName
      +
      +    val model = (loadedClassName, loadedVersion) match {
      +      case (className, "1.0") if className == classNameV1_0 =>
      +        DistributedLDAModel.SaveLoadV1_0.load(sc, path, vocabSize, docConcentration,
      +          topicConcentration, iterationTimes.toArray, gammaShape)
      +      case _ => throw new Exception(
      +        s"DistributedLDAModel.load did not recognize model with (className, format version):" +
      +          s"($loadedClassName, $loadedVersion).  Supported: ($classNameV1_0, 1.0)")
      +    }
      +
      +    require(model.vocabSize == vocabSize,
      +      s"DistributedLDAModel requires $vocabSize vocabSize, got ${model.vocabSize} vocabSize")
      +    require(model.docConcentration == docConcentration,
      +      s"DistributedLDAModel requires $docConcentration docConcentration, " +
      +        s"got ${model.docConcentration} docConcentration")
      +    require(model.topicConcentration == topicConcentration,
      +      s"DistributedLDAModel requires $topicConcentration docConcentration, " +
      +        s"got ${model.topicConcentration} docConcentration")
      +    require(expectedK == model.k,
      +      s"DistributedLDAModel requires $expectedK topics, got ${model.k} topics")
      +    model
      +  }
      +
       }
      +
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
      index 8e5154b902d1..38486e949bbc 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
      @@ -19,15 +19,15 @@ package org.apache.spark.mllib.clustering
       
       import java.util.Random
       
      -import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, sum, normalize, kron}
      -import breeze.numerics.{digamma, exp, abs}
      +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, all, normalize, sum}
      +import breeze.numerics.{trigamma, abs, exp}
       import breeze.stats.distributions.{Gamma, RandBasis}
       
      -import org.apache.spark.annotation.DeveloperApi
      +import org.apache.spark.annotation.{DeveloperApi, Since}
       import org.apache.spark.graphx._
       import org.apache.spark.graphx.impl.GraphImpl
       import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
      -import org.apache.spark.mllib.linalg.{Matrices, SparseVector, DenseVector, Vector}
      +import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors}
       import org.apache.spark.rdd.RDD
       
       /**
      @@ -36,6 +36,7 @@ import org.apache.spark.rdd.RDD
        * An LDAOptimizer specifies which optimization/learning/inference algorithm to use, and it can
        * hold optimizer-specific parameters for users to set.
        */
      +@Since("1.4.0")
       @DeveloperApi
       sealed trait LDAOptimizer {
       
      @@ -73,8 +74,8 @@ sealed trait LDAOptimizer {
        *  - Paper which clearly explains several algorithms, including EM:
        *    Asuncion, Welling, Smyth, and Teh.
        *    "On Smoothing and Inference for Topic Models."  UAI, 2009.
      - *
        */
      +@Since("1.4.0")
       @DeveloperApi
       final class EMLDAOptimizer extends LDAOptimizer {
       
      @@ -95,8 +96,9 @@ final class EMLDAOptimizer extends LDAOptimizer {
          * Compute bipartite term/doc graph.
          */
         override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer = {
      -
      +    // EMLDAOptimizer currently only supports symmetric document-topic priors
           val docConcentration = lda.getDocConcentration
      +
           val topicConcentration = lda.getTopicConcentration
           val k = lda.getK
       
      @@ -139,8 +141,9 @@ final class EMLDAOptimizer extends LDAOptimizer {
           this.k = k
           this.vocabSize = docs.take(1).head._2.size
           this.checkpointInterval = lda.getCheckpointInterval
      -    this.graphCheckpointer = new
      -      PeriodicGraphCheckpointer[TopicCounts, TokenCount](graph, checkpointInterval)
      +    this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount](
      +      checkpointInterval, graph.vertices.sparkContext)
      +    this.graphCheckpointer.update(this.graph)
           this.globalTopicTotals = computeGlobalTopicTotals()
           this
         }
      @@ -164,7 +167,7 @@ final class EMLDAOptimizer extends LDAOptimizer {
               edgeContext.sendToDst((false, scaledTopicDistribution))
               edgeContext.sendToSrc((false, scaledTopicDistribution))
             }
      -    // This is a hack to detect whether we could modify the values in-place.
      +    // The Boolean is a hack to detect whether we could modify the values in-place.
           // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438)
           val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) =
             (m0, m1) => {
      @@ -185,7 +188,7 @@ final class EMLDAOptimizer extends LDAOptimizer {
           // Update the vertex descriptors with the new counts.
           val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges)
           graph = newGraph
      -    graphCheckpointer.updateGraph(newGraph)
      +    graphCheckpointer.update(newGraph)
           globalTopicTotals = computeGlobalTopicTotals()
           this
         }
      @@ -205,7 +208,11 @@ final class EMLDAOptimizer extends LDAOptimizer {
         override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
           require(graph != null, "graph is null, EMLDAOptimizer not initialized.")
           this.graphCheckpointer.deleteAllCheckpoints()
      -    new DistributedLDAModel(this, iterationTimes)
      +    // The constructor's default arguments assume gammaShape = 100 to ensure equivalence in
      +    // LDAModel.toLocal conversion
      +    new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize,
      +      Vectors.dense(Array.fill(this.k)(this.docConcentration)), this.topicConcentration,
      +      iterationTimes)
         }
       }
       
      @@ -220,6 +227,7 @@ final class EMLDAOptimizer extends LDAOptimizer {
        * Original Online LDA paper:
        *   Hoffman, Blei and Bach, "Online Learning for Latent Dirichlet Allocation." NIPS, 2010.
        */
      +@Since("1.4.0")
       @DeveloperApi
       final class OnlineLDAOptimizer extends LDAOptimizer {
       
      @@ -229,24 +237,28 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
         private var vocabSize: Int = 0
       
         /** alias for docConcentration */
      -  private var alpha: Double = 0
      +  private var alpha: Vector = Vectors.dense(0)
       
      -  /** (private[clustering] for debugging)  Get docConcentration */
      -  private[clustering] def getAlpha: Double = alpha
      +  /** (for debugging)  Get docConcentration */
      +  private[clustering] def getAlpha: Vector = alpha
       
         /** alias for topicConcentration */
         private var eta: Double = 0
       
      -  /** (private[clustering] for debugging)  Get topicConcentration */
      +  /** (for debugging)  Get topicConcentration */
         private[clustering] def getEta: Double = eta
       
         private var randomGenerator: java.util.Random = null
       
      +  /** (for debugging) Whether to sample mini-batches with replacement. (default = true) */
      +  private var sampleWithReplacement: Boolean = true
      +
         // Online LDA specific parameters
         // Learning rate is: (tau0 + t)^{-kappa}
         private var tau0: Double = 1024
         private var kappa: Double = 0.51
         private var miniBatchFraction: Double = 0.05
      +  private var optimizeDocConcentration: Boolean = false
       
         // internal data structure
         private var docs: RDD[(Long, Vector)] = null
      @@ -254,7 +266,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
         /** Dirichlet parameter for the posterior over topics */
         private var lambda: BDM[Double] = null
       
      -  /** (private[clustering] for debugging) Get parameter for topics */
      +  /** (for debugging) Get parameter for topics */
         private[clustering] def getLambda: BDM[Double] = lambda
       
         /** Current iteration (count of invocations of [[next()]]) */
      @@ -265,6 +277,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
          * A (positive) learning parameter that downweights early iterations. Larger values make early
          * iterations count less.
          */
      +  @Since("1.4.0")
         def getTau0: Double = this.tau0
       
         /**
      @@ -272,6 +285,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
          * iterations count less.
          * Default: 1024, following the original Online LDA paper.
          */
      +  @Since("1.4.0")
         def setTau0(tau0: Double): this.type = {
           require(tau0 > 0, s"LDA tau0 must be positive, but was set to $tau0")
           this.tau0 = tau0
      @@ -281,6 +295,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
         /**
          * Learning rate: exponential decay rate
          */
      +  @Since("1.4.0")
         def getKappa: Double = this.kappa
       
         /**
      @@ -288,6 +303,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
          * (0.5, 1.0] to guarantee asymptotic convergence.
          * Default: 0.51, based on the original Online LDA paper.
          */
      +  @Since("1.4.0")
         def setKappa(kappa: Double): this.type = {
           require(kappa >= 0, s"Online LDA kappa must be nonnegative, but was set to $kappa")
           this.kappa = kappa
      @@ -297,6 +313,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
         /**
          * Mini-batch fraction, which sets the fraction of document sampled and used in each iteration
          */
      +  @Since("1.4.0")
         def getMiniBatchFraction: Double = this.miniBatchFraction
       
         /**
      @@ -309,6 +326,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
          *
          * Default: 0.05, i.e., 5% of total documents.
          */
      +  @Since("1.4.0")
         def setMiniBatchFraction(miniBatchFraction: Double): this.type = {
           require(miniBatchFraction > 0.0 && miniBatchFraction <= 1.0,
             s"Online LDA miniBatchFraction must be in range (0,1], but was set to $miniBatchFraction")
      @@ -317,7 +335,24 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
         }
       
         /**
      -   * (private[clustering])
      +   * Optimize docConcentration, indicates whether docConcentration (Dirichlet parameter for
      +   * document-topic distribution) will be optimized during training.
      +   */
      +  @Since("1.5.0")
      +  def getOptimizeDocConcentration: Boolean = this.optimizeDocConcentration
      +
      +  /**
      +   * Sets whether to optimize docConcentration parameter during training.
      +   *
      +   * Default: false
      +   */
      +  @Since("1.5.0")
      +  def setOptimizeDocConcentration(optimizeDocConcentration: Boolean): this.type = {
      +    this.optimizeDocConcentration = optimizeDocConcentration
      +    this
      +  }
      +
      +  /**
          * Set the Dirichlet parameter for the posterior over topics.
          * This is only used for testing now. In the future, it can help support training stop/resume.
          */
      @@ -327,7 +362,6 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
         }
       
         /**
      -   * (private[clustering])
          * Used for random initialization of the variational parameters.
          * Larger value produces values closer to 1.0.
          * This is only used for testing currently.
      @@ -337,13 +371,36 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
           this
         }
       
      +  /**
      +   * Sets whether to sample mini-batches with or without replacement. (default = true)
      +   * This is only used for testing currently.
      +   */
      +  private[clustering] def setSampleWithReplacement(replace: Boolean): this.type = {
      +    this.sampleWithReplacement = replace
      +    this
      +  }
      +
         override private[clustering] def initialize(
             docs: RDD[(Long, Vector)],
             lda: LDA): OnlineLDAOptimizer = {
           this.k = lda.getK
           this.corpusSize = docs.count()
           this.vocabSize = docs.first()._2.size
      -    this.alpha = if (lda.getDocConcentration == -1) 1.0 / k else lda.getDocConcentration
      +    this.alpha = if (lda.getAsymmetricDocConcentration.size == 1) {
      +      if (lda.getAsymmetricDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k))
      +      else {
      +        require(lda.getAsymmetricDocConcentration(0) >= 0,
      +          s"all entries in alpha must be >=0, got: $alpha")
      +        Vectors.dense(Array.fill(k)(lda.getAsymmetricDocConcentration(0)))
      +      }
      +    } else {
      +      require(lda.getAsymmetricDocConcentration.size == k,
      +        s"alpha must have length k, got: $alpha")
      +      lda.getAsymmetricDocConcentration.foreachActive { case (_, x) =>
      +        require(x >= 0, s"all entries in alpha must be >= 0, got: $alpha")
      +      }
      +      lda.getAsymmetricDocConcentration
      +    }
           this.eta = if (lda.getTopicConcentration == -1) 1.0 / k else lda.getTopicConcentration
           this.randomGenerator = new Random(lda.getSeed)
       
      @@ -356,7 +413,8 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
         }
       
         override private[clustering] def next(): OnlineLDAOptimizer = {
      -    val batch = docs.sample(withReplacement = true, miniBatchFraction, randomGenerator.nextLong())
      +    val batch = docs.sample(withReplacement = sampleWithReplacement, miniBatchFraction,
      +      randomGenerator.nextLong())
           if (batch.isEmpty()) return this
           submitMiniBatch(batch)
         }
      @@ -370,80 +428,85 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
           iteration += 1
           val k = this.k
           val vocabSize = this.vocabSize
      -    val Elogbeta = dirichletExpectation(lambda)
      -    val expElogbeta = exp(Elogbeta)
      -    val alpha = this.alpha
      +    val expElogbeta = exp(LDAUtils.dirichletExpectation(lambda)).t
      +    val expElogbetaBc = batch.sparkContext.broadcast(expElogbeta)
      +    val alpha = this.alpha.toBreeze
           val gammaShape = this.gammaShape
       
      -    val stats: RDD[BDM[Double]] = batch.mapPartitions { docs =>
      -      val stat = BDM.zeros[Double](k, vocabSize)
      -      docs.foreach { doc =>
      -        val termCounts = doc._2
      -        val (ids: List[Int], cts: Array[Double]) = termCounts match {
      -          case v: DenseVector => ((0 until v.size).toList, v.values)
      -          case v: SparseVector => (v.indices.toList, v.values)
      -          case v => throw new IllegalArgumentException("Online LDA does not support vector type "
      -            + v.getClass)
      -        }
      -
      -        // Initialize the variational distribution q(theta|gamma) for the mini-batch
      -        var gammad = new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k).t // 1 * K
      -        var Elogthetad = digamma(gammad) - digamma(sum(gammad))     // 1 * K
      -        var expElogthetad = exp(Elogthetad)                         // 1 * K
      -        val expElogbetad = expElogbeta(::, ids).toDenseMatrix       // K * ids
      -
      -        var phinorm = expElogthetad * expElogbetad + 1e-100         // 1 * ids
      -        var meanchange = 1D
      -        val ctsVector = new BDV[Double](cts).t                      // 1 * ids
      -
      -        // Iterate between gamma and phi until convergence
      -        while (meanchange > 1e-3) {
      -          val lastgamma = gammad
      -          //        1*K                  1 * ids               ids * k
      -          gammad = (expElogthetad :* ((ctsVector / phinorm) * expElogbetad.t)) + alpha
      -          Elogthetad = digamma(gammad) - digamma(sum(gammad))
      -          expElogthetad = exp(Elogthetad)
      -          phinorm = expElogthetad * expElogbetad + 1e-100
      -          meanchange = sum(abs(gammad - lastgamma)) / k
      -        }
      +    val stats: RDD[(BDM[Double], List[BDV[Double]])] = batch.mapPartitions { docs =>
      +      val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0)
       
      -        val m1 = expElogthetad.t
      -        val m2 = (ctsVector / phinorm).t.toDenseVector
      -        var i = 0
      -        while (i < ids.size) {
      -          stat(::, ids(i)) := stat(::, ids(i)) + m1 * m2(i)
      -          i += 1
      +      val stat = BDM.zeros[Double](k, vocabSize)
      +      var gammaPart = List[BDV[Double]]()
      +      nonEmptyDocs.zipWithIndex.foreach { case ((_, termCounts: Vector), idx: Int) =>
      +        val ids: List[Int] = termCounts match {
      +          case v: DenseVector => (0 until v.size).toList
      +          case v: SparseVector => v.indices.toList
               }
      +        val (gammad, sstats) = OnlineLDAOptimizer.variationalTopicInference(
      +          termCounts, expElogbetaBc.value, alpha, gammaShape, k)
      +        stat(::, ids) := stat(::, ids).toDenseMatrix + sstats
      +        gammaPart = gammad :: gammaPart
             }
      -      Iterator(stat)
      +      Iterator((stat, gammaPart))
           }
      -
      -    val statsSum: BDM[Double] = stats.reduce(_ += _)
      -    val batchResult = statsSum :* expElogbeta
      +    val statsSum: BDM[Double] = stats.map(_._1).reduce(_ += _)
      +    expElogbetaBc.unpersist()
      +    val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat(
      +      stats.map(_._2).reduce(_ ++ _).map(_.toDenseMatrix): _*)
      +    val batchResult = statsSum :* expElogbeta.t
       
           // Note that this is an optimization to avoid batch.count
      -    update(batchResult, iteration, (miniBatchFraction * corpusSize).ceil.toInt)
      +    updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt)
      +    if (optimizeDocConcentration) updateAlpha(gammat)
           this
         }
       
      -  override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
      -    new LocalLDAModel(Matrices.fromBreeze(lambda).transpose)
      -  }
      -
         /**
          * Update lambda based on the batch submitted. batchSize can be different for each iteration.
          */
      -  private[clustering] def update(stat: BDM[Double], iter: Int, batchSize: Int): Unit = {
      +  private def updateLambda(stat: BDM[Double], batchSize: Int): Unit = {
           // weight of the mini-batch.
      -    val weight = math.pow(getTau0 + iter, -getKappa)
      +    val weight = rho()
       
           // Update lambda based on documents.
      -    lambda = lambda * (1 - weight) +
      -      (stat * (corpusSize.toDouble / batchSize.toDouble) + eta) * weight
      +    lambda := (1 - weight) * lambda +
      +      weight * (stat * (corpusSize.toDouble / batchSize.toDouble) + eta)
         }
       
         /**
      -   * Get a random matrix to initialize lambda
      +   * Update alpha based on `gammat`, the inferred topic distributions for documents in the
      +   * current mini-batch. Uses Newton-Rhapson method.
      +   * @see Section 3.3, Huang: Maximum Likelihood Estimation of Dirichlet Distribution Parameters
      +   *      (http://jonathan-huang.org/research/dirichlet/dirichlet.pdf)
      +   */
      +  private def updateAlpha(gammat: BDM[Double]): Unit = {
      +    val weight = rho()
      +    val N = gammat.rows.toDouble
      +    val alpha = this.alpha.toBreeze.toDenseVector
      +    val logphat: BDM[Double] = sum(LDAUtils.dirichletExpectation(gammat)(::, breeze.linalg.*)) / N
      +    val gradf = N * (-LDAUtils.dirichletExpectation(alpha) + logphat.toDenseVector)
      +
      +    val c = N * trigamma(sum(alpha))
      +    val q = -N * trigamma(alpha)
      +    val b = sum(gradf / q) / (1D / c + sum(1D / q))
      +
      +    val dalpha = -(gradf - b) / q
      +
      +    if (all((weight * dalpha + alpha) :> 0D)) {
      +      alpha :+= weight * dalpha
      +      this.alpha = Vectors.dense(alpha.toArray)
      +    }
      +  }
      +
      +
      +  /** Calculate learning rate rho for the current [[iteration]]. */
      +  private def rho(): Double = {
      +    math.pow(getTau0 + this.iteration, -getKappa)
      +  }
      +
      +  /**
      +   * Get a random matrix to initialize lambda.
          */
         private def getGammaMatrix(row: Int, col: Int): BDM[Double] = {
           val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister(
      @@ -453,15 +516,58 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
           new BDM[Double](col, row, temp).t
         }
       
      +  override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
      +    new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta, gammaShape)
      +  }
      +
      +}
      +
      +/**
      + * Serializable companion object containing helper methods and shared code for
      + * [[OnlineLDAOptimizer]] and [[LocalLDAModel]].
      + */
      +private[clustering] object OnlineLDAOptimizer {
         /**
      -   * For theta ~ Dir(alpha), computes E[log(theta)] given alpha. Currently the implementation
      -   * uses digamma which is accurate but expensive.
      +   * Uses variational inference to infer the topic distribution `gammad` given the term counts
      +   * for a document. `termCounts` must contain at least one non-zero entry, otherwise Breeze will
      +   * throw a BLAS error.
      +   *
      +   * An optimization (Lee, Seung: Algorithms for non-negative matrix factorization, NIPS 2001)
      +   * avoids explicit computation of variational parameter `phi`.
      +   * @see [[http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.31.7566]]
          */
      -  private def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = {
      -    val rowSum = sum(alpha(breeze.linalg.*, ::))
      -    val digAlpha = digamma(alpha)
      -    val digRowSum = digamma(rowSum)
      -    val result = digAlpha(::, breeze.linalg.*) - digRowSum
      -    result
      +  private[clustering] def variationalTopicInference(
      +      termCounts: Vector,
      +      expElogbeta: BDM[Double],
      +      alpha: breeze.linalg.Vector[Double],
      +      gammaShape: Double,
      +      k: Int): (BDV[Double], BDM[Double]) = {
      +    val (ids: List[Int], cts: Array[Double]) = termCounts match {
      +      case v: DenseVector => ((0 until v.size).toList, v.values)
      +      case v: SparseVector => (v.indices.toList, v.values)
      +    }
      +    // Initialize the variational distribution q(theta|gamma) for the mini-batch
      +    val gammad: BDV[Double] =
      +      new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k)                   // K
      +    val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad))  // K
      +    val expElogbetad = expElogbeta(ids, ::).toDenseMatrix                        // ids * K
      +
      +    val phiNorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100            // ids
      +    var meanGammaChange = 1D
      +    val ctsVector = new BDV[Double](cts)                                         // ids
      +
      +    // Iterate between gamma and phi until convergence
      +    while (meanGammaChange > 1e-3) {
      +      val lastgamma = gammad.copy
      +      //        K                  K * ids               ids
      +      gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phiNorm))) :+ alpha
      +      expElogthetad := exp(LDAUtils.dirichletExpectation(gammad))
      +      // TODO: Keep more values in log space, and only exponentiate when needed.
      +      phiNorm := expElogbetad * expElogthetad :+ 1e-100
      +      meanGammaChange = sum(abs(gammad - lastgamma)) / k
      +    }
      +
      +    val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector :/ phiNorm).asDenseMatrix
      +    (gammad, sstatsd)
         }
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala
      new file mode 100644
      index 000000000000..a9ba7b60bad0
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala
      @@ -0,0 +1,55 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +package org.apache.spark.mllib.clustering
      +
      +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, max, sum}
      +import breeze.numerics._
      +
      +/**
      + * Utility methods for LDA.
      + */
      +private[clustering] object LDAUtils {
      +  /**
      +   * Log Sum Exp with overflow protection using the identity:
      +   * For any a: \log \sum_{n=1}^N \exp\{x_n\} = a + \log \sum_{n=1}^N \exp\{x_n - a\}
      +   */
      +  private[clustering] def logSumExp(x: BDV[Double]): Double = {
      +    val a = max(x)
      +    a + log(sum(exp(x :- a)))
      +  }
      +
      +  /**
      +   * For theta ~ Dir(alpha), computes E[log(theta)] given alpha. Currently the implementation
      +   * uses [[breeze.numerics.digamma]] which is accurate but expensive.
      +   */
      +  private[clustering] def dirichletExpectation(alpha: BDV[Double]): BDV[Double] = {
      +    digamma(alpha) - digamma(sum(alpha))
      +  }
      +
      +  /**
      +   * Computes [[dirichletExpectation()]] row-wise, assuming each row of alpha are
      +   * Dirichlet parameters.
      +   */
      +  private[clustering] def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = {
      +    val rowSum = sum(alpha(breeze.linalg.*, ::))
      +    val digAlpha = digamma(alpha)
      +    val digRowSum = digamma(rowSum)
      +    val result = digAlpha(::, breeze.linalg.*) - digRowSum
      +    result
      +  }
      +
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
      index e7a243f854e3..6c76e26fd162 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
      @@ -21,7 +21,7 @@ import org.json4s.JsonDSL._
       import org.json4s._
       import org.json4s.jackson.JsonMethods._
       
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.api.java.JavaRDD
       import org.apache.spark.graphx._
       import org.apache.spark.graphx.impl.GraphImpl
      @@ -40,11 +40,14 @@ import org.apache.spark.{Logging, SparkContext, SparkException}
        * @param k number of clusters
        * @param assignments an RDD of clustering [[PowerIterationClustering#Assignment]]s
        */
      +@Since("1.3.0")
       @Experimental
      -class PowerIterationClusteringModel(
      -    val k: Int,
      -    val assignments: RDD[PowerIterationClustering.Assignment]) extends Saveable with Serializable {
      +class PowerIterationClusteringModel @Since("1.3.0") (
      +    @Since("1.3.0") val k: Int,
      +    @Since("1.3.0") val assignments: RDD[PowerIterationClustering.Assignment])
      +  extends Saveable with Serializable {
       
      +  @Since("1.4.0")
         override def save(sc: SparkContext, path: String): Unit = {
           PowerIterationClusteringModel.SaveLoadV1_0.save(sc, this, path)
         }
      @@ -52,7 +55,10 @@ class PowerIterationClusteringModel(
         override protected def formatVersion: String = "1.0"
       }
       
      +@Since("1.4.0")
       object PowerIterationClusteringModel extends Loader[PowerIterationClusteringModel] {
      +
      +  @Since("1.4.0")
         override def load(sc: SparkContext, path: String): PowerIterationClusteringModel = {
           PowerIterationClusteringModel.SaveLoadV1_0.load(sc, path)
         }
      @@ -65,6 +71,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode
           private[clustering]
           val thisClassName = "org.apache.spark.mllib.clustering.PowerIterationClusteringModel"
       
      +    @Since("1.4.0")
           def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = {
             val sqlContext = new SQLContext(sc)
             import sqlContext.implicits._
      @@ -77,6 +84,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode
             dataRDD.write.parquet(Loader.dataPath(path))
           }
       
      +    @Since("1.4.0")
           def load(sc: SparkContext, path: String): PowerIterationClusteringModel = {
             implicit val formats = DefaultFormats
             val sqlContext = new SQLContext(sc)
      @@ -113,6 +121,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode
        * @see [[http://en.wikipedia.org/wiki/Spectral_clustering Spectral clustering (Wikipedia)]]
        */
       @Experimental
      +@Since("1.3.0")
       class PowerIterationClustering private[clustering] (
           private var k: Int,
           private var maxIterations: Int,
      @@ -120,14 +129,17 @@ class PowerIterationClustering private[clustering] (
       
         import org.apache.spark.mllib.clustering.PowerIterationClustering._
       
      -  /** Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100,
      -   *  initMode: "random"}.
      +  /**
      +   * Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100,
      +   * initMode: "random"}.
          */
      +  @Since("1.3.0")
         def this() = this(k = 2, maxIterations = 100, initMode = "random")
       
         /**
          * Set the number of clusters.
          */
      +  @Since("1.3.0")
         def setK(k: Int): this.type = {
           this.k = k
           this
      @@ -136,6 +148,7 @@ class PowerIterationClustering private[clustering] (
         /**
          * Set maximum number of iterations of the power iteration loop
          */
      +  @Since("1.3.0")
         def setMaxIterations(maxIterations: Int): this.type = {
           this.maxIterations = maxIterations
           this
      @@ -145,6 +158,7 @@ class PowerIterationClustering private[clustering] (
          * Set the initialization mode. This can be either "random" to use a random vector
          * as vertex properties, or "degree" to use normalized sum similarities. Default: random.
          */
      +  @Since("1.3.0")
         def setInitializationMode(mode: String): this.type = {
           this.initMode = mode match {
             case "random" | "degree" => mode
      @@ -153,6 +167,28 @@ class PowerIterationClustering private[clustering] (
           this
         }
       
      +  /**
      +   * Run the PIC algorithm on Graph.
      +   *
      +   * @param graph an affinity matrix represented as graph, which is the matrix A in the PIC paper.
      +   *              The similarity s,,ij,, represented as the edge between vertices (i, j) must
      +   *              be nonnegative. This is a symmetric matrix and hence s,,ij,, = s,,ji,,. For
      +   *              any (i, j) with nonzero similarity, there should be either (i, j, s,,ij,,)
      +   *              or (j, i, s,,ji,,) in the input. Tuples with i = j are ignored, because we
      +   *              assume s,,ij,, = 0.0.
      +   *
      +   * @return a [[PowerIterationClusteringModel]] that contains the clustering result
      +   */
      +  @Since("1.5.0")
      +  def run(graph: Graph[Double, Double]): PowerIterationClusteringModel = {
      +    val w = normalize(graph)
      +    val w0 = initMode match {
      +      case "random" => randomInit(w)
      +      case "degree" => initDegreeVector(w)
      +    }
      +    pic(w0)
      +  }
      +
         /**
          * Run the PIC algorithm.
          *
      @@ -165,6 +201,7 @@ class PowerIterationClustering private[clustering] (
          *
          * @return a [[PowerIterationClusteringModel]] that contains the clustering result
          */
      +  @Since("1.3.0")
         def run(similarities: RDD[(Long, Long, Double)]): PowerIterationClusteringModel = {
           val w = normalize(similarities)
           val w0 = initMode match {
      @@ -177,6 +214,7 @@ class PowerIterationClustering private[clustering] (
         /**
          * A Java-friendly version of [[PowerIterationClustering.run]].
          */
      +  @Since("1.3.0")
         def run(similarities: JavaRDD[(java.lang.Long, java.lang.Long, java.lang.Double)])
           : PowerIterationClusteringModel = {
           run(similarities.rdd.asInstanceOf[RDD[(Long, Long, Double)]])
      @@ -200,6 +238,7 @@ class PowerIterationClustering private[clustering] (
         }
       }
       
      +@Since("1.3.0")
       @Experimental
       object PowerIterationClustering extends Logging {
       
      @@ -209,9 +248,35 @@ object PowerIterationClustering extends Logging {
          * @param id node id
          * @param cluster assigned cluster id
          */
      +  @Since("1.3.0")
         @Experimental
         case class Assignment(id: Long, cluster: Int)
       
      +  /**
      +   * Normalizes the affinity graph (A) and returns the normalized affinity matrix (W).
      +   */
      +  private[clustering]
      +  def normalize(graph: Graph[Double, Double]): Graph[Double, Double] = {
      +    val vD = graph.aggregateMessages[Double](
      +      sendMsg = ctx => {
      +        val i = ctx.srcId
      +        val j = ctx.dstId
      +        val s = ctx.attr
      +        if (s < 0.0) {
      +          throw new SparkException("Similarity must be nonnegative but found s($i, $j) = $s.")
      +        }
      +        if (s > 0.0) {
      +          ctx.sendToSrc(s)
      +        }
      +      },
      +      mergeMsg = _ + _,
      +      TripletFields.EdgeOnly)
      +    GraphImpl.fromExistingRDDs(vD, graph.edges)
      +      .mapTriplets(
      +        e => e.attr / math.max(e.srcAttr, MLUtils.EPSILON),
      +        TripletFields.Src)
      +  }
      +
         /**
          * Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W).
          */
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
      index d9b34cec6489..1d50ffec96fa 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
      @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering
       import scala.reflect.ClassTag
       
       import org.apache.spark.Logging
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.api.java.JavaSparkContext._
       import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
       import org.apache.spark.rdd.RDD
      @@ -63,14 +63,18 @@ import org.apache.spark.util.random.XORShiftRandom
        * such that at time t + h the discount applied to the data from t is 0.5.
        * The definition remains the same whether the time unit is given
        * as batches or points.
      - *
        */
      +@Since("1.2.0")
       @Experimental
      -class StreamingKMeansModel(
      -    override val clusterCenters: Array[Vector],
      -    val clusterWeights: Array[Double]) extends KMeansModel(clusterCenters) with Logging {
      +class StreamingKMeansModel @Since("1.2.0") (
      +    @Since("1.2.0") override val clusterCenters: Array[Vector],
      +    @Since("1.2.0") val clusterWeights: Array[Double])
      +  extends KMeansModel(clusterCenters) with Logging {
       
      -  /** Perform a k-means update on a batch of data. */
      +  /**
      +   * Perform a k-means update on a batch of data.
      +   */
      +  @Since("1.2.0")
         def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = {
       
           // find nearest cluster to each point
      @@ -82,6 +86,7 @@ class StreamingKMeansModel(
             (p1._1, p1._2 + p2._2)
           }
           val dim = clusterCenters(0).size
      +
           val pointStats: Array[(Int, (Vector, Long))] = closest
             .aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs)
             .collect()
      @@ -162,29 +167,40 @@ class StreamingKMeansModel(
        *    .trainOn(DStream)
        * }}}
        */
      +@Since("1.2.0")
       @Experimental
      -class StreamingKMeans(
      -    var k: Int,
      -    var decayFactor: Double,
      -    var timeUnit: String) extends Logging with Serializable {
      +class StreamingKMeans @Since("1.2.0") (
      +    @Since("1.2.0") var k: Int,
      +    @Since("1.2.0") var decayFactor: Double,
      +    @Since("1.2.0") var timeUnit: String) extends Logging with Serializable {
       
      +  @Since("1.2.0")
         def this() = this(2, 1.0, StreamingKMeans.BATCHES)
       
         protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null)
       
      -  /** Set the number of clusters. */
      +  /**
      +   * Set the number of clusters.
      +   */
      +  @Since("1.2.0")
         def setK(k: Int): this.type = {
           this.k = k
           this
         }
       
      -  /** Set the decay factor directly (for forgetful algorithms). */
      +  /**
      +   * Set the decay factor directly (for forgetful algorithms).
      +   */
      +  @Since("1.2.0")
         def setDecayFactor(a: Double): this.type = {
           this.decayFactor = a
           this
         }
       
      -  /** Set the half life and time unit ("batches" or "points") for forgetful algorithms. */
      +  /**
      +   * Set the half life and time unit ("batches" or "points") for forgetful algorithms.
      +   */
      +  @Since("1.2.0")
         def setHalfLife(halfLife: Double, timeUnit: String): this.type = {
           if (timeUnit != StreamingKMeans.BATCHES && timeUnit != StreamingKMeans.POINTS) {
             throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit)
      @@ -195,7 +211,10 @@ class StreamingKMeans(
           this
         }
       
      -  /** Specify initial centers directly. */
      +  /**
      +   * Specify initial centers directly.
      +   */
      +  @Since("1.2.0")
         def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = {
           model = new StreamingKMeansModel(centers, weights)
           this
      @@ -208,6 +227,7 @@ class StreamingKMeans(
          * @param weight Weight for each center
          * @param seed Random seed
          */
      +  @Since("1.2.0")
         def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = {
           val random = new XORShiftRandom(seed)
           val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian())))
      @@ -216,7 +236,10 @@ class StreamingKMeans(
           this
         }
       
      -  /** Return the latest model. */
      +  /**
      +   * Return the latest model.
      +   */
      +  @Since("1.2.0")
         def latestModel(): StreamingKMeansModel = {
           model
         }
      @@ -229,6 +252,7 @@ class StreamingKMeans(
          *
          * @param data DStream containing vector data
          */
      +  @Since("1.2.0")
         def trainOn(data: DStream[Vector]) {
           assertInitialized()
           data.foreachRDD { (rdd, time) =>
      @@ -236,7 +260,10 @@ class StreamingKMeans(
           }
         }
       
      -  /** Java-friendly version of `trainOn`. */
      +  /**
      +   * Java-friendly version of `trainOn`.
      +   */
      +  @Since("1.4.0")
         def trainOn(data: JavaDStream[Vector]): Unit = trainOn(data.dstream)
       
         /**
      @@ -245,12 +272,16 @@ class StreamingKMeans(
          * @param data DStream containing vector data
          * @return DStream containing predictions
          */
      +  @Since("1.2.0")
         def predictOn(data: DStream[Vector]): DStream[Int] = {
           assertInitialized()
           data.map(model.predict)
         }
       
      -  /** Java-friendly version of `predictOn`. */
      +  /**
      +   * Java-friendly version of `predictOn`.
      +   */
      +  @Since("1.4.0")
         def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Integer] = {
           JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Integer]])
         }
      @@ -262,12 +293,16 @@ class StreamingKMeans(
          * @tparam K key type
          * @return DStream containing the input keys and the predictions as values
          */
      +  @Since("1.2.0")
         def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = {
           assertInitialized()
           data.mapValues(model.predict)
         }
       
      -  /** Java-friendly version of `predictOnValues`. */
      +  /**
      +   * Java-friendly version of `predictOnValues`.
      +   */
      +  @Since("1.4.0")
         def predictOnValues[K](
             data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Integer] = {
           implicit val tag = fakeClassTag[K]
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
      index c1d1a224817e..508fe532b130 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
      @@ -17,7 +17,7 @@
       
       package org.apache.spark.mllib.evaluation
       
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.Logging
       import org.apache.spark.SparkContext._
       import org.apache.spark.mllib.evaluation.binary._
      @@ -42,16 +42,18 @@ import org.apache.spark.sql.DataFrame
        *                be smaller as a result, meaning there may be an extra sample at
        *                partition boundaries.
        */
      +@Since("1.0.0")
       @Experimental
      -class BinaryClassificationMetrics(
      -    val scoreAndLabels: RDD[(Double, Double)],
      -    val numBins: Int) extends Logging {
      +class BinaryClassificationMetrics @Since("1.3.0") (
      +    @Since("1.3.0") val scoreAndLabels: RDD[(Double, Double)],
      +    @Since("1.3.0") val numBins: Int) extends Logging {
       
         require(numBins >= 0, "numBins must be nonnegative")
       
         /**
          * Defaults `numBins` to 0.
          */
      +  @Since("1.0.0")
         def this(scoreAndLabels: RDD[(Double, Double)]) = this(scoreAndLabels, 0)
       
         /**
      @@ -61,12 +63,18 @@ class BinaryClassificationMetrics(
         private[mllib] def this(scoreAndLabels: DataFrame) =
           this(scoreAndLabels.map(r => (r.getDouble(0), r.getDouble(1))))
       
      -  /** Unpersist intermediate RDDs used in the computation. */
      +  /**
      +   * Unpersist intermediate RDDs used in the computation.
      +   */
      +  @Since("1.0.0")
         def unpersist() {
           cumulativeCounts.unpersist()
         }
       
      -  /** Returns thresholds in descending order. */
      +  /**
      +   * Returns thresholds in descending order.
      +   */
      +  @Since("1.0.0")
         def thresholds(): RDD[Double] = cumulativeCounts.map(_._1)
       
         /**
      @@ -75,6 +83,7 @@ class BinaryClassificationMetrics(
          * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
          * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic
          */
      +  @Since("1.0.0")
         def roc(): RDD[(Double, Double)] = {
           val rocCurve = createCurve(FalsePositiveRate, Recall)
           val sc = confusions.context
      @@ -86,6 +95,7 @@ class BinaryClassificationMetrics(
         /**
          * Computes the area under the receiver operating characteristic (ROC) curve.
          */
      +  @Since("1.0.0")
         def areaUnderROC(): Double = AreaUnderCurve.of(roc())
       
         /**
      @@ -93,6 +103,7 @@ class BinaryClassificationMetrics(
          * NOT (precision, recall), with (0.0, 1.0) prepended to it.
          * @see http://en.wikipedia.org/wiki/Precision_and_recall
          */
      +  @Since("1.0.0")
         def pr(): RDD[(Double, Double)] = {
           val prCurve = createCurve(Recall, Precision)
           val sc = confusions.context
      @@ -103,6 +114,7 @@ class BinaryClassificationMetrics(
         /**
          * Computes the area under the precision-recall curve.
          */
      +  @Since("1.0.0")
         def areaUnderPR(): Double = AreaUnderCurve.of(pr())
       
         /**
      @@ -111,15 +123,25 @@ class BinaryClassificationMetrics(
          * @return an RDD of (threshold, F-Measure) pairs.
          * @see http://en.wikipedia.org/wiki/F1_score
          */
      +  @Since("1.0.0")
         def fMeasureByThreshold(beta: Double): RDD[(Double, Double)] = createCurve(FMeasure(beta))
       
      -  /** Returns the (threshold, F-Measure) curve with beta = 1.0. */
      +  /**
      +   * Returns the (threshold, F-Measure) curve with beta = 1.0.
      +   */
      +  @Since("1.0.0")
         def fMeasureByThreshold(): RDD[(Double, Double)] = fMeasureByThreshold(1.0)
       
      -  /** Returns the (threshold, precision) curve. */
      +  /**
      +   * Returns the (threshold, precision) curve.
      +   */
      +  @Since("1.0.0")
         def precisionByThreshold(): RDD[(Double, Double)] = createCurve(Precision)
       
      -  /** Returns the (threshold, recall) curve. */
      +  /**
      +   * Returns the (threshold, recall) curve.
      +   */
      +  @Since("1.0.0")
         def recallByThreshold(): RDD[(Double, Double)] = createCurve(Recall)
       
         private lazy val (
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
      index 4628dc569091..00e837661dfc 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
      @@ -20,7 +20,7 @@ package org.apache.spark.mllib.evaluation
       import scala.collection.Map
       
       import org.apache.spark.SparkContext._
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.mllib.linalg.{Matrices, Matrix}
       import org.apache.spark.rdd.RDD
       import org.apache.spark.sql.DataFrame
      @@ -31,8 +31,9 @@ import org.apache.spark.sql.DataFrame
        *
        * @param predictionAndLabels an RDD of (prediction, label) pairs.
        */
      +@Since("1.1.0")
       @Experimental
      -class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
      +class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Double)]) {
       
         /**
          * An auxiliary constructor taking a DataFrame.
      @@ -65,6 +66,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
          * they are ordered by class label ascending,
          * as in "labels"
          */
      +  @Since("1.1.0")
         def confusionMatrix: Matrix = {
           val n = labels.size
           val values = Array.ofDim[Double](n * n)
      @@ -84,12 +86,14 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
          * Returns true positive rate for a given label (category)
          * @param label the label.
          */
      +  @Since("1.1.0")
         def truePositiveRate(label: Double): Double = recall(label)
       
         /**
          * Returns false positive rate for a given label (category)
          * @param label the label.
          */
      +  @Since("1.1.0")
         def falsePositiveRate(label: Double): Double = {
           val fp = fpByClass.getOrElse(label, 0)
           fp.toDouble / (labelCount - labelCountByClass(label))
      @@ -99,6 +103,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
          * Returns precision for a given label (category)
          * @param label the label.
          */
      +  @Since("1.1.0")
         def precision(label: Double): Double = {
           val tp = tpByClass(label)
           val fp = fpByClass.getOrElse(label, 0)
      @@ -109,6 +114,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
          * Returns recall for a given label (category)
          * @param label the label.
          */
      +  @Since("1.1.0")
         def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label)
       
         /**
      @@ -116,6 +122,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
          * @param label the label.
          * @param beta the beta parameter.
          */
      +  @Since("1.1.0")
         def fMeasure(label: Double, beta: Double): Double = {
           val p = precision(label)
           val r = recall(label)
      @@ -127,11 +134,13 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
          * Returns f1-measure for a given label (category)
          * @param label the label.
          */
      +  @Since("1.1.0")
         def fMeasure(label: Double): Double = fMeasure(label, 1.0)
       
         /**
          * Returns precision
          */
      +  @Since("1.1.0")
         lazy val precision: Double = tpByClass.values.sum.toDouble / labelCount
       
         /**
      @@ -140,23 +149,27 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
          * because sum of all false positives is equal to sum
          * of all false negatives)
          */
      +  @Since("1.1.0")
         lazy val recall: Double = precision
       
         /**
          * Returns f-measure
          * (equals to precision and recall because precision equals recall)
          */
      +  @Since("1.1.0")
         lazy val fMeasure: Double = precision
       
         /**
          * Returns weighted true positive rate
          * (equals to precision, recall and f-measure)
          */
      +  @Since("1.1.0")
         lazy val weightedTruePositiveRate: Double = weightedRecall
       
         /**
          * Returns weighted false positive rate
          */
      +  @Since("1.1.0")
         lazy val weightedFalsePositiveRate: Double = labelCountByClass.map { case (category, count) =>
           falsePositiveRate(category) * count.toDouble / labelCount
         }.sum
      @@ -165,6 +178,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
          * Returns weighted averaged recall
          * (equals to precision, recall and f-measure)
          */
      +  @Since("1.1.0")
         lazy val weightedRecall: Double = labelCountByClass.map { case (category, count) =>
           recall(category) * count.toDouble / labelCount
         }.sum
      @@ -172,6 +186,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
         /**
          * Returns weighted averaged precision
          */
      +  @Since("1.1.0")
         lazy val weightedPrecision: Double = labelCountByClass.map { case (category, count) =>
           precision(category) * count.toDouble / labelCount
         }.sum
      @@ -180,6 +195,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
          * Returns weighted averaged f-measure
          * @param beta the beta parameter.
          */
      +  @Since("1.1.0")
         def weightedFMeasure(beta: Double): Double = labelCountByClass.map { case (category, count) =>
           fMeasure(category, beta) * count.toDouble / labelCount
         }.sum
      @@ -187,6 +203,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
         /**
          * Returns weighted averaged f1-measure
          */
      +  @Since("1.1.0")
         lazy val weightedFMeasure: Double = labelCountByClass.map { case (category, count) =>
           fMeasure(category, 1.0) * count.toDouble / labelCount
         }.sum
      @@ -194,5 +211,6 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
         /**
          * Returns the sequence of labels in ascending order
          */
      +  @Since("1.1.0")
         lazy val labels: Array[Double] = tpByClass.keys.toArray.sorted
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
      index bf6eb1d5bd2a..c100b3c9ec14 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
      @@ -17,6 +17,7 @@
       
       package org.apache.spark.mllib.evaluation
       
      +import org.apache.spark.annotation.Since
       import org.apache.spark.rdd.RDD
       import org.apache.spark.SparkContext._
       import org.apache.spark.sql.DataFrame
      @@ -26,7 +27,8 @@ import org.apache.spark.sql.DataFrame
        * @param predictionAndLabels an RDD of (predictions, labels) pairs,
        * both are non-null Arrays, each with unique elements.
        */
      -class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) {
      +@Since("1.2.0")
      +class MultilabelMetrics @Since("1.2.0") (predictionAndLabels: RDD[(Array[Double], Array[Double])]) {
       
         /**
          * An auxiliary constructor taking a DataFrame.
      @@ -44,6 +46,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
          * Returns subset accuracy
          * (for equal sets of labels)
          */
      +  @Since("1.2.0")
         lazy val subsetAccuracy: Double = predictionAndLabels.filter { case (predictions, labels) =>
           predictions.deep == labels.deep
         }.count().toDouble / numDocs
      @@ -51,6 +54,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
         /**
          * Returns accuracy
          */
      +  @Since("1.2.0")
         lazy val accuracy: Double = predictionAndLabels.map { case (predictions, labels) =>
           labels.intersect(predictions).size.toDouble /
             (labels.size + predictions.size - labels.intersect(predictions).size)}.sum / numDocs
      @@ -59,6 +63,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
         /**
          * Returns Hamming-loss
          */
      +  @Since("1.2.0")
         lazy val hammingLoss: Double = predictionAndLabels.map { case (predictions, labels) =>
           labels.size + predictions.size - 2 * labels.intersect(predictions).size
         }.sum / (numDocs * numLabels)
      @@ -66,6 +71,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
         /**
          * Returns document-based precision averaged by the number of documents
          */
      +  @Since("1.2.0")
         lazy val precision: Double = predictionAndLabels.map { case (predictions, labels) =>
           if (predictions.size > 0) {
             predictions.intersect(labels).size.toDouble / predictions.size
      @@ -77,6 +83,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
         /**
          * Returns document-based recall averaged by the number of documents
          */
      +  @Since("1.2.0")
         lazy val recall: Double = predictionAndLabels.map { case (predictions, labels) =>
           labels.intersect(predictions).size.toDouble / labels.size
         }.sum / numDocs
      @@ -84,6 +91,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
         /**
          * Returns document-based f1-measure averaged by the number of documents
          */
      +  @Since("1.2.0")
         lazy val f1Measure: Double = predictionAndLabels.map { case (predictions, labels) =>
           2.0 * predictions.intersect(labels).size / (predictions.size + labels.size)
         }.sum / numDocs
      @@ -104,6 +112,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
          * Returns precision for a given label (category)
          * @param label the label.
          */
      +  @Since("1.2.0")
         def precision(label: Double): Double = {
           val tp = tpPerClass(label)
           val fp = fpPerClass.getOrElse(label, 0L)
      @@ -114,6 +123,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
          * Returns recall for a given label (category)
          * @param label the label.
          */
      +  @Since("1.2.0")
         def recall(label: Double): Double = {
           val tp = tpPerClass(label)
           val fn = fnPerClass.getOrElse(label, 0L)
      @@ -124,6 +134,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
          * Returns f1-measure for a given label (category)
          * @param label the label.
          */
      +  @Since("1.2.0")
         def f1Measure(label: Double): Double = {
           val p = precision(label)
           val r = recall(label)
      @@ -138,6 +149,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
          * Returns micro-averaged label-based precision
          * (equals to micro-averaged document-based precision)
          */
      +  @Since("1.2.0")
         lazy val microPrecision: Double = {
           val sumFp = fpPerClass.foldLeft(0L){ case(cum, (_, fp)) => cum + fp}
           sumTp.toDouble / (sumTp + sumFp)
      @@ -147,6 +159,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
          * Returns micro-averaged label-based recall
          * (equals to micro-averaged document-based recall)
          */
      +  @Since("1.2.0")
         lazy val microRecall: Double = {
           val sumFn = fnPerClass.foldLeft(0.0){ case(cum, (_, fn)) => cum + fn}
           sumTp.toDouble / (sumTp + sumFn)
      @@ -156,10 +169,12 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
          * Returns micro-averaged label-based f1-measure
          * (equals to micro-averaged document-based f1-measure)
          */
      +  @Since("1.2.0")
         lazy val microF1Measure: Double = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass)
       
         /**
          * Returns the sequence of labels in ascending order
          */
      +  @Since("1.2.0")
         lazy val labels: Array[Double] = tpPerClass.keys.toArray.sorted
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
      index 5b5a2a1450f7..a7f43f0b110f 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
      @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
       import scala.reflect.ClassTag
       
       import org.apache.spark.Logging
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.api.java.{JavaSparkContext, JavaRDD}
       import org.apache.spark.rdd.RDD
       
      @@ -35,6 +35,7 @@ import org.apache.spark.rdd.RDD
        *
        * @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs.
        */
      +@Since("1.2.0")
       @Experimental
       class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])])
         extends Logging with Serializable {
      @@ -56,6 +57,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
          * @param k the position to compute the truncated precision, must be positive
          * @return the average precision at the first k ranking positions
          */
      +  @Since("1.2.0")
         def precisionAt(k: Int): Double = {
           require(k > 0, "ranking position k should be positive")
           predictionAndLabels.map { case (pred, lab) =>
      @@ -125,6 +127,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
          * @param k the position to compute the truncated ndcg, must be positive
          * @return the average ndcg at the first k ranking positions
          */
      +  @Since("1.2.0")
         def ndcgAt(k: Int): Double = {
           require(k > 0, "ranking position k should be positive")
           predictionAndLabels.map { case (pred, lab) =>
      @@ -163,6 +166,7 @@ object RankingMetrics {
          * Creates a [[RankingMetrics]] instance (for Java users).
          * @param predictionAndLabels a JavaRDD of (predicted ranking, ground truth set) pairs
          */
      +  @Since("1.4.0")
         def of[E, T <: jl.Iterable[E]](predictionAndLabels: JavaRDD[(T, T)]): RankingMetrics[E] = {
           implicit val tag = JavaSparkContext.fakeClassTag[E]
           val rdd = predictionAndLabels.rdd.map { case (predictions, labels) =>
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
      index e577bf87f885..799ebb980ef0 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
      @@ -17,7 +17,7 @@
       
       package org.apache.spark.mllib.evaluation
       
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.rdd.RDD
       import org.apache.spark.Logging
       import org.apache.spark.mllib.linalg.Vectors
      @@ -30,8 +30,10 @@ import org.apache.spark.sql.DataFrame
        *
        * @param predictionAndObservations an RDD of (prediction, observation) pairs.
        */
      +@Since("1.2.0")
       @Experimental
      -class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extends Logging {
      +class RegressionMetrics @Since("1.2.0") (
      +    predictionAndObservations: RDD[(Double, Double)]) extends Logging {
       
         /**
          * An auxiliary constructor taking a DataFrame.
      @@ -53,20 +55,30 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend
             )
           summary
         }
      +  private lazy val SSerr = math.pow(summary.normL2(1), 2)
      +  private lazy val SStot = summary.variance(0) * (summary.count - 1)
      +  private lazy val SSreg = {
      +    val yMean = summary.mean(0)
      +    predictionAndObservations.map {
      +      case (prediction, _) => math.pow(prediction - yMean, 2)
      +    }.sum()
      +  }
       
         /**
      -   * Returns the explained variance regression score.
      -   * explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
      -   * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]]
      +   * Returns the variance explained by regression.
      +   * explainedVariance = \sum_i (\hat{y_i} - \bar{y})^2 / n
      +   * @see [[https://en.wikipedia.org/wiki/Fraction_of_variance_unexplained]]
          */
      +  @Since("1.2.0")
         def explainedVariance: Double = {
      -    1 - summary.variance(1) / summary.variance(0)
      +    SSreg / summary.count
         }
       
         /**
          * Returns the mean absolute error, which is a risk function corresponding to the
          * expected value of the absolute error loss or l1-norm loss.
          */
      +  @Since("1.2.0")
         def meanAbsoluteError: Double = {
           summary.normL1(1) / summary.count
         }
      @@ -75,24 +87,26 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend
          * Returns the mean squared error, which is a risk function corresponding to the
          * expected value of the squared error loss or quadratic loss.
          */
      +  @Since("1.2.0")
         def meanSquaredError: Double = {
      -    val rmse = summary.normL2(1) / math.sqrt(summary.count)
      -    rmse * rmse
      +    SSerr / summary.count
         }
       
         /**
          * Returns the root mean squared error, which is defined as the square root of
          * the mean squared error.
          */
      +  @Since("1.2.0")
         def rootMeanSquaredError: Double = {
      -    summary.normL2(1) / math.sqrt(summary.count)
      +    math.sqrt(this.meanSquaredError)
         }
       
         /**
      -   * Returns R^2^, the coefficient of determination.
      -   * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
      +   * Returns R^2^, the unadjusted coefficient of determination.
      +   * @see [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
          */
      +  @Since("1.2.0")
         def r2: Double = {
      -    1 - math.pow(summary.normL2(1), 2) / (summary.variance(0) * (summary.count - 1))
      +    1 - SSerr / SStot
         }
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
      index 5f8c1dea237b..4743cfd1a2c3 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
      @@ -19,7 +19,7 @@ package org.apache.spark.mllib.feature
       
       import scala.collection.mutable.ArrayBuilder
       
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
       import org.apache.spark.mllib.regression.LabeledPoint
       import org.apache.spark.mllib.stat.Statistics
      @@ -31,8 +31,10 @@ import org.apache.spark.rdd.RDD
        *
        * @param selectedFeatures list of indices to select (filter). Must be ordered asc
        */
      +@Since("1.3.0")
       @Experimental
      -class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransformer {
      +class ChiSqSelectorModel @Since("1.3.0") (
      +  @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer {
       
         require(isSorted(selectedFeatures), "Array has to be sorted asc")
       
      @@ -52,6 +54,7 @@ class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransf
          * @param vector vector to be transformed.
          * @return transformed vector.
          */
      +  @Since("1.3.0")
         override def transform(vector: Vector): Vector = {
           compress(vector, selectedFeatures)
         }
      @@ -107,8 +110,10 @@ class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransf
        * @param numTopFeatures number of features that selector will select
        *                       (ordered by statistic value descending)
        */
      +@Since("1.3.0")
       @Experimental
      -class ChiSqSelector (val numTopFeatures: Int) extends Serializable {
      +class ChiSqSelector @Since("1.3.0") (
      +  @Since("1.3.0") val numTopFeatures: Int) extends Serializable {
       
         /**
          * Returns a ChiSquared feature selector.
      @@ -117,6 +122,7 @@ class ChiSqSelector (val numTopFeatures: Int) extends Serializable {
          *             Real-valued features will be treated as categorical for each distinct value.
          *             Apply feature discretizer before using this function.
          */
      +  @Since("1.3.0")
         def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = {
           val indices = Statistics.chiSqTest(data)
             .zipWithIndex.sortBy { case (res, _) => -res.statistic }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala
      index d67fe6c3ee4f..d0a6cf61687a 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala
      @@ -17,7 +17,7 @@
       
       package org.apache.spark.mllib.feature
       
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.mllib.linalg._
       
       /**
      @@ -27,8 +27,10 @@ import org.apache.spark.mllib.linalg._
        * multiplier.
        * @param scalingVec The values used to scale the reference vector's individual components.
        */
      +@Since("1.4.0")
       @Experimental
      -class ElementwiseProduct(val scalingVec: Vector) extends VectorTransformer {
      +class ElementwiseProduct @Since("1.4.0") (
      +    @Since("1.4.0") val scalingVec: Vector) extends VectorTransformer {
       
         /**
          * Does the hadamard product transformation.
      @@ -36,6 +38,7 @@ class ElementwiseProduct(val scalingVec: Vector) extends VectorTransformer {
          * @param vector vector to be transformed.
          * @return transformed vector.
          */
      +  @Since("1.4.0")
         override def transform(vector: Vector): Vector = {
           require(vector.size == scalingVec.size,
             s"vector sizes do not match: Expected ${scalingVec.size} but found ${vector.size}")
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
      index c53475818395..e47d524b6162 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
      @@ -22,7 +22,7 @@ import java.lang.{Iterable => JavaIterable}
       import scala.collection.JavaConverters._
       import scala.collection.mutable
       
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.api.java.JavaRDD
       import org.apache.spark.mllib.linalg.{Vector, Vectors}
       import org.apache.spark.rdd.RDD
      @@ -34,19 +34,25 @@ import org.apache.spark.util.Utils
        *
        * @param numFeatures number of features (default: 2^20^)
        */
      +@Since("1.1.0")
       @Experimental
       class HashingTF(val numFeatures: Int) extends Serializable {
       
      +  /**
      +   */
      +  @Since("1.1.0")
         def this() = this(1 << 20)
       
         /**
          * Returns the index of the input term.
          */
      +  @Since("1.1.0")
         def indexOf(term: Any): Int = Utils.nonNegativeMod(term.##, numFeatures)
       
         /**
          * Transforms the input document into a sparse term frequency vector.
          */
      +  @Since("1.1.0")
         def transform(document: Iterable[_]): Vector = {
           val termFrequencies = mutable.HashMap.empty[Int, Double]
           document.foreach { term =>
      @@ -59,6 +65,7 @@ class HashingTF(val numFeatures: Int) extends Serializable {
         /**
          * Transforms the input document into a sparse term frequency vector (Java version).
          */
      +  @Since("1.1.0")
         def transform(document: JavaIterable[_]): Vector = {
           transform(document.asScala)
         }
      @@ -66,6 +73,7 @@ class HashingTF(val numFeatures: Int) extends Serializable {
         /**
          * Transforms the input document to term frequency vectors.
          */
      +  @Since("1.1.0")
         def transform[D <: Iterable[_]](dataset: RDD[D]): RDD[Vector] = {
           dataset.map(this.transform)
         }
      @@ -73,6 +81,7 @@ class HashingTF(val numFeatures: Int) extends Serializable {
         /**
          * Transforms the input document to term frequency vectors (Java version).
          */
      +  @Since("1.1.0")
         def transform[D <: JavaIterable[_]](dataset: JavaRDD[D]): JavaRDD[Vector] = {
           dataset.rdd.map(this.transform).toJavaRDD()
         }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
      index 3fab7ea79bef..68078ccfa3d6 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
      @@ -19,7 +19,7 @@ package org.apache.spark.mllib.feature
       
       import breeze.linalg.{DenseVector => BDV}
       
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.api.java.JavaRDD
       import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
       import org.apache.spark.rdd.RDD
      @@ -37,9 +37,11 @@ import org.apache.spark.rdd.RDD
        * @param minDocFreq minimum of documents in which a term
        *                   should appear for filtering
        */
      +@Since("1.1.0")
       @Experimental
      -class IDF(val minDocFreq: Int) {
      +class IDF @Since("1.2.0") (@Since("1.2.0") val minDocFreq: Int) {
       
      +  @Since("1.1.0")
         def this() = this(0)
       
         // TODO: Allow different IDF formulations.
      @@ -48,6 +50,7 @@ class IDF(val minDocFreq: Int) {
          * Computes the inverse document frequency.
          * @param dataset an RDD of term frequency vectors
          */
      +  @Since("1.1.0")
         def fit(dataset: RDD[Vector]): IDFModel = {
           val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator(
                 minDocFreq = minDocFreq))(
      @@ -61,6 +64,7 @@ class IDF(val minDocFreq: Int) {
          * Computes the inverse document frequency.
          * @param dataset a JavaRDD of term frequency vectors
          */
      +  @Since("1.1.0")
         def fit(dataset: JavaRDD[Vector]): IDFModel = {
           fit(dataset.rdd)
         }
      @@ -159,7 +163,8 @@ private object IDF {
        * Represents an IDF model that can transform term frequency vectors.
        */
       @Experimental
      -class IDFModel private[spark] (val idf: Vector) extends Serializable {
      +@Since("1.1.0")
      +class IDFModel private[spark] (@Since("1.1.0") val idf: Vector) extends Serializable {
       
         /**
          * Transforms term frequency (TF) vectors to TF-IDF vectors.
      @@ -171,6 +176,7 @@ class IDFModel private[spark] (val idf: Vector) extends Serializable {
          * @param dataset an RDD of term frequency vectors
          * @return an RDD of TF-IDF vectors
          */
      +  @Since("1.1.0")
         def transform(dataset: RDD[Vector]): RDD[Vector] = {
           val bcIdf = dataset.context.broadcast(idf)
           dataset.mapPartitions(iter => iter.map(v => IDFModel.transform(bcIdf.value, v)))
      @@ -182,6 +188,7 @@ class IDFModel private[spark] (val idf: Vector) extends Serializable {
          * @param v a term frequency vector
          * @return a TF-IDF vector
          */
      +  @Since("1.3.0")
         def transform(v: Vector): Vector = IDFModel.transform(idf, v)
       
         /**
      @@ -189,6 +196,7 @@ class IDFModel private[spark] (val idf: Vector) extends Serializable {
          * @param dataset a JavaRDD of term frequency vectors
          * @return a JavaRDD of TF-IDF vectors
          */
      +  @Since("1.1.0")
         def transform(dataset: JavaRDD[Vector]): JavaRDD[Vector] = {
           transform(dataset.rdd).toJavaRDD()
         }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
      index 32848e039eb8..8d5a22520d6b 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
      @@ -17,7 +17,7 @@
       
       package org.apache.spark.mllib.feature
       
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
       
       /**
      @@ -31,9 +31,11 @@ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors
        *
        * @param p Normalization in L^p^ space, p = 2 by default.
        */
      +@Since("1.1.0")
       @Experimental
      -class Normalizer(p: Double) extends VectorTransformer {
      +class Normalizer @Since("1.1.0") (p: Double) extends VectorTransformer {
       
      +  @Since("1.1.0")
         def this() = this(2)
       
         require(p >= 1.0)
      @@ -44,6 +46,7 @@ class Normalizer(p: Double) extends VectorTransformer {
          * @param vector vector to be normalized.
          * @return normalized vector. If the norm of the input is zero, it will return the input vector.
          */
      +  @Since("1.1.0")
         override def transform(vector: Vector): Vector = {
           val norm = Vectors.norm(vector, p)
       
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala
      index 4e01e402b428..ecb3c1e6c1c8 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala
      @@ -17,6 +17,7 @@
       
       package org.apache.spark.mllib.feature
       
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.api.java.JavaRDD
       import org.apache.spark.mllib.linalg._
       import org.apache.spark.mllib.linalg.distributed.RowMatrix
      @@ -27,7 +28,8 @@ import org.apache.spark.rdd.RDD
        *
        * @param k number of principal components
        */
      -class PCA(val k: Int) {
      +@Since("1.4.0")
      +class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) {
         require(k >= 1, s"PCA requires a number of principal components k >= 1 but was given $k")
       
         /**
      @@ -35,6 +37,7 @@ class PCA(val k: Int) {
          *
          * @param sources source vectors
          */
      +  @Since("1.4.0")
         def fit(sources: RDD[Vector]): PCAModel = {
           require(k <= sources.first().size,
             s"source vector size is ${sources.first().size} must be greater than k=$k")
      @@ -58,7 +61,10 @@ class PCA(val k: Int) {
           new PCAModel(k, pc)
         }
       
      -  /** Java-friendly version of [[fit()]] */
      +  /**
      +   * Java-friendly version of [[fit()]]
      +   */
      +  @Since("1.4.0")
         def fit(sources: JavaRDD[Vector]): PCAModel = fit(sources.rdd)
       }
       
      @@ -68,7 +74,10 @@ class PCA(val k: Int) {
        * @param k number of principal components.
        * @param pc a principal components Matrix. Each column is one principal component.
        */
      -class PCAModel private[mllib] (val k: Int, val pc: DenseMatrix) extends VectorTransformer {
      +@Since("1.4.0")
      +class PCAModel private[spark] (
      +    @Since("1.4.0") val k: Int,
      +    @Since("1.4.0") val pc: DenseMatrix) extends VectorTransformer {
         /**
          * Transform a vector by computed Principal Components.
          *
      @@ -76,6 +85,7 @@ class PCAModel private[mllib] (val k: Int, val pc: DenseMatrix) extends VectorTr
          *               Vector must be the same length as the source vectors given to [[PCA.fit()]].
          * @return transformed vector. Vector will be of length k.
          */
      +  @Since("1.4.0")
         override def transform(vector: Vector): Vector = {
           vector match {
             case dv: DenseVector =>
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
      index c73b8f258060..f018b453bae7 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
      @@ -18,7 +18,7 @@
       package org.apache.spark.mllib.feature
       
       import org.apache.spark.Logging
      -import org.apache.spark.annotation.{DeveloperApi, Experimental}
      +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
       import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
       import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
       import org.apache.spark.rdd.RDD
      @@ -32,9 +32,11 @@ import org.apache.spark.rdd.RDD
        *                 dense output, so this does not work on sparse input and will raise an exception.
        * @param withStd True by default. Scales the data to unit standard deviation.
        */
      +@Since("1.1.0")
       @Experimental
      -class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging {
      +class StandardScaler @Since("1.1.0") (withMean: Boolean, withStd: Boolean) extends Logging {
       
      +  @Since("1.1.0")
         def this() = this(false, true)
       
         if (!(withMean || withStd)) {
      @@ -47,6 +49,7 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging {
          * @param data The data used to compute the mean and variance to build the transformation model.
          * @return a StandardScalarModel
          */
      +  @Since("1.1.0")
         def fit(data: RDD[Vector]): StandardScalerModel = {
           // TODO: skip computation if both withMean and withStd are false
           val summary = data.treeAggregate(new MultivariateOnlineSummarizer)(
      @@ -69,13 +72,17 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging {
        * @param withStd whether to scale the data to have unit standard deviation
        * @param withMean whether to center the data before scaling
        */
      +@Since("1.1.0")
       @Experimental
      -class StandardScalerModel (
      -    val std: Vector,
      -    val mean: Vector,
      -    var withStd: Boolean,
      -    var withMean: Boolean) extends VectorTransformer {
      +class StandardScalerModel @Since("1.3.0") (
      +    @Since("1.3.0") val std: Vector,
      +    @Since("1.1.0") val mean: Vector,
      +    @Since("1.3.0") var withStd: Boolean,
      +    @Since("1.3.0") var withMean: Boolean) extends VectorTransformer {
       
      +  /**
      +   */
      +  @Since("1.3.0")
         def this(std: Vector, mean: Vector) {
           this(std, mean, withStd = std != null, withMean = mean != null)
           require(this.withStd || this.withMean,
      @@ -86,8 +93,10 @@ class StandardScalerModel (
           }
         }
       
      +  @Since("1.3.0")
         def this(std: Vector) = this(std, null)
       
      +  @Since("1.3.0")
         @DeveloperApi
         def setWithMean(withMean: Boolean): this.type = {
           require(!(withMean && this.mean == null), "cannot set withMean to true while mean is null")
      @@ -95,6 +104,7 @@ class StandardScalerModel (
           this
         }
       
      +  @Since("1.3.0")
         @DeveloperApi
         def setWithStd(withStd: Boolean): this.type = {
           require(!(withStd && this.std == null),
      @@ -115,6 +125,7 @@ class StandardScalerModel (
          * @return Standardized vector. If the std of a column is zero, it will return default `0.0`
          *         for the column with zero std.
          */
      +  @Since("1.1.0")
         override def transform(vector: Vector): Vector = {
           require(mean.size == vector.size)
           if (withMean) {
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala
      index 7358c1c84f79..5778fd1d0925 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala
      @@ -17,7 +17,7 @@
       
       package org.apache.spark.mllib.feature
       
      -import org.apache.spark.annotation.DeveloperApi
      +import org.apache.spark.annotation.{DeveloperApi, Since}
       import org.apache.spark.api.java.JavaRDD
       import org.apache.spark.mllib.linalg.Vector
       import org.apache.spark.rdd.RDD
      @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD
        * :: DeveloperApi ::
        * Trait for transformation of a vector
        */
      +@Since("1.1.0")
       @DeveloperApi
       trait VectorTransformer extends Serializable {
       
      @@ -35,6 +36,7 @@ trait VectorTransformer extends Serializable {
          * @param vector vector to be transformed.
          * @return transformed vector.
          */
      +  @Since("1.1.0")
         def transform(vector: Vector): Vector
       
         /**
      @@ -43,6 +45,7 @@ trait VectorTransformer extends Serializable {
          * @param data RDD[Vector] to be transformed.
          * @return transformed RDD[Vector].
          */
      +  @Since("1.1.0")
         def transform(data: RDD[Vector]): RDD[Vector] = {
           // Later in #1498 , all RDD objects are sent via broadcasting instead of akka.
           // So it should be no longer necessary to explicitly broadcast `this` object.
      @@ -55,6 +58,7 @@ trait VectorTransformer extends Serializable {
          * @param data JavaRDD[Vector] to be transformed.
          * @return transformed JavaRDD[Vector].
          */
      +  @Since("1.1.0")
         def transform(data: JavaRDD[Vector]): JavaRDD[Vector] = {
           transform(data.rdd)
         }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
      index f087d06d2a46..58857c338f54 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
      @@ -32,7 +32,7 @@ import org.json4s.jackson.JsonMethods._
       import org.apache.spark.Logging
       import org.apache.spark.SparkContext
       import org.apache.spark.SparkContext._
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.api.java.JavaRDD
       import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, BLAS, DenseVector}
       import org.apache.spark.mllib.util.{Loader, Saveable}
      @@ -70,6 +70,7 @@ private case class VocabWord(
        * and
        * Distributed Representations of Words and Phrases and their Compositionality.
        */
      +@Since("1.1.0")
       @Experimental
       class Word2Vec extends Serializable with Logging {
       
      @@ -83,6 +84,7 @@ class Word2Vec extends Serializable with Logging {
         /**
          * Sets vector size (default: 100).
          */
      +  @Since("1.1.0")
         def setVectorSize(vectorSize: Int): this.type = {
           this.vectorSize = vectorSize
           this
      @@ -91,6 +93,7 @@ class Word2Vec extends Serializable with Logging {
         /**
          * Sets initial learning rate (default: 0.025).
          */
      +  @Since("1.1.0")
         def setLearningRate(learningRate: Double): this.type = {
           this.learningRate = learningRate
           this
      @@ -99,6 +102,7 @@ class Word2Vec extends Serializable with Logging {
         /**
          * Sets number of partitions (default: 1). Use a small number for accuracy.
          */
      +  @Since("1.1.0")
         def setNumPartitions(numPartitions: Int): this.type = {
           require(numPartitions > 0, s"numPartitions must be greater than 0 but got $numPartitions")
           this.numPartitions = numPartitions
      @@ -109,6 +113,7 @@ class Word2Vec extends Serializable with Logging {
          * Sets number of iterations (default: 1), which should be smaller than or equal to number of
          * partitions.
          */
      +  @Since("1.1.0")
         def setNumIterations(numIterations: Int): this.type = {
           this.numIterations = numIterations
           this
      @@ -117,6 +122,7 @@ class Word2Vec extends Serializable with Logging {
         /**
          * Sets random seed (default: a random long integer).
          */
      +  @Since("1.1.0")
         def setSeed(seed: Long): this.type = {
           this.seed = seed
           this
      @@ -126,6 +132,7 @@ class Word2Vec extends Serializable with Logging {
          * Sets minCount, the minimum number of times a token must appear to be included in the word2vec
          * model's vocabulary (default: 5).
          */
      +  @Since("1.3.0")
         def setMinCount(minCount: Int): this.type = {
           this.minCount = minCount
           this
      @@ -263,6 +270,7 @@ class Word2Vec extends Serializable with Logging {
          * @param dataset an RDD of words
          * @return a Word2VecModel
          */
      +  @Since("1.1.0")
         def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = {
       
           val words = dataset.flatMap(x => x)
      @@ -403,17 +411,8 @@ class Word2Vec extends Serializable with Logging {
           }
           newSentences.unpersist()
       
      -    val word2VecMap = mutable.HashMap.empty[String, Array[Float]]
      -    var i = 0
      -    while (i < vocabSize) {
      -      val word = bcVocab.value(i).word
      -      val vector = new Array[Float](vectorSize)
      -      Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize)
      -      word2VecMap += word -> vector
      -      i += 1
      -    }
      -
      -    new Word2VecModel(word2VecMap.toMap)
      +    val wordArray = vocab.map(_.word)
      +    new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global)
         }
       
         /**
      @@ -421,6 +420,7 @@ class Word2Vec extends Serializable with Logging {
          * @param dataset a JavaRDD of words
          * @return a Word2VecModel
          */
      +  @Since("1.1.0")
         def fit[S <: JavaIterable[String]](dataset: JavaRDD[S]): Word2VecModel = {
           fit(dataset.rdd.map(_.asScala))
         }
      @@ -429,38 +429,44 @@ class Word2Vec extends Serializable with Logging {
       /**
        * :: Experimental ::
        * Word2Vec model
      + * @param wordIndex maps each word to an index, which can retrieve the corresponding
      + *                  vector from wordVectors
      + * @param wordVectors array of length numWords * vectorSize, vector corresponding
      + *                    to the word mapped with index i can be retrieved by the slice
      + *                    (i * vectorSize, i * vectorSize + vectorSize)
        */
       @Experimental
      -class Word2VecModel private[spark] (
      -    model: Map[String, Array[Float]]) extends Serializable with Saveable {
      +@Since("1.1.0")
      +class Word2VecModel private[mllib] (
      +    private val wordIndex: Map[String, Int],
      +    private val wordVectors: Array[Float]) extends Serializable with Saveable {
       
      -  // wordList: Ordered list of words obtained from model.
      -  private val wordList: Array[String] = model.keys.toArray
      -
      -  // wordIndex: Maps each word to an index, which can retrieve the corresponding
      -  //            vector from wordVectors (see below).
      -  private val wordIndex: Map[String, Int] = wordList.zip(0 until model.size).toMap
      -
      -  // vectorSize: Dimension of each word's vector.
      -  private val vectorSize = model.head._2.size
         private val numWords = wordIndex.size
      +  // vectorSize: Dimension of each word's vector.
      +  private val vectorSize = wordVectors.length / numWords
      +
      +  // wordList: Ordered list of words obtained from wordIndex.
      +  private val wordList: Array[String] = {
      +    val (wl, _) = wordIndex.toSeq.sortBy(_._2).unzip
      +    wl.toArray
      +  }
       
      -  // wordVectors: Array of length numWords * vectorSize, vector corresponding to the word
      -  //              mapped with index i can be retrieved by the slice
      -  //              (ind * vectorSize, ind * vectorSize + vectorSize)
         // wordVecNorms: Array of length numWords, each value being the Euclidean norm
         //               of the wordVector.
      -  private val (wordVectors: Array[Float], wordVecNorms: Array[Double]) = {
      -    val wordVectors = new Array[Float](vectorSize * numWords)
      +  private val wordVecNorms: Array[Double] = {
           val wordVecNorms = new Array[Double](numWords)
           var i = 0
           while (i < numWords) {
      -      val vec = model.get(wordList(i)).get
      -      Array.copy(vec, 0, wordVectors, i * vectorSize, vectorSize)
      +      val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize)
             wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1)
             i += 1
           }
      -    (wordVectors, wordVecNorms)
      +    wordVecNorms
      +  }
      +
      +  @Since("1.5.0")
      +  def this(model: Map[String, Array[Float]]) = {
      +    this(Word2VecModel.buildWordIndex(model), Word2VecModel.buildWordVectors(model))
         }
       
         private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
      @@ -474,6 +480,7 @@ class Word2VecModel private[spark] (
       
         override protected def formatVersion = "1.0"
       
      +  @Since("1.4.0")
         def save(sc: SparkContext, path: String): Unit = {
           Word2VecModel.SaveLoadV1_0.save(sc, path, getVectors)
         }
      @@ -483,9 +490,11 @@ class Word2VecModel private[spark] (
          * @param word a word
          * @return vector representation of word
          */
      +  @Since("1.1.0")
         def transform(word: String): Vector = {
      -    model.get(word) match {
      -      case Some(vec) =>
      +    wordIndex.get(word) match {
      +      case Some(ind) =>
      +        val vec = wordVectors.slice(ind * vectorSize, ind * vectorSize + vectorSize)
               Vectors.dense(vec.map(_.toDouble))
             case None =>
               throw new IllegalStateException(s"$word not in vocabulary")
      @@ -498,6 +507,7 @@ class Word2VecModel private[spark] (
          * @param num number of synonyms to find
          * @return array of (word, cosineSimilarity)
          */
      +  @Since("1.1.0")
         def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
           val vector = transform(word)
           findSynonyms(vector, num)
      @@ -509,9 +519,10 @@ class Word2VecModel private[spark] (
          * @param num number of synonyms to find
          * @return array of (word, cosineSimilarity)
          */
      +  @Since("1.1.0")
         def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
           require(num > 0, "Number of similar words should > 0")
      -
      +    // TODO: optimize top-k
           val fVector = vector.toArray.map(_.toFloat)
           val cosineVec = Array.fill[Float](numWords)(0)
           val alpha: Float = 1
      @@ -521,13 +532,13 @@ class Word2VecModel private[spark] (
             "T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1)
       
           // Need not divide with the norm of the given vector since it is constant.
      -    val updatedCosines = new Array[Double](numWords)
      +    val cosVec = cosineVec.map(_.toDouble)
           var ind = 0
           while (ind < numWords) {
      -      updatedCosines(ind) = cosineVec(ind) / wordVecNorms(ind)
      +      cosVec(ind) /= wordVecNorms(ind)
             ind += 1
           }
      -    wordList.zip(updatedCosines)
      +    wordList.zip(cosVec)
             .toSeq
             .sortBy(- _._2)
             .take(num + 1)
      @@ -538,6 +549,7 @@ class Word2VecModel private[spark] (
         /**
          * Returns a map of words to their vector representations.
          */
      +  @Since("1.2.0")
         def getVectors: Map[String, Array[Float]] = {
           wordIndex.map { case (word, ind) =>
             (word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize))
      @@ -545,9 +557,27 @@ class Word2VecModel private[spark] (
         }
       }
       
      +@Since("1.4.0")
       @Experimental
       object Word2VecModel extends Loader[Word2VecModel] {
       
      +  private def buildWordIndex(model: Map[String, Array[Float]]): Map[String, Int] = {
      +    model.keys.zipWithIndex.toMap
      +  }
      +
      +  private def buildWordVectors(model: Map[String, Array[Float]]): Array[Float] = {
      +    require(model.nonEmpty, "Word2VecMap should be non-empty")
      +    val (vectorSize, numWords) = (model.head._2.size, model.size)
      +    val wordList = model.keys.toArray
      +    val wordVectors = new Array[Float](vectorSize * numWords)
      +    var i = 0
      +    while (i < numWords) {
      +      Array.copy(model(wordList(i)), 0, wordVectors, i * vectorSize, vectorSize)
      +      i += 1
      +    }
      +    wordVectors
      +  }
      +
         private object SaveLoadV1_0 {
       
           val formatVersionV1_0 = "1.0"
      @@ -560,12 +590,10 @@ object Word2VecModel extends Loader[Word2VecModel] {
             val dataPath = Loader.dataPath(path)
             val sqlContext = new SQLContext(sc)
             val dataFrame = sqlContext.read.parquet(dataPath)
      -
      -      val dataArray = dataFrame.select("word", "vector").collect()
      -
             // Check schema explicitly since erasure makes it hard to use match-case for checking.
             Loader.checkSchema[Data](dataFrame.schema)
       
      +      val dataArray = dataFrame.select("word", "vector").collect()
             val word2VecMap = dataArray.map(i => (i.getString(0), i.getSeq[Float](1).toArray)).toMap
             new Word2VecModel(word2VecMap)
           }
      @@ -587,6 +615,7 @@ object Word2VecModel extends Loader[Word2VecModel] {
           }
         }
       
      +  @Since("1.4.0")
         override def load(sc: SparkContext, path: String): Word2VecModel = {
       
           val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala
      new file mode 100644
      index 000000000000..95c688c86a7e
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala
      @@ -0,0 +1,146 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +package org.apache.spark.mllib.fpm
      +
      +import scala.collection.JavaConverters._
      +import scala.reflect.ClassTag
      +
      +import org.apache.spark.Logging
      +import org.apache.spark.annotation.{Experimental, Since}
      +import org.apache.spark.api.java.JavaRDD
      +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
      +import org.apache.spark.mllib.fpm.AssociationRules.Rule
      +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
      +import org.apache.spark.rdd.RDD
      +
      +/**
      + * :: Experimental ::
      + *
      + * Generates association rules from a [[RDD[FreqItemset[Item]]]. This method only generates
      + * association rules which have a single item as the consequent.
      + *
      + */
      +@Since("1.5.0")
      +@Experimental
      +class AssociationRules private[fpm] (
      +    private var minConfidence: Double) extends Logging with Serializable {
      +
      +  /**
      +   * Constructs a default instance with default parameters {minConfidence = 0.8}.
      +   */
      +  @Since("1.5.0")
      +  def this() = this(0.8)
      +
      +  /**
      +   * Sets the minimal confidence (default: `0.8`).
      +   */
      +  @Since("1.5.0")
      +  def setMinConfidence(minConfidence: Double): this.type = {
      +    require(minConfidence >= 0.0 && minConfidence <= 1.0)
      +    this.minConfidence = minConfidence
      +    this
      +  }
      +
      +  /**
      +   * Computes the association rules with confidence above [[minConfidence]].
      +   * @param freqItemsets frequent itemset model obtained from [[FPGrowth]]
      +   * @return a [[Set[Rule[Item]]] containing the assocation rules.
      +   *
      +   */
      +  @Since("1.5.0")
      +  def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]]): RDD[Rule[Item]] = {
      +    // For candidate rule X => Y, generate (X, (Y, freq(X union Y)))
      +    val candidates = freqItemsets.flatMap { itemset =>
      +      val items = itemset.items
      +      items.flatMap { item =>
      +        items.partition(_ == item) match {
      +          case (consequent, antecedent) if !antecedent.isEmpty =>
      +            Some((antecedent.toSeq, (consequent.toSeq, itemset.freq)))
      +          case _ => None
      +        }
      +      }
      +    }
      +
      +    // Join to get (X, ((Y, freq(X union Y)), freq(X))), generate rules, and filter by confidence
      +    candidates.join(freqItemsets.map(x => (x.items.toSeq, x.freq)))
      +      .map { case (antecendent, ((consequent, freqUnion), freqAntecedent)) =>
      +      new Rule(antecendent.toArray, consequent.toArray, freqUnion, freqAntecedent)
      +    }.filter(_.confidence >= minConfidence)
      +  }
      +
      +  /** Java-friendly version of [[run]]. */
      +  @Since("1.5.0")
      +  def run[Item](freqItemsets: JavaRDD[FreqItemset[Item]]): JavaRDD[Rule[Item]] = {
      +    val tag = fakeClassTag[Item]
      +    run(freqItemsets.rdd)(tag)
      +  }
      +}
      +
      +@Since("1.5.0")
      +object AssociationRules {
      +
      +  /**
      +   * :: Experimental ::
      +   *
      +   * An association rule between sets of items.
      +   * @param antecedent hypotheses of the rule. Java users should call [[Rule#javaAntecedent]]
      +   *                   instead.
      +   * @param consequent conclusion of the rule. Java users should call [[Rule#javaConsequent]]
      +   *                   instead.
      +   * @tparam Item item type
      +   *
      +   */
      +  @Since("1.5.0")
      +  @Experimental
      +  class Rule[Item] private[fpm] (
      +      @Since("1.5.0") val antecedent: Array[Item],
      +      @Since("1.5.0") val consequent: Array[Item],
      +      freqUnion: Double,
      +      freqAntecedent: Double) extends Serializable {
      +
      +    /**
      +     * Returns the confidence of the rule.
      +     *
      +     */
      +    @Since("1.5.0")
      +    def confidence: Double = freqUnion.toDouble / freqAntecedent
      +
      +    require(antecedent.toSet.intersect(consequent.toSet).isEmpty, {
      +      val sharedItems = antecedent.toSet.intersect(consequent.toSet)
      +      s"A valid association rule must have disjoint antecedent and " +
      +        s"consequent but ${sharedItems} is present in both."
      +    })
      +
      +    /**
      +     * Returns antecedent in a Java List.
      +     *
      +     */
      +    @Since("1.5.0")
      +    def javaAntecedent: java.util.List[Item] = {
      +      antecedent.toList.asJava
      +    }
      +
      +    /**
      +     * Returns consequent in a Java List.
      +     *
      +     */
      +    @Since("1.5.0")
      +    def javaConsequent: java.util.List[Item] = {
      +      consequent.toList.asJava
      +    }
      +  }
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
      index efa8459d3cdb..aea5c4f8a8a7 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
      @@ -25,10 +25,10 @@ import scala.collection.JavaConverters._
       import scala.reflect.ClassTag
       
       import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException}
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.api.java.JavaRDD
       import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
      -import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
      +import org.apache.spark.mllib.fpm.FPGrowth._
       import org.apache.spark.rdd.RDD
       import org.apache.spark.storage.StorageLevel
       
      @@ -38,9 +38,22 @@ import org.apache.spark.storage.StorageLevel
        * Model trained by [[FPGrowth]], which holds frequent itemsets.
        * @param freqItemsets frequent itemset, which is an RDD of [[FreqItemset]]
        * @tparam Item item type
      + *
        */
      +@Since("1.3.0")
       @Experimental
      -class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable
      +class FPGrowthModel[Item: ClassTag] @Since("1.3.0") (
      +    @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable {
      +  /**
      +   * Generates association rules for the [[Item]]s in [[freqItemsets]].
      +   * @param confidence minimal confidence of the rules produced
      +   */
      +  @Since("1.5.0")
      +  def generateAssociationRules(confidence: Double): RDD[AssociationRules.Rule[Item]] = {
      +    val associationRules = new AssociationRules(confidence)
      +    associationRules.run(freqItemsets)
      +  }
      +}
       
       /**
        * :: Experimental ::
      @@ -58,7 +71,9 @@ class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) ex
        *
        * @see [[http://en.wikipedia.org/wiki/Association_rule_learning Association rule learning
        *       (Wikipedia)]]
      + *
        */
      +@Since("1.3.0")
       @Experimental
       class FPGrowth private (
           private var minSupport: Double,
      @@ -67,12 +82,16 @@ class FPGrowth private (
         /**
          * Constructs a default instance with default parameters {minSupport: `0.3`, numPartitions: same
          * as the input data}.
      +   *
          */
      +  @Since("1.3.0")
         def this() = this(0.3, -1)
       
         /**
          * Sets the minimal support level (default: `0.3`).
      +   *
          */
      +  @Since("1.3.0")
         def setMinSupport(minSupport: Double): this.type = {
           this.minSupport = minSupport
           this
      @@ -80,7 +99,9 @@ class FPGrowth private (
       
         /**
          * Sets the number of partitions used by parallel FP-growth (default: same as input data).
      +   *
          */
      +  @Since("1.3.0")
         def setNumPartitions(numPartitions: Int): this.type = {
           this.numPartitions = numPartitions
           this
      @@ -90,7 +111,9 @@ class FPGrowth private (
          * Computes an FP-Growth model that contains frequent itemsets.
          * @param data input data set, each element contains a transaction
          * @return an [[FPGrowthModel]]
      +   *
          */
      +  @Since("1.3.0")
         def run[Item: ClassTag](data: RDD[Array[Item]]): FPGrowthModel[Item] = {
           if (data.getStorageLevel == StorageLevel.NONE) {
             logWarning("Input data is not cached.")
      @@ -104,6 +127,8 @@ class FPGrowth private (
           new FPGrowthModel(freqItemsets)
         }
       
      +  /** Java-friendly version of [[run]]. */
      +  @Since("1.3.0")
         def run[Item, Basket <: JavaIterable[Item]](data: JavaRDD[Basket]): FPGrowthModel[Item] = {
           implicit val tag = fakeClassTag[Item]
           run(data.rdd.map(_.asScala.toArray))
      @@ -190,7 +215,9 @@ class FPGrowth private (
       
       /**
        * :: Experimental ::
      + *
        */
      +@Since("1.3.0")
       @Experimental
       object FPGrowth {
       
      @@ -199,12 +226,18 @@ object FPGrowth {
          * @param items items in this itemset. Java users should call [[FreqItemset#javaItems]] instead.
          * @param freq frequency
          * @tparam Item item type
      +   *
          */
      -  class FreqItemset[Item](val items: Array[Item], val freq: Long) extends Serializable {
      +  @Since("1.3.0")
      +  class FreqItemset[Item] @Since("1.3.0") (
      +      @Since("1.3.0") val items: Array[Item],
      +      @Since("1.3.0") val freq: Long) extends Serializable {
       
           /**
            * Returns items in a Java List.
      +     *
            */
      +    @Since("1.3.0")
           def javaItems: java.util.List[Item] = {
             items.toList.asJava
           }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
      new file mode 100644
      index 000000000000..3ea10779a183
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
      @@ -0,0 +1,110 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.mllib.fpm
      +
      +import scala.collection.mutable
      +
      +import org.apache.spark.Logging
      +
      +/**
      + * Calculate all patterns of a projected database in local mode.
      + *
      + * @param minCount minimal count for a frequent pattern
      + * @param maxPatternLength max pattern length for a frequent pattern
      + */
      +private[fpm] class LocalPrefixSpan(
      +    val minCount: Long,
      +    val maxPatternLength: Int) extends Logging with Serializable {
      +  import PrefixSpan.Postfix
      +  import LocalPrefixSpan.ReversedPrefix
      +
      +  /**
      +   * Generates frequent patterns on the input array of postfixes.
      +   * @param postfixes an array of postfixes
      +   * @return an iterator of (frequent pattern, count)
      +   */
      +  def run(postfixes: Array[Postfix]): Iterator[(Array[Int], Long)] = {
      +    genFreqPatterns(ReversedPrefix.empty, postfixes).map { case (prefix, count) =>
      +      (prefix.toSequence, count)
      +    }
      +  }
      +
      +  /**
      +   * Recursively generates frequent patterns.
      +   * @param prefix current prefix
      +   * @param postfixes projected postfixes w.r.t. the prefix
      +   * @return an iterator of (prefix, count)
      +   */
      +  private def genFreqPatterns(
      +      prefix: ReversedPrefix,
      +      postfixes: Array[Postfix]): Iterator[(ReversedPrefix, Long)] = {
      +    if (maxPatternLength == prefix.length || postfixes.length < minCount) {
      +      return Iterator.empty
      +    }
      +    // find frequent items
      +    val counts = mutable.Map.empty[Int, Long].withDefaultValue(0)
      +    postfixes.foreach { postfix =>
      +      postfix.genPrefixItems.foreach { case (x, _) =>
      +        counts(x) += 1L
      +      }
      +    }
      +    val freqItems = counts.toSeq.filter { case (_, count) =>
      +      count >= minCount
      +    }.sorted
      +    // project and recursively call genFreqPatterns
      +    freqItems.toIterator.flatMap { case (item, count) =>
      +      val newPrefix = prefix :+ item
      +      Iterator.single((newPrefix, count)) ++ {
      +        val projected = postfixes.map(_.project(item)).filter(_.nonEmpty)
      +        genFreqPatterns(newPrefix, projected)
      +      }
      +    }
      +  }
      +}
      +
      +private object LocalPrefixSpan {
      +
      +  /**
      +   * Represents a prefix stored as a list in reversed order.
      +   * @param items items in the prefix in reversed order
      +   * @param length length of the prefix, not counting delimiters
      +   */
      +  class ReversedPrefix private (val items: List[Int], val length: Int) extends Serializable {
      +    /**
      +     * Expands the prefix by one item.
      +     */
      +    def :+(item: Int): ReversedPrefix = {
      +      require(item != 0)
      +      if (item < 0) {
      +        new ReversedPrefix(-item :: items, length + 1)
      +      } else {
      +        new ReversedPrefix(item :: 0 :: items, length + 1)
      +      }
      +    }
      +
      +    /**
      +     * Converts this prefix to a sequence.
      +     */
      +    def toSequence: Array[Int] = (0 :: items).toArray.reverse
      +  }
      +
      +  object ReversedPrefix {
      +    /** An empty prefix. */
      +    val empty: ReversedPrefix = new ReversedPrefix(List.empty, 0)
      +  }
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
      new file mode 100644
      index 000000000000..97916daa2e9a
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
      @@ -0,0 +1,569 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.mllib.fpm
      +
      +import java.{lang => jl, util => ju}
      +import java.util.concurrent.atomic.AtomicInteger
      +
      +import scala.collection.mutable
      +import scala.collection.JavaConverters._
      +import scala.reflect.ClassTag
      +
      +import org.apache.spark.Logging
      +import org.apache.spark.annotation.{Experimental, Since}
      +import org.apache.spark.api.java.JavaRDD
      +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
      +import org.apache.spark.rdd.RDD
      +import org.apache.spark.storage.StorageLevel
      +
      +/**
      + * :: Experimental ::
      + *
      + * A parallel PrefixSpan algorithm to mine frequent sequential patterns.
      + * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns
      + * Efficiently by Prefix-Projected Pattern Growth ([[http://doi.org/10.1109/ICDE.2001.914830]]).
      + *
      + * @param minSupport the minimal support level of the sequential pattern, any pattern appears
      + *                   more than  (minSupport * size-of-the-dataset) times will be output
      + * @param maxPatternLength the maximal length of the sequential pattern, any pattern appears
      + *                         less than maxPatternLength will be output
      + * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the internal
      + *                           storage format) allowed in a projected database before local
      + *                           processing. If a projected database exceeds this size, another
      + *                           iteration of distributed prefix growth is run.
      + *
      + * @see [[https://en.wikipedia.org/wiki/Sequential_Pattern_Mining Sequential Pattern Mining
      + *       (Wikipedia)]]
      + */
      +@Experimental
      +@Since("1.5.0")
      +class PrefixSpan private (
      +    private var minSupport: Double,
      +    private var maxPatternLength: Int,
      +    private var maxLocalProjDBSize: Long) extends Logging with Serializable {
      +  import PrefixSpan._
      +
      +  /**
      +   * Constructs a default instance with default parameters
      +   * {minSupport: `0.1`, maxPatternLength: `10`, maxLocalProjDBSize: `32000000L`}.
      +   */
      +  @Since("1.5.0")
      +  def this() = this(0.1, 10, 32000000L)
      +
      +  /**
      +   * Get the minimal support (i.e. the frequency of occurrence before a pattern is considered
      +   * frequent).
      +   */
      +  @Since("1.5.0")
      +  def getMinSupport: Double = minSupport
      +
      +  /**
      +   * Sets the minimal support level (default: `0.1`).
      +   */
      +  @Since("1.5.0")
      +  def setMinSupport(minSupport: Double): this.type = {
      +    require(minSupport >= 0 && minSupport <= 1,
      +      s"The minimum support value must be in [0, 1], but got $minSupport.")
      +    this.minSupport = minSupport
      +    this
      +  }
      +
      +  /**
      +   * Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider.
      +   */
      +  @Since("1.5.0")
      +  def getMaxPatternLength: Int = maxPatternLength
      +
      +  /**
      +   * Sets maximal pattern length (default: `10`).
      +   */
      +  @Since("1.5.0")
      +  def setMaxPatternLength(maxPatternLength: Int): this.type = {
      +    // TODO: support unbounded pattern length when maxPatternLength = 0
      +    require(maxPatternLength >= 1,
      +      s"The maximum pattern length value must be greater than 0, but got $maxPatternLength.")
      +    this.maxPatternLength = maxPatternLength
      +    this
      +  }
      +
      +  /**
      +   * Gets the maximum number of items allowed in a projected database before local processing.
      +   */
      +  @Since("1.5.0")
      +  def getMaxLocalProjDBSize: Long = maxLocalProjDBSize
      +
      +  /**
      +   * Sets the maximum number of items (including delimiters used in the internal storage format)
      +   * allowed in a projected database before local processing (default: `32000000L`).
      +   */
      +  @Since("1.5.0")
      +  def setMaxLocalProjDBSize(maxLocalProjDBSize: Long): this.type = {
      +    require(maxLocalProjDBSize >= 0L,
      +      s"The maximum local projected database size must be nonnegative, but got $maxLocalProjDBSize")
      +    this.maxLocalProjDBSize = maxLocalProjDBSize
      +    this
      +  }
      +
      +  /**
      +   * Finds the complete set of frequent sequential patterns in the input sequences of itemsets.
      +   * @param data sequences of itemsets.
      +   * @return a [[PrefixSpanModel]] that contains the frequent patterns
      +   */
      +  @Since("1.5.0")
      +  def run[Item: ClassTag](data: RDD[Array[Array[Item]]]): PrefixSpanModel[Item] = {
      +    if (data.getStorageLevel == StorageLevel.NONE) {
      +      logWarning("Input data is not cached.")
      +    }
      +
      +    val totalCount = data.count()
      +    logInfo(s"number of sequences: $totalCount")
      +    val minCount = math.ceil(minSupport * totalCount).toLong
      +    logInfo(s"minimum count for a frequent pattern: $minCount")
      +
      +    // Find frequent items.
      +    val freqItemAndCounts = data.flatMap { itemsets =>
      +        val uniqItems = mutable.Set.empty[Item]
      +        itemsets.foreach { _.foreach { item =>
      +          uniqItems += item
      +        }}
      +        uniqItems.toIterator.map((_, 1L))
      +      }.reduceByKey(_ + _)
      +      .filter { case (_, count) =>
      +        count >= minCount
      +      }.collect()
      +    val freqItems = freqItemAndCounts.sortBy(-_._2).map(_._1)
      +    logInfo(s"number of frequent items: ${freqItems.length}")
      +
      +    // Keep only frequent items from input sequences and convert them to internal storage.
      +    val itemToInt = freqItems.zipWithIndex.toMap
      +    val dataInternalRepr = data.flatMap { itemsets =>
      +      val allItems = mutable.ArrayBuilder.make[Int]
      +      var containsFreqItems = false
      +      allItems += 0
      +      itemsets.foreach { itemsets =>
      +        val items = mutable.ArrayBuilder.make[Int]
      +        itemsets.foreach { item =>
      +          if (itemToInt.contains(item)) {
      +            items += itemToInt(item) + 1 // using 1-indexing in internal format
      +          }
      +        }
      +        val result = items.result()
      +        if (result.nonEmpty) {
      +          containsFreqItems = true
      +          allItems ++= result.sorted
      +        }
      +        allItems += 0
      +      }
      +      if (containsFreqItems) {
      +        Iterator.single(allItems.result())
      +      } else {
      +        Iterator.empty
      +      }
      +    }.persist(StorageLevel.MEMORY_AND_DISK)
      +
      +    val results = genFreqPatterns(dataInternalRepr, minCount, maxPatternLength, maxLocalProjDBSize)
      +
      +    def toPublicRepr(pattern: Array[Int]): Array[Array[Item]] = {
      +      val sequenceBuilder = mutable.ArrayBuilder.make[Array[Item]]
      +      val itemsetBuilder = mutable.ArrayBuilder.make[Item]
      +      val n = pattern.length
      +      var i = 1
      +      while (i < n) {
      +        val x = pattern(i)
      +        if (x == 0) {
      +          sequenceBuilder += itemsetBuilder.result()
      +          itemsetBuilder.clear()
      +        } else {
      +          itemsetBuilder += freqItems(x - 1) // using 1-indexing in internal format
      +        }
      +        i += 1
      +      }
      +      sequenceBuilder.result()
      +    }
      +
      +    val freqSequences = results.map { case (seq: Array[Int], count: Long) =>
      +      new FreqSequence(toPublicRepr(seq), count)
      +    }
      +    new PrefixSpanModel(freqSequences)
      +  }
      +
      +  /**
      +   * A Java-friendly version of [[run()]] that reads sequences from a [[JavaRDD]] and returns
      +   * frequent sequences in a [[PrefixSpanModel]].
      +   * @param data ordered sequences of itemsets stored as Java Iterable of Iterables
      +   * @tparam Item item type
      +   * @tparam Itemset itemset type, which is an Iterable of Items
      +   * @tparam Sequence sequence type, which is an Iterable of Itemsets
      +   * @return a [[PrefixSpanModel]] that contains the frequent sequential patterns
      +   */
      +  @Since("1.5.0")
      +  def run[Item, Itemset <: jl.Iterable[Item], Sequence <: jl.Iterable[Itemset]](
      +      data: JavaRDD[Sequence]): PrefixSpanModel[Item] = {
      +    implicit val tag = fakeClassTag[Item]
      +    run(data.rdd.map(_.asScala.map(_.asScala.toArray).toArray))
      +  }
      +
      +}
      +
      +@Experimental
      +@Since("1.5.0")
      +object PrefixSpan extends Logging {
      +
      +  /**
      +   * Find the complete set of frequent sequential patterns in the input sequences.
      +   * @param data ordered sequences of itemsets. We represent a sequence internally as Array[Int],
      +   *             where each itemset is represented by a contiguous sequence of distinct and ordered
      +   *             positive integers. We use 0 as the delimiter at itemset boundaries, including the
      +   *             first and the last position.
      +   * @return an RDD of (frequent sequential pattern, count) pairs,
      +   * @see [[Postfix]]
      +   */
      +  private[fpm] def genFreqPatterns(
      +      data: RDD[Array[Int]],
      +      minCount: Long,
      +      maxPatternLength: Int,
      +      maxLocalProjDBSize: Long): RDD[(Array[Int], Long)] = {
      +    val sc = data.sparkContext
      +
      +    if (data.getStorageLevel == StorageLevel.NONE) {
      +      logWarning("Input data is not cached.")
      +    }
      +
      +    val postfixes = data.map(items => new Postfix(items))
      +
      +    // Local frequent patterns (prefixes) and their counts.
      +    val localFreqPatterns = mutable.ArrayBuffer.empty[(Array[Int], Long)]
      +    // Prefixes whose projected databases are small.
      +    val smallPrefixes = mutable.Map.empty[Int, Prefix]
      +    val emptyPrefix = Prefix.empty
      +    // Prefixes whose projected databases are large.
      +    var largePrefixes = mutable.Map(emptyPrefix.id -> emptyPrefix)
      +    while (largePrefixes.nonEmpty) {
      +      val numLocalFreqPatterns = localFreqPatterns.length
      +      logInfo(s"number of local frequent patterns: $numLocalFreqPatterns")
      +      if (numLocalFreqPatterns > 1000000) {
      +        logWarning(
      +          s"""
      +             | Collected $numLocalFreqPatterns local frequent patterns. You may want to consider:
      +             |   1. increase minSupport,
      +             |   2. decrease maxPatternLength,
      +             |   3. increase maxLocalProjDBSize.
      +           """.stripMargin)
      +      }
      +      logInfo(s"number of small prefixes: ${smallPrefixes.size}")
      +      logInfo(s"number of large prefixes: ${largePrefixes.size}")
      +      val largePrefixArray = largePrefixes.values.toArray
      +      val freqPrefixes = postfixes.flatMap { postfix =>
      +          largePrefixArray.flatMap { prefix =>
      +            postfix.project(prefix).genPrefixItems.map { case (item, postfixSize) =>
      +              ((prefix.id, item), (1L, postfixSize))
      +            }
      +          }
      +        }.reduceByKey { case ((c0, s0), (c1, s1)) =>
      +          (c0 + c1, s0 + s1)
      +        }.filter { case (_, (c, _)) => c >= minCount }
      +        .collect()
      +      val newLargePrefixes = mutable.Map.empty[Int, Prefix]
      +      freqPrefixes.foreach { case ((id, item), (count, projDBSize)) =>
      +        val newPrefix = largePrefixes(id) :+ item
      +        localFreqPatterns += ((newPrefix.items :+ 0, count))
      +        if (newPrefix.length < maxPatternLength) {
      +          if (projDBSize > maxLocalProjDBSize) {
      +            newLargePrefixes += newPrefix.id -> newPrefix
      +          } else {
      +            smallPrefixes += newPrefix.id -> newPrefix
      +          }
      +        }
      +      }
      +      largePrefixes = newLargePrefixes
      +    }
      +
      +    var freqPatterns = sc.parallelize(localFreqPatterns, 1)
      +
      +    val numSmallPrefixes = smallPrefixes.size
      +    logInfo(s"number of small prefixes for local processing: $numSmallPrefixes")
      +    if (numSmallPrefixes > 0) {
      +      // Switch to local processing.
      +      val bcSmallPrefixes = sc.broadcast(smallPrefixes)
      +      val distributedFreqPattern = postfixes.flatMap { postfix =>
      +        bcSmallPrefixes.value.values.map { prefix =>
      +          (prefix.id, postfix.project(prefix).compressed)
      +        }.filter(_._2.nonEmpty)
      +      }.groupByKey().flatMap { case (id, projPostfixes) =>
      +        val prefix = bcSmallPrefixes.value(id)
      +        val localPrefixSpan = new LocalPrefixSpan(minCount, maxPatternLength - prefix.length)
      +        // TODO: We collect projected postfixes into memory. We should also compare the performance
      +        // TODO: of keeping them on shuffle files.
      +        localPrefixSpan.run(projPostfixes.toArray).map { case (pattern, count) =>
      +          (prefix.items ++ pattern, count)
      +        }
      +      }
      +      // Union local frequent patterns and distributed ones.
      +      freqPatterns = freqPatterns ++ distributedFreqPattern
      +    }
      +
      +    freqPatterns
      +  }
      +
      +  /**
      +   * Represents a prefix.
      +   * @param items items in this prefix, using the internal format
      +   * @param length length of this prefix, not counting 0
      +   */
      +  private[fpm] class Prefix private (val items: Array[Int], val length: Int) extends Serializable {
      +
      +    /** A unique id for this prefix. */
      +    val id: Int = Prefix.nextId
      +
      +    /** Expands this prefix by the input item. */
      +    def :+(item: Int): Prefix = {
      +      require(item != 0)
      +      if (item < 0) {
      +        new Prefix(items :+ -item, length + 1)
      +      } else {
      +        new Prefix(items ++ Array(0, item), length + 1)
      +      }
      +    }
      +  }
      +
      +  private[fpm] object Prefix {
      +    /** Internal counter to generate unique IDs. */
      +    private val counter: AtomicInteger = new AtomicInteger(-1)
      +
      +    /** Gets the next unique ID. */
      +    private def nextId: Int = counter.incrementAndGet()
      +
      +    /** An empty [[Prefix]] instance. */
      +    val empty: Prefix = new Prefix(Array.empty, 0)
      +  }
      +
      +  /**
      +   * An internal representation of a postfix from some projection.
      +   * We use one int array to store the items, which might also contains other items from the
      +   * original sequence.
      +   * Items are represented by positive integers, and items in each itemset must be distinct and
      +   * ordered.
      +   * we use 0 as the delimiter between itemsets.
      +   * For example, a sequence `<(12)(31)1>` is represented by `[0, 1, 2, 0, 1, 3, 0, 1, 0]`.
      +   * The postfix of this sequence w.r.t. to prefix `<1>` is `<(_2)(13)1>`.
      +   * We may reuse the original items array `[0, 1, 2, 0, 1, 3, 0, 1, 0]` to represent the postfix,
      +   * and mark the start index of the postfix, which is `2` in this example.
      +   * So the active items in this postfix are `[2, 0, 1, 3, 0, 1, 0]`.
      +   * We also remember the start indices of partial projections, the ones that split an itemset.
      +   * For example, another possible partial projection w.r.t. `<1>` is `<(_3)1>`.
      +   * We remember the start indices of partial projections, which is `[2, 5]` in this example.
      +   * This data structure makes it easier to do projections.
      +   *
      +   * @param items a sequence stored as `Array[Int]` containing this postfix
      +   * @param start the start index of this postfix in items
      +   * @param partialStarts start indices of possible partial projections, strictly increasing
      +   */
      +  private[fpm] class Postfix(
      +      val items: Array[Int],
      +      val start: Int = 0,
      +      val partialStarts: Array[Int] = Array.empty) extends Serializable {
      +
      +    require(items.last == 0, s"The last item in a postfix must be zero, but got ${items.last}.")
      +    if (partialStarts.nonEmpty) {
      +      require(partialStarts.head >= start,
      +        "The first partial start cannot be smaller than the start index," +
      +          s"but got partialStarts.head = ${partialStarts.head} < start = $start.")
      +    }
      +
      +    /**
      +     * Start index of the first full itemset contained in this postfix.
      +     */
      +    private[this] def fullStart: Int = {
      +      var i = start
      +      while (items(i) != 0) {
      +        i += 1
      +      }
      +      i
      +    }
      +
      +    /**
      +     * Generates length-1 prefix items of this postfix with the corresponding postfix sizes.
      +     * There are two types of prefix items:
      +     *   a) The item can be assembled to the last itemset of the prefix. For example,
      +     *      the postfix of `<(12)(123)>1` w.r.t. `<1>` is `<(_2)(123)1>`. The prefix items of this
      +     *      postfix can be assembled to `<1>` is `_2` and `_3`, resulting new prefixes `<(12)>` and
      +     *      `<(13)>`. We flip the sign in the output to indicate that this is a partial prefix item.
      +     *   b) The item can be appended to the prefix. Taking the same example above, the prefix items
      +     *      can be appended to `<1>` is `1`, `2`, and `3`, resulting new prefixes `<11>`, `<12>`,
      +     *      and `<13>`.
      +     * @return an iterator of (prefix item, corresponding postfix size). If the item is negative, it
      +     *         indicates a partial prefix item, which should be assembled to the last itemset of the
      +     *         current prefix. Otherwise, the item should be appended to the current prefix.
      +     */
      +    def genPrefixItems: Iterator[(Int, Long)] = {
      +      val n1 = items.length - 1
      +      // For each unique item (subject to sign) in this sequence, we output exact one split.
      +      // TODO: use PrimitiveKeyOpenHashMap
      +      val prefixes = mutable.Map.empty[Int, Long]
      +      // a) items that can be assembled to the last itemset of the prefix
      +      partialStarts.foreach { start =>
      +        var i = start
      +        var x = -items(i)
      +        while (x != 0) {
      +          if (!prefixes.contains(x)) {
      +            prefixes(x) = n1 - i
      +          }
      +          i += 1
      +          x = -items(i)
      +        }
      +      }
      +      // b) items that can be appended to the prefix
      +      var i = fullStart
      +      while (i < n1) {
      +        val x = items(i)
      +        if (x != 0 && !prefixes.contains(x)) {
      +          prefixes(x) = n1 - i
      +        }
      +        i += 1
      +      }
      +      prefixes.toIterator
      +    }
      +
      +    /** Tests whether this postfix is non-empty. */
      +    def nonEmpty: Boolean = items.length > start + 1
      +
      +    /**
      +     * Projects this postfix with respect to the input prefix item.
      +     * @param prefix prefix item. If prefix is positive, we match items in any full itemset; if it
      +     *               is negative, we do partial projections.
      +     * @return the projected postfix
      +     */
      +    def project(prefix: Int): Postfix = {
      +      require(prefix != 0)
      +      val n1 = items.length - 1
      +      var matched = false
      +      var newStart = n1
      +      val newPartialStarts = mutable.ArrayBuilder.make[Int]
      +      if (prefix < 0) {
      +        // Search for partial projections.
      +        val target = -prefix
      +        partialStarts.foreach { start =>
      +          var i = start
      +          var x = items(i)
      +          while (x != target && x != 0) {
      +            i += 1
      +            x = items(i)
      +          }
      +          if (x == target) {
      +            i += 1
      +            if (!matched) {
      +              newStart = i
      +              matched = true
      +            }
      +            if (items(i) != 0) {
      +              newPartialStarts += i
      +            }
      +          }
      +        }
      +      } else {
      +        // Search for items in full itemsets.
      +        // Though the items are ordered in each itemsets, they should be small in practice.
      +        // So a sequential scan is sufficient here, compared to bisection search.
      +        val target = prefix
      +        var i = fullStart
      +        while (i < n1) {
      +          val x = items(i)
      +          if (x == target) {
      +            if (!matched) {
      +              newStart = i
      +              matched = true
      +            }
      +            if (items(i + 1) != 0) {
      +              newPartialStarts += i + 1
      +            }
      +          }
      +          i += 1
      +        }
      +      }
      +      new Postfix(items, newStart, newPartialStarts.result())
      +    }
      +
      +    /**
      +     * Projects this postfix with respect to the input prefix.
      +     */
      +    private def project(prefix: Array[Int]): Postfix = {
      +      var partial = true
      +      var cur = this
      +      var i = 0
      +      val np = prefix.length
      +      while (i < np && cur.nonEmpty) {
      +        val x = prefix(i)
      +        if (x == 0) {
      +          partial = false
      +        } else {
      +          if (partial) {
      +            cur = cur.project(-x)
      +          } else {
      +            cur = cur.project(x)
      +            partial = true
      +          }
      +        }
      +        i += 1
      +      }
      +      cur
      +    }
      +
      +    /**
      +     * Projects this postfix with respect to the input prefix.
      +     */
      +    def project(prefix: Prefix): Postfix = project(prefix.items)
      +
      +    /**
      +     * Returns the same sequence with compressed storage if possible.
      +     */
      +    def compressed: Postfix = {
      +      if (start > 0) {
      +        new Postfix(items.slice(start, items.length), 0, partialStarts.map(_ - start))
      +      } else {
      +        this
      +      }
      +    }
      +  }
      +
      +  /**
      +   * Represents a frequence sequence.
      +   * @param sequence a sequence of itemsets stored as an Array of Arrays
      +   * @param freq frequency
      +   * @tparam Item item type
      +   */
      +  @Since("1.5.0")
      +  class FreqSequence[Item] @Since("1.5.0") (
      +      @Since("1.5.0") val sequence: Array[Array[Item]],
      +      @Since("1.5.0") val freq: Long) extends Serializable {
      +    /**
      +     * Returns sequence as a Java List of lists for Java users.
      +     */
      +    @Since("1.5.0")
      +    def javaSequence: ju.List[ju.List[Item]] = sequence.map(_.toList.asJava).toList.asJava
      +  }
      +}
      +
      +/**
      + * Model fitted by [[PrefixSpan]]
      + * @param freqSequences frequent sequences
      + * @tparam Item item type
      + */
      +@Since("1.5.0")
      +class PrefixSpanModel[Item] @Since("1.5.0") (
      +    @Since("1.5.0") val freqSequences: RDD[PrefixSpan.FreqSequence[Item]])
      +  extends Serializable
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
      new file mode 100644
      index 000000000000..72d3aabc9b1f
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
      @@ -0,0 +1,154 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.mllib.impl
      +
      +import scala.collection.mutable
      +
      +import org.apache.hadoop.fs.{Path, FileSystem}
      +
      +import org.apache.spark.{SparkContext, Logging}
      +import org.apache.spark.storage.StorageLevel
      +
      +
      +/**
      + * This abstraction helps with persisting and checkpointing RDDs and types derived from RDDs
      + * (such as Graphs and DataFrames).  In documentation, we use the phrase "Dataset" to refer to
      + * the distributed data type (RDD, Graph, etc.).
      + *
      + * Specifically, this abstraction automatically handles persisting and (optionally) checkpointing,
      + * as well as unpersisting and removing checkpoint files.
      + *
      + * Users should call update() when a new Dataset has been created,
      + * before the Dataset has been materialized.  After updating [[PeriodicCheckpointer]], users are
      + * responsible for materializing the Dataset to ensure that persisting and checkpointing actually
      + * occur.
      + *
      + * When update() is called, this does the following:
      + *  - Persist new Dataset (if not yet persisted), and put in queue of persisted Datasets.
      + *  - Unpersist Datasets from queue until there are at most 3 persisted Datasets.
      + *  - If using checkpointing and the checkpoint interval has been reached,
      + *     - Checkpoint the new Dataset, and put in a queue of checkpointed Datasets.
      + *     - Remove older checkpoints.
      + *
      + * WARNINGS:
      + *  - This class should NOT be copied (since copies may conflict on which Datasets should be
      + *    checkpointed).
      + *  - This class removes checkpoint files once later Datasets have been checkpointed.
      + *    However, references to the older Datasets will still return isCheckpointed = true.
      + *
      + * @param checkpointInterval  Datasets will be checkpointed at this interval
      + * @param sc  SparkContext for the Datasets given to this checkpointer
      + * @tparam T  Dataset type, such as RDD[Double]
      + */
      +private[mllib] abstract class PeriodicCheckpointer[T](
      +    val checkpointInterval: Int,
      +    val sc: SparkContext) extends Logging {
      +
      +  /** FIFO queue of past checkpointed Datasets */
      +  private val checkpointQueue = mutable.Queue[T]()
      +
      +  /** FIFO queue of past persisted Datasets */
      +  private val persistedQueue = mutable.Queue[T]()
      +
      +  /** Number of times [[update()]] has been called */
      +  private var updateCount = 0
      +
      +  /**
      +   * Update with a new Dataset. Handle persistence and checkpointing as needed.
      +   * Since this handles persistence and checkpointing, this should be called before the Dataset
      +   * has been materialized.
      +   *
      +   * @param newData  New Dataset created from previous Datasets in the lineage.
      +   */
      +  def update(newData: T): Unit = {
      +    persist(newData)
      +    persistedQueue.enqueue(newData)
      +    // We try to maintain 2 Datasets in persistedQueue to support the semantics of this class:
      +    // Users should call [[update()]] when a new Dataset has been created,
      +    // before the Dataset has been materialized.
      +    while (persistedQueue.size > 3) {
      +      val dataToUnpersist = persistedQueue.dequeue()
      +      unpersist(dataToUnpersist)
      +    }
      +    updateCount += 1
      +
      +    // Handle checkpointing (after persisting)
      +    if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) {
      +      // Add new checkpoint before removing old checkpoints.
      +      checkpoint(newData)
      +      checkpointQueue.enqueue(newData)
      +      // Remove checkpoints before the latest one.
      +      var canDelete = true
      +      while (checkpointQueue.size > 1 && canDelete) {
      +        // Delete the oldest checkpoint only if the next checkpoint exists.
      +        if (isCheckpointed(checkpointQueue.head)) {
      +          removeCheckpointFile()
      +        } else {
      +          canDelete = false
      +        }
      +      }
      +    }
      +  }
      +
      +  /** Checkpoint the Dataset */
      +  protected def checkpoint(data: T): Unit
      +
      +  /** Return true iff the Dataset is checkpointed */
      +  protected def isCheckpointed(data: T): Boolean
      +
      +  /**
      +   * Persist the Dataset.
      +   * Note: This should handle checking the current [[StorageLevel]] of the Dataset.
      +   */
      +  protected def persist(data: T): Unit
      +
      +  /** Unpersist the Dataset */
      +  protected def unpersist(data: T): Unit
      +
      +  /** Get list of checkpoint files for this given Dataset */
      +  protected def getCheckpointFiles(data: T): Iterable[String]
      +
      +  /**
      +   * Call this at the end to delete any remaining checkpoint files.
      +   */
      +  def deleteAllCheckpoints(): Unit = {
      +    while (checkpointQueue.nonEmpty) {
      +      removeCheckpointFile()
      +    }
      +  }
      +
      +  /**
      +   * Dequeue the oldest checkpointed Dataset, and remove its checkpoint files.
      +   * This prints a warning but does not fail if the files cannot be removed.
      +   */
      +  private def removeCheckpointFile(): Unit = {
      +    val old = checkpointQueue.dequeue()
      +    // Since the old checkpoint is not deleted by Spark, we manually delete it.
      +    val fs = FileSystem.get(sc.hadoopConfiguration)
      +    getCheckpointFiles(old).foreach { checkpointFile =>
      +      try {
      +        fs.delete(new Path(checkpointFile), true)
      +      } catch {
      +        case e: Exception =>
      +          logWarning("PeriodicCheckpointer could not remove old checkpoint file: " +
      +            checkpointFile)
      +      }
      +    }
      +  }
      +
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
      index 6e5dd119dd65..11a059536c50 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
      @@ -17,11 +17,7 @@
       
       package org.apache.spark.mllib.impl
       
      -import scala.collection.mutable
      -
      -import org.apache.hadoop.fs.{Path, FileSystem}
      -
      -import org.apache.spark.Logging
      +import org.apache.spark.SparkContext
       import org.apache.spark.graphx.Graph
       import org.apache.spark.storage.StorageLevel
       
      @@ -31,12 +27,12 @@ import org.apache.spark.storage.StorageLevel
        * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as
        * unpersisting and removing checkpoint files.
        *
      - * Users should call [[PeriodicGraphCheckpointer.updateGraph()]] when a new graph has been created,
      + * Users should call update() when a new graph has been created,
        * before the graph has been materialized.  After updating [[PeriodicGraphCheckpointer]], users are
        * responsible for materializing the graph to ensure that persisting and checkpointing actually
        * occur.
        *
      - * When [[PeriodicGraphCheckpointer.updateGraph()]] is called, this does the following:
      + * When update() is called, this does the following:
        *  - Persist new graph (if not yet persisted), and put in queue of persisted graphs.
        *  - Unpersist graphs from queue until there are at most 3 persisted graphs.
        *  - If using checkpointing and the checkpoint interval has been reached,
      @@ -52,7 +48,7 @@ import org.apache.spark.storage.StorageLevel
        * Example usage:
        * {{{
        *  val (graph1, graph2, graph3, ...) = ...
      - *  val cp = new PeriodicGraphCheckpointer(graph1, dir, 2)
      + *  val cp = new PeriodicGraphCheckpointer(2, sc)
        *  graph1.vertices.count(); graph1.edges.count()
        *  // persisted: graph1
        *  cp.updateGraph(graph2)
      @@ -73,99 +69,30 @@ import org.apache.spark.storage.StorageLevel
        *  // checkpointed: graph4
        * }}}
        *
      - * @param currentGraph  Initial graph
        * @param checkpointInterval Graphs will be checkpointed at this interval
        * @tparam VD  Vertex descriptor type
        * @tparam ED  Edge descriptor type
        *
      - * TODO: Generalize this for Graphs and RDDs, and move it out of MLlib.
      + * TODO: Move this out of MLlib?
        */
       private[mllib] class PeriodicGraphCheckpointer[VD, ED](
      -    var currentGraph: Graph[VD, ED],
      -    val checkpointInterval: Int) extends Logging {
      -
      -  /** FIFO queue of past checkpointed RDDs */
      -  private val checkpointQueue = mutable.Queue[Graph[VD, ED]]()
      -
      -  /** FIFO queue of past persisted RDDs */
      -  private val persistedQueue = mutable.Queue[Graph[VD, ED]]()
      -
      -  /** Number of times [[updateGraph()]] has been called */
      -  private var updateCount = 0
      -
      -  /**
      -   * Spark Context for the Graphs given to this checkpointer.
      -   * NOTE: This code assumes that only one SparkContext is used for the given graphs.
      -   */
      -  private val sc = currentGraph.vertices.sparkContext
      +    checkpointInterval: Int,
      +    sc: SparkContext)
      +  extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) {
       
      -  updateGraph(currentGraph)
      +  override protected def checkpoint(data: Graph[VD, ED]): Unit = data.checkpoint()
       
      -  /**
      -   * Update [[currentGraph]] with a new graph. Handle persistence and checkpointing as needed.
      -   * Since this handles persistence and checkpointing, this should be called before the graph
      -   * has been materialized.
      -   *
      -   * @param newGraph  New graph created from previous graphs in the lineage.
      -   */
      -  def updateGraph(newGraph: Graph[VD, ED]): Unit = {
      -    if (newGraph.vertices.getStorageLevel == StorageLevel.NONE) {
      -      newGraph.persist()
      -    }
      -    persistedQueue.enqueue(newGraph)
      -    // We try to maintain 2 Graphs in persistedQueue to support the semantics of this class:
      -    // Users should call [[updateGraph()]] when a new graph has been created,
      -    // before the graph has been materialized.
      -    while (persistedQueue.size > 3) {
      -      val graphToUnpersist = persistedQueue.dequeue()
      -      graphToUnpersist.unpersist(blocking = false)
      -    }
      -    updateCount += 1
      +  override protected def isCheckpointed(data: Graph[VD, ED]): Boolean = data.isCheckpointed
       
      -    // Handle checkpointing (after persisting)
      -    if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) {
      -      // Add new checkpoint before removing old checkpoints.
      -      newGraph.checkpoint()
      -      checkpointQueue.enqueue(newGraph)
      -      // Remove checkpoints before the latest one.
      -      var canDelete = true
      -      while (checkpointQueue.size > 1 && canDelete) {
      -        // Delete the oldest checkpoint only if the next checkpoint exists.
      -        if (checkpointQueue.get(1).get.isCheckpointed) {
      -          removeCheckpointFile()
      -        } else {
      -          canDelete = false
      -        }
      -      }
      +  override protected def persist(data: Graph[VD, ED]): Unit = {
      +    if (data.vertices.getStorageLevel == StorageLevel.NONE) {
      +      data.persist()
           }
         }
       
      -  /**
      -   * Call this at the end to delete any remaining checkpoint files.
      -   */
      -  def deleteAllCheckpoints(): Unit = {
      -    while (checkpointQueue.size > 0) {
      -      removeCheckpointFile()
      -    }
      -  }
      +  override protected def unpersist(data: Graph[VD, ED]): Unit = data.unpersist(blocking = false)
       
      -  /**
      -   * Dequeue the oldest checkpointed Graph, and remove its checkpoint files.
      -   * This prints a warning but does not fail if the files cannot be removed.
      -   */
      -  private def removeCheckpointFile(): Unit = {
      -    val old = checkpointQueue.dequeue()
      -    // Since the old checkpoint is not deleted by Spark, we manually delete it.
      -    val fs = FileSystem.get(sc.hadoopConfiguration)
      -    old.getCheckpointFiles.foreach { checkpointFile =>
      -      try {
      -        fs.delete(new Path(checkpointFile), true)
      -      } catch {
      -        case e: Exception =>
      -          logWarning("PeriodicGraphCheckpointer could not remove old checkpoint file: " +
      -            checkpointFile)
      -      }
      -    }
      +  override protected def getCheckpointFiles(data: Graph[VD, ED]): Iterable[String] = {
      +    data.getCheckpointFiles
         }
      -
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala
      new file mode 100644
      index 000000000000..f31ed2aa90a6
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala
      @@ -0,0 +1,97 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.mllib.impl
      +
      +import org.apache.spark.SparkContext
      +import org.apache.spark.rdd.RDD
      +import org.apache.spark.storage.StorageLevel
      +
      +
      +/**
      + * This class helps with persisting and checkpointing RDDs.
      + * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as
      + * unpersisting and removing checkpoint files.
      + *
      + * Users should call update() when a new RDD has been created,
      + * before the RDD has been materialized.  After updating [[PeriodicRDDCheckpointer]], users are
      + * responsible for materializing the RDD to ensure that persisting and checkpointing actually
      + * occur.
      + *
      + * When update() is called, this does the following:
      + *  - Persist new RDD (if not yet persisted), and put in queue of persisted RDDs.
      + *  - Unpersist RDDs from queue until there are at most 3 persisted RDDs.
      + *  - If using checkpointing and the checkpoint interval has been reached,
      + *     - Checkpoint the new RDD, and put in a queue of checkpointed RDDs.
      + *     - Remove older checkpoints.
      + *
      + * WARNINGS:
      + *  - This class should NOT be copied (since copies may conflict on which RDDs should be
      + *    checkpointed).
      + *  - This class removes checkpoint files once later RDDs have been checkpointed.
      + *    However, references to the older RDDs will still return isCheckpointed = true.
      + *
      + * Example usage:
      + * {{{
      + *  val (rdd1, rdd2, rdd3, ...) = ...
      + *  val cp = new PeriodicRDDCheckpointer(2, sc)
      + *  rdd1.count();
      + *  // persisted: rdd1
      + *  cp.update(rdd2)
      + *  rdd2.count();
      + *  // persisted: rdd1, rdd2
      + *  // checkpointed: rdd2
      + *  cp.update(rdd3)
      + *  rdd3.count();
      + *  // persisted: rdd1, rdd2, rdd3
      + *  // checkpointed: rdd2
      + *  cp.update(rdd4)
      + *  rdd4.count();
      + *  // persisted: rdd2, rdd3, rdd4
      + *  // checkpointed: rdd4
      + *  cp.update(rdd5)
      + *  rdd5.count();
      + *  // persisted: rdd3, rdd4, rdd5
      + *  // checkpointed: rdd4
      + * }}}
      + *
      + * @param checkpointInterval  RDDs will be checkpointed at this interval
      + * @tparam T  RDD element type
      + *
      + * TODO: Move this out of MLlib?
      + */
      +private[mllib] class PeriodicRDDCheckpointer[T](
      +    checkpointInterval: Int,
      +    sc: SparkContext)
      +  extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) {
      +
      +  override protected def checkpoint(data: RDD[T]): Unit = data.checkpoint()
      +
      +  override protected def isCheckpointed(data: RDD[T]): Boolean = data.isCheckpointed
      +
      +  override protected def persist(data: RDD[T]): Unit = {
      +    if (data.getStorageLevel == StorageLevel.NONE) {
      +      data.persist()
      +    }
      +  }
      +
      +  override protected def unpersist(data: RDD[T]): Unit = data.unpersist(blocking = false)
      +
      +  override protected def getCheckpointFiles(data: RDD[T]): Iterable[String] = {
      +    data.getCheckpointFile.map(x => x)
      +  }
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
      index 3523f1804325..df9f4ae145b8 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
      @@ -92,6 +92,13 @@ private[spark] object BLAS extends Serializable with Logging {
           }
         }
       
      +  /** Y += a * x */
      +  private[spark] def axpy(a: Double, X: DenseMatrix, Y: DenseMatrix): Unit = {
      +    require(X.numRows == Y.numRows && X.numCols == Y.numCols, "Dimension mismatch: " +
      +      s"size(X) = ${(X.numRows, X.numCols)} but size(Y) = ${(Y.numRows, Y.numCols)}.")
      +    f2jBLAS.daxpy(X.numRows * X.numCols, a, X.values, 1, Y.values, 1)
      +  }
      +
         /**
          * dot(x, y)
          */
      @@ -229,6 +236,50 @@ private[spark] object BLAS extends Serializable with Logging {
           _nativeBLAS
         }
       
      +  /**
      +   * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR.
      +   *
      +   * @param U the upper triangular part of the matrix in a [[DenseVector]](column major)
      +   */
      +  def spr(alpha: Double, v: Vector, U: DenseVector): Unit = {
      +    spr(alpha, v, U.values)
      +  }
      +
      +  /**
      +   * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR.
      +   *
      +   * @param U the upper triangular part of the matrix packed in an array (column major)
      +   */
      +  def spr(alpha: Double, v: Vector, U: Array[Double]): Unit = {
      +    val n = v.size
      +    v match {
      +      case DenseVector(values) =>
      +        NativeBLAS.dspr("U", n, alpha, values, 1, U)
      +      case SparseVector(size, indices, values) =>
      +        val nnz = indices.length
      +        var colStartIdx = 0
      +        var prevCol = 0
      +        var col = 0
      +        var j = 0
      +        var i = 0
      +        var av = 0.0
      +        while (j < nnz) {
      +          col = indices(j)
      +          // Skip empty columns.
      +          colStartIdx += (col - prevCol) * (col + prevCol + 1) / 2
      +          col = indices(j)
      +          av = alpha * values(j)
      +          i = 0
      +          while (i <= j) {
      +            U(colStartIdx + indices(i)) += av * values(i)
      +            i += 1
      +          }
      +          j += 1
      +          prevCol = col
      +        }
      +    }
      +  }
      +
         /**
          * A := alpha * x * x^T^ + A
          * @param alpha a real scalar that will be multiplied to x * x^T^.
      @@ -303,8 +354,10 @@ private[spark] object BLAS extends Serializable with Logging {
             C: DenseMatrix): Unit = {
           require(!C.isTransposed,
             "The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.")
      -    if (alpha == 0.0) {
      -      logDebug("gemm: alpha is equal to 0. Returning C.")
      +    if (alpha == 0.0 && beta == 1.0) {
      +      logDebug("gemm: alpha is equal to 0 and beta is equal to 1. Returning C.")
      +    } else if (alpha == 0.0) {
      +      f2jBLAS.dscal(C.values.length, beta, C.values, 1)
           } else {
             A match {
               case sparse: SparseMatrix => gemm(alpha, sparse, B, beta, C)
      @@ -408,8 +461,8 @@ private[spark] object BLAS extends Serializable with Logging {
               }
             }
           } else {
      -      // Scale matrix first if `beta` is not equal to 0.0
      -      if (beta != 0.0) {
      +      // Scale matrix first if `beta` is not equal to 1.0
      +      if (beta != 1.0) {
               f2jBLAS.dscal(C.values.length, beta, C.values, 1)
             }
             // Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of
      @@ -469,9 +522,11 @@ private[spark] object BLAS extends Serializable with Logging {
           require(A.numCols == x.size,
             s"The columns of A don't match the number of elements of x. A: ${A.numCols}, x: ${x.size}")
           require(A.numRows == y.size,
      -      s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}}")
      -    if (alpha == 0.0) {
      -      logDebug("gemv: alpha is equal to 0. Returning y.")
      +      s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}")
      +    if (alpha == 0.0 && beta == 1.0) {
      +      logDebug("gemv: alpha is equal to 0 and beta is equal to 1. Returning y.")
      +    } else if (alpha == 0.0) {
      +      scal(beta, y)
           } else {
             (A, x) match {
               case (smA: SparseMatrix, dvx: DenseVector) =>
      @@ -526,11 +581,6 @@ private[spark] object BLAS extends Serializable with Logging {
           val xValues = x.values
           val yValues = y.values
       
      -    if (alpha == 0.0) {
      -      scal(beta, y)
      -      return
      -    }
      -
           if (A.isTransposed) {
             var rowCounterForA = 0
             while (rowCounterForA < mA) {
      @@ -581,11 +631,6 @@ private[spark] object BLAS extends Serializable with Logging {
           val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs
           val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices
       
      -    if (alpha == 0.0) {
      -      scal(beta, y)
      -      return
      -    }
      -
           if (A.isTransposed) {
             var rowCounter = 0
             while (rowCounter < mA) {
      @@ -604,7 +649,7 @@ private[spark] object BLAS extends Serializable with Logging {
               rowCounter += 1
             }
           } else {
      -      scal(beta, y)
      +      if (beta != 1.0) scal(beta, y)
       
             var colCounterForA = 0
             var k = 0
      @@ -659,7 +704,7 @@ private[spark] object BLAS extends Serializable with Logging {
               rowCounter += 1
             }
           } else {
      -      scal(beta, y)
      +      if (beta != 1.0) scal(beta, y)
             // Perform matrix-vector multiplication and add to y
             var colCounterForA = 0
             while (colCounterForA < nA) {
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
      index 85e63b1382b5..c02ba426fcc3 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
      @@ -23,27 +23,32 @@ import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHash
       
       import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM}
       
      -import org.apache.spark.annotation.DeveloperApi
      -import org.apache.spark.sql.Row
      -import org.apache.spark.sql.types._
      +import org.apache.spark.annotation.{DeveloperApi, Since}
       import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
      +import org.apache.spark.sql.catalyst.InternalRow
      +import org.apache.spark.sql.types._
       
       /**
        * Trait for a local matrix.
        */
       @SQLUserDefinedType(udt = classOf[MatrixUDT])
      +@Since("1.0.0")
       sealed trait Matrix extends Serializable {
       
         /** Number of rows. */
      +  @Since("1.0.0")
         def numRows: Int
       
         /** Number of columns. */
      +  @Since("1.0.0")
         def numCols: Int
       
         /** Flag that keeps track whether the matrix is transposed or not. False by default. */
      +  @Since("1.3.0")
         val isTransposed: Boolean = false
       
         /** Converts to a dense array in column major. */
      +  @Since("1.0.0")
         def toArray: Array[Double] = {
           val newArray = new Array[Double](numRows * numCols)
           foreachActive { (i, j, v) =>
      @@ -56,6 +61,7 @@ sealed trait Matrix extends Serializable {
         private[mllib] def toBreeze: BM[Double]
       
         /** Gets the (i, j)-th element. */
      +  @Since("1.3.0")
         def apply(i: Int, j: Int): Double
       
         /** Return the index for the (i, j)-th element in the backing array. */
      @@ -65,12 +71,15 @@ sealed trait Matrix extends Serializable {
         private[mllib] def update(i: Int, j: Int, v: Double): Unit
       
         /** Get a deep copy of the matrix. */
      +  @Since("1.2.0")
         def copy: Matrix
       
         /** Transpose the Matrix. Returns a new `Matrix` instance sharing the same underlying data. */
      +  @Since("1.3.0")
         def transpose: Matrix
       
         /** Convenience method for `Matrix`-`DenseMatrix` multiplication. */
      +  @Since("1.2.0")
         def multiply(y: DenseMatrix): DenseMatrix = {
           val C: DenseMatrix = DenseMatrix.zeros(numRows, y.numCols)
           BLAS.gemm(1.0, this, y, 0.0, C)
      @@ -78,11 +87,13 @@ sealed trait Matrix extends Serializable {
         }
       
         /** Convenience method for `Matrix`-`DenseVector` multiplication. For binary compatibility. */
      +  @Since("1.2.0")
         def multiply(y: DenseVector): DenseVector = {
           multiply(y.asInstanceOf[Vector])
         }
       
         /** Convenience method for `Matrix`-`Vector` multiplication. */
      +  @Since("1.4.0")
         def multiply(y: Vector): DenseVector = {
           val output = new DenseVector(new Array[Double](numRows))
           BLAS.gemv(1.0, this, y, 0.0, output)
      @@ -93,12 +104,13 @@ sealed trait Matrix extends Serializable {
         override def toString: String = toBreeze.toString()
       
         /** A human readable representation of the matrix with maximum lines and width */
      +  @Since("1.4.0")
         def toString(maxLines: Int, maxLineWidth: Int): String = toBreeze.toString(maxLines, maxLineWidth)
       
         /** Map the values of this matrix using a function. Generates a new matrix. Performs the
           * function on only the backing array. For example, an operation such as addition or
           * subtraction will only be performed on the non-zero values in a `SparseMatrix`. */
      -  private[mllib] def map(f: Double => Double): Matrix
      +  private[spark] def map(f: Double => Double): Matrix
       
         /** Update all the values of this matrix using the function f. Performed in-place on the
           * backing array. For example, an operation such as addition or subtraction will only be
      @@ -114,6 +126,18 @@ sealed trait Matrix extends Serializable {
          *          corresponding value in the matrix with type `Double`.
          */
         private[spark] def foreachActive(f: (Int, Int, Double) => Unit)
      +
      +  /**
      +   * Find the number of non-zero active values.
      +   */
      +  @Since("1.5.0")
      +  def numNonzeros: Int
      +
      +  /**
      +   * Find the number of values stored explicitly. These values can be zero as well.
      +   */
      +  @Since("1.5.0")
      +  def numActives: Int
       }
       
       @DeveloperApi
      @@ -137,16 +161,16 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
             ))
         }
       
      -  override def serialize(obj: Any): Row = {
      +  override def serialize(obj: Any): InternalRow = {
           val row = new GenericMutableRow(7)
           obj match {
             case sm: SparseMatrix =>
               row.setByte(0, 0)
               row.setInt(1, sm.numRows)
               row.setInt(2, sm.numCols)
      -        row.update(3, sm.colPtrs.toSeq)
      -        row.update(4, sm.rowIndices.toSeq)
      -        row.update(5, sm.values.toSeq)
      +        row.update(3, new GenericArrayData(sm.colPtrs.map(_.asInstanceOf[Any])))
      +        row.update(4, new GenericArrayData(sm.rowIndices.map(_.asInstanceOf[Any])))
      +        row.update(5, new GenericArrayData(sm.values.map(_.asInstanceOf[Any])))
               row.setBoolean(6, sm.isTransposed)
       
             case dm: DenseMatrix =>
      @@ -155,7 +179,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
               row.setInt(2, dm.numCols)
               row.setNullAt(3)
               row.setNullAt(4)
      -        row.update(5, dm.values.toSeq)
      +        row.update(5, new GenericArrayData(dm.values.map(_.asInstanceOf[Any])))
               row.setBoolean(6, dm.isTransposed)
           }
           row
      @@ -163,20 +187,18 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
       
         override def deserialize(datum: Any): Matrix = {
           datum match {
      -      // TODO: something wrong with UDT serialization, should never happen.
      -      case m: Matrix => m
      -      case row: Row =>
      -        require(row.length == 7,
      -          s"MatrixUDT.deserialize given row with length ${row.length} but requires length == 7")
      +      case row: InternalRow =>
      +        require(row.numFields == 7,
      +          s"MatrixUDT.deserialize given row with length ${row.numFields} but requires length == 7")
               val tpe = row.getByte(0)
               val numRows = row.getInt(1)
               val numCols = row.getInt(2)
      -        val values = row.getAs[Iterable[Double]](5).toArray
      +        val values = row.getArray(5).toDoubleArray()
               val isTransposed = row.getBoolean(6)
               tpe match {
                 case 0 =>
      -            val colPtrs = row.getAs[Iterable[Int]](3).toArray
      -            val rowIndices = row.getAs[Iterable[Int]](4).toArray
      +            val colPtrs = row.getArray(3).toIntArray()
      +            val rowIndices = row.getArray(4).toIntArray()
                   new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
                 case 1 =>
                   new DenseMatrix(numRows, numCols, values, isTransposed)
      @@ -193,7 +215,8 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
           }
         }
       
      -  override def hashCode(): Int = 1994
      +  // see [SPARK-8647], this achieves the needed constant hash code without constant no.
      +  override def hashCode(): Int = classOf[MatrixUDT].getName.hashCode()
       
         override def typeName: String = "matrix"
       
      @@ -219,12 +242,13 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
        * @param isTransposed whether the matrix is transposed. If true, `values` stores the matrix in
        *                     row major.
        */
      +@Since("1.0.0")
       @SQLUserDefinedType(udt = classOf[MatrixUDT])
      -class DenseMatrix(
      -    val numRows: Int,
      -    val numCols: Int,
      -    val values: Array[Double],
      -    override val isTransposed: Boolean) extends Matrix {
      +class DenseMatrix @Since("1.3.0") (
      +    @Since("1.0.0") val numRows: Int,
      +    @Since("1.0.0") val numCols: Int,
      +    @Since("1.0.0") val values: Array[Double],
      +    @Since("1.3.0") override val isTransposed: Boolean) extends Matrix {
       
         require(values.length == numRows * numCols, "The number of values supplied doesn't match the " +
           s"size of the matrix! values.length: ${values.length}, numRows * numCols: ${numRows * numCols}")
      @@ -244,12 +268,12 @@ class DenseMatrix(
          * @param numCols number of columns
          * @param values matrix entries in column major
          */
      +  @Since("1.0.0")
         def this(numRows: Int, numCols: Int, values: Array[Double]) =
           this(numRows, numCols, values, false)
       
         override def equals(o: Any): Boolean = o match {
      -    case m: DenseMatrix =>
      -      m.numRows == numRows && m.numCols == numCols && Arrays.equals(toArray, m.toArray)
      +    case m: Matrix => toBreeze == m.toBreeze
           case _ => false
         }
       
      @@ -268,6 +292,7 @@ class DenseMatrix(
       
         private[mllib] def apply(i: Int): Double = values(i)
       
      +  @Since("1.3.0")
         override def apply(i: Int, j: Int): Double = values(index(i, j))
       
         private[mllib] def index(i: Int, j: Int): Int = {
      @@ -278,9 +303,10 @@ class DenseMatrix(
           values(index(i, j)) = v
         }
       
      +  @Since("1.4.0")
         override def copy: DenseMatrix = new DenseMatrix(numRows, numCols, values.clone())
       
      -  private[mllib] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f),
      +  private[spark] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f),
           isTransposed)
       
         private[mllib] def update(f: Double => Double): DenseMatrix = {
      @@ -293,6 +319,7 @@ class DenseMatrix(
           this
         }
       
      +  @Since("1.3.0")
         override def transpose: DenseMatrix = new DenseMatrix(numCols, numRows, values, !isTransposed)
       
         private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = {
      @@ -323,10 +350,17 @@ class DenseMatrix(
           }
         }
       
      +  @Since("1.5.0")
      +  override def numNonzeros: Int = values.count(_ != 0)
      +
      +  @Since("1.5.0")
      +  override def numActives: Int = values.length
      +
         /**
          * Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed
          * set to false.
          */
      +  @Since("1.3.0")
         def toSparse: SparseMatrix = {
           val spVals: MArrayBuilder[Double] = new MArrayBuilder.ofDouble
           val colPtrs: Array[Int] = new Array[Int](numCols + 1)
      @@ -354,6 +388,7 @@ class DenseMatrix(
       /**
        * Factory methods for [[org.apache.spark.mllib.linalg.DenseMatrix]].
        */
      +@Since("1.3.0")
       object DenseMatrix {
       
         /**
      @@ -362,6 +397,7 @@ object DenseMatrix {
          * @param numCols number of columns of the matrix
          * @return `DenseMatrix` with size `numRows` x `numCols` and values of zeros
          */
      +  @Since("1.3.0")
         def zeros(numRows: Int, numCols: Int): DenseMatrix = {
           require(numRows.toLong * numCols <= Int.MaxValue,
                   s"$numRows x $numCols dense matrix is too large to allocate")
      @@ -374,6 +410,7 @@ object DenseMatrix {
          * @param numCols number of columns of the matrix
          * @return `DenseMatrix` with size `numRows` x `numCols` and values of ones
          */
      +  @Since("1.3.0")
         def ones(numRows: Int, numCols: Int): DenseMatrix = {
           require(numRows.toLong * numCols <= Int.MaxValue,
                   s"$numRows x $numCols dense matrix is too large to allocate")
      @@ -385,6 +422,7 @@ object DenseMatrix {
          * @param n number of rows and columns of the matrix
          * @return `DenseMatrix` with size `n` x `n` and values of ones on the diagonal
          */
      +  @Since("1.3.0")
         def eye(n: Int): DenseMatrix = {
           val identity = DenseMatrix.zeros(n, n)
           var i = 0
      @@ -402,6 +440,7 @@ object DenseMatrix {
          * @param rng a random number generator
          * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1)
          */
      +  @Since("1.3.0")
         def rand(numRows: Int, numCols: Int, rng: Random): DenseMatrix = {
           require(numRows.toLong * numCols <= Int.MaxValue,
                   s"$numRows x $numCols dense matrix is too large to allocate")
      @@ -415,6 +454,7 @@ object DenseMatrix {
          * @param rng a random number generator
          * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1)
          */
      +  @Since("1.3.0")
         def randn(numRows: Int, numCols: Int, rng: Random): DenseMatrix = {
           require(numRows.toLong * numCols <= Int.MaxValue,
                   s"$numRows x $numCols dense matrix is too large to allocate")
      @@ -427,6 +467,7 @@ object DenseMatrix {
          * @return Square `DenseMatrix` with size `values.length` x `values.length` and `values`
          *         on the diagonal
          */
      +  @Since("1.3.0")
         def diag(vector: Vector): DenseMatrix = {
           val n = vector.size
           val matrix = DenseMatrix.zeros(n, n)
      @@ -462,14 +503,15 @@ object DenseMatrix {
        *                     Compressed Sparse Row (CSR) format, where `colPtrs` behaves as rowPtrs,
        *                     and `rowIndices` behave as colIndices, and `values` are stored in row major.
        */
      +@Since("1.2.0")
       @SQLUserDefinedType(udt = classOf[MatrixUDT])
      -class SparseMatrix(
      -    val numRows: Int,
      -    val numCols: Int,
      -    val colPtrs: Array[Int],
      -    val rowIndices: Array[Int],
      -    val values: Array[Double],
      -    override val isTransposed: Boolean) extends Matrix {
      +class SparseMatrix @Since("1.3.0") (
      +    @Since("1.2.0") val numRows: Int,
      +    @Since("1.2.0") val numCols: Int,
      +    @Since("1.2.0") val colPtrs: Array[Int],
      +    @Since("1.2.0") val rowIndices: Array[Int],
      +    @Since("1.2.0") val values: Array[Double],
      +    @Since("1.3.0") override val isTransposed: Boolean) extends Matrix {
       
         require(values.length == rowIndices.length, "The number of row indices and values don't match! " +
           s"values.length: ${values.length}, rowIndices.length: ${rowIndices.length}")
      @@ -499,6 +541,7 @@ class SparseMatrix(
          *                   order for each column
          * @param values non-zero matrix entries in column major
          */
      +  @Since("1.2.0")
         def this(
             numRows: Int,
             numCols: Int,
      @@ -506,6 +549,11 @@ class SparseMatrix(
             rowIndices: Array[Int],
             values: Array[Double]) = this(numRows, numCols, colPtrs, rowIndices, values, false)
       
      +  override def equals(o: Any): Boolean = o match {
      +    case m: Matrix => toBreeze == m.toBreeze
      +    case _ => false
      +  }
      +
         private[mllib] def toBreeze: BM[Double] = {
            if (!isTransposed) {
              new BSM[Double](values, numRows, numCols, colPtrs, rowIndices)
      @@ -515,6 +563,7 @@ class SparseMatrix(
            }
         }
       
      +  @Since("1.3.0")
         override def apply(i: Int, j: Int): Double = {
           val ind = index(i, j)
           if (ind < 0) 0.0 else values(ind)
      @@ -538,11 +587,12 @@ class SparseMatrix(
           }
         }
       
      +  @Since("1.4.0")
         override def copy: SparseMatrix = {
           new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone())
         }
       
      -  private[mllib] def map(f: Double => Double) =
      +  private[spark] def map(f: Double => Double) =
           new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f), isTransposed)
       
         private[mllib] def update(f: Double => Double): SparseMatrix = {
      @@ -555,6 +605,7 @@ class SparseMatrix(
           this
         }
       
      +  @Since("1.3.0")
         override def transpose: SparseMatrix =
           new SparseMatrix(numCols, numRows, colPtrs, rowIndices, values, !isTransposed)
       
      @@ -589,14 +640,23 @@ class SparseMatrix(
          * Generate a `DenseMatrix` from the given `SparseMatrix`. The new matrix will have isTransposed
          * set to false.
          */
      +  @Since("1.3.0")
         def toDense: DenseMatrix = {
           new DenseMatrix(numRows, numCols, toArray)
         }
      +
      +  @Since("1.5.0")
      +  override def numNonzeros: Int = values.count(_ != 0)
      +
      +  @Since("1.5.0")
      +  override def numActives: Int = values.length
      +
       }
       
       /**
        * Factory methods for [[org.apache.spark.mllib.linalg.SparseMatrix]].
        */
      +@Since("1.3.0")
       object SparseMatrix {
       
         /**
      @@ -608,6 +668,7 @@ object SparseMatrix {
          * @param entries Array of (i, j, value) tuples
          * @return The corresponding `SparseMatrix`
          */
      +  @Since("1.3.0")
         def fromCOO(numRows: Int, numCols: Int, entries: Iterable[(Int, Int, Double)]): SparseMatrix = {
           val sortedEntries = entries.toSeq.sortBy(v => (v._2, v._1))
           val numEntries = sortedEntries.size
      @@ -656,6 +717,7 @@ object SparseMatrix {
          * @param n number of rows and columns of the matrix
          * @return `SparseMatrix` with size `n` x `n` and values of ones on the diagonal
          */
      +  @Since("1.3.0")
         def speye(n: Int): SparseMatrix = {
           new SparseMatrix(n, n, (0 to n).toArray, (0 until n).toArray, Array.fill(n)(1.0))
         }
      @@ -725,6 +787,7 @@ object SparseMatrix {
          * @param rng a random number generator
          * @return `SparseMatrix` with size `numRows` x `numCols` and values in U(0, 1)
          */
      +  @Since("1.3.0")
         def sprand(numRows: Int, numCols: Int, density: Double, rng: Random): SparseMatrix = {
           val mat = genRandMatrix(numRows, numCols, density, rng)
           mat.update(i => rng.nextDouble())
      @@ -738,6 +801,7 @@ object SparseMatrix {
          * @param rng a random number generator
          * @return `SparseMatrix` with size `numRows` x `numCols` and values in N(0, 1)
          */
      +  @Since("1.3.0")
         def sprandn(numRows: Int, numCols: Int, density: Double, rng: Random): SparseMatrix = {
           val mat = genRandMatrix(numRows, numCols, density, rng)
           mat.update(i => rng.nextGaussian())
      @@ -749,6 +813,7 @@ object SparseMatrix {
          * @return Square `SparseMatrix` with size `values.length` x `values.length` and non-zero
          *         `values` on the diagonal
          */
      +  @Since("1.3.0")
         def spdiag(vector: Vector): SparseMatrix = {
           val n = vector.size
           vector match {
      @@ -765,6 +830,7 @@ object SparseMatrix {
       /**
        * Factory methods for [[org.apache.spark.mllib.linalg.Matrix]].
        */
      +@Since("1.0.0")
       object Matrices {
       
         /**
      @@ -774,6 +840,7 @@ object Matrices {
          * @param numCols number of columns
          * @param values matrix entries in column major
          */
      +  @Since("1.0.0")
         def dense(numRows: Int, numCols: Int, values: Array[Double]): Matrix = {
           new DenseMatrix(numRows, numCols, values)
         }
      @@ -787,6 +854,7 @@ object Matrices {
          * @param rowIndices the row index of the entry
          * @param values non-zero matrix entries in column major
          */
      +  @Since("1.2.0")
         def sparse(
            numRows: Int,
            numCols: Int,
      @@ -820,6 +888,7 @@ object Matrices {
          * @param numCols number of columns of the matrix
          * @return `Matrix` with size `numRows` x `numCols` and values of zeros
          */
      +  @Since("1.2.0")
         def zeros(numRows: Int, numCols: Int): Matrix = DenseMatrix.zeros(numRows, numCols)
       
         /**
      @@ -828,6 +897,7 @@ object Matrices {
          * @param numCols number of columns of the matrix
          * @return `Matrix` with size `numRows` x `numCols` and values of ones
          */
      +  @Since("1.2.0")
         def ones(numRows: Int, numCols: Int): Matrix = DenseMatrix.ones(numRows, numCols)
       
         /**
      @@ -835,6 +905,7 @@ object Matrices {
          * @param n number of rows and columns of the matrix
          * @return `Matrix` with size `n` x `n` and values of ones on the diagonal
          */
      +  @Since("1.2.0")
         def eye(n: Int): Matrix = DenseMatrix.eye(n)
       
         /**
      @@ -842,6 +913,7 @@ object Matrices {
          * @param n number of rows and columns of the matrix
          * @return `Matrix` with size `n` x `n` and values of ones on the diagonal
          */
      +  @Since("1.3.0")
         def speye(n: Int): Matrix = SparseMatrix.speye(n)
       
         /**
      @@ -851,6 +923,7 @@ object Matrices {
          * @param rng a random number generator
          * @return `Matrix` with size `numRows` x `numCols` and values in U(0, 1)
          */
      +  @Since("1.2.0")
         def rand(numRows: Int, numCols: Int, rng: Random): Matrix =
           DenseMatrix.rand(numRows, numCols, rng)
       
      @@ -862,6 +935,7 @@ object Matrices {
          * @param rng a random number generator
          * @return `Matrix` with size `numRows` x `numCols` and values in U(0, 1)
          */
      +  @Since("1.3.0")
         def sprand(numRows: Int, numCols: Int, density: Double, rng: Random): Matrix =
           SparseMatrix.sprand(numRows, numCols, density, rng)
       
      @@ -872,6 +946,7 @@ object Matrices {
          * @param rng a random number generator
          * @return `Matrix` with size `numRows` x `numCols` and values in N(0, 1)
          */
      +  @Since("1.2.0")
         def randn(numRows: Int, numCols: Int, rng: Random): Matrix =
           DenseMatrix.randn(numRows, numCols, rng)
       
      @@ -883,6 +958,7 @@ object Matrices {
          * @param rng a random number generator
          * @return `Matrix` with size `numRows` x `numCols` and values in N(0, 1)
          */
      +  @Since("1.3.0")
         def sprandn(numRows: Int, numCols: Int, density: Double, rng: Random): Matrix =
           SparseMatrix.sprandn(numRows, numCols, density, rng)
       
      @@ -892,6 +968,7 @@ object Matrices {
          * @return Square `Matrix` with size `values.length` x `values.length` and `values`
          *         on the diagonal
          */
      +  @Since("1.2.0")
         def diag(vector: Vector): Matrix = DenseMatrix.diag(vector)
       
         /**
      @@ -901,6 +978,7 @@ object Matrices {
          * @param matrices array of matrices
          * @return a single `Matrix` composed of the matrices that were horizontally concatenated
          */
      +  @Since("1.3.0")
         def horzcat(matrices: Array[Matrix]): Matrix = {
           if (matrices.isEmpty) {
             return new DenseMatrix(0, 0, Array[Double]())
      @@ -959,6 +1037,7 @@ object Matrices {
          * @param matrices array of matrices
          * @return a single `Matrix` composed of the matrices that were vertically concatenated
          */
      +  @Since("1.3.0")
         def vertcat(matrices: Array[Matrix]): Matrix = {
           if (matrices.isEmpty) {
             return new DenseMatrix(0, 0, Array[Double]())
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala
      index 9669c364bad8..4dcf8f28f202 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala
      @@ -17,11 +17,21 @@
       
       package org.apache.spark.mllib.linalg
       
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       
       /**
        * :: Experimental ::
        * Represents singular value decomposition (SVD) factors.
        */
      +@Since("1.0.0")
       @Experimental
       case class SingularValueDecomposition[UType, VType](U: UType, s: Vector, V: VType)
      +
      +/**
      + * :: Experimental ::
      + * Represents QR factors.
      + */
      +@Since("1.5.0")
      +@Experimental
      +case class QRDecomposition[QType, RType](Q: QType, R: RType)
      +
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
      index 2ffa497a99d9..3642e9286504 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
      @@ -26,9 +26,9 @@ import scala.collection.JavaConverters._
       import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
       
       import org.apache.spark.SparkException
      -import org.apache.spark.annotation.DeveloperApi
      +import org.apache.spark.annotation.{AlphaComponent, Since}
       import org.apache.spark.mllib.util.NumericParser
      -import org.apache.spark.sql.Row
      +import org.apache.spark.sql.catalyst.InternalRow
       import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
       import org.apache.spark.sql.types._
       
      @@ -38,16 +38,19 @@ import org.apache.spark.sql.types._
        * Note: Users should not implement this interface.
        */
       @SQLUserDefinedType(udt = classOf[VectorUDT])
      +@Since("1.0.0")
       sealed trait Vector extends Serializable {
       
         /**
          * Size of the vector.
          */
      +  @Since("1.0.0")
         def size: Int
       
         /**
          * Converts the instance to a double array.
          */
      +  @Since("1.0.0")
         def toArray: Array[Double]
       
         override def equals(other: Any): Boolean = {
      @@ -68,20 +71,22 @@ sealed trait Vector extends Serializable {
         }
       
         /**
      -   * Returns a hash code value for the vector. The hash code is based on its size and its nonzeros
      -   * in the first 16 entries, using a hash algorithm similar to [[java.util.Arrays.hashCode]].
      +   * Returns a hash code value for the vector. The hash code is based on its size and its first 128
      +   * nonzero entries, using a hash algorithm similar to [[java.util.Arrays.hashCode]].
          */
         override def hashCode(): Int = {
           // This is a reference implementation. It calls return in foreachActive, which is slow.
           // Subclasses should override it with optimized implementation.
           var result: Int = 31 + size
      +    var nnz = 0
           this.foreachActive { (index, value) =>
      -      if (index < 16) {
      +      if (nnz < Vectors.MAX_HASH_NNZ) {
               // ignore explicit 0 for comparison between sparse and dense
               if (value != 0) {
                 result = 31 * result + index
                 val bits = java.lang.Double.doubleToLongBits(value)
                 result = 31 * result + (bits ^ (bits >>> 32)).toInt
      +          nnz += 1
               }
             } else {
               return result
      @@ -99,11 +104,13 @@ sealed trait Vector extends Serializable {
          * Gets the value of the ith element.
          * @param i index
          */
      +  @Since("1.1.0")
         def apply(i: Int): Double = toBreeze(i)
       
         /**
          * Makes a deep copy of this vector.
          */
      +  @Since("1.1.0")
         def copy: Vector = {
           throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.")
         }
      @@ -121,26 +128,31 @@ sealed trait Vector extends Serializable {
          * Number of active entries.  An "active entry" is an element which is explicitly stored,
          * regardless of its value.  Note that inactive entries have value 0.
          */
      +  @Since("1.4.0")
         def numActives: Int
       
         /**
          * Number of nonzero elements. This scans all active values and count nonzeros.
          */
      +  @Since("1.4.0")
         def numNonzeros: Int
       
         /**
          * Converts this vector to a sparse vector with all explicit zeros removed.
          */
      +  @Since("1.4.0")
         def toSparse: SparseVector
       
         /**
          * Converts this vector to a dense vector.
          */
      +  @Since("1.4.0")
         def toDense: DenseVector = new DenseVector(this.toArray)
       
         /**
          * Returns a vector in either dense or sparse format, whichever uses less storage.
          */
      +  @Since("1.4.0")
         def compressed: Vector = {
           val nnz = numNonzeros
           // A dense vector needs 8 * size + 8 bytes, while a sparse vector needs 12 * nnz + 20 bytes.
      @@ -150,18 +162,23 @@ sealed trait Vector extends Serializable {
             toDense
           }
         }
      +
      +  /**
      +   * Find the index of a maximal element.  Returns the first maximal element in case of a tie.
      +   * Returns -1 if vector has length 0.
      +   */
      +  @Since("1.5.0")
      +  def argmax: Int
       }
       
       /**
      - * :: DeveloperApi ::
      + * :: AlphaComponent ::
        *
        * User-defined type for [[Vector]] which allows easy interaction with SQL
        * via [[org.apache.spark.sql.DataFrame]].
      - *
      - * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
        */
      -@DeveloperApi
      -private[spark] class VectorUDT extends UserDefinedType[Vector] {
      +@AlphaComponent
      +class VectorUDT extends UserDefinedType[Vector] {
       
         override def sqlType: StructType = {
           // type: 0 = sparse, 1 = dense
      @@ -175,51 +192,41 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
             StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))
         }
       
      -  override def serialize(obj: Any): Row = {
      +  override def serialize(obj: Any): InternalRow = {
           obj match {
             case SparseVector(size, indices, values) =>
               val row = new GenericMutableRow(4)
               row.setByte(0, 0)
               row.setInt(1, size)
      -        row.update(2, indices.toSeq)
      -        row.update(3, values.toSeq)
      +        row.update(2, new GenericArrayData(indices.map(_.asInstanceOf[Any])))
      +        row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any])))
               row
             case DenseVector(values) =>
               val row = new GenericMutableRow(4)
               row.setByte(0, 1)
               row.setNullAt(1)
               row.setNullAt(2)
      -        row.update(3, values.toSeq)
      -        row
      -      // TODO: There are bugs in UDT serialization because we don't have a clear separation between
      -      // TODO: internal SQL types and language specific types (including UDT). UDT serialize and
      -      // TODO: deserialize may get called twice. See SPARK-7186.
      -      case row: Row =>
      +        row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any])))
               row
           }
         }
       
         override def deserialize(datum: Any): Vector = {
           datum match {
      -      case row: Row =>
      -        require(row.length == 4,
      -          s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4")
      +      case row: InternalRow =>
      +        require(row.numFields == 4,
      +          s"VectorUDT.deserialize given row with length ${row.numFields} but requires length == 4")
               val tpe = row.getByte(0)
               tpe match {
                 case 0 =>
                   val size = row.getInt(1)
      -            val indices = row.getAs[Iterable[Int]](2).toArray
      -            val values = row.getAs[Iterable[Double]](3).toArray
      +            val indices = row.getArray(2).toIntArray()
      +            val values = row.getArray(3).toDoubleArray()
                   new SparseVector(size, indices, values)
                 case 1 =>
      -            val values = row.getAs[Iterable[Double]](3).toArray
      +            val values = row.getArray(3).toDoubleArray()
                   new DenseVector(values)
               }
      -      // TODO: There are bugs in UDT serialization because we don't have a clear separation between
      -      // TODO: internal SQL types and language specific types (including UDT). UDT serialize and
      -      // TODO: deserialize may get called twice. See SPARK-7186.
      -      case v: Vector =>
      -        v
           }
         }
       
      @@ -234,7 +241,8 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
           }
         }
       
      -  override def hashCode: Int = 7919
      +  // see [SPARK-8647], this achieves the needed constant hash code without constant no.
      +  override def hashCode(): Int = classOf[VectorUDT].getName.hashCode()
       
         override def typeName: String = "vector"
       
      @@ -246,11 +254,13 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
        * We don't use the name `Vector` because Scala imports
        * [[scala.collection.immutable.Vector]] by default.
        */
      +@Since("1.0.0")
       object Vectors {
       
         /**
          * Creates a dense vector from its values.
          */
      +  @Since("1.0.0")
         @varargs
         def dense(firstValue: Double, otherValues: Double*): Vector =
           new DenseVector((firstValue +: otherValues).toArray)
      @@ -259,6 +269,7 @@ object Vectors {
         /**
          * Creates a dense vector from a double array.
          */
      +  @Since("1.0.0")
         def dense(values: Array[Double]): Vector = new DenseVector(values)
       
         /**
      @@ -268,6 +279,7 @@ object Vectors {
          * @param indices index array, must be strictly increasing.
          * @param values value array, must have the same length as indices.
          */
      +  @Since("1.0.0")
         def sparse(size: Int, indices: Array[Int], values: Array[Double]): Vector =
           new SparseVector(size, indices, values)
       
      @@ -277,6 +289,7 @@ object Vectors {
          * @param size vector size.
          * @param elements vector elements in (index, value) pairs.
          */
      +  @Since("1.0.0")
         def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = {
           require(size > 0, "The size of the requested sparse vector must be greater than 0.")
       
      @@ -298,6 +311,7 @@ object Vectors {
          * @param size vector size.
          * @param elements vector elements in (index, value) pairs.
          */
      +  @Since("1.0.0")
         def sparse(size: Int, elements: JavaIterable[(JavaInteger, JavaDouble)]): Vector = {
           sparse(size, elements.asScala.map { case (i, x) =>
             (i.intValue(), x.doubleValue())
      @@ -310,6 +324,7 @@ object Vectors {
          * @param size vector size
          * @return a zero vector
          */
      +  @Since("1.1.0")
         def zeros(size: Int): Vector = {
           new DenseVector(new Array[Double](size))
         }
      @@ -317,6 +332,7 @@ object Vectors {
         /**
          * Parses a string resulted from [[Vector.toString]] into a [[Vector]].
          */
      +  @Since("1.1.0")
         def parse(s: String): Vector = {
           parseNumeric(NumericParser.parse(s))
         }
      @@ -360,6 +376,7 @@ object Vectors {
          * @param p norm.
          * @return norm in L^p^ space.
          */
      +  @Since("1.3.0")
         def norm(vector: Vector, p: Double): Double = {
           require(p >= 1.0, "To compute the p-norm of the vector, we require that you specify a p>=1. " +
             s"You specified p=$p.")
      @@ -412,6 +429,7 @@ object Vectors {
          * @param v2 second Vector.
          * @return squared distance between two Vectors.
          */
      +  @Since("1.3.0")
         def sqdist(v1: Vector, v2: Vector): Double = {
           require(v1.size == v2.size, s"Vector dimensions do not match: Dim(v1)=${v1.size} and Dim(v2)" +
             s"=${v2.size}.")
      @@ -520,24 +538,33 @@ object Vectors {
           }
           allEqual
         }
      +
      +  /** Max number of nonzero entries used in computing hash code. */
      +  private[linalg] val MAX_HASH_NNZ = 128
       }
       
       /**
        * A dense vector represented by a value array.
        */
      +@Since("1.0.0")
       @SQLUserDefinedType(udt = classOf[VectorUDT])
      -class DenseVector(val values: Array[Double]) extends Vector {
      +class DenseVector @Since("1.0.0") (
      +    @Since("1.0.0") val values: Array[Double]) extends Vector {
       
      +  @Since("1.0.0")
         override def size: Int = values.length
       
         override def toString: String = values.mkString("[", ",", "]")
       
      +  @Since("1.0.0")
         override def toArray: Array[Double] = values
       
         private[spark] override def toBreeze: BV[Double] = new BDV[Double](values)
       
      +  @Since("1.0.0")
         override def apply(i: Int): Double = values(i)
       
      +  @Since("1.1.0")
         override def copy: DenseVector = {
           new DenseVector(values.clone())
         }
      @@ -556,21 +583,25 @@ class DenseVector(val values: Array[Double]) extends Vector {
         override def hashCode(): Int = {
           var result: Int = 31 + size
           var i = 0
      -    val end = math.min(values.length, 16)
      -    while (i < end) {
      +    val end = values.length
      +    var nnz = 0
      +    while (i < end && nnz < Vectors.MAX_HASH_NNZ) {
             val v = values(i)
             if (v != 0.0) {
               result = 31 * result + i
               val bits = java.lang.Double.doubleToLongBits(values(i))
               result = 31 * result + (bits ^ (bits >>> 32)).toInt
      +        nnz += 1
             }
             i += 1
           }
           result
         }
       
      +  @Since("1.4.0")
         override def numActives: Int = size
       
      +  @Since("1.4.0")
         override def numNonzeros: Int = {
           // same as values.count(_ != 0.0) but faster
           var nnz = 0
      @@ -582,6 +613,7 @@ class DenseVector(val values: Array[Double]) extends Vector {
           nnz
         }
       
      +  @Since("1.4.0")
         override def toSparse: SparseVector = {
           val nnz = numNonzeros
           val ii = new Array[Int](nnz)
      @@ -597,11 +629,8 @@ class DenseVector(val values: Array[Double]) extends Vector {
           new SparseVector(size, ii, vv)
         }
       
      -  /**
      -   * Find the index of a maximal element.  Returns the first maximal element in case of a tie.
      -   * Returns -1 if vector has length 0.
      -   */
      -  private[spark] def argmax: Int = {
      +  @Since("1.5.0")
      +  override def argmax: Int = {
           if (size == 0) {
             -1
           } else {
      @@ -620,8 +649,11 @@ class DenseVector(val values: Array[Double]) extends Vector {
         }
       }
       
      +@Since("1.3.0")
       object DenseVector {
      +
         /** Extracts the value array from a dense vector. */
      +  @Since("1.3.0")
         def unapply(dv: DenseVector): Option[Array[Double]] = Some(dv.values)
       }
       
      @@ -632,19 +664,23 @@ object DenseVector {
        * @param indices index array, assume to be strictly increasing.
        * @param values value array, must have the same length as the index array.
        */
      +@Since("1.0.0")
       @SQLUserDefinedType(udt = classOf[VectorUDT])
      -class SparseVector(
      -    override val size: Int,
      -    val indices: Array[Int],
      -    val values: Array[Double]) extends Vector {
      +class SparseVector @Since("1.0.0") (
      +    @Since("1.0.0") override val size: Int,
      +    @Since("1.0.0") val indices: Array[Int],
      +    @Since("1.0.0") val values: Array[Double]) extends Vector {
       
         require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
           s" indices match the dimension of the values. You provided ${indices.length} indices and " +
           s" ${values.length} values.")
      +  require(indices.length <= size, s"You provided ${indices.length} indices and values, " +
      +    s"which exceeds the specified vector size ${size}.")
       
         override def toString: String =
           s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})"
       
      +  @Since("1.0.0")
         override def toArray: Array[Double] = {
           val data = new Array[Double](size)
           var i = 0
      @@ -656,6 +692,7 @@ class SparseVector(
           data
         }
       
      +  @Since("1.1.0")
         override def copy: SparseVector = {
           new SparseVector(size, indices.clone(), values.clone())
         }
      @@ -677,27 +714,26 @@ class SparseVector(
         override def hashCode(): Int = {
           var result: Int = 31 + size
           val end = values.length
      -    var continue = true
           var k = 0
      -    while ((k < end) & continue) {
      -      val i = indices(k)
      -      if (i < 16) {
      -        val v = values(k)
      -        if (v != 0.0) {
      -          result = 31 * result + i
      -          val bits = java.lang.Double.doubleToLongBits(v)
      -          result = 31 * result + (bits ^ (bits >>> 32)).toInt
      -        }
      -      } else {
      -        continue = false
      +    var nnz = 0
      +    while (k < end && nnz < Vectors.MAX_HASH_NNZ) {
      +      val v = values(k)
      +      if (v != 0.0) {
      +        val i = indices(k)
      +        result = 31 * result + i
      +        val bits = java.lang.Double.doubleToLongBits(v)
      +        result = 31 * result + (bits ^ (bits >>> 32)).toInt
      +        nnz += 1
             }
             k += 1
           }
           result
         }
       
      +  @Since("1.4.0")
         override def numActives: Int = values.length
       
      +  @Since("1.4.0")
         override def numNonzeros: Int = {
           var nnz = 0
           values.foreach { v =>
      @@ -708,6 +744,7 @@ class SparseVector(
           nnz
         }
       
      +  @Since("1.4.0")
         override def toSparse: SparseVector = {
           val nnz = numNonzeros
           if (nnz == numActives) {
      @@ -726,9 +763,81 @@ class SparseVector(
             new SparseVector(size, ii, vv)
           }
         }
      +
      +  @Since("1.5.0")
      +  override def argmax: Int = {
      +    if (size == 0) {
      +      -1
      +    } else {
      +      // Find the max active entry.
      +      var maxIdx = indices(0)
      +      var maxValue = values(0)
      +      var maxJ = 0
      +      var j = 1
      +      val na = numActives
      +      while (j < na) {
      +        val v = values(j)
      +        if (v > maxValue) {
      +          maxValue = v
      +          maxIdx = indices(j)
      +          maxJ = j
      +        }
      +        j += 1
      +      }
      +
      +      // If the max active entry is nonpositive and there exists inactive ones, find the first zero.
      +      if (maxValue <= 0.0 && na < size) {
      +        if (maxValue == 0.0) {
      +          // If there exists an inactive entry before maxIdx, find it and return its index.
      +          if (maxJ < maxIdx) {
      +            var k = 0
      +            while (k < maxJ && indices(k) == k) {
      +              k += 1
      +            }
      +            maxIdx = k
      +          }
      +        } else {
      +          // If the max active value is negative, find and return the first inactive index.
      +          var k = 0
      +          while (k < na && indices(k) == k) {
      +            k += 1
      +          }
      +          maxIdx = k
      +        }
      +      }
      +
      +      maxIdx
      +    }
      +  }
      +
      +  /**
      +   * Create a slice of this vector based on the given indices.
      +   * @param selectedIndices Unsorted list of indices into the vector.
      +   *                        This does NOT do bound checking.
      +   * @return  New SparseVector with values in the order specified by the given indices.
      +   *
      +   * NOTE: The API needs to be discussed before making this public.
      +   *       Also, if we have a version assuming indices are sorted, we should optimize it.
      +   */
      +  private[spark] def slice(selectedIndices: Array[Int]): SparseVector = {
      +    var currentIdx = 0
      +    val (sliceInds, sliceVals) = selectedIndices.flatMap { origIdx =>
      +      val iIdx = java.util.Arrays.binarySearch(this.indices, origIdx)
      +      val i_v = if (iIdx >= 0) {
      +        Iterator((currentIdx, this.values(iIdx)))
      +      } else {
      +        Iterator()
      +      }
      +      currentIdx += 1
      +      i_v
      +    }.unzip
      +    new SparseVector(selectedIndices.length, sliceInds.toArray, sliceVals.toArray)
      +  }
       }
       
      +@Since("1.3.0")
       object SparseVector {
      +  @Since("1.3.0")
         def unapply(sv: SparseVector): Option[(Int, Array[Int], Array[Double])] =
           Some((sv.size, sv.indices, sv.values))
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
      index 3323ae7b1fba..a33b6137cf9c 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
      @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer
       import breeze.linalg.{DenseMatrix => BDM}
       
       import org.apache.spark.{Logging, Partitioner, SparkException}
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix}
       import org.apache.spark.rdd.RDD
       import org.apache.spark.storage.StorageLevel
      @@ -129,11 +129,12 @@ private[mllib] object GridPartitioner {
        * @param nCols Number of columns of this matrix. If the supplied value is less than or equal to
        *              zero, the number of columns will be calculated when `numCols` is invoked.
        */
      +@Since("1.3.0")
       @Experimental
      -class BlockMatrix(
      -    val blocks: RDD[((Int, Int), Matrix)],
      -    val rowsPerBlock: Int,
      -    val colsPerBlock: Int,
      +class BlockMatrix @Since("1.3.0") (
      +    @Since("1.3.0") val blocks: RDD[((Int, Int), Matrix)],
      +    @Since("1.3.0") val rowsPerBlock: Int,
      +    @Since("1.3.0") val colsPerBlock: Int,
           private var nRows: Long,
           private var nCols: Long) extends DistributedMatrix with Logging {
       
      @@ -150,6 +151,7 @@ class BlockMatrix(
          * @param colsPerBlock Number of columns that make up each block. The blocks forming the final
          *                     columns are not required to have the given number of columns
          */
      +  @Since("1.3.0")
         def this(
             blocks: RDD[((Int, Int), Matrix)],
             rowsPerBlock: Int,
      @@ -157,17 +159,21 @@ class BlockMatrix(
           this(blocks, rowsPerBlock, colsPerBlock, 0L, 0L)
         }
       
      +  @Since("1.3.0")
         override def numRows(): Long = {
           if (nRows <= 0L) estimateDim()
           nRows
         }
       
      +  @Since("1.3.0")
         override def numCols(): Long = {
           if (nCols <= 0L) estimateDim()
           nCols
         }
       
      +  @Since("1.3.0")
         val numRowBlocks = math.ceil(numRows() * 1.0 / rowsPerBlock).toInt
      +  @Since("1.3.0")
         val numColBlocks = math.ceil(numCols() * 1.0 / colsPerBlock).toInt
       
         private[mllib] def createPartitioner(): GridPartitioner =
      @@ -193,6 +199,7 @@ class BlockMatrix(
          * Validates the block matrix info against the matrix data (`blocks`) and throws an exception if
          * any error is found.
          */
      +  @Since("1.3.0")
         def validate(): Unit = {
           logDebug("Validating BlockMatrix...")
           // check if the matrix is larger than the claimed dimensions
      @@ -229,18 +236,21 @@ class BlockMatrix(
         }
       
         /** Caches the underlying RDD. */
      +  @Since("1.3.0")
         def cache(): this.type = {
           blocks.cache()
           this
         }
       
         /** Persists the underlying RDD with the specified storage level. */
      +  @Since("1.3.0")
         def persist(storageLevel: StorageLevel): this.type = {
           blocks.persist(storageLevel)
           this
         }
       
         /** Converts to CoordinateMatrix. */
      +  @Since("1.3.0")
         def toCoordinateMatrix(): CoordinateMatrix = {
           val entryRDD = blocks.flatMap { case ((blockRowIndex, blockColIndex), mat) =>
             val rowStart = blockRowIndex.toLong * rowsPerBlock
      @@ -255,6 +265,7 @@ class BlockMatrix(
         }
       
         /** Converts to IndexedRowMatrix. The number of columns must be within the integer range. */
      +  @Since("1.3.0")
         def toIndexedRowMatrix(): IndexedRowMatrix = {
           require(numCols() < Int.MaxValue, "The number of columns must be within the integer range. " +
             s"numCols: ${numCols()}")
      @@ -263,6 +274,7 @@ class BlockMatrix(
         }
       
         /** Collect the distributed matrix on the driver as a `DenseMatrix`. */
      +  @Since("1.3.0")
         def toLocalMatrix(): Matrix = {
           require(numRows() < Int.MaxValue, "The number of rows of this matrix should be less than " +
             s"Int.MaxValue. Currently numRows: ${numRows()}")
      @@ -287,8 +299,11 @@ class BlockMatrix(
           new DenseMatrix(m, n, values)
         }
       
      -  /** Transpose this `BlockMatrix`. Returns a new `BlockMatrix` instance sharing the
      -    * same underlying data. Is a lazy operation. */
      +  /**
      +   * Transpose this `BlockMatrix`. Returns a new `BlockMatrix` instance sharing the
      +   * same underlying data. Is a lazy operation.
      +   */
      +  @Since("1.3.0")
         def transpose: BlockMatrix = {
           val transposedBlocks = blocks.map { case ((blockRowIndex, blockColIndex), mat) =>
             ((blockColIndex, blockRowIndex), mat.transpose)
      @@ -302,12 +317,14 @@ class BlockMatrix(
           new BDM[Double](localMat.numRows, localMat.numCols, localMat.toArray)
         }
       
      -  /** Adds two block matrices together. The matrices must have the same size and matching
      -    * `rowsPerBlock` and `colsPerBlock` values. If one of the blocks that are being added are
      -    * instances of [[SparseMatrix]], the resulting sub matrix will also be a [[SparseMatrix]], even
      -    * if it is being added to a [[DenseMatrix]]. If two dense matrices are added, the output will
      -    * also be a [[DenseMatrix]].
      -    */
      +  /**
      +   * Adds two block matrices together. The matrices must have the same size and matching
      +   * `rowsPerBlock` and `colsPerBlock` values. If one of the blocks that are being added are
      +   * instances of [[SparseMatrix]], the resulting sub matrix will also be a [[SparseMatrix]], even
      +   * if it is being added to a [[DenseMatrix]]. If two dense matrices are added, the output will
      +   * also be a [[DenseMatrix]].
      +   */
      +  @Since("1.3.0")
         def add(other: BlockMatrix): BlockMatrix = {
           require(numRows() == other.numRows(), "Both matrices must have the same number of rows. " +
             s"A.numRows: ${numRows()}, B.numRows: ${other.numRows()}")
      @@ -335,12 +352,14 @@ class BlockMatrix(
           }
         }
       
      -  /** Left multiplies this [[BlockMatrix]] to `other`, another [[BlockMatrix]]. The `colsPerBlock`
      -    * of this matrix must equal the `rowsPerBlock` of `other`. If `other` contains
      -    * [[SparseMatrix]], they will have to be converted to a [[DenseMatrix]]. The output
      -    * [[BlockMatrix]] will only consist of blocks of [[DenseMatrix]]. This may cause
      -    * some performance issues until support for multiplying two sparse matrices is added.
      -    */
      +  /**
      +   * Left multiplies this [[BlockMatrix]] to `other`, another [[BlockMatrix]]. The `colsPerBlock`
      +   * of this matrix must equal the `rowsPerBlock` of `other`. If `other` contains
      +   * [[SparseMatrix]], they will have to be converted to a [[DenseMatrix]]. The output
      +   * [[BlockMatrix]] will only consist of blocks of [[DenseMatrix]]. This may cause
      +   * some performance issues until support for multiplying two sparse matrices is added.
      +   */
      +  @Since("1.3.0")
         def multiply(other: BlockMatrix): BlockMatrix = {
           require(numCols() == other.numRows(), "The number of columns of A and the number of rows " +
             s"of B must be equal. A.numCols: ${numCols()}, B.numRows: ${other.numRows()}. If you " +
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala
      index 078d1fac4444..644f293d88a7 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala
      @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg.distributed
       
       import breeze.linalg.{DenseMatrix => BDM}
       
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.rdd.RDD
       import org.apache.spark.mllib.linalg.{Matrix, SparseMatrix, Vectors}
       
      @@ -30,6 +30,7 @@ import org.apache.spark.mllib.linalg.{Matrix, SparseMatrix, Vectors}
        * @param j column index
        * @param value value of the entry
        */
      +@Since("1.0.0")
       @Experimental
       case class MatrixEntry(i: Long, j: Long, value: Double)
       
      @@ -43,16 +44,19 @@ case class MatrixEntry(i: Long, j: Long, value: Double)
        * @param nCols number of columns. A non-positive value means unknown, and then the number of
        *              columns will be determined by the max column index plus one.
        */
      +@Since("1.0.0")
       @Experimental
      -class CoordinateMatrix(
      -    val entries: RDD[MatrixEntry],
      +class CoordinateMatrix @Since("1.0.0") (
      +    @Since("1.0.0") val entries: RDD[MatrixEntry],
           private var nRows: Long,
           private var nCols: Long) extends DistributedMatrix {
       
         /** Alternative constructor leaving matrix dimensions to be determined automatically. */
      +  @Since("1.0.0")
         def this(entries: RDD[MatrixEntry]) = this(entries, 0L, 0L)
       
         /** Gets or computes the number of columns. */
      +  @Since("1.0.0")
         override def numCols(): Long = {
           if (nCols <= 0L) {
             computeSize()
      @@ -61,6 +65,7 @@ class CoordinateMatrix(
         }
       
         /** Gets or computes the number of rows. */
      +  @Since("1.0.0")
         override def numRows(): Long = {
           if (nRows <= 0L) {
             computeSize()
      @@ -69,11 +74,13 @@ class CoordinateMatrix(
         }
       
         /** Transposes this CoordinateMatrix. */
      +  @Since("1.3.0")
         def transpose(): CoordinateMatrix = {
           new CoordinateMatrix(entries.map(x => MatrixEntry(x.j, x.i, x.value)), numCols(), numRows())
         }
       
         /** Converts to IndexedRowMatrix. The number of columns must be within the integer range. */
      +  @Since("1.0.0")
         def toIndexedRowMatrix(): IndexedRowMatrix = {
           val nl = numCols()
           if (nl > Int.MaxValue) {
      @@ -93,11 +100,13 @@ class CoordinateMatrix(
          * Converts to RowMatrix, dropping row indices after grouping by row index.
          * The number of columns must be within the integer range.
          */
      +  @Since("1.0.0")
         def toRowMatrix(): RowMatrix = {
           toIndexedRowMatrix().toRowMatrix()
         }
       
         /** Converts to BlockMatrix. Creates blocks of [[SparseMatrix]] with size 1024 x 1024. */
      +  @Since("1.3.0")
         def toBlockMatrix(): BlockMatrix = {
           toBlockMatrix(1024, 1024)
         }
      @@ -110,6 +119,7 @@ class CoordinateMatrix(
          *                     a smaller value. Must be an integer value greater than 0.
          * @return a [[BlockMatrix]]
          */
      +  @Since("1.3.0")
         def toBlockMatrix(rowsPerBlock: Int, colsPerBlock: Int): BlockMatrix = {
           require(rowsPerBlock > 0,
             s"rowsPerBlock needs to be greater than 0. rowsPerBlock: $rowsPerBlock")
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala
      index a0e26ce3bc46..db3433a5e245 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala
      @@ -19,15 +19,20 @@ package org.apache.spark.mllib.linalg.distributed
       
       import breeze.linalg.{DenseMatrix => BDM}
       
      +import org.apache.spark.annotation.Since
      +
       /**
        * Represents a distributively stored matrix backed by one or more RDDs.
        */
      +@Since("1.0.0")
       trait DistributedMatrix extends Serializable {
       
         /** Gets or computes the number of rows. */
      +  @Since("1.0.0")
         def numRows(): Long
       
         /** Gets or computes the number of columns. */
      +  @Since("1.0.0")
         def numCols(): Long
       
         /** Collects data and assembles a local dense breeze matrix (for test only). */
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
      index 3be530fa0753..b20ea0dc50da 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
      @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg.distributed
       
       import breeze.linalg.{DenseMatrix => BDM}
       
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.rdd.RDD
       import org.apache.spark.mllib.linalg._
       import org.apache.spark.mllib.linalg.SingularValueDecomposition
      @@ -28,6 +28,7 @@ import org.apache.spark.mllib.linalg.SingularValueDecomposition
        * :: Experimental ::
        * Represents a row of [[org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix]].
        */
      +@Since("1.0.0")
       @Experimental
       case class IndexedRow(index: Long, vector: Vector)
       
      @@ -42,15 +43,18 @@ case class IndexedRow(index: Long, vector: Vector)
        * @param nCols number of columns. A non-positive value means unknown, and then the number of
        *              columns will be determined by the size of the first row.
        */
      +@Since("1.0.0")
       @Experimental
      -class IndexedRowMatrix(
      -    val rows: RDD[IndexedRow],
      +class IndexedRowMatrix @Since("1.0.0") (
      +    @Since("1.0.0") val rows: RDD[IndexedRow],
           private var nRows: Long,
           private var nCols: Int) extends DistributedMatrix {
       
         /** Alternative constructor leaving matrix dimensions to be determined automatically. */
      +  @Since("1.0.0")
         def this(rows: RDD[IndexedRow]) = this(rows, 0L, 0)
       
      +  @Since("1.0.0")
         override def numCols(): Long = {
           if (nCols <= 0) {
             // Calling `first` will throw an exception if `rows` is empty.
      @@ -59,6 +63,7 @@ class IndexedRowMatrix(
           nCols
         }
       
      +  @Since("1.0.0")
         override def numRows(): Long = {
           if (nRows <= 0L) {
             // Reduce will throw an exception if `rows` is empty.
      @@ -71,11 +76,13 @@ class IndexedRowMatrix(
          * Drops row indices and converts this matrix to a
          * [[org.apache.spark.mllib.linalg.distributed.RowMatrix]].
          */
      +  @Since("1.0.0")
         def toRowMatrix(): RowMatrix = {
           new RowMatrix(rows.map(_.vector), 0L, nCols)
         }
       
         /** Converts to BlockMatrix. Creates blocks of [[SparseMatrix]] with size 1024 x 1024. */
      +  @Since("1.3.0")
         def toBlockMatrix(): BlockMatrix = {
           toBlockMatrix(1024, 1024)
         }
      @@ -88,6 +95,7 @@ class IndexedRowMatrix(
          *                     a smaller value. Must be an integer value greater than 0.
          * @return a [[BlockMatrix]]
          */
      +  @Since("1.3.0")
         def toBlockMatrix(rowsPerBlock: Int, colsPerBlock: Int): BlockMatrix = {
           // TODO: This implementation may be optimized
           toCoordinateMatrix().toBlockMatrix(rowsPerBlock, colsPerBlock)
      @@ -97,6 +105,7 @@ class IndexedRowMatrix(
          * Converts this matrix to a
          * [[org.apache.spark.mllib.linalg.distributed.CoordinateMatrix]].
          */
      +  @Since("1.3.0")
         def toCoordinateMatrix(): CoordinateMatrix = {
           val entries = rows.flatMap { row =>
             val rowIndex = row.index
      @@ -133,6 +142,7 @@ class IndexedRowMatrix(
          *              are treated as zero, where sigma(0) is the largest singular value.
          * @return SingularValueDecomposition(U, s, V)
          */
      +  @Since("1.0.0")
         def computeSVD(
             k: Int,
             computeU: Boolean = false,
      @@ -146,7 +156,7 @@ class IndexedRowMatrix(
             val indexedRows = indices.zip(svd.U.rows).map { case (i, v) =>
               IndexedRow(i, v)
             }
      -      new IndexedRowMatrix(indexedRows, nRows, nCols)
      +      new IndexedRowMatrix(indexedRows, nRows, svd.U.numCols().toInt)
           } else {
             null
           }
      @@ -159,6 +169,7 @@ class IndexedRowMatrix(
          * @param B a local matrix whose number of rows must match the number of columns of this matrix
          * @return an IndexedRowMatrix representing the product, which preserves partitioning
          */
      +  @Since("1.0.0")
         def multiply(B: Matrix): IndexedRowMatrix = {
           val mat = toRowMatrix().multiply(B)
           val indexedRows = rows.map(_.index).zip(mat.rows).map { case (i, v) =>
      @@ -170,6 +181,7 @@ class IndexedRowMatrix(
         /**
          * Computes the Gramian matrix `A^T A`.
          */
      +  @Since("1.0.0")
         def computeGramianMatrix(): Matrix = {
           toRowMatrix().computeGramianMatrix()
         }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
      index 1626da9c3d2e..e55ef26858ad 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
      @@ -22,13 +22,12 @@ import java.util.Arrays
       import scala.collection.mutable.ListBuffer
       
       import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV, axpy => brzAxpy,
      -  svd => brzSvd}
      +  svd => brzSvd, MatrixSingularException, inv}
       import breeze.numerics.{sqrt => brzSqrt}
      -import com.github.fommil.netlib.BLAS.{getInstance => blas}
       
       import org.apache.spark.Logging
       import org.apache.spark.SparkContext._
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.mllib.linalg._
       import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary}
       import org.apache.spark.rdd.RDD
      @@ -45,16 +44,19 @@ import org.apache.spark.storage.StorageLevel
        * @param nCols number of columns. A non-positive value means unknown, and then the number of
        *              columns will be determined by the size of the first row.
        */
      +@Since("1.0.0")
       @Experimental
      -class RowMatrix(
      -    val rows: RDD[Vector],
      +class RowMatrix @Since("1.0.0") (
      +    @Since("1.0.0") val rows: RDD[Vector],
           private var nRows: Long,
           private var nCols: Int) extends DistributedMatrix with Logging {
       
         /** Alternative constructor leaving matrix dimensions to be determined automatically. */
      +  @Since("1.0.0")
         def this(rows: RDD[Vector]) = this(rows, 0L, 0)
       
         /** Gets or computes the number of columns. */
      +  @Since("1.0.0")
         override def numCols(): Long = {
           if (nCols <= 0) {
             try {
      @@ -70,6 +72,7 @@ class RowMatrix(
         }
       
         /** Gets or computes the number of rows. */
      +  @Since("1.0.0")
         override def numRows(): Long = {
           if (nRows <= 0L) {
             nRows = rows.count()
      @@ -108,6 +111,7 @@ class RowMatrix(
         /**
          * Computes the Gramian matrix `A^T A`.
          */
      +  @Since("1.0.0")
         def computeGramianMatrix(): Matrix = {
           val n = numCols().toInt
           checkNumColumns(n)
      @@ -118,7 +122,7 @@ class RowMatrix(
           // Compute the upper triangular part of the gram matrix.
           val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))(
             seqOp = (U, v) => {
      -        RowMatrix.dspr(1.0, v, U.data)
      +        BLAS.spr(1.0, v, U.data)
               U
             }, combOp = (U1, U2) => U1 += U2)
       
      @@ -178,6 +182,7 @@ class RowMatrix(
          *              are treated as zero, where sigma(0) is the largest singular value.
          * @return SingularValueDecomposition(U, s, V). U = null if computeU = false.
          */
      +  @Since("1.0.0")
         def computeSVD(
             k: Int,
             computeU: Boolean = false,
      @@ -318,6 +323,7 @@ class RowMatrix(
          * Computes the covariance matrix, treating each row as an observation.
          * @return a local dense matrix of size n x n
          */
      +  @Since("1.0.0")
         def computeCovariance(): Matrix = {
           val n = numCols().toInt
           checkNumColumns(n)
      @@ -371,6 +377,7 @@ class RowMatrix(
          * @param k number of top principal components.
          * @return a matrix of size n-by-k, whose columns are principal components
          */
      +  @Since("1.0.0")
         def computePrincipalComponents(k: Int): Matrix = {
           val n = numCols().toInt
           require(k > 0 && k <= n, s"k = $k out of range (0, n = $n]")
      @@ -389,6 +396,7 @@ class RowMatrix(
         /**
          * Computes column-wise summary statistics.
          */
      +  @Since("1.0.0")
         def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
           val summary = rows.treeAggregate(new MultivariateOnlineSummarizer)(
             (aggregator, data) => aggregator.add(data),
      @@ -404,6 +412,7 @@ class RowMatrix(
          * @return a [[org.apache.spark.mllib.linalg.distributed.RowMatrix]] representing the product,
          *         which preserves partitioning
          */
      +  @Since("1.0.0")
         def multiply(B: Matrix): RowMatrix = {
           val n = numCols().toInt
           val k = B.numCols
      @@ -436,6 +445,7 @@ class RowMatrix(
          * @return An n x n sparse upper-triangular matrix of cosine similarities between
          *         columns of this matrix.
          */
      +  @Since("1.2.0")
         def columnSimilarities(): CoordinateMatrix = {
           columnSimilarities(0.0)
         }
      @@ -479,6 +489,7 @@ class RowMatrix(
          * @return An n x n sparse upper-triangular matrix of cosine similarities
          *         between columns of this matrix.
          */
      +  @Since("1.2.0")
         def columnSimilarities(threshold: Double): CoordinateMatrix = {
           require(threshold >= 0, s"Threshold cannot be negative: $threshold")
       
      @@ -497,6 +508,51 @@ class RowMatrix(
           columnSimilaritiesDIMSUM(computeColumnSummaryStatistics().normL2.toArray, gamma)
         }
       
      +  /**
      +   * Compute QR decomposition for [[RowMatrix]]. The implementation is designed to optimize the QR
      +   * decomposition (factorization) for the [[RowMatrix]] of a tall and skinny shape.
      +   * Reference:
      +   *  Paul G. Constantine, David F. Gleich. "Tall and skinny QR factorizations in MapReduce
      +   *  architectures"  ([[http://dx.doi.org/10.1145/1996092.1996103]])
      +   *
      +   * @param computeQ whether to computeQ
      +   * @return QRDecomposition(Q, R), Q = null if computeQ = false.
      +   */
      +  @Since("1.5.0")
      +  def tallSkinnyQR(computeQ: Boolean = false): QRDecomposition[RowMatrix, Matrix] = {
      +    val col = numCols().toInt
      +    // split rows horizontally into smaller matrices, and compute QR for each of them
      +    val blockQRs = rows.glom().map { partRows =>
      +      val bdm = BDM.zeros[Double](partRows.length, col)
      +      var i = 0
      +      partRows.foreach { row =>
      +        bdm(i, ::) := row.toBreeze.t
      +        i += 1
      +      }
      +      breeze.linalg.qr.reduced(bdm).r
      +    }
      +
      +    // combine the R part from previous results vertically into a tall matrix
      +    val combinedR = blockQRs.treeReduce{ (r1, r2) =>
      +      val stackedR = BDM.vertcat(r1, r2)
      +      breeze.linalg.qr.reduced(stackedR).r
      +    }
      +    val finalR = Matrices.fromBreeze(combinedR.toDenseMatrix)
      +    val finalQ = if (computeQ) {
      +      try {
      +        val invR = inv(combinedR)
      +        this.multiply(Matrices.fromBreeze(invR))
      +      } catch {
      +        case err: MatrixSingularException =>
      +          logWarning("R is not invertible and return Q as null")
      +          null
      +      }
      +    } else {
      +      null
      +    }
      +    QRDecomposition(finalQ, finalR)
      +  }
      +
         /**
          * Find all similar columns using the DIMSUM sampling algorithm, described in two papers
          *
      @@ -612,45 +668,10 @@ class RowMatrix(
         }
       }
       
      +@Since("1.0.0")
       @Experimental
       object RowMatrix {
       
      -  /**
      -   * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's DSPR.
      -   *
      -   * @param U the upper triangular part of the matrix packed in an array (column major)
      -   */
      -  private def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = {
      -    // TODO: Find a better home (breeze?) for this method.
      -    val n = v.size
      -    v match {
      -      case DenseVector(values) =>
      -        blas.dspr("U", n, alpha, values, 1, U)
      -      case SparseVector(size, indices, values) =>
      -        val nnz = indices.length
      -        var colStartIdx = 0
      -        var prevCol = 0
      -        var col = 0
      -        var j = 0
      -        var i = 0
      -        var av = 0.0
      -        while (j < nnz) {
      -          col = indices(j)
      -          // Skip empty columns.
      -          colStartIdx += (col - prevCol) * (col + prevCol + 1) / 2
      -          col = indices(j)
      -          av = alpha * values(j)
      -          i = 0
      -          while (i <= j) {
      -            U(colStartIdx + indices(i)) += av * values(i)
      -            i += 1
      -          }
      -          j += 1
      -          prevCol = col
      -        }
      -    }
      -  }
      -
         /**
          * Fills a full square matrix from its upper triangular part.
          */
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
      index 06e45e10c5bf..3b663b5defb0 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
      @@ -19,25 +19,27 @@ package org.apache.spark.mllib.optimization
       
       import scala.collection.mutable.ArrayBuffer
       
      -import breeze.linalg.{DenseVector => BDV}
      +import breeze.linalg.{DenseVector => BDV, norm}
       
       import org.apache.spark.annotation.{Experimental, DeveloperApi}
       import org.apache.spark.Logging
       import org.apache.spark.rdd.RDD
       import org.apache.spark.mllib.linalg.{Vectors, Vector}
       
      +
       /**
        * Class used to solve an optimization problem using Gradient Descent.
        * @param gradient Gradient function to be used.
        * @param updater Updater to be used to update weights after every iteration.
        */
      -class GradientDescent private[mllib] (private var gradient: Gradient, private var updater: Updater)
      +class GradientDescent private[spark] (private var gradient: Gradient, private var updater: Updater)
         extends Optimizer with Logging {
       
         private var stepSize: Double = 1.0
         private var numIterations: Int = 100
         private var regParam: Double = 0.0
         private var miniBatchFraction: Double = 1.0
      +  private var convergenceTol: Double = 0.001
       
         /**
          * Set the initial step size of SGD for the first step. Default 1.0.
      @@ -75,6 +77,23 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va
           this
         }
       
      +  /**
      +   * Set the convergence tolerance. Default 0.001
      +   * convergenceTol is a condition which decides iteration termination.
      +   * The end of iteration is decided based on below logic.
      +   * - If the norm of the new solution vector is >1, the diff of solution vectors
      +   *   is compared to relative tolerance which means normalizing by the norm of
      +   *   the new solution vector.
      +   * - If the norm of the new solution vector is <=1, the diff of solution vectors
      +   *   is compared to absolute tolerance which is not normalizing.
      +   * Must be between 0.0 and 1.0 inclusively.
      +   */
      +  def setConvergenceTol(tolerance: Double): this.type = {
      +    require(0.0 <= tolerance && tolerance <= 1.0)
      +    this.convergenceTol = tolerance
      +    this
      +  }
      +
         /**
          * Set the gradient function (of the loss function of one single data example)
          * to be used for SGD.
      @@ -112,7 +131,8 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va
             numIterations,
             regParam,
             miniBatchFraction,
      -      initialWeights)
      +      initialWeights,
      +      convergenceTol)
           weights
         }
       
      @@ -131,17 +151,20 @@ object GradientDescent extends Logging {
          * Sampling, and averaging the subgradients over this subset is performed using one standard
          * spark map-reduce in each iteration.
          *
      -   * @param data - Input data for SGD. RDD of the set of data examples, each of
      -   *               the form (label, [feature values]).
      -   * @param gradient - Gradient object (used to compute the gradient of the loss function of
      -   *                   one single data example)
      -   * @param updater - Updater function to actually perform a gradient step in a given direction.
      -   * @param stepSize - initial step size for the first step
      -   * @param numIterations - number of iterations that SGD should be run.
      -   * @param regParam - regularization parameter
      -   * @param miniBatchFraction - fraction of the input data set that should be used for
      -   *                            one iteration of SGD. Default value 1.0.
      -   *
      +   * @param data Input data for SGD. RDD of the set of data examples, each of
      +   *             the form (label, [feature values]).
      +   * @param gradient Gradient object (used to compute the gradient of the loss function of
      +   *                 one single data example)
      +   * @param updater Updater function to actually perform a gradient step in a given direction.
      +   * @param stepSize initial step size for the first step
      +   * @param numIterations number of iterations that SGD should be run.
      +   * @param regParam regularization parameter
      +   * @param miniBatchFraction fraction of the input data set that should be used for
      +   *                          one iteration of SGD. Default value 1.0.
      +   * @param convergenceTol Minibatch iteration will end before numIterations if the relative
      +   *                       difference between the current weight and the previous weight is less
      +   *                       than this value. In measuring convergence, L2 norm is calculated.
      +   *                       Default value 0.001. Must be between 0.0 and 1.0 inclusively.
          * @return A tuple containing two elements. The first element is a column matrix containing
          *         weights for every feature, and the second element is an array containing the
          *         stochastic loss computed for every iteration.
      @@ -154,9 +177,20 @@ object GradientDescent extends Logging {
             numIterations: Int,
             regParam: Double,
             miniBatchFraction: Double,
      -      initialWeights: Vector): (Vector, Array[Double]) = {
      +      initialWeights: Vector,
      +      convergenceTol: Double): (Vector, Array[Double]) = {
      +
      +    // convergenceTol should be set with non minibatch settings
      +    if (miniBatchFraction < 1.0 && convergenceTol > 0.0) {
      +      logWarning("Testing against a convergenceTol when using miniBatchFraction " +
      +        "< 1.0 can be unstable because of the stochasticity in sampling.")
      +    }
       
           val stochasticLossHistory = new ArrayBuffer[Double](numIterations)
      +    // Record previous weight and current one to calculate solution vector difference
      +
      +    var previousWeights: Option[Vector] = None
      +    var currentWeights: Option[Vector] = None
       
           val numExamples = data.count()
       
      @@ -181,7 +215,9 @@ object GradientDescent extends Logging {
           var regVal = updater.compute(
             weights, Vectors.zeros(weights.size), 0, 1, regParam)._2
       
      -    for (i <- 1 to numIterations) {
      +    var converged = false // indicates whether converged based on convergenceTol
      +    var i = 1
      +    while (!converged && i <= numIterations) {
             val bcWeights = data.context.broadcast(weights)
             // Sample a subset (fraction miniBatchFraction) of the total data
             // compute and sum up the subgradients on this subset (this is one map-reduce)
      @@ -199,17 +235,26 @@ object GradientDescent extends Logging {
       
             if (miniBatchSize > 0) {
               /**
      -         * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration
      +         * lossSum is computed using the weights from the previous iteration
                * and regVal is the regularization value computed in the previous iteration as well.
                */
               stochasticLossHistory.append(lossSum / miniBatchSize + regVal)
               val update = updater.compute(
      -          weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble), stepSize, i, regParam)
      +          weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble),
      +          stepSize, i, regParam)
               weights = update._1
               regVal = update._2
      +
      +        previousWeights = currentWeights
      +        currentWeights = Some(weights)
      +        if (previousWeights != None && currentWeights != None) {
      +          converged = isConverged(previousWeights.get,
      +            currentWeights.get, convergenceTol)
      +        }
             } else {
               logWarning(s"Iteration ($i/$numIterations). The size of sampled batch is zero")
             }
      +      i += 1
           }
       
           logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format(
      @@ -218,4 +263,35 @@ object GradientDescent extends Logging {
           (weights, stochasticLossHistory.toArray)
       
         }
      +
      +  /**
      +   * Alias of [[runMiniBatchSGD]] with convergenceTol set to default value of 0.001.
      +   */
      +  def runMiniBatchSGD(
      +      data: RDD[(Double, Vector)],
      +      gradient: Gradient,
      +      updater: Updater,
      +      stepSize: Double,
      +      numIterations: Int,
      +      regParam: Double,
      +      miniBatchFraction: Double,
      +      initialWeights: Vector): (Vector, Array[Double]) =
      +    GradientDescent.runMiniBatchSGD(data, gradient, updater, stepSize, numIterations,
      +                                    regParam, miniBatchFraction, initialWeights, 0.001)
      +
      +
      +  private def isConverged(
      +      previousWeights: Vector,
      +      currentWeights: Vector,
      +      convergenceTol: Double): Boolean = {
      +    // To compare with convergence tolerance.
      +    val previousBDV = previousWeights.toBreeze.toDenseVector
      +    val currentBDV = currentWeights.toBreeze.toDenseVector
      +
      +    // This represents the difference of updated weights in the iteration.
      +    val solutionVecDiff: Double = norm(previousBDV - currentBDV)
      +
      +    solutionVecDiff < convergenceTol * Math.max(norm(currentBDV), 1.0)
      +  }
      +
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala
      index 5e882d4ebb10..274ac7c99553 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala
      @@ -23,7 +23,7 @@ import javax.xml.transform.stream.StreamResult
       import org.jpmml.model.JAXBUtil
       
       import org.apache.spark.SparkContext
      -import org.apache.spark.annotation.{DeveloperApi, Experimental}
      +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
       import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory
       
       /**
      @@ -33,6 +33,7 @@ import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory
        * developed by the Data Mining Group (www.dmg.org).
        */
       @DeveloperApi
      +@Since("1.4.0")
       trait PMMLExportable {
       
         /**
      @@ -48,6 +49,7 @@ trait PMMLExportable {
          * Export the model to a local file in PMML format
          */
         @Experimental
      +  @Since("1.4.0")
         def toPMML(localPath: String): Unit = {
           toPMML(new StreamResult(new File(localPath)))
         }
      @@ -57,6 +59,7 @@ trait PMMLExportable {
          * Export the model to a directory on a distributed file system in PMML format
          */
         @Experimental
      +  @Since("1.4.0")
         def toPMML(sc: SparkContext, path: String): Unit = {
           val pmml = toPMML()
           sc.parallelize(Array(pmml), 1).saveAsTextFile(path)
      @@ -67,6 +70,7 @@ trait PMMLExportable {
          * Export the model to the OutputStream in PMML format
          */
         @Experimental
      +  @Since("1.4.0")
         def toPMML(outputStream: OutputStream): Unit = {
           toPMML(new StreamResult(outputStream))
         }
      @@ -76,6 +80,7 @@ trait PMMLExportable {
          * Export the model to a String in PMML format
          */
         @Experimental
      +  @Since("1.4.0")
         def toPMML(): String = {
           val writer = new StringWriter
           toPMML(new StreamResult(writer))
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
      index 9349ecaa13f5..9eab7efc160d 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
      @@ -17,10 +17,9 @@
       
       package org.apache.spark.mllib.random
       
      -import org.apache.commons.math3.distribution.{ExponentialDistribution,
      -  GammaDistribution, LogNormalDistribution, PoissonDistribution}
      +import org.apache.commons.math3.distribution._
       
      -import org.apache.spark.annotation.DeveloperApi
      +import org.apache.spark.annotation.{Since, DeveloperApi}
       import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom}
       
       /**
      @@ -28,17 +27,20 @@ import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom}
        * Trait for random data generators that generate i.i.d. data.
        */
       @DeveloperApi
      +@Since("1.1.0")
       trait RandomDataGenerator[T] extends Pseudorandom with Serializable {
       
         /**
          * Returns an i.i.d. sample as a generic type from an underlying distribution.
          */
      +  @Since("1.1.0")
         def nextValue(): T
       
         /**
          * Returns a copy of the RandomDataGenerator with a new instance of the rng object used in the
          * class when applicable for non-locking concurrent usage.
          */
      +  @Since("1.1.0")
         def copy(): RandomDataGenerator[T]
       }
       
      @@ -47,17 +49,21 @@ trait RandomDataGenerator[T] extends Pseudorandom with Serializable {
        * Generates i.i.d. samples from U[0.0, 1.0]
        */
       @DeveloperApi
      +@Since("1.1.0")
       class UniformGenerator extends RandomDataGenerator[Double] {
       
         // XORShiftRandom for better performance. Thread safety isn't necessary here.
         private val random = new XORShiftRandom()
       
      +  @Since("1.1.0")
         override def nextValue(): Double = {
           random.nextDouble()
         }
       
      +  @Since("1.1.0")
         override def setSeed(seed: Long): Unit = random.setSeed(seed)
       
      +  @Since("1.1.0")
         override def copy(): UniformGenerator = new UniformGenerator()
       }
       
      @@ -66,17 +72,21 @@ class UniformGenerator extends RandomDataGenerator[Double] {
        * Generates i.i.d. samples from the standard normal distribution.
        */
       @DeveloperApi
      +@Since("1.1.0")
       class StandardNormalGenerator extends RandomDataGenerator[Double] {
       
         // XORShiftRandom for better performance. Thread safety isn't necessary here.
         private val random = new XORShiftRandom()
       
      +  @Since("1.1.0")
         override def nextValue(): Double = {
             random.nextGaussian()
         }
       
      +  @Since("1.1.0")
         override def setSeed(seed: Long): Unit = random.setSeed(seed)
       
      +  @Since("1.1.0")
         override def copy(): StandardNormalGenerator = new StandardNormalGenerator()
       }
       
      @@ -87,16 +97,21 @@ class StandardNormalGenerator extends RandomDataGenerator[Double] {
        * @param mean mean for the Poisson distribution.
        */
       @DeveloperApi
      -class PoissonGenerator(val mean: Double) extends RandomDataGenerator[Double] {
      +@Since("1.1.0")
      +class PoissonGenerator @Since("1.1.0") (
      +    @Since("1.1.0") val mean: Double) extends RandomDataGenerator[Double] {
       
         private val rng = new PoissonDistribution(mean)
       
      +  @Since("1.1.0")
         override def nextValue(): Double = rng.sample()
       
      +  @Since("1.1.0")
         override def setSeed(seed: Long) {
           rng.reseedRandomGenerator(seed)
         }
       
      +  @Since("1.1.0")
         override def copy(): PoissonGenerator = new PoissonGenerator(mean)
       }
       
      @@ -107,16 +122,21 @@ class PoissonGenerator(val mean: Double) extends RandomDataGenerator[Double] {
        * @param mean mean for the exponential distribution.
        */
       @DeveloperApi
      -class ExponentialGenerator(val mean: Double) extends RandomDataGenerator[Double] {
      +@Since("1.3.0")
      +class ExponentialGenerator @Since("1.3.0") (
      +    @Since("1.3.0") val mean: Double) extends RandomDataGenerator[Double] {
       
         private val rng = new ExponentialDistribution(mean)
       
      +  @Since("1.3.0")
         override def nextValue(): Double = rng.sample()
       
      +  @Since("1.3.0")
         override def setSeed(seed: Long) {
           rng.reseedRandomGenerator(seed)
         }
       
      +  @Since("1.3.0")
         override def copy(): ExponentialGenerator = new ExponentialGenerator(mean)
       }
       
      @@ -128,16 +148,22 @@ class ExponentialGenerator(val mean: Double) extends RandomDataGenerator[Double]
        * @param scale scale for the gamma distribution
        */
       @DeveloperApi
      -class GammaGenerator(val shape: Double, val scale: Double) extends RandomDataGenerator[Double] {
      +@Since("1.3.0")
      +class GammaGenerator @Since("1.3.0") (
      +    @Since("1.3.0") val shape: Double,
      +    @Since("1.3.0") val scale: Double) extends RandomDataGenerator[Double] {
       
         private val rng = new GammaDistribution(shape, scale)
       
      +  @Since("1.3.0")
         override def nextValue(): Double = rng.sample()
       
      +  @Since("1.3.0")
         override def setSeed(seed: Long) {
           rng.reseedRandomGenerator(seed)
         }
       
      +  @Since("1.3.0")
         override def copy(): GammaGenerator = new GammaGenerator(shape, scale)
       }
       
      @@ -150,15 +176,45 @@ class GammaGenerator(val shape: Double, val scale: Double) extends RandomDataGen
        * @param std standard deviation for the log normal distribution
        */
       @DeveloperApi
      -class LogNormalGenerator(val mean: Double, val std: Double) extends RandomDataGenerator[Double] {
      +@Since("1.3.0")
      +class LogNormalGenerator @Since("1.3.0") (
      +    @Since("1.3.0") val mean: Double,
      +    @Since("1.3.0") val std: Double) extends RandomDataGenerator[Double] {
       
         private val rng = new LogNormalDistribution(mean, std)
       
      +  @Since("1.3.0")
         override def nextValue(): Double = rng.sample()
       
      +  @Since("1.3.0")
         override def setSeed(seed: Long) {
           rng.reseedRandomGenerator(seed)
         }
       
      +  @Since("1.3.0")
         override def copy(): LogNormalGenerator = new LogNormalGenerator(mean, std)
       }
      +
      +/**
      + * :: DeveloperApi ::
      + * Generates i.i.d. samples from the Weibull distribution with the
      + * given shape and scale parameter.
      + *
      + * @param alpha shape parameter for the Weibull distribution.
      + * @param beta scale parameter for the Weibull distribution.
      + */
      +@DeveloperApi
      +class WeibullGenerator(
      +    val alpha: Double,
      +    val beta: Double) extends RandomDataGenerator[Double] {
      +
      +  private val rng = new WeibullDistribution(alpha, beta)
      +
      +  override def nextValue(): Double = rng.sample()
      +
      +  override def setSeed(seed: Long): Unit = {
      +    rng.reseedRandomGenerator(seed)
      +  }
      +
      +  override def copy(): WeibullGenerator = new WeibullGenerator(alpha, beta)
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
      index 174d5e0f6c9f..4dd5ea214d67 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
      @@ -20,7 +20,7 @@ package org.apache.spark.mllib.random
       import scala.reflect.ClassTag
       
       import org.apache.spark.SparkContext
      -import org.apache.spark.annotation.{DeveloperApi, Experimental}
      +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
       import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD, JavaSparkContext}
       import org.apache.spark.mllib.linalg.Vector
       import org.apache.spark.mllib.rdd.{RandomRDD, RandomVectorRDD}
      @@ -32,6 +32,7 @@ import org.apache.spark.util.Utils
        * Generator methods for creating RDDs comprised of `i.i.d.` samples from some distribution.
        */
       @Experimental
      +@Since("1.1.0")
       object RandomRDDs {
       
         /**
      @@ -46,6 +47,7 @@ object RandomRDDs {
          * @param seed Random seed (default: a random long integer).
          * @return RDD[Double] comprised of `i.i.d.` samples ~ `U(0.0, 1.0)`.
          */
      +  @Since("1.1.0")
         def uniformRDD(
             sc: SparkContext,
             size: Long,
      @@ -58,6 +60,7 @@ object RandomRDDs {
         /**
          * Java-friendly version of [[RandomRDDs#uniformRDD]].
          */
      +  @Since("1.1.0")
         def uniformJavaRDD(
             jsc: JavaSparkContext,
             size: Long,
      @@ -69,6 +72,7 @@ object RandomRDDs {
         /**
          * [[RandomRDDs#uniformJavaRDD]] with the default seed.
          */
      +  @Since("1.1.0")
         def uniformJavaRDD(jsc: JavaSparkContext, size: Long, numPartitions: Int): JavaDoubleRDD = {
           JavaDoubleRDD.fromRDD(uniformRDD(jsc.sc, size, numPartitions))
         }
      @@ -76,6 +80,7 @@ object RandomRDDs {
         /**
          * [[RandomRDDs#uniformJavaRDD]] with the default number of partitions and the default seed.
          */
      +  @Since("1.1.0")
         def uniformJavaRDD(jsc: JavaSparkContext, size: Long): JavaDoubleRDD = {
           JavaDoubleRDD.fromRDD(uniformRDD(jsc.sc, size))
         }
      @@ -92,6 +97,7 @@ object RandomRDDs {
          * @param seed Random seed (default: a random long integer).
          * @return RDD[Double] comprised of `i.i.d.` samples ~ N(0.0, 1.0).
          */
      +  @Since("1.1.0")
         def normalRDD(
             sc: SparkContext,
             size: Long,
      @@ -104,6 +110,7 @@ object RandomRDDs {
         /**
          * Java-friendly version of [[RandomRDDs#normalRDD]].
          */
      +  @Since("1.1.0")
         def normalJavaRDD(
             jsc: JavaSparkContext,
             size: Long,
      @@ -115,6 +122,7 @@ object RandomRDDs {
         /**
          * [[RandomRDDs#normalJavaRDD]] with the default seed.
          */
      +  @Since("1.1.0")
         def normalJavaRDD(jsc: JavaSparkContext, size: Long, numPartitions: Int): JavaDoubleRDD = {
           JavaDoubleRDD.fromRDD(normalRDD(jsc.sc, size, numPartitions))
         }
      @@ -122,6 +130,7 @@ object RandomRDDs {
         /**
          * [[RandomRDDs#normalJavaRDD]] with the default number of partitions and the default seed.
          */
      +  @Since("1.1.0")
         def normalJavaRDD(jsc: JavaSparkContext, size: Long): JavaDoubleRDD = {
           JavaDoubleRDD.fromRDD(normalRDD(jsc.sc, size))
         }
      @@ -137,6 +146,7 @@ object RandomRDDs {
          * @param seed Random seed (default: a random long integer).
          * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean).
          */
      +  @Since("1.1.0")
         def poissonRDD(
             sc: SparkContext,
             mean: Double,
      @@ -150,6 +160,7 @@ object RandomRDDs {
         /**
          * Java-friendly version of [[RandomRDDs#poissonRDD]].
          */
      +  @Since("1.1.0")
         def poissonJavaRDD(
             jsc: JavaSparkContext,
             mean: Double,
      @@ -162,6 +173,7 @@ object RandomRDDs {
         /**
          * [[RandomRDDs#poissonJavaRDD]] with the default seed.
          */
      +  @Since("1.1.0")
         def poissonJavaRDD(
             jsc: JavaSparkContext,
             mean: Double,
      @@ -173,6 +185,7 @@ object RandomRDDs {
         /**
          * [[RandomRDDs#poissonJavaRDD]] with the default number of partitions and the default seed.
          */
      +  @Since("1.1.0")
         def poissonJavaRDD(jsc: JavaSparkContext, mean: Double, size: Long): JavaDoubleRDD = {
           JavaDoubleRDD.fromRDD(poissonRDD(jsc.sc, mean, size))
         }
      @@ -188,6 +201,7 @@ object RandomRDDs {
          * @param seed Random seed (default: a random long integer).
          * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean).
          */
      +  @Since("1.3.0")
         def exponentialRDD(
             sc: SparkContext,
             mean: Double,
      @@ -201,6 +215,7 @@ object RandomRDDs {
         /**
          * Java-friendly version of [[RandomRDDs#exponentialRDD]].
          */
      +  @Since("1.3.0")
         def exponentialJavaRDD(
             jsc: JavaSparkContext,
             mean: Double,
      @@ -213,6 +228,7 @@ object RandomRDDs {
         /**
          * [[RandomRDDs#exponentialJavaRDD]] with the default seed.
          */
      +  @Since("1.3.0")
         def exponentialJavaRDD(
             jsc: JavaSparkContext,
             mean: Double,
      @@ -224,6 +240,7 @@ object RandomRDDs {
         /**
          * [[RandomRDDs#exponentialJavaRDD]] with the default number of partitions and the default seed.
          */
      +  @Since("1.3.0")
         def exponentialJavaRDD(jsc: JavaSparkContext, mean: Double, size: Long): JavaDoubleRDD = {
           JavaDoubleRDD.fromRDD(exponentialRDD(jsc.sc, mean, size))
         }
      @@ -240,6 +257,7 @@ object RandomRDDs {
          * @param seed Random seed (default: a random long integer).
          * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean).
          */
      +  @Since("1.3.0")
         def gammaRDD(
             sc: SparkContext,
             shape: Double,
      @@ -254,6 +272,7 @@ object RandomRDDs {
         /**
          * Java-friendly version of [[RandomRDDs#gammaRDD]].
          */
      +  @Since("1.3.0")
         def gammaJavaRDD(
             jsc: JavaSparkContext,
             shape: Double,
      @@ -267,6 +286,7 @@ object RandomRDDs {
         /**
          * [[RandomRDDs#gammaJavaRDD]] with the default seed.
          */
      +  @Since("1.3.0")
         def gammaJavaRDD(
             jsc: JavaSparkContext,
             shape: Double,
      @@ -279,11 +299,12 @@ object RandomRDDs {
         /**
          * [[RandomRDDs#gammaJavaRDD]] with the default number of partitions and the default seed.
          */
      +  @Since("1.3.0")
         def gammaJavaRDD(
      -    jsc: JavaSparkContext,
      -    shape: Double,
      -    scale: Double,
      -    size: Long): JavaDoubleRDD = {
      +      jsc: JavaSparkContext,
      +      shape: Double,
      +      scale: Double,
      +      size: Long): JavaDoubleRDD = {
           JavaDoubleRDD.fromRDD(gammaRDD(jsc.sc, shape, scale, size))
         }
       
      @@ -299,6 +320,7 @@ object RandomRDDs {
          * @param seed Random seed (default: a random long integer).
          * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean).
          */
      +  @Since("1.3.0")
         def logNormalRDD(
             sc: SparkContext,
             mean: Double,
      @@ -313,6 +335,7 @@ object RandomRDDs {
         /**
          * Java-friendly version of [[RandomRDDs#logNormalRDD]].
          */
      +  @Since("1.3.0")
         def logNormalJavaRDD(
             jsc: JavaSparkContext,
             mean: Double,
      @@ -326,6 +349,7 @@ object RandomRDDs {
         /**
          * [[RandomRDDs#logNormalJavaRDD]] with the default seed.
          */
      +  @Since("1.3.0")
         def logNormalJavaRDD(
             jsc: JavaSparkContext,
             mean: Double,
      @@ -338,11 +362,12 @@ object RandomRDDs {
         /**
          * [[RandomRDDs#logNormalJavaRDD]] with the default number of partitions and the default seed.
          */
      +  @Since("1.3.0")
         def logNormalJavaRDD(
      -    jsc: JavaSparkContext,
      -    mean: Double,
      -    std: Double,
      -    size: Long): JavaDoubleRDD = {
      +      jsc: JavaSparkContext,
      +      mean: Double,
      +      std: Double,
      +      size: Long): JavaDoubleRDD = {
           JavaDoubleRDD.fromRDD(logNormalRDD(jsc.sc, mean, std, size))
         }
       
      @@ -359,6 +384,7 @@ object RandomRDDs {
          * @return RDD[Double] comprised of `i.i.d.` samples produced by generator.
          */
         @DeveloperApi
      +  @Since("1.1.0")
         def randomRDD[T: ClassTag](
             sc: SparkContext,
             generator: RandomDataGenerator[T],
      @@ -381,6 +407,7 @@ object RandomRDDs {
          * @param seed Seed for the RNG that generates the seed for the generator in each partition.
          * @return RDD[Vector] with vectors containing i.i.d samples ~ `U(0.0, 1.0)`.
          */
      +  @Since("1.1.0")
         def uniformVectorRDD(
             sc: SparkContext,
             numRows: Long,
      @@ -394,6 +421,7 @@ object RandomRDDs {
         /**
          * Java-friendly version of [[RandomRDDs#uniformVectorRDD]].
          */
      +  @Since("1.1.0")
         def uniformJavaVectorRDD(
             jsc: JavaSparkContext,
             numRows: Long,
      @@ -406,6 +434,7 @@ object RandomRDDs {
         /**
          * [[RandomRDDs#uniformJavaVectorRDD]] with the default seed.
          */
      +  @Since("1.1.0")
         def uniformJavaVectorRDD(
             jsc: JavaSparkContext,
             numRows: Long,
      @@ -417,6 +446,7 @@ object RandomRDDs {
         /**
          * [[RandomRDDs#uniformJavaVectorRDD]] with the default number of partitions and the default seed.
          */
      +  @Since("1.1.0")
         def uniformJavaVectorRDD(
             jsc: JavaSparkContext,
             numRows: Long,
      @@ -435,6 +465,7 @@ object RandomRDDs {
          * @param seed Random seed (default: a random long integer).
          * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ `N(0.0, 1.0)`.
          */
      +  @Since("1.1.0")
         def normalVectorRDD(
             sc: SparkContext,
             numRows: Long,
      @@ -448,6 +479,7 @@ object RandomRDDs {
         /**
          * Java-friendly version of [[RandomRDDs#normalVectorRDD]].
          */
      +  @Since("1.1.0")
         def normalJavaVectorRDD(
             jsc: JavaSparkContext,
             numRows: Long,
      @@ -460,6 +492,7 @@ object RandomRDDs {
         /**
          * [[RandomRDDs#normalJavaVectorRDD]] with the default seed.
          */
      +  @Since("1.1.0")
         def normalJavaVectorRDD(
             jsc: JavaSparkContext,
             numRows: Long,
      @@ -471,6 +504,7 @@ object RandomRDDs {
         /**
          * [[RandomRDDs#normalJavaVectorRDD]] with the default number of partitions and the default seed.
          */
      +  @Since("1.1.0")
         def normalJavaVectorRDD(
             jsc: JavaSparkContext,
             numRows: Long,
      @@ -491,6 +525,7 @@ object RandomRDDs {
          * @param seed Random seed (default: a random long integer).
          * @return RDD[Vector] with vectors containing `i.i.d.` samples.
          */
      +  @Since("1.3.0")
         def logNormalVectorRDD(
             sc: SparkContext,
             mean: Double,
      @@ -507,6 +542,7 @@ object RandomRDDs {
         /**
          * Java-friendly version of [[RandomRDDs#logNormalVectorRDD]].
          */
      +  @Since("1.3.0")
         def logNormalJavaVectorRDD(
             jsc: JavaSparkContext,
             mean: Double,
      @@ -521,6 +557,7 @@ object RandomRDDs {
         /**
          * [[RandomRDDs#logNormalJavaVectorRDD]] with the default seed.
          */
      +  @Since("1.3.0")
         def logNormalJavaVectorRDD(
             jsc: JavaSparkContext,
             mean: Double,
      @@ -535,6 +572,7 @@ object RandomRDDs {
          * [[RandomRDDs#logNormalJavaVectorRDD]] with the default number of partitions and
          * the default seed.
          */
      +  @Since("1.3.0")
         def logNormalJavaVectorRDD(
             jsc: JavaSparkContext,
             mean: Double,
      @@ -556,6 +594,7 @@ object RandomRDDs {
          * @param seed Random seed (default: a random long integer).
          * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ Pois(mean).
          */
      +  @Since("1.1.0")
         def poissonVectorRDD(
             sc: SparkContext,
             mean: Double,
      @@ -570,6 +609,7 @@ object RandomRDDs {
         /**
          * Java-friendly version of [[RandomRDDs#poissonVectorRDD]].
          */
      +  @Since("1.1.0")
         def poissonJavaVectorRDD(
             jsc: JavaSparkContext,
             mean: Double,
      @@ -583,6 +623,7 @@ object RandomRDDs {
         /**
          * [[RandomRDDs#poissonJavaVectorRDD]] with the default seed.
          */
      +  @Since("1.1.0")
         def poissonJavaVectorRDD(
             jsc: JavaSparkContext,
             mean: Double,
      @@ -595,6 +636,7 @@ object RandomRDDs {
         /**
          * [[RandomRDDs#poissonJavaVectorRDD]] with the default number of partitions and the default seed.
          */
      +  @Since("1.1.0")
         def poissonJavaVectorRDD(
             jsc: JavaSparkContext,
             mean: Double,
      @@ -615,6 +657,7 @@ object RandomRDDs {
          * @param seed Random seed (default: a random long integer).
          * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ Exp(mean).
          */
      +  @Since("1.3.0")
         def exponentialVectorRDD(
             sc: SparkContext,
             mean: Double,
      @@ -630,6 +673,7 @@ object RandomRDDs {
         /**
          * Java-friendly version of [[RandomRDDs#exponentialVectorRDD]].
          */
      +  @Since("1.3.0")
         def exponentialJavaVectorRDD(
             jsc: JavaSparkContext,
             mean: Double,
      @@ -643,6 +687,7 @@ object RandomRDDs {
         /**
          * [[RandomRDDs#exponentialJavaVectorRDD]] with the default seed.
          */
      +  @Since("1.3.0")
         def exponentialJavaVectorRDD(
             jsc: JavaSparkContext,
             mean: Double,
      @@ -656,6 +701,7 @@ object RandomRDDs {
          * [[RandomRDDs#exponentialJavaVectorRDD]] with the default number of partitions
          * and the default seed.
          */
      +  @Since("1.3.0")
         def exponentialJavaVectorRDD(
             jsc: JavaSparkContext,
             mean: Double,
      @@ -678,6 +724,7 @@ object RandomRDDs {
          * @param seed Random seed (default: a random long integer).
          * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ Exp(mean).
          */
      +  @Since("1.3.0")
         def gammaVectorRDD(
             sc: SparkContext,
             shape: Double,
      @@ -693,6 +740,7 @@ object RandomRDDs {
         /**
          * Java-friendly version of [[RandomRDDs#gammaVectorRDD]].
          */
      +  @Since("1.3.0")
         def gammaJavaVectorRDD(
             jsc: JavaSparkContext,
             shape: Double,
      @@ -707,6 +755,7 @@ object RandomRDDs {
         /**
          * [[RandomRDDs#gammaJavaVectorRDD]] with the default seed.
          */
      +  @Since("1.3.0")
         def gammaJavaVectorRDD(
             jsc: JavaSparkContext,
             shape: Double,
      @@ -720,6 +769,7 @@ object RandomRDDs {
         /**
          * [[RandomRDDs#gammaJavaVectorRDD]] with the default number of partitions and the default seed.
          */
      +  @Since("1.3.0")
         def gammaJavaVectorRDD(
             jsc: JavaSparkContext,
             shape: Double,
      @@ -744,6 +794,7 @@ object RandomRDDs {
          * @return RDD[Vector] with vectors containing `i.i.d.` samples produced by generator.
          */
         @DeveloperApi
      +  @Since("1.1.0")
         def randomVectorRDD(sc: SparkContext,
             generator: RandomDataGenerator[Double],
             numRows: Long,
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala
      index 910eff9540a4..f8cea7ecea6b 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala
      @@ -35,11 +35,11 @@ private[mllib] class RandomRDDPartition[T](override val index: Int,
       }
       
       // These two classes are necessary since Range objects in Scala cannot have size > Int.MaxValue
      -private[mllib] class RandomRDD[T: ClassTag](@transient sc: SparkContext,
      +private[mllib] class RandomRDD[T: ClassTag](sc: SparkContext,
           size: Long,
           numPartitions: Int,
      -    @transient rng: RandomDataGenerator[T],
      -    @transient seed: Long = Utils.random.nextLong) extends RDD[T](sc, Nil) {
      +    @transient private val rng: RandomDataGenerator[T],
      +    @transient private val seed: Long = Utils.random.nextLong) extends RDD[T](sc, Nil) {
       
         require(size > 0, "Positive RDD size required.")
         require(numPartitions > 0, "Positive number of partitions required")
      @@ -56,12 +56,12 @@ private[mllib] class RandomRDD[T: ClassTag](@transient sc: SparkContext,
         }
       }
       
      -private[mllib] class RandomVectorRDD(@transient sc: SparkContext,
      +private[mllib] class RandomVectorRDD(sc: SparkContext,
           size: Long,
           vectorSize: Int,
           numPartitions: Int,
      -    @transient rng: RandomDataGenerator[Double],
      -    @transient seed: Long = Utils.random.nextLong) extends RDD[Vector](sc, Nil) {
      +    @transient private val rng: RandomDataGenerator[Double],
      +    @transient private val seed: Long = Utils.random.nextLong) extends RDD[Vector](sc, Nil) {
       
         require(size > 0, "Positive RDD size required.")
         require(numPartitions > 0, "Positive number of partitions required")
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
      index 35e81fcb3de0..1facf83d806d 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
      @@ -72,7 +72,7 @@ class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int
             val w1 = windowSize - 1
             // Get the first w1 items of each partition, starting from the second partition.
             val nextHeads =
      -        parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n, true)
      +        parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n)
             val partitions = mutable.ArrayBuffer[SlidingRDDPartition[T]]()
             var i = 0
             var partitionIndex = 0
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
      index 93290e650852..33aaf853e599 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
      @@ -18,7 +18,7 @@
       package org.apache.spark.mllib.recommendation
       
       import org.apache.spark.Logging
      -import org.apache.spark.annotation.DeveloperApi
      +import org.apache.spark.annotation.{DeveloperApi, Since}
       import org.apache.spark.api.java.JavaRDD
       import org.apache.spark.ml.recommendation.{ALS => NewALS}
       import org.apache.spark.rdd.RDD
      @@ -27,7 +27,11 @@ import org.apache.spark.storage.StorageLevel
       /**
        * A more compact class to represent a rating than Tuple3[Int, Int, Double].
        */
      -case class Rating(user: Int, product: Int, rating: Double)
      +@Since("0.8.0")
      +case class Rating @Since("0.8.0") (
      +    @Since("0.8.0") user: Int,
      +    @Since("0.8.0") product: Int,
      +    @Since("0.8.0") rating: Double)
       
       /**
        * Alternating Least Squares matrix factorization.
      @@ -58,6 +62,7 @@ case class Rating(user: Int, product: Int, rating: Double)
        * indicated user
        * preferences rather than explicit ratings given to items.
        */
      +@Since("0.8.0")
       class ALS private (
           private var numUserBlocks: Int,
           private var numProductBlocks: Int,
      @@ -73,6 +78,7 @@ class ALS private (
          * Constructs an ALS instance with default parameters: {numBlocks: -1, rank: 10, iterations: 10,
          * lambda: 0.01, implicitPrefs: false, alpha: 1.0}.
          */
      +  @Since("0.8.0")
         def this() = this(-1, -1, 10, 10, 0.01, false, 1.0)
       
         /** If true, do alternating nonnegative least squares. */
      @@ -89,6 +95,7 @@ class ALS private (
          * Set the number of blocks for both user blocks and product blocks to parallelize the computation
          * into; pass -1 for an auto-configured number of blocks. Default: -1.
          */
      +  @Since("0.8.0")
         def setBlocks(numBlocks: Int): this.type = {
           this.numUserBlocks = numBlocks
           this.numProductBlocks = numBlocks
      @@ -98,6 +105,7 @@ class ALS private (
         /**
          * Set the number of user blocks to parallelize the computation.
          */
      +  @Since("1.1.0")
         def setUserBlocks(numUserBlocks: Int): this.type = {
           this.numUserBlocks = numUserBlocks
           this
      @@ -106,30 +114,35 @@ class ALS private (
         /**
          * Set the number of product blocks to parallelize the computation.
          */
      +  @Since("1.1.0")
         def setProductBlocks(numProductBlocks: Int): this.type = {
           this.numProductBlocks = numProductBlocks
           this
         }
       
         /** Set the rank of the feature matrices computed (number of features). Default: 10. */
      +  @Since("0.8.0")
         def setRank(rank: Int): this.type = {
           this.rank = rank
           this
         }
       
         /** Set the number of iterations to run. Default: 10. */
      +  @Since("0.8.0")
         def setIterations(iterations: Int): this.type = {
           this.iterations = iterations
           this
         }
       
         /** Set the regularization parameter, lambda. Default: 0.01. */
      +  @Since("0.8.0")
         def setLambda(lambda: Double): this.type = {
           this.lambda = lambda
           this
         }
       
         /** Sets whether to use implicit preference. Default: false. */
      +  @Since("0.8.1")
         def setImplicitPrefs(implicitPrefs: Boolean): this.type = {
           this.implicitPrefs = implicitPrefs
           this
      @@ -138,12 +151,14 @@ class ALS private (
         /**
          * Sets the constant used in computing confidence in implicit ALS. Default: 1.0.
          */
      +  @Since("0.8.1")
         def setAlpha(alpha: Double): this.type = {
           this.alpha = alpha
           this
         }
       
         /** Sets a random seed to have deterministic results. */
      +  @Since("1.0.0")
         def setSeed(seed: Long): this.type = {
           this.seed = seed
           this
      @@ -153,6 +168,7 @@ class ALS private (
          * Set whether the least-squares problems solved at each iteration should have
          * nonnegativity constraints.
          */
      +  @Since("1.1.0")
         def setNonnegative(b: Boolean): this.type = {
           this.nonnegative = b
           this
      @@ -165,6 +181,7 @@ class ALS private (
          * set `spark.rdd.compress` to `true` to reduce the space requirement, at the cost of speed.
          */
         @DeveloperApi
      +  @Since("1.1.0")
         def setIntermediateRDDStorageLevel(storageLevel: StorageLevel): this.type = {
           require(storageLevel != StorageLevel.NONE,
             "ALS is not designed to run without persisting intermediate RDDs.")
      @@ -180,6 +197,7 @@ class ALS private (
          * at the cost of speed.
          */
         @DeveloperApi
      +  @Since("1.3.0")
         def setFinalRDDStorageLevel(storageLevel: StorageLevel): this.type = {
           this.finalRDDStorageLevel = storageLevel
           this
      @@ -193,6 +211,7 @@ class ALS private (
          * this setting is ignored.
          */
         @DeveloperApi
      +  @Since("1.4.0")
         def setCheckpointInterval(checkpointInterval: Int): this.type = {
           this.checkpointInterval = checkpointInterval
           this
      @@ -202,6 +221,7 @@ class ALS private (
          * Run ALS with the configured parameters on an input RDD of (user, product, rating) triples.
          * Returns a MatrixFactorizationModel with feature vectors for each user and product.
          */
      +  @Since("0.8.0")
         def run(ratings: RDD[Rating]): MatrixFactorizationModel = {
           val sc = ratings.context
       
      @@ -249,12 +269,14 @@ class ALS private (
         /**
          * Java-friendly version of [[ALS.run]].
          */
      +  @Since("1.3.0")
         def run(ratings: JavaRDD[Rating]): MatrixFactorizationModel = run(ratings.rdd)
       }
       
       /**
        * Top-level methods for calling Alternating Least Squares (ALS) matrix factorization.
        */
      +@Since("0.8.0")
       object ALS {
         /**
          * Train a matrix factorization model given an RDD of ratings given by users to some products,
      @@ -270,6 +292,7 @@ object ALS {
          * @param blocks     level of parallelism to split computation into
          * @param seed       random seed
          */
      +  @Since("0.9.1")
         def train(
             ratings: RDD[Rating],
             rank: Int,
      @@ -294,6 +317,7 @@ object ALS {
          * @param lambda     regularization factor (recommended: 0.01)
          * @param blocks     level of parallelism to split computation into
          */
      +  @Since("0.8.0")
         def train(
             ratings: RDD[Rating],
             rank: Int,
      @@ -316,6 +340,7 @@ object ALS {
          * @param iterations number of iterations of ALS (recommended: 10-20)
          * @param lambda     regularization factor (recommended: 0.01)
          */
      +  @Since("0.8.0")
         def train(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double)
           : MatrixFactorizationModel = {
           train(ratings, rank, iterations, lambda, -1)
      @@ -332,6 +357,7 @@ object ALS {
          * @param rank       number of features to use
          * @param iterations number of iterations of ALS (recommended: 10-20)
          */
      +  @Since("0.8.0")
         def train(ratings: RDD[Rating], rank: Int, iterations: Int)
           : MatrixFactorizationModel = {
           train(ratings, rank, iterations, 0.01, -1)
      @@ -352,6 +378,7 @@ object ALS {
          * @param alpha      confidence parameter
          * @param seed       random seed
          */
      +  @Since("0.8.1")
         def trainImplicit(
             ratings: RDD[Rating],
             rank: Int,
      @@ -378,6 +405,7 @@ object ALS {
          * @param blocks     level of parallelism to split computation into
          * @param alpha      confidence parameter
          */
      +  @Since("0.8.1")
         def trainImplicit(
             ratings: RDD[Rating],
             rank: Int,
      @@ -402,6 +430,7 @@ object ALS {
          * @param lambda     regularization factor (recommended: 0.01)
          * @param alpha      confidence parameter
          */
      +  @Since("0.8.1")
         def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double, alpha: Double)
           : MatrixFactorizationModel = {
           trainImplicit(ratings, rank, iterations, lambda, -1, alpha)
      @@ -419,6 +448,7 @@ object ALS {
          * @param rank       number of features to use
          * @param iterations number of iterations of ALS (recommended: 10-20)
          */
      +  @Since("0.8.1")
         def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int)
           : MatrixFactorizationModel = {
           trainImplicit(ratings, rank, iterations, 0.01, -1, 1.0)
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
      index 93aa41e49961..46562eb2ad0f 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
      @@ -22,6 +22,7 @@ import java.lang.{Integer => JavaInteger}
       
       import scala.collection.mutable
       
      +import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus
       import com.github.fommil.netlib.BLAS.{getInstance => blas}
       import org.apache.hadoop.fs.Path
       import org.json4s._
      @@ -29,6 +30,7 @@ import org.json4s.JsonDSL._
       import org.json4s.jackson.JsonMethods._
       
       import org.apache.spark.{Logging, SparkContext}
      +import org.apache.spark.annotation.Since
       import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
       import org.apache.spark.mllib.linalg._
       import org.apache.spark.mllib.rdd.MLPairRDDFunctions._
      @@ -49,10 +51,11 @@ import org.apache.spark.storage.StorageLevel
        * @param productFeatures RDD of tuples where each tuple represents the productId
        *                        and the features computed for this product.
        */
      -class MatrixFactorizationModel(
      -    val rank: Int,
      -    val userFeatures: RDD[(Int, Array[Double])],
      -    val productFeatures: RDD[(Int, Array[Double])])
      +@Since("0.8.0")
      +class MatrixFactorizationModel @Since("0.8.0") (
      +    @Since("0.8.0") val rank: Int,
      +    @Since("0.8.0") val userFeatures: RDD[(Int, Array[Double])],
      +    @Since("0.8.0") val productFeatures: RDD[(Int, Array[Double])])
         extends Saveable with Serializable with Logging {
       
         require(rank > 0)
      @@ -73,12 +76,37 @@ class MatrixFactorizationModel(
         }
       
         /** Predict the rating of one user for one product. */
      +  @Since("0.8.0")
         def predict(user: Int, product: Int): Double = {
           val userVector = userFeatures.lookup(user).head
           val productVector = productFeatures.lookup(product).head
           blas.ddot(rank, userVector, 1, productVector, 1)
         }
       
      +  /**
      +   * Return approximate numbers of users and products in the given usersProducts tuples.
      +   * This method is based on `countApproxDistinct` in class `RDD`.
      +   *
      +   * @param usersProducts  RDD of (user, product) pairs.
      +   * @return approximate numbers of users and products.
      +   */
      +  private[this] def countApproxDistinctUserProduct(usersProducts: RDD[(Int, Int)]): (Long, Long) = {
      +    val zeroCounterUser = new HyperLogLogPlus(4, 0)
      +    val zeroCounterProduct = new HyperLogLogPlus(4, 0)
      +    val aggregated = usersProducts.aggregate((zeroCounterUser, zeroCounterProduct))(
      +      (hllTuple: (HyperLogLogPlus, HyperLogLogPlus), v: (Int, Int)) => {
      +        hllTuple._1.offer(v._1)
      +        hllTuple._2.offer(v._2)
      +        hllTuple
      +      },
      +      (h1: (HyperLogLogPlus, HyperLogLogPlus), h2: (HyperLogLogPlus, HyperLogLogPlus)) => {
      +        h1._1.addAll(h2._1)
      +        h1._2.addAll(h2._2)
      +        h1
      +      })
      +    (aggregated._1.cardinality(), aggregated._2.cardinality())
      +  }
      +
         /**
          * Predict the rating of many users for many products.
          * The output RDD has an element per each element in the input RDD (including all duplicates)
      @@ -87,19 +115,39 @@ class MatrixFactorizationModel(
          * @param usersProducts  RDD of (user, product) pairs.
          * @return RDD of Ratings.
          */
      +  @Since("0.9.0")
         def predict(usersProducts: RDD[(Int, Int)]): RDD[Rating] = {
      -    val users = userFeatures.join(usersProducts).map {
      -      case (user, (uFeatures, product)) => (product, (user, uFeatures))
      -    }
      -    users.join(productFeatures).map {
      -      case (product, ((user, uFeatures), pFeatures)) =>
      -        Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1))
      +    // Previously the partitions of ratings are only based on the given products.
      +    // So if the usersProducts given for prediction contains only few products or
      +    // even one product, the generated ratings will be pushed into few or single partition
      +    // and can't use high parallelism.
      +    // Here we calculate approximate numbers of users and products. Then we decide the
      +    // partitions should be based on users or products.
      +    val (usersCount, productsCount) = countApproxDistinctUserProduct(usersProducts)
      +
      +    if (usersCount < productsCount) {
      +      val users = userFeatures.join(usersProducts).map {
      +        case (user, (uFeatures, product)) => (product, (user, uFeatures))
      +      }
      +      users.join(productFeatures).map {
      +        case (product, ((user, uFeatures), pFeatures)) =>
      +          Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1))
      +      }
      +    } else {
      +      val products = productFeatures.join(usersProducts.map(_.swap)).map {
      +        case (product, (pFeatures, user)) => (user, (product, pFeatures))
      +      }
      +      products.join(userFeatures).map {
      +        case (user, ((product, pFeatures), uFeatures)) =>
      +          Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1))
      +      }
           }
         }
       
         /**
          * Java-friendly version of [[MatrixFactorizationModel.predict]].
          */
      +  @Since("1.2.0")
         def predict(usersProducts: JavaPairRDD[JavaInteger, JavaInteger]): JavaRDD[Rating] = {
           predict(usersProducts.rdd.asInstanceOf[RDD[(Int, Int)]]).toJavaRDD()
         }
      @@ -115,6 +163,7 @@ class MatrixFactorizationModel(
          *  recommended to the user. The score is an opaque value that indicates how strongly
          *  recommended the product is.
          */
      +  @Since("1.1.0")
         def recommendProducts(user: Int, num: Int): Array[Rating] =
           MatrixFactorizationModel.recommend(userFeatures.lookup(user).head, productFeatures, num)
             .map(t => Rating(user, t._1, t._2))
      @@ -131,12 +180,27 @@ class MatrixFactorizationModel(
          *  recommended to the product. The score is an opaque value that indicates how strongly
          *  recommended the user is.
          */
      +  @Since("1.1.0")
         def recommendUsers(product: Int, num: Int): Array[Rating] =
           MatrixFactorizationModel.recommend(productFeatures.lookup(product).head, userFeatures, num)
             .map(t => Rating(t._1, product, t._2))
       
         protected override val formatVersion: String = "1.0"
       
      +  /**
      +   * Save this model to the given path.
      +   *
      +   * This saves:
      +   *  - human-readable (JSON) model metadata to path/metadata/
      +   *  - Parquet formatted data to path/data/
      +   *
      +   * The model may be loaded using [[Loader.load]].
      +   *
      +   * @param sc  Spark context used to save model data.
      +   * @param path  Path specifying the directory in which to save this model.
      +   *              If the directory already exists, this method throws an exception.
      +   */
      +  @Since("1.3.0")
         override def save(sc: SparkContext, path: String): Unit = {
           MatrixFactorizationModel.SaveLoadV1_0.save(this, path)
         }
      @@ -149,6 +213,7 @@ class MatrixFactorizationModel(
          * rating objects which contains the same userId, recommended productID and a "score" in the
          * rating field. Semantics of score is same as recommendProducts API
          */
      +  @Since("1.4.0")
         def recommendProductsForUsers(num: Int): RDD[(Int, Array[Rating])] = {
           MatrixFactorizationModel.recommendForAll(rank, userFeatures, productFeatures, num).map {
             case (user, top) =>
      @@ -166,6 +231,7 @@ class MatrixFactorizationModel(
          * of rating objects which contains the recommended userId, same productID and a "score" in the
          * rating field. Semantics of score is same as recommendUsers API
          */
      +  @Since("1.4.0")
         def recommendUsersForProducts(num: Int): RDD[(Int, Array[Rating])] = {
           MatrixFactorizationModel.recommendForAll(rank, productFeatures, userFeatures, num).map {
             case (product, top) =>
      @@ -175,6 +241,7 @@ class MatrixFactorizationModel(
         }
       }
       
      +@Since("1.3.0")
       object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
       
         import org.apache.spark.mllib.util.Loader._
      @@ -249,6 +316,16 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
           }
         }
       
      +  /**
      +   * Load a model from the given path.
      +   *
      +   * The model should have been saved by [[Saveable.save]].
      +   *
      +   * @param sc  Spark context used for loading model files.
      +   * @param path  Path specifying the directory to which the model was saved.
      +   * @return  Model instance
      +   */
      +  @Since("1.3.0")
         override def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
           val (loadedClassName, formatVersion, _) = loadMetadata(sc, path)
           val classNameV1_0 = SaveLoadV1_0.thisClassName
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
      index 6709bd79bc82..8f657bfb9c73 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
      @@ -17,7 +17,7 @@
       
       package org.apache.spark.mllib.regression
       
      -import org.apache.spark.annotation.DeveloperApi
      +import org.apache.spark.annotation.{DeveloperApi, Since}
       import org.apache.spark.mllib.feature.StandardScaler
       import org.apache.spark.{Logging, SparkException}
       import org.apache.spark.rdd.RDD
      @@ -34,9 +34,13 @@ import org.apache.spark.storage.StorageLevel
        *
        * @param weights Weights computed for every feature.
        * @param intercept Intercept computed for this model.
      + *
        */
      +@Since("0.8.0")
       @DeveloperApi
      -abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double)
      +abstract class GeneralizedLinearModel @Since("1.0.0") (
      +    @Since("1.0.0") val weights: Vector,
      +    @Since("0.8.0") val intercept: Double)
         extends Serializable {
       
         /**
      @@ -53,7 +57,9 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double
          *
          * @param testData RDD representing data points to be predicted
          * @return RDD[Double] where each entry contains the corresponding prediction
      +   *
          */
      +  @Since("1.0.0")
         def predict(testData: RDD[Vector]): RDD[Double] = {
           // A small optimization to avoid serializing the entire model. Only the weightsMatrix
           // and intercept is needed.
      @@ -71,7 +77,9 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double
          *
          * @param testData array representing a single data point
          * @return Double prediction from the trained model
      +   *
          */
      +  @Since("1.0.0")
         def predict(testData: Vector): Double = {
           predictPoint(testData, weights, intercept)
         }
      @@ -88,14 +96,20 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double
        * :: DeveloperApi ::
        * GeneralizedLinearAlgorithm implements methods to train a Generalized Linear Model (GLM).
        * This class should be extended with an Optimizer to create a new GLM.
      + *
        */
      +@Since("0.8.0")
       @DeveloperApi
       abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
         extends Logging with Serializable {
       
         protected val validators: Seq[RDD[LabeledPoint] => Boolean] = List()
       
      -  /** The optimizer to solve the problem. */
      +  /**
      +   * The optimizer to solve the problem.
      +   *
      +   */
      +  @Since("0.8.0")
         def optimizer: Optimizer
       
         /** Whether to add intercept (default: false). */
      @@ -130,7 +144,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
       
         /**
          * The dimension of training features.
      +   *
          */
      +  @Since("1.4.0")
         def getNumFeatures: Int = this.numFeatures
       
         /**
      @@ -153,13 +169,17 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
       
         /**
          * Get if the algorithm uses addIntercept
      +   *
          */
      +  @Since("1.4.0")
         def isAddIntercept: Boolean = this.addIntercept
       
         /**
          * Set if the algorithm should add an intercept. Default false.
          * We set the default to false because adding the intercept will cause memory allocation.
      +   *
          */
      +  @Since("0.8.0")
         def setIntercept(addIntercept: Boolean): this.type = {
           this.addIntercept = addIntercept
           this
      @@ -167,7 +187,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
       
         /**
          * Set if the algorithm should validate data before training. Default true.
      +   *
          */
      +  @Since("0.8.0")
         def setValidateData(validateData: Boolean): this.type = {
           this.validateData = validateData
           this
      @@ -176,7 +198,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
         /**
          * Run the algorithm with the configured parameters on an input
          * RDD of LabeledPoint entries.
      +   *
          */
      +  @Since("0.8.0")
         def run(input: RDD[LabeledPoint]): M = {
           if (numFeatures < 0) {
             numFeatures = input.map(_.features.size).first()
      @@ -208,7 +232,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
         /**
          * Run the algorithm with the configured parameters on an input RDD
          * of LabeledPoint entries starting from the initial weights provided.
      +   *
          */
      +  @Since("1.0.0")
         def run(input: RDD[LabeledPoint], initialWeights: Vector): M = {
       
           if (numFeatures < 0) {
      @@ -333,6 +359,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
               + " parent RDDs are also uncached.")
           }
       
      +    // Unpersist cached data
      +    if (data.getStorageLevel != StorageLevel.NONE) {
      +      data.unpersist(false)
      +    }
      +
           createModel(weights, intercept)
         }
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
      index f3b46c75c05f..877d31ba4130 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
      @@ -29,7 +29,7 @@ import org.json4s.JsonDSL._
       import org.json4s.jackson.JsonMethods._
       
       import org.apache.spark.SparkContext
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD}
       import org.apache.spark.mllib.linalg.{Vector, Vectors}
       import org.apache.spark.mllib.util.{Loader, Saveable}
      @@ -46,12 +46,14 @@ import org.apache.spark.sql.SQLContext
        * @param predictions Array of predictions associated to the boundaries at the same index.
        *                    Results of isotonic regression and therefore monotone.
        * @param isotonic indicates whether this is isotonic or antitonic.
      + *
        */
      +@Since("1.3.0")
       @Experimental
      -class IsotonicRegressionModel (
      -    val boundaries: Array[Double],
      -    val predictions: Array[Double],
      -    val isotonic: Boolean) extends Serializable with Saveable {
      +class IsotonicRegressionModel @Since("1.3.0") (
      +    @Since("1.3.0") val boundaries: Array[Double],
      +    @Since("1.3.0") val predictions: Array[Double],
      +    @Since("1.3.0") val isotonic: Boolean) extends Serializable with Saveable {
       
         private val predictionOrd = if (isotonic) Ordering[Double] else Ordering[Double].reverse
       
      @@ -59,7 +61,10 @@ class IsotonicRegressionModel (
         assertOrdered(boundaries)
         assertOrdered(predictions)(predictionOrd)
       
      -  /** A Java-friendly constructor that takes two Iterable parameters and one Boolean parameter. */
      +  /**
      +   * A Java-friendly constructor that takes two Iterable parameters and one Boolean parameter.
      +   */
      +  @Since("1.4.0")
         def this(boundaries: java.lang.Iterable[Double],
             predictions: java.lang.Iterable[Double],
             isotonic: java.lang.Boolean) = {
      @@ -83,7 +88,9 @@ class IsotonicRegressionModel (
          *
          * @param testData Features to be labeled.
          * @return Predicted labels.
      +   *
          */
      +  @Since("1.3.0")
         def predict(testData: RDD[Double]): RDD[Double] = {
           testData.map(predict)
         }
      @@ -94,7 +101,9 @@ class IsotonicRegressionModel (
          *
          * @param testData Features to be labeled.
          * @return Predicted labels.
      +   *
          */
      +  @Since("1.3.0")
         def predict(testData: JavaDoubleRDD): JavaDoubleRDD = {
           JavaDoubleRDD.fromRDD(predict(testData.rdd.retag.asInstanceOf[RDD[Double]]))
         }
      @@ -114,7 +123,9 @@ class IsotonicRegressionModel (
          *         3) If testData falls between two values in boundary array then prediction is treated
          *           as piecewise linear function and interpolated value is returned. In case there are
          *           multiple values with the same boundary then the same rules as in 2) are used.
      +   *
          */
      +  @Since("1.3.0")
         def predict(testData: Double): Double = {
       
           def linearInterpolation(x1: Double, y1: Double, x2: Double, y2: Double, x: Double): Double = {
      @@ -148,6 +159,7 @@ class IsotonicRegressionModel (
         /** A convenient method for boundaries called by the Python API. */
         private[mllib] def predictionVector: Vector = Vectors.dense(predictions)
       
      +  @Since("1.4.0")
         override def save(sc: SparkContext, path: String): Unit = {
           IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, boundaries, predictions, isotonic)
         }
      @@ -155,6 +167,7 @@ class IsotonicRegressionModel (
         override protected def formatVersion: String = "1.0"
       }
       
      +@Since("1.4.0")
       object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
       
         import org.apache.spark.mllib.util.Loader._
      @@ -200,6 +213,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
           }
         }
       
      +  @Since("1.4.0")
         override def load(sc: SparkContext, path: String): IsotonicRegressionModel = {
           implicit val formats = DefaultFormats
           val (loadedClassName, version, metadata) = loadMetadata(sc, path)
      @@ -239,6 +253,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
        * @see [[http://en.wikipedia.org/wiki/Isotonic_regression Isotonic regression (Wikipedia)]]
        */
       @Experimental
      +@Since("1.3.0")
       class IsotonicRegression private (private var isotonic: Boolean) extends Serializable {
       
         /**
      @@ -246,6 +261,7 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali
          *
          * @return New instance of IsotonicRegression.
          */
      +  @Since("1.3.0")
         def this() = this(true)
       
         /**
      @@ -254,6 +270,7 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali
          * @param isotonic Isotonic (increasing) or antitonic (decreasing) sequence.
          * @return This instance of IsotonicRegression.
          */
      +  @Since("1.3.0")
         def setIsotonic(isotonic: Boolean): this.type = {
           this.isotonic = isotonic
           this
      @@ -269,6 +286,7 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali
          *              the algorithm is executed.
          * @return Isotonic regression model.
          */
      +  @Since("1.3.0")
         def run(input: RDD[(Double, Double, Double)]): IsotonicRegressionModel = {
           val preprocessedInput = if (isotonic) {
             input
      @@ -294,6 +312,7 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali
          *              the algorithm is executed.
          * @return Isotonic regression model.
          */
      +  @Since("1.3.0")
         def run(input: JavaRDD[(JDouble, JDouble, JDouble)]): IsotonicRegressionModel = {
           run(input.rdd.retag.asInstanceOf[RDD[(Double, Double, Double)]])
         }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
      index d5fea822ad77..c284ad232537 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
      @@ -19,6 +19,7 @@ package org.apache.spark.mllib.regression
       
       import scala.beans.BeanInfo
       
      +import org.apache.spark.annotation.Since
       import org.apache.spark.mllib.linalg.{Vectors, Vector}
       import org.apache.spark.mllib.util.NumericParser
       import org.apache.spark.SparkException
      @@ -29,8 +30,11 @@ import org.apache.spark.SparkException
        * @param label Label for this data point.
        * @param features List of features for this data point.
        */
      +@Since("0.8.0")
       @BeanInfo
      -case class LabeledPoint(label: Double, features: Vector) {
      +case class LabeledPoint @Since("1.0.0") (
      +    @Since("0.8.0") label: Double,
      +    @Since("1.0.0") features: Vector) {
         override def toString: String = {
           s"($label,$features)"
         }
      @@ -38,12 +42,16 @@ case class LabeledPoint(label: Double, features: Vector) {
       
       /**
        * Parser for [[org.apache.spark.mllib.regression.LabeledPoint]].
      + *
        */
      +@Since("1.1.0")
       object LabeledPoint {
         /**
          * Parses a string resulted from `LabeledPoint#toString` into
          * an [[org.apache.spark.mllib.regression.LabeledPoint]].
      +   *
          */
      +  @Since("1.1.0")
         def parse(s: String): LabeledPoint = {
           if (s.startsWith("(")) {
             NumericParser.parse(s) match {
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
      index 4f482384f0f3..a9aba173fa0e 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
      @@ -18,6 +18,7 @@
       package org.apache.spark.mllib.regression
       
       import org.apache.spark.SparkContext
      +import org.apache.spark.annotation.Since
       import org.apache.spark.mllib.linalg.Vector
       import org.apache.spark.mllib.optimization._
       import org.apache.spark.mllib.pmml.PMMLExportable
      @@ -30,10 +31,12 @@ import org.apache.spark.rdd.RDD
        *
        * @param weights Weights computed for every feature.
        * @param intercept Intercept computed for this model.
      + *
        */
      -class LassoModel (
      -    override val weights: Vector,
      -    override val intercept: Double)
      +@Since("0.8.0")
      +class LassoModel @Since("1.1.0") (
      +    @Since("1.0.0") override val weights: Vector,
      +    @Since("0.8.0") override val intercept: Double)
         extends GeneralizedLinearModel(weights, intercept)
         with RegressionModel with Serializable with Saveable with PMMLExportable {
       
      @@ -44,6 +47,7 @@ class LassoModel (
           weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
         }
       
      +  @Since("1.3.0")
         override def save(sc: SparkContext, path: String): Unit = {
           GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept)
         }
      @@ -51,8 +55,10 @@ class LassoModel (
         override protected def formatVersion: String = "1.0"
       }
       
      +@Since("1.3.0")
       object LassoModel extends Loader[LassoModel] {
       
      +  @Since("1.3.0")
         override def load(sc: SparkContext, path: String): LassoModel = {
           val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
           // Hard-code class name string in case it changes in the future
      @@ -78,6 +84,7 @@ object LassoModel extends Loader[LassoModel] {
        * its corresponding right hand side label y.
        * See also the documentation for the precise formulation.
        */
      +@Since("0.8.0")
       class LassoWithSGD private (
           private var stepSize: Double,
           private var numIterations: Int,
      @@ -87,6 +94,7 @@ class LassoWithSGD private (
       
         private val gradient = new LeastSquaresGradient()
         private val updater = new L1Updater()
      +  @Since("0.8.0")
         override val optimizer = new GradientDescent(gradient, updater)
           .setStepSize(stepSize)
           .setNumIterations(numIterations)
      @@ -97,6 +105,7 @@ class LassoWithSGD private (
          * Construct a Lasso object with default parameters: {stepSize: 1.0, numIterations: 100,
          * regParam: 0.01, miniBatchFraction: 1.0}.
          */
      +  @Since("0.8.0")
         def this() = this(1.0, 100, 0.01, 1.0)
       
         override protected def createModel(weights: Vector, intercept: Double) = {
      @@ -106,7 +115,9 @@ class LassoWithSGD private (
       
       /**
        * Top-level methods for calling Lasso.
      + *
        */
      +@Since("0.8.0")
       object LassoWithSGD {
       
         /**
      @@ -123,7 +134,9 @@ object LassoWithSGD {
          * @param miniBatchFraction Fraction of data to be used per iteration.
          * @param initialWeights Initial set of weights to be used. Array should be equal in size to
          *        the number of features in the data.
      +   *
          */
      +  @Since("1.0.0")
         def train(
             input: RDD[LabeledPoint],
             numIterations: Int,
      @@ -146,7 +159,9 @@ object LassoWithSGD {
          * @param stepSize Step size to be used for each iteration of gradient descent.
          * @param regParam Regularization parameter.
          * @param miniBatchFraction Fraction of data to be used per iteration.
      +   *
          */
      +  @Since("0.8.0")
         def train(
             input: RDD[LabeledPoint],
             numIterations: Int,
      @@ -167,7 +182,9 @@ object LassoWithSGD {
          * @param regParam Regularization parameter.
          * @param numIterations Number of iterations of gradient descent to run.
          * @return a LassoModel which has the weights and offset from training.
      +   *
          */
      +  @Since("0.8.0")
         def train(
             input: RDD[LabeledPoint],
             numIterations: Int,
      @@ -185,7 +202,9 @@ object LassoWithSGD {
          *              matrix A as well as the corresponding right hand side label y
          * @param numIterations Number of iterations of gradient descent to run.
          * @return a LassoModel which has the weights and offset from training.
      +   *
          */
      +  @Since("0.8.0")
         def train(
             input: RDD[LabeledPoint],
             numIterations: Int): LassoModel = {
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
      index 9453c4f66c21..4996ace5df85 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
      @@ -18,6 +18,7 @@
       package org.apache.spark.mllib.regression
       
       import org.apache.spark.SparkContext
      +import org.apache.spark.annotation.Since
       import org.apache.spark.mllib.linalg.Vector
       import org.apache.spark.mllib.optimization._
       import org.apache.spark.mllib.pmml.PMMLExportable
      @@ -30,10 +31,12 @@ import org.apache.spark.rdd.RDD
        *
        * @param weights Weights computed for every feature.
        * @param intercept Intercept computed for this model.
      + *
        */
      -class LinearRegressionModel (
      -    override val weights: Vector,
      -    override val intercept: Double)
      +@Since("0.8.0")
      +class LinearRegressionModel @Since("1.1.0") (
      +    @Since("1.0.0") override val weights: Vector,
      +    @Since("0.8.0") override val intercept: Double)
         extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable
         with Saveable with PMMLExportable {
       
      @@ -44,6 +47,7 @@ class LinearRegressionModel (
           weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
         }
       
      +  @Since("1.3.0")
         override def save(sc: SparkContext, path: String): Unit = {
           GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept)
         }
      @@ -51,8 +55,10 @@ class LinearRegressionModel (
         override protected def formatVersion: String = "1.0"
       }
       
      +@Since("1.3.0")
       object LinearRegressionModel extends Loader[LinearRegressionModel] {
       
      +  @Since("1.3.0")
         override def load(sc: SparkContext, path: String): LinearRegressionModel = {
           val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
           // Hard-code class name string in case it changes in the future
      @@ -79,6 +85,7 @@ object LinearRegressionModel extends Loader[LinearRegressionModel] {
        * its corresponding right hand side label y.
        * See also the documentation for the precise formulation.
        */
      +@Since("0.8.0")
       class LinearRegressionWithSGD private[mllib] (
           private var stepSize: Double,
           private var numIterations: Int,
      @@ -87,6 +94,7 @@ class LinearRegressionWithSGD private[mllib] (
       
         private val gradient = new LeastSquaresGradient()
         private val updater = new SimpleUpdater()
      +  @Since("0.8.0")
         override val optimizer = new GradientDescent(gradient, updater)
           .setStepSize(stepSize)
           .setNumIterations(numIterations)
      @@ -96,6 +104,7 @@ class LinearRegressionWithSGD private[mllib] (
          * Construct a LinearRegression object with default parameters: {stepSize: 1.0,
          * numIterations: 100, miniBatchFraction: 1.0}.
          */
      +  @Since("0.8.0")
         def this() = this(1.0, 100, 1.0)
       
         override protected[mllib] def createModel(weights: Vector, intercept: Double) = {
      @@ -105,7 +114,9 @@ class LinearRegressionWithSGD private[mllib] (
       
       /**
        * Top-level methods for calling LinearRegression.
      + *
        */
      +@Since("0.8.0")
       object LinearRegressionWithSGD {
       
         /**
      @@ -121,7 +132,9 @@ object LinearRegressionWithSGD {
          * @param miniBatchFraction Fraction of data to be used per iteration.
          * @param initialWeights Initial set of weights to be used. Array should be equal in size to
          *        the number of features in the data.
      +   *
          */
      +  @Since("1.0.0")
         def train(
             input: RDD[LabeledPoint],
             numIterations: Int,
      @@ -142,7 +155,9 @@ object LinearRegressionWithSGD {
          * @param numIterations Number of iterations of gradient descent to run.
          * @param stepSize Step size to be used for each iteration of gradient descent.
          * @param miniBatchFraction Fraction of data to be used per iteration.
      +   *
          */
      +  @Since("0.8.0")
         def train(
             input: RDD[LabeledPoint],
             numIterations: Int,
      @@ -161,7 +176,9 @@ object LinearRegressionWithSGD {
          * @param stepSize Step size to be used for each iteration of Gradient Descent.
          * @param numIterations Number of iterations of gradient descent to run.
          * @return a LinearRegressionModel which has the weights and offset from training.
      +   *
          */
      +  @Since("0.8.0")
         def train(
             input: RDD[LabeledPoint],
             numIterations: Int,
      @@ -178,7 +195,9 @@ object LinearRegressionWithSGD {
          *              matrix A as well as the corresponding right hand side label y
          * @param numIterations Number of iterations of gradient descent to run.
          * @return a LinearRegressionModel which has the weights and offset from training.
      +   *
          */
      +  @Since("0.8.0")
         def train(
             input: RDD[LabeledPoint],
             numIterations: Int): LinearRegressionModel = {
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
      index 214ac4d0ed7d..0e72d6591ce8 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
      @@ -19,11 +19,12 @@ package org.apache.spark.mllib.regression
       
       import org.json4s.{DefaultFormats, JValue}
       
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.api.java.JavaRDD
       import org.apache.spark.mllib.linalg.Vector
       import org.apache.spark.rdd.RDD
       
      +@Since("0.8.0")
       @Experimental
       trait RegressionModel extends Serializable {
         /**
      @@ -31,7 +32,9 @@ trait RegressionModel extends Serializable {
          *
          * @param testData RDD representing data points to be predicted
          * @return RDD[Double] where each entry contains the corresponding prediction
      +   *
          */
      +  @Since("1.0.0")
         def predict(testData: RDD[Vector]): RDD[Double]
       
         /**
      @@ -39,14 +42,18 @@ trait RegressionModel extends Serializable {
          *
          * @param testData array representing a single data point
          * @return Double prediction from the trained model
      +   *
          */
      +  @Since("1.0.0")
         def predict(testData: Vector): Double
       
         /**
          * Predict values for examples stored in a JavaRDD.
          * @param testData JavaRDD representing data points to be predicted
          * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction
      +   *
          */
      +  @Since("1.0.0")
         def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
           predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
      index 7d28ffad45c9..0a44ff559d55 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
      @@ -18,6 +18,7 @@
       package org.apache.spark.mllib.regression
       
       import org.apache.spark.SparkContext
      +import org.apache.spark.annotation.Since
       import org.apache.spark.mllib.linalg.Vector
       import org.apache.spark.mllib.optimization._
       import org.apache.spark.mllib.pmml.PMMLExportable
      @@ -31,10 +32,12 @@ import org.apache.spark.rdd.RDD
        *
        * @param weights Weights computed for every feature.
        * @param intercept Intercept computed for this model.
      + *
        */
      -class RidgeRegressionModel (
      -    override val weights: Vector,
      -    override val intercept: Double)
      +@Since("0.8.0")
      +class RidgeRegressionModel @Since("1.1.0") (
      +    @Since("1.0.0") override val weights: Vector,
      +    @Since("0.8.0") override val intercept: Double)
         extends GeneralizedLinearModel(weights, intercept)
         with RegressionModel with Serializable with Saveable with PMMLExportable {
       
      @@ -45,6 +48,7 @@ class RidgeRegressionModel (
           weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
         }
       
      +  @Since("1.3.0")
         override def save(sc: SparkContext, path: String): Unit = {
           GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept)
         }
      @@ -52,8 +56,10 @@ class RidgeRegressionModel (
         override protected def formatVersion: String = "1.0"
       }
       
      +@Since("1.3.0")
       object RidgeRegressionModel extends Loader[RidgeRegressionModel] {
       
      +  @Since("1.3.0")
         override def load(sc: SparkContext, path: String): RidgeRegressionModel = {
           val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
           // Hard-code class name string in case it changes in the future
      @@ -79,6 +85,7 @@ object RidgeRegressionModel extends Loader[RidgeRegressionModel] {
        * its corresponding right hand side label y.
        * See also the documentation for the precise formulation.
        */
      +@Since("0.8.0")
       class RidgeRegressionWithSGD private (
           private var stepSize: Double,
           private var numIterations: Int,
      @@ -88,7 +95,7 @@ class RidgeRegressionWithSGD private (
       
         private val gradient = new LeastSquaresGradient()
         private val updater = new SquaredL2Updater()
      -
      +  @Since("0.8.0")
         override val optimizer = new GradientDescent(gradient, updater)
           .setStepSize(stepSize)
           .setNumIterations(numIterations)
      @@ -99,6 +106,7 @@ class RidgeRegressionWithSGD private (
          * Construct a RidgeRegression object with default parameters: {stepSize: 1.0, numIterations: 100,
          * regParam: 0.01, miniBatchFraction: 1.0}.
          */
      +  @Since("0.8.0")
         def this() = this(1.0, 100, 0.01, 1.0)
       
         override protected def createModel(weights: Vector, intercept: Double) = {
      @@ -108,7 +116,9 @@ class RidgeRegressionWithSGD private (
       
       /**
        * Top-level methods for calling RidgeRegression.
      + *
        */
      +@Since("0.8.0")
       object RidgeRegressionWithSGD {
       
         /**
      @@ -124,7 +134,9 @@ object RidgeRegressionWithSGD {
          * @param miniBatchFraction Fraction of data to be used per iteration.
          * @param initialWeights Initial set of weights to be used. Array should be equal in size to
          *        the number of features in the data.
      +   *
          */
      +  @Since("1.0.0")
         def train(
             input: RDD[LabeledPoint],
             numIterations: Int,
      @@ -146,7 +158,9 @@ object RidgeRegressionWithSGD {
          * @param stepSize Step size to be used for each iteration of gradient descent.
          * @param regParam Regularization parameter.
          * @param miniBatchFraction Fraction of data to be used per iteration.
      +   *
          */
      +  @Since("0.8.0")
         def train(
             input: RDD[LabeledPoint],
             numIterations: Int,
      @@ -166,7 +180,9 @@ object RidgeRegressionWithSGD {
          * @param regParam Regularization parameter.
          * @param numIterations Number of iterations of gradient descent to run.
          * @return a RidgeRegressionModel which has the weights and offset from training.
      +   *
          */
      +  @Since("0.8.0")
         def train(
             input: RDD[LabeledPoint],
             numIterations: Int,
      @@ -183,7 +199,9 @@ object RidgeRegressionWithSGD {
          * @param input RDD of (label, array of features) pairs.
          * @param numIterations Number of iterations of gradient descent to run.
          * @return a RidgeRegressionModel which has the weights and offset from training.
      +   *
          */
      +  @Since("0.8.0")
         def train(
             input: RDD[LabeledPoint],
             numIterations: Int): RidgeRegressionModel = {
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
      index 141052ba813e..73948b2d9851 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
      @@ -20,9 +20,9 @@ package org.apache.spark.mllib.regression
       import scala.reflect.ClassTag
       
       import org.apache.spark.Logging
      -import org.apache.spark.annotation.DeveloperApi
      +import org.apache.spark.annotation.{DeveloperApi, Since}
       import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
      -import org.apache.spark.mllib.linalg.{Vector, Vectors}
      +import org.apache.spark.mllib.linalg.Vector
       import org.apache.spark.streaming.api.java.{JavaDStream, JavaPairDStream}
       import org.apache.spark.streaming.dstream.DStream
       
      @@ -53,7 +53,9 @@ import org.apache.spark.streaming.dstream.DStream
        * It is also ok to call trainOn on different streams; this will update
        * the model using each of the different sources, in sequence.
        *
      + *
        */
      +@Since("1.1.0")
       @DeveloperApi
       abstract class StreamingLinearAlgorithm[
           M <: GeneralizedLinearModel,
      @@ -65,7 +67,11 @@ abstract class StreamingLinearAlgorithm[
         /** The algorithm to use for updating. */
         protected val algorithm: A
       
      -  /** Return the latest model. */
      +  /**
      +   * Return the latest model.
      +   *
      +   */
      +  @Since("1.1.0")
         def latestModel(): M = {
           model.get
         }
      @@ -78,6 +84,7 @@ abstract class StreamingLinearAlgorithm[
          *
          * @param data DStream containing labeled data
          */
      +  @Since("1.1.0")
         def trainOn(data: DStream[LabeledPoint]): Unit = {
           if (model.isEmpty) {
             throw new IllegalArgumentException("Model must be initialized before starting training.")
      @@ -95,7 +102,10 @@ abstract class StreamingLinearAlgorithm[
           }
         }
       
      -  /** Java-friendly version of `trainOn`. */
      +  /**
      +   * Java-friendly version of `trainOn`.
      +   */
      +  @Since("1.3.0")
         def trainOn(data: JavaDStream[LabeledPoint]): Unit = trainOn(data.dstream)
       
         /**
      @@ -103,7 +113,9 @@ abstract class StreamingLinearAlgorithm[
          *
          * @param data DStream containing feature vectors
          * @return DStream containing predictions
      +   *
          */
      +  @Since("1.1.0")
         def predictOn(data: DStream[Vector]): DStream[Double] = {
           if (model.isEmpty) {
             throw new IllegalArgumentException("Model must be initialized before starting prediction.")
      @@ -111,7 +123,11 @@ abstract class StreamingLinearAlgorithm[
           data.map{x => model.get.predict(x)}
         }
       
      -  /** Java-friendly version of `predictOn`. */
      +  /**
      +   * Java-friendly version of `predictOn`.
      +   *
      +   */
      +  @Since("1.3.0")
         def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Double] = {
           JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Double]])
         }
      @@ -121,7 +137,9 @@ abstract class StreamingLinearAlgorithm[
          * @param data DStream containing feature vectors
          * @tparam K key type
          * @return DStream containing the input keys and the predictions as values
      +   *
          */
      +  @Since("1.1.0")
         def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Double)] = {
           if (model.isEmpty) {
             throw new IllegalArgumentException("Model must be initialized before starting prediction")
      @@ -130,7 +148,11 @@ abstract class StreamingLinearAlgorithm[
         }
       
       
      -  /** Java-friendly version of `predictOnValues`. */
      +  /**
      +   * Java-friendly version of `predictOnValues`.
      +   *
      +   */
      +  @Since("1.3.0")
         def predictOnValues[K](data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Double] = {
           implicit val tag = fakeClassTag[K]
           JavaPairDStream.fromPairDStream(
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
      index 235e043c7754..fe1d487cdd07 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
      @@ -17,7 +17,7 @@
       
       package org.apache.spark.mllib.regression
       
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.mllib.linalg.Vector
       
       /**
      @@ -39,9 +39,9 @@ import org.apache.spark.mllib.linalg.Vector
        *    .setNumIterations(10)
        *    .setInitialWeights(Vectors.dense(...))
        *    .trainOn(DStream)
      - *
        */
       @Experimental
      +@Since("1.1.0")
       class StreamingLinearRegressionWithSGD private[mllib] (
           private var stepSize: Double,
           private var numIterations: Int,
      @@ -55,34 +55,56 @@ class StreamingLinearRegressionWithSGD private[mllib] (
          * Initial weights must be set before using trainOn or predictOn
          * (see `StreamingLinearAlgorithm`)
          */
      +  @Since("1.1.0")
         def this() = this(0.1, 50, 1.0)
       
      +  @Since("1.1.0")
         val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction)
       
         protected var model: Option[LinearRegressionModel] = None
       
      -  /** Set the step size for gradient descent. Default: 0.1. */
      +  /**
      +   * Set the step size for gradient descent. Default: 0.1.
      +   */
      +  @Since("1.1.0")
         def setStepSize(stepSize: Double): this.type = {
           this.algorithm.optimizer.setStepSize(stepSize)
           this
         }
       
      -  /** Set the number of iterations of gradient descent to run per update. Default: 50. */
      +  /**
      +   * Set the number of iterations of gradient descent to run per update. Default: 50.
      +   */
      +  @Since("1.1.0")
         def setNumIterations(numIterations: Int): this.type = {
           this.algorithm.optimizer.setNumIterations(numIterations)
           this
         }
       
      -  /** Set the fraction of each batch to use for updates. Default: 1.0. */
      +  /**
      +   * Set the fraction of each batch to use for updates. Default: 1.0.
      +   */
      +  @Since("1.1.0")
         def setMiniBatchFraction(miniBatchFraction: Double): this.type = {
           this.algorithm.optimizer.setMiniBatchFraction(miniBatchFraction)
           this
         }
       
      -  /** Set the initial weights. */
      +  /**
      +   * Set the initial weights.
      +   */
      +  @Since("1.1.0")
         def setInitialWeights(initialWeights: Vector): this.type = {
           this.model = Some(algorithm.createModel(initialWeights, 0.0))
           this
         }
       
      +  /**
      +   * Set the convergence tolerance. Default: 0.001.
      +   */
      +  @Since("1.5.0")
      +  def setConvergenceTol(tolerance: Double): this.type = {
      +    this.algorithm.optimizer.setConvergenceTol(tolerance)
      +    this
      +  }
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala
      index 58a50f9c19f1..4a856f7f3434 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala
      @@ -19,7 +19,7 @@ package org.apache.spark.mllib.stat
       
       import com.github.fommil.netlib.BLAS.{getInstance => blas}
       
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.api.java.JavaRDD
       import org.apache.spark.rdd.RDD
       
      @@ -38,6 +38,7 @@ import org.apache.spark.rdd.RDD
        * val densities = kd.estimate(Array(-1.0, 2.0, 5.0))
        * }}}
        */
      +@Since("1.4.0")
       @Experimental
       class KernelDensity extends Serializable {
       
      @@ -52,6 +53,7 @@ class KernelDensity extends Serializable {
         /**
          * Sets the bandwidth (standard deviation) of the Gaussian kernel (default: `1.0`).
          */
      +  @Since("1.4.0")
         def setBandwidth(bandwidth: Double): this.type = {
           require(bandwidth > 0, s"Bandwidth must be positive, but got $bandwidth.")
           this.bandwidth = bandwidth
      @@ -61,6 +63,7 @@ class KernelDensity extends Serializable {
         /**
          * Sets the sample to use for density estimation.
          */
      +  @Since("1.4.0")
         def setSample(sample: RDD[Double]): this.type = {
           this.sample = sample
           this
      @@ -69,6 +72,7 @@ class KernelDensity extends Serializable {
         /**
          * Sets the sample to use for density estimation (for Java users).
          */
      +  @Since("1.4.0")
         def setSample(sample: JavaRDD[java.lang.Double]): this.type = {
           this.sample = sample.rdd.asInstanceOf[RDD[Double]]
           this
      @@ -77,6 +81,7 @@ class KernelDensity extends Serializable {
         /**
          * Estimates probability density function at the given array of points.
          */
      +  @Since("1.4.0")
         def estimate(points: Array[Double]): Array[Double] = {
           val sample = this.sample
           val bandwidth = this.bandwidth
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
      index d321cc554c1c..201333c3690d 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
      @@ -17,23 +17,27 @@
       
       package org.apache.spark.mllib.stat
       
      -import org.apache.spark.annotation.DeveloperApi
      +import org.apache.spark.annotation.{DeveloperApi, Since}
       import org.apache.spark.mllib.linalg.{Vectors, Vector}
       
       /**
        * :: DeveloperApi ::
        * MultivariateOnlineSummarizer implements [[MultivariateStatisticalSummary]] to compute the mean,
      - * variance, minimum, maximum, counts, and nonzero counts for samples in sparse or dense vector
      + * variance, minimum, maximum, counts, and nonzero counts for instances in sparse or dense vector
        * format in a online fashion.
        *
        * Two MultivariateOnlineSummarizer can be merged together to have a statistical summary of
        * the corresponding joint dataset.
        *
      - * A numerically stable algorithm is implemented to compute sample mean and variance:
      + * A numerically stable algorithm is implemented to compute the mean and variance of instances:
        * Reference: [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]]
        * Zero elements (including explicit zero values) are skipped when calling add(),
        * to have time complexity O(nnz) instead of O(n) for each column.
      + *
      + * For weighted instances, the unbiased estimation of variance is defined by the reliability
      + * weights: [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]].
        */
      +@Since("1.1.0")
       @DeveloperApi
       class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable {
       
      @@ -43,6 +47,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
         private var currM2: Array[Double] = _
         private var currL1: Array[Double] = _
         private var totalCnt: Long = 0
      +  private var weightSum: Double = 0.0
      +  private var weightSquareSum: Double = 0.0
         private var nnz: Array[Double] = _
         private var currMax: Array[Double] = _
         private var currMin: Array[Double] = _
      @@ -53,10 +59,16 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
          * @param sample The sample in dense/sparse vector format to be added into this summarizer.
          * @return This MultivariateOnlineSummarizer object.
          */
      -  def add(sample: Vector): this.type = {
      +  @Since("1.1.0")
      +  def add(sample: Vector): this.type = add(sample, 1.0)
      +
      +  private[spark] def add(instance: Vector, weight: Double): this.type = {
      +    require(weight >= 0.0, s"sample weight, ${weight} has to be >= 0.0")
      +    if (weight == 0.0) return this
      +
           if (n == 0) {
      -      require(sample.size > 0, s"Vector should have dimension larger than zero.")
      -      n = sample.size
      +      require(instance.size > 0, s"Vector should have dimension larger than zero.")
      +      n = instance.size
       
             currMean = Array.ofDim[Double](n)
             currM2n = Array.ofDim[Double](n)
      @@ -67,8 +79,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
             currMin = Array.fill[Double](n)(Double.MaxValue)
           }
       
      -    require(n == sample.size, s"Dimensions mismatch when adding new sample." +
      -      s" Expecting $n but got ${sample.size}.")
      +    require(n == instance.size, s"Dimensions mismatch when adding new sample." +
      +      s" Expecting $n but got ${instance.size}.")
       
           val localCurrMean = currMean
           val localCurrM2n = currM2n
      @@ -77,7 +89,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
           val localNnz = nnz
           val localCurrMax = currMax
           val localCurrMin = currMin
      -    sample.foreachActive { (index, value) =>
      +    instance.foreachActive { (index, value) =>
             if (value != 0.0) {
               if (localCurrMax(index) < value) {
                 localCurrMax(index) = value
      @@ -88,15 +100,17 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
       
               val prevMean = localCurrMean(index)
               val diff = value - prevMean
      -        localCurrMean(index) = prevMean + diff / (localNnz(index) + 1.0)
      -        localCurrM2n(index) += (value - localCurrMean(index)) * diff
      -        localCurrM2(index) += value * value
      -        localCurrL1(index) += math.abs(value)
      +        localCurrMean(index) = prevMean + weight * diff / (localNnz(index) + weight)
      +        localCurrM2n(index) += weight * (value - localCurrMean(index)) * diff
      +        localCurrM2(index) += weight * value * value
      +        localCurrL1(index) += weight * math.abs(value)
       
      -        localNnz(index) += 1.0
      +        localNnz(index) += weight
             }
           }
       
      +    weightSum += weight
      +    weightSquareSum += weight * weight
           totalCnt += 1
           this
         }
      @@ -108,11 +122,14 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
          * @param other The other MultivariateOnlineSummarizer to be merged.
          * @return This MultivariateOnlineSummarizer object.
          */
      +  @Since("1.1.0")
         def merge(other: MultivariateOnlineSummarizer): this.type = {
      -    if (this.totalCnt != 0 && other.totalCnt != 0) {
      +    if (this.weightSum != 0.0 && other.weightSum != 0.0) {
             require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " +
               s"Expecting $n but got ${other.n}.")
             totalCnt += other.totalCnt
      +      weightSum += other.weightSum
      +      weightSquareSum += other.weightSquareSum
             var i = 0
             while (i < n) {
               val thisNnz = nnz(i)
      @@ -135,13 +152,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
               nnz(i) = totalNnz
               i += 1
             }
      -    } else if (totalCnt == 0 && other.totalCnt != 0) {
      +    } else if (weightSum == 0.0 && other.weightSum != 0.0) {
             this.n = other.n
             this.currMean = other.currMean.clone()
             this.currM2n = other.currM2n.clone()
             this.currM2 = other.currM2.clone()
             this.currL1 = other.currL1.clone()
             this.totalCnt = other.totalCnt
      +      this.weightSum = other.weightSum
      +      this.weightSquareSum = other.weightSquareSum
             this.nnz = other.nnz.clone()
             this.currMax = other.currMax.clone()
             this.currMin = other.currMin.clone()
      @@ -149,24 +168,34 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
           this
         }
       
      +  /**
      +   * Sample mean of each dimension.
      +   *
      +   */
      +  @Since("1.1.0")
         override def mean: Vector = {
      -    require(totalCnt > 0, s"Nothing has been added to this summarizer.")
      +    require(weightSum > 0, s"Nothing has been added to this summarizer.")
       
           val realMean = Array.ofDim[Double](n)
           var i = 0
           while (i < n) {
      -      realMean(i) = currMean(i) * (nnz(i) / totalCnt)
      +      realMean(i) = currMean(i) * (nnz(i) / weightSum)
             i += 1
           }
           Vectors.dense(realMean)
         }
       
      +  /**
      +   * Unbiased estimate of sample variance of each dimension.
      +   *
      +   */
      +  @Since("1.1.0")
         override def variance: Vector = {
      -    require(totalCnt > 0, s"Nothing has been added to this summarizer.")
      +    require(weightSum > 0, s"Nothing has been added to this summarizer.")
       
           val realVariance = Array.ofDim[Double](n)
       
      -    val denominator = totalCnt - 1.0
      +    val denominator = weightSum - (weightSquareSum / weightSum)
       
           // Sample variance is computed, if the denominator is less than 0, the variance is just 0.
           if (denominator > 0.0) {
      @@ -174,47 +203,71 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
             var i = 0
             val len = currM2n.length
             while (i < len) {
      -        realVariance(i) =
      -          currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
      -        realVariance(i) /= denominator
      +        realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) *
      +          (weightSum - nnz(i)) / weightSum) / denominator
               i += 1
             }
           }
           Vectors.dense(realVariance)
         }
       
      +  /**
      +   * Sample size.
      +   *
      +   */
      +  @Since("1.1.0")
         override def count: Long = totalCnt
       
      +  /**
      +   * Number of nonzero elements in each dimension.
      +   *
      +   */
      +  @Since("1.1.0")
         override def numNonzeros: Vector = {
      -    require(totalCnt > 0, s"Nothing has been added to this summarizer.")
      +    require(weightSum > 0, s"Nothing has been added to this summarizer.")
       
           Vectors.dense(nnz)
         }
       
      +  /**
      +   * Maximum value of each dimension.
      +   *
      +   */
      +  @Since("1.1.0")
         override def max: Vector = {
      -    require(totalCnt > 0, s"Nothing has been added to this summarizer.")
      +    require(weightSum > 0, s"Nothing has been added to this summarizer.")
       
           var i = 0
           while (i < n) {
      -      if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
      +      if ((nnz(i) < weightSum) && (currMax(i) < 0.0)) currMax(i) = 0.0
             i += 1
           }
           Vectors.dense(currMax)
         }
       
      +  /**
      +   * Minimum value of each dimension.
      +   *
      +   */
      +  @Since("1.1.0")
         override def min: Vector = {
      -    require(totalCnt > 0, s"Nothing has been added to this summarizer.")
      +    require(weightSum > 0, s"Nothing has been added to this summarizer.")
       
           var i = 0
           while (i < n) {
      -      if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
      +      if ((nnz(i) < weightSum) && (currMin(i) > 0.0)) currMin(i) = 0.0
             i += 1
           }
           Vectors.dense(currMin)
         }
       
      +  /**
      +   * L2 (Euclidian) norm of each dimension.
      +   *
      +   */
      +  @Since("1.2.0")
         override def normL2: Vector = {
      -    require(totalCnt > 0, s"Nothing has been added to this summarizer.")
      +    require(weightSum > 0, s"Nothing has been added to this summarizer.")
       
           val realMagnitude = Array.ofDim[Double](n)
       
      @@ -227,8 +280,13 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
           Vectors.dense(realMagnitude)
         }
       
      +  /**
      +   * L1 norm of each dimension.
      +   *
      +   */
      +  @Since("1.2.0")
         override def normL1: Vector = {
      -    require(totalCnt > 0, s"Nothing has been added to this summarizer.")
      +    require(weightSum > 0, s"Nothing has been added to this summarizer.")
       
           Vectors.dense(currL1)
         }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala
      index 6a364c93284a..39a16fb743d6 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala
      @@ -17,50 +17,60 @@
       
       package org.apache.spark.mllib.stat
       
      +import org.apache.spark.annotation.Since
       import org.apache.spark.mllib.linalg.Vector
       
       /**
        * Trait for multivariate statistical summary of a data matrix.
        */
      +@Since("1.0.0")
       trait MultivariateStatisticalSummary {
       
         /**
          * Sample mean vector.
          */
      +  @Since("1.0.0")
         def mean: Vector
       
         /**
          * Sample variance vector. Should return a zero vector if the sample size is 1.
          */
      +  @Since("1.0.0")
         def variance: Vector
       
         /**
          * Sample size.
          */
      +  @Since("1.0.0")
         def count: Long
       
         /**
          * Number of nonzero elements (including explicitly presented zero values) in each column.
          */
      +  @Since("1.0.0")
         def numNonzeros: Vector
       
         /**
          * Maximum value of each column.
          */
      +  @Since("1.0.0")
         def max: Vector
       
         /**
          * Minimum value of each column.
          */
      +  @Since("1.0.0")
         def min: Vector
       
         /**
          * Euclidean magnitude of each column
          */
      +  @Since("1.2.0")
         def normL2: Vector
       
         /**
          * L1 norm of each column
          */
      +  @Since("1.2.0")
         def normL1: Vector
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
      index 900007ec6bc7..84d64a5bfb38 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
      @@ -17,19 +17,23 @@
       
       package org.apache.spark.mllib.stat
       
      -import org.apache.spark.annotation.Experimental
      -import org.apache.spark.api.java.JavaRDD
      +import scala.annotation.varargs
      +
      +import org.apache.spark.annotation.{Experimental, Since}
      +import org.apache.spark.api.java.{JavaRDD, JavaDoubleRDD}
       import org.apache.spark.mllib.linalg.distributed.RowMatrix
       import org.apache.spark.mllib.linalg.{Matrix, Vector}
       import org.apache.spark.mllib.regression.LabeledPoint
       import org.apache.spark.mllib.stat.correlation.Correlations
      -import org.apache.spark.mllib.stat.test.{ChiSqTest, ChiSqTestResult}
      +import org.apache.spark.mllib.stat.test.{ChiSqTest, ChiSqTestResult, KolmogorovSmirnovTest,
      +  KolmogorovSmirnovTestResult}
       import org.apache.spark.rdd.RDD
       
       /**
        * :: Experimental ::
        * API for statistical functions in MLlib.
        */
      +@Since("1.1.0")
       @Experimental
       object Statistics {
       
      @@ -39,6 +43,7 @@ object Statistics {
          * @param X an RDD[Vector] for which column-wise summary statistics are to be computed.
          * @return [[MultivariateStatisticalSummary]] object containing column-wise summary statistics.
          */
      +  @Since("1.1.0")
         def colStats(X: RDD[Vector]): MultivariateStatisticalSummary = {
           new RowMatrix(X).computeColumnSummaryStatistics()
         }
      @@ -50,6 +55,7 @@ object Statistics {
          * @param X an RDD[Vector] for which the correlation matrix is to be computed.
          * @return Pearson correlation matrix comparing columns in X.
          */
      +  @Since("1.1.0")
         def corr(X: RDD[Vector]): Matrix = Correlations.corrMatrix(X)
       
         /**
      @@ -66,6 +72,7 @@ object Statistics {
          *               Supported: `pearson` (default), `spearman`
          * @return Correlation matrix comparing columns in X.
          */
      +  @Since("1.1.0")
         def corr(X: RDD[Vector], method: String): Matrix = Correlations.corrMatrix(X, method)
       
         /**
      @@ -79,9 +86,13 @@ object Statistics {
          * @param y RDD[Double] of the same cardinality as x.
          * @return A Double containing the Pearson correlation between the two input RDD[Double]s
          */
      +  @Since("1.1.0")
         def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y)
       
      -  /** Java-friendly version of [[corr()]] */
      +  /**
      +   * Java-friendly version of [[corr()]]
      +   */
      +  @Since("1.4.1")
         def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double]): Double =
           corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]])
       
      @@ -99,9 +110,13 @@ object Statistics {
          * @return A Double containing the correlation between the two input RDD[Double]s using the
          *         specified method.
          */
      +  @Since("1.1.0")
         def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method)
       
      -  /** Java-friendly version of [[corr()]] */
      +  /**
      +   * Java-friendly version of [[corr()]]
      +   */
      +  @Since("1.4.1")
         def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double], method: String): Double =
           corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]], method)
       
      @@ -119,6 +134,7 @@ object Statistics {
          * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value,
          *         the method used, and the null hypothesis.
          */
      +  @Since("1.1.0")
         def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = {
           ChiSqTest.chiSquared(observed, expected)
         }
      @@ -133,6 +149,7 @@ object Statistics {
          * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value,
          *         the method used, and the null hypothesis.
          */
      +  @Since("1.1.0")
         def chiSqTest(observed: Vector): ChiSqTestResult = ChiSqTest.chiSquared(observed)
       
         /**
      @@ -143,6 +160,7 @@ object Statistics {
          * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value,
          *         the method used, and the null hypothesis.
          */
      +  @Since("1.1.0")
         def chiSqTest(observed: Matrix): ChiSqTestResult = ChiSqTest.chiSquaredMatrix(observed)
       
         /**
      @@ -155,7 +173,59 @@ object Statistics {
          * @return an array containing the ChiSquaredTestResult for every feature against the label.
          *         The order of the elements in the returned array reflects the order of input features.
          */
      +  @Since("1.1.0")
         def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = {
           ChiSqTest.chiSquaredFeatures(data)
         }
      +
      +  /** Java-friendly version of [[chiSqTest()]] */
      +  @Since("1.5.0")
      +  def chiSqTest(data: JavaRDD[LabeledPoint]): Array[ChiSqTestResult] = chiSqTest(data.rdd)
      +
      +  /**
      +   * Conduct the two-sided Kolmogorov-Smirnov (KS) test for data sampled from a
      +   * continuous distribution. By comparing the largest difference between the empirical cumulative
      +   * distribution of the sample data and the theoretical distribution we can provide a test for the
      +   * the null hypothesis that the sample data comes from that theoretical distribution.
      +   * For more information on KS Test:
      +   * @see [[https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test]]
      +   *
      +   * @param data an `RDD[Double]` containing the sample of data to test
      +   * @param cdf a `Double => Double` function to calculate the theoretical CDF at a given value
      +   * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] object containing test
      +   *        statistic, p-value, and null hypothesis.
      +   */
      +  @Since("1.5.0")
      +  def kolmogorovSmirnovTest(data: RDD[Double], cdf: Double => Double)
      +    : KolmogorovSmirnovTestResult = {
      +    KolmogorovSmirnovTest.testOneSample(data, cdf)
      +  }
      +
      +  /**
      +   * Convenience function to conduct a one-sample, two-sided Kolmogorov-Smirnov test for probability
      +   * distribution equality. Currently supports the normal distribution, taking as parameters
      +   * the mean and standard deviation.
      +   * (distName = "norm")
      +   * @param data an `RDD[Double]` containing the sample of data to test
      +   * @param distName a `String` name for a theoretical distribution
      +   * @param params `Double*` specifying the parameters to be used for the theoretical distribution
      +   * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] object containing test
      +   *        statistic, p-value, and null hypothesis.
      +   */
      +  @Since("1.5.0")
      +  @varargs
      +  def kolmogorovSmirnovTest(data: RDD[Double], distName: String, params: Double*)
      +    : KolmogorovSmirnovTestResult = {
      +    KolmogorovSmirnovTest.testOneSample(data, distName, params: _*)
      +  }
      +
      +  /** Java-friendly version of [[kolmogorovSmirnovTest()]] */
      +  @Since("1.5.0")
      +  @varargs
      +  def kolmogorovSmirnovTest(
      +      data: JavaDoubleRDD,
      +      distName: String,
      +      params: Double*): KolmogorovSmirnovTestResult = {
      +    kolmogorovSmirnovTest(data.rdd.asInstanceOf[RDD[Double]], distName, params: _*)
      +  }
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
      index cf51b24ff777..92a5af708d04 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
      @@ -19,7 +19,7 @@ package org.apache.spark.mllib.stat.distribution
       
       import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym, Vector => BV}
       
      -import org.apache.spark.annotation.DeveloperApi;
      +import org.apache.spark.annotation.{DeveloperApi, Since}
       import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix}
       import org.apache.spark.mllib.util.MLUtils
       
      @@ -33,10 +33,11 @@ import org.apache.spark.mllib.util.MLUtils
        * @param mu The mean vector of the distribution
        * @param sigma The covariance matrix of the distribution
        */
      +@Since("1.3.0")
       @DeveloperApi
      -class MultivariateGaussian (
      -    val mu: Vector,
      -    val sigma: Matrix) extends Serializable {
      +class MultivariateGaussian @Since("1.3.0") (
      +    @Since("1.3.0") val mu: Vector,
      +    @Since("1.3.0") val sigma: Matrix) extends Serializable {
       
         require(sigma.numCols == sigma.numRows, "Covariance matrix must be square")
         require(mu.size == sigma.numCols, "Mean vector length must match covariance matrix size")
      @@ -60,12 +61,16 @@ class MultivariateGaussian (
          */
         private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants
       
      -  /** Returns density of this multivariate Gaussian at given point, x */
      +  /** Returns density of this multivariate Gaussian at given point, x
      +    */
      +   @Since("1.3.0")
         def pdf(x: Vector): Double = {
           pdf(x.toBreeze)
         }
       
      -  /** Returns the log-density of this multivariate Gaussian at given point, x */
      +  /** Returns the log-density of this multivariate Gaussian at given point, x
      +    */
      +   @Since("1.3.0")
         def logpdf(x: Vector): Double = {
           logpdf(x.toBreeze)
         }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala
      new file mode 100644
      index 000000000000..2b3ed6df486c
      --- /dev/null
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala
      @@ -0,0 +1,194 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.mllib.stat.test
      +
      +import scala.annotation.varargs
      +
      +import org.apache.commons.math3.distribution.{NormalDistribution, RealDistribution}
      +import org.apache.commons.math3.stat.inference.{KolmogorovSmirnovTest => CommonMathKolmogorovSmirnovTest}
      +
      +import org.apache.spark.Logging
      +import org.apache.spark.rdd.RDD
      +
      +/**
      + * Conduct the two-sided Kolmogorov Smirnov (KS) test for data sampled from a
      + * continuous distribution. By comparing the largest difference between the empirical cumulative
      + * distribution of the sample data and the theoretical distribution we can provide a test for the
      + * the null hypothesis that the sample data comes from that theoretical distribution.
      + * For more information on KS Test:
      + * @see [[https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test]]
      + *
      + * Implementation note: We seek to implement the KS test with a minimal number of distributed
      + * passes. We sort the RDD, and then perform the following operations on a per-partition basis:
      + * calculate an empirical cumulative distribution value for each observation, and a theoretical
      + * cumulative distribution value. We know the latter to be correct, while the former will be off by
      + * a constant (how large the constant is depends on how many values precede it in other partitions).
      + * However, given that this constant simply shifts the empirical CDF upwards, but doesn't
      + * change its shape, and furthermore, that constant is the same within a given partition, we can
      + * pick 2 values in each partition that can potentially resolve to the largest global distance.
      + * Namely, we pick the minimum distance and the maximum distance. Additionally, we keep track of how
      + * many elements are in each partition. Once these three values have been returned for every
      + * partition, we can collect and operate locally. Locally, we can now adjust each distance by the
      + * appropriate constant (the cumulative sum of number of elements in the prior partitions divided by
      + * thedata set size). Finally, we take the maximum absolute value, and this is the statistic.
      + */
      +private[stat] object KolmogorovSmirnovTest extends Logging {
      +
      +  // Null hypothesis for the type of KS test to be included in the result.
      +  object NullHypothesis extends Enumeration {
      +    type NullHypothesis = Value
      +    val OneSampleTwoSided = Value("Sample follows theoretical distribution")
      +  }
      +
      +  /**
      +   * Runs a KS test for 1 set of sample data, comparing it to a theoretical distribution
      +   * @param data `RDD[Double]` data on which to run test
      +   * @param cdf `Double => Double` function to calculate the theoretical CDF
      +   * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] summarizing the test
      +   *        results (p-value, statistic, and null hypothesis)
      +   */
      +  def testOneSample(data: RDD[Double], cdf: Double => Double): KolmogorovSmirnovTestResult = {
      +    val n = data.count().toDouble
      +    val localData = data.sortBy(x => x).mapPartitions { part =>
      +      val partDiffs = oneSampleDifferences(part, n, cdf) // local distances
      +      searchOneSampleCandidates(partDiffs) // candidates: local extrema
      +    }.collect()
      +    val ksStat = searchOneSampleStatistic(localData, n) // result: global extreme
      +    evalOneSampleP(ksStat, n.toLong)
      +  }
      +
      +  /**
      +   * Runs a KS test for 1 set of sample data, comparing it to a theoretical distribution
      +   * @param data `RDD[Double]` data on which to run test
      +   * @param distObj `RealDistribution` a theoretical distribution
      +   * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] summarizing the test
      +   *        results (p-value, statistic, and null hypothesis)
      +   */
      +  def testOneSample(data: RDD[Double], distObj: RealDistribution): KolmogorovSmirnovTestResult = {
      +    val cdf = (x: Double) => distObj.cumulativeProbability(x)
      +    testOneSample(data, cdf)
      +  }
      +
      +  /**
      +   * Calculate unadjusted distances between the empirical CDF and the theoretical CDF in a
      +   * partition
      +   * @param partData `Iterator[Double]` 1 partition of a sorted RDD
      +   * @param n `Double` the total size of the RDD
      +   * @param cdf `Double => Double` a function the calculates the theoretical CDF of a value
      +   * @return `Iterator[(Double, Double)] `Unadjusted (ie. off by a constant) potential extrema
      +   *        in a partition. The first element corresponds to the (empirical CDF - 1/N) - CDF,
      +   *        the second element corresponds to empirical CDF - CDF.  We can then search the resulting
      +   *        iterator for the minimum of the first and the maximum of the second element, and provide
      +   *        this as a partition's candidate extrema
      +   */
      +  private def oneSampleDifferences(partData: Iterator[Double], n: Double, cdf: Double => Double)
      +    : Iterator[(Double, Double)] = {
      +    // zip data with index (within that partition)
      +    // calculate local (unadjusted) empirical CDF and subtract CDF
      +    partData.zipWithIndex.map { case (v, ix) =>
      +      // dp and dl are later adjusted by constant, when global info is available
      +      val dp = (ix + 1) / n
      +      val dl = ix / n
      +      val cdfVal = cdf(v)
      +      (dl - cdfVal, dp - cdfVal)
      +    }
      +  }
      +
      +  /**
      +   * Search the unadjusted differences in a partition and return the
      +   * two extrema (furthest below and furthest above CDF), along with a count of elements in that
      +   * partition
      +   * @param partDiffs `Iterator[(Double, Double)]` the unadjusted differences between empirical CDF
      +   *                 and CDFin a partition, which come as a tuple of
      +   *                 (empirical CDF - 1/N - CDF, empirical CDF - CDF)
      +   * @return `Iterator[(Double, Double, Double)]` the local extrema and a count of elements
      +   */
      +  private def searchOneSampleCandidates(partDiffs: Iterator[(Double, Double)])
      +    : Iterator[(Double, Double, Double)] = {
      +    val initAcc = (Double.MaxValue, Double.MinValue, 0.0)
      +    val pResults = partDiffs.foldLeft(initAcc) { case ((pMin, pMax, pCt), (dl, dp)) =>
      +      (math.min(pMin, dl), math.max(pMax, dp), pCt + 1)
      +    }
      +    val results = if (pResults == initAcc) Array[(Double, Double, Double)]() else Array(pResults)
      +    results.iterator
      +  }
      +
      +  /**
      +   * Find the global maximum distance between empirical CDF and CDF (i.e. the KS statistic) after
      +   * adjusting local extrema estimates from individual partitions with the amount of elements in
      +   * preceding partitions
      +   * @param localData `Array[(Double, Double, Double)]` A local array containing the collected
      +   *                 results of `searchOneSampleCandidates` across all partitions
      +   * @param n `Double`The size of the RDD
      +   * @return The one-sample Kolmogorov Smirnov Statistic
      +   */
      +  private def searchOneSampleStatistic(localData: Array[(Double, Double, Double)], n: Double)
      +    : Double = {
      +    val initAcc = (Double.MinValue, 0.0)
      +    // adjust differences based on the number of elements preceding it, which should provide
      +    // the correct distance between empirical CDF and CDF
      +    val results = localData.foldLeft(initAcc) { case ((prevMax, prevCt), (minCand, maxCand, ct)) =>
      +      val adjConst = prevCt / n
      +      val dist1 = math.abs(minCand + adjConst)
      +      val dist2 = math.abs(maxCand + adjConst)
      +      val maxVal = Array(prevMax, dist1, dist2).max
      +      (maxVal, prevCt + ct)
      +    }
      +    results._1
      +  }
      +
      +  /**
      +   * A convenience function that allows running the KS test for 1 set of sample data against
      +   * a named distribution
      +   * @param data the sample data that we wish to evaluate
      +   * @param distName the name of the theoretical distribution
      +   * @param params Variable length parameter for distribution's parameters
      +   * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] summarizing the
      +   *        test results (p-value, statistic, and null hypothesis)
      +   */
      +  @varargs
      +  def testOneSample(data: RDD[Double], distName: String, params: Double*)
      +    : KolmogorovSmirnovTestResult = {
      +    val distObj =
      +      distName match {
      +        case "norm" => {
      +          if (params.nonEmpty) {
      +            // parameters are passed, then can only be 2
      +            require(params.length == 2, "Normal distribution requires mean and standard " +
      +              "deviation as parameters")
      +            new NormalDistribution(params(0), params(1))
      +          } else {
      +            // if no parameters passed in initializes to standard normal
      +            logInfo("No parameters specified for normal distribution," +
      +              "initialized to standard normal (i.e. N(0, 1))")
      +            new NormalDistribution(0, 1)
      +          }
      +        }
      +        case  _ => throw new UnsupportedOperationException(s"$distName not yet supported through" +
      +          s" convenience method. Current options are:['norm'].")
      +      }
      +
      +    testOneSample(data, distObj)
      +  }
      +
      +  private def evalOneSampleP(ksStat: Double, n: Long): KolmogorovSmirnovTestResult = {
      +    val pval = 1 - new CommonMathKolmogorovSmirnovTest().cdf(ksStat, n.toInt)
      +    new KolmogorovSmirnovTestResult(pval, ksStat, NullHypothesis.OneSampleTwoSided.toString)
      +  }
      +}
      +
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala
      index 4784f9e94790..d01b3707be94 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala
      @@ -17,7 +17,7 @@
       
       package org.apache.spark.mllib.stat.test
       
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       
       /**
        * :: Experimental ::
      @@ -25,28 +25,33 @@ import org.apache.spark.annotation.Experimental
        * @tparam DF Return type of `degreesOfFreedom`.
        */
       @Experimental
      +@Since("1.1.0")
       trait TestResult[DF] {
       
         /**
          * The probability of obtaining a test statistic result at least as extreme as the one that was
          * actually observed, assuming that the null hypothesis is true.
          */
      +  @Since("1.1.0")
         def pValue: Double
       
         /**
          * Returns the degree(s) of freedom of the hypothesis test.
          * Return type should be Number(e.g. Int, Double) or tuples of Numbers for toString compatibility.
          */
      +  @Since("1.1.0")
         def degreesOfFreedom: DF
       
         /**
          * Test statistic.
          */
      +  @Since("1.1.0")
         def statistic: Double
       
         /**
          * Null hypothesis of the test.
          */
      +  @Since("1.1.0")
         def nullHypothesis: String
       
         /**
      @@ -78,11 +83,12 @@ trait TestResult[DF] {
        * Object containing the test results for the chi-squared hypothesis test.
        */
       @Experimental
      +@Since("1.1.0")
       class ChiSqTestResult private[stat] (override val pValue: Double,
      -    override val degreesOfFreedom: Int,
      -    override val statistic: Double,
      -    val method: String,
      -    override val nullHypothesis: String) extends TestResult[Int] {
      +    @Since("1.1.0") override val degreesOfFreedom: Int,
      +    @Since("1.1.0") override val statistic: Double,
      +    @Since("1.1.0") val method: String,
      +    @Since("1.1.0") override val nullHypothesis: String) extends TestResult[Int] {
       
         override def toString: String = {
           "Chi squared test summary:\n" +
      @@ -90,3 +96,22 @@ class ChiSqTestResult private[stat] (override val pValue: Double,
             super.toString
         }
       }
      +
      +/**
      + * :: Experimental ::
      + * Object containing the test results for the Kolmogorov-Smirnov test.
      + */
      +@Experimental
      +@Since("1.5.0")
      +class KolmogorovSmirnovTestResult private[stat] (
      +    @Since("1.5.0") override val pValue: Double,
      +    @Since("1.5.0") override val statistic: Double,
      +    @Since("1.5.0") override val nullHypothesis: String) extends TestResult[Int] {
      +
      +  @Since("1.5.0")
      +  override val degreesOfFreedom = 0
      +
      +  override def toString: String = {
      +    "Kolmogorov-Smirnov test summary:\n" + super.toString
      +  }
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
      index cecd1fed896d..4a77d4adcd86 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
      @@ -22,7 +22,7 @@ import scala.collection.mutable
       import scala.collection.mutable.ArrayBuilder
       
       import org.apache.spark.Logging
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.api.java.JavaRDD
       import org.apache.spark.mllib.regression.LabeledPoint
       import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
      @@ -44,8 +44,10 @@ import org.apache.spark.util.random.XORShiftRandom
        *                 of algorithm (classification, regression, etc.), feature type (continuous,
        *                 categorical), depth of the tree, quantile calculation strategy, etc.
        */
      +@Since("1.0.0")
       @Experimental
      -class DecisionTree (private val strategy: Strategy) extends Serializable with Logging {
      +class DecisionTree @Since("1.0.0") (private val strategy: Strategy)
      +  extends Serializable with Logging {
       
         strategy.assertValid()
       
      @@ -54,6 +56,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
          * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
          * @return DecisionTreeModel that can be used for prediction
          */
      +  @Since("1.2.0")
         def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
           // Note: random seed will not be used since numTrees = 1.
           val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
      @@ -62,6 +65,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
         }
       }
       
      +@Since("1.0.0")
       object DecisionTree extends Serializable with Logging {
       
         /**
      @@ -79,7 +83,8 @@ object DecisionTree extends Serializable with Logging {
          *                 of algorithm (classification, regression, etc.), feature type (continuous,
          *                 categorical), depth of the tree, quantile calculation strategy, etc.
          * @return DecisionTreeModel that can be used for prediction
      -  */
      +   */
      + @Since("1.0.0")
         def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
           new DecisionTree(strategy).run(input)
         }
      @@ -101,6 +106,7 @@ object DecisionTree extends Serializable with Logging {
          *                 E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
          * @return DecisionTreeModel that can be used for prediction
          */
      +  @Since("1.0.0")
         def train(
             input: RDD[LabeledPoint],
             algo: Algo,
      @@ -128,6 +134,7 @@ object DecisionTree extends Serializable with Logging {
          * @param numClasses number of classes for classification. Default value of 2.
          * @return DecisionTreeModel that can be used for prediction
          */
      +  @Since("1.2.0")
         def train(
             input: RDD[LabeledPoint],
             algo: Algo,
      @@ -161,6 +168,7 @@ object DecisionTree extends Serializable with Logging {
          *                                with k categories indexed from 0: {0, 1, ..., k-1}.
          * @return DecisionTreeModel that can be used for prediction
          */
      +  @Since("1.0.0")
         def train(
             input: RDD[LabeledPoint],
             algo: Algo,
      @@ -193,6 +201,7 @@ object DecisionTree extends Serializable with Logging {
          *                 (suggested value: 32)
          * @return DecisionTreeModel that can be used for prediction
          */
      +  @Since("1.1.0")
         def trainClassifier(
             input: RDD[LabeledPoint],
             numClasses: Int,
      @@ -208,6 +217,7 @@ object DecisionTree extends Serializable with Logging {
         /**
          * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
          */
      +  @Since("1.1.0")
         def trainClassifier(
             input: JavaRDD[LabeledPoint],
             numClasses: Int,
      @@ -237,6 +247,7 @@ object DecisionTree extends Serializable with Logging {
          *                 (suggested value: 32)
          * @return DecisionTreeModel that can be used for prediction
          */
      +  @Since("1.1.0")
         def trainRegressor(
             input: RDD[LabeledPoint],
             categoricalFeaturesInfo: Map[Int, Int],
      @@ -250,6 +261,7 @@ object DecisionTree extends Serializable with Logging {
         /**
          * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
          */
      +  @Since("1.1.0")
         def trainRegressor(
             input: JavaRDD[LabeledPoint],
             categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
      index a835f96d5d0e..95ed48cea671 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
      @@ -18,8 +18,9 @@
       package org.apache.spark.mllib.tree
       
       import org.apache.spark.Logging
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.api.java.JavaRDD
      +import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
       import org.apache.spark.mllib.regression.LabeledPoint
       import org.apache.spark.mllib.tree.configuration.BoostingStrategy
       import org.apache.spark.mllib.tree.configuration.Algo._
      @@ -48,8 +49,9 @@ import org.apache.spark.storage.StorageLevel
        *
        * @param boostingStrategy Parameters for the gradient boosting algorithm.
        */
      +@Since("1.2.0")
       @Experimental
      -class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
      +class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: BoostingStrategy)
         extends Serializable with Logging {
       
         /**
      @@ -57,6 +59,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
          * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
          * @return a gradient boosted trees model that can be used for prediction
          */
      +  @Since("1.2.0")
         def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
           val algo = boostingStrategy.treeStrategy.algo
           algo match {
      @@ -74,6 +77,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
         /**
          * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#run]].
          */
      +  @Since("1.2.0")
         def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = {
           run(input.rdd)
         }
      @@ -88,6 +92,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
          *                        by using [[org.apache.spark.rdd.RDD.randomSplit()]]
          * @return a gradient boosted trees model that can be used for prediction
          */
      +  @Since("1.4.0")
         def runWithValidation(
             input: RDD[LabeledPoint],
             validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = {
      @@ -111,6 +116,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
         /**
          * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#runWithValidation]].
          */
      +  @Since("1.4.0")
         def runWithValidation(
             input: JavaRDD[LabeledPoint],
             validationInput: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = {
      @@ -118,6 +124,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
         }
       }
       
      +@Since("1.2.0")
       object GradientBoostedTrees extends Logging {
       
         /**
      @@ -129,6 +136,7 @@ object GradientBoostedTrees extends Logging {
          * @param boostingStrategy Configuration options for the boosting algorithm.
          * @return a gradient boosted trees model that can be used for prediction
          */
      +  @Since("1.2.0")
         def train(
             input: RDD[LabeledPoint],
             boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
      @@ -138,6 +146,7 @@ object GradientBoostedTrees extends Logging {
         /**
          * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees$#train]]
          */
      +  @Since("1.2.0")
         def train(
             input: JavaRDD[LabeledPoint],
             boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
      @@ -184,22 +193,28 @@ object GradientBoostedTrees extends Logging {
             false
           }
       
      +    // Prepare periodic checkpointers
      +    val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
      +      treeStrategy.getCheckpointInterval, input.sparkContext)
      +    val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
      +      treeStrategy.getCheckpointInterval, input.sparkContext)
      +
           timer.stop("init")
       
           logDebug("##########")
           logDebug("Building tree 0")
           logDebug("##########")
      -    var data = input
       
           // Initialize tree
           timer.start("building tree 0")
      -    val firstTreeModel = new DecisionTree(treeStrategy).run(data)
      +    val firstTreeModel = new DecisionTree(treeStrategy).run(input)
           val firstTreeWeight = 1.0
           baseLearners(0) = firstTreeModel
           baseLearnerWeights(0) = firstTreeWeight
       
           var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
             computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
      +    predErrorCheckpointer.update(predError)
           logDebug("error of gbt = " + predError.values.mean())
       
           // Note: A model of type regression is used since we require raw prediction
      @@ -207,35 +222,34 @@ object GradientBoostedTrees extends Logging {
       
           var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel.
             computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
      +    if (validate) validatePredErrorCheckpointer.update(validatePredError)
           var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
           var bestM = 1
       
      -    // pseudo-residual for second iteration
      -    data = predError.zip(input).map { case ((pred, _), point) =>
      -      LabeledPoint(-loss.gradient(pred, point.label), point.features)
      -    }
      -
           var m = 1
      -    while (m < numIterations) {
      +    var doneLearning = false
      +    while (m < numIterations && !doneLearning) {
      +      // Update data with pseudo-residuals
      +      val data = predError.zip(input).map { case ((pred, _), point) =>
      +        LabeledPoint(-loss.gradient(pred, point.label), point.features)
      +      }
      +
             timer.start(s"building tree $m")
             logDebug("###################################################")
             logDebug("Gradient boosting tree iteration " + m)
             logDebug("###################################################")
             val model = new DecisionTree(treeStrategy).run(data)
             timer.stop(s"building tree $m")
      -      // Create partial model
      +      // Update partial model
             baseLearners(m) = model
             // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
             //       Technically, the weight should be optimized for the particular loss.
             //       However, the behavior should be reasonable, though not optimal.
             baseLearnerWeights(m) = learningRate
      -      // Note: A model of type regression is used since we require raw prediction
      -      val partialModel = new GradientBoostedTreesModel(
      -        Regression, baseLearners.slice(0, m + 1),
      -        baseLearnerWeights.slice(0, m + 1))
       
             predError = GradientBoostedTreesModel.updatePredictionError(
               input, predError, baseLearnerWeights(m), baseLearners(m), loss)
      +      predErrorCheckpointer.update(predError)
             logDebug("error of gbt = " + predError.values.mean())
       
             if (validate) {
      @@ -246,21 +260,15 @@ object GradientBoostedTrees extends Logging {
       
               validatePredError = GradientBoostedTreesModel.updatePredictionError(
                 validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
      +        validatePredErrorCheckpointer.update(validatePredError)
               val currentValidateError = validatePredError.values.mean()
               if (bestValidateError - currentValidateError < validationTol) {
      -          return new GradientBoostedTreesModel(
      -            boostingStrategy.treeStrategy.algo,
      -            baseLearners.slice(0, bestM),
      -            baseLearnerWeights.slice(0, bestM))
      +          doneLearning = true
               } else if (currentValidateError < bestValidateError) {
      -            bestValidateError = currentValidateError
      -            bestM = m + 1
      +          bestValidateError = currentValidateError
      +          bestM = m + 1
               }
             }
      -      // Update data with pseudo-residuals
      -      data = predError.zip(input).map { case ((pred, _), point) =>
      -        LabeledPoint(-loss.gradient(pred, point.label), point.features)
      -      }
             m += 1
           }
       
      @@ -269,6 +277,8 @@ object GradientBoostedTrees extends Logging {
           logInfo("Internal timing for DecisionTree:")
           logInfo(s"$timer")
       
      +    predErrorCheckpointer.deleteAllCheckpoints()
      +    validatePredErrorCheckpointer.deleteAllCheckpoints()
           if (persistedInput) input.unpersist()
       
           if (validate) {
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
      index 069959976a18..63a902f3eb51 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
      @@ -23,7 +23,7 @@ import scala.collection.mutable
       import scala.collection.JavaConverters._
       
       import org.apache.spark.Logging
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.api.java.JavaRDD
       import org.apache.spark.mllib.regression.LabeledPoint
       import org.apache.spark.mllib.tree.configuration.Strategy
      @@ -260,6 +260,7 @@ private class RandomForest (
       
       }
       
      +@Since("1.2.0")
       object RandomForest extends Serializable with Logging {
       
         /**
      @@ -277,6 +278,7 @@ object RandomForest extends Serializable with Logging {
          * @param seed  Random seed for bootstrapping and choosing feature subsets.
          * @return a random forest model that can be used for prediction
          */
      +  @Since("1.2.0")
         def trainClassifier(
             input: RDD[LabeledPoint],
             strategy: Strategy,
      @@ -314,6 +316,7 @@ object RandomForest extends Serializable with Logging {
          * @param seed  Random seed for bootstrapping and choosing feature subsets.
          * @return a random forest model  that can be used for prediction
          */
      +  @Since("1.2.0")
         def trainClassifier(
             input: RDD[LabeledPoint],
             numClasses: Int,
      @@ -333,6 +336,7 @@ object RandomForest extends Serializable with Logging {
         /**
          * Java-friendly API for [[org.apache.spark.mllib.tree.RandomForest$#trainClassifier]]
          */
      +  @Since("1.2.0")
         def trainClassifier(
             input: JavaRDD[LabeledPoint],
             numClasses: Int,
      @@ -363,6 +367,7 @@ object RandomForest extends Serializable with Logging {
          * @param seed  Random seed for bootstrapping and choosing feature subsets.
          * @return a random forest model that can be used for prediction
          */
      +  @Since("1.2.0")
         def trainRegressor(
             input: RDD[LabeledPoint],
             strategy: Strategy,
      @@ -399,6 +404,7 @@ object RandomForest extends Serializable with Logging {
          * @param seed  Random seed for bootstrapping and choosing feature subsets.
          * @return a random forest model that can be used for prediction
          */
      +  @Since("1.2.0")
         def trainRegressor(
             input: RDD[LabeledPoint],
             categoricalFeaturesInfo: Map[Int, Int],
      @@ -417,6 +423,7 @@ object RandomForest extends Serializable with Logging {
         /**
          * Java-friendly API for [[org.apache.spark.mllib.tree.RandomForest$#trainRegressor]]
          */
      +  @Since("1.2.0")
         def trainRegressor(
             input: JavaRDD[LabeledPoint],
             categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
      @@ -434,6 +441,7 @@ object RandomForest extends Serializable with Logging {
         /**
          * List of supported feature subset sampling strategies.
          */
      +  @Since("1.2.0")
         val supportedFeatureSubsetStrategies: Array[String] =
           Array("auto", "all", "sqrt", "log2", "onethird")
       
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala
      index b6099259971b..853c7319ec44 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala
      @@ -17,15 +17,18 @@
       
       package org.apache.spark.mllib.tree.configuration
       
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       
       /**
        * :: Experimental ::
        * Enum to select the algorithm for the decision tree
        */
      +@Since("1.0.0")
       @Experimental
       object Algo extends Enumeration {
      +  @Since("1.0.0")
         type Algo = Value
      +  @Since("1.0.0")
         val Classification, Regression = Value
       
         private[mllib] def fromString(name: String): Algo = name match {
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
      index 2d6b01524ff3..b5c72fba3ede 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
      @@ -19,7 +19,7 @@ package org.apache.spark.mllib.tree.configuration
       
       import scala.beans.BeanProperty
       
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.mllib.tree.configuration.Algo._
       import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
       
      @@ -36,17 +36,19 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
        *                     learning rate should be between in the interval (0, 1]
        * @param validationTol Useful when runWithValidation is used. If the error rate on the
        *                      validation input between two iterations is less than the validationTol
      - *                      then stop. Ignored when [[run]] is used.
      + *                      then stop.  Ignored when
      + *                      [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used.
        */
      +@Since("1.2.0")
       @Experimental
      -case class BoostingStrategy(
      +case class BoostingStrategy @Since("1.4.0") (
           // Required boosting parameters
      -    @BeanProperty var treeStrategy: Strategy,
      -    @BeanProperty var loss: Loss,
      +    @Since("1.2.0") @BeanProperty var treeStrategy: Strategy,
      +    @Since("1.2.0") @BeanProperty var loss: Loss,
           // Optional boosting parameters
      -    @BeanProperty var numIterations: Int = 100,
      -    @BeanProperty var learningRate: Double = 0.1,
      -    @BeanProperty var validationTol: Double = 1e-5) extends Serializable {
      +    @Since("1.2.0") @BeanProperty var numIterations: Int = 100,
      +    @Since("1.2.0") @BeanProperty var learningRate: Double = 0.1,
      +    @Since("1.4.0") @BeanProperty var validationTol: Double = 1e-5) extends Serializable {
       
         /**
          * Check validity of parameters.
      @@ -69,6 +71,7 @@ case class BoostingStrategy(
         }
       }
       
      +@Since("1.2.0")
       @Experimental
       object BoostingStrategy {
       
      @@ -77,6 +80,7 @@ object BoostingStrategy {
          * @param algo Learning goal.  Supported: "Classification" or "Regression"
          * @return Configuration for boosting algorithm
          */
      +  @Since("1.2.0")
         def defaultParams(algo: String): BoostingStrategy = {
           defaultParams(Algo.fromString(algo))
         }
      @@ -88,8 +92,9 @@ object BoostingStrategy {
          *             [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
          * @return Configuration for boosting algorithm
          */
      +  @Since("1.3.0")
         def defaultParams(algo: Algo): BoostingStrategy = {
      -    val treeStrategy = Strategy.defaultStategy(algo)
      +    val treeStrategy = Strategy.defaultStrategy(algo)
           treeStrategy.maxDepth = 3
           algo match {
             case Algo.Classification =>
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala
      index f4c877232750..4e0cd473def0 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala
      @@ -17,14 +17,17 @@
       
       package org.apache.spark.mllib.tree.configuration
       
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       
       /**
        * :: Experimental ::
        * Enum to describe whether a feature is "continuous" or "categorical"
        */
      +@Since("1.0.0")
       @Experimental
       object FeatureType extends Enumeration {
      +  @Since("1.0.0")
         type FeatureType = Value
      +  @Since("1.0.0")
         val Continuous, Categorical = Value
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala
      index 7da976e55a72..8262db8a4f11 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala
      @@ -17,14 +17,17 @@
       
       package org.apache.spark.mllib.tree.configuration
       
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       
       /**
        * :: Experimental ::
        * Enum for selecting the quantile calculation strategy
        */
      +@Since("1.0.0")
       @Experimental
       object QuantileStrategy extends Enumeration {
      +  @Since("1.0.0")
         type QuantileStrategy = Value
      +  @Since("1.0.0")
         val Sort, MinMax, ApproxHist = Value
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
      index ada227c200a7..89cc13b7c06c 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
      @@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.configuration
       import scala.beans.BeanProperty
       import scala.collection.JavaConverters._
       
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity}
       import org.apache.spark.mllib.tree.configuration.Algo._
       import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
      @@ -67,26 +67,33 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
        *                           the checkpoint directory is not set in
        *                           [[org.apache.spark.SparkContext]], this setting is ignored.
        */
      +@Since("1.0.0")
       @Experimental
      -class Strategy (
      -    @BeanProperty var algo: Algo,
      -    @BeanProperty var impurity: Impurity,
      -    @BeanProperty var maxDepth: Int,
      -    @BeanProperty var numClasses: Int = 2,
      -    @BeanProperty var maxBins: Int = 32,
      -    @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort,
      -    @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
      -    @BeanProperty var minInstancesPerNode: Int = 1,
      -    @BeanProperty var minInfoGain: Double = 0.0,
      -    @BeanProperty var maxMemoryInMB: Int = 256,
      -    @BeanProperty var subsamplingRate: Double = 1,
      -    @BeanProperty var useNodeIdCache: Boolean = false,
      -    @BeanProperty var checkpointInterval: Int = 10) extends Serializable {
      +class Strategy @Since("1.3.0") (
      +    @Since("1.0.0") @BeanProperty var algo: Algo,
      +    @Since("1.0.0") @BeanProperty var impurity: Impurity,
      +    @Since("1.0.0") @BeanProperty var maxDepth: Int,
      +    @Since("1.2.0") @BeanProperty var numClasses: Int = 2,
      +    @Since("1.0.0") @BeanProperty var maxBins: Int = 32,
      +    @Since("1.0.0") @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort,
      +    @Since("1.0.0") @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
      +    @Since("1.2.0") @BeanProperty var minInstancesPerNode: Int = 1,
      +    @Since("1.2.0") @BeanProperty var minInfoGain: Double = 0.0,
      +    @Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256,
      +    @Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1,
      +    @Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false,
      +    @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10) extends Serializable {
       
      +  /**
      +   */
      +  @Since("1.2.0")
         def isMulticlassClassification: Boolean = {
           algo == Classification && numClasses > 2
         }
       
      +  /**
      +   */
      +  @Since("1.2.0")
         def isMulticlassWithCategoricalFeatures: Boolean = {
           isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
         }
      @@ -94,6 +101,7 @@ class Strategy (
         /**
          * Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]]
          */
      +  @Since("1.1.0")
         def this(
             algo: Algo,
             impurity: Impurity,
      @@ -108,6 +116,7 @@ class Strategy (
         /**
          * Sets Algorithm using a String.
          */
      +  @Since("1.2.0")
         def setAlgo(algo: String): Unit = algo match {
           case "Classification" => setAlgo(Classification)
           case "Regression" => setAlgo(Regression)
      @@ -116,6 +125,7 @@ class Strategy (
         /**
          * Sets categoricalFeaturesInfo using a Java Map.
          */
      +  @Since("1.2.0")
         def setCategoricalFeaturesInfo(
             categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]): Unit = {
           this.categoricalFeaturesInfo =
      @@ -148,11 +158,6 @@ class Strategy (
             s"  Valid values are integers >= 0.")
           require(maxBins >= 2, s"DecisionTree Strategy given invalid maxBins parameter: $maxBins." +
             s"  Valid values are integers >= 2.")
      -    categoricalFeaturesInfo.foreach { case (feature, arity) =>
      -      require(arity >= 2,
      -        s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" +
      -        s" feature $feature has $arity categories.  The number of categories should be >= 2.")
      -    }
           require(minInstancesPerNode >= 1,
             s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
           require(maxMemoryInMB <= 10240,
      @@ -162,7 +167,10 @@ class Strategy (
             s"$subsamplingRate")
         }
       
      -  /** Returns a shallow copy of this instance. */
      +  /**
      +   * Returns a shallow copy of this instance.
      +   */
      +  @Since("1.2.0")
         def copy: Strategy = {
           new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
             quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain,
      @@ -170,6 +178,7 @@ class Strategy (
         }
       }
       
      +@Since("1.2.0")
       @Experimental
       object Strategy {
       
      @@ -177,15 +186,17 @@ object Strategy {
          * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
          * @param algo  "Classification" or "Regression"
          */
      +  @Since("1.2.0")
         def defaultStrategy(algo: String): Strategy = {
      -    defaultStategy(Algo.fromString(algo))
      +    defaultStrategy(Algo.fromString(algo))
         }
       
         /**
          * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
          * @param algo Algo.Classification or Algo.Regression
          */
      -  def defaultStategy(algo: Algo): Strategy = algo match {
      +  @Since("1.3.0")
      +  def defaultStrategy(algo: Algo): Strategy = algo match {
           case Algo.Classification =>
             new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
               numClasses = 2)
      @@ -193,4 +204,9 @@ object Strategy {
             new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
               numClasses = 0)
         }
      +
      +  @deprecated("Use Strategy.defaultStrategy instead.", "1.5.0")
      +  @Since("1.2.0")
      +  def defaultStategy(algo: Algo): Strategy = defaultStrategy(algo)
      +
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
      index 089010c81ffb..572815df0bc4 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
      @@ -38,10 +38,10 @@ import org.apache.spark.util.random.XORShiftRandom
        * TODO: This does not currently support (Double) weighted instances.  Once MLlib has weighted
        *       dataset support, update.  (We store subsampleWeights as Double for this future extension.)
        */
      -private[tree] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double])
      +private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double])
         extends Serializable
       
      -private[tree] object BaggedPoint {
      +private[spark] object BaggedPoint {
       
         /**
          * Convert an input dataset into its BaggedPoint representation,
      @@ -60,7 +60,7 @@ private[tree] object BaggedPoint {
             subsamplingRate: Double,
             numSubsamples: Int,
             withReplacement: Boolean,
      -      seed: Int = Utils.random.nextInt()): RDD[BaggedPoint[Datum]] = {
      +      seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = {
           if (withReplacement) {
             convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed)
           } else {
      @@ -76,7 +76,7 @@ private[tree] object BaggedPoint {
             input: RDD[Datum],
             subsamplingRate: Double,
             numSubsamples: Int,
      -      seed: Int): RDD[BaggedPoint[Datum]] = {
      +      seed: Long): RDD[BaggedPoint[Datum]] = {
           input.mapPartitionsWithIndex { (partitionIndex, instances) =>
             // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
             val rng = new XORShiftRandom
      @@ -100,7 +100,7 @@ private[tree] object BaggedPoint {
             input: RDD[Datum],
             subsample: Double,
             numSubsamples: Int,
      -      seed: Int): RDD[BaggedPoint[Datum]] = {
      +      seed: Long): RDD[BaggedPoint[Datum]] = {
           input.mapPartitionsWithIndex { (partitionIndex, instances) =>
             // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
             val poisson = new PoissonDistribution(subsample)
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
      index ce8825cc0322..7985ed4b4c0f 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
      @@ -27,7 +27,7 @@ import org.apache.spark.mllib.tree.impurity._
        * and helps with indexing.
        * This class is abstract to support learning with and without feature subsampling.
        */
      -private[tree] class DTStatsAggregator(
      +private[spark] class DTStatsAggregator(
           val metadata: DecisionTreeMetadata,
           featureSubset: Option[Array[Int]]) extends Serializable {
       
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
      index f73896e37c05..21ee49c45788 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
      @@ -37,7 +37,7 @@ import org.apache.spark.rdd.RDD
        *                      I.e., the feature takes values in {0, ..., arity - 1}.
        * @param numBins  Number of bins for each feature.
        */
      -private[tree] class DecisionTreeMetadata(
      +private[spark] class DecisionTreeMetadata(
           val numFeatures: Int,
           val numExamples: Long,
           val numClasses: Int,
      @@ -94,7 +94,7 @@ private[tree] class DecisionTreeMetadata(
       
       }
       
      -private[tree] object DecisionTreeMetadata extends Logging {
      +private[spark] object DecisionTreeMetadata extends Logging {
       
         /**
          * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters.
      @@ -128,9 +128,13 @@ private[tree] object DecisionTreeMetadata extends Logging {
           // based on the number of training examples.
           if (strategy.categoricalFeaturesInfo.nonEmpty) {
             val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max
      +      val maxCategory =
      +        strategy.categoricalFeaturesInfo.find(_._2 == maxCategoriesPerFeature).get._1
             require(maxCategoriesPerFeature <= maxPossibleBins,
      -        s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories " +
      -          s"in categorical features (= $maxCategoriesPerFeature)")
      +        s"DecisionTree requires maxBins (= $maxPossibleBins) to be at least as large as the " +
      +        s"number of values in each categorical feature, but categorical feature $maxCategory " +
      +        s"has $maxCategoriesPerFeature values. Considering remove this and other categorical " +
      +        "features with a large number of values, or add more training examples.")
           }
       
           val unorderedFeatures = new mutable.HashSet[Int]()
      @@ -140,21 +144,28 @@ private[tree] object DecisionTreeMetadata extends Logging {
             val maxCategoriesForUnorderedFeature =
               ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt
             strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
      -        // Decide if some categorical features should be treated as unordered features,
      -        //  which require 2 * ((1 << numCategories - 1) - 1) bins.
      -        // We do this check with log values to prevent overflows in case numCategories is large.
      -        // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins
      -        if (numCategories <= maxCategoriesForUnorderedFeature) {
      -          unorderedFeatures.add(featureIndex)
      -          numBins(featureIndex) = numUnorderedBins(numCategories)
      -        } else {
      -          numBins(featureIndex) = numCategories
      +        // Hack: If a categorical feature has only 1 category, we treat it as continuous.
      +        // TODO(SPARK-9957): Handle this properly by filtering out those features.
      +        if (numCategories > 1) {
      +          // Decide if some categorical features should be treated as unordered features,
      +          //  which require 2 * ((1 << numCategories - 1) - 1) bins.
      +          // We do this check with log values to prevent overflows in case numCategories is large.
      +          // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins
      +          if (numCategories <= maxCategoriesForUnorderedFeature) {
      +            unorderedFeatures.add(featureIndex)
      +            numBins(featureIndex) = numUnorderedBins(numCategories)
      +          } else {
      +            numBins(featureIndex) = numCategories
      +          }
               }
             }
           } else {
             // Binary classification or regression
             strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
      -        numBins(featureIndex) = numCategories
      +        // If a categorical feature has only 1 category, we treat it as continuous: SPARK-9957
      +        if (numCategories > 1) {
      +          numBins(featureIndex) = numCategories
      +        }
             }
           }
       
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
      index bdd0f576b048..8f9eb24b57b5 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
      @@ -75,7 +75,7 @@ private[tree] case class NodeIndexUpdater(
        *                           (how often should the cache be checkpointed.).
        */
       @DeveloperApi
      -private[tree] class NodeIdCache(
      +private[spark] class NodeIdCache(
         var nodeIdsForInstances: RDD[Array[Int]],
         val checkpointInterval: Int) {
       
      @@ -170,7 +170,7 @@ private[tree] class NodeIdCache(
       }
       
       @DeveloperApi
      -private[tree] object NodeIdCache {
      +private[spark] object NodeIdCache {
         /**
          * Initialize the node Id cache with initial node Id values.
          * @param data The RDD of training rows.
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala
      index d215d68c4279..aac84243d5ce 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala
      @@ -25,7 +25,7 @@ import org.apache.spark.annotation.Experimental
        * Time tracker implementation which holds labeled timers.
        */
       @Experimental
      -private[tree] class TimeTracker extends Serializable {
      +private[spark] class TimeTracker extends Serializable {
       
         private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]()
       
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
      index 50b292e71b06..21919d69a38a 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
      @@ -37,11 +37,11 @@ import org.apache.spark.rdd.RDD
        * @param binnedFeatures  Binned feature values.
        *                        Same length as LabeledPoint.features, but values are bin indices.
        */
      -private[tree] class TreePoint(val label: Double, val binnedFeatures: Array[Int])
      +private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int])
         extends Serializable {
       }
       
      -private[tree] object TreePoint {
      +private[spark] object TreePoint {
       
         /**
          * Convert an input dataset into its TreePoint representation,
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
      index 5ac10f3fd32d..73df6b054a8c 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
      @@ -17,13 +17,14 @@
       
       package org.apache.spark.mllib.tree.impurity
       
      -import org.apache.spark.annotation.{DeveloperApi, Experimental}
      +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
       
       /**
        * :: Experimental ::
        * Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during
        * binary classification.
        */
      +@Since("1.0.0")
       @Experimental
       object Entropy extends Impurity {
       
      @@ -36,6 +37,7 @@ object Entropy extends Impurity {
          * @param totalCount sum of counts for all labels
          * @return information value, or 0 if totalCount = 0
          */
      +  @Since("1.1.0")
         @DeveloperApi
         override def calculate(counts: Array[Double], totalCount: Double): Double = {
           if (totalCount == 0) {
      @@ -63,6 +65,7 @@ object Entropy extends Impurity {
          * @param sumSquares summation of squares of the labels
          * @return information value, or 0 if count = 0
          */
      +  @Since("1.0.0")
         @DeveloperApi
         override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
           throw new UnsupportedOperationException("Entropy.calculate")
      @@ -71,6 +74,7 @@ object Entropy extends Impurity {
          * Get this impurity instance.
          * This is useful for passing impurity parameters to a Strategy in Java.
          */
      +  @Since("1.1.0")
         def instance: this.type = this
       
       }
      @@ -118,7 +122,7 @@ private[tree] class EntropyAggregator(numClasses: Int)
        * (node, feature, bin).
        * @param stats  Array of sufficient statistics for a (node, feature, bin).
        */
      -private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
      +private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
       
         /**
          * Make a deep copy of this [[ImpurityCalculator]].
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
      index 19d318203c34..f21845b21a80 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
      @@ -17,7 +17,7 @@
       
       package org.apache.spark.mllib.tree.impurity
       
      -import org.apache.spark.annotation.{DeveloperApi, Experimental}
      +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
       
       /**
        * :: Experimental ::
      @@ -25,6 +25,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
        * [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]]
        * during binary classification.
        */
      +@Since("1.0.0")
       @Experimental
       object Gini extends Impurity {
       
      @@ -35,6 +36,7 @@ object Gini extends Impurity {
          * @param totalCount sum of counts for all labels
          * @return information value, or 0 if totalCount = 0
          */
      +  @Since("1.1.0")
         @DeveloperApi
         override def calculate(counts: Array[Double], totalCount: Double): Double = {
           if (totalCount == 0) {
      @@ -59,6 +61,7 @@ object Gini extends Impurity {
          * @param sumSquares summation of squares of the labels
          * @return information value, or 0 if count = 0
          */
      +  @Since("1.0.0")
         @DeveloperApi
         override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
           throw new UnsupportedOperationException("Gini.calculate")
      @@ -67,6 +70,7 @@ object Gini extends Impurity {
          * Get this impurity instance.
          * This is useful for passing impurity parameters to a Strategy in Java.
          */
      +  @Since("1.1.0")
         def instance: this.type = this
       
       }
      @@ -114,7 +118,7 @@ private[tree] class GiniAggregator(numClasses: Int)
        * (node, feature, bin).
        * @param stats  Array of sufficient statistics for a (node, feature, bin).
        */
      -private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
      +private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
       
         /**
          * Make a deep copy of this [[ImpurityCalculator]].
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
      index 72eb24c49264..4637dcceea7f 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
      @@ -17,7 +17,7 @@
       
       package org.apache.spark.mllib.tree.impurity
       
      -import org.apache.spark.annotation.{DeveloperApi, Experimental}
      +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
       
       /**
        * :: Experimental ::
      @@ -26,6 +26,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
        *  (a) setting the impurity parameter in [[org.apache.spark.mllib.tree.configuration.Strategy]]
        *  (b) calculating impurity values from sufficient statistics.
        */
      +@Since("1.0.0")
       @Experimental
       trait Impurity extends Serializable {
       
      @@ -36,6 +37,7 @@ trait Impurity extends Serializable {
          * @param totalCount sum of counts for all labels
          * @return information value, or 0 if totalCount = 0
          */
      +  @Since("1.1.0")
         @DeveloperApi
         def calculate(counts: Array[Double], totalCount: Double): Double
       
      @@ -47,6 +49,7 @@ trait Impurity extends Serializable {
          * @param sumSquares summation of squares of the labels
          * @return information value, or 0 if count = 0
          */
      +  @Since("1.0.0")
         @DeveloperApi
         def calculate(count: Double, sum: Double, sumSquares: Double): Double
       }
      @@ -57,7 +60,7 @@ trait Impurity extends Serializable {
        * Note: Instances of this class do not hold the data; they operate on views of the data.
        * @param statsSize  Length of the vector of sufficient statistics for one bin.
        */
      -private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Serializable {
      +private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Serializable {
       
         /**
          * Merge the stats from one bin into another.
      @@ -95,7 +98,7 @@ private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Seri
        * (node, feature, bin).
        * @param stats  Array of sufficient statistics for a (node, feature, bin).
        */
      -private[tree] abstract class ImpurityCalculator(val stats: Array[Double]) {
      +private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) extends Serializable {
       
         /**
          * Make a deep copy of this [[ImpurityCalculator]].
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
      index 7104a7fa4dd4..a74197278d6f 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
      @@ -17,12 +17,13 @@
       
       package org.apache.spark.mllib.tree.impurity
       
      -import org.apache.spark.annotation.{DeveloperApi, Experimental}
      +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
       
       /**
        * :: Experimental ::
        * Class for calculating variance during regression
        */
      +@Since("1.0.0")
       @Experimental
       object Variance extends Impurity {
       
      @@ -33,6 +34,7 @@ object Variance extends Impurity {
          * @param totalCount sum of counts for all labels
          * @return information value, or 0 if totalCount = 0
          */
      +  @Since("1.1.0")
         @DeveloperApi
         override def calculate(counts: Array[Double], totalCount: Double): Double =
            throw new UnsupportedOperationException("Variance.calculate")
      @@ -45,6 +47,7 @@ object Variance extends Impurity {
          * @param sumSquares summation of squares of the labels
          * @return information value, or 0 if count = 0
          */
      +  @Since("1.0.0")
         @DeveloperApi
         override def calculate(count: Double, sum: Double, sumSquares: Double): Double = {
           if (count == 0) {
      @@ -58,6 +61,7 @@ object Variance extends Impurity {
          * Get this impurity instance.
          * This is useful for passing impurity parameters to a Strategy in Java.
          */
      +  @Since("1.0.0")
         def instance: this.type = this
       
       }
      @@ -98,7 +102,7 @@ private[tree] class VarianceAggregator()
        * (node, feature, bin).
        * @param stats  Array of sufficient statistics for a (node, feature, bin).
        */
      -private[tree] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
      +private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
       
         require(stats.size == 3,
           s"VarianceCalculator requires sufficient statistics array stats to be of length 3," +
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
      index 2bdef73c4a8f..bab7b8c6cadf 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
      @@ -17,7 +17,7 @@
       
       package org.apache.spark.mllib.tree.loss
       
      -import org.apache.spark.annotation.DeveloperApi
      +import org.apache.spark.annotation.{DeveloperApi, Since}
       import org.apache.spark.mllib.regression.LabeledPoint
       import org.apache.spark.mllib.tree.model.TreeEnsembleModel
       
      @@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.model.TreeEnsembleModel
        *  |y - F(x)|
        * where y is the label and F(x) is the model prediction for features x.
        */
      +@Since("1.2.0")
       @DeveloperApi
       object AbsoluteError extends Loss {
       
      @@ -41,6 +42,7 @@ object AbsoluteError extends Loss {
          * @param label True label.
          * @return Loss gradient
          */
      +  @Since("1.2.0")
         override def gradient(prediction: Double, label: Double): Double = {
           if (label - prediction < 0) 1.0 else -1.0
         }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
      index 778c24526de7..b2b4594712f0 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
      @@ -17,7 +17,7 @@
       
       package org.apache.spark.mllib.tree.loss
       
      -import org.apache.spark.annotation.DeveloperApi
      +import org.apache.spark.annotation.{DeveloperApi, Since}
       import org.apache.spark.mllib.regression.LabeledPoint
       import org.apache.spark.mllib.tree.model.TreeEnsembleModel
       import org.apache.spark.mllib.util.MLUtils
      @@ -32,6 +32,7 @@ import org.apache.spark.mllib.util.MLUtils
        *   2 log(1 + exp(-2 y F(x)))
        * where y is a label in {-1, 1} and F(x) is the model prediction for features x.
        */
      +@Since("1.2.0")
       @DeveloperApi
       object LogLoss extends Loss {
       
      @@ -43,6 +44,7 @@ object LogLoss extends Loss {
          * @param label True label.
          * @return Loss gradient
          */
      +  @Since("1.2.0")
         override def gradient(prediction: Double, label: Double): Double = {
           - 4.0 * label / (1.0 + math.exp(2.0 * label * prediction))
         }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
      index 64ffccbce073..687cde325ffe 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
      @@ -17,7 +17,7 @@
       
       package org.apache.spark.mllib.tree.loss
       
      -import org.apache.spark.annotation.DeveloperApi
      +import org.apache.spark.annotation.{DeveloperApi, Since}
       import org.apache.spark.mllib.regression.LabeledPoint
       import org.apache.spark.mllib.tree.model.TreeEnsembleModel
       import org.apache.spark.rdd.RDD
      @@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD
        * :: DeveloperApi ::
        * Trait for adding "pluggable" loss functions for the gradient boosting algorithm.
        */
      +@Since("1.2.0")
       @DeveloperApi
       trait Loss extends Serializable {
       
      @@ -36,6 +37,7 @@ trait Loss extends Serializable {
          * @param label true label.
          * @return Loss gradient.
          */
      +  @Since("1.2.0")
         def gradient(prediction: Double, label: Double): Double
       
         /**
      @@ -46,6 +48,7 @@ trait Loss extends Serializable {
          * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
          * @return Measure of model error on data
          */
      +  @Since("1.2.0")
         def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
           data.map(point => computeError(model.predict(point.features), point.label)).mean()
         }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala
      index 42c9ead9884b..2b112fbe1220 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala
      @@ -17,8 +17,12 @@
       
       package org.apache.spark.mllib.tree.loss
       
      +import org.apache.spark.annotation.Since
      +
      +@Since("1.2.0")
       object Losses {
       
      +  @Since("1.2.0")
         def fromString(name: String): Loss = name match {
           case "leastSquaresError" => SquaredError
           case "leastAbsoluteError" => AbsoluteError
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
      index a5582d3ef332..3f7d3d38be16 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
      @@ -17,7 +17,7 @@
       
       package org.apache.spark.mllib.tree.loss
       
      -import org.apache.spark.annotation.DeveloperApi
      +import org.apache.spark.annotation.{DeveloperApi, Since}
       import org.apache.spark.mllib.regression.LabeledPoint
       import org.apache.spark.mllib.tree.model.TreeEnsembleModel
       
      @@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.model.TreeEnsembleModel
        *   (y - F(x))**2
        * where y is the label and F(x) is the model prediction for features x.
        */
      +@Since("1.2.0")
       @DeveloperApi
       object SquaredError extends Loss {
       
      @@ -41,12 +42,13 @@ object SquaredError extends Loss {
          * @param label True label.
          * @return Loss gradient
          */
      +  @Since("1.2.0")
         override def gradient(prediction: Double, label: Double): Double = {
      -    2.0 * (prediction - label)
      +    - 2.0 * (label - prediction)
         }
       
         override private[mllib] def computeError(prediction: Double, label: Double): Double = {
      -    val err = prediction - label
      +    val err = label - prediction
           err * err
         }
       }
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
      index 25bb1453db40..e1bf23f4c34b 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
      @@ -24,7 +24,7 @@ import org.json4s.JsonDSL._
       import org.json4s.jackson.JsonMethods._
       
       import org.apache.spark.{Logging, SparkContext}
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.api.java.JavaRDD
       import org.apache.spark.mllib.linalg.Vector
       import org.apache.spark.mllib.tree.configuration.{Algo, FeatureType}
      @@ -41,8 +41,11 @@ import org.apache.spark.util.Utils
        * @param topNode root node
        * @param algo algorithm type -- classification or regression
        */
      +@Since("1.0.0")
       @Experimental
      -class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable with Saveable {
      +class DecisionTreeModel @Since("1.0.0") (
      +    @Since("1.0.0") val topNode: Node,
      +    @Since("1.0.0") val algo: Algo) extends Serializable with Saveable {
       
         /**
          * Predict values for a single data point using the model trained.
      @@ -50,6 +53,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
          * @param features array representing a single data point
          * @return Double prediction from the trained model
          */
      +  @Since("1.0.0")
         def predict(features: Vector): Double = {
           topNode.predict(features)
         }
      @@ -60,6 +64,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
          * @param features RDD representing data points to be predicted
          * @return RDD of predictions for each of the given data points
          */
      +  @Since("1.0.0")
         def predict(features: RDD[Vector]): RDD[Double] = {
           features.map(x => predict(x))
         }
      @@ -70,6 +75,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
          * @param features JavaRDD representing data points to be predicted
          * @return JavaRDD of predictions for each of the given data points
          */
      +  @Since("1.2.0")
         def predict(features: JavaRDD[Vector]): JavaRDD[Double] = {
           predict(features.rdd)
         }
      @@ -77,6 +83,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
         /**
          * Get number of nodes in tree, including leaf nodes.
          */
      +  @Since("1.1.0")
         def numNodes: Int = {
           1 + topNode.numDescendants
         }
      @@ -85,6 +92,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
          * Get depth of tree.
          * E.g.: Depth 0 means 1 leaf node.  Depth 1 means 1 internal node and 2 leaf nodes.
          */
      +  @Since("1.1.0")
         def depth: Int = {
           topNode.subtreeDepth
         }
      @@ -104,11 +112,18 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
         /**
          * Print the full model to a string.
          */
      +  @Since("1.2.0")
         def toDebugString: String = {
           val header = toString + "\n"
           header + topNode.subtreeToString(2)
         }
       
      +  /**
      +   * @param sc  Spark context used to save model data.
      +   * @param path  Path specifying the directory in which to save this model.
      +   *              If the directory already exists, this method throws an exception.
      +   */
      +  @Since("1.3.0")
         override def save(sc: SparkContext, path: String): Unit = {
           DecisionTreeModel.SaveLoadV1_0.save(sc, path, this)
         }
      @@ -116,6 +131,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
         override protected def formatVersion: String = DecisionTreeModel.formatVersion
       }
       
      +@Since("1.3.0")
       object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
       
         private[spark] def formatVersion: String = "1.0"
      @@ -198,7 +214,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
               val driverMemory = sc.getConf.getOption("spark.driver.memory")
                 .orElse(Option(System.getenv("SPARK_DRIVER_MEMORY")))
                 .map(Utils.memoryStringToMb)
      -          .getOrElse(512)
      +          .getOrElse(Utils.DEFAULT_DRIVER_MEM_MB)
               if (driverMemory <= memThreshold) {
                 logWarning(s"$thisClassName.save() was called, but it may fail because of too little" +
                   s" driver memory (${driverMemory}m)." +
      @@ -297,6 +313,13 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
           }
         }
       
      +  /**
      +   *
      +   * @param sc  Spark context used for loading model files.
      +   * @param path  Path specifying the directory to which the model was saved.
      +   * @return  Model instance
      +   */
      +  @Since("1.3.0")
         override def load(sc: SparkContext, path: String): DecisionTreeModel = {
           implicit val formats = DefaultFormats
           val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
      index 2d087c967f67..091a0462c204 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
      @@ -17,7 +17,8 @@
       
       package org.apache.spark.mllib.tree.model
       
      -import org.apache.spark.annotation.DeveloperApi
      +import org.apache.spark.annotation.{DeveloperApi, Since}
      +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
       
       /**
        * :: DeveloperApi ::
      @@ -29,6 +30,7 @@ import org.apache.spark.annotation.DeveloperApi
        * @param leftPredict left node predict
        * @param rightPredict right node predict
        */
      +@Since("1.0.0")
       @DeveloperApi
       class InformationGainStats(
           val gain: Double,
      @@ -66,8 +68,7 @@ class InformationGainStats(
         }
       }
       
      -
      -private[tree] object InformationGainStats {
      +private[spark] object InformationGainStats {
         /**
          * An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to
          * denote that current split doesn't satisfies minimum info gain or
      @@ -76,3 +77,62 @@ private[tree] object InformationGainStats {
         val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0,
           new Predict(0.0, 0.0), new Predict(0.0, 0.0))
       }
      +
      +/**
      + * :: DeveloperApi ::
      + * Impurity statistics for each split
      + * @param gain information gain value
      + * @param impurity current node impurity
      + * @param impurityCalculator impurity statistics for current node
      + * @param leftImpurityCalculator impurity statistics for left child node
      + * @param rightImpurityCalculator impurity statistics for right child node
      + * @param valid whether the current split satisfies minimum info gain or
      + *              minimum number of instances per node
      + */
      +@DeveloperApi
      +private[spark] class ImpurityStats(
      +    val gain: Double,
      +    val impurity: Double,
      +    val impurityCalculator: ImpurityCalculator,
      +    val leftImpurityCalculator: ImpurityCalculator,
      +    val rightImpurityCalculator: ImpurityCalculator,
      +    val valid: Boolean = true) extends Serializable {
      +
      +  override def toString: String = {
      +    s"gain = $gain, impurity = $impurity, left impurity = $leftImpurity, " +
      +      s"right impurity = $rightImpurity"
      +  }
      +
      +  def leftImpurity: Double = if (leftImpurityCalculator != null) {
      +    leftImpurityCalculator.calculate()
      +  } else {
      +    -1.0
      +  }
      +
      +  def rightImpurity: Double = if (rightImpurityCalculator != null) {
      +    rightImpurityCalculator.calculate()
      +  } else {
      +    -1.0
      +  }
      +}
      +
      +private[spark] object ImpurityStats {
      +
      +  /**
      +   * Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object to
      +   * denote that current split doesn't satisfies minimum info gain or
      +   * minimum number of instances per node.
      +   */
      +  def getInvalidImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = {
      +    new ImpurityStats(Double.MinValue, impurityCalculator.calculate(),
      +      impurityCalculator, null, null, false)
      +  }
      +
      +  /**
      +   * Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object
      +   * that only 'impurity' and 'impurityCalculator' are defined.
      +   */
      +  def getEmptyImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = {
      +    new ImpurityStats(Double.NaN, impurityCalculator.calculate(), impurityCalculator, null, null)
      +  }
      +}
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
      index a6d1398fc267..ea6e5aa5d94e 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
      @@ -17,7 +17,7 @@
       
       package org.apache.spark.mllib.tree.model
       
      -import org.apache.spark.annotation.DeveloperApi
      +import org.apache.spark.annotation.{DeveloperApi, Since}
       import org.apache.spark.Logging
       import org.apache.spark.mllib.tree.configuration.FeatureType._
       import org.apache.spark.mllib.linalg.Vector
      @@ -39,16 +39,17 @@ import org.apache.spark.mllib.linalg.Vector
        * @param rightNode right child
        * @param stats information gain stats
        */
      +@Since("1.0.0")
       @DeveloperApi
      -class Node (
      -    val id: Int,
      -    var predict: Predict,
      -    var impurity: Double,
      -    var isLeaf: Boolean,
      -    var split: Option[Split],
      -    var leftNode: Option[Node],
      -    var rightNode: Option[Node],
      -    var stats: Option[InformationGainStats]) extends Serializable with Logging {
      +class Node @Since("1.2.0") (
      +    @Since("1.0.0") val id: Int,
      +    @Since("1.0.0") var predict: Predict,
      +    @Since("1.2.0") var impurity: Double,
      +    @Since("1.0.0") var isLeaf: Boolean,
      +    @Since("1.0.0") var split: Option[Split],
      +    @Since("1.0.0") var leftNode: Option[Node],
      +    @Since("1.0.0") var rightNode: Option[Node],
      +    @Since("1.0.0") var stats: Option[InformationGainStats]) extends Serializable with Logging {
       
         override def toString: String = {
           s"id = $id, isLeaf = $isLeaf, predict = $predict, impurity = $impurity, " +
      @@ -59,6 +60,7 @@ class Node (
          * build the left node and right nodes if not leaf
          * @param nodes array of nodes
          */
      +  @Since("1.0.0")
         @deprecated("build should no longer be used since trees are constructed on-the-fly in training",
           "1.2.0")
         def build(nodes: Array[Node]): Unit = {
      @@ -80,6 +82,7 @@ class Node (
          * @param features feature value
          * @return predicted value
          */
      +  @Since("1.1.0")
         def predict(features: Vector) : Double = {
           if (isLeaf) {
             predict.predict
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
      index 5cbe7c280dbe..06ceff19d863 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
      @@ -17,17 +17,18 @@
       
       package org.apache.spark.mllib.tree.model
       
      -import org.apache.spark.annotation.DeveloperApi
      +import org.apache.spark.annotation.{DeveloperApi, Since}
       
       /**
        * Predicted value for a node
        * @param predict predicted value
        * @param prob probability of the label (classification only)
        */
      +@Since("1.2.0")
       @DeveloperApi
      -class Predict(
      -    val predict: Double,
      -    val prob: Double = 0.0) extends Serializable {
      +class Predict @Since("1.2.0") (
      +    @Since("1.2.0") val predict: Double,
      +    @Since("1.2.0") val prob: Double = 0.0) extends Serializable {
       
         override def toString: String = s"$predict (prob = $prob)"
       
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
      index be6c9b3de547..b85a66c05a81 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
      @@ -17,7 +17,7 @@
       
       package org.apache.spark.mllib.tree.model
       
      -import org.apache.spark.annotation.DeveloperApi
      +import org.apache.spark.annotation.{DeveloperApi, Since}
       import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
       import org.apache.spark.mllib.tree.configuration.FeatureType
       import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
      @@ -31,12 +31,13 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
        * @param featureType type of feature -- categorical or continuous
        * @param categories Split left if categorical feature value is in this set, else right.
        */
      +@Since("1.0.0")
       @DeveloperApi
       case class Split(
      -    feature: Int,
      -    threshold: Double,
      -    featureType: FeatureType,
      -    categories: List[Double]) {
      +    @Since("1.0.0") feature: Int,
      +    @Since("1.0.0") threshold: Double,
      +    @Since("1.0.0") featureType: FeatureType,
      +    @Since("1.0.0") categories: List[Double]) {
       
         override def toString: String = {
           s"Feature = $feature, threshold = $threshold, featureType = $featureType, " +
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
      index 1e3333d8d81d..df5b8feab5d5 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
      @@ -25,7 +25,7 @@ import org.json4s.JsonDSL._
       import org.json4s.jackson.JsonMethods._
       
       import org.apache.spark.{Logging, SparkContext}
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.api.java.JavaRDD
       import org.apache.spark.mllib.linalg.Vector
       import org.apache.spark.mllib.regression.LabeledPoint
      @@ -46,14 +46,24 @@ import org.apache.spark.util.Utils
        * @param algo algorithm for the ensemble model, either Classification or Regression
        * @param trees tree ensembles
        */
      +@Since("1.2.0")
       @Experimental
      -class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel])
      +class RandomForestModel @Since("1.2.0") (
      +    @Since("1.2.0") override val algo: Algo,
      +    @Since("1.2.0") override val trees: Array[DecisionTreeModel])
         extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0),
           combiningStrategy = if (algo == Classification) Vote else Average)
         with Saveable {
       
         require(trees.forall(_.algo == algo))
       
      +  /**
      +   *
      +   * @param sc  Spark context used to save model data.
      +   * @param path  Path specifying the directory in which to save this model.
      +   *              If the directory already exists, this method throws an exception.
      +   */
      +  @Since("1.3.0")
         override def save(sc: SparkContext, path: String): Unit = {
           TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this,
             RandomForestModel.SaveLoadV1_0.thisClassName)
      @@ -62,10 +72,18 @@ class RandomForestModel(override val algo: Algo, override val trees: Array[Decis
         override protected def formatVersion: String = RandomForestModel.formatVersion
       }
       
      +@Since("1.3.0")
       object RandomForestModel extends Loader[RandomForestModel] {
       
         private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
       
      +  /**
      +   *
      +   * @param sc  Spark context used for loading model files.
      +   * @param path  Path specifying the directory to which the model was saved.
      +   * @return  Model instance
      +   */
      +  @Since("1.3.0")
         override def load(sc: SparkContext, path: String): RandomForestModel = {
           val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path)
           val classNameV1_0 = SaveLoadV1_0.thisClassName
      @@ -97,16 +115,23 @@ object RandomForestModel extends Loader[RandomForestModel] {
        * @param trees tree ensembles
        * @param treeWeights tree ensemble weights
        */
      +@Since("1.2.0")
       @Experimental
      -class GradientBoostedTreesModel(
      -    override val algo: Algo,
      -    override val trees: Array[DecisionTreeModel],
      -    override val treeWeights: Array[Double])
      +class GradientBoostedTreesModel @Since("1.2.0") (
      +    @Since("1.2.0") override val algo: Algo,
      +    @Since("1.2.0") override val trees: Array[DecisionTreeModel],
      +    @Since("1.2.0") override val treeWeights: Array[Double])
         extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum)
         with Saveable {
       
         require(trees.length == treeWeights.length)
       
      +  /**
      +   * @param sc  Spark context used to save model data.
      +   * @param path  Path specifying the directory in which to save this model.
      +   *              If the directory already exists, this method throws an exception.
      +   */
      +  @Since("1.3.0")
         override def save(sc: SparkContext, path: String): Unit = {
           TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this,
             GradientBoostedTreesModel.SaveLoadV1_0.thisClassName)
      @@ -119,6 +144,7 @@ class GradientBoostedTreesModel(
          * @return an array with index i having the losses or errors for the ensemble
          *         containing the first i+1 trees
          */
      +  @Since("1.4.0")
         def evaluateEachIteration(
             data: RDD[LabeledPoint],
             loss: Loss): Array[Double] = {
      @@ -159,6 +185,9 @@ class GradientBoostedTreesModel(
         override protected def formatVersion: String = GradientBoostedTreesModel.formatVersion
       }
       
      +/**
      + */
      +@Since("1.3.0")
       object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
       
         /**
      @@ -171,6 +200,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
          * @return a RDD with each element being a zip of the prediction and error
          *         corresponding to every sample.
          */
      +  @Since("1.4.0")
         def computeInitialPredictionAndError(
             data: RDD[LabeledPoint],
             initTreeWeight: Double,
      @@ -194,6 +224,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
          * @return a RDD with each element being a zip of the prediction and error
          *         corresponding to each sample.
          */
      +  @Since("1.4.0")
         def updatePredictionError(
           data: RDD[LabeledPoint],
           predictionAndError: RDD[(Double, Double)],
      @@ -213,6 +244,12 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
       
         private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
       
      +  /**
      +   * @param sc  Spark context used for loading model files.
      +   * @param path  Path specifying the directory to which the model was saved.
      +   * @return  Model instance
      +   */
      +  @Since("1.3.0")
         override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = {
           val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path)
           val classNameV1_0 = SaveLoadV1_0.thisClassName
      @@ -387,7 +424,7 @@ private[tree] object TreeEnsembleModel extends Logging {
               val driverMemory = sc.getConf.getOption("spark.driver.memory")
                 .orElse(Option(System.getenv("SPARK_DRIVER_MEMORY")))
                 .map(Utils.memoryStringToMb)
      -          .getOrElse(512)
      +          .getOrElse(Utils.DEFAULT_DRIVER_MEM_MB)
               if (driverMemory <= memThreshold) {
                 logWarning(s"$className.save() was called, but it may fail because of too little" +
                   s" driver memory (${driverMemory}m)." +
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala
      index be335a1aca58..dffe6e78939e 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala
      @@ -17,16 +17,17 @@
       
       package org.apache.spark.mllib.util
       
      -import org.apache.spark.annotation.DeveloperApi
       import org.apache.spark.Logging
      -import org.apache.spark.rdd.RDD
      +import org.apache.spark.annotation.{DeveloperApi, Since}
       import org.apache.spark.mllib.regression.LabeledPoint
      +import org.apache.spark.rdd.RDD
       
       /**
        * :: DeveloperApi ::
        * A collection of methods used to validate data before applying ML algorithms.
        */
       @DeveloperApi
      +@Since("0.8.0")
       object DataValidators extends Logging {
       
         /**
      @@ -34,6 +35,7 @@ object DataValidators extends Logging {
          *
          * @return True if labels are all zero or one, false otherwise.
          */
      +  @Since("1.0.0")
         val binaryLabelValidator: RDD[LabeledPoint] => Boolean = { data =>
           val numInvalid = data.filter(x => x.label != 1.0 && x.label != 0.0).count()
           if (numInvalid != 0) {
      @@ -48,6 +50,7 @@ object DataValidators extends Logging {
          *
          * @return True if labels are all in the range of {0, 1, ..., k-1}, false otherwise.
          */
      +  @Since("1.3.0")
         def multiLabelValidator(k: Int): RDD[LabeledPoint] => Boolean = { data =>
           val numInvalid = data.filter(x =>
             x.label - x.label.toInt != 0.0 || x.label < 0 || x.label > k - 1).count()
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala
      index 6eaebaf7dba9..00fd1606a369 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala
      @@ -19,8 +19,8 @@ package org.apache.spark.mllib.util
       
       import scala.util.Random
       
      -import org.apache.spark.annotation.DeveloperApi
       import org.apache.spark.SparkContext
      +import org.apache.spark.annotation.{DeveloperApi, Since}
       import org.apache.spark.rdd.RDD
       
       /**
      @@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD
        * cluster with scale 1 around each center.
        */
       @DeveloperApi
      +@Since("0.8.0")
       object KMeansDataGenerator {
       
         /**
      @@ -42,6 +43,7 @@ object KMeansDataGenerator {
          * @param r Scaling factor for the distribution of the initial centers
          * @param numPartitions Number of partitions of the generated RDD; default 2
          */
      +  @Since("0.8.0")
         def generateKMeansRDD(
             sc: SparkContext,
             numPoints: Int,
      @@ -62,10 +64,13 @@ object KMeansDataGenerator {
           }
         }
       
      +  @Since("0.8.0")
         def main(args: Array[String]) {
           if (args.length < 6) {
      +      // scalastyle:off println
             println("Usage: KMeansGenerator " +
               "      []")
      +      // scalastyle:on println
             System.exit(1)
           }
       
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
      index b4e33c98ba7e..d0ba454f379a 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
      @@ -17,16 +17,16 @@
       
       package org.apache.spark.mllib.util
       
      -import scala.collection.JavaConversions._
      +import scala.collection.JavaConverters._
       import scala.util.Random
       
       import com.github.fommil.netlib.BLAS.{getInstance => blas}
       
      -import org.apache.spark.annotation.DeveloperApi
       import org.apache.spark.SparkContext
      -import org.apache.spark.rdd.RDD
      +import org.apache.spark.annotation.{DeveloperApi, Since}
       import org.apache.spark.mllib.linalg.Vectors
       import org.apache.spark.mllib.regression.LabeledPoint
      +import org.apache.spark.rdd.RDD
       
       /**
        * :: DeveloperApi ::
      @@ -35,6 +35,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
        * response variable `Y`.
        */
       @DeveloperApi
      +@Since("0.8.0")
       object LinearDataGenerator {
       
         /**
      @@ -46,13 +47,14 @@ object LinearDataGenerator {
          * @param seed Random seed
          * @return Java List of input.
          */
      +  @Since("0.8.0")
         def generateLinearInputAsList(
             intercept: Double,
             weights: Array[Double],
             nPoints: Int,
             seed: Int,
             eps: Double): java.util.List[LabeledPoint] = {
      -    seqAsJavaList(generateLinearInput(intercept, weights, nPoints, seed, eps))
      +    generateLinearInput(intercept, weights, nPoints, seed, eps).asJava
         }
       
         /**
      @@ -68,6 +70,7 @@ object LinearDataGenerator {
          * @param eps Epsilon scaling factor.
          * @return Seq of input.
          */
      +  @Since("0.8.0")
         def generateLinearInput(
             intercept: Double,
             weights: Array[Double],
      @@ -92,6 +95,7 @@ object LinearDataGenerator {
          * @param eps Epsilon scaling factor.
          * @return Seq of input.
          */
      +  @Since("0.8.0")
         def generateLinearInput(
             intercept: Double,
             weights: Array[Double],
      @@ -132,6 +136,7 @@ object LinearDataGenerator {
          *
          * @return RDD of LabeledPoint containing sample data.
          */
      +  @Since("0.8.0")
         def generateLinearRDD(
             sc: SparkContext,
             nexamples: Int,
      @@ -151,10 +156,13 @@ object LinearDataGenerator {
           data
         }
       
      +  @Since("0.8.0")
         def main(args: Array[String]) {
           if (args.length < 2) {
      +      // scalastyle:off println
             println("Usage: LinearDataGenerator " +
               "  [num_examples] [num_features] [num_partitions]")
      +      // scalastyle:on println
             System.exit(1)
           }
       
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala
      index 9d802678c4a7..33477ee20ebb 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala
      @@ -19,7 +19,7 @@ package org.apache.spark.mllib.util
       
       import scala.util.Random
       
      -import org.apache.spark.annotation.DeveloperApi
      +import org.apache.spark.annotation.{Since, DeveloperApi}
       import org.apache.spark.SparkContext
       import org.apache.spark.rdd.RDD
       import org.apache.spark.mllib.regression.LabeledPoint
      @@ -31,6 +31,7 @@ import org.apache.spark.mllib.linalg.Vectors
        * with probability `probOne` and scales features for positive examples by `eps`.
        */
       @DeveloperApi
      +@Since("0.8.0")
       object LogisticRegressionDataGenerator {
       
         /**
      @@ -43,6 +44,7 @@ object LogisticRegressionDataGenerator {
          * @param nparts Number of partitions of the generated RDD. Default value is 2.
          * @param probOne Probability that a label is 1 (and not 0). Default value is 0.5.
          */
      +  @Since("0.8.0")
         def generateLogisticRDD(
           sc: SparkContext,
           nexamples: Int,
      @@ -62,10 +64,13 @@ object LogisticRegressionDataGenerator {
           data
         }
       
      +  @Since("0.8.0")
         def main(args: Array[String]) {
           if (args.length != 5) {
      +      // scalastyle:off println
             println("Usage: LogisticRegressionGenerator " +
               "    ")
      +      // scalastyle:on println
             System.exit(1)
           }
       
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala
      index bd73a866c8a8..906bd30563bd 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala
      @@ -23,7 +23,7 @@ import scala.language.postfixOps
       import scala.util.Random
       
       import org.apache.spark.SparkContext
      -import org.apache.spark.annotation.DeveloperApi
      +import org.apache.spark.annotation.{Since, DeveloperApi}
       import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix}
       import org.apache.spark.rdd.RDD
       
      @@ -52,11 +52,15 @@ import org.apache.spark.rdd.RDD
        *   testSampFact   (Double) Percentage of training data to use as test data.
        */
       @DeveloperApi
      +@Since("0.8.0")
       object MFDataGenerator {
      +  @Since("0.8.0")
         def main(args: Array[String]) {
           if (args.length < 2) {
      +      // scalastyle:off println
             println("Usage: MFDataGenerator " +
               "  [m] [n] [rank] [trainSampFact] [noise] [sigma] [test] [testSampFact]")
      +      // scalastyle:on println
             System.exit(1)
           }
       
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
      index 7c5cfa7bd84c..81c2f0ce6e12 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
      @@ -21,7 +21,7 @@ import scala.reflect.ClassTag
       
       import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}
       
      -import org.apache.spark.annotation.Experimental
      +import org.apache.spark.annotation.{Experimental, Since}
       import org.apache.spark.SparkContext
       import org.apache.spark.rdd.RDD
       import org.apache.spark.rdd.PartitionwiseSampledRDD
      @@ -36,6 +36,7 @@ import org.apache.spark.streaming.dstream.DStream
       /**
        * Helper methods to load, save and pre-process data used in ML Lib.
        */
      +@Since("0.8.0")
       object MLUtils {
       
         private[mllib] lazy val EPSILON = {
      @@ -65,6 +66,7 @@ object MLUtils {
          * @param minPartitions min number of partitions
          * @return labeled data stored as an RDD[LabeledPoint]
          */
      +  @Since("1.0.0")
         def loadLibSVMFile(
             sc: SparkContext,
             path: String,
      @@ -114,6 +116,7 @@ object MLUtils {
       
         // Convenient methods for `loadLibSVMFile`.
       
      +  @Since("1.0.0")
         @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0")
         def loadLibSVMFile(
             sc: SparkContext,
      @@ -127,12 +130,14 @@ object MLUtils {
          * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the default number of
          * partitions.
          */
      +  @Since("1.0.0")
         def loadLibSVMFile(
             sc: SparkContext,
             path: String,
             numFeatures: Int): RDD[LabeledPoint] =
           loadLibSVMFile(sc, path, numFeatures, sc.defaultMinPartitions)
       
      +  @Since("1.0.0")
         @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0")
         def loadLibSVMFile(
             sc: SparkContext,
      @@ -141,6 +146,7 @@ object MLUtils {
             numFeatures: Int): RDD[LabeledPoint] =
           loadLibSVMFile(sc, path, numFeatures)
       
      +  @Since("1.0.0")
         @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0")
         def loadLibSVMFile(
             sc: SparkContext,
      @@ -152,6 +158,7 @@ object MLUtils {
          * Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint], with number of
          * features determined automatically and the default number of partitions.
          */
      +  @Since("1.0.0")
         def loadLibSVMFile(sc: SparkContext, path: String): RDD[LabeledPoint] =
           loadLibSVMFile(sc, path, -1)
       
      @@ -162,6 +169,7 @@ object MLUtils {
          *
          * @see [[org.apache.spark.mllib.util.MLUtils#loadLibSVMFile]]
          */
      +  @Since("1.0.0")
         def saveAsLibSVMFile(data: RDD[LabeledPoint], dir: String) {
           // TODO: allow to specify label precision and feature precision.
           val dataStr = data.map { case LabeledPoint(label, features) =>
      @@ -182,12 +190,14 @@ object MLUtils {
          * @param minPartitions min number of partitions
          * @return vectors stored as an RDD[Vector]
          */
      +  @Since("1.1.0")
         def loadVectors(sc: SparkContext, path: String, minPartitions: Int): RDD[Vector] =
           sc.textFile(path, minPartitions).map(Vectors.parse)
       
         /**
          * Loads vectors saved using `RDD[Vector].saveAsTextFile` with the default number of partitions.
          */
      +  @Since("1.1.0")
         def loadVectors(sc: SparkContext, path: String): RDD[Vector] =
           sc.textFile(path, sc.defaultMinPartitions).map(Vectors.parse)
       
      @@ -198,6 +208,7 @@ object MLUtils {
          * @param minPartitions min number of partitions
          * @return labeled points stored as an RDD[LabeledPoint]
          */
      +  @Since("1.1.0")
         def loadLabeledPoints(sc: SparkContext, path: String, minPartitions: Int): RDD[LabeledPoint] =
           sc.textFile(path, minPartitions).map(LabeledPoint.parse)
       
      @@ -205,6 +216,7 @@ object MLUtils {
          * Loads labeled points saved using `RDD[LabeledPoint].saveAsTextFile` with the default number of
          * partitions.
          */
      +  @Since("1.1.0")
         def loadLabeledPoints(sc: SparkContext, dir: String): RDD[LabeledPoint] =
           loadLabeledPoints(sc, dir, sc.defaultMinPartitions)
       
      @@ -221,6 +233,7 @@ object MLUtils {
          * @deprecated Should use [[org.apache.spark.rdd.RDD#saveAsTextFile]] for saving and
          *            [[org.apache.spark.mllib.util.MLUtils#loadLabeledPoints]] for loading.
          */
      +  @Since("1.0.0")
         @deprecated("Should use MLUtils.loadLabeledPoints instead.", "1.0.1")
         def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = {
           sc.textFile(dir).map { line =>
      @@ -242,6 +255,7 @@ object MLUtils {
          * @deprecated Should use [[org.apache.spark.rdd.RDD#saveAsTextFile]] for saving and
          *            [[org.apache.spark.mllib.util.MLUtils#loadLabeledPoints]] for loading.
          */
      +  @Since("1.0.0")
         @deprecated("Should use RDD[LabeledPoint].saveAsTextFile instead.", "1.0.1")
         def saveLabeledData(data: RDD[LabeledPoint], dir: String) {
           val dataStr = data.map(x => x.label + "," + x.features.toArray.mkString(" "))
      @@ -254,6 +268,7 @@ object MLUtils {
          * containing the training data, a complement of the validation data and the second
          * element, the validation data, containing a unique 1/kth of the data. Where k=numFolds.
          */
      +  @Since("1.0.0")
         @Experimental
         def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = {
           val numFoldsF = numFolds.toFloat
      @@ -269,6 +284,7 @@ object MLUtils {
         /**
          * Returns a new vector with `1.0` (bias) appended to the input vector.
          */
      +  @Since("1.0.0")
         def appendBias(vector: Vector): Vector = {
           vector match {
             case dv: DenseVector =>
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala
      index 308f7f3578e2..a841c5caf014 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala
      @@ -98,6 +98,8 @@ private[mllib] object NumericParser {
               }
             } else if (token == ")") {
               parsing = false
      +      } else if (token.trim.isEmpty){
      +          // ignore whitespaces between delim chars, e.g. ", ["
             } else {
               // expecting a number
               items.append(parseDouble(token))
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
      index a8e30cc9d730..cde597939617 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
      @@ -21,11 +21,11 @@ import scala.util.Random
       
       import com.github.fommil.netlib.BLAS.{getInstance => blas}
       
      -import org.apache.spark.annotation.DeveloperApi
       import org.apache.spark.SparkContext
      -import org.apache.spark.rdd.RDD
      +import org.apache.spark.annotation.{DeveloperApi, Since}
       import org.apache.spark.mllib.linalg.Vectors
       import org.apache.spark.mllib.regression.LabeledPoint
      +import org.apache.spark.rdd.RDD
       
       /**
        * :: DeveloperApi ::
      @@ -33,12 +33,16 @@ import org.apache.spark.mllib.regression.LabeledPoint
        * for the features and adds Gaussian noise with weight 0.1 to generate labels.
        */
       @DeveloperApi
      +@Since("0.8.0")
       object SVMDataGenerator {
       
      +  @Since("0.8.0")
         def main(args: Array[String]) {
           if (args.length < 2) {
      +      // scalastyle:off println
             println("Usage: SVMGenerator " +
               "  [num_examples] [num_features] [num_partitions]")
      +      // scalastyle:on println
             System.exit(1)
           }
       
      diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
      index 30d642c754b7..4d71d534a077 100644
      --- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
      +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
      @@ -24,7 +24,7 @@ import org.json4s._
       import org.json4s.jackson.JsonMethods._
       
       import org.apache.spark.SparkContext
      -import org.apache.spark.annotation.DeveloperApi
      +import org.apache.spark.annotation.{DeveloperApi, Since}
       import org.apache.spark.sql.catalyst.ScalaReflection
       import org.apache.spark.sql.types.{DataType, StructField, StructType}
       
      @@ -35,6 +35,7 @@ import org.apache.spark.sql.types.{DataType, StructField, StructType}
        * This should be inherited by the class which implements model instances.
        */
       @DeveloperApi
      +@Since("1.3.0")
       trait Saveable {
       
         /**
      @@ -50,6 +51,7 @@ trait Saveable {
          * @param path  Path specifying the directory in which to save this model.
          *              If the directory already exists, this method throws an exception.
          */
      +  @Since("1.3.0")
         def save(sc: SparkContext, path: String): Unit
       
         /** Current version of model save/load format. */
      @@ -64,6 +66,7 @@ trait Saveable {
        * This should be inherited by an object paired with the model class.
        */
       @DeveloperApi
      +@Since("1.3.0")
       trait Loader[M <: Saveable] {
       
         /**
      @@ -75,6 +78,7 @@ trait Loader[M <: Saveable] {
          * @param path  Path specifying the directory to which the model was saved.
          * @return  Model instance
          */
      +  @Since("1.3.0")
         def load(sc: SparkContext, path: String): M
       
       }
      diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
      index f75e024a713e..fd22eb6dca01 100644
      --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
      @@ -22,6 +22,7 @@
       import java.util.List;
       
       import org.junit.After;
      +import org.junit.Assert;
       import org.junit.Before;
       import org.junit.Test;
       
      @@ -63,16 +64,16 @@ public void tearDown() {
         @Test
         public void logisticRegressionDefaultParams() {
           LogisticRegression lr = new LogisticRegression();
      -    assert(lr.getLabelCol().equals("label"));
      +    Assert.assertEquals(lr.getLabelCol(), "label");
           LogisticRegressionModel model = lr.fit(dataset);
           model.transform(dataset).registerTempTable("prediction");
           DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
           predictions.collectAsList();
           // Check defaults
      -    assert(model.getThreshold() == 0.5);
      -    assert(model.getFeaturesCol().equals("features"));
      -    assert(model.getPredictionCol().equals("prediction"));
      -    assert(model.getProbabilityCol().equals("probability"));
      +    Assert.assertEquals(0.5, model.getThreshold(), eps);
      +    Assert.assertEquals("features", model.getFeaturesCol());
      +    Assert.assertEquals("prediction", model.getPredictionCol());
      +    Assert.assertEquals("probability", model.getProbabilityCol());
         }
       
         @Test
      @@ -85,17 +86,19 @@ public void logisticRegressionWithSetters() {
             .setProbabilityCol("myProbability");
           LogisticRegressionModel model = lr.fit(dataset);
           LogisticRegression parent = (LogisticRegression) model.parent();
      -    assert(parent.getMaxIter() == 10);
      -    assert(parent.getRegParam() == 1.0);
      -    assert(parent.getThreshold() == 0.6);
      -    assert(model.getThreshold() == 0.6);
      +    Assert.assertEquals(10, parent.getMaxIter());
      +    Assert.assertEquals(1.0, parent.getRegParam(), eps);
      +    Assert.assertEquals(0.4, parent.getThresholds()[0], eps);
      +    Assert.assertEquals(0.6, parent.getThresholds()[1], eps);
      +    Assert.assertEquals(0.6, parent.getThreshold(), eps);
      +    Assert.assertEquals(0.6, model.getThreshold(), eps);
       
           // Modify model params, and check that the params worked.
           model.setThreshold(1.0);
           model.transform(dataset).registerTempTable("predAllZero");
           DataFrame predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero");
           for (Row r: predAllZero.collectAsList()) {
      -      assert(r.getDouble(0) == 0.0);
      +      Assert.assertEquals(0.0, r.getDouble(0), eps);
           }
           // Call transform with params, and check that the params worked.
           model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
      @@ -105,17 +108,17 @@ public void logisticRegressionWithSetters() {
           for (Row r: predNotAllZero.collectAsList()) {
             if (r.getDouble(0) != 0.0) foundNonZero = true;
           }
      -    assert(foundNonZero);
      +    Assert.assertTrue(foundNonZero);
       
           // Call fit() with new params, and check as many params as we can.
           LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
               lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
           LogisticRegression parent2 = (LogisticRegression) model2.parent();
      -    assert(parent2.getMaxIter() == 5);
      -    assert(parent2.getRegParam() == 0.1);
      -    assert(parent2.getThreshold() == 0.4);
      -    assert(model2.getThreshold() == 0.4);
      -    assert(model2.getProbabilityCol().equals("theProb"));
      +    Assert.assertEquals(5, parent2.getMaxIter());
      +    Assert.assertEquals(0.1, parent2.getRegParam(), eps);
      +    Assert.assertEquals(0.4, parent2.getThreshold(), eps);
      +    Assert.assertEquals(0.4, model2.getThreshold(), eps);
      +    Assert.assertEquals("theProb", model2.getProbabilityCol());
         }
       
         @SuppressWarnings("unchecked")
      @@ -123,18 +126,18 @@ public void logisticRegressionWithSetters() {
         public void logisticRegressionPredictorClassifierMethods() {
           LogisticRegression lr = new LogisticRegression();
           LogisticRegressionModel model = lr.fit(dataset);
      -    assert(model.numClasses() == 2);
      +    Assert.assertEquals(2, model.numClasses());
       
           model.transform(dataset).registerTempTable("transformed");
           DataFrame trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed");
           for (Row row: trans1.collect()) {
             Vector raw = (Vector)row.get(0);
             Vector prob = (Vector)row.get(1);
      -      assert(raw.size() == 2);
      -      assert(prob.size() == 2);
      +      Assert.assertEquals(raw.size(), 2);
      +      Assert.assertEquals(prob.size(), 2);
             double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1)));
      -      assert(Math.abs(prob.apply(1) - probFromRaw1) < eps);
      -      assert(Math.abs(prob.apply(0) - (1.0 - probFromRaw1)) < eps);
      +      Assert.assertEquals(0, Math.abs(prob.apply(1) - probFromRaw1), eps);
      +      Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps);
           }
       
           DataFrame trans2 = jsql.sql("SELECT prediction, probability FROM transformed");
      @@ -143,8 +146,17 @@ public void logisticRegressionPredictorClassifierMethods() {
             Vector prob = (Vector)row.get(1);
             double probOfPred = prob.apply((int)pred);
             for (int i = 0; i < prob.size(); ++i) {
      -        assert(probOfPred >= prob.apply(i));
      +        Assert.assertTrue(probOfPred >= prob.apply(i));
             }
           }
         }
      +
      +  @Test
      +  public void logisticRegressionTrainingSummary() {
      +    LogisticRegression lr = new LogisticRegression();
      +    LogisticRegressionModel model = lr.fit(dataset);
      +
      +    LogisticRegressionTrainingSummary summary = model.summary();
      +    Assert.assertEquals(summary.totalIterations(), summary.objectiveHistory().length);
      +  }
       }
      diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
      new file mode 100644
      index 000000000000..ec6b4bf3c0f8
      --- /dev/null
      +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
      @@ -0,0 +1,74 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.classification;
      +
      +import java.io.Serializable;
      +import java.util.Arrays;
      +
      +import org.junit.After;
      +import org.junit.Assert;
      +import org.junit.Before;
      +import org.junit.Test;
      +
      +import org.apache.spark.api.java.JavaSparkContext;
      +import org.apache.spark.mllib.linalg.Vectors;
      +import org.apache.spark.mllib.regression.LabeledPoint;
      +import org.apache.spark.sql.DataFrame;
      +import org.apache.spark.sql.Row;
      +import org.apache.spark.sql.SQLContext;
      +
      +public class JavaMultilayerPerceptronClassifierSuite implements Serializable {
      +
      +  private transient JavaSparkContext jsc;
      +  private transient SQLContext sqlContext;
      +
      +  @Before
      +  public void setUp() {
      +    jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
      +    sqlContext = new SQLContext(jsc);
      +  }
      +
      +  @After
      +  public void tearDown() {
      +    jsc.stop();
      +    jsc = null;
      +    sqlContext = null;
      +  }
      +
      +  @Test
      +  public void testMLPC() {
      +    DataFrame dataFrame = sqlContext.createDataFrame(
      +      jsc.parallelize(Arrays.asList(
      +        new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
      +        new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
      +        new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
      +        new LabeledPoint(0.0, Vectors.dense(1.0, 1.0)))),
      +      LabeledPoint.class);
      +    MultilayerPerceptronClassifier mlpc = new MultilayerPerceptronClassifier()
      +      .setLayers(new int[] {2, 5, 2})
      +      .setBlockSize(1)
      +      .setSeed(11L)
      +      .setMaxIter(100);
      +    MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame);
      +    DataFrame result = model.transform(dataFrame);
      +    Row[] predictionAndLabels = result.select("prediction", "label").collect();
      +    for (Row r: predictionAndLabels) {
      +      Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1));
      +    }
      +  }
      +}
      diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
      new file mode 100644
      index 000000000000..075a62c493f1
      --- /dev/null
      +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
      @@ -0,0 +1,99 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.classification;
      +
      +import java.io.Serializable;
      +import java.util.Arrays;
      +
      +import org.junit.After;
      +import org.junit.Before;
      +import org.junit.Test;
      +import static org.junit.Assert.assertEquals;
      +
      +import org.apache.spark.api.java.JavaRDD;
      +import org.apache.spark.api.java.JavaSparkContext;
      +import org.apache.spark.mllib.linalg.VectorUDT;
      +import org.apache.spark.mllib.linalg.Vectors;
      +import org.apache.spark.sql.DataFrame;
      +import org.apache.spark.sql.Row;
      +import org.apache.spark.sql.RowFactory;
      +import org.apache.spark.sql.SQLContext;
      +import org.apache.spark.sql.types.DataTypes;
      +import org.apache.spark.sql.types.Metadata;
      +import org.apache.spark.sql.types.StructField;
      +import org.apache.spark.sql.types.StructType;
      +
      +public class JavaNaiveBayesSuite implements Serializable {
      +
      +  private transient JavaSparkContext jsc;
      +  private transient SQLContext jsql;
      +
      +  @Before
      +  public void setUp() {
      +    jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
      +    jsql = new SQLContext(jsc);
      +  }
      +
      +  @After
      +  public void tearDown() {
      +    jsc.stop();
      +    jsc = null;
      +  }
      +
      +  public void validatePrediction(DataFrame predictionAndLabels) {
      +    for (Row r : predictionAndLabels.collect()) {
      +      double prediction = r.getAs(0);
      +      double label = r.getAs(1);
      +      assertEquals(label, prediction, 1E-5);
      +    }
      +  }
      +
      +  @Test
      +  public void naiveBayesDefaultParams() {
      +    NaiveBayes nb = new NaiveBayes();
      +    assertEquals("label", nb.getLabelCol());
      +    assertEquals("features", nb.getFeaturesCol());
      +    assertEquals("prediction", nb.getPredictionCol());
      +    assertEquals(1.0, nb.getSmoothing(), 1E-5);
      +    assertEquals("multinomial", nb.getModelType());
      +  }
      +
      +  @Test
      +  public void testNaiveBayes() {
      +    JavaRDD jrdd = jsc.parallelize(Arrays.asList(
      +      RowFactory.create(0.0, Vectors.dense(1.0, 0.0, 0.0)),
      +      RowFactory.create(0.0, Vectors.dense(2.0, 0.0, 0.0)),
      +      RowFactory.create(1.0, Vectors.dense(0.0, 1.0, 0.0)),
      +      RowFactory.create(1.0, Vectors.dense(0.0, 2.0, 0.0)),
      +      RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 1.0)),
      +      RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 2.0))
      +    ));
      +
      +    StructType schema = new StructType(new StructField[]{
      +      new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
      +      new StructField("features", new VectorUDT(), false, Metadata.empty())
      +    });
      +
      +    DataFrame dataset = jsql.createDataFrame(jrdd, schema);
      +    NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial");
      +    NaiveBayesModel model = nb.fit(dataset);
      +
      +    DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label");
      +    validatePrediction(predictionAndLabels);
      +  }
      +}
      diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
      index a1ee55415237..253cabf0133d 100644
      --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
      @@ -20,7 +20,7 @@
       import java.io.Serializable;
       import java.util.List;
       
      -import static scala.collection.JavaConversions.seqAsJavaList;
      +import scala.collection.JavaConverters;
       
       import org.junit.After;
       import org.junit.Assert;
      @@ -55,8 +55,9 @@ public void setUp() {
       
               double[] xMean = {5.843, 3.057, 3.758, 1.199};
               double[] xVariance = {0.6856, 0.1899, 3.116, 0.581};
      -        List points = seqAsJavaList(generateMultinomialLogisticInput(
      -                weights, xMean, xVariance, true, nPoints, 42));
      +        List points = JavaConverters.seqAsJavaListConverter(
      +            generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42)
      +        ).asJava();
               datasetRDD = jsc.parallelize(points, 2);
               dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
           }
      diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
      index 32d0b3856b7e..a66a1e12927b 100644
      --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
      @@ -29,6 +29,7 @@
       import org.apache.spark.api.java.JavaSparkContext;
       import org.apache.spark.ml.impl.TreeTests;
       import org.apache.spark.mllib.classification.LogisticRegressionSuite;
      +import org.apache.spark.mllib.linalg.Vector;
       import org.apache.spark.mllib.regression.LabeledPoint;
       import org.apache.spark.sql.DataFrame;
       
      @@ -85,6 +86,7 @@ public void runDT() {
           model.toDebugString();
           model.trees();
           model.treeWeights();
      +    Vector importances = model.featureImportances();
       
           /*
           // TODO: Add test once save/load are implemented.  SPARK-6725
      diff --git a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
      new file mode 100644
      index 000000000000..d09fa7fd5637
      --- /dev/null
      +++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
      @@ -0,0 +1,72 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.clustering;
      +
      +import java.io.Serializable;
      +import java.util.Arrays;
      +import java.util.List;
      +
      +import org.junit.After;
      +import org.junit.Before;
      +import org.junit.Test;
      +import static org.junit.Assert.assertArrayEquals;
      +import static org.junit.Assert.assertEquals;
      +import static org.junit.Assert.assertTrue;
      +
      +import org.apache.spark.api.java.JavaSparkContext;
      +import org.apache.spark.mllib.linalg.Vector;
      +import org.apache.spark.sql.DataFrame;
      +import org.apache.spark.sql.SQLContext;
      +
      +public class JavaKMeansSuite implements Serializable {
      +
      +  private transient int k = 5;
      +  private transient JavaSparkContext sc;
      +  private transient DataFrame dataset;
      +  private transient SQLContext sql;
      +
      +  @Before
      +  public void setUp() {
      +    sc = new JavaSparkContext("local", "JavaKMeansSuite");
      +    sql = new SQLContext(sc);
      +
      +    dataset = KMeansSuite.generateKMeansData(sql, 50, 3, k);
      +  }
      +
      +  @After
      +  public void tearDown() {
      +    sc.stop();
      +    sc = null;
      +  }
      +
      +  @Test
      +  public void fitAndTransform() {
      +    KMeans kmeans = new KMeans().setK(k).setSeed(1);
      +    KMeansModel model = kmeans.fit(dataset);
      +
      +    Vector[] centers = model.clusterCenters();
      +    assertEquals(k, centers.length);
      +
      +    DataFrame transformed = model.transform(dataset);
      +    List columns = Arrays.asList(transformed.columns());
      +    List expectedColumns = Arrays.asList("features", "prediction");
      +    for (String column: expectedColumns) {
      +      assertTrue(columns.contains(column));
      +    }
      +  }
      +}
      diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
      index d5bd230a957a..47d68de599da 100644
      --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
      @@ -17,7 +17,8 @@
       
       package org.apache.spark.ml.feature;
       
      -import com.google.common.collect.Lists;
      +import java.util.Arrays;
      +
       import org.junit.After;
       import org.junit.Assert;
       import org.junit.Before;
      @@ -54,7 +55,7 @@ public void tearDown() {
         public void bucketizerTest() {
           double[] splits = {-0.5, 0.0, 0.5};
       
      -    JavaRDD data = jsc.parallelize(Lists.newArrayList(
      +    JavaRDD data = jsc.parallelize(Arrays.asList(
             RowFactory.create(-0.5),
             RowFactory.create(-0.3),
             RowFactory.create(0.0),
      diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
      new file mode 100644
      index 000000000000..0f6ec64d97d3
      --- /dev/null
      +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
      @@ -0,0 +1,79 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.feature;
      +
      +import java.util.Arrays;
      +
      +import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D;
      +import org.junit.After;
      +import org.junit.Assert;
      +import org.junit.Before;
      +import org.junit.Test;
      +
      +import org.apache.spark.api.java.JavaRDD;
      +import org.apache.spark.api.java.JavaSparkContext;
      +import org.apache.spark.mllib.linalg.Vector;
      +import org.apache.spark.mllib.linalg.VectorUDT;
      +import org.apache.spark.mllib.linalg.Vectors;
      +import org.apache.spark.sql.DataFrame;
      +import org.apache.spark.sql.Row;
      +import org.apache.spark.sql.RowFactory;
      +import org.apache.spark.sql.SQLContext;
      +import org.apache.spark.sql.types.Metadata;
      +import org.apache.spark.sql.types.StructField;
      +import org.apache.spark.sql.types.StructType;
      +
      +public class JavaDCTSuite {
      +  private transient JavaSparkContext jsc;
      +  private transient SQLContext jsql;
      +
      +  @Before
      +  public void setUp() {
      +    jsc = new JavaSparkContext("local", "JavaDCTSuite");
      +    jsql = new SQLContext(jsc);
      +  }
      +
      +  @After
      +  public void tearDown() {
      +    jsc.stop();
      +    jsc = null;
      +  }
      +
      +  @Test
      +  public void javaCompatibilityTest() {
      +    double[] input = new double[] {1D, 2D, 3D, 4D};
      +    JavaRDD data = jsc.parallelize(Arrays.asList(
      +      RowFactory.create(Vectors.dense(input))
      +    ));
      +    DataFrame dataset = jsql.createDataFrame(data, new StructType(new StructField[]{
      +      new StructField("vec", (new VectorUDT()), false, Metadata.empty())
      +    }));
      +
      +    double[] expectedResult = input.clone();
      +    (new DoubleDCT_1D(input.length)).forward(expectedResult, true);
      +
      +    DCT dct = new DCT()
      +      .setInputCol("vec")
      +      .setOutputCol("resultVec");
      +
      +    Row[] result = dct.transform(dataset).select("resultVec").collect();
      +    Vector resultVec = result[0].getAs("resultVec");
      +
      +    Assert.assertArrayEquals(expectedResult, resultVec.toArray(), 1e-6);
      +  }
      +}
      diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
      index 599e9cfd23ad..03dd5369bddf 100644
      --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
      @@ -17,7 +17,8 @@
       
       package org.apache.spark.ml.feature;
       
      -import com.google.common.collect.Lists;
      +import java.util.Arrays;
      +
       import org.junit.After;
       import org.junit.Assert;
       import org.junit.Before;
      @@ -54,7 +55,7 @@ public void tearDown() {
       
         @Test
         public void hashingTF() {
      -    JavaRDD jrdd = jsc.parallelize(Lists.newArrayList(
      +    JavaRDD jrdd = jsc.parallelize(Arrays.asList(
             RowFactory.create(0.0, "Hi I heard about Spark"),
             RowFactory.create(0.0, "I wish Java could use case classes"),
             RowFactory.create(1.0, "Logistic regression models are neat")
      diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
      index d82f3b7e8c07..e17d549c5059 100644
      --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
      @@ -17,15 +17,15 @@
       
       package org.apache.spark.ml.feature;
       
      -import java.util.List;
      +import java.util.Arrays;
       
      -import com.google.common.collect.Lists;
       import org.junit.After;
       import org.junit.Before;
       import org.junit.Test;
       
       import org.apache.spark.api.java.JavaSparkContext;
       import org.apache.spark.mllib.linalg.Vectors;
      +import org.apache.spark.api.java.JavaRDD;
       import org.apache.spark.sql.DataFrame;
       import org.apache.spark.sql.SQLContext;
       
      @@ -48,13 +48,12 @@ public void tearDown() {
         @Test
         public void normalizer() {
           // The tests are to check Java compatibility.
      -    List points = Lists.newArrayList(
      +    JavaRDD points = jsc.parallelize(Arrays.asList(
             new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)),
             new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
             new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
      -    );
      -    DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2),
      -      VectorIndexerSuite.FeatureData.class);
      +    ));
      +    DataFrame dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class);
           Normalizer normalizer = new Normalizer()
             .setInputCol("features")
             .setOutputCol("normFeatures");
      diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
      new file mode 100644
      index 000000000000..e8f329f9cf29
      --- /dev/null
      +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
      @@ -0,0 +1,114 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.feature;
      +
      +import java.io.Serializable;
      +import java.util.Arrays;
      +import java.util.List;
      +
      +import scala.Tuple2;
      +
      +import org.junit.After;
      +import org.junit.Assert;
      +import org.junit.Before;
      +import org.junit.Test;
      +
      +import org.apache.spark.api.java.function.Function;
      +import org.apache.spark.api.java.JavaRDD;
      +import org.apache.spark.api.java.JavaSparkContext;
      +import org.apache.spark.mllib.linalg.distributed.RowMatrix;
      +import org.apache.spark.mllib.linalg.Matrix;
      +import org.apache.spark.mllib.linalg.Vector;
      +import org.apache.spark.mllib.linalg.Vectors;
      +import org.apache.spark.sql.DataFrame;
      +import org.apache.spark.sql.Row;
      +import org.apache.spark.sql.SQLContext;
      +
      +public class JavaPCASuite implements Serializable {
      +  private transient JavaSparkContext jsc;
      +  private transient SQLContext sqlContext;
      +
      +  @Before
      +  public void setUp() {
      +    jsc = new JavaSparkContext("local", "JavaPCASuite");
      +    sqlContext = new SQLContext(jsc);
      +  }
      +
      +  @After
      +  public void tearDown() {
      +    jsc.stop();
      +    jsc = null;
      +  }
      +
      +  public static class VectorPair implements Serializable {
      +    private Vector features = Vectors.dense(0.0);
      +    private Vector expected = Vectors.dense(0.0);
      +
      +    public void setFeatures(Vector features) {
      +      this.features = features;
      +    }
      +
      +    public Vector getFeatures() {
      +      return this.features;
      +    }
      +
      +    public void setExpected(Vector expected) {
      +      this.expected = expected;
      +    }
      +
      +    public Vector getExpected() {
      +      return this.expected;
      +    }
      +  }
      +
      +  @Test
      +  public void testPCA() {
      +    List points = Arrays.asList(
      +      Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0}),
      +      Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
      +      Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
      +    );
      +    JavaRDD dataRDD = jsc.parallelize(points, 2);
      +
      +    RowMatrix mat = new RowMatrix(dataRDD.rdd());
      +    Matrix pc = mat.computePrincipalComponents(3);
      +    JavaRDD expected = mat.multiply(pc).rows().toJavaRDD();
      +
      +    JavaRDD featuresExpected = dataRDD.zip(expected).map(
      +      new Function, VectorPair>() {
      +        public VectorPair call(Tuple2 pair) {
      +          VectorPair featuresExpected = new VectorPair();
      +          featuresExpected.setFeatures(pair._1());
      +          featuresExpected.setExpected(pair._2());
      +          return featuresExpected;
      +        }
      +      }
      +    );
      +
      +    DataFrame df = sqlContext.createDataFrame(featuresExpected, VectorPair.class);
      +    PCAModel pca = new PCA()
      +      .setInputCol("features")
      +      .setOutputCol("pca_features")
      +      .setK(3)
      +      .fit(df);
      +    List result = pca.transform(df).select("pca_features", "expected").toJavaRDD().collect();
      +    for (Row r : result) {
      +      Assert.assertEquals(r.get(1), r.get(0));
      +    }
      +  }
      +}
      diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
      index 5e8211c2c511..834fedbb59e1 100644
      --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
      @@ -17,7 +17,8 @@
       
       package org.apache.spark.ml.feature;
       
      -import com.google.common.collect.Lists;
      +import java.util.Arrays;
      +
       import org.junit.After;
       import org.junit.Assert;
       import org.junit.Before;
      @@ -59,7 +60,7 @@ public void polynomialExpansionTest() {
             .setOutputCol("polyFeatures")
             .setDegree(3);
       
      -    JavaRDD data = jsc.parallelize(Lists.newArrayList(
      +    JavaRDD data = jsc.parallelize(Arrays.asList(
             RowFactory.create(
               Vectors.dense(-2.0, 2.3),
               Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17)
      diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
      index 74eb2733f06e..ed74363f59e3 100644
      --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
      @@ -17,9 +17,9 @@
       
       package org.apache.spark.ml.feature;
       
      +import java.util.Arrays;
       import java.util.List;
       
      -import com.google.common.collect.Lists;
       import org.junit.After;
       import org.junit.Before;
       import org.junit.Test;
      @@ -48,7 +48,7 @@ public void tearDown() {
         @Test
         public void standardScaler() {
           // The tests are to check Java compatibility.
      -    List points = Lists.newArrayList(
      +    List points = Arrays.asList(
             new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)),
             new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
             new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
      diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
      new file mode 100644
      index 000000000000..76cdd0fae84a
      --- /dev/null
      +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
      @@ -0,0 +1,72 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.feature;
      +
      +import java.util.Arrays;
      +
      +import org.junit.After;
      +import org.junit.Before;
      +import org.junit.Test;
      +
      +import org.apache.spark.api.java.JavaRDD;
      +import org.apache.spark.api.java.JavaSparkContext;
      +import org.apache.spark.sql.DataFrame;
      +import org.apache.spark.sql.Row;
      +import org.apache.spark.sql.RowFactory;
      +import org.apache.spark.sql.SQLContext;
      +import org.apache.spark.sql.types.DataTypes;
      +import org.apache.spark.sql.types.Metadata;
      +import org.apache.spark.sql.types.StructField;
      +import org.apache.spark.sql.types.StructType;
      +
      +
      +public class JavaStopWordsRemoverSuite {
      +
      +  private transient JavaSparkContext jsc;
      +  private transient SQLContext jsql;
      +
      +  @Before
      +  public void setUp() {
      +    jsc = new JavaSparkContext("local", "JavaStopWordsRemoverSuite");
      +    jsql = new SQLContext(jsc);
      +  }
      +
      +  @After
      +  public void tearDown() {
      +    jsc.stop();
      +    jsc = null;
      +  }
      +
      +  @Test
      +  public void javaCompatibilityTest() {
      +    StopWordsRemover remover = new StopWordsRemover()
      +      .setInputCol("raw")
      +      .setOutputCol("filtered");
      +
      +    JavaRDD rdd = jsc.parallelize(Arrays.asList(
      +      RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")),
      +      RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb"))
      +    ));
      +    StructType schema = new StructType(new StructField[] {
      +      new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty())
      +    });
      +    DataFrame dataset = jsql.createDataFrame(rdd, schema);
      +
      +    remover.transform(dataset).collect();
      +  }
      +}
      diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
      index 3806f650025b..02309ce63219 100644
      --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
      @@ -17,7 +17,8 @@
       
       package org.apache.spark.ml.feature;
       
      -import com.google.common.collect.Lists;
      +import java.util.Arrays;
      +
       import org.junit.After;
       import org.junit.Assert;
       import org.junit.Before;
      @@ -54,7 +55,8 @@ public void regexTokenizer() {
             .setGaps(true)
             .setMinTokenLength(3);
       
      -    JavaRDD rdd = jsc.parallelize(Lists.newArrayList(
      +
      +    JavaRDD rdd = jsc.parallelize(Arrays.asList(
             new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}),
             new TokenizerTestData("Te,st.  punct", new String[] {"Te,st.", "punct"})
           ));
      diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
      index c7ae5468b942..bfcca62fa1c9 100644
      --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
      @@ -18,6 +18,7 @@
       package org.apache.spark.ml.feature;
       
       import java.io.Serializable;
      +import java.util.Arrays;
       import java.util.List;
       import java.util.Map;
       
      @@ -26,8 +27,6 @@
       import org.junit.Before;
       import org.junit.Test;
       
      -import com.google.common.collect.Lists;
      -
       import org.apache.spark.api.java.JavaSparkContext;
       import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData;
       import org.apache.spark.mllib.linalg.Vectors;
      @@ -52,7 +51,7 @@ public void tearDown() {
         @Test
         public void vectorIndexerAPI() {
           // The tests are to check Java compatibility.
      -    List points = Lists.newArrayList(
      +    List points = Arrays.asList(
             new FeatureData(Vectors.dense(0.0, -2.0)),
             new FeatureData(Vectors.dense(1.0, 3.0)),
             new FeatureData(Vectors.dense(1.0, 4.0))
      diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
      new file mode 100644
      index 000000000000..f95336142758
      --- /dev/null
      +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
      @@ -0,0 +1,85 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.feature;
      +
      +import java.util.Arrays;
      +
      +import org.junit.After;
      +import org.junit.Assert;
      +import org.junit.Before;
      +import org.junit.Test;
      +
      +import org.apache.spark.api.java.JavaRDD;
      +import org.apache.spark.api.java.JavaSparkContext;
      +import org.apache.spark.ml.attribute.Attribute;
      +import org.apache.spark.ml.attribute.AttributeGroup;
      +import org.apache.spark.ml.attribute.NumericAttribute;
      +import org.apache.spark.mllib.linalg.Vector;
      +import org.apache.spark.mllib.linalg.Vectors;
      +import org.apache.spark.sql.DataFrame;
      +import org.apache.spark.sql.Row;
      +import org.apache.spark.sql.RowFactory;
      +import org.apache.spark.sql.SQLContext;
      +import org.apache.spark.sql.types.StructType;
      +
      +
      +public class JavaVectorSlicerSuite {
      +  private transient JavaSparkContext jsc;
      +  private transient SQLContext jsql;
      +
      +  @Before
      +  public void setUp() {
      +    jsc = new JavaSparkContext("local", "JavaVectorSlicerSuite");
      +    jsql = new SQLContext(jsc);
      +  }
      +
      +  @After
      +  public void tearDown() {
      +    jsc.stop();
      +    jsc = null;
      +  }
      +
      +  @Test
      +  public void vectorSlice() {
      +    Attribute[] attrs = new Attribute[]{
      +      NumericAttribute.defaultAttr().withName("f1"),
      +      NumericAttribute.defaultAttr().withName("f2"),
      +      NumericAttribute.defaultAttr().withName("f3")
      +    };
      +    AttributeGroup group = new AttributeGroup("userFeatures", attrs);
      +
      +    JavaRDD jrdd = jsc.parallelize(Arrays.asList(
      +      RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})),
      +      RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0))
      +    ));
      +
      +    DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField()));
      +
      +    VectorSlicer vectorSlicer = new VectorSlicer()
      +      .setInputCol("userFeatures").setOutputCol("features");
      +
      +    vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"});
      +
      +    DataFrame output = vectorSlicer.transform(dataset);
      +
      +    for (Row r : output.select("userFeatures", "features").take(2)) {
      +      Vector features = r.getAs(1);
      +      Assert.assertEquals(features.size(), 2);
      +    }
      +  }
      +}
      diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
      index 39c70157f83c..70f5ad943221 100644
      --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
      @@ -17,7 +17,8 @@
       
       package org.apache.spark.ml.feature;
       
      -import com.google.common.collect.Lists;
      +import java.util.Arrays;
      +
       import org.junit.After;
       import org.junit.Assert;
       import org.junit.Before;
      @@ -50,10 +51,10 @@ public void tearDown() {
       
         @Test
         public void testJavaWord2Vec() {
      -    JavaRDD jrdd = jsc.parallelize(Lists.newArrayList(
      -      RowFactory.create(Lists.newArrayList("Hi I heard about Spark".split(" "))),
      -      RowFactory.create(Lists.newArrayList("I wish Java could use case classes".split(" "))),
      -      RowFactory.create(Lists.newArrayList("Logistic regression models are neat".split(" ")))
      +    JavaRDD jrdd = jsc.parallelize(Arrays.asList(
      +      RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))),
      +      RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))),
      +      RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" ")))
           ));
           StructType schema = new StructType(new StructField[]{
             new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())
      diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
      index 9890155e9f86..fa777f3d42a9 100644
      --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
      @@ -17,7 +17,8 @@
       
       package org.apache.spark.ml.param;
       
      -import com.google.common.collect.Lists;
      +import java.util.Arrays;
      +
       import org.junit.After;
       import org.junit.Assert;
       import org.junit.Before;
      @@ -61,7 +62,7 @@ public void testParamValidate() {
           ParamValidators.ltEq(1.0);
           ParamValidators.inRange(0, 1, true, false);
           ParamValidators.inRange(0, 1);
      -    ParamValidators.inArray(Lists.newArrayList(0, 1, 3));
      -    ParamValidators.inArray(Lists.newArrayList("a", "b"));
      +    ParamValidators.inArray(Arrays.asList(0, 1, 3));
      +    ParamValidators.inArray(Arrays.asList("a", "b"));
         }
       }
      diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
      index 3ae09d39ef50..65841182df9b 100644
      --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
      +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
      @@ -17,10 +17,9 @@
       
       package org.apache.spark.ml.param;
       
      +import java.util.Arrays;
       import java.util.List;
       
      -import com.google.common.collect.Lists;
      -
       import org.apache.spark.ml.util.Identifiable$;
       
       /**
      @@ -89,18 +88,15 @@ private void init() {
           myIntParam_ = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0));
           myDoubleParam_ = new DoubleParam(this, "myDoubleParam", "this is a double param",
             ParamValidators.inRange(0.0, 1.0));
      -    List validStrings = Lists.newArrayList("a", "b");
      +    List validStrings = Arrays.asList("a", "b");
           myStringParam_ = new Param(this, "myStringParam", "this is a string param",
             ParamValidators.inArray(validStrings));
           myDoubleArrayParam_ =
             new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param");
       
           setDefault(myIntParam(), 1);
      -    setDefault(myIntParam().w(1));
           setDefault(myDoubleParam(), 0.5);
      -    setDefault(myIntParam().w(1), myDoubleParam().w(0.5));
           setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0});
      -    setDefault(myDoubleArrayParam().w(new double[] {1.0, 2.0}));
         }
       
         @Override
      diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
      index 71b041818d7e..ebe800e749e0 100644
      --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
      @@ -57,7 +57,7 @@ public void runDT() {
           JavaRDD data = sc.parallelize(
             LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
           Map categoricalFeatures = new HashMap();
      -    DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
      +    DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
       
           // This tests setters. Training with various options is tested in Scala.
           DecisionTreeRegressor dt = new DecisionTreeRegressor()
      diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
      index d591a456864e..91c589d00abd 100644
      --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
      @@ -60,7 +60,7 @@ public void tearDown() {
         @Test
         public void linearRegressionDefaultParams() {
           LinearRegression lr = new LinearRegression();
      -    assert(lr.getLabelCol().equals("label"));
      +    assertEquals("label", lr.getLabelCol());
           LinearRegressionModel model = lr.fit(dataset);
           model.transform(dataset).registerTempTable("prediction");
           DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction");
      diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
      index e306ebadfe7c..a00ce5e249c3 100644
      --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
      @@ -29,6 +29,7 @@
       import org.apache.spark.api.java.JavaSparkContext;
       import org.apache.spark.mllib.classification.LogisticRegressionSuite;
       import org.apache.spark.ml.impl.TreeTests;
      +import org.apache.spark.mllib.linalg.Vector;
       import org.apache.spark.mllib.regression.LabeledPoint;
       import org.apache.spark.sql.DataFrame;
       
      @@ -85,6 +86,7 @@ public void runDT() {
           model.toDebugString();
           model.trees();
           model.treeWeights();
      +    Vector importances = model.featureImportances();
       
           /*
           // TODO: Add test once save/load are implemented.   SPARK-6725
      diff --git a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
      new file mode 100644
      index 000000000000..2976b38e4503
      --- /dev/null
      +++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
      @@ -0,0 +1,80 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.source.libsvm;
      +
      +import java.io.File;
      +import java.io.IOException;
      +
      +import com.google.common.base.Charsets;
      +import com.google.common.io.Files;
      +
      +import org.junit.After;
      +import org.junit.Assert;
      +import org.junit.Before;
      +import org.junit.Test;
      +
      +import org.apache.spark.api.java.JavaSparkContext;
      +import org.apache.spark.mllib.linalg.DenseVector;
      +import org.apache.spark.mllib.linalg.Vectors;
      +import org.apache.spark.sql.DataFrame;
      +import org.apache.spark.sql.Row;
      +import org.apache.spark.sql.SQLContext;
      +import org.apache.spark.util.Utils;
      +
      +
      +/**
      + * Test LibSVMRelation in Java.
      + */
      +public class JavaLibSVMRelationSuite {
      +  private transient JavaSparkContext jsc;
      +  private transient SQLContext sqlContext;
      +
      +  private File tempDir;
      +  private String path;
      +
      +  @Before
      +  public void setUp() throws IOException {
      +    jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite");
      +    sqlContext = new SQLContext(jsc);
      +
      +    tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource");
      +    File file = new File(tempDir, "part-00000");
      +    String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0";
      +    Files.write(s, file, Charsets.US_ASCII);
      +    path = tempDir.toURI().toString();
      +  }
      +
      +  @After
      +  public void tearDown() {
      +    jsc.stop();
      +    jsc = null;
      +    Utils.deleteRecursively(tempDir);
      +  }
      +
      +  @Test
      +  public void verifyLibSVMDF() {
      +    DataFrame dataset = sqlContext.read().format("libsvm").option("vectorType", "dense")
      +      .load(path);
      +    Assert.assertEquals("label", dataset.columns()[0]);
      +    Assert.assertEquals("features", dataset.columns()[1]);
      +    Row r = dataset.first();
      +    Assert.assertEquals(1.0, r.getDouble(0), 1e-15);
      +    DenseVector v = r.getAs(1);
      +    Assert.assertEquals(Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0), v);
      +  }
      +}
      diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java
      index 55787f8606d4..c9e5ee22f327 100644
      --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java
      @@ -18,11 +18,11 @@
       package org.apache.spark.mllib.classification;
       
       import java.io.Serializable;
      +import java.util.Arrays;
       import java.util.List;
       
       import scala.Tuple2;
       
      -import com.google.common.collect.Lists;
       import org.junit.After;
       import org.junit.Before;
       import org.junit.Test;
      @@ -60,16 +60,16 @@ public void tearDown() {
         @Test
         @SuppressWarnings("unchecked")
         public void javaAPI() {
      -    List trainingBatch = Lists.newArrayList(
      +    List trainingBatch = Arrays.asList(
             new LabeledPoint(1.0, Vectors.dense(1.0)),
             new LabeledPoint(0.0, Vectors.dense(0.0)));
           JavaDStream training =
      -      attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2);
      -    List> testBatch = Lists.newArrayList(
      +      attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2);
      +    List> testBatch = Arrays.asList(
             new Tuple2(10, Vectors.dense(1.0)),
             new Tuple2(11, Vectors.dense(0.0)));
           JavaPairDStream test = JavaPairDStream.fromJavaDStream(
      -      attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2));
      +      attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2));
           StreamingLogisticRegressionWithSGD slr = new StreamingLogisticRegressionWithSGD()
             .setNumIterations(2)
             .setInitialWeights(Vectors.dense(0.0));
      diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
      index 467a7a69e8f3..123f78da54e3 100644
      --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
      @@ -18,9 +18,9 @@
       package org.apache.spark.mllib.clustering;
       
       import java.io.Serializable;
      +import java.util.Arrays;
       import java.util.List;
       
      -import com.google.common.collect.Lists;
       import org.junit.After;
       import org.junit.Before;
       import org.junit.Test;
      @@ -48,7 +48,7 @@ public void tearDown() {
       
         @Test
         public void runGaussianMixture() {
      -    List points = Lists.newArrayList(
      +    List points = Arrays.asList(
             Vectors.dense(1.0, 2.0, 6.0),
             Vectors.dense(1.0, 3.0, 0.0),
             Vectors.dense(1.0, 4.0, 6.0)
      diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
      index 31676e64025d..ad06676c72ac 100644
      --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
      @@ -18,6 +18,7 @@
       package org.apache.spark.mllib.clustering;
       
       import java.io.Serializable;
      +import java.util.Arrays;
       import java.util.List;
       
       import org.junit.After;
      @@ -25,8 +26,6 @@
       import org.junit.Test;
       import static org.junit.Assert.*;
       
      -import com.google.common.collect.Lists;
      -
       import org.apache.spark.api.java.JavaRDD;
       import org.apache.spark.api.java.JavaSparkContext;
       import org.apache.spark.mllib.linalg.Vector;
      @@ -48,7 +47,7 @@ public void tearDown() {
       
         @Test
         public void runKMeansUsingStaticMethods() {
      -    List points = Lists.newArrayList(
      +    List points = Arrays.asList(
             Vectors.dense(1.0, 2.0, 6.0),
             Vectors.dense(1.0, 3.0, 0.0),
             Vectors.dense(1.0, 4.0, 6.0)
      @@ -67,7 +66,7 @@ public void runKMeansUsingStaticMethods() {
       
         @Test
         public void runKMeansUsingConstructor() {
      -    List points = Lists.newArrayList(
      +    List points = Arrays.asList(
             Vectors.dense(1.0, 2.0, 6.0),
             Vectors.dense(1.0, 3.0, 0.0),
             Vectors.dense(1.0, 4.0, 6.0)
      @@ -90,7 +89,7 @@ public void runKMeansUsingConstructor() {
       
         @Test
         public void testPredictJavaRDD() {
      -    List points = Lists.newArrayList(
      +    List points = Arrays.asList(
             Vectors.dense(1.0, 2.0, 6.0),
             Vectors.dense(1.0, 3.0, 0.0),
             Vectors.dense(1.0, 4.0, 6.0)
      diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
      index 581c033f08eb..3fea359a3b46 100644
      --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
      +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
      @@ -19,21 +19,25 @@
       
       import java.io.Serializable;
       import java.util.ArrayList;
      +import java.util.Arrays;
       
       import scala.Tuple2;
      +import scala.Tuple3;
       
       import org.junit.After;
      -import static org.junit.Assert.assertEquals;
      -import static org.junit.Assert.assertArrayEquals;
       import org.junit.Before;
       import org.junit.Test;
      +import static org.junit.Assert.assertArrayEquals;
      +import static org.junit.Assert.assertEquals;
      +import static org.junit.Assert.assertTrue;
       
      +import org.apache.spark.api.java.function.Function;
       import org.apache.spark.api.java.JavaPairRDD;
       import org.apache.spark.api.java.JavaRDD;
       import org.apache.spark.api.java.JavaSparkContext;
       import org.apache.spark.mllib.linalg.Matrix;
       import org.apache.spark.mllib.linalg.Vector;
      -
      +import org.apache.spark.mllib.linalg.Vectors;
       
       public class JavaLDASuite implements Serializable {
         private transient JavaSparkContext sc;
      @@ -42,9 +46,9 @@ public class JavaLDASuite implements Serializable {
         public void setUp() {
           sc = new JavaSparkContext("local", "JavaLDA");
           ArrayList> tinyCorpus = new ArrayList>();
      -    for (int i = 0; i < LDASuite$.MODULE$.tinyCorpus().length; i++) {
      -      tinyCorpus.add(new Tuple2((Long)LDASuite$.MODULE$.tinyCorpus()[i]._1(),
      -          LDASuite$.MODULE$.tinyCorpus()[i]._2()));
      +    for (int i = 0; i < LDASuite.tinyCorpus().length; i++) {
      +      tinyCorpus.add(new Tuple2((Long)LDASuite.tinyCorpus()[i]._1(),
      +          LDASuite.tinyCorpus()[i]._2()));
           }
           JavaRDD> tmpCorpus = sc.parallelize(tinyCorpus, 2);
           corpus = JavaPairRDD.fromJavaRDD(tmpCorpus);
      @@ -58,7 +62,10 @@ public void tearDown() {
       
         @Test
         public void localLDAModel() {
      -    LocalLDAModel model = new LocalLDAModel(LDASuite$.MODULE$.tinyTopics());
      +    Matrix topics = LDASuite.tinyTopics();
      +    double[] topicConcentration = new double[topics.numRows()];
      +    Arrays.fill(topicConcentration, 1.0D / topics.numRows());
      +    LocalLDAModel model = new LocalLDAModel(topics, Vectors.dense(topicConcentration), 1D, 100D);
       
           // Check: basic parameters
           assertEquals(model.k(), tinyK);
      @@ -105,12 +112,35 @@ public void distributedLDAModel() {
           assertEquals(roundedLocalTopicSummary.length, k);
       
           // Check: log probabilities
      -    assert(model.logLikelihood() < 0.0);
      -    assert(model.logPrior() < 0.0);
      +    assertTrue(model.logLikelihood() < 0.0);
      +    assertTrue(model.logPrior() < 0.0);
       
           // Check: topic distributions
           JavaPairRDD topicDistributions = model.javaTopicDistributions();
      -    assertEquals(topicDistributions.count(), corpus.count());
      +    // SPARK-5562. since the topicDistribution returns the distribution of the non empty docs
      +    // over topics. Compare it against nonEmptyCorpus instead of corpus
      +    JavaPairRDD nonEmptyCorpus = corpus.filter(
      +      new Function, Boolean>() {
      +        public Boolean call(Tuple2 tuple2) {
      +          return Vectors.norm(tuple2._2(), 1.0) != 0.0;
      +        }
      +    });
      +    assertEquals(topicDistributions.count(), nonEmptyCorpus.count());
      +
      +    // Check: javaTopTopicsPerDocuments
      +    Tuple3 topTopics = model.javaTopTopicsPerDocument(3).first();
      +    Long docId = topTopics._1(); // confirm doc ID type
      +    int[] topicIndices = topTopics._2();
      +    double[] topicWeights = topTopics._3();
      +    assertEquals(3, topicIndices.length);
      +    assertEquals(3, topicWeights.length);
      +
      +    // Check: topTopicAssignments
      +    Tuple3 topicAssignment = model.javaTopicAssignments().first();
      +    Long docId2 = topicAssignment._1();
      +    int[] termIndices2 = topicAssignment._2();
      +    int[] topicIndices2 = topicAssignment._3();
      +    assertEquals(termIndices2.length, topicIndices2.length);
         }
       
         @Test
      @@ -147,11 +177,31 @@ public void OnlineOptimizerCompatibility() {
           assertEquals(roundedLocalTopicSummary.length, k);
         }
       
      -  private static int tinyK = LDASuite$.MODULE$.tinyK();
      -  private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize();
      -  private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics();
      +  @Test
      +  public void localLdaMethods() {
      +    JavaRDD> docs = sc.parallelize(toyData, 2);
      +    JavaPairRDD pairedDocs = JavaPairRDD.fromJavaRDD(docs);
      +
      +    // check: topicDistributions
      +    assertEquals(toyModel.topicDistributions(pairedDocs).count(), pairedDocs.count());
      +
      +    // check: logPerplexity
      +    double logPerplexity = toyModel.logPerplexity(pairedDocs);
      +
      +    // check: logLikelihood.
      +    ArrayList> docsSingleWord = new ArrayList>();
      +    docsSingleWord.add(new Tuple2(0L, Vectors.dense(1.0, 0.0, 0.0)));
      +    JavaPairRDD single = JavaPairRDD.fromJavaRDD(sc.parallelize(docsSingleWord));
      +    double logLikelihood = toyModel.logLikelihood(single);
      +  }
      +
      +  private static int tinyK = LDASuite.tinyK();
      +  private static int tinyVocabSize = LDASuite.tinyVocabSize();
      +  private static Matrix tinyTopics = LDASuite.tinyTopics();
         private static Tuple2[] tinyTopicDescription =
      -      LDASuite$.MODULE$.tinyTopicDescription();
      +      LDASuite.tinyTopicDescription();
         private JavaPairRDD corpus;
      +  private LocalLDAModel toyModel = LDASuite.toyModel();
      +  private ArrayList> toyData = LDASuite.javaToyData();
       
       }
      diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
      index 3b0e879eec77..d644766d1e54 100644
      --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
      @@ -18,11 +18,11 @@
       package org.apache.spark.mllib.clustering;
       
       import java.io.Serializable;
      +import java.util.Arrays;
       import java.util.List;
       
       import scala.Tuple2;
       
      -import com.google.common.collect.Lists;
       import org.junit.After;
       import org.junit.Before;
       import org.junit.Test;
      @@ -60,16 +60,16 @@ public void tearDown() {
         @Test
         @SuppressWarnings("unchecked")
         public void javaAPI() {
      -    List trainingBatch = Lists.newArrayList(
      +    List trainingBatch = Arrays.asList(
             Vectors.dense(1.0),
             Vectors.dense(0.0));
           JavaDStream training =
      -      attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2);
      -    List> testBatch = Lists.newArrayList(
      +      attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2);
      +    List> testBatch = Arrays.asList(
             new Tuple2(10, Vectors.dense(1.0)),
             new Tuple2(11, Vectors.dense(0.0)));
           JavaPairDStream test = JavaPairDStream.fromJavaDStream(
      -      attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2));
      +      attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2));
           StreamingKMeans skmeans = new StreamingKMeans()
             .setK(1)
             .setDecayFactor(1.0)
      diff --git a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
      index effc8a1a6dab..fa4d334801ce 100644
      --- a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
      @@ -18,12 +18,12 @@
       package org.apache.spark.mllib.evaluation;
       
       import java.io.Serializable;
      -import java.util.ArrayList;
      +import java.util.Arrays;
      +import java.util.List;
       
       import scala.Tuple2;
       import scala.Tuple2$;
       
      -import com.google.common.collect.Lists;
       import org.junit.After;
       import org.junit.Assert;
       import org.junit.Before;
      @@ -34,18 +34,18 @@
       
       public class JavaRankingMetricsSuite implements Serializable {
         private transient JavaSparkContext sc;
      -  private transient JavaRDD, ArrayList>> predictionAndLabels;
      +  private transient JavaRDD, List>> predictionAndLabels;
       
         @Before
         public void setUp() {
           sc = new JavaSparkContext("local", "JavaRankingMetricsSuite");
      -    predictionAndLabels = sc.parallelize(Lists.newArrayList(
      +    predictionAndLabels = sc.parallelize(Arrays.asList(
             Tuple2$.MODULE$.apply(
      -        Lists.newArrayList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Lists.newArrayList(1, 2, 3, 4, 5)),
      +        Arrays.asList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Arrays.asList(1, 2, 3, 4, 5)),
             Tuple2$.MODULE$.apply(
      -        Lists.newArrayList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Lists.newArrayList(1, 2, 3)),
      +          Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)),
             Tuple2$.MODULE$.apply(
      -        Lists.newArrayList(1, 2, 3, 4, 5), Lists.newArrayList())), 2);
      +          Arrays.asList(1, 2, 3, 4, 5), Arrays.asList())), 2);
         }
       
         @After
      diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java
      index fbc26167ce66..8a320afa4b13 100644
      --- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java
      @@ -18,14 +18,13 @@
       package org.apache.spark.mllib.feature;
       
       import java.io.Serializable;
      -import java.util.ArrayList;
      +import java.util.Arrays;
       import java.util.List;
       
       import org.junit.After;
       import org.junit.Assert;
       import org.junit.Before;
       import org.junit.Test;
      -import com.google.common.collect.Lists;
       
       import org.apache.spark.api.java.JavaRDD;
       import org.apache.spark.api.java.JavaSparkContext;
      @@ -50,10 +49,10 @@ public void tfIdf() {
           // The tests are to check Java compatibility.
           HashingTF tf = new HashingTF();
           @SuppressWarnings("unchecked")
      -    JavaRDD> documents = sc.parallelize(Lists.newArrayList(
      -      Lists.newArrayList("this is a sentence".split(" ")),
      -      Lists.newArrayList("this is another sentence".split(" ")),
      -      Lists.newArrayList("this is still a sentence".split(" "))), 2);
      +    JavaRDD> documents = sc.parallelize(Arrays.asList(
      +      Arrays.asList("this is a sentence".split(" ")),
      +      Arrays.asList("this is another sentence".split(" ")),
      +      Arrays.asList("this is still a sentence".split(" "))), 2);
           JavaRDD termFreqs = tf.transform(documents);
           termFreqs.collect();
           IDF idf = new IDF();
      @@ -70,10 +69,10 @@ public void tfIdfMinimumDocumentFrequency() {
           // The tests are to check Java compatibility.
           HashingTF tf = new HashingTF();
           @SuppressWarnings("unchecked")
      -    JavaRDD> documents = sc.parallelize(Lists.newArrayList(
      -      Lists.newArrayList("this is a sentence".split(" ")),
      -      Lists.newArrayList("this is another sentence".split(" ")),
      -      Lists.newArrayList("this is still a sentence".split(" "))), 2);
      +    JavaRDD> documents = sc.parallelize(Arrays.asList(
      +      Arrays.asList("this is a sentence".split(" ")),
      +      Arrays.asList("this is another sentence".split(" ")),
      +      Arrays.asList("this is still a sentence".split(" "))), 2);
           JavaRDD termFreqs = tf.transform(documents);
           termFreqs.collect();
           IDF idf = new IDF(2);
      diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java
      index fb7afe8c6434..e13ed07e283d 100644
      --- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java
      @@ -18,11 +18,11 @@
       package org.apache.spark.mllib.feature;
       
       import java.io.Serializable;
      +import java.util.Arrays;
       import java.util.List;
       
       import scala.Tuple2;
       
      -import com.google.common.collect.Lists;
       import com.google.common.base.Strings;
       import org.junit.After;
       import org.junit.Assert;
      @@ -51,8 +51,8 @@ public void tearDown() {
         public void word2Vec() {
           // The tests are to check Java compatibility.
           String sentence = Strings.repeat("a b ", 100) + Strings.repeat("a c ", 10);
      -    List words = Lists.newArrayList(sentence.split(" "));
      -    List> localDoc = Lists.newArrayList(words, words);
      +    List words = Arrays.asList(sentence.split(" "));
      +    List> localDoc = Arrays.asList(words, words);
           JavaRDD> doc = sc.parallelize(localDoc);
           Word2Vec word2vec = new Word2Vec()
             .setVectorSize(10)
      diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
      new file mode 100644
      index 000000000000..2bef7a860975
      --- /dev/null
      +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
      @@ -0,0 +1,57 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +package org.apache.spark.mllib.fpm;
      +
      +import java.io.Serializable;
      +import java.util.Arrays;
      +
      +import org.junit.After;
      +import org.junit.Before;
      +import org.junit.Test;
      +
      +import org.apache.spark.api.java.JavaRDD;
      +import org.apache.spark.api.java.JavaSparkContext;
      +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
      +
      +public class JavaAssociationRulesSuite implements Serializable {
      +  private transient JavaSparkContext sc;
      +
      +  @Before
      +  public void setUp() {
      +    sc = new JavaSparkContext("local", "JavaFPGrowth");
      +  }
      +
      +  @After
      +  public void tearDown() {
      +    sc.stop();
      +    sc = null;
      +  }
      +
      +  @Test
      +  public void runAssociationRules() {
      +
      +    @SuppressWarnings("unchecked")
      +    JavaRDD> freqItemsets = sc.parallelize(Arrays.asList(
      +      new FreqItemset(new String[] {"a"}, 15L),
      +      new FreqItemset(new String[] {"b"}, 35L),
      +      new FreqItemset(new String[] {"a", "b"}, 12L)
      +    ));
      +
      +    JavaRDD> results = (new AssociationRules()).run(freqItemsets);
      +  }
      +}
      +
      diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
      index bd0edf2b9ea6..154f75d75e4a 100644
      --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
      @@ -18,18 +18,16 @@
       package org.apache.spark.mllib.fpm;
       
       import java.io.Serializable;
      -import java.util.ArrayList;
      +import java.util.Arrays;
       import java.util.List;
       
       import org.junit.After;
       import org.junit.Before;
       import org.junit.Test;
      -import com.google.common.collect.Lists;
       import static org.junit.Assert.*;
       
       import org.apache.spark.api.java.JavaRDD;
       import org.apache.spark.api.java.JavaSparkContext;
      -import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
       
       public class JavaFPGrowthSuite implements Serializable {
         private transient JavaSparkContext sc;
      @@ -49,23 +47,23 @@ public void tearDown() {
         public void runFPGrowth() {
       
           @SuppressWarnings("unchecked")
      -    JavaRDD> rdd = sc.parallelize(Lists.newArrayList(
      -      Lists.newArrayList("r z h k p".split(" ")),
      -      Lists.newArrayList("z y x w v u t s".split(" ")),
      -      Lists.newArrayList("s x o n r".split(" ")),
      -      Lists.newArrayList("x z y m t s q e".split(" ")),
      -      Lists.newArrayList("z".split(" ")),
      -      Lists.newArrayList("x z y r q t p".split(" "))), 2);
      +    JavaRDD> rdd = sc.parallelize(Arrays.asList(
      +      Arrays.asList("r z h k p".split(" ")),
      +      Arrays.asList("z y x w v u t s".split(" ")),
      +      Arrays.asList("s x o n r".split(" ")),
      +      Arrays.asList("x z y m t s q e".split(" ")),
      +      Arrays.asList("z".split(" ")),
      +      Arrays.asList("x z y r q t p".split(" "))), 2);
       
           FPGrowthModel model = new FPGrowth()
             .setMinSupport(0.5)
             .setNumPartitions(2)
             .run(rdd);
       
      -    List> freqItemsets = model.freqItemsets().toJavaRDD().collect();
      +    List> freqItemsets = model.freqItemsets().toJavaRDD().collect();
           assertEquals(18, freqItemsets.size());
       
      -    for (FreqItemset itemset: freqItemsets) {
      +    for (FPGrowth.FreqItemset itemset: freqItemsets) {
             // Test return types.
             List items = itemset.javaItems();
             long freq = itemset.freq();
      diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
      new file mode 100644
      index 000000000000..34daf5fbde80
      --- /dev/null
      +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
      @@ -0,0 +1,67 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.mllib.fpm;
      +
      +import java.util.Arrays;
      +import java.util.List;
      +
      +import org.junit.After;
      +import org.junit.Assert;
      +import org.junit.Before;
      +import org.junit.Test;
      +
      +import org.apache.spark.api.java.JavaRDD;
      +import org.apache.spark.api.java.JavaSparkContext;
      +import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence;
      +
      +public class JavaPrefixSpanSuite {
      +  private transient JavaSparkContext sc;
      +
      +  @Before
      +  public void setUp() {
      +    sc = new JavaSparkContext("local", "JavaPrefixSpan");
      +  }
      +
      +  @After
      +  public void tearDown() {
      +    sc.stop();
      +    sc = null;
      +  }
      +
      +  @Test
      +  public void runPrefixSpan() {
      +    JavaRDD>> sequences = sc.parallelize(Arrays.asList(
      +      Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)),
      +      Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)),
      +      Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)),
      +      Arrays.asList(Arrays.asList(6))
      +    ), 2);
      +    PrefixSpan prefixSpan = new PrefixSpan()
      +      .setMinSupport(0.5)
      +      .setMaxPatternLength(5);
      +    PrefixSpanModel model = prefixSpan.run(sequences);
      +    JavaRDD> freqSeqs = model.freqSequences().toJavaRDD();
      +    List> localFreqSeqs = freqSeqs.collect();
      +    Assert.assertEquals(5, localFreqSeqs.size());
      +    // Check that each frequent sequence could be materialized.
      +    for (PrefixSpan.FreqSequence freqSeq: localFreqSeqs) {
      +      List> seq = freqSeq.javaSequence();
      +      long freq = freqSeq.freq();
      +    }
      +  }
      +}
      diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
      index 3349c5022423..8beea102efd0 100644
      --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
      @@ -80,10 +80,10 @@ public void diagonalMatrixConstruction() {
               assertArrayEquals(sd.toArray(), s.toArray(), 0.0);
               assertArrayEquals(s.toArray(), ss.toArray(), 0.0);
               assertArrayEquals(s.values(), ss.values(), 0.0);
      -        assert(s.values().length == 2);
      -        assert(ss.values().length == 2);
      -        assert(s.colPtrs().length == 4);
      -        assert(ss.colPtrs().length == 4);
      +        assertEquals(2, s.values().length);
      +        assertEquals(2, ss.values().length);
      +        assertEquals(4, s.colPtrs().length);
      +        assertEquals(4, ss.colPtrs().length);
           }
       
           @Test
      @@ -137,27 +137,27 @@ public void concatenateMatrices() {
               Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2});
               Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2});
       
      -        assert(deHorz1.numRows() == 3);
      -        assert(deHorz2.numRows() == 3);
      -        assert(deHorz3.numRows() == 3);
      -        assert(spHorz.numRows() == 3);
      -        assert(deHorz1.numCols() == 5);
      -        assert(deHorz2.numCols() == 5);
      -        assert(deHorz3.numCols() == 5);
      -        assert(spHorz.numCols() == 5);
      +        assertEquals(3, deHorz1.numRows());
      +        assertEquals(3, deHorz2.numRows());
      +        assertEquals(3, deHorz3.numRows());
      +        assertEquals(3, spHorz.numRows());
      +        assertEquals(5, deHorz1.numCols());
      +        assertEquals(5, deHorz2.numCols());
      +        assertEquals(5, deHorz3.numCols());
      +        assertEquals(5, spHorz.numCols());
       
               Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3});
               Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3});
               Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3});
               Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3});
       
      -        assert(deVert1.numRows() == 5);
      -        assert(deVert2.numRows() == 5);
      -        assert(deVert3.numRows() == 5);
      -        assert(spVert.numRows() == 5);
      -        assert(deVert1.numCols() == 2);
      -        assert(deVert2.numCols() == 2);
      -        assert(deVert3.numCols() == 2);
      -        assert(spVert.numCols() == 2);
      +        assertEquals(5, deVert1.numRows());
      +        assertEquals(5, deVert2.numRows());
      +        assertEquals(5, deVert3.numRows());
      +        assertEquals(5, spVert.numRows());
      +        assertEquals(2, deVert1.numCols());
      +        assertEquals(2, deVert2.numCols());
      +        assertEquals(2, deVert3.numCols());
      +        assertEquals(2, spVert.numCols());
           }
       }
      diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
      index 1421067dc61e..77c8c6274f37 100644
      --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
      @@ -18,11 +18,10 @@
       package org.apache.spark.mllib.linalg;
       
       import java.io.Serializable;
      +import java.util.Arrays;
       
       import scala.Tuple2;
       
      -import com.google.common.collect.Lists;
      -
       import org.junit.Test;
       import static org.junit.Assert.*;
       
      @@ -37,7 +36,7 @@ public void denseArrayConstruction() {
         @Test
         public void sparseArrayConstruction() {
           @SuppressWarnings("unchecked")
      -    Vector v = Vectors.sparse(3, Lists.>newArrayList(
      +    Vector v = Vectors.sparse(3, Arrays.asList(
               new Tuple2(0, 2.0),
               new Tuple2(2, 3.0)));
           assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0);
      diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
      index fcc13c00cbdc..33d81b1e9592 100644
      --- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
      @@ -17,7 +17,8 @@
       
       package org.apache.spark.mllib.random;
       
      -import com.google.common.collect.Lists;
      +import java.util.Arrays;
      +
       import org.apache.spark.api.java.JavaRDD;
       import org.junit.Assert;
       import org.junit.After;
      @@ -51,7 +52,7 @@ public void testUniformRDD() {
           JavaDoubleRDD rdd1 = uniformJavaRDD(sc, m);
           JavaDoubleRDD rdd2 = uniformJavaRDD(sc, m, p);
           JavaDoubleRDD rdd3 = uniformJavaRDD(sc, m, p, seed);
      -    for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) {
      +    for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
             Assert.assertEquals(m, rdd.count());
           }
         }
      @@ -64,7 +65,7 @@ public void testNormalRDD() {
           JavaDoubleRDD rdd1 = normalJavaRDD(sc, m);
           JavaDoubleRDD rdd2 = normalJavaRDD(sc, m, p);
           JavaDoubleRDD rdd3 = normalJavaRDD(sc, m, p, seed);
      -    for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) {
      +    for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
             Assert.assertEquals(m, rdd.count());
           }
         }
      @@ -79,7 +80,7 @@ public void testLNormalRDD() {
           JavaDoubleRDD rdd1 = logNormalJavaRDD(sc, mean, std, m);
           JavaDoubleRDD rdd2 = logNormalJavaRDD(sc, mean, std, m, p);
           JavaDoubleRDD rdd3 = logNormalJavaRDD(sc, mean, std, m, p, seed);
      -    for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) {
      +    for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
             Assert.assertEquals(m, rdd.count());
           }
         }
      @@ -93,7 +94,7 @@ public void testPoissonRDD() {
           JavaDoubleRDD rdd1 = poissonJavaRDD(sc, mean, m);
           JavaDoubleRDD rdd2 = poissonJavaRDD(sc, mean, m, p);
           JavaDoubleRDD rdd3 = poissonJavaRDD(sc, mean, m, p, seed);
      -    for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) {
      +    for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
             Assert.assertEquals(m, rdd.count());
           }
         }
      @@ -107,7 +108,7 @@ public void testExponentialRDD() {
           JavaDoubleRDD rdd1 = exponentialJavaRDD(sc, mean, m);
           JavaDoubleRDD rdd2 = exponentialJavaRDD(sc, mean, m, p);
           JavaDoubleRDD rdd3 = exponentialJavaRDD(sc, mean, m, p, seed);
      -    for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) {
      +    for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
             Assert.assertEquals(m, rdd.count());
           }
         }
      @@ -122,7 +123,7 @@ public void testGammaRDD() {
           JavaDoubleRDD rdd1 = gammaJavaRDD(sc, shape, scale, m);
           JavaDoubleRDD rdd2 = gammaJavaRDD(sc, shape, scale, m, p);
           JavaDoubleRDD rdd3 = gammaJavaRDD(sc, shape, scale, m, p, seed);
      -    for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) {
      +    for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
             Assert.assertEquals(m, rdd.count());
           }
         }
      @@ -138,7 +139,7 @@ public void testUniformVectorRDD() {
           JavaRDD rdd1 = uniformJavaVectorRDD(sc, m, n);
           JavaRDD rdd2 = uniformJavaVectorRDD(sc, m, n, p);
           JavaRDD rdd3 = uniformJavaVectorRDD(sc, m, n, p, seed);
      -    for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) {
      +    for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
             Assert.assertEquals(m, rdd.count());
             Assert.assertEquals(n, rdd.first().size());
           }
      @@ -154,7 +155,7 @@ public void testNormalVectorRDD() {
           JavaRDD rdd1 = normalJavaVectorRDD(sc, m, n);
           JavaRDD rdd2 = normalJavaVectorRDD(sc, m, n, p);
           JavaRDD rdd3 = normalJavaVectorRDD(sc, m, n, p, seed);
      -    for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) {
      +    for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
             Assert.assertEquals(m, rdd.count());
             Assert.assertEquals(n, rdd.first().size());
           }
      @@ -172,7 +173,7 @@ public void testLogNormalVectorRDD() {
           JavaRDD rdd1 = logNormalJavaVectorRDD(sc, mean, std, m, n);
           JavaRDD rdd2 = logNormalJavaVectorRDD(sc, mean, std, m, n, p);
           JavaRDD rdd3 = logNormalJavaVectorRDD(sc, mean, std, m, n, p, seed);
      -    for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) {
      +    for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
             Assert.assertEquals(m, rdd.count());
             Assert.assertEquals(n, rdd.first().size());
           }
      @@ -189,7 +190,7 @@ public void testPoissonVectorRDD() {
           JavaRDD rdd1 = poissonJavaVectorRDD(sc, mean, m, n);
           JavaRDD rdd2 = poissonJavaVectorRDD(sc, mean, m, n, p);
           JavaRDD rdd3 = poissonJavaVectorRDD(sc, mean, m, n, p, seed);
      -    for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) {
      +    for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
             Assert.assertEquals(m, rdd.count());
             Assert.assertEquals(n, rdd.first().size());
           }
      @@ -206,7 +207,7 @@ public void testExponentialVectorRDD() {
           JavaRDD rdd1 = exponentialJavaVectorRDD(sc, mean, m, n);
           JavaRDD rdd2 = exponentialJavaVectorRDD(sc, mean, m, n, p);
           JavaRDD rdd3 = exponentialJavaVectorRDD(sc, mean, m, n, p, seed);
      -    for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) {
      +    for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
             Assert.assertEquals(m, rdd.count());
             Assert.assertEquals(n, rdd.first().size());
           }
      @@ -224,7 +225,7 @@ public void testGammaVectorRDD() {
           JavaRDD rdd1 = gammaJavaVectorRDD(sc, shape, scale, m, n);
           JavaRDD rdd2 = gammaJavaVectorRDD(sc, shape, scale, m, n, p);
           JavaRDD rdd3 = gammaJavaVectorRDD(sc, shape, scale, m, n, p, seed);
      -    for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) {
      +    for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
             Assert.assertEquals(m, rdd.count());
             Assert.assertEquals(n, rdd.first().size());
           }
      diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
      index af688c504cf1..271dda4662e0 100644
      --- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
      @@ -18,12 +18,12 @@
       package org.apache.spark.mllib.recommendation;
       
       import java.io.Serializable;
      +import java.util.ArrayList;
       import java.util.List;
       
       import scala.Tuple2;
       import scala.Tuple3;
       
      -import com.google.common.collect.Lists;
       import org.jblas.DoubleMatrix;
       import org.junit.After;
       import org.junit.Assert;
      @@ -56,8 +56,7 @@ void validatePrediction(
             double matchThreshold,
             boolean implicitPrefs,
             DoubleMatrix truePrefs) {
      -    List> localUsersProducts =
      -      Lists.newArrayListWithCapacity(users * products);
      +    List> localUsersProducts = new ArrayList(users * products);
           for (int u=0; u < users; ++u) {
             for (int p=0; p < products; ++p) {
               localUsersProducts.add(new Tuple2(u, p));
      diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java
      index d38fc91ace3c..32c2f4f3395b 100644
      --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java
      @@ -18,11 +18,12 @@
       package org.apache.spark.mllib.regression;
       
       import java.io.Serializable;
      +import java.util.ArrayList;
      +import java.util.Arrays;
       import java.util.List;
       
       import scala.Tuple3;
       
      -import com.google.common.collect.Lists;
       import org.junit.After;
       import org.junit.Assert;
       import org.junit.Before;
      @@ -36,7 +37,7 @@ public class JavaIsotonicRegressionSuite implements Serializable {
         private transient JavaSparkContext sc;
       
         private List> generateIsotonicInput(double[] labels) {
      -    List> input = Lists.newArrayList();
      +    ArrayList> input = new ArrayList(labels.length);
       
           for (int i = 1; i <= labels.length; i++) {
             input.add(new Tuple3(labels[i-1], (double) i, 1d));
      @@ -77,7 +78,7 @@ public void testIsotonicRegressionPredictionsJavaRDD() {
           IsotonicRegressionModel model =
             runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12});
       
      -    JavaDoubleRDD testRDD = sc.parallelizeDoubles(Lists.newArrayList(0.0, 1.0, 9.5, 12.0, 13.0));
      +    JavaDoubleRDD testRDD = sc.parallelizeDoubles(Arrays.asList(0.0, 1.0, 9.5, 12.0, 13.0));
           List predictions = model.predict(testRDD).collect();
       
           Assert.assertTrue(predictions.get(0) == 1d);
      diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java
      index 899c4ea60786..dbf6488d4108 100644
      --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java
      @@ -18,11 +18,11 @@
       package org.apache.spark.mllib.regression;
       
       import java.io.Serializable;
      +import java.util.Arrays;
       import java.util.List;
       
       import scala.Tuple2;
       
      -import com.google.common.collect.Lists;
       import org.junit.After;
       import org.junit.Before;
       import org.junit.Test;
      @@ -59,16 +59,16 @@ public void tearDown() {
         @Test
         @SuppressWarnings("unchecked")
         public void javaAPI() {
      -    List trainingBatch = Lists.newArrayList(
      +    List trainingBatch = Arrays.asList(
             new LabeledPoint(1.0, Vectors.dense(1.0)),
             new LabeledPoint(0.0, Vectors.dense(0.0)));
           JavaDStream training =
      -      attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2);
      -    List> testBatch = Lists.newArrayList(
      +      attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2);
      +    List> testBatch = Arrays.asList(
             new Tuple2(10, Vectors.dense(1.0)),
             new Tuple2(11, Vectors.dense(0.0)));
           JavaPairDStream test = JavaPairDStream.fromJavaDStream(
      -      attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2));
      +      attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2));
           StreamingLinearRegressionWithSGD slr = new StreamingLinearRegressionWithSGD()
             .setNumIterations(2)
             .setInitialWeights(Vectors.dense(0.0));
      diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
      index 62f7f26b7c98..4795809e47a4 100644
      --- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
      +++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
      @@ -19,7 +19,8 @@
       
       import java.io.Serializable;
       
      -import com.google.common.collect.Lists;
      +import java.util.Arrays;
      +
       import org.junit.After;
       import org.junit.Before;
       import org.junit.Test;
      @@ -27,7 +28,12 @@
       import static org.junit.Assert.assertEquals;
       
       import org.apache.spark.api.java.JavaRDD;
      +import org.apache.spark.api.java.JavaDoubleRDD;
       import org.apache.spark.api.java.JavaSparkContext;
      +import org.apache.spark.mllib.linalg.Vectors;
      +import org.apache.spark.mllib.regression.LabeledPoint;
      +import org.apache.spark.mllib.stat.test.ChiSqTestResult;
      +import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult;
       
       public class JavaStatisticsSuite implements Serializable {
         private transient JavaSparkContext sc;
      @@ -45,12 +51,29 @@ public void tearDown() {
       
         @Test
         public void testCorr() {
      -    JavaRDD x = sc.parallelize(Lists.newArrayList(1.0, 2.0, 3.0, 4.0));
      -    JavaRDD y = sc.parallelize(Lists.newArrayList(1.1, 2.2, 3.1, 4.3));
      +    JavaRDD x = sc.parallelize(Arrays.asList(1.0, 2.0, 3.0, 4.0));
      +    JavaRDD y = sc.parallelize(Arrays.asList(1.1, 2.2, 3.1, 4.3));
       
           Double corr1 = Statistics.corr(x, y);
           Double corr2 = Statistics.corr(x, y, "pearson");
           // Check default method
           assertEquals(corr1, corr2);
         }
      +
      +  @Test
      +  public void kolmogorovSmirnovTest() {
      +    JavaDoubleRDD data = sc.parallelizeDoubles(Arrays.asList(0.2, 1.0, -1.0, 2.0));
      +    KolmogorovSmirnovTestResult testResult1 = Statistics.kolmogorovSmirnovTest(data, "norm");
      +    KolmogorovSmirnovTestResult testResult2 = Statistics.kolmogorovSmirnovTest(
      +      data, "norm", 0.0, 1.0);
      +  }
      +
      +  @Test
      +  public void chiSqTest() {
      +    JavaRDD data = sc.parallelize(Arrays.asList(
      +      new LabeledPoint(0.0, Vectors.dense(0.1, 2.3)),
      +      new LabeledPoint(1.0, Vectors.dense(1.5, 5.1)),
      +      new LabeledPoint(0.0, Vectors.dense(2.4, 8.1))));
      +    ChiSqTestResult[] testResults = Statistics.chiSqTest(data);
      +  }
       }
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
      index 63d2fa31c749..1f2c9b75b617 100644
      --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
      @@ -26,6 +26,7 @@ import org.scalatest.mock.MockitoSugar.mock
       import org.apache.spark.SparkFunSuite
       import org.apache.spark.ml.feature.HashingTF
       import org.apache.spark.ml.param.ParamMap
      +import org.apache.spark.ml.util.MLTestingUtils
       import org.apache.spark.sql.DataFrame
       
       class PipelineSuite extends SparkFunSuite {
      @@ -65,6 +66,8 @@ class PipelineSuite extends SparkFunSuite {
             .setStages(Array(estimator0, transformer1, estimator2, transformer3))
           val pipelineModel = pipeline.fit(dataset0)
       
      +    MLTestingUtils.checkCopy(pipelineModel)
      +
           assert(pipelineModel.stages.length === 4)
           assert(pipelineModel.stages(0).eq(model0))
           assert(pipelineModel.stages(1).eq(transformer1))
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala
      new file mode 100644
      index 000000000000..1292e57d7c01
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala
      @@ -0,0 +1,91 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.ann
      +
      +import org.apache.spark.SparkFunSuite
      +import org.apache.spark.mllib.linalg.Vectors
      +import org.apache.spark.mllib.util.MLlibTestSparkContext
      +import org.apache.spark.mllib.util.TestingUtils._
      +
      +
      +class ANNSuite extends SparkFunSuite with MLlibTestSparkContext {
      +
      +  // TODO: test for weights comparison with Weka MLP
      +  test("ANN with Sigmoid learns XOR function with LBFGS optimizer") {
      +    val inputs = Array(
      +      Array(0.0, 0.0),
      +      Array(0.0, 1.0),
      +      Array(1.0, 0.0),
      +      Array(1.0, 1.0)
      +    )
      +    val outputs = Array(0.0, 1.0, 1.0, 0.0)
      +    val data = inputs.zip(outputs).map { case (features, label) =>
      +      (Vectors.dense(features), Vectors.dense(label))
      +    }
      +    val rddData = sc.parallelize(data, 1)
      +    val hiddenLayersTopology = Array(5)
      +    val dataSample = rddData.first()
      +    val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size
      +    val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false)
      +    val initialWeights = FeedForwardModel(topology, 23124).weights()
      +    val trainer = new FeedForwardTrainer(topology, 2, 1)
      +    trainer.setWeights(initialWeights)
      +    trainer.LBFGSOptimizer.setNumIterations(20)
      +    val model = trainer.train(rddData)
      +    val predictionAndLabels = rddData.map { case (input, label) =>
      +      (model.predict(input)(0), label(0))
      +    }.collect()
      +    predictionAndLabels.foreach { case (p, l) =>
      +      assert(math.round(p) === l)
      +    }
      +  }
      +
      +  test("ANN with SoftMax learns XOR function with 2-bit output and batch GD optimizer") {
      +    val inputs = Array(
      +      Array(0.0, 0.0),
      +      Array(0.0, 1.0),
      +      Array(1.0, 0.0),
      +      Array(1.0, 1.0)
      +    )
      +    val outputs = Array(
      +      Array(1.0, 0.0),
      +      Array(0.0, 1.0),
      +      Array(0.0, 1.0),
      +      Array(1.0, 0.0)
      +    )
      +    val data = inputs.zip(outputs).map { case (features, label) =>
      +      (Vectors.dense(features), Vectors.dense(label))
      +    }
      +    val rddData = sc.parallelize(data, 1)
      +    val hiddenLayersTopology = Array(5)
      +    val dataSample = rddData.first()
      +    val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size
      +    val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false)
      +    val initialWeights = FeedForwardModel(topology, 23124).weights()
      +    val trainer = new FeedForwardTrainer(topology, 2, 2)
      +    trainer.SGDOptimizer.setNumIterations(2000)
      +    trainer.setWeights(initialWeights)
      +    val model = trainer.train(rddData)
      +    val predictionAndLabels = rddData.map { case (input, label) =>
      +      (model.predict(input), label)
      +    }.collect()
      +    predictionAndLabels.foreach { case (p, l) =>
      +      assert(p ~== l absTol 0.5)
      +    }
      +  }
      +}
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
      index 72b575d02254..6355e0f17949 100644
      --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
      @@ -215,5 +215,10 @@ class AttributeSuite extends SparkFunSuite {
           assert(Attribute.fromStructField(fldWithoutMeta) == UnresolvedAttribute)
           val fldWithMeta = new StructField("x", DoubleType, false, metadata)
           assert(Attribute.fromStructField(fldWithMeta).isNumeric)
      +    // Attribute.fromStructField should accept any NumericType, not just DoubleType
      +    val longFldWithMeta = new StructField("x", LongType, false, metadata)
      +    assert(Attribute.fromStructField(longFldWithMeta).isNumeric)
      +    val decimalFldWithMeta = new StructField("x", DecimalType(38, 18), false, metadata)
      +    assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric)
         }
       }
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
      index 73b4805c4c59..f680d8d3c4cc 100644
      --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
      @@ -21,12 +21,14 @@ import org.apache.spark.SparkFunSuite
       import org.apache.spark.ml.impl.TreeTests
       import org.apache.spark.ml.param.ParamsSuite
       import org.apache.spark.ml.tree.LeafNode
      -import org.apache.spark.mllib.linalg.Vectors
      +import org.apache.spark.ml.util.MLTestingUtils
      +import org.apache.spark.mllib.linalg.{Vector, Vectors}
       import org.apache.spark.mllib.regression.LabeledPoint
       import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
       import org.apache.spark.mllib.util.MLlibTestSparkContext
       import org.apache.spark.rdd.RDD
       import org.apache.spark.sql.DataFrame
      +import org.apache.spark.sql.Row
       
       class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
       
      @@ -57,7 +59,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
       
         test("params") {
           ParamsSuite.checkParams(new DecisionTreeClassifier)
      -    val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))
      +    val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)
           ParamsSuite.checkParams(model)
         }
       
      @@ -231,6 +233,47 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
           compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
         }
       
      +  test("predictRaw and predictProbability") {
      +    val rdd = continuousDataPointsForMulticlassRDD
      +    val dt = new DecisionTreeClassifier()
      +      .setImpurity("Gini")
      +      .setMaxDepth(4)
      +      .setMaxBins(100)
      +    val categoricalFeatures = Map(0 -> 3)
      +    val numClasses = 3
      +
      +    val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
      +    val newTree = dt.fit(newData)
      +
      +    // copied model must have the same parent.
      +    MLTestingUtils.checkCopy(newTree)
      +
      +    val predictions = newTree.transform(newData)
      +      .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol)
      +      .collect()
      +
      +    predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
      +      assert(pred === rawPred.argmax,
      +        s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
      +      val sum = rawPred.toArray.sum
      +      assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
      +        "probability prediction mismatch")
      +    }
      +  }
      +
      +  test("training with 1-category categorical feature") {
      +    val data = sc.parallelize(Seq(
      +      LabeledPoint(0, Vectors.dense(0, 2, 3)),
      +      LabeledPoint(1, Vectors.dense(0, 3, 1)),
      +      LabeledPoint(0, Vectors.dense(0, 2, 2)),
      +      LabeledPoint(1, Vectors.dense(0, 3, 9)),
      +      LabeledPoint(0, Vectors.dense(0, 2, 6))
      +    ))
      +    val df = TreeTests.setMetadata(data, Map(0 -> 1), 2)
      +    val dt = new DecisionTreeClassifier().setMaxDepth(3)
      +    val model = dt.fit(df)
      +  }
      +
         /////////////////////////////////////////////////////////////////////////////
         // Tests of model save/load
         /////////////////////////////////////////////////////////////////////////////
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
      index 82c345491bb3..e3909bccaa5c 100644
      --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
      @@ -22,12 +22,14 @@ import org.apache.spark.ml.impl.TreeTests
       import org.apache.spark.ml.param.ParamsSuite
       import org.apache.spark.ml.regression.DecisionTreeRegressionModel
       import org.apache.spark.ml.tree.LeafNode
      +import org.apache.spark.ml.util.MLTestingUtils
       import org.apache.spark.mllib.regression.LabeledPoint
       import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
       import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
       import org.apache.spark.mllib.util.MLlibTestSparkContext
       import org.apache.spark.rdd.RDD
       import org.apache.spark.sql.DataFrame
      +import org.apache.spark.util.Utils
       
       
       /**
      @@ -57,7 +59,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
         test("params") {
           ParamsSuite.checkParams(new GBTClassifier)
           val model = new GBTClassificationModel("gbtc",
      -      Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0))),
      +      Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null))),
             Array(1.0))
           ParamsSuite.checkParams(model)
         }
      @@ -76,6 +78,28 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
           }
         }
       
      +  test("Checkpointing") {
      +    val tempDir = Utils.createTempDir()
      +    val path = tempDir.toURI.toString
      +    sc.setCheckpointDir(path)
      +
      +    val categoricalFeatures = Map.empty[Int, Int]
      +    val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2)
      +    val gbt = new GBTClassifier()
      +      .setMaxDepth(2)
      +      .setLossType("logistic")
      +      .setMaxIter(5)
      +      .setStepSize(0.1)
      +      .setCheckpointInterval(2)
      +    val model = gbt.fit(df)
      +
      +    // copied model must have the same parent.
      +    MLTestingUtils.checkCopy(model)
      +
      +    sc.checkpointDir = None
      +    Utils.deleteRecursively(tempDir)
      +  }
      +
         // TODO: Reinstate test once runWithValidation is implemented   SPARK-7132
         /*
         test("runWithValidation stops early and performs better on a validation dataset") {
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
      index 5a6265ea992c..f5219f9f574b 100644
      --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
      @@ -17,10 +17,14 @@
       
       package org.apache.spark.ml.classification
       
      +import scala.util.Random
      +
       import org.apache.spark.SparkFunSuite
       import org.apache.spark.ml.param.ParamsSuite
      +import org.apache.spark.ml.util.MLTestingUtils
       import org.apache.spark.mllib.classification.LogisticRegressionSuite._
       import org.apache.spark.mllib.linalg.{Vectors, Vector}
      +import org.apache.spark.mllib.regression.LabeledPoint
       import org.apache.spark.mllib.util.MLlibTestSparkContext
       import org.apache.spark.mllib.util.TestingUtils._
       import org.apache.spark.sql.{DataFrame, Row}
      @@ -36,19 +40,19 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
       
           dataset = sqlContext.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42))
       
      -    /**
      -     * Here is the instruction describing how to export the test data into CSV format
      -     * so we can validate the training accuracy compared with R's glmnet package.
      -     *
      -     * import org.apache.spark.mllib.classification.LogisticRegressionSuite
      -     * val nPoints = 10000
      -     * val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191)
      -     * val xMean = Array(5.843, 3.057, 3.758, 1.199)
      -     * val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
      -     * val data = sc.parallelize(LogisticRegressionSuite.generateMultinomialLogisticInput(
      -     *   weights, xMean, xVariance, true, nPoints, 42), 1)
      -     * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1) + ", "
      -     *   + x.features(2) + ", " + x.features(3)).saveAsTextFile("path")
      +    /*
      +       Here is the instruction describing how to export the test data into CSV format
      +       so we can validate the training accuracy compared with R's glmnet package.
      +
      +       import org.apache.spark.mllib.classification.LogisticRegressionSuite
      +       val nPoints = 10000
      +       val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191)
      +       val xMean = Array(5.843, 3.057, 3.758, 1.199)
      +       val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
      +       val data = sc.parallelize(LogisticRegressionSuite.generateMultinomialLogisticInput(
      +         weights, xMean, xVariance, true, nPoints, 42), 1)
      +       data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1) + ", "
      +         + x.features(2) + ", " + x.features(3)).saveAsTextFile("path")
            */
           binaryDataset = {
             val nPoints = 10000
      @@ -58,8 +62,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
       
             val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42)
       
      -      sqlContext.createDataFrame(
      -        generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42))
      +      sqlContext.createDataFrame(sc.parallelize(testData, 4))
           }
         }
       
      @@ -76,7 +79,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
           assert(lr.getPredictionCol === "prediction")
           assert(lr.getRawPredictionCol === "rawPrediction")
           assert(lr.getProbabilityCol === "probability")
      +    assert(lr.getWeightCol === "")
           assert(lr.getFitIntercept)
      +    assert(lr.getStandardization)
           val model = lr.fit(dataset)
           model.transform(dataset)
             .select("label", "probability", "prediction", "rawPrediction")
      @@ -90,11 +95,53 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
           assert(model.hasParent)
         }
       
      +  test("setThreshold, getThreshold") {
      +    val lr = new LogisticRegression
      +    // default
      +    assert(lr.getThreshold === 0.5, "LogisticRegression.threshold should default to 0.5")
      +    withClue("LogisticRegression should not have thresholds set by default.") {
      +      intercept[java.util.NoSuchElementException] { // Note: The exception type may change in future
      +        lr.getThresholds
      +      }
      +    }
      +    // Set via threshold.
      +    // Intuition: Large threshold or large thresholds(1) makes class 0 more likely.
      +    lr.setThreshold(1.0)
      +    assert(lr.getThresholds === Array(0.0, 1.0))
      +    lr.setThreshold(0.0)
      +    assert(lr.getThresholds === Array(1.0, 0.0))
      +    lr.setThreshold(0.5)
      +    assert(lr.getThresholds === Array(0.5, 0.5))
      +    // Set via thresholds
      +    val lr2 = new LogisticRegression
      +    lr2.setThresholds(Array(0.3, 0.7))
      +    val expectedThreshold = 1.0 / (1.0 + 0.3 / 0.7)
      +    assert(lr2.getThreshold ~== expectedThreshold relTol 1E-7)
      +    // thresholds and threshold must be consistent
      +    lr2.setThresholds(Array(0.1, 0.2, 0.3))
      +    withClue("getThreshold should throw error if thresholds has length != 2.") {
      +      intercept[IllegalArgumentException] {
      +        lr2.getThreshold
      +      }
      +    }
      +    // thresholds and threshold must be consistent: values
      +    withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") {
      +      intercept[IllegalArgumentException] {
      +        val lr2model = lr2.fit(dataset,
      +          lr2.thresholds -> Array(0.3, 0.7), lr2.threshold -> (expectedThreshold / 2.0))
      +        lr2model.getThreshold
      +      }
      +    }
      +  }
      +
         test("logistic regression doesn't fit intercept when fitIntercept is off") {
           val lr = new LogisticRegression
           lr.setFitIntercept(false)
           val model = lr.fit(dataset)
           assert(model.intercept === 0.0)
      +
      +    // copied model must have the same parent.
      +    MLTestingUtils.checkCopy(model)
         }
       
         test("logistic regression with setters") {
      @@ -122,14 +169,16 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
             s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.")
           // Call transform with params, and check that the params worked.
           val predNotAllZero =
      -      model.transform(dataset, model.threshold -> 0.0, model.probabilityCol -> "myProb")
      +      model.transform(dataset, model.threshold -> 0.0,
      +        model.probabilityCol -> "myProb")
               .select("prediction", "myProb")
               .collect()
               .map { case Row(pred: Double, prob: Vector) => pred }
           assert(predNotAllZero.exists(_ !== 0.0))
       
           // Call fit() with new params, and check as many params as we can.
      -    val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4,
      +    lr.setThresholds(Array(0.6, 0.4))
      +    val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1,
             lr.probabilityCol -> "theProb")
           val parent2 = model2.parent.asInstanceOf[LogisticRegression]
           assert(parent2.getMaxIter === 5)
      @@ -170,305 +219,503 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
         test("MultiClassSummarizer") {
           val summarizer1 = (new MultiClassSummarizer)
             .add(0.0).add(3.0).add(4.0).add(3.0).add(6.0)
      -    assert(summarizer1.histogram.zip(Array[Long](1, 0, 0, 2, 1, 0, 1)).forall(x => x._1 === x._2))
      +    assert(summarizer1.histogram === Array[Double](1, 0, 0, 2, 1, 0, 1))
           assert(summarizer1.countInvalid === 0)
           assert(summarizer1.numClasses === 7)
       
           val summarizer2 = (new MultiClassSummarizer)
             .add(1.0).add(5.0).add(3.0).add(0.0).add(4.0).add(1.0)
      -    assert(summarizer2.histogram.zip(Array[Long](1, 2, 0, 1, 1, 1)).forall(x => x._1 === x._2))
      +    assert(summarizer2.histogram === Array[Double](1, 2, 0, 1, 1, 1))
           assert(summarizer2.countInvalid === 0)
           assert(summarizer2.numClasses === 6)
       
           val summarizer3 = (new MultiClassSummarizer)
             .add(0.0).add(1.3).add(5.2).add(2.5).add(2.0).add(4.0).add(4.0).add(4.0).add(1.0)
      -    assert(summarizer3.histogram.zip(Array[Long](1, 1, 1, 0, 3)).forall(x => x._1 === x._2))
      +    assert(summarizer3.histogram === Array[Double](1, 1, 1, 0, 3))
           assert(summarizer3.countInvalid === 3)
           assert(summarizer3.numClasses === 5)
       
           val summarizer4 = (new MultiClassSummarizer)
             .add(3.1).add(4.3).add(2.0).add(1.0).add(3.0)
      -    assert(summarizer4.histogram.zip(Array[Long](0, 1, 1, 1)).forall(x => x._1 === x._2))
      +    assert(summarizer4.histogram === Array[Double](0, 1, 1, 1))
           assert(summarizer4.countInvalid === 2)
           assert(summarizer4.numClasses === 4)
       
           // small map merges large one
           val summarizerA = summarizer1.merge(summarizer2)
           assert(summarizerA.hashCode() === summarizer2.hashCode())
      -    assert(summarizerA.histogram.zip(Array[Long](2, 2, 0, 3, 2, 1, 1)).forall(x => x._1 === x._2))
      +    assert(summarizerA.histogram === Array[Double](2, 2, 0, 3, 2, 1, 1))
           assert(summarizerA.countInvalid === 0)
           assert(summarizerA.numClasses === 7)
       
           // large map merges small one
           val summarizerB = summarizer3.merge(summarizer4)
           assert(summarizerB.hashCode() === summarizer3.hashCode())
      -    assert(summarizerB.histogram.zip(Array[Long](1, 2, 2, 1, 3)).forall(x => x._1 === x._2))
      +    assert(summarizerB.histogram === Array[Double](1, 2, 2, 1, 3))
           assert(summarizerB.countInvalid === 5)
           assert(summarizerB.numClasses === 5)
         }
       
      +  test("MultiClassSummarizer with weighted samples") {
      +    val summarizer1 = (new MultiClassSummarizer)
      +      .add(label = 0.0, weight = 0.2).add(3.0, 0.8).add(4.0, 3.2).add(3.0, 1.3).add(6.0, 3.1)
      +    assert(Vectors.dense(summarizer1.histogram) ~==
      +      Vectors.dense(Array(0.2, 0, 0, 2.1, 3.2, 0, 3.1)) absTol 1E-10)
      +    assert(summarizer1.countInvalid === 0)
      +    assert(summarizer1.numClasses === 7)
      +
      +    val summarizer2 = (new MultiClassSummarizer)
      +      .add(1.0, 1.1).add(5.0, 2.3).add(3.0).add(0.0).add(4.0).add(1.0).add(2, 0.0)
      +    assert(Vectors.dense(summarizer2.histogram) ~==
      +      Vectors.dense(Array[Double](1.0, 2.1, 0.0, 1, 1, 2.3)) absTol 1E-10)
      +    assert(summarizer2.countInvalid === 0)
      +    assert(summarizer2.numClasses === 6)
      +
      +    val summarizer = summarizer1.merge(summarizer2)
      +    assert(Vectors.dense(summarizer.histogram) ~==
      +      Vectors.dense(Array(1.2, 2.1, 0.0, 3.1, 4.2, 2.3, 3.1)) absTol 1E-10)
      +    assert(summarizer.countInvalid === 0)
      +    assert(summarizer.numClasses === 7)
      +  }
      +
         test("binary logistic regression with intercept without regularization") {
      -    val trainer = (new LogisticRegression).setFitIntercept(true)
      -    val model = trainer.fit(binaryDataset)
      -
      -    /**
      -     * Using the following R code to load the data and train the model using glmnet package.
      -     *
      -     * > library("glmnet")
      -     * > data <- read.csv("path", header=FALSE)
      -     * > label = factor(data$V1)
      -     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
      -     * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0))
      -     * > weights
      -     * 5 x 1 sparse Matrix of class "dgCMatrix"
      -     *                     s0
      -     * (Intercept)  2.8366423
      -     * data.V2     -0.5895848
      -     * data.V3      0.8931147
      -     * data.V4     -0.3925051
      -     * data.V5     -0.7996864
      +    val trainer1 = (new LogisticRegression).setFitIntercept(true).setStandardization(true)
      +    val trainer2 = (new LogisticRegression).setFitIntercept(true).setStandardization(false)
      +
      +    val model1 = trainer1.fit(binaryDataset)
      +    val model2 = trainer2.fit(binaryDataset)
      +
      +    /*
      +       Using the following R code to load the data and train the model using glmnet package.
      +
      +       library("glmnet")
      +       data <- read.csv("path", header=FALSE)
      +       label = factor(data$V1)
      +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
      +       weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0))
      +       weights
      +
      +       5 x 1 sparse Matrix of class "dgCMatrix"
      +                           s0
      +       (Intercept)  2.8366423
      +       data.V2     -0.5895848
      +       data.V3      0.8931147
      +       data.V4     -0.3925051
      +       data.V5     -0.7996864
            */
           val interceptR = 2.8366423
      -    val weightsR = Array(-0.5895848, 0.8931147, -0.3925051, -0.7996864)
      +    val weightsR = Vectors.dense(-0.5895848, 0.8931147, -0.3925051, -0.7996864)
      +
      +    assert(model1.intercept ~== interceptR relTol 1E-3)
      +    assert(model1.weights ~= weightsR relTol 1E-3)
       
      -    assert(model.intercept ~== interceptR relTol 1E-3)
      -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
      -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
      -    assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
      -    assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
      +    // Without regularization, with or without standardization will converge to the same solution.
      +    assert(model2.intercept ~== interceptR relTol 1E-3)
      +    assert(model2.weights ~= weightsR relTol 1E-3)
         }
       
         test("binary logistic regression without intercept without regularization") {
      -    val trainer = (new LogisticRegression).setFitIntercept(false)
      -    val model = trainer.fit(binaryDataset)
      -
      -    /**
      -     * Using the following R code to load the data and train the model using glmnet package.
      -     *
      -     * > library("glmnet")
      -     * > data <- read.csv("path", header=FALSE)
      -     * > label = factor(data$V1)
      -     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
      -     * > weights =
      -     *     coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE))
      -     * > weights
      -     * 5 x 1 sparse Matrix of class "dgCMatrix"
      -     *                     s0
      -     * (Intercept)   .
      -     * data.V2     -0.3534996
      -     * data.V3      1.2964482
      -     * data.V4     -0.3571741
      -     * data.V5     -0.7407946
      +    val trainer1 = (new LogisticRegression).setFitIntercept(false).setStandardization(true)
      +    val trainer2 = (new LogisticRegression).setFitIntercept(false).setStandardization(false)
      +
      +    val model1 = trainer1.fit(binaryDataset)
      +    val model2 = trainer2.fit(binaryDataset)
      +
      +    /*
      +       Using the following R code to load the data and train the model using glmnet package.
      +
      +       library("glmnet")
      +       data <- read.csv("path", header=FALSE)
      +       label = factor(data$V1)
      +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
      +       weights =
      +           coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE))
      +       weights
      +
      +       5 x 1 sparse Matrix of class "dgCMatrix"
      +                           s0
      +       (Intercept)   .
      +       data.V2     -0.3534996
      +       data.V3      1.2964482
      +       data.V4     -0.3571741
      +       data.V5     -0.7407946
            */
           val interceptR = 0.0
      -    val weightsR = Array(-0.3534996, 1.2964482, -0.3571741, -0.7407946)
      +    val weightsR = Vectors.dense(-0.3534996, 1.2964482, -0.3571741, -0.7407946)
       
      -    assert(model.intercept ~== interceptR relTol 1E-3)
      -    assert(model.weights(0) ~== weightsR(0) relTol 1E-2)
      -    assert(model.weights(1) ~== weightsR(1) relTol 1E-2)
      -    assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
      -    assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
      +    assert(model1.intercept ~== interceptR relTol 1E-3)
      +    assert(model1.weights ~= weightsR relTol 1E-2)
      +
      +    // Without regularization, with or without standardization should converge to the same solution.
      +    assert(model2.intercept ~== interceptR relTol 1E-3)
      +    assert(model2.weights ~= weightsR relTol 1E-2)
         }
       
         test("binary logistic regression with intercept with L1 regularization") {
      -    val trainer = (new LogisticRegression).setFitIntercept(true)
      -      .setElasticNetParam(1.0).setRegParam(0.12)
      -    val model = trainer.fit(binaryDataset)
      -
      -    /**
      -     * Using the following R code to load the data and train the model using glmnet package.
      -     *
      -     * > library("glmnet")
      -     * > data <- read.csv("path", header=FALSE)
      -     * > label = factor(data$V1)
      -     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
      -     * > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12))
      -     * > weights
      -     * 5 x 1 sparse Matrix of class "dgCMatrix"
      -     *                      s0
      -     * (Intercept) -0.05627428
      -     * data.V2       .
      -     * data.V3       .
      -     * data.V4     -0.04325749
      -     * data.V5     -0.02481551
      +    val trainer1 = (new LogisticRegression).setFitIntercept(true)
      +      .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true)
      +    val trainer2 = (new LogisticRegression).setFitIntercept(true)
      +      .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(false)
      +
      +    val model1 = trainer1.fit(binaryDataset)
      +    val model2 = trainer2.fit(binaryDataset)
      +
      +    /*
      +       Using the following R code to load the data and train the model using glmnet package.
      +
      +       library("glmnet")
      +       data <- read.csv("path", header=FALSE)
      +       label = factor(data$V1)
      +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
      +       weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12))
      +       weights
      +
      +       5 x 1 sparse Matrix of class "dgCMatrix"
      +                            s0
      +       (Intercept) -0.05627428
      +       data.V2       .
      +       data.V3       .
      +       data.V4     -0.04325749
      +       data.V5     -0.02481551
      +     */
      +    val interceptR1 = -0.05627428
      +    val weightsR1 = Vectors.dense(0.0, 0.0, -0.04325749, -0.02481551)
      +
      +    assert(model1.intercept ~== interceptR1 relTol 1E-2)
      +    assert(model1.weights ~= weightsR1 absTol 2E-2)
      +
      +    /*
      +       Using the following R code to load the data and train the model using glmnet package.
      +
      +       library("glmnet")
      +       data <- read.csv("path", header=FALSE)
      +       label = factor(data$V1)
      +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
      +       weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12,
      +           standardize=FALSE))
      +       weights
      +
      +       5 x 1 sparse Matrix of class "dgCMatrix"
      +                           s0
      +       (Intercept)  0.3722152
      +       data.V2       .
      +       data.V3       .
      +       data.V4     -0.1665453
      +       data.V5       .
            */
      -    val interceptR = -0.05627428
      -    val weightsR = Array(0.0, 0.0, -0.04325749, -0.02481551)
      -
      -    assert(model.intercept ~== interceptR relTol 1E-2)
      -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
      -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
      -    assert(model.weights(2) ~== weightsR(2) relTol 1E-2)
      -    assert(model.weights(3) ~== weightsR(3) relTol 2E-2)
      +    val interceptR2 = 0.3722152
      +    val weightsR2 = Vectors.dense(0.0, 0.0, -0.1665453, 0.0)
      +
      +    assert(model2.intercept ~== interceptR2 relTol 1E-2)
      +    assert(model2.weights ~= weightsR2 absTol 1E-3)
         }
       
         test("binary logistic regression without intercept with L1 regularization") {
      -    val trainer = (new LogisticRegression).setFitIntercept(false)
      -      .setElasticNetParam(1.0).setRegParam(0.12)
      -    val model = trainer.fit(binaryDataset)
      -
      -    /**
      -     * Using the following R code to load the data and train the model using glmnet package.
      -     *
      -     * > library("glmnet")
      -     * > data <- read.csv("path", header=FALSE)
      -     * > label = factor(data$V1)
      -     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
      -     * > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12,
      -     *     intercept=FALSE))
      -     * > weights
      -     * 5 x 1 sparse Matrix of class "dgCMatrix"
      -     *                      s0
      -     * (Intercept)   .
      -     * data.V2       .
      -     * data.V3       .
      -     * data.V4     -0.05189203
      -     * data.V5     -0.03891782
      +    val trainer1 = (new LogisticRegression).setFitIntercept(false)
      +      .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true)
      +    val trainer2 = (new LogisticRegression).setFitIntercept(false)
      +      .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(false)
      +
      +    val model1 = trainer1.fit(binaryDataset)
      +    val model2 = trainer2.fit(binaryDataset)
      +
      +    /*
      +       Using the following R code to load the data and train the model using glmnet package.
      +
      +       library("glmnet")
      +       data <- read.csv("path", header=FALSE)
      +       label = factor(data$V1)
      +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
      +       weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12,
      +           intercept=FALSE))
      +       weights
      +
      +       5 x 1 sparse Matrix of class "dgCMatrix"
      +                            s0
      +       (Intercept)   .
      +       data.V2       .
      +       data.V3       .
      +       data.V4     -0.05189203
      +       data.V5     -0.03891782
            */
      -    val interceptR = 0.0
      -    val weightsR = Array(0.0, 0.0, -0.05189203, -0.03891782)
      +    val interceptR1 = 0.0
      +    val weightsR1 = Vectors.dense(0.0, 0.0, -0.05189203, -0.03891782)
      +
      +    assert(model1.intercept ~== interceptR1 relTol 1E-3)
      +    assert(model1.weights ~= weightsR1 absTol 1E-3)
      +
      +    /*
      +       Using the following R code to load the data and train the model using glmnet package.
      +
      +       library("glmnet")
      +       data <- read.csv("path", header=FALSE)
      +       label = factor(data$V1)
      +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
      +       weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12,
      +           intercept=FALSE, standardize=FALSE))
      +       weights
      +
      +       5 x 1 sparse Matrix of class "dgCMatrix"
      +                            s0
      +       (Intercept)   .
      +       data.V2       .
      +       data.V3       .
      +       data.V4     -0.08420782
      +       data.V5       .
      +     */
      +    val interceptR2 = 0.0
      +    val weightsR2 = Vectors.dense(0.0, 0.0, -0.08420782, 0.0)
       
      -    assert(model.intercept ~== interceptR relTol 1E-3)
      -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
      -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
      -    assert(model.weights(2) ~== weightsR(2) relTol 1E-2)
      -    assert(model.weights(3) ~== weightsR(3) relTol 1E-2)
      +    assert(model2.intercept ~== interceptR2 absTol 1E-3)
      +    assert(model2.weights ~= weightsR2 absTol 1E-3)
         }
       
         test("binary logistic regression with intercept with L2 regularization") {
      -    val trainer = (new LogisticRegression).setFitIntercept(true)
      -      .setElasticNetParam(0.0).setRegParam(1.37)
      -    val model = trainer.fit(binaryDataset)
      -
      -    /**
      -     * Using the following R code to load the data and train the model using glmnet package.
      -     *
      -     * > library("glmnet")
      -     * > data <- read.csv("path", header=FALSE)
      -     * > label = factor(data$V1)
      -     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
      -     * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37))
      -     * > weights
      -     * 5 x 1 sparse Matrix of class "dgCMatrix"
      -     *                      s0
      -     * (Intercept)  0.15021751
      -     * data.V2     -0.07251837
      -     * data.V3      0.10724191
      -     * data.V4     -0.04865309
      -     * data.V5     -0.10062872
      +    val trainer1 = (new LogisticRegression).setFitIntercept(true)
      +      .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true)
      +    val trainer2 = (new LogisticRegression).setFitIntercept(true)
      +      .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(false)
      +
      +    val model1 = trainer1.fit(binaryDataset)
      +    val model2 = trainer2.fit(binaryDataset)
      +
      +    /*
      +       Using the following R code to load the data and train the model using glmnet package.
      +
      +       library("glmnet")
      +       data <- read.csv("path", header=FALSE)
      +       label = factor(data$V1)
      +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
      +       weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37))
      +       weights
      +
      +       5 x 1 sparse Matrix of class "dgCMatrix"
      +                            s0
      +       (Intercept)  0.15021751
      +       data.V2     -0.07251837
      +       data.V3      0.10724191
      +       data.V4     -0.04865309
      +       data.V5     -0.10062872
            */
      -    val interceptR = 0.15021751
      -    val weightsR = Array(-0.07251837, 0.10724191, -0.04865309, -0.10062872)
      -
      -    assert(model.intercept ~== interceptR relTol 1E-3)
      -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
      -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
      -    assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
      -    assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
      +    val interceptR1 = 0.15021751
      +    val weightsR1 = Vectors.dense(-0.07251837, 0.10724191, -0.04865309, -0.10062872)
      +
      +    assert(model1.intercept ~== interceptR1 relTol 1E-3)
      +    assert(model1.weights ~= weightsR1 relTol 1E-3)
      +
      +    /*
      +       Using the following R code to load the data and train the model using glmnet package.
      +
      +       library("glmnet")
      +       data <- read.csv("path", header=FALSE)
      +       label = factor(data$V1)
      +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
      +       weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37,
      +           standardize=FALSE))
      +       weights
      +
      +       5 x 1 sparse Matrix of class "dgCMatrix"
      +                            s0
      +       (Intercept)  0.48657516
      +       data.V2     -0.05155371
      +       data.V3      0.02301057
      +       data.V4     -0.11482896
      +       data.V5     -0.06266838
      +     */
      +    val interceptR2 = 0.48657516
      +    val weightsR2 = Vectors.dense(-0.05155371, 0.02301057, -0.11482896, -0.06266838)
      +
      +    assert(model2.intercept ~== interceptR2 relTol 1E-3)
      +    assert(model2.weights ~= weightsR2 relTol 1E-3)
         }
       
         test("binary logistic regression without intercept with L2 regularization") {
      -    val trainer = (new LogisticRegression).setFitIntercept(false)
      -      .setElasticNetParam(0.0).setRegParam(1.37)
      -    val model = trainer.fit(binaryDataset)
      -
      -    /**
      -     * Using the following R code to load the data and train the model using glmnet package.
      -     *
      -     * > library("glmnet")
      -     * > data <- read.csv("path", header=FALSE)
      -     * > label = factor(data$V1)
      -     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
      -     * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37,
      -     *     intercept=FALSE))
      -     * > weights
      -     * 5 x 1 sparse Matrix of class "dgCMatrix"
      -     *                      s0
      -     * (Intercept)   .
      -     * data.V2     -0.06099165
      -     * data.V3      0.12857058
      -     * data.V4     -0.04708770
      -     * data.V5     -0.09799775
      +    val trainer1 = (new LogisticRegression).setFitIntercept(false)
      +      .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true)
      +    val trainer2 = (new LogisticRegression).setFitIntercept(false)
      +      .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(false)
      +
      +    val model1 = trainer1.fit(binaryDataset)
      +    val model2 = trainer2.fit(binaryDataset)
      +
      +    /*
      +       Using the following R code to load the data and train the model using glmnet package.
      +
      +       library("glmnet")
      +       data <- read.csv("path", header=FALSE)
      +       label = factor(data$V1)
      +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
      +       weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37,
      +           intercept=FALSE))
      +       weights
      +
      +       5 x 1 sparse Matrix of class "dgCMatrix"
      +                            s0
      +       (Intercept)   .
      +       data.V2     -0.06099165
      +       data.V3      0.12857058
      +       data.V4     -0.04708770
      +       data.V5     -0.09799775
            */
      -    val interceptR = 0.0
      -    val weightsR = Array(-0.06099165, 0.12857058, -0.04708770, -0.09799775)
      +    val interceptR1 = 0.0
      +    val weightsR1 = Vectors.dense(-0.06099165, 0.12857058, -0.04708770, -0.09799775)
      +
      +    assert(model1.intercept ~== interceptR1 absTol 1E-3)
      +    assert(model1.weights ~= weightsR1 relTol 1E-2)
      +
      +    /*
      +       Using the following R code to load the data and train the model using glmnet package.
      +
      +       library("glmnet")
      +       data <- read.csv("path", header=FALSE)
      +       label = factor(data$V1)
      +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
      +       weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37,
      +           intercept=FALSE, standardize=FALSE))
      +       weights
      +
      +       5 x 1 sparse Matrix of class "dgCMatrix"
      +                             s0
      +       (Intercept)   .
      +       data.V2     -0.005679651
      +       data.V3      0.048967094
      +       data.V4     -0.093714016
      +       data.V5     -0.053314311
      +     */
      +    val interceptR2 = 0.0
      +    val weightsR2 = Vectors.dense(-0.005679651, 0.048967094, -0.093714016, -0.053314311)
       
      -    assert(model.intercept ~== interceptR relTol 1E-3)
      -    assert(model.weights(0) ~== weightsR(0) relTol 1E-2)
      -    assert(model.weights(1) ~== weightsR(1) relTol 1E-2)
      -    assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
      -    assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
      +    assert(model2.intercept ~== interceptR2 absTol 1E-3)
      +    assert(model2.weights ~= weightsR2 relTol 1E-2)
         }
       
         test("binary logistic regression with intercept with ElasticNet regularization") {
      -    val trainer = (new LogisticRegression).setFitIntercept(true)
      -      .setElasticNetParam(0.38).setRegParam(0.21)
      -    val model = trainer.fit(binaryDataset)
      -
      -    /**
      -     * Using the following R code to load the data and train the model using glmnet package.
      -     *
      -     * > library("glmnet")
      -     * > data <- read.csv("path", header=FALSE)
      -     * > label = factor(data$V1)
      -     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
      -     * > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21))
      -     * > weights
      -     * 5 x 1 sparse Matrix of class "dgCMatrix"
      -     *                      s0
      -     * (Intercept)  0.57734851
      -     * data.V2     -0.05310287
      -     * data.V3       .
      -     * data.V4     -0.08849250
      -     * data.V5     -0.15458796
      +    val trainer1 = (new LogisticRegression).setFitIntercept(true)
      +      .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
      +    val trainer2 = (new LogisticRegression).setFitIntercept(true)
      +      .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
      +
      +    val model1 = trainer1.fit(binaryDataset)
      +    val model2 = trainer2.fit(binaryDataset)
      +
      +    /*
      +       Using the following R code to load the data and train the model using glmnet package.
      +
      +       library("glmnet")
      +       data <- read.csv("path", header=FALSE)
      +       label = factor(data$V1)
      +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
      +       weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21))
      +       weights
      +
      +       5 x 1 sparse Matrix of class "dgCMatrix"
      +                            s0
      +       (Intercept)  0.57734851
      +       data.V2     -0.05310287
      +       data.V3       .
      +       data.V4     -0.08849250
      +       data.V5     -0.15458796
      +     */
      +    val interceptR1 = 0.57734851
      +    val weightsR1 = Vectors.dense(-0.05310287, 0.0, -0.08849250, -0.15458796)
      +
      +    assert(model1.intercept ~== interceptR1 relTol 6E-3)
      +    assert(model1.weights ~== weightsR1 absTol 5E-3)
      +
      +    /*
      +       Using the following R code to load the data and train the model using glmnet package.
      +
      +       library("glmnet")
      +       data <- read.csv("path", header=FALSE)
      +       label = factor(data$V1)
      +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
      +       weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21,
      +           standardize=FALSE))
      +       weights
      +
      +       5 x 1 sparse Matrix of class "dgCMatrix"
      +                            s0
      +       (Intercept)  0.51555993
      +       data.V2       .
      +       data.V3       .
      +       data.V4     -0.18807395
      +       data.V5     -0.05350074
            */
      -    val interceptR = 0.57734851
      -    val weightsR = Array(-0.05310287, 0.0, -0.08849250, -0.15458796)
      -
      -    assert(model.intercept ~== interceptR relTol 6E-3)
      -    assert(model.weights(0) ~== weightsR(0) relTol 5E-3)
      -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
      -    assert(model.weights(2) ~== weightsR(2) relTol 5E-3)
      -    assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
      +    val interceptR2 = 0.51555993
      +    val weightsR2 = Vectors.dense(0.0, 0.0, -0.18807395, -0.05350074)
      +
      +    assert(model2.intercept ~== interceptR2 relTol 6E-3)
      +    assert(model2.weights ~= weightsR2 absTol 1E-3)
         }
       
         test("binary logistic regression without intercept with ElasticNet regularization") {
      -    val trainer = (new LogisticRegression).setFitIntercept(false)
      -      .setElasticNetParam(0.38).setRegParam(0.21)
      -    val model = trainer.fit(binaryDataset)
      -
      -    /**
      -     * Using the following R code to load the data and train the model using glmnet package.
      -     *
      -     * > library("glmnet")
      -     * > data <- read.csv("path", header=FALSE)
      -     * > label = factor(data$V1)
      -     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
      -     * > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21,
      -     *     intercept=FALSE))
      -     * > weights
      -     * 5 x 1 sparse Matrix of class "dgCMatrix"
      -     *                      s0
      -     * (Intercept)   .
      -     * data.V2     -0.001005743
      -     * data.V3      0.072577857
      -     * data.V4     -0.081203769
      -     * data.V5     -0.142534158
      +    val trainer1 = (new LogisticRegression).setFitIntercept(false)
      +      .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
      +    val trainer2 = (new LogisticRegression).setFitIntercept(false)
      +      .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
      +
      +    val model1 = trainer1.fit(binaryDataset)
      +    val model2 = trainer2.fit(binaryDataset)
      +
      +    /*
      +       Using the following R code to load the data and train the model using glmnet package.
      +
      +       library("glmnet")
      +       data <- read.csv("path", header=FALSE)
      +       label = factor(data$V1)
      +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
      +       weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21,
      +           intercept=FALSE))
      +       weights
      +
      +       5 x 1 sparse Matrix of class "dgCMatrix"
      +                            s0
      +       (Intercept)   .
      +       data.V2     -0.001005743
      +       data.V3      0.072577857
      +       data.V4     -0.081203769
      +       data.V5     -0.142534158
            */
      -    val interceptR = 0.0
      -    val weightsR = Array(-0.001005743, 0.072577857, -0.081203769, -0.142534158)
      +    val interceptR1 = 0.0
      +    val weightsR1 = Vectors.dense(-0.001005743, 0.072577857, -0.081203769, -0.142534158)
      +
      +    assert(model1.intercept ~== interceptR1 relTol 1E-3)
      +    assert(model1.weights ~= weightsR1 absTol 1E-2)
      +
      +    /*
      +       Using the following R code to load the data and train the model using glmnet package.
      +
      +       library("glmnet")
      +       data <- read.csv("path", header=FALSE)
      +       label = factor(data$V1)
      +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
      +       weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21,
      +           intercept=FALSE, standardize=FALSE))
      +       weights
      +
      +       5 x 1 sparse Matrix of class "dgCMatrix"
      +                            s0
      +       (Intercept)   .
      +       data.V2       .
      +       data.V3      0.03345223
      +       data.V4     -0.11304532
      +       data.V5       .
      +     */
      +    val interceptR2 = 0.0
      +    val weightsR2 = Vectors.dense(0.0, 0.03345223, -0.11304532, 0.0)
       
      -    assert(model.intercept ~== interceptR relTol 1E-3)
      -    assert(model.weights(0) ~== weightsR(0) absTol 1E-3)
      -    assert(model.weights(1) ~== weightsR(1) absTol 1E-2)
      -    assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
      -    assert(model.weights(3) ~== weightsR(3) relTol 1E-2)
      +    assert(model2.intercept ~== interceptR2 absTol 1E-3)
      +    assert(model2.weights ~= weightsR2 absTol 1E-3)
         }
       
         test("binary logistic regression with intercept with strong L1 regularization") {
      -    val trainer = (new LogisticRegression).setFitIntercept(true)
      -      .setElasticNetParam(1.0).setRegParam(6.0)
      -    val model = trainer.fit(binaryDataset)
      +    val trainer1 = (new LogisticRegression).setFitIntercept(true)
      +      .setElasticNetParam(1.0).setRegParam(6.0).setStandardization(true)
      +    val trainer2 = (new LogisticRegression).setFitIntercept(true)
      +      .setElasticNetParam(1.0).setRegParam(6.0).setStandardization(false)
      +
      +    val model1 = trainer1.fit(binaryDataset)
      +    val model2 = trainer2.fit(binaryDataset)
       
           val histogram = binaryDataset.map { case Row(label: Double, features: Vector) => label }
             .treeAggregate(new MultiClassSummarizer)(
      @@ -480,50 +727,142 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
                   classSummarizer1.merge(classSummarizer2)
               }).histogram
       
      -    /**
      -     * For binary logistic regression with strong L1 regularization, all the weights will be zeros.
      -     * As a result,
      -     * {{{
      -     * P(0) = 1 / (1 + \exp(b)), and
      -     * P(1) = \exp(b) / (1 + \exp(b))
      -     * }}}, hence
      -     * {{{
      -     * b = \log{P(1) / P(0)} = \log{count_1 / count_0}
      -     * }}}
      +    /*
      +       For binary logistic regression with strong L1 regularization, all the weights will be zeros.
      +       As a result,
      +       {{{
      +       P(0) = 1 / (1 + \exp(b)), and
      +       P(1) = \exp(b) / (1 + \exp(b))
      +       }}}, hence
      +       {{{
      +       b = \log{P(1) / P(0)} = \log{count_1 / count_0}
      +       }}}
            */
      -    val interceptTheory = math.log(histogram(1).toDouble / histogram(0).toDouble)
      -    val weightsTheory = Array(0.0, 0.0, 0.0, 0.0)
      -
      -    assert(model.intercept ~== interceptTheory relTol 1E-5)
      -    assert(model.weights(0) ~== weightsTheory(0) absTol 1E-6)
      -    assert(model.weights(1) ~== weightsTheory(1) absTol 1E-6)
      -    assert(model.weights(2) ~== weightsTheory(2) absTol 1E-6)
      -    assert(model.weights(3) ~== weightsTheory(3) absTol 1E-6)
      -
      -    /**
      -     * Using the following R code to load the data and train the model using glmnet package.
      -     *
      -     * > library("glmnet")
      -     * > data <- read.csv("path", header=FALSE)
      -     * > label = factor(data$V1)
      -     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
      -     * > weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0))
      -     * > weights
      -     * 5 x 1 sparse Matrix of class "dgCMatrix"
      -     *                      s0
      -     * (Intercept) -0.2480643
      -     * data.V2      0.0000000
      -     * data.V3       .
      -     * data.V4       .
      -     * data.V5       .
      +    val interceptTheory = math.log(histogram(1) / histogram(0))
      +    val weightsTheory = Vectors.dense(0.0, 0.0, 0.0, 0.0)
      +
      +    assert(model1.intercept ~== interceptTheory relTol 1E-5)
      +    assert(model1.weights ~= weightsTheory absTol 1E-6)
      +
      +    assert(model2.intercept ~== interceptTheory relTol 1E-5)
      +    assert(model2.weights ~= weightsTheory absTol 1E-6)
      +
      +    /*
      +       Using the following R code to load the data and train the model using glmnet package.
      +
      +       library("glmnet")
      +       data <- read.csv("path", header=FALSE)
      +       label = factor(data$V1)
      +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
      +       weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0))
      +       weights
      +
      +       5 x 1 sparse Matrix of class "dgCMatrix"
      +                            s0
      +       (Intercept) -0.2480643
      +       data.V2      0.0000000
      +       data.V3       .
      +       data.V4       .
      +       data.V5       .
            */
           val interceptR = -0.248065
      -    val weightsR = Array(0.0, 0.0, 0.0, 0.0)
      +    val weightsR = Vectors.dense(0.0, 0.0, 0.0, 0.0)
      +
      +    assert(model1.intercept ~== interceptR relTol 1E-5)
      +    assert(model1.weights ~== weightsR absTol 1E-6)
      +  }
      +
      +  test("evaluate on test set") {
      +    // Evaluate on test set should be same as that of the transformed training data.
      +    val lr = new LogisticRegression()
      +      .setMaxIter(10)
      +      .setRegParam(1.0)
      +      .setThreshold(0.6)
      +    val model = lr.fit(dataset)
      +    val summary = model.summary.asInstanceOf[BinaryLogisticRegressionSummary]
      +
      +    val sameSummary = model.evaluate(dataset).asInstanceOf[BinaryLogisticRegressionSummary]
      +    assert(summary.areaUnderROC === sameSummary.areaUnderROC)
      +    assert(summary.roc.collect() === sameSummary.roc.collect())
      +    assert(summary.pr.collect === sameSummary.pr.collect())
      +    assert(
      +      summary.fMeasureByThreshold.collect() === sameSummary.fMeasureByThreshold.collect())
      +    assert(summary.recallByThreshold.collect() === sameSummary.recallByThreshold.collect())
      +    assert(
      +      summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect())
      +  }
      +
      +  test("statistics on training data") {
      +    // Test that loss is monotonically decreasing.
      +    val lr = new LogisticRegression()
      +      .setMaxIter(10)
      +      .setRegParam(1.0)
      +      .setThreshold(0.6)
      +    val model = lr.fit(dataset)
      +    assert(
      +      model.summary
      +        .objectiveHistory
      +        .sliding(2)
      +        .forall(x => x(0) >= x(1)))
      +
      +  }
      +
      +  test("binary logistic regression with weighted samples") {
      +    val (dataset, weightedDataset) = {
      +      val nPoints = 1000
      +      val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191)
      +      val xMean = Array(5.843, 3.057, 3.758, 1.199)
      +      val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
      +      val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42)
      +
      +      // Let's over-sample the positive samples twice.
      +      val data1 = testData.flatMap { case labeledPoint: LabeledPoint =>
      +        if (labeledPoint.label == 1.0) {
      +          Iterator(labeledPoint, labeledPoint)
      +        } else {
      +          Iterator(labeledPoint)
      +        }
      +      }
      +
      +      val rnd = new Random(8392)
      +      val data2 = testData.flatMap { case LabeledPoint(label: Double, features: Vector) =>
      +        if (rnd.nextGaussian() > 0.0) {
      +          if (label == 1.0) {
      +            Iterator(
      +              Instance(label, 1.2, features),
      +              Instance(label, 0.8, features),
      +              Instance(0.0, 0.0, features))
      +          } else {
      +            Iterator(
      +              Instance(label, 0.3, features),
      +              Instance(1.0, 0.0, features),
      +              Instance(label, 0.1, features),
      +              Instance(label, 0.6, features))
      +          }
      +        } else {
      +          if (label == 1.0) {
      +            Iterator(Instance(label, 2.0, features))
      +          } else {
      +            Iterator(Instance(label, 1.0, features))
      +          }
      +        }
      +      }
      +
      +      (sqlContext.createDataFrame(sc.parallelize(data1, 4)),
      +        sqlContext.createDataFrame(sc.parallelize(data2, 4)))
      +    }
      +
      +    val trainer1a = (new LogisticRegression).setFitIntercept(true)
      +      .setRegParam(0.0).setStandardization(true)
      +    val trainer1b = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight")
      +      .setRegParam(0.0).setStandardization(true)
      +    val model1a0 = trainer1a.fit(dataset)
      +    val model1a1 = trainer1a.fit(weightedDataset)
      +    val model1b = trainer1b.fit(weightedDataset)
      +    assert(model1a0.weights !~= model1a1.weights absTol 1E-3)
      +    assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3)
      +    assert(model1a0.weights ~== model1b.weights absTol 1E-3)
      +    assert(model1a0.intercept ~== model1b.intercept absTol 1E-3)
       
      -    assert(model.intercept ~== interceptR relTol 1E-5)
      -    assert(model.weights(0) ~== weightsR(0) absTol 1E-6)
      -    assert(model.weights(1) ~== weightsR(1) absTol 1E-6)
      -    assert(model.weights(2) ~== weightsR(2) absTol 1E-6)
      -    assert(model.weights(3) ~== weightsR(3) absTol 1E-6)
         }
       }
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
      new file mode 100644
      index 000000000000..ddc948f65df4
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
      @@ -0,0 +1,91 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.classification
      +
      +import org.apache.spark.SparkFunSuite
      +import org.apache.spark.mllib.classification.LogisticRegressionSuite._
      +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
      +import org.apache.spark.mllib.evaluation.MulticlassMetrics
      +import org.apache.spark.mllib.linalg.Vectors
      +import org.apache.spark.mllib.util.MLlibTestSparkContext
      +import org.apache.spark.mllib.util.TestingUtils._
      +import org.apache.spark.sql.Row
      +
      +class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
      +
      +  test("XOR function learning as binary classification problem with two outputs.") {
      +    val dataFrame = sqlContext.createDataFrame(Seq(
      +        (Vectors.dense(0.0, 0.0), 0.0),
      +        (Vectors.dense(0.0, 1.0), 1.0),
      +        (Vectors.dense(1.0, 0.0), 1.0),
      +        (Vectors.dense(1.0, 1.0), 0.0))
      +    ).toDF("features", "label")
      +    val layers = Array[Int](2, 5, 2)
      +    val trainer = new MultilayerPerceptronClassifier()
      +      .setLayers(layers)
      +      .setBlockSize(1)
      +      .setSeed(11L)
      +      .setMaxIter(100)
      +    val model = trainer.fit(dataFrame)
      +    val result = model.transform(dataFrame)
      +    val predictionAndLabels = result.select("prediction", "label").collect()
      +    predictionAndLabels.foreach { case Row(p: Double, l: Double) =>
      +      assert(p == l)
      +    }
      +  }
      +
      +  // TODO: implement a more rigorous test
      +  test("3 class classification with 2 hidden layers") {
      +    val nPoints = 1000
      +
      +    // The following weights are taken from OneVsRestSuite.scala
      +    // they represent 3-class iris dataset
      +    val weights = Array(
      +      -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
      +      -0.16624, -0.84355, -0.048509, -0.301789, 4.170682)
      +
      +    val xMean = Array(5.843, 3.057, 3.758, 1.199)
      +    val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
      +    val rdd = sc.parallelize(generateMultinomialLogisticInput(
      +      weights, xMean, xVariance, true, nPoints, 42), 2)
      +    val dataFrame = sqlContext.createDataFrame(rdd).toDF("label", "features")
      +    val numClasses = 3
      +    val numIterations = 100
      +    val layers = Array[Int](4, 5, 4, numClasses)
      +    val trainer = new MultilayerPerceptronClassifier()
      +      .setLayers(layers)
      +      .setBlockSize(1)
      +      .setSeed(11L)
      +      .setMaxIter(numIterations)
      +    val model = trainer.fit(dataFrame)
      +    val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label")
      +      .map { case Row(p: Double, l: Double) => (p, l) }
      +    // train multinomial logistic regression
      +    val lr = new LogisticRegressionWithLBFGS()
      +      .setIntercept(true)
      +      .setNumClasses(numClasses)
      +    lr.optimizer.setRegParam(0.0)
      +      .setNumIterations(numIterations)
      +    val lrModel = lr.run(rdd)
      +    val lrPredictionAndLabels = lrModel.predict(rdd.map(_.features)).zip(rdd.map(_.label))
      +    // MLP's predictions should not differ a lot from LR's.
      +    val lrMetrics = new MulticlassMetrics(lrPredictionAndLabels)
      +    val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels)
      +    assert(mlpMetrics.confusionMatrix ~== lrMetrics.confusionMatrix absTol 100)
      +  }
      +}
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
      new file mode 100644
      index 000000000000..98bc9511163e
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
      @@ -0,0 +1,164 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.classification
      +
      +import breeze.linalg.{Vector => BV}
      +
      +import org.apache.spark.SparkFunSuite
      +import org.apache.spark.ml.param.ParamsSuite
      +import org.apache.spark.mllib.classification.NaiveBayes.{Multinomial, Bernoulli}
      +import org.apache.spark.mllib.linalg._
      +import org.apache.spark.mllib.util.MLlibTestSparkContext
      +import org.apache.spark.mllib.util.TestingUtils._
      +import org.apache.spark.mllib.classification.NaiveBayesSuite._
      +import org.apache.spark.sql.DataFrame
      +import org.apache.spark.sql.Row
      +
      +class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
      +
      +  def validatePrediction(predictionAndLabels: DataFrame): Unit = {
      +    val numOfErrorPredictions = predictionAndLabels.collect().count {
      +      case Row(prediction: Double, label: Double) =>
      +        prediction != label
      +    }
      +    // At least 80% of the predictions should be on.
      +    assert(numOfErrorPredictions < predictionAndLabels.count() / 5)
      +  }
      +
      +  def validateModelFit(
      +      piData: Vector,
      +      thetaData: Matrix,
      +      model: NaiveBayesModel): Unit = {
      +    assert(Vectors.dense(model.pi.toArray.map(math.exp)) ~==
      +      Vectors.dense(piData.toArray.map(math.exp)) absTol 0.05, "pi mismatch")
      +    assert(model.theta.map(math.exp) ~== thetaData.map(math.exp) absTol 0.05, "theta mismatch")
      +  }
      +
      +  def expectedMultinomialProbabilities(model: NaiveBayesModel, feature: Vector): Vector = {
      +    val logClassProbs: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze
      +    val classProbs = logClassProbs.toArray.map(math.exp)
      +    val classProbsSum = classProbs.sum
      +    Vectors.dense(classProbs.map(_ / classProbsSum))
      +  }
      +
      +  def expectedBernoulliProbabilities(model: NaiveBayesModel, feature: Vector): Vector = {
      +    val negThetaMatrix = model.theta.map(v => math.log(1.0 - math.exp(v)))
      +    val negFeature = Vectors.dense(feature.toArray.map(v => 1.0 - v))
      +    val piTheta: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze
      +    val logClassProbs: BV[Double] = piTheta + negThetaMatrix.multiply(negFeature).toBreeze
      +    val classProbs = logClassProbs.toArray.map(math.exp)
      +    val classProbsSum = classProbs.sum
      +    Vectors.dense(classProbs.map(_ / classProbsSum))
      +  }
      +
      +  def validateProbabilities(
      +      featureAndProbabilities: DataFrame,
      +      model: NaiveBayesModel,
      +      modelType: String): Unit = {
      +    featureAndProbabilities.collect().foreach {
      +      case Row(features: Vector, probability: Vector) => {
      +        assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10)
      +        val expected = modelType match {
      +          case Multinomial =>
      +            expectedMultinomialProbabilities(model, features)
      +          case Bernoulli =>
      +            expectedBernoulliProbabilities(model, features)
      +          case _ =>
      +            throw new UnknownError(s"Invalid modelType: $modelType.")
      +        }
      +        assert(probability ~== expected relTol 1.0e-10)
      +      }
      +    }
      +  }
      +
      +  test("params") {
      +    ParamsSuite.checkParams(new NaiveBayes)
      +    val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)),
      +      theta = new DenseMatrix(2, 3, Array(0.1, 0.2, 0.3, 0.4, 0.6, 0.4)))
      +    ParamsSuite.checkParams(model)
      +  }
      +
      +  test("naive bayes: default params") {
      +    val nb = new NaiveBayes
      +    assert(nb.getLabelCol === "label")
      +    assert(nb.getFeaturesCol === "features")
      +    assert(nb.getPredictionCol === "prediction")
      +    assert(nb.getSmoothing === 1.0)
      +    assert(nb.getModelType === "multinomial")
      +  }
      +
      +  test("Naive Bayes Multinomial") {
      +    val nPoints = 1000
      +    val piArray = Array(0.5, 0.1, 0.4).map(math.log)
      +    val thetaArray = Array(
      +      Array(0.70, 0.10, 0.10, 0.10), // label 0
      +      Array(0.10, 0.70, 0.10, 0.10), // label 1
      +      Array(0.10, 0.10, 0.70, 0.10)  // label 2
      +    ).map(_.map(math.log))
      +    val pi = Vectors.dense(piArray)
      +    val theta = new DenseMatrix(3, 4, thetaArray.flatten, true)
      +
      +    val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
      +      piArray, thetaArray, nPoints, 42, "multinomial"))
      +    val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial")
      +    val model = nb.fit(testDataset)
      +
      +    validateModelFit(pi, theta, model)
      +    assert(model.hasParent)
      +
      +    val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
      +      piArray, thetaArray, nPoints, 17, "multinomial"))
      +
      +    val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
      +    validatePrediction(predictionAndLabels)
      +
      +    val featureAndProbabilities = model.transform(validationDataset)
      +      .select("features", "probability")
      +    validateProbabilities(featureAndProbabilities, model, "multinomial")
      +  }
      +
      +  test("Naive Bayes Bernoulli") {
      +    val nPoints = 10000
      +    val piArray = Array(0.5, 0.3, 0.2).map(math.log)
      +    val thetaArray = Array(
      +      Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0
      +      Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1
      +      Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30)  // label 2
      +    ).map(_.map(math.log))
      +    val pi = Vectors.dense(piArray)
      +    val theta = new DenseMatrix(3, 12, thetaArray.flatten, true)
      +
      +    val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
      +      piArray, thetaArray, nPoints, 45, "bernoulli"))
      +    val nb = new NaiveBayes().setSmoothing(1.0).setModelType("bernoulli")
      +    val model = nb.fit(testDataset)
      +
      +    validateModelFit(pi, theta, model)
      +    assert(model.hasParent)
      +
      +    val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
      +      piArray, thetaArray, nPoints, 20, "bernoulli"))
      +
      +    val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
      +    validatePrediction(predictionAndLabels)
      +
      +    val featureAndProbabilities = model.transform(validationDataset)
      +      .select("features", "probability")
      +    validateProbabilities(featureAndProbabilities, model, "bernoulli")
      +  }
      +}
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
      index 75cf5bd4ead4..977f0e0b70c1 100644
      --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
      @@ -19,8 +19,9 @@ package org.apache.spark.ml.classification
       
       import org.apache.spark.SparkFunSuite
       import org.apache.spark.ml.attribute.NominalAttribute
      +import org.apache.spark.ml.feature.StringIndexer
       import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
      -import org.apache.spark.ml.util.MetadataUtils
      +import org.apache.spark.ml.util.{MLTestingUtils, MetadataUtils}
       import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
       import org.apache.spark.mllib.classification.LogisticRegressionSuite._
       import org.apache.spark.mllib.evaluation.MulticlassMetrics
      @@ -69,6 +70,10 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
           assert(ova.getLabelCol === "label")
           assert(ova.getPredictionCol === "prediction")
           val ovaModel = ova.fit(dataset)
      +
      +    // copied model must have the same parent.
      +    MLTestingUtils.checkCopy(ovaModel)
      +
           assert(ovaModel.models.size === numClasses)
           val transformedDataset = ovaModel.transform(dataset)
       
      @@ -104,6 +109,29 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
           ova.fit(datasetWithLabelMetadata)
         }
       
      +  test("SPARK-8092: ensure label features and prediction cols are configurable") {
      +    val labelIndexer = new StringIndexer()
      +      .setInputCol("label")
      +      .setOutputCol("indexed")
      +
      +    val indexedDataset = labelIndexer
      +      .fit(dataset)
      +      .transform(dataset)
      +      .drop("label")
      +      .withColumnRenamed("features", "f")
      +
      +    val ova = new OneVsRest()
      +    ova.setClassifier(new LogisticRegression())
      +      .setLabelCol(labelIndexer.getOutputCol)
      +      .setFeaturesCol("f")
      +      .setPredictionCol("p")
      +
      +    val ovaModel = ova.fit(indexedDataset)
      +    val transformedDataset = ovaModel.transform(indexedDataset)
      +    val outputFields = transformedDataset.schema.fieldNames.toSet
      +    assert(outputFields.contains("p"))
      +  }
      +
         test("SPARK-8049: OneVsRest shouldn't output temp columns") {
           val logReg = new LogisticRegression()
             .setMaxIter(1)
      @@ -127,7 +155,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
           require(ovr1.getClassifier.getOrDefault(lr.maxIter) === 10,
             "copy should handle extra classifier params")
       
      -    val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.threshold -> 0.1))
      +    val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.thresholds -> Array(0.9, 0.1)))
           ovrModel.models.foreach { case m: LogisticRegressionModel =>
             require(m.getThreshold === 0.1, "copy should handle extra model params")
           }
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
      new file mode 100644
      index 000000000000..8f50cb924e64
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
      @@ -0,0 +1,57 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.classification
      +
      +import org.apache.spark.SparkFunSuite
      +import org.apache.spark.mllib.linalg.{Vector, Vectors}
      +
      +final class TestProbabilisticClassificationModel(
      +    override val uid: String,
      +    override val numClasses: Int)
      +  extends ProbabilisticClassificationModel[Vector, TestProbabilisticClassificationModel] {
      +
      +  override def copy(extra: org.apache.spark.ml.param.ParamMap): this.type = defaultCopy(extra)
      +
      +  override protected def predictRaw(input: Vector): Vector = {
      +    input
      +  }
      +
      +  override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
      +    rawPrediction
      +  }
      +
      +  def friendlyPredict(input: Vector): Double = {
      +    predict(input)
      +  }
      +}
      +
      +
      +class ProbabilisticClassifierSuite extends SparkFunSuite {
      +
      +  test("test thresholding") {
      +    val thresholds = Array(0.5, 0.2)
      +    val testModel = new TestProbabilisticClassificationModel("myuid", 2).setThresholds(thresholds)
      +    assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) === 1.0)
      +    assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) === 0.0)
      +  }
      +
      +  test("test thresholding not required") {
      +    val testModel = new TestProbabilisticClassificationModel("myuid", 2)
      +    assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0)
      +  }
      +}
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
      index 1b6b69c7dc71..b4403ec30049 100644
      --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
      @@ -21,13 +21,15 @@ import org.apache.spark.SparkFunSuite
       import org.apache.spark.ml.impl.TreeTests
       import org.apache.spark.ml.param.ParamsSuite
       import org.apache.spark.ml.tree.LeafNode
      -import org.apache.spark.mllib.linalg.Vectors
      +import org.apache.spark.ml.util.MLTestingUtils
      +import org.apache.spark.mllib.linalg.{Vector, Vectors}
       import org.apache.spark.mllib.regression.LabeledPoint
       import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
       import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
       import org.apache.spark.mllib.util.MLlibTestSparkContext
      +import org.apache.spark.mllib.util.TestingUtils._
       import org.apache.spark.rdd.RDD
      -import org.apache.spark.sql.DataFrame
      +import org.apache.spark.sql.{DataFrame, Row}
       
       /**
        * Test suite for [[RandomForestClassifier]].
      @@ -66,7 +68,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
         test("params") {
           ParamsSuite.checkParams(new RandomForestClassifier)
           val model = new RandomForestClassificationModel("rfc",
      -      Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))))
      +      Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2, 2)
           ParamsSuite.checkParams(model)
         }
       
      @@ -121,6 +123,65 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
           compareAPIs(rdd, rf2, categoricalFeatures, numClasses)
         }
       
      +  test("predictRaw and predictProbability") {
      +    val rdd = orderedLabeledPoints5_20
      +    val rf = new RandomForestClassifier()
      +      .setImpurity("Gini")
      +      .setMaxDepth(3)
      +      .setNumTrees(3)
      +      .setSeed(123)
      +    val categoricalFeatures = Map.empty[Int, Int]
      +    val numClasses = 2
      +
      +    val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
      +    val model = rf.fit(df)
      +
      +    // copied model must have the same parent.
      +    MLTestingUtils.checkCopy(model)
      +
      +    val predictions = model.transform(df)
      +      .select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol)
      +      .collect()
      +
      +    predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
      +      assert(pred === rawPred.argmax,
      +        s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
      +      val sum = rawPred.toArray.sum
      +      assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
      +        "probability prediction mismatch")
      +      assert(probPred.toArray.sum ~== 1.0 relTol 1E-5)
      +    }
      +  }
      +
      +  /////////////////////////////////////////////////////////////////////////////
      +  // Tests of feature importance
      +  /////////////////////////////////////////////////////////////////////////////
      +  test("Feature importance with toy data") {
      +    val numClasses = 2
      +    val rf = new RandomForestClassifier()
      +      .setImpurity("Gini")
      +      .setMaxDepth(3)
      +      .setNumTrees(3)
      +      .setFeatureSubsetStrategy("all")
      +      .setSubsamplingRate(1.0)
      +      .setSeed(123)
      +
      +    // In this data, feature 1 is very important.
      +    val data: RDD[LabeledPoint] = sc.parallelize(Seq(
      +      new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)),
      +      new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)),
      +      new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)),
      +      new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
      +      new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
      +    ))
      +    val categoricalFeatures = Map.empty[Int, Int]
      +    val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
      +
      +    val importances = rf.fit(df).featureImportances
      +    val mostImportantFeature = importances.argmax
      +    assert(mostImportantFeature === 1)
      +  }
      +
         /////////////////////////////////////////////////////////////////////////////
         // Tests of model save/load
         /////////////////////////////////////////////////////////////////////////////
      @@ -167,9 +228,11 @@ private object RandomForestClassifierSuite {
           val newModel = rf.fit(newData)
           // Use parent from newTree since this is not checked anyways.
           val oldModelAsNew = RandomForestClassificationModel.fromOld(
      -      oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures)
      +      oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures,
      +      numClasses)
           TreeTests.checkEqual(oldModelAsNew, newModel)
           assert(newModel.hasParent)
           assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent)
      +    assert(newModel.numClasses == numClasses)
         }
       }
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
      new file mode 100644
      index 000000000000..688b0e31f91d
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
      @@ -0,0 +1,108 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.clustering
      +
      +import org.apache.spark.SparkFunSuite
      +import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
      +import org.apache.spark.mllib.linalg.{Vector, Vectors}
      +import org.apache.spark.mllib.util.MLlibTestSparkContext
      +import org.apache.spark.sql.{DataFrame, SQLContext}
      +
      +private[clustering] case class TestRow(features: Vector)
      +
      +object KMeansSuite {
      +  def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = {
      +    val sc = sql.sparkContext
      +    val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble)))
      +      .map(v => new TestRow(v))
      +    sql.createDataFrame(rdd)
      +  }
      +}
      +
      +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
      +
      +  final val k = 5
      +  @transient var dataset: DataFrame = _
      +
      +  override def beforeAll(): Unit = {
      +    super.beforeAll()
      +
      +    dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k)
      +  }
      +
      +  test("default parameters") {
      +    val kmeans = new KMeans()
      +
      +    assert(kmeans.getK === 2)
      +    assert(kmeans.getFeaturesCol === "features")
      +    assert(kmeans.getPredictionCol === "prediction")
      +    assert(kmeans.getMaxIter === 20)
      +    assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL)
      +    assert(kmeans.getInitSteps === 5)
      +    assert(kmeans.getTol === 1e-4)
      +  }
      +
      +  test("set parameters") {
      +    val kmeans = new KMeans()
      +      .setK(9)
      +      .setFeaturesCol("test_feature")
      +      .setPredictionCol("test_prediction")
      +      .setMaxIter(33)
      +      .setInitMode(MLlibKMeans.RANDOM)
      +      .setInitSteps(3)
      +      .setSeed(123)
      +      .setTol(1e-3)
      +
      +    assert(kmeans.getK === 9)
      +    assert(kmeans.getFeaturesCol === "test_feature")
      +    assert(kmeans.getPredictionCol === "test_prediction")
      +    assert(kmeans.getMaxIter === 33)
      +    assert(kmeans.getInitMode === MLlibKMeans.RANDOM)
      +    assert(kmeans.getInitSteps === 3)
      +    assert(kmeans.getSeed === 123)
      +    assert(kmeans.getTol === 1e-3)
      +  }
      +
      +  test("parameters validation") {
      +    intercept[IllegalArgumentException] {
      +      new KMeans().setK(1)
      +    }
      +    intercept[IllegalArgumentException] {
      +      new KMeans().setInitMode("no_such_a_mode")
      +    }
      +    intercept[IllegalArgumentException] {
      +      new KMeans().setInitSteps(0)
      +    }
      +  }
      +
      +  test("fit & transform") {
      +    val predictionColName = "kmeans_prediction"
      +    val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
      +    val model = kmeans.fit(dataset)
      +    assert(model.clusterCenters.length === k)
      +
      +    val transformed = model.transform(dataset)
      +    val expectedColumns = Array("features", predictionColName)
      +    expectedColumns.foreach { column =>
      +      assert(transformed.columns.contains(column))
      +    }
      +    val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet
      +    assert(clusters.size === k)
      +    assert(clusters === Set(0, 1, 2, 3, 4))
      +  }
      +}
      diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
      similarity index 65%
      rename from core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala
      rename to mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
      index 8df4f3b554c4..6d8412b0b370 100644
      --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala
      +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
      @@ -15,17 +15,14 @@
        * limitations under the License.
        */
       
      -package org.apache.spark.scheduler.cluster.mesos
      +package org.apache.spark.ml.evaluation
       
      -import org.apache.spark.SparkContext
      +import org.apache.spark.SparkFunSuite
      +import org.apache.spark.ml.param.ParamsSuite
       
      -private[spark] object MemoryUtils {
      -  // These defaults copied from YARN
      -  val OVERHEAD_FRACTION = 0.10
      -  val OVERHEAD_MINIMUM = 384
      +class MulticlassClassificationEvaluatorSuite extends SparkFunSuite {
       
      -  def calculateTotalMemory(sc: SparkContext): Int = {
      -    sc.conf.getInt("spark.mesos.executor.memoryOverhead",
      -      math.max(OVERHEAD_FRACTION * sc.executorMemory, OVERHEAD_MINIMUM).toInt) + sc.executorMemory
      +  test("params") {
      +    ParamsSuite.checkParams(new MulticlassClassificationEvaluator)
         }
       }
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
      index 5b203784559e..aa722da32393 100644
      --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
      @@ -63,7 +63,7 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext
       
           // default = rmse
           val evaluator = new RegressionEvaluator()
      -    assert(evaluator.evaluate(predictions) ~== -0.1019382 absTol 0.001)
      +    assert(evaluator.evaluate(predictions) ~== 0.1019382 absTol 0.001)
       
           // r2 score
           evaluator.setMetricName("r2")
      @@ -71,6 +71,6 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext
       
           // mae
           evaluator.setMetricName("mae")
      -    assert(evaluator.evaluate(predictions) ~== -0.08036075 absTol 0.001)
      +    assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001)
         }
       }
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
      index ec85e0d151e0..0eba34fda622 100644
      --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
      @@ -21,6 +21,7 @@ import scala.util.Random
       
       import org.apache.spark.{SparkException, SparkFunSuite}
       import org.apache.spark.ml.param.ParamsSuite
      +import org.apache.spark.ml.util.MLTestingUtils
       import org.apache.spark.mllib.linalg.Vectors
       import org.apache.spark.mllib.util.MLlibTestSparkContext
       import org.apache.spark.mllib.util.TestingUtils._
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
      new file mode 100644
      index 000000000000..e192fa4850af
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
      @@ -0,0 +1,167 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +package org.apache.spark.ml.feature
      +
      +import org.apache.spark.SparkFunSuite
      +import org.apache.spark.ml.param.ParamsSuite
      +import org.apache.spark.mllib.linalg.{Vector, Vectors}
      +import org.apache.spark.mllib.util.MLlibTestSparkContext
      +import org.apache.spark.mllib.util.TestingUtils._
      +import org.apache.spark.sql.Row
      +
      +class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext {
      +
      +  test("params") {
      +    ParamsSuite.checkParams(new CountVectorizerModel(Array("empty")))
      +  }
      +
      +  private def split(s: String): Seq[String] = s.split("\\s+")
      +
      +  test("CountVectorizerModel common cases") {
      +    val df = sqlContext.createDataFrame(Seq(
      +      (0, split("a b c d"),
      +        Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))),
      +      (1, split("a b b c d  a"),
      +        Vectors.sparse(4, Seq((0, 2.0), (1, 2.0), (2, 1.0), (3, 1.0)))),
      +      (2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))),
      +      (3, split(""), Vectors.sparse(4, Seq())), // empty string
      +      (4, split("a notInDict d"),
      +        Vectors.sparse(4, Seq((0, 1.0), (3, 1.0))))  // with words not in vocabulary
      +    )).toDF("id", "words", "expected")
      +    val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
      +      .setInputCol("words")
      +      .setOutputCol("features")
      +    cv.transform(df).select("features", "expected").collect().foreach {
      +      case Row(features: Vector, expected: Vector) =>
      +        assert(features ~== expected absTol 1e-14)
      +    }
      +  }
      +
      +  test("CountVectorizer common cases") {
      +    val df = sqlContext.createDataFrame(Seq(
      +      (0, split("a b c d e"),
      +        Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))),
      +      (1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))),
      +      (2, split("c"), Vectors.sparse(5, Seq((2, 1.0)))),
      +      (3, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0)))))
      +    ).toDF("id", "words", "expected")
      +    val cv = new CountVectorizer()
      +      .setInputCol("words")
      +      .setOutputCol("features")
      +      .fit(df)
      +    assert(cv.vocabulary === Array("a", "b", "c", "d", "e"))
      +
      +    cv.transform(df).select("features", "expected").collect().foreach {
      +      case Row(features: Vector, expected: Vector) =>
      +        assert(features ~== expected absTol 1e-14)
      +    }
      +  }
      +
      +  test("CountVectorizer vocabSize and minDF") {
      +    val df = sqlContext.createDataFrame(Seq(
      +      (0, split("a b c d"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
      +      (1, split("a b c"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
      +      (2, split("a b"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
      +      (3, split("a"), Vectors.sparse(3, Seq((0, 1.0)))))
      +    ).toDF("id", "words", "expected")
      +    val cvModel = new CountVectorizer()
      +      .setInputCol("words")
      +      .setOutputCol("features")
      +      .setVocabSize(3)  // limit vocab size to 3
      +      .fit(df)
      +    assert(cvModel.vocabulary === Array("a", "b", "c"))
      +
      +    // minDF: ignore terms with count less than 3
      +    val cvModel2 = new CountVectorizer()
      +      .setInputCol("words")
      +      .setOutputCol("features")
      +      .setMinDF(3)
      +      .fit(df)
      +    assert(cvModel2.vocabulary === Array("a", "b"))
      +
      +    cvModel2.transform(df).select("features", "expected").collect().foreach {
      +      case Row(features: Vector, expected: Vector) =>
      +        assert(features ~== expected absTol 1e-14)
      +    }
      +
      +    // minDF: ignore terms with freq < 0.75
      +    val cvModel3 = new CountVectorizer()
      +      .setInputCol("words")
      +      .setOutputCol("features")
      +      .setMinDF(3.0 / df.count())
      +      .fit(df)
      +    assert(cvModel3.vocabulary === Array("a", "b"))
      +
      +    cvModel3.transform(df).select("features", "expected").collect().foreach {
      +      case Row(features: Vector, expected: Vector) =>
      +        assert(features ~== expected absTol 1e-14)
      +    }
      +  }
      +
      +  test("CountVectorizer throws exception when vocab is empty") {
      +    intercept[IllegalArgumentException] {
      +      val df = sqlContext.createDataFrame(Seq(
      +        (0, split("a a b b c c")),
      +        (1, split("aa bb cc")))
      +      ).toDF("id", "words")
      +      val cvModel = new CountVectorizer()
      +        .setInputCol("words")
      +        .setOutputCol("features")
      +        .setVocabSize(3) // limit vocab size to 3
      +        .setMinDF(3)
      +        .fit(df)
      +    }
      +  }
      +
      +  test("CountVectorizerModel with minTF count") {
      +    val df = sqlContext.createDataFrame(Seq(
      +      (0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
      +      (1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))),
      +      (2, split("a"), Vectors.sparse(4, Seq())),
      +      (3, split("e e e e e"), Vectors.sparse(4, Seq())))
      +    ).toDF("id", "words", "expected")
      +
      +    // minTF: count
      +    val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
      +      .setInputCol("words")
      +      .setOutputCol("features")
      +      .setMinTF(3)
      +    cv.transform(df).select("features", "expected").collect().foreach {
      +      case Row(features: Vector, expected: Vector) =>
      +        assert(features ~== expected absTol 1e-14)
      +    }
      +  }
      +
      +  test("CountVectorizerModel with minTF freq") {
      +    val df = sqlContext.createDataFrame(Seq(
      +      (0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
      +      (1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))),
      +      (2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))),
      +      (3, split("e e e e e"), Vectors.sparse(4, Seq())))
      +    ).toDF("id", "words", "expected")
      +
      +    // minTF: count
      +    val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
      +      .setInputCol("words")
      +      .setOutputCol("features")
      +      .setMinTF(0.3)
      +    cv.transform(df).select("features", "expected").collect().foreach {
      +      case Row(features: Vector, expected: Vector) =>
      +        assert(features ~== expected absTol 1e-14)
      +    }
      +  }
      +}
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
      new file mode 100644
      index 000000000000..37ed2367c33f
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
      @@ -0,0 +1,73 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.feature
      +
      +import scala.beans.BeanInfo
      +
      +import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D
      +
      +import org.apache.spark.SparkFunSuite
      +import org.apache.spark.mllib.linalg.{Vector, Vectors}
      +import org.apache.spark.mllib.util.MLlibTestSparkContext
      +import org.apache.spark.sql.{DataFrame, Row}
      +
      +@BeanInfo
      +case class DCTTestData(vec: Vector, wantedVec: Vector)
      +
      +class DCTSuite extends SparkFunSuite with MLlibTestSparkContext {
      +
      +  test("forward transform of discrete cosine matches jTransforms result") {
      +    val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray)
      +    val inverse = false
      +
      +    testDCT(data, inverse)
      +  }
      +
      +  test("inverse transform of discrete cosine matches jTransforms result") {
      +    val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray)
      +    val inverse = true
      +
      +    testDCT(data, inverse)
      +  }
      +
      +  private def testDCT(data: Vector, inverse: Boolean): Unit = {
      +    val expectedResultBuffer = data.toArray.clone()
      +    if (inverse) {
      +      (new DoubleDCT_1D(data.size)).inverse(expectedResultBuffer, true)
      +    } else {
      +      (new DoubleDCT_1D(data.size)).forward(expectedResultBuffer, true)
      +    }
      +    val expectedResult = Vectors.dense(expectedResultBuffer)
      +
      +    val dataset = sqlContext.createDataFrame(Seq(
      +      DCTTestData(data, expectedResult)
      +    ))
      +
      +    val transformer = new DCT()
      +      .setInputCol("vec")
      +      .setOutputCol("resultVec")
      +      .setInverse(inverse)
      +
      +    transformer.transform(dataset)
      +      .select("resultVec", "wantedVec")
      +      .collect()
      +      .foreach { case Row(resultVec: Vector, wantedVec: Vector) =>
      +      assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6)
      +    }
      +  }
      +}
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
      new file mode 100644
      index 000000000000..c04dda41eea3
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
      @@ -0,0 +1,72 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.feature
      +
      +import org.apache.spark.SparkFunSuite
      +import org.apache.spark.ml.util.MLTestingUtils
      +import org.apache.spark.mllib.linalg.{Vector, Vectors}
      +import org.apache.spark.mllib.util.MLlibTestSparkContext
      +import org.apache.spark.sql.{Row, SQLContext}
      +
      +class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext {
      +
      +  test("MinMaxScaler fit basic case") {
      +    val sqlContext = new SQLContext(sc)
      +
      +    val data = Array(
      +      Vectors.dense(1, 0, Long.MinValue),
      +      Vectors.dense(2, 0, 0),
      +      Vectors.sparse(3, Array(0, 2), Array(3, Long.MaxValue)),
      +      Vectors.sparse(3, Array(0), Array(1.5)))
      +
      +    val expected: Array[Vector] = Array(
      +      Vectors.dense(-5, 0, -5),
      +      Vectors.dense(0, 0, 0),
      +      Vectors.sparse(3, Array(0, 2), Array(5, 5)),
      +      Vectors.sparse(3, Array(0), Array(-2.5)))
      +
      +    val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")
      +    val scaler = new MinMaxScaler()
      +      .setInputCol("features")
      +      .setOutputCol("scaled")
      +      .setMin(-5)
      +      .setMax(5)
      +
      +    val model = scaler.fit(df)
      +    model.transform(df).select("expected", "scaled").collect()
      +      .foreach { case Row(vector1: Vector, vector2: Vector) =>
      +        assert(vector1.equals(vector2), "Transformed vector is different with expected.")
      +    }
      +
      +    // copied model must have the same parent.
      +    MLTestingUtils.checkCopy(model)
      +  }
      +
      +  test("MinMaxScaler arguments max must be larger than min") {
      +    withClue("arguments max must be larger than min") {
      +      intercept[IllegalArgumentException] {
      +        val scaler = new MinMaxScaler().setMin(10).setMax(0)
      +        scaler.validateParams()
      +      }
      +      intercept[IllegalArgumentException] {
      +        val scaler = new MinMaxScaler().setMin(0).setMax(0)
      +        scaler.validateParams()
      +      }
      +    }
      +  }
      +}
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
      new file mode 100644
      index 000000000000..ab97e3dbc6ee
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
      @@ -0,0 +1,94 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.feature
      +
      +import scala.beans.BeanInfo
      +
      +import org.apache.spark.SparkFunSuite
      +import org.apache.spark.mllib.util.MLlibTestSparkContext
      +import org.apache.spark.sql.{DataFrame, Row}
      +
      +@BeanInfo
      +case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String])
      +
      +class NGramSuite extends SparkFunSuite with MLlibTestSparkContext {
      +  import org.apache.spark.ml.feature.NGramSuite._
      +
      +  test("default behavior yields bigram features") {
      +    val nGram = new NGram()
      +      .setInputCol("inputTokens")
      +      .setOutputCol("nGrams")
      +    val dataset = sqlContext.createDataFrame(Seq(
      +      NGramTestData(
      +        Array("Test", "for", "ngram", "."),
      +        Array("Test for", "for ngram", "ngram .")
      +    )))
      +    testNGram(nGram, dataset)
      +  }
      +
      +  test("NGramLength=4 yields length 4 n-grams") {
      +    val nGram = new NGram()
      +      .setInputCol("inputTokens")
      +      .setOutputCol("nGrams")
      +      .setN(4)
      +    val dataset = sqlContext.createDataFrame(Seq(
      +      NGramTestData(
      +        Array("a", "b", "c", "d", "e"),
      +        Array("a b c d", "b c d e")
      +      )))
      +    testNGram(nGram, dataset)
      +  }
      +
      +  test("empty input yields empty output") {
      +    val nGram = new NGram()
      +      .setInputCol("inputTokens")
      +      .setOutputCol("nGrams")
      +      .setN(4)
      +    val dataset = sqlContext.createDataFrame(Seq(
      +      NGramTestData(
      +        Array(),
      +        Array()
      +      )))
      +    testNGram(nGram, dataset)
      +  }
      +
      +  test("input array < n yields empty output") {
      +    val nGram = new NGram()
      +      .setInputCol("inputTokens")
      +      .setOutputCol("nGrams")
      +      .setN(6)
      +    val dataset = sqlContext.createDataFrame(Seq(
      +      NGramTestData(
      +        Array("a", "b", "c", "d", "e"),
      +        Array()
      +      )))
      +    testNGram(nGram, dataset)
      +  }
      +}
      +
      +object NGramSuite extends SparkFunSuite {
      +
      +  def testNGram(t: NGram, dataset: DataFrame): Unit = {
      +    t.transform(dataset)
      +      .select("nGrams", "wantedNGrams")
      +      .collect()
      +      .foreach { case Row(actualNGrams, wantedNGrams) =>
      +        assert(actualNGrams === wantedNGrams)
      +      }
      +  }
      +}
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
      index 65846a846b7b..321eeb843941 100644
      --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
      @@ -86,8 +86,8 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
           val output = encoder.transform(df)
           val group = AttributeGroup.fromStructField(output.schema("encoded"))
           assert(group.size === 2)
      -    assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0))
      -    assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1))
      +    assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0))
      +    assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1))
         }
       
         test("input column without ML attribute") {
      @@ -98,7 +98,7 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
           val output = encoder.transform(df)
           val group = AttributeGroup.fromStructField(output.schema("encoded"))
           assert(group.size === 2)
      -    assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0))
      -    assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1))
      +    assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0))
      +    assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1))
         }
       }
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
      new file mode 100644
      index 000000000000..30c500f87a76
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
      @@ -0,0 +1,68 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.feature
      +
      +import org.apache.spark.SparkFunSuite
      +import org.apache.spark.ml.param.ParamsSuite
      +import org.apache.spark.ml.util.MLTestingUtils
      +import org.apache.spark.mllib.linalg.distributed.RowMatrix
      +import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, Matrices}
      +import org.apache.spark.mllib.util.MLlibTestSparkContext
      +import org.apache.spark.mllib.util.TestingUtils._
      +import org.apache.spark.mllib.feature.{PCAModel => OldPCAModel}
      +import org.apache.spark.sql.Row
      +
      +class PCASuite extends SparkFunSuite with MLlibTestSparkContext {
      +
      +  test("params") {
      +    ParamsSuite.checkParams(new PCA)
      +    val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix]
      +    val model = new PCAModel("pca", new OldPCAModel(2, mat))
      +    ParamsSuite.checkParams(model)
      +  }
      +
      +  test("pca") {
      +    val data = Array(
      +      Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))),
      +      Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
      +      Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
      +    )
      +
      +    val dataRDD = sc.parallelize(data, 2)
      +
      +    val mat = new RowMatrix(dataRDD)
      +    val pc = mat.computePrincipalComponents(3)
      +    val expected = mat.multiply(pc).rows
      +
      +    val df = sqlContext.createDataFrame(dataRDD.zip(expected)).toDF("features", "expected")
      +
      +    val pca = new PCA()
      +      .setInputCol("features")
      +      .setOutputCol("pca_features")
      +      .setK(3)
      +      .fit(df)
      +
      +    // copied model must have the same parent.
      +    MLTestingUtils.checkCopy(pca)
      +
      +    pca.transform(df).select("pca_features", "expected").collect().foreach {
      +      case Row(x: Vector, y: Vector) =>
      +        assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
      +    }
      +  }
      +}
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
      new file mode 100644
      index 000000000000..436e66bab09b
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
      @@ -0,0 +1,82 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.feature
      +
      +import org.apache.spark.SparkFunSuite
      +import org.apache.spark.sql.types._
      +
      +class RFormulaParserSuite extends SparkFunSuite {
      +  private def checkParse(
      +      formula: String,
      +      label: String,
      +      terms: Seq[String],
      +      schema: StructType = null) {
      +    val resolved = RFormulaParser.parse(formula).resolve(schema)
      +    assert(resolved.label == label)
      +    assert(resolved.terms == terms)
      +  }
      +
      +  test("parse simple formulas") {
      +    checkParse("y ~ x", "y", Seq("x"))
      +    checkParse("y ~ x + x", "y", Seq("x"))
      +    checkParse("y ~   ._foo  ", "y", Seq("._foo"))
      +    checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123"))
      +  }
      +
      +  test("parse dot") {
      +    val schema = (new StructType)
      +      .add("a", "int", true)
      +      .add("b", "long", false)
      +      .add("c", "string", true)
      +    checkParse("a ~ .", "a", Seq("b", "c"), schema)
      +  }
      +
      +  test("parse deletion") {
      +    val schema = (new StructType)
      +      .add("a", "int", true)
      +      .add("b", "long", false)
      +      .add("c", "string", true)
      +    checkParse("a ~ c - b", "a", Seq("c"), schema)
      +  }
      +
      +  test("parse additions and deletions in order") {
      +    val schema = (new StructType)
      +      .add("a", "int", true)
      +      .add("b", "long", false)
      +      .add("c", "string", true)
      +    checkParse("a ~ . - b + . - c", "a", Seq("b"), schema)
      +  }
      +
      +  test("dot ignores complex column types") {
      +    val schema = (new StructType)
      +      .add("a", "int", true)
      +      .add("b", "tinyint", false)
      +      .add("c", "map", true)
      +    checkParse("a ~ .", "a", Seq("b"), schema)
      +  }
      +
      +  test("parse intercept") {
      +    assert(RFormulaParser.parse("a ~ b").hasIntercept)
      +    assert(RFormulaParser.parse("a ~ b + 1").hasIntercept)
      +    assert(RFormulaParser.parse("a ~ b - 0").hasIntercept)
      +    assert(RFormulaParser.parse("a ~ b - 1 + 1").hasIntercept)
      +    assert(!RFormulaParser.parse("a ~ b + 0").hasIntercept)
      +    assert(!RFormulaParser.parse("a ~ b - 1").hasIntercept)
      +    assert(!RFormulaParser.parse("a ~ b + 1 - 1").hasIntercept)
      +  }
      +}
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
      new file mode 100644
      index 000000000000..6aed3243afce
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
      @@ -0,0 +1,126 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.feature
      +
      +import org.apache.spark.SparkFunSuite
      +import org.apache.spark.ml.attribute._
      +import org.apache.spark.ml.param.ParamsSuite
      +import org.apache.spark.mllib.linalg.Vectors
      +import org.apache.spark.mllib.util.MLlibTestSparkContext
      +
      +class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
      +  test("params") {
      +    ParamsSuite.checkParams(new RFormula())
      +  }
      +
      +  test("transform numeric data") {
      +    val formula = new RFormula().setFormula("id ~ v1 + v2")
      +    val original = sqlContext.createDataFrame(
      +      Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
      +    val model = formula.fit(original)
      +    val result = model.transform(original)
      +    val resultSchema = model.transformSchema(original.schema)
      +    val expected = sqlContext.createDataFrame(
      +      Seq(
      +        (0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0),
      +        (2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0))
      +      ).toDF("id", "v1", "v2", "features", "label")
      +    // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString
      +    assert(result.schema.toString == resultSchema.toString)
      +    assert(resultSchema == expected.schema)
      +    assert(result.collect() === expected.collect())
      +  }
      +
      +  test("features column already exists") {
      +    val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x")
      +    val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
      +    intercept[IllegalArgumentException] {
      +      formula.fit(original)
      +    }
      +    intercept[IllegalArgumentException] {
      +      formula.fit(original)
      +    }
      +  }
      +
      +  test("label column already exists") {
      +    val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
      +    val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
      +    val model = formula.fit(original)
      +    val resultSchema = model.transformSchema(original.schema)
      +    assert(resultSchema.length == 3)
      +    assert(resultSchema.toString == model.transform(original).schema.toString)
      +  }
      +
      +  test("label column already exists but is not double type") {
      +    val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
      +    val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
      +    val model = formula.fit(original)
      +    intercept[IllegalArgumentException] {
      +      model.transformSchema(original.schema)
      +    }
      +    intercept[IllegalArgumentException] {
      +      model.transform(original)
      +    }
      +  }
      +
      +  test("allow missing label column for test datasets") {
      +    val formula = new RFormula().setFormula("y ~ x").setLabelCol("label")
      +    val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y")
      +    val model = formula.fit(original)
      +    val resultSchema = model.transformSchema(original.schema)
      +    assert(resultSchema.length == 3)
      +    assert(!resultSchema.exists(_.name == "label"))
      +    assert(resultSchema.toString == model.transform(original).schema.toString)
      +  }
      +
      +  test("encodes string terms") {
      +    val formula = new RFormula().setFormula("id ~ a + b")
      +    val original = sqlContext.createDataFrame(
      +      Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
      +    ).toDF("id", "a", "b")
      +    val model = formula.fit(original)
      +    val result = model.transform(original)
      +    val resultSchema = model.transformSchema(original.schema)
      +    val expected = sqlContext.createDataFrame(
      +      Seq(
      +        (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
      +        (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0),
      +        (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0),
      +        (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0))
      +      ).toDF("id", "a", "b", "features", "label")
      +    assert(result.schema.toString == resultSchema.toString)
      +    assert(result.collect() === expected.collect())
      +  }
      +
      +  test("attribute generation") {
      +    val formula = new RFormula().setFormula("id ~ a + b")
      +    val original = sqlContext.createDataFrame(
      +      Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
      +    ).toDF("id", "a", "b")
      +    val model = formula.fit(original)
      +    val result = model.transform(original)
      +    val attrs = AttributeGroup.fromStructField(result.schema("features"))
      +    val expectedAttrs = new AttributeGroup(
      +      "features",
      +      Array(
      +        new BinaryAttribute(Some("a__bar"), Some(1)),
      +        new BinaryAttribute(Some("a__foo"), Some(2)),
      +        new NumericAttribute(Some("b"), Some(3))))
      +    assert(attrs === expectedAttrs)
      +  }
      +}
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
      new file mode 100644
      index 000000000000..d19052881ae4
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
      @@ -0,0 +1,44 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.feature
      +
      +import org.apache.spark.SparkFunSuite
      +import org.apache.spark.ml.param.ParamsSuite
      +import org.apache.spark.mllib.util.MLlibTestSparkContext
      +
      +class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext {
      +
      +  test("params") {
      +    ParamsSuite.checkParams(new SQLTransformer())
      +  }
      +
      +  test("transform numeric data") {
      +    val original = sqlContext.createDataFrame(
      +      Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
      +    val sqlTrans = new SQLTransformer().setStatement(
      +      "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__")
      +    val result = sqlTrans.transform(original)
      +    val resultSchema = sqlTrans.transformSchema(original.schema)
      +    val expected = sqlContext.createDataFrame(
      +      Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0)))
      +      .toDF("id", "v1", "v2", "v3", "v4")
      +    assert(result.schema.toString == resultSchema.toString)
      +    assert(resultSchema == expected.schema)
      +    assert(result.collect().toSeq == expected.collect().toSeq)
      +  }
      +}
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
      new file mode 100644
      index 000000000000..e0d433f566c2
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
      @@ -0,0 +1,80 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.feature
      +
      +import org.apache.spark.SparkFunSuite
      +import org.apache.spark.mllib.util.MLlibTestSparkContext
      +import org.apache.spark.sql.{DataFrame, Row}
      +
      +object StopWordsRemoverSuite extends SparkFunSuite {
      +  def testStopWordsRemover(t: StopWordsRemover, dataset: DataFrame): Unit = {
      +    t.transform(dataset)
      +      .select("filtered", "expected")
      +      .collect()
      +      .foreach { case Row(tokens, wantedTokens) =>
      +        assert(tokens === wantedTokens)
      +    }
      +  }
      +}
      +
      +class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext {
      +  import StopWordsRemoverSuite._
      +
      +  test("StopWordsRemover default") {
      +    val remover = new StopWordsRemover()
      +      .setInputCol("raw")
      +      .setOutputCol("filtered")
      +    val dataSet = sqlContext.createDataFrame(Seq(
      +      (Seq("test", "test"), Seq("test", "test")),
      +      (Seq("a", "b", "c", "d"), Seq("b", "c", "d")),
      +      (Seq("a", "the", "an"), Seq()),
      +      (Seq("A", "The", "AN"), Seq()),
      +      (Seq(null), Seq(null)),
      +      (Seq(), Seq())
      +    )).toDF("raw", "expected")
      +
      +    testStopWordsRemover(remover, dataSet)
      +  }
      +
      +  test("StopWordsRemover case sensitive") {
      +    val remover = new StopWordsRemover()
      +      .setInputCol("raw")
      +      .setOutputCol("filtered")
      +      .setCaseSensitive(true)
      +    val dataSet = sqlContext.createDataFrame(Seq(
      +      (Seq("A"), Seq("A")),
      +      (Seq("The", "the"), Seq("The"))
      +    )).toDF("raw", "expected")
      +
      +    testStopWordsRemover(remover, dataSet)
      +  }
      +
      +  test("StopWordsRemover with additional words") {
      +    val stopWords = StopWords.English ++ Array("python", "scala")
      +    val remover = new StopWordsRemover()
      +      .setInputCol("raw")
      +      .setOutputCol("filtered")
      +      .setStopWords(stopWords)
      +    val dataSet = sqlContext.createDataFrame(Seq(
      +      (Seq("python", "scala", "a"), Seq()),
      +      (Seq("Python", "Scala", "swift"), Seq("swift"))
      +    )).toDF("raw", "expected")
      +
      +    testStopWordsRemover(remover, dataSet)
      +  }
      +}
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
      index 99f82bea4268..ddcdb5f4212b 100644
      --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
      @@ -17,17 +17,23 @@
       
       package org.apache.spark.ml.feature
       
      -import org.apache.spark.SparkFunSuite
      +import org.apache.spark.sql.types.{StringType, StructType, StructField, DoubleType}
      +import org.apache.spark.{SparkException, SparkFunSuite}
       import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
       import org.apache.spark.ml.param.ParamsSuite
      +import org.apache.spark.ml.util.MLTestingUtils
       import org.apache.spark.mllib.util.MLlibTestSparkContext
      +import org.apache.spark.sql.Row
      +import org.apache.spark.sql.functions.col
       
       class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
       
         test("params") {
           ParamsSuite.checkParams(new StringIndexer)
           val model = new StringIndexerModel("indexer", Array("a", "b"))
      +    val modelWithoutUid = new StringIndexerModel(Array("a", "b"))
           ParamsSuite.checkParams(model)
      +    ParamsSuite.checkParams(modelWithoutUid)
         }
       
         test("StringIndexer") {
      @@ -37,6 +43,10 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
             .setInputCol("label")
             .setOutputCol("labelIndex")
             .fit(df)
      +
      +    // copied model must have the same parent.
      +    MLTestingUtils.checkCopy(indexer)
      +
           val transformed = indexer.transform(df)
           val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
             .asInstanceOf[NominalAttribute]
      @@ -49,6 +59,37 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
           assert(output === expected)
         }
       
      +  test("StringIndexerUnseen") {
      +    val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2)
      +    val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2)
      +    val df = sqlContext.createDataFrame(data).toDF("id", "label")
      +    val df2 = sqlContext.createDataFrame(data2).toDF("id", "label")
      +    val indexer = new StringIndexer()
      +      .setInputCol("label")
      +      .setOutputCol("labelIndex")
      +      .fit(df)
      +    // Verify we throw by default with unseen values
      +    intercept[SparkException] {
      +      indexer.transform(df2).collect()
      +    }
      +    val indexerSkipInvalid = new StringIndexer()
      +      .setInputCol("label")
      +      .setOutputCol("labelIndex")
      +      .setHandleInvalid("skip")
      +      .fit(df)
      +    // Verify that we skip the c record
      +    val transformed = indexerSkipInvalid.transform(df2)
      +    val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
      +      .asInstanceOf[NominalAttribute]
      +    assert(attr.values.get === Array("b", "a"))
      +    val output = transformed.select("id", "labelIndex").map { r =>
      +      (r.getInt(0), r.getDouble(1))
      +    }.collect().toSet
      +    // a -> 1, b -> 0
      +    val expected = Set((0, 1.0), (1, 0.0))
      +    assert(output === expected)
      +  }
      +
         test("StringIndexer with a numeric input column") {
           val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2)
           val df = sqlContext.createDataFrame(data).toDF("id", "label")
      @@ -75,4 +116,61 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
           val df = sqlContext.range(0L, 10L)
           assert(indexerModel.transform(df).eq(df))
         }
      +
      +  test("IndexToString params") {
      +    val idxToStr = new IndexToString()
      +    ParamsSuite.checkParams(idxToStr)
      +  }
      +
      +  test("IndexToString.transform") {
      +    val labels = Array("a", "b", "c")
      +    val df0 = sqlContext.createDataFrame(Seq(
      +      (0, "a"), (1, "b"), (2, "c"), (0, "a")
      +    )).toDF("index", "expected")
      +
      +    val idxToStr0 = new IndexToString()
      +      .setInputCol("index")
      +      .setOutputCol("actual")
      +      .setLabels(labels)
      +    idxToStr0.transform(df0).select("actual", "expected").collect().foreach {
      +      case Row(actual, expected) =>
      +        assert(actual === expected)
      +    }
      +
      +    val attr = NominalAttribute.defaultAttr.withValues(labels)
      +    val df1 = df0.select(col("index").as("indexWithAttr", attr.toMetadata()), col("expected"))
      +
      +    val idxToStr1 = new IndexToString()
      +      .setInputCol("indexWithAttr")
      +      .setOutputCol("actual")
      +    idxToStr1.transform(df1).select("actual", "expected").collect().foreach {
      +      case Row(actual, expected) =>
      +        assert(actual === expected)
      +    }
      +  }
      +
      +  test("StringIndexer, IndexToString are inverses") {
      +    val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
      +    val df = sqlContext.createDataFrame(data).toDF("id", "label")
      +    val indexer = new StringIndexer()
      +      .setInputCol("label")
      +      .setOutputCol("labelIndex")
      +      .fit(df)
      +    val transformed = indexer.transform(df)
      +    val idx2str = new IndexToString()
      +      .setInputCol("labelIndex")
      +      .setOutputCol("sameLabel")
      +      .setLabels(indexer.labels)
      +    idx2str.transform(transformed).select("label", "sameLabel").collect().foreach {
      +      case Row(a: String, b: String) =>
      +        assert(a === b)
      +    }
      +  }
      +
      +  test("IndexToString.transformSchema (SPARK-10573)") {
      +    val idxToStr = new IndexToString().setInputCol("input").setOutputCol("output")
      +    val inSchema = StructType(Seq(StructField("input", DoubleType)))
      +    val outSchema = idxToStr.transformSchema(inSchema)
      +    assert(outSchema("output").dataType === StringType)
      +  }
       }
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
      index 8c85c96d5c6d..8cb0a2cf14d3 100644
      --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
      @@ -19,15 +19,16 @@ package org.apache.spark.ml.feature
       
       import scala.beans.{BeanInfo, BeanProperty}
       
      -import org.apache.spark.{SparkException, SparkFunSuite}
      +import org.apache.spark.{Logging, SparkException, SparkFunSuite}
       import org.apache.spark.ml.attribute._
       import org.apache.spark.ml.param.ParamsSuite
      +import org.apache.spark.ml.util.MLTestingUtils
       import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
       import org.apache.spark.mllib.util.MLlibTestSparkContext
       import org.apache.spark.rdd.RDD
       import org.apache.spark.sql.DataFrame
       
      -class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
      +class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
       
         import VectorIndexerSuite.FeatureData
       
      @@ -109,15 +110,19 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
         test("Throws error when given RDDs with different size vectors") {
           val vectorIndexer = getIndexer
           val model = vectorIndexer.fit(densePoints1) // vectors of length 3
      +
      +    // copied model must have the same parent.
      +    MLTestingUtils.checkCopy(model)
      +
           model.transform(densePoints1) // should work
           model.transform(sparsePoints1) // should work
           intercept[SparkException] {
             model.transform(densePoints2).collect()
      -      println("Did not throw error when fit, transform were called on vectors of different lengths")
      +      logInfo("Did not throw error when fit, transform were called on vectors of different lengths")
           }
           intercept[SparkException] {
             vectorIndexer.fit(badPoints)
      -      println("Did not throw error when fitting vectors of different lengths in same RDD.")
      +      logInfo("Did not throw error when fitting vectors of different lengths in same RDD.")
           }
         }
       
      @@ -196,7 +201,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
               }
             } catch {
               case e: org.scalatest.exceptions.TestFailedException =>
      -          println(errMsg)
      +          logError(errMsg)
                 throw e
             }
           }
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
      new file mode 100644
      index 000000000000..a6c2fba8360d
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
      @@ -0,0 +1,109 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.feature
      +
      +import org.apache.spark.SparkFunSuite
      +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute}
      +import org.apache.spark.ml.param.ParamsSuite
      +import org.apache.spark.mllib.linalg.{Vector, Vectors}
      +import org.apache.spark.mllib.util.MLlibTestSparkContext
      +import org.apache.spark.sql.types.StructType
      +import org.apache.spark.sql.{DataFrame, Row, SQLContext}
      +
      +class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext {
      +
      +  test("params") {
      +    val slicer = new VectorSlicer
      +    ParamsSuite.checkParams(slicer)
      +    assert(slicer.getIndices.length === 0)
      +    assert(slicer.getNames.length === 0)
      +    withClue("VectorSlicer should not have any features selected by default") {
      +      intercept[IllegalArgumentException] {
      +        slicer.validateParams()
      +      }
      +    }
      +  }
      +
      +  test("feature validity checks") {
      +    import VectorSlicer._
      +    assert(validIndices(Array(0, 1, 8, 2)))
      +    assert(validIndices(Array.empty[Int]))
      +    assert(!validIndices(Array(-1)))
      +    assert(!validIndices(Array(1, 2, 1)))
      +
      +    assert(validNames(Array("a", "b")))
      +    assert(validNames(Array.empty[String]))
      +    assert(!validNames(Array("", "b")))
      +    assert(!validNames(Array("a", "b", "a")))
      +  }
      +
      +  test("Test vector slicer") {
      +    val sqlContext = new SQLContext(sc)
      +
      +    val data = Array(
      +      Vectors.sparse(5, Seq((0, -2.0), (1, 2.3))),
      +      Vectors.dense(-2.0, 2.3, 0.0, 0.0, 1.0),
      +      Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0),
      +      Vectors.dense(0.6, -1.1, -3.0, 4.5, 3.3),
      +      Vectors.sparse(5, Seq())
      +    )
      +
      +    // Expected after selecting indices 1, 4
      +    val expected = Array(
      +      Vectors.sparse(2, Seq((0, 2.3))),
      +      Vectors.dense(2.3, 1.0),
      +      Vectors.dense(0.0, 0.0),
      +      Vectors.dense(-1.1, 3.3),
      +      Vectors.sparse(2, Seq())
      +    )
      +
      +    val defaultAttr = NumericAttribute.defaultAttr
      +    val attrs = Array("f0", "f1", "f2", "f3", "f4").map(defaultAttr.withName)
      +    val attrGroup = new AttributeGroup("features", attrs.asInstanceOf[Array[Attribute]])
      +
      +    val resultAttrs = Array("f1", "f4").map(defaultAttr.withName)
      +    val resultAttrGroup = new AttributeGroup("expected", resultAttrs.asInstanceOf[Array[Attribute]])
      +
      +    val rdd = sc.parallelize(data.zip(expected)).map { case (a, b) => Row(a, b) }
      +    val df = sqlContext.createDataFrame(rdd,
      +      StructType(Array(attrGroup.toStructField(), resultAttrGroup.toStructField())))
      +
      +    val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result")
      +
      +    def validateResults(df: DataFrame): Unit = {
      +      df.select("result", "expected").collect().foreach { case Row(vec1: Vector, vec2: Vector) =>
      +        assert(vec1 === vec2)
      +      }
      +      val resultMetadata = AttributeGroup.fromStructField(df.schema("result"))
      +      val expectedMetadata = AttributeGroup.fromStructField(df.schema("expected"))
      +      assert(resultMetadata.numAttributes === expectedMetadata.numAttributes)
      +      resultMetadata.attributes.get.zip(expectedMetadata.attributes.get).foreach { case (a, b) =>
      +        assert(a === b)
      +      }
      +    }
      +
      +    vectorSlicer.setIndices(Array(1, 4)).setNames(Array.empty)
      +    validateResults(vectorSlicer.transform(df))
      +
      +    vectorSlicer.setIndices(Array(1)).setNames(Array("f4"))
      +    validateResults(vectorSlicer.transform(df))
      +
      +    vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4"))
      +    validateResults(vectorSlicer.transform(df))
      +  }
      +}
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
      index aa6ce533fd88..a2e46f202995 100644
      --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
      @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
       
       import org.apache.spark.SparkFunSuite
       import org.apache.spark.ml.param.ParamsSuite
      +import org.apache.spark.ml.util.MLTestingUtils
       import org.apache.spark.mllib.linalg.{Vector, Vectors}
       import org.apache.spark.mllib.util.MLlibTestSparkContext
       import org.apache.spark.mllib.util.TestingUtils._
      @@ -62,10 +63,75 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
             .setSeed(42L)
             .fit(docDF)
       
      +    // copied model must have the same parent.
      +    MLTestingUtils.checkCopy(model)
      +
           model.transform(docDF).select("result", "expected").collect().foreach {
             case Row(vector1: Vector, vector2: Vector) =>
               assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.")
           }
         }
      +
      +  test("getVectors") {
      +
      +    val sqlContext = new SQLContext(sc)
      +    import sqlContext.implicits._
      +
      +    val sentence = "a b " * 100 + "a c " * 10
      +    val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
      +
      +    val codes = Map(
      +      "a" -> Array(-0.2811822295188904, -0.6356269121170044, -0.3020961284637451),
      +      "b" -> Array(1.0309048891067505, -1.29472815990448, 0.22276712954044342),
      +      "c" -> Array(-0.08456747233867645, 0.5137411952018738, 0.11731560528278351)
      +    )
      +    val expectedVectors = codes.toSeq.sortBy(_._1).map { case (w, v) => Vectors.dense(v) }
      +
      +    val docDF = doc.zip(doc).toDF("text", "alsotext")
      +
      +    val model = new Word2Vec()
      +      .setVectorSize(3)
      +      .setInputCol("text")
      +      .setOutputCol("result")
      +      .setSeed(42L)
      +      .fit(docDF)
      +
      +    val realVectors = model.getVectors.sort("word").select("vector").map {
      +      case Row(v: Vector) => v
      +    }.collect()
      +
      +    realVectors.zip(expectedVectors).foreach {
      +      case (real, expected) =>
      +        assert(real ~== expected absTol 1E-5, "Actual vector is different from expected.")
      +    }
      +  }
      +
      +  test("findSynonyms") {
      +
      +    val sqlContext = new SQLContext(sc)
      +    import sqlContext.implicits._
      +
      +    val sentence = "a b " * 100 + "a c " * 10
      +    val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
      +    val docDF = doc.zip(doc).toDF("text", "alsotext")
      +
      +    val model = new Word2Vec()
      +      .setVectorSize(3)
      +      .setInputCol("text")
      +      .setOutputCol("result")
      +      .setSeed(42L)
      +      .fit(docDF)
      +
      +    val expectedSimilarity = Array(0.2789285076917586, -0.6336972059851644)
      +    val (synonyms, similarity) = model.findSynonyms("a", 2).map {
      +      case Row(w: String, sim: Double) => (w, sim)
      +    }.collect().unzip
      +
      +    assert(synonyms.toArray === Array("b", "c"))
      +    expectedSimilarity.zip(similarity).map {
      +      case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5)
      +    }
      +
      +  }
       }
       
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
      index 778abcba22c1..460849c79f04 100644
      --- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
      +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
      @@ -124,4 +124,22 @@ private[ml] object TreeTests extends SparkFunSuite {
               "checkEqual failed since the two tree ensembles were not identical")
           }
         }
      +
      +  /**
      +   * Helper method for constructing a tree for testing.
      +   * Given left, right children, construct a parent node.
      +   * @param split  Split for parent node
      +   * @return  Parent node with children attached
      +   */
      +  def buildParentNode(left: Node, right: Node, split: Split): Node = {
      +    val leftImp = left.impurityStats
      +    val rightImp = right.impurityStats
      +    val parentImp = leftImp.copy.add(rightImp)
      +    val leftWeight = leftImp.count / parentImp.count.toDouble
      +    val rightWeight = rightImp.count / parentImp.count.toDouble
      +    val gain = parentImp.calculate() -
      +      (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate())
      +    val pred = parentImp.predict
      +    new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp)
      +  }
       }
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
      new file mode 100644
      index 000000000000..652f3adb984d
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
      @@ -0,0 +1,133 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.optim
      +
      +import org.apache.spark.SparkFunSuite
      +import org.apache.spark.ml.optim.WeightedLeastSquares.Instance
      +import org.apache.spark.mllib.linalg.Vectors
      +import org.apache.spark.mllib.util.MLlibTestSparkContext
      +import org.apache.spark.mllib.util.TestingUtils._
      +import org.apache.spark.rdd.RDD
      +
      +class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext {
      +
      +  private var instances: RDD[Instance] = _
      +
      +  override def beforeAll(): Unit = {
      +    super.beforeAll()
      +    /*
      +       R code:
      +
      +       A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2)
      +       b <- c(17, 19, 23, 29)
      +       w <- c(1, 2, 3, 4)
      +     */
      +    instances = sc.parallelize(Seq(
      +      Instance(1.0, Vectors.dense(0.0, 5.0).toSparse, 17.0),
      +      Instance(2.0, Vectors.dense(1.0, 7.0), 19.0),
      +      Instance(3.0, Vectors.dense(2.0, 11.0), 23.0),
      +      Instance(4.0, Vectors.dense(3.0, 13.0), 29.0)
      +    ), 2)
      +  }
      +
      +  test("WLS against lm") {
      +    /*
      +       R code:
      +
      +       df <- as.data.frame(cbind(A, b))
      +       for (formula in c(b ~ . -1, b ~ .)) {
      +         model <- lm(formula, data=df, weights=w)
      +         print(as.vector(coef(model)))
      +       }
      +
      +       [1] -3.727121  3.009983
      +       [1] 18.08  6.08 -0.60
      +     */
      +
      +    val expected = Seq(
      +      Vectors.dense(0.0, -3.727121, 3.009983),
      +      Vectors.dense(18.08, 6.08, -0.60))
      +
      +    var idx = 0
      +    for (fitIntercept <- Seq(false, true)) {
      +      val wls = new WeightedLeastSquares(
      +        fitIntercept, regParam = 0.0, standardizeFeatures = false, standardizeLabel = false)
      +        .fit(instances)
      +      val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1))
      +      assert(actual ~== expected(idx) absTol 1e-4)
      +      idx += 1
      +    }
      +  }
      +
      +  test("WLS against glmnet") {
      +    /*
      +       R code:
      +
      +       library(glmnet)
      +
      +       for (intercept in c(FALSE, TRUE)) {
      +         for (lambda in c(0.0, 0.1, 1.0)) {
      +           for (standardize in c(FALSE, TRUE)) {
      +             model <- glmnet(A, b, weights=w, intercept=intercept, lambda=lambda,
      +                             standardize=standardize, alpha=0, thresh=1E-14)
      +             print(as.vector(coef(model)))
      +           }
      +         }
      +       }
      +
      +       [1]  0.000000 -3.727117  3.009982
      +       [1]  0.000000 -3.727117  3.009982
      +       [1]  0.000000 -3.307532  2.924206
      +       [1]  0.000000 -2.914790  2.840627
      +       [1]  0.000000 -1.526575  2.558158
      +       [1] 0.00000000 0.06984238 2.20488344
      +       [1] 18.0799727  6.0799832 -0.5999941
      +       [1] 18.0799727  6.0799832 -0.5999941
      +       [1] 13.5356178  3.2714044  0.3770744
      +       [1] 14.064629  3.565802  0.269593
      +       [1] 10.1238013  0.9708569  1.1475466
      +       [1] 13.1860638  2.1761382  0.6213134
      +     */
      +
      +    val expected = Seq(
      +      Vectors.dense(0.0, -3.727117, 3.009982),
      +      Vectors.dense(0.0, -3.727117, 3.009982),
      +      Vectors.dense(0.0, -3.307532, 2.924206),
      +      Vectors.dense(0.0, -2.914790, 2.840627),
      +      Vectors.dense(0.0, -1.526575, 2.558158),
      +      Vectors.dense(0.0, 0.06984238, 2.20488344),
      +      Vectors.dense(18.0799727, 6.0799832, -0.5999941),
      +      Vectors.dense(18.0799727, 6.0799832, -0.5999941),
      +      Vectors.dense(13.5356178, 3.2714044, 0.3770744),
      +      Vectors.dense(14.064629, 3.565802, 0.269593),
      +      Vectors.dense(10.1238013, 0.9708569, 1.1475466),
      +      Vectors.dense(13.1860638, 2.1761382, 0.6213134))
      +
      +    var idx = 0
      +    for (fitIntercept <- Seq(false, true);
      +         regParam <- Seq(0.0, 0.1, 1.0);
      +         standardizeFeatures <- Seq(false, true)) {
      +      val wls = new WeightedLeastSquares(
      +        fitIntercept, regParam, standardizeFeatures, standardizeLabel = true)
      +        .fit(instances)
      +      val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1))
      +      assert(actual ~== expected(idx) absTol 1e-4)
      +      idx += 1
      +    }
      +  }
      +}
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
      index 050d4170ea01..dfab82c8b67a 100644
      --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
      @@ -40,6 +40,10 @@ class ParamsSuite extends SparkFunSuite {
       
           assert(inputCol.toString === s"${uid}__inputCol")
       
      +    intercept[java.util.NoSuchElementException] {
      +      solver.getOrDefault(solver.handleInvalid)
      +    }
      +
           intercept[IllegalArgumentException] {
             solver.setMaxIter(-1)
           }
      @@ -102,12 +106,13 @@ class ParamsSuite extends SparkFunSuite {
       
         test("params") {
           val solver = new TestParams()
      -    import solver.{maxIter, inputCol}
      +    import solver.{handleInvalid, maxIter, inputCol}
       
           val params = solver.params
      -    assert(params.length === 2)
      -    assert(params(0).eq(inputCol), "params must be ordered by name")
      -    assert(params(1).eq(maxIter))
      +    assert(params.length === 3)
      +    assert(params(0).eq(handleInvalid), "params must be ordered by name")
      +    assert(params(1).eq(inputCol), "params must be ordered by name")
      +    assert(params(2).eq(maxIter))
       
           assert(!solver.isSet(maxIter))
           assert(solver.isDefined(maxIter))
      @@ -122,7 +127,7 @@ class ParamsSuite extends SparkFunSuite {
           assert(solver.explainParam(maxIter) ===
             "maxIter: maximum number of iterations (>= 0) (default: 10, current: 100)")
           assert(solver.explainParams() ===
      -      Seq(inputCol, maxIter).map(solver.explainParam).mkString("\n"))
      +      Seq(handleInvalid, inputCol, maxIter).map(solver.explainParam).mkString("\n"))
       
           assert(solver.getParam("inputCol").eq(inputCol))
           assert(solver.getParam("maxIter").eq(maxIter))
      @@ -199,6 +204,17 @@ class ParamsSuite extends SparkFunSuite {
       
           val inArray = ParamValidators.inArray[Int](Array(1, 2))
           assert(inArray(1) && inArray(2) && !inArray(0))
      +
      +    val arrayLengthGt = ParamValidators.arrayLengthGt[Int](2.0)
      +    assert(arrayLengthGt(Array(0, 1, 2)) && !arrayLengthGt(Array(0, 1)))
      +  }
      +
      +  test("Params.copyValues") {
      +    val t = new TestParams()
      +    val t2 = t.copy(ParamMap.empty)
      +    assert(!t2.isSet(t2.maxIter))
      +    val t3 = t.copy(ParamMap(t.maxIter -> 20))
      +    assert(t3.isSet(t3.maxIter))
         }
       }
       
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
      index 275924834453..9d23547f2844 100644
      --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
      +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
      @@ -17,11 +17,12 @@
       
       package org.apache.spark.ml.param
       
      -import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter}
      +import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasMaxIter}
       import org.apache.spark.ml.util.Identifiable
       
       /** A subclass of Params for testing. */
      -class TestParams(override val uid: String) extends Params with HasMaxIter with HasInputCol {
      +class TestParams(override val uid: String) extends Params with HasHandleInvalid with HasMaxIter
      +    with HasInputCol {
       
         def this() = this(Identifiable.randomUID("testParams"))
       
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
      index 2e5cfe7027eb..eadc80e0e62b 100644
      --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
      @@ -28,6 +28,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas}
       
       import org.apache.spark.{Logging, SparkException, SparkFunSuite}
       import org.apache.spark.ml.recommendation.ALS._
      +import org.apache.spark.ml.util.MLTestingUtils
       import org.apache.spark.mllib.linalg.Vectors
       import org.apache.spark.mllib.util.MLlibTestSparkContext
       import org.apache.spark.mllib.util.TestingUtils._
      @@ -374,6 +375,9 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
             }
           logInfo(s"Test RMSE is $rmse.")
           assert(rmse < targetRMSE)
      +
      +    // copied model must have the same parent.
      +    MLTestingUtils.checkCopy(model)
         }
       
         test("exact rank-1 matrix") {
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
      index 33aa9d0d6234..b092bcd6a7e8 100644
      --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
      @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression
       
       import org.apache.spark.SparkFunSuite
       import org.apache.spark.ml.impl.TreeTests
      +import org.apache.spark.ml.util.MLTestingUtils
       import org.apache.spark.mllib.regression.LabeledPoint
       import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
         DecisionTreeSuite => OldDecisionTreeSuite}
      @@ -61,6 +62,16 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
           compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures)
         }
       
      +  test("copied model must have the same parent") {
      +    val categoricalFeatures = Map(0 -> 2, 1-> 2)
      +    val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0)
      +    val model = new DecisionTreeRegressor()
      +      .setImpurity("variance")
      +      .setMaxDepth(2)
      +      .setMaxBins(8).fit(df)
      +    MLTestingUtils.checkCopy(model)
      +  }
      +
         /////////////////////////////////////////////////////////////////////////////
         // Tests of model save/load
         /////////////////////////////////////////////////////////////////////////////
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
      index 98fb3d3f5f22..a68197b59193 100644
      --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
      @@ -19,12 +19,15 @@ package org.apache.spark.ml.regression
       
       import org.apache.spark.SparkFunSuite
       import org.apache.spark.ml.impl.TreeTests
      +import org.apache.spark.ml.util.MLTestingUtils
      +import org.apache.spark.mllib.linalg.Vectors
       import org.apache.spark.mllib.regression.LabeledPoint
       import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
       import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
       import org.apache.spark.mllib.util.MLlibTestSparkContext
       import org.apache.spark.rdd.RDD
       import org.apache.spark.sql.DataFrame
      +import org.apache.spark.util.Utils
       
       
       /**
      @@ -67,6 +70,47 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
           }
         }
       
      +  test("GBTRegressor behaves reasonably on toy data") {
      +    val df = sqlContext.createDataFrame(Seq(
      +      LabeledPoint(10, Vectors.dense(1, 2, 3, 4)),
      +      LabeledPoint(-5, Vectors.dense(6, 3, 2, 1)),
      +      LabeledPoint(11, Vectors.dense(2, 2, 3, 4)),
      +      LabeledPoint(-6, Vectors.dense(6, 4, 2, 1)),
      +      LabeledPoint(9, Vectors.dense(1, 2, 6, 4)),
      +      LabeledPoint(-4, Vectors.dense(6, 3, 2, 2))
      +    ))
      +    val gbt = new GBTRegressor()
      +      .setMaxDepth(2)
      +      .setMaxIter(2)
      +    val model = gbt.fit(df)
      +
      +    // copied model must have the same parent.
      +    MLTestingUtils.checkCopy(model)
      +    val preds = model.transform(df)
      +    val predictions = preds.select("prediction").map(_.getDouble(0))
      +    // Checks based on SPARK-8736 (to ensure it is not doing classification)
      +    assert(predictions.max() > 2)
      +    assert(predictions.min() < -1)
      +  }
      +
      +  test("Checkpointing") {
      +    val tempDir = Utils.createTempDir()
      +    val path = tempDir.toURI.toString
      +    sc.setCheckpointDir(path)
      +
      +    val df = sqlContext.createDataFrame(data)
      +    val gbt = new GBTRegressor()
      +      .setMaxDepth(2)
      +      .setMaxIter(5)
      +      .setStepSize(0.1)
      +      .setCheckpointInterval(2)
      +    val model = gbt.fit(df)
      +
      +    sc.checkpointDir = None
      +    Utils.deleteRecursively(tempDir)
      +
      +  }
      +
         // TODO: Reinstate test once runWithValidation is implemented  SPARK-7132
         /*
         test("runWithValidation stops early and performs better on a validation dataset") {
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
      new file mode 100644
      index 000000000000..59f4193abc8f
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
      @@ -0,0 +1,167 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.regression
      +
      +import org.apache.spark.SparkFunSuite
      +import org.apache.spark.ml.param.ParamsSuite
      +import org.apache.spark.ml.util.MLTestingUtils
      +import org.apache.spark.mllib.linalg.Vectors
      +import org.apache.spark.mllib.util.MLlibTestSparkContext
      +import org.apache.spark.sql.{DataFrame, Row}
      +
      +class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
      +  private def generateIsotonicInput(labels: Seq[Double]): DataFrame = {
      +    sqlContext.createDataFrame(
      +      labels.zipWithIndex.map { case (label, i) => (label, i.toDouble, 1.0) }
      +    ).toDF("label", "features", "weight")
      +  }
      +
      +  private def generatePredictionInput(features: Seq[Double]): DataFrame = {
      +    sqlContext.createDataFrame(features.map(Tuple1.apply))
      +      .toDF("features")
      +  }
      +
      +  test("isotonic regression predictions") {
      +    val dataset = generateIsotonicInput(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18))
      +    val ir = new IsotonicRegression().setIsotonic(true)
      +
      +    val model = ir.fit(dataset)
      +
      +    val predictions = model
      +      .transform(dataset)
      +      .select("prediction").map { case Row(pred) =>
      +        pred
      +      }.collect()
      +
      +    assert(predictions === Array(1, 2, 2, 2, 6, 16.5, 16.5, 17, 18))
      +
      +    assert(model.boundaries === Vectors.dense(0, 1, 3, 4, 5, 6, 7, 8))
      +    assert(model.predictions === Vectors.dense(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0))
      +    assert(model.getIsotonic)
      +  }
      +
      +  test("antitonic regression predictions") {
      +    val dataset = generateIsotonicInput(Seq(7, 5, 3, 5, 1))
      +    val ir = new IsotonicRegression().setIsotonic(false)
      +
      +    val model = ir.fit(dataset)
      +    val features = generatePredictionInput(Seq(-2.0, -1.0, 0.5, 0.75, 1.0, 2.0, 9.0))
      +
      +    val predictions = model
      +      .transform(features)
      +      .select("prediction").map {
      +        case Row(pred) => pred
      +      }.collect()
      +
      +    assert(predictions === Array(7, 7, 6, 5.5, 5, 4, 1))
      +  }
      +
      +  test("params validation") {
      +    val dataset = generateIsotonicInput(Seq(1, 2, 3))
      +    val ir = new IsotonicRegression
      +    ParamsSuite.checkParams(ir)
      +    val model = ir.fit(dataset)
      +    ParamsSuite.checkParams(model)
      +  }
      +
      +  test("default params") {
      +    val dataset = generateIsotonicInput(Seq(1, 2, 3))
      +    val ir = new IsotonicRegression()
      +    assert(ir.getLabelCol === "label")
      +    assert(ir.getFeaturesCol === "features")
      +    assert(ir.getPredictionCol === "prediction")
      +    assert(!ir.isDefined(ir.weightCol))
      +    assert(ir.getIsotonic)
      +    assert(ir.getFeatureIndex === 0)
      +
      +    val model = ir.fit(dataset)
      +
      +    // copied model must have the same parent.
      +    MLTestingUtils.checkCopy(model)
      +
      +    model.transform(dataset)
      +      .select("label", "features", "prediction", "weight")
      +      .collect()
      +
      +    assert(model.getLabelCol === "label")
      +    assert(model.getFeaturesCol === "features")
      +    assert(model.getPredictionCol === "prediction")
      +    assert(!model.isDefined(model.weightCol))
      +    assert(model.getIsotonic)
      +    assert(model.getFeatureIndex === 0)
      +    assert(model.hasParent)
      +  }
      +
      +  test("set parameters") {
      +    val isotonicRegression = new IsotonicRegression()
      +      .setIsotonic(false)
      +      .setWeightCol("w")
      +      .setFeaturesCol("f")
      +      .setLabelCol("l")
      +      .setPredictionCol("p")
      +
      +    assert(!isotonicRegression.getIsotonic)
      +    assert(isotonicRegression.getWeightCol === "w")
      +    assert(isotonicRegression.getFeaturesCol === "f")
      +    assert(isotonicRegression.getLabelCol === "l")
      +    assert(isotonicRegression.getPredictionCol === "p")
      +  }
      +
      +  test("missing column") {
      +    val dataset = generateIsotonicInput(Seq(1, 2, 3))
      +
      +    intercept[IllegalArgumentException] {
      +      new IsotonicRegression().setWeightCol("w").fit(dataset)
      +    }
      +
      +    intercept[IllegalArgumentException] {
      +      new IsotonicRegression().setFeaturesCol("f").fit(dataset)
      +    }
      +
      +    intercept[IllegalArgumentException] {
      +      new IsotonicRegression().setLabelCol("l").fit(dataset)
      +    }
      +
      +    intercept[IllegalArgumentException] {
      +      new IsotonicRegression().fit(dataset).setFeaturesCol("f").transform(dataset)
      +    }
      +  }
      +
      +  test("vector features column with feature index") {
      +    val dataset = sqlContext.createDataFrame(Seq(
      +      (4.0, Vectors.dense(0.0, 1.0)),
      +      (3.0, Vectors.dense(0.0, 2.0)),
      +      (5.0, Vectors.sparse(2, Array(1), Array(3.0))))
      +    ).toDF("label", "features")
      +
      +    val ir = new IsotonicRegression()
      +      .setFeatureIndex(1)
      +
      +    val model = ir.fit(dataset)
      +
      +    val features = generatePredictionInput(Seq(2.0, 3.0, 4.0, 5.0))
      +
      +    val predictions = model
      +      .transform(features)
      +      .select("prediction").map {
      +      case Row(pred) => pred
      +    }.collect()
      +
      +    assert(predictions === Array(3.5, 5.0, 5.0, 5.0))
      +  }
      +}
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
      index 732e2c42be14..2aaee71ecc73 100644
      --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
      @@ -18,7 +18,9 @@
       package org.apache.spark.ml.regression
       
       import org.apache.spark.SparkFunSuite
      -import org.apache.spark.mllib.linalg.DenseVector
      +import org.apache.spark.ml.param.ParamsSuite
      +import org.apache.spark.ml.util.MLTestingUtils
      +import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
       import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
       import org.apache.spark.mllib.util.TestingUtils._
       import org.apache.spark.sql.{DataFrame, Row}
      @@ -26,139 +28,486 @@ import org.apache.spark.sql.{DataFrame, Row}
       class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
       
         @transient var dataset: DataFrame = _
      +  @transient var datasetWithoutIntercept: DataFrame = _
       
      -  /**
      -   * In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
      -   * is the same as the one trained by R's glmnet package. The following instruction
      -   * describes how to reproduce the data in R.
      -   *
      -   * import org.apache.spark.mllib.util.LinearDataGenerator
      -   * val data =
      -   *   sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), 10000, 42), 2)
      -   * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).saveAsTextFile("path")
      +  /*
      +     In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
      +     is the same as the one trained by R's glmnet package. The following instruction
      +     describes how to reproduce the data in R.
      +
      +     import org.apache.spark.mllib.util.LinearDataGenerator
      +     val data =
      +       sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2),
      +         Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)
      +     data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1)
      +       .saveAsTextFile("path")
          */
         override def beforeAll(): Unit = {
           super.beforeAll()
           dataset = sqlContext.createDataFrame(
             sc.parallelize(LinearDataGenerator.generateLinearInput(
               6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2))
      +    /*
      +       datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating
      +       training model without intercept
      +     */
      +    datasetWithoutIntercept = sqlContext.createDataFrame(
      +      sc.parallelize(LinearDataGenerator.generateLinearInput(
      +        0.0, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2))
      +
      +  }
      +
      +  test("params") {
      +    ParamsSuite.checkParams(new LinearRegression)
      +    val model = new LinearRegressionModel("linearReg", Vectors.dense(0.0), 0.0)
      +    ParamsSuite.checkParams(model)
      +  }
      +
      +  test("linear regression: default params") {
      +    val lir = new LinearRegression
      +    assert(lir.getLabelCol === "label")
      +    assert(lir.getFeaturesCol === "features")
      +    assert(lir.getPredictionCol === "prediction")
      +    assert(lir.getRegParam === 0.0)
      +    assert(lir.getElasticNetParam === 0.0)
      +    assert(lir.getFitIntercept)
      +    assert(lir.getStandardization)
      +    val model = lir.fit(dataset)
      +
      +    // copied model must have the same parent.
      +    MLTestingUtils.checkCopy(model)
      +
      +    model.transform(dataset)
      +      .select("label", "prediction")
      +      .collect()
      +    assert(model.getFeaturesCol === "features")
      +    assert(model.getPredictionCol === "prediction")
      +    assert(model.intercept !== 0.0)
      +    assert(model.hasParent)
         }
       
         test("linear regression with intercept without regularization") {
      -    val trainer = new LinearRegression
      -    val model = trainer.fit(dataset)
      +    val trainer1 = new LinearRegression
      +    // The result should be the same regardless of standardization without regularization
      +    val trainer2 = (new LinearRegression).setStandardization(false)
      +    val model1 = trainer1.fit(dataset)
      +    val model2 = trainer2.fit(dataset)
       
      -    /**
      -     * Using the following R code to load the data and train the model using glmnet package.
      -     *
      -     * library("glmnet")
      -     * data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
      -     * features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3)))
      -     * label <- as.numeric(data$V1)
      -     * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0))
      -     * > weights
      -     *  3 x 1 sparse Matrix of class "dgCMatrix"
      -     *                           s0
      -     * (Intercept)         6.300528
      -     * as.numeric.data.V2. 4.701024
      -     * as.numeric.data.V3. 7.198257
      +    /*
      +       Using the following R code to load the data and train the model using glmnet package.
      +
      +       library("glmnet")
      +       data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
      +       features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3)))
      +       label <- as.numeric(data$V1)
      +       weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0))
      +       > weights
      +        3 x 1 sparse Matrix of class "dgCMatrix"
      +                                 s0
      +       (Intercept)         6.298698
      +       as.numeric.data.V2. 4.700706
      +       as.numeric.data.V3. 7.199082
            */
           val interceptR = 6.298698
      -    val weightsR = Array(4.700706, 7.199082)
      +    val weightsR = Vectors.dense(4.700706, 7.199082)
      +
      +    assert(model1.intercept ~== interceptR relTol 1E-3)
      +    assert(model1.weights ~= weightsR relTol 1E-3)
      +    assert(model2.intercept ~== interceptR relTol 1E-3)
      +    assert(model2.weights ~= weightsR relTol 1E-3)
       
      -    assert(model.intercept ~== interceptR relTol 1E-3)
      -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
      -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
       
      -    model.transform(dataset).select("features", "prediction").collect().foreach {
      +    model1.transform(dataset).select("features", "prediction").collect().foreach {
             case Row(features: DenseVector, prediction1: Double) =>
               val prediction2 =
      -          features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
      +          features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
               assert(prediction1 ~== prediction2 relTol 1E-5)
           }
         }
       
      +  test("linear regression without intercept without regularization") {
      +    val trainer1 = (new LinearRegression).setFitIntercept(false)
      +    // Without regularization the results should be the same
      +    val trainer2 = (new LinearRegression).setFitIntercept(false).setStandardization(false)
      +    val model1 = trainer1.fit(dataset)
      +    val modelWithoutIntercept1 = trainer1.fit(datasetWithoutIntercept)
      +    val model2 = trainer2.fit(dataset)
      +    val modelWithoutIntercept2 = trainer2.fit(datasetWithoutIntercept)
      +
      +
      +    /*
      +       weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0,
      +         intercept = FALSE))
      +       > weights
      +        3 x 1 sparse Matrix of class "dgCMatrix"
      +                                 s0
      +       (Intercept)         .
      +       as.numeric.data.V2. 6.995908
      +       as.numeric.data.V3. 5.275131
      +     */
      +    val weightsR = Vectors.dense(6.995908, 5.275131)
      +
      +    assert(model1.intercept ~== 0 absTol 1E-3)
      +    assert(model1.weights ~= weightsR relTol 1E-3)
      +    assert(model2.intercept ~== 0 absTol 1E-3)
      +    assert(model2.weights ~= weightsR relTol 1E-3)
      +
      +    /*
      +       Then again with the data with no intercept:
      +       > weightsWithoutIntercept
      +       3 x 1 sparse Matrix of class "dgCMatrix"
      +                                 s0
      +       (Intercept)           .
      +       as.numeric.data3.V2. 4.70011
      +       as.numeric.data3.V3. 7.19943
      +     */
      +    val weightsWithoutInterceptR = Vectors.dense(4.70011, 7.19943)
      +
      +    assert(modelWithoutIntercept1.intercept ~== 0 absTol 1E-3)
      +    assert(modelWithoutIntercept1.weights ~= weightsWithoutInterceptR relTol 1E-3)
      +    assert(modelWithoutIntercept2.intercept ~== 0 absTol 1E-3)
      +    assert(modelWithoutIntercept2.weights ~= weightsWithoutInterceptR relTol 1E-3)
      +  }
      +
         test("linear regression with intercept with L1 regularization") {
      -    val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
      -    val model = trainer.fit(dataset)
      +    val trainer1 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
      +    val trainer2 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
      +      .setStandardization(false)
      +    val model1 = trainer1.fit(dataset)
      +    val model2 = trainer2.fit(dataset)
       
      -    /**
      -     * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57))
      -     * > weights
      -     *  3 x 1 sparse Matrix of class "dgCMatrix"
      -     *                           s0
      -     * (Intercept)         6.311546
      -     * as.numeric.data.V2. 2.123522
      -     * as.numeric.data.V3. 4.605651
      +    /*
      +       weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57))
      +       > weights
      +        3 x 1 sparse Matrix of class "dgCMatrix"
      +                                 s0
      +       (Intercept)         6.24300
      +       as.numeric.data.V2. 4.024821
      +       as.numeric.data.V3. 6.679841
            */
      -    val interceptR = 6.243000
      -    val weightsR = Array(4.024821, 6.679841)
      +    val interceptR1 = 6.24300
      +    val weightsR1 = Vectors.dense(4.024821, 6.679841)
       
      -    assert(model.intercept ~== interceptR relTol 1E-3)
      -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
      -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
      +    assert(model1.intercept ~== interceptR1 relTol 1E-3)
      +    assert(model1.weights ~= weightsR1 relTol 1E-3)
       
      -    model.transform(dataset).select("features", "prediction").collect().foreach {
      +    /*
      +      weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
      +        standardize=FALSE))
      +      > weights
      +       3 x 1 sparse Matrix of class "dgCMatrix"
      +                                s0
      +      (Intercept)         6.416948
      +      as.numeric.data.V2. 3.893869
      +      as.numeric.data.V3. 6.724286
      +     */
      +    val interceptR2 = 6.416948
      +    val weightsR2 = Vectors.dense(3.893869, 6.724286)
      +
      +    assert(model2.intercept ~== interceptR2 relTol 1E-3)
      +    assert(model2.weights ~= weightsR2 relTol 1E-3)
      +
      +
      +    model1.transform(dataset).select("features", "prediction").collect().foreach {
             case Row(features: DenseVector, prediction1: Double) =>
               val prediction2 =
      -          features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
      +          features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
      +        assert(prediction1 ~== prediction2 relTol 1E-5)
      +    }
      +  }
      +
      +  test("linear regression without intercept with L1 regularization") {
      +    val trainer1 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
      +      .setFitIntercept(false)
      +    val trainer2 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
      +      .setFitIntercept(false).setStandardization(false)
      +    val model1 = trainer1.fit(dataset)
      +    val model2 = trainer2.fit(dataset)
      +
      +    /*
      +       weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
      +         intercept=FALSE))
      +       > weights
      +        3 x 1 sparse Matrix of class "dgCMatrix"
      +                                 s0
      +       (Intercept)          .
      +       as.numeric.data.V2. 6.299752
      +       as.numeric.data.V3. 4.772913
      +     */
      +    val interceptR1 = 0.0
      +    val weightsR1 = Vectors.dense(6.299752, 4.772913)
      +
      +    assert(model1.intercept ~== interceptR1 absTol 1E-3)
      +    assert(model1.weights ~= weightsR1 relTol 1E-3)
      +
      +    /*
      +       weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
      +         intercept=FALSE, standardize=FALSE))
      +       > weights
      +       3 x 1 sparse Matrix of class "dgCMatrix"
      +                                 s0
      +       (Intercept)         .
      +       as.numeric.data.V2. 6.232193
      +       as.numeric.data.V3. 4.764229
      +     */
      +    val interceptR2 = 0.0
      +    val weightsR2 = Vectors.dense(6.232193, 4.764229)
      +
      +    assert(model2.intercept ~== interceptR2 absTol 1E-3)
      +    assert(model2.weights ~= weightsR2 relTol 1E-3)
      +
      +
      +    model1.transform(dataset).select("features", "prediction").collect().foreach {
      +      case Row(features: DenseVector, prediction1: Double) =>
      +        val prediction2 =
      +          features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
               assert(prediction1 ~== prediction2 relTol 1E-5)
           }
         }
       
         test("linear regression with intercept with L2 regularization") {
      -    val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
      -    val model = trainer.fit(dataset)
      +    val trainer1 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
      +    val trainer2 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
      +      .setStandardization(false)
      +    val model1 = trainer1.fit(dataset)
      +    val model2 = trainer2.fit(dataset)
       
      -    /**
      -     * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3))
      -     * > weights
      -     *  3 x 1 sparse Matrix of class "dgCMatrix"
      -     *                           s0
      -     * (Intercept)         6.328062
      -     * as.numeric.data.V2. 3.222034
      -     * as.numeric.data.V3. 4.926260
      +    /*
      +      weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3))
      +      > weights
      +       3 x 1 sparse Matrix of class "dgCMatrix"
      +                                s0
      +      (Intercept)         5.269376
      +      as.numeric.data.V2. 3.736216
      +      as.numeric.data.V3. 5.712356)
            */
      -    val interceptR = 5.269376
      -    val weightsR = Array(3.736216, 5.712356)
      +    val interceptR1 = 5.269376
      +    val weightsR1 = Vectors.dense(3.736216, 5.712356)
      +
      +    assert(model1.intercept ~== interceptR1 relTol 1E-3)
      +    assert(model1.weights ~= weightsR1 relTol 1E-3)
       
      -    assert(model.intercept ~== interceptR relTol 1E-3)
      -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
      -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
      +    /*
      +      weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
      +        standardize=FALSE))
      +      > weights
      +       3 x 1 sparse Matrix of class "dgCMatrix"
      +                                s0
      +      (Intercept)         5.791109
      +      as.numeric.data.V2. 3.435466
      +      as.numeric.data.V3. 5.910406
      +     */
      +    val interceptR2 = 5.791109
      +    val weightsR2 = Vectors.dense(3.435466, 5.910406)
       
      -    model.transform(dataset).select("features", "prediction").collect().foreach {
      +    assert(model2.intercept ~== interceptR2 relTol 1E-3)
      +    assert(model2.weights ~= weightsR2 relTol 1E-3)
      +
      +    model1.transform(dataset).select("features", "prediction").collect().foreach {
             case Row(features: DenseVector, prediction1: Double) =>
               val prediction2 =
      -          features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
      +          features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
      +        assert(prediction1 ~== prediction2 relTol 1E-5)
      +    }
      +  }
      +
      +  test("linear regression without intercept with L2 regularization") {
      +    val trainer1 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
      +      .setFitIntercept(false)
      +    val trainer2 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
      +      .setFitIntercept(false).setStandardization(false)
      +    val model1 = trainer1.fit(dataset)
      +    val model2 = trainer2.fit(dataset)
      +
      +    /*
      +       weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
      +         intercept = FALSE))
      +       > weights
      +        3 x 1 sparse Matrix of class "dgCMatrix"
      +                                 s0
      +       (Intercept)         .
      +       as.numeric.data.V2. 5.522875
      +       as.numeric.data.V3. 4.214502
      +     */
      +    val interceptR1 = 0.0
      +    val weightsR1 = Vectors.dense(5.522875, 4.214502)
      +
      +    assert(model1.intercept ~== interceptR1 absTol 1E-3)
      +    assert(model1.weights ~= weightsR1 relTol 1E-3)
      +
      +    /*
      +       weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
      +         intercept = FALSE, standardize=FALSE))
      +       > weights
      +        3 x 1 sparse Matrix of class "dgCMatrix"
      +                                 s0
      +       (Intercept)         .
      +       as.numeric.data.V2. 5.263704
      +       as.numeric.data.V3. 4.187419
      +     */
      +    val interceptR2 = 0.0
      +    val weightsR2 = Vectors.dense(5.263704, 4.187419)
      +
      +    assert(model2.intercept ~== interceptR2 absTol 1E-3)
      +    assert(model2.weights ~= weightsR2 relTol 1E-3)
      +
      +    model1.transform(dataset).select("features", "prediction").collect().foreach {
      +      case Row(features: DenseVector, prediction1: Double) =>
      +        val prediction2 =
      +          features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
               assert(prediction1 ~== prediction2 relTol 1E-5)
           }
         }
       
         test("linear regression with intercept with ElasticNet regularization") {
      -    val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
      -    val model = trainer.fit(dataset)
      +    val trainer1 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
      +    val trainer2 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
      +      .setStandardization(false)
      +    val model1 = trainer1.fit(dataset)
      +    val model2 = trainer2.fit(dataset)
      +
      +    /*
      +       weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6))
      +       > weights
      +       3 x 1 sparse Matrix of class "dgCMatrix"
      +       s0
      +       (Intercept)         6.324108
      +       as.numeric.data.V2. 3.168435
      +       as.numeric.data.V3. 5.200403
      +     */
      +    val interceptR1 = 5.696056
      +    val weightsR1 = Vectors.dense(3.670489, 6.001122)
      +
      +    assert(model1.intercept ~== interceptR1 relTol 1E-3)
      +    assert(model1.weights ~= weightsR1 relTol 1E-3)
      +
      +    /*
      +      weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6
      +       standardize=FALSE))
      +      > weights
      +      3 x 1 sparse Matrix of class "dgCMatrix"
      +      s0
      +      (Intercept)         6.114723
      +      as.numeric.data.V2. 3.409937
      +      as.numeric.data.V3. 6.146531
      +     */
      +    val interceptR2 = 6.114723
      +    val weightsR2 = Vectors.dense(3.409937, 6.146531)
      +
      +    assert(model2.intercept ~== interceptR2 relTol 1E-3)
      +    assert(model2.weights ~= weightsR2 relTol 1E-3)
      +
      +    model1.transform(dataset).select("features", "prediction").collect().foreach {
      +      case Row(features: DenseVector, prediction1: Double) =>
      +        val prediction2 =
      +          features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
      +        assert(prediction1 ~== prediction2 relTol 1E-5)
      +    }
      +  }
      +
      +  test("linear regression without intercept with ElasticNet regularization") {
      +    val trainer1 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
      +      .setFitIntercept(false)
      +    val trainer2 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
      +      .setFitIntercept(false).setStandardization(false)
      +    val model1 = trainer1.fit(dataset)
      +    val model2 = trainer2.fit(dataset)
      +
      +    /*
      +       weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6,
      +         intercept=FALSE))
      +       > weights
      +       3 x 1 sparse Matrix of class "dgCMatrix"
      +       s0
      +       (Intercept)         .
      +       as.numeric.dataM.V2. 5.673348
      +       as.numeric.dataM.V3. 4.322251
      +     */
      +    val interceptR1 = 0.0
      +    val weightsR1 = Vectors.dense(5.673348, 4.322251)
      +
      +    assert(model1.intercept ~== interceptR1 absTol 1E-3)
      +    assert(model1.weights ~= weightsR1 relTol 1E-3)
       
      -    /**
      -     * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6))
      -     * > weights
      -     * 3 x 1 sparse Matrix of class "dgCMatrix"
      -     * s0
      -     * (Intercept)         6.324108
      -     * as.numeric.data.V2. 3.168435
      -     * as.numeric.data.V3. 5.200403
      +    /*
      +       weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6,
      +         intercept=FALSE, standardize=FALSE))
      +       > weights
      +       3 x 1 sparse Matrix of class "dgCMatrix"
      +       s0
      +       (Intercept)         .
      +       as.numeric.data.V2. 5.477988
      +       as.numeric.data.V3. 4.297622
            */
      -    val interceptR = 5.696056
      -    val weightsR = Array(3.670489, 6.001122)
      +    val interceptR2 = 0.0
      +    val weightsR2 = Vectors.dense(5.477988, 4.297622)
       
      -    assert(model.intercept ~== interceptR relTol 1E-3)
      -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
      -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
      +    assert(model2.intercept ~== interceptR2 absTol 1E-3)
      +    assert(model2.weights ~= weightsR2 relTol 1E-3)
       
      -    model.transform(dataset).select("features", "prediction").collect().foreach {
      +    model1.transform(dataset).select("features", "prediction").collect().foreach {
             case Row(features: DenseVector, prediction1: Double) =>
               val prediction2 =
      -          features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
      +          features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
               assert(prediction1 ~== prediction2 relTol 1E-5)
           }
         }
      +
      +  test("linear regression model training summary") {
      +    val trainer = new LinearRegression
      +    val model = trainer.fit(dataset)
      +
      +    // Training results for the model should be available
      +    assert(model.hasSummary)
      +
      +    // Residuals in [[LinearRegressionResults]] should equal those manually computed
      +    val expectedResiduals = dataset.select("features", "label")
      +      .map { case Row(features: DenseVector, label: Double) =>
      +      val prediction =
      +        features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
      +      label - prediction
      +    }
      +      .zip(model.summary.residuals.map(_.getDouble(0)))
      +      .collect()
      +      .foreach { case (manualResidual: Double, resultResidual: Double) =>
      +      assert(manualResidual ~== resultResidual relTol 1E-5)
      +    }
      +
      +    /*
      +       Use the following R code to generate model training results.
      +
      +       predictions <- predict(fit, newx=features)
      +       residuals <- label - predictions
      +       > mean(residuals^2) # MSE
      +       [1] 0.009720325
      +       > mean(abs(residuals)) # MAD
      +       [1] 0.07863206
      +       > cor(predictions, label)^2# r^2
      +               [,1]
      +       s0 0.9998749
      +     */
      +    assert(model.summary.meanSquaredError ~== 0.00972035 relTol 1E-5)
      +    assert(model.summary.meanAbsoluteError ~== 0.07863206  relTol 1E-5)
      +    assert(model.summary.r2 ~== 0.9998749 relTol 1E-5)
      +
      +    // Objective function should be monotonically decreasing for linear regression
      +    assert(
      +      model.summary
      +        .objectiveHistory
      +        .sliding(2)
      +        .forall(x => x(0) >= x(1)))
      +  }
      +
      +  test("linear regression model testset evaluation summary") {
      +    val trainer = new LinearRegression
      +    val model = trainer.fit(dataset)
      +
      +    // Evaluating on training dataset should yield results summary equal to training summary
      +    val testSummary = model.evaluate(dataset)
      +    assert(model.summary.meanSquaredError ~== testSummary.meanSquaredError relTol 1E-5)
      +    assert(model.summary.r2 ~== testSummary.r2 relTol 1E-5)
      +    model.summary.residuals.select("residuals").collect()
      +      .zip(testSummary.residuals.select("residuals").collect())
      +      .forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 }
      +  }
       }
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
      index b24ecaa57c89..7b1b3f11481d 100644
      --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
      @@ -19,6 +19,8 @@ package org.apache.spark.ml.regression
       
       import org.apache.spark.SparkFunSuite
       import org.apache.spark.ml.impl.TreeTests
      +import org.apache.spark.ml.util.MLTestingUtils
      +import org.apache.spark.mllib.linalg.Vectors
       import org.apache.spark.mllib.regression.LabeledPoint
       import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
       import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
      @@ -26,7 +28,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
       import org.apache.spark.rdd.RDD
       import org.apache.spark.sql.DataFrame
       
      -
       /**
        * Test suite for [[RandomForestRegressor]].
        */
      @@ -71,6 +72,35 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
           regressionTestWithContinuousFeatures(rf)
         }
       
      +  test("Feature importance with toy data") {
      +    val rf = new RandomForestRegressor()
      +      .setImpurity("variance")
      +      .setMaxDepth(3)
      +      .setNumTrees(3)
      +      .setFeatureSubsetStrategy("all")
      +      .setSubsamplingRate(1.0)
      +      .setSeed(123)
      +
      +    // In this data, feature 1 is very important.
      +    val data: RDD[LabeledPoint] = sc.parallelize(Seq(
      +      new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)),
      +      new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)),
      +      new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)),
      +      new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
      +      new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
      +    ))
      +    val categoricalFeatures = Map.empty[Int, Int]
      +    val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
      +
      +    val model = rf.fit(df)
      +
      +    // copied model must have the same parent.
      +    MLTestingUtils.checkCopy(model)
      +    val importances = model.featureImportances
      +    val mostImportantFeature = importances.argmax
      +    assert(mostImportantFeature === 1)
      +  }
      +
         /////////////////////////////////////////////////////////////////////////////
         // Tests of model save/load
         /////////////////////////////////////////////////////////////////////////////
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
      new file mode 100644
      index 000000000000..997f574e51f6
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
      @@ -0,0 +1,82 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.source.libsvm
      +
      +import java.io.File
      +
      +import com.google.common.base.Charsets
      +import com.google.common.io.Files
      +
      +import org.apache.spark.SparkFunSuite
      +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
      +import org.apache.spark.mllib.util.MLlibTestSparkContext
      +import org.apache.spark.util.Utils
      +
      +class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
      +  var tempDir: File = _
      +  var path: String = _
      +
      +  override def beforeAll(): Unit = {
      +    super.beforeAll()
      +    val lines =
      +      """
      +        |1 1:1.0 3:2.0 5:3.0
      +        |0
      +        |0 2:4.0 4:5.0 6:6.0
      +      """.stripMargin
      +    tempDir = Utils.createTempDir()
      +    val file = new File(tempDir, "part-00000")
      +    Files.write(lines, file, Charsets.US_ASCII)
      +    path = tempDir.toURI.toString
      +  }
      +
      +  override def afterAll(): Unit = {
      +    Utils.deleteRecursively(tempDir)
      +    super.afterAll()
      +  }
      +
      +  test("select as sparse vector") {
      +    val df = sqlContext.read.format("libsvm").load(path)
      +    assert(df.columns(0) == "label")
      +    assert(df.columns(1) == "features")
      +    val row1 = df.first()
      +    assert(row1.getDouble(0) == 1.0)
      +    val v = row1.getAs[SparseVector](1)
      +    assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
      +  }
      +
      +  test("select as dense vector") {
      +    val df = sqlContext.read.format("libsvm").options(Map("vectorType" -> "dense"))
      +      .load(path)
      +    assert(df.columns(0) == "label")
      +    assert(df.columns(1) == "features")
      +    assert(df.count() == 3)
      +    val row1 = df.first()
      +    assert(row1.getDouble(0) == 1.0)
      +    val v = row1.getAs[DenseVector](1)
      +    assert(v == Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0))
      +  }
      +
      +  test("select a vector with specifying the longer dimension") {
      +    val df = sqlContext.read.option("numFeatures", "100").format("libsvm")
      +      .load(path)
      +    val row1 = df.first()
      +    val v = row1.getAs[SparseVector](1)
      +    assert(v == Vectors.sparse(100, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
      +  }
      +}
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
      new file mode 100644
      index 000000000000..dc852795c7f6
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
      @@ -0,0 +1,107 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.tree.impl
      +
      +import org.apache.spark.SparkFunSuite
      +import org.apache.spark.ml.classification.DecisionTreeClassificationModel
      +import org.apache.spark.ml.impl.TreeTests
      +import org.apache.spark.ml.tree.{ContinuousSplit, DecisionTreeModel, LeafNode, Node}
      +import org.apache.spark.mllib.linalg.{Vector, Vectors}
      +import org.apache.spark.mllib.tree.impurity.GiniCalculator
      +import org.apache.spark.mllib.util.MLlibTestSparkContext
      +import org.apache.spark.mllib.util.TestingUtils._
      +import org.apache.spark.util.collection.OpenHashMap
      +
      +/**
      + * Test suite for [[RandomForest]].
      + */
      +class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
      +
      +  import RandomForestSuite.mapToVec
      +
      +  test("computeFeatureImportance, featureImportances") {
      +    /* Build tree for testing, with this structure:
      +          grandParent
      +      left2       parent
      +                left  right
      +     */
      +    val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0))
      +    val left = new LeafNode(0.0, leftImp.calculate(), leftImp)
      +
      +    val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0))
      +    val right = new LeafNode(2.0, rightImp.calculate(), rightImp)
      +
      +    val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5))
      +    val parentImp = parent.impurityStats
      +
      +    val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0))
      +    val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp)
      +
      +    val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0))
      +    val grandImp = grandParent.impurityStats
      +
      +    // Test feature importance computed at different subtrees.
      +    def testNode(node: Node, expected: Map[Int, Double]): Unit = {
      +      val map = new OpenHashMap[Int, Double]()
      +      RandomForest.computeFeatureImportance(node, map)
      +      assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
      +    }
      +
      +    // Leaf node
      +    testNode(left, Map.empty[Int, Double])
      +
      +    // Internal node with 2 leaf children
      +    val feature0importance = parentImp.calculate() * parentImp.count -
      +      (leftImp.calculate() * leftImp.count + rightImp.calculate() * rightImp.count)
      +    testNode(parent, Map(0 -> feature0importance))
      +
      +    // Full tree
      +    val feature1importance = grandImp.calculate() * grandImp.count -
      +      (left2Imp.calculate() * left2Imp.count + parentImp.calculate() * parentImp.count)
      +    testNode(grandParent, Map(0 -> feature0importance, 1 -> feature1importance))
      +
      +    // Forest consisting of (full tree) + (internal node with 2 leafs)
      +    val trees = Array(parent, grandParent).map { root =>
      +      new DecisionTreeClassificationModel(root, numClasses = 3).asInstanceOf[DecisionTreeModel]
      +    }
      +    val importances: Vector = RandomForest.featureImportances(trees, 2)
      +    val tree2norm = feature0importance + feature1importance
      +    val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0,
      +      (feature1importance / tree2norm) / 2.0)
      +    assert(importances ~== expected relTol 0.01)
      +  }
      +
      +  test("normalizeMapValues") {
      +    val map = new OpenHashMap[Int, Double]()
      +    map(0) = 1.0
      +    map(2) = 2.0
      +    RandomForest.normalizeMapValues(map)
      +    val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
      +    assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
      +  }
      +
      +}
      +
      +private object RandomForestSuite {
      +
      +  def mapToVec(map: Map[Int, Double]): Vector = {
      +    val size = (map.keys.toSeq :+ 0).max + 1
      +    val (indices, values) = map.toSeq.sortBy(_._1).unzip
      +    Vectors.sparse(size, indices.toArray, values.toArray)
      +  }
      +}
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
      index db64511a7605..fde02e0c84bc 100644
      --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
      @@ -18,6 +18,7 @@
       package org.apache.spark.ml.tuning
       
       import org.apache.spark.SparkFunSuite
      +import org.apache.spark.ml.util.MLTestingUtils
       import org.apache.spark.ml.{Estimator, Model}
       import org.apache.spark.ml.classification.LogisticRegression
       import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
      @@ -53,6 +54,10 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
             .setEvaluator(eval)
             .setNumFolds(3)
           val cvModel = cv.fit(dataset)
      +
      +    // copied model must have the same paren.
      +    MLTestingUtils.checkCopy(cvModel)
      +
           val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
           assert(parent.getRegParam === 0.001)
           assert(parent.getMaxIter === 10)
      @@ -138,6 +143,8 @@ object CrossValidatorSuite {
             throw new UnsupportedOperationException
           }
       
      +    override def isLargerBetter: Boolean = true
      +
           override val uid: String = "eval"
       
           override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra)
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
      new file mode 100644
      index 000000000000..ef24e6fb6b80
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
      @@ -0,0 +1,141 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.tuning
      +
      +import org.apache.spark.SparkFunSuite
      +import org.apache.spark.ml.{Estimator, Model}
      +import org.apache.spark.ml.classification.LogisticRegression
      +import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
      +import org.apache.spark.ml.param.ParamMap
      +import org.apache.spark.ml.param.shared.HasInputCol
      +import org.apache.spark.ml.regression.LinearRegression
      +import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
      +import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
      +import org.apache.spark.sql.DataFrame
      +import org.apache.spark.sql.types.StructType
      +
      +class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext {
      +  test("train validation with logistic regression") {
      +    val dataset = sqlContext.createDataFrame(
      +      sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
      +
      +    val lr = new LogisticRegression
      +    val lrParamMaps = new ParamGridBuilder()
      +      .addGrid(lr.regParam, Array(0.001, 1000.0))
      +      .addGrid(lr.maxIter, Array(0, 10))
      +      .build()
      +    val eval = new BinaryClassificationEvaluator
      +    val cv = new TrainValidationSplit()
      +      .setEstimator(lr)
      +      .setEstimatorParamMaps(lrParamMaps)
      +      .setEvaluator(eval)
      +      .setTrainRatio(0.5)
      +    val cvModel = cv.fit(dataset)
      +    val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
      +    assert(cv.getTrainRatio === 0.5)
      +    assert(parent.getRegParam === 0.001)
      +    assert(parent.getMaxIter === 10)
      +    assert(cvModel.validationMetrics.length === lrParamMaps.length)
      +  }
      +
      +  test("train validation with linear regression") {
      +    val dataset = sqlContext.createDataFrame(
      +        sc.parallelize(LinearDataGenerator.generateLinearInput(
      +            6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
      +
      +    val trainer = new LinearRegression
      +    val lrParamMaps = new ParamGridBuilder()
      +      .addGrid(trainer.regParam, Array(1000.0, 0.001))
      +      .addGrid(trainer.maxIter, Array(0, 10))
      +      .build()
      +    val eval = new RegressionEvaluator()
      +    val cv = new TrainValidationSplit()
      +      .setEstimator(trainer)
      +      .setEstimatorParamMaps(lrParamMaps)
      +      .setEvaluator(eval)
      +      .setTrainRatio(0.5)
      +    val cvModel = cv.fit(dataset)
      +    val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression]
      +    assert(parent.getRegParam === 0.001)
      +    assert(parent.getMaxIter === 10)
      +    assert(cvModel.validationMetrics.length === lrParamMaps.length)
      +
      +      eval.setMetricName("r2")
      +    val cvModel2 = cv.fit(dataset)
      +    val parent2 = cvModel2.bestModel.parent.asInstanceOf[LinearRegression]
      +    assert(parent2.getRegParam === 0.001)
      +    assert(parent2.getMaxIter === 10)
      +    assert(cvModel2.validationMetrics.length === lrParamMaps.length)
      +  }
      +
      +  test("validateParams should check estimatorParamMaps") {
      +    import TrainValidationSplitSuite._
      +
      +    val est = new MyEstimator("est")
      +    val eval = new MyEvaluator
      +    val paramMaps = new ParamGridBuilder()
      +      .addGrid(est.inputCol, Array("input1", "input2"))
      +      .build()
      +
      +    val cv = new TrainValidationSplit()
      +      .setEstimator(est)
      +      .setEstimatorParamMaps(paramMaps)
      +      .setEvaluator(eval)
      +      .setTrainRatio(0.5)
      +    cv.validateParams() // This should pass.
      +
      +    val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
      +    cv.setEstimatorParamMaps(invalidParamMaps)
      +    intercept[IllegalArgumentException] {
      +      cv.validateParams()
      +    }
      +  }
      +}
      +
      +object TrainValidationSplitSuite {
      +
      +  abstract class MyModel extends Model[MyModel]
      +
      +  class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
      +
      +    override def validateParams(): Unit = require($(inputCol).nonEmpty)
      +
      +    override def fit(dataset: DataFrame): MyModel = {
      +      throw new UnsupportedOperationException
      +    }
      +
      +    override def transformSchema(schema: StructType): StructType = {
      +      throw new UnsupportedOperationException
      +    }
      +
      +    override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)
      +  }
      +
      +  class MyEvaluator extends Evaluator {
      +
      +    override def evaluate(dataset: DataFrame): Double = {
      +      throw new UnsupportedOperationException
      +    }
      +
      +    override def isLargerBetter: Boolean = true
      +
      +    override val uid: String = "eval"
      +
      +    override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra)
      +  }
      +}
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
      new file mode 100644
      index 000000000000..d290cc9b06e7
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
      @@ -0,0 +1,30 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.util
      +
      +import org.apache.spark.ml.Model
      +import org.apache.spark.ml.param.ParamMap
      +
      +object MLTestingUtils {
      +  def checkCopy(model: Model[_]): Unit = {
      +    val copied = model.copy(ParamMap.empty)
      +      .asInstanceOf[Model[_]]
      +    assert(copied.parent.uid == model.parent.uid)
      +    assert(copied.parent == model.parent)
      +  }
      +}
      diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala
      new file mode 100644
      index 000000000000..9e6bc7193c13
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala
      @@ -0,0 +1,125 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.ml.util
      +
      +import java.util.Random
      +
      +import org.apache.spark.SparkFunSuite
      +import org.apache.spark.mllib.util.MLlibTestSparkContext
      +
      +class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext {
      +
      +  import StopwatchSuite._
      +
      +  private def testStopwatchOnDriver(sw: Stopwatch): Unit = {
      +    assert(sw.name === "sw")
      +    assert(sw.elapsed() === 0L)
      +    assert(!sw.isRunning)
      +    intercept[AssertionError] {
      +      sw.stop()
      +    }
      +    val duration = checkStopwatch(sw)
      +    val elapsed = sw.elapsed()
      +    assert(elapsed === duration)
      +    val duration2 = checkStopwatch(sw)
      +    val elapsed2 = sw.elapsed()
      +    assert(elapsed2 === duration + duration2)
      +    assert(sw.toString === s"sw: ${elapsed2}ms")
      +    sw.start()
      +    assert(sw.isRunning)
      +    intercept[AssertionError] {
      +      sw.start()
      +    }
      +  }
      +
      +  test("LocalStopwatch") {
      +    val sw = new LocalStopwatch("sw")
      +    testStopwatchOnDriver(sw)
      +  }
      +
      +  test("DistributedStopwatch on driver") {
      +    val sw = new DistributedStopwatch(sc, "sw")
      +    testStopwatchOnDriver(sw)
      +  }
      +
      +  test("DistributedStopwatch on executors") {
      +    val sw = new DistributedStopwatch(sc, "sw")
      +    val rdd = sc.parallelize(0 until 4, 4)
      +    val acc = sc.accumulator(0L)
      +    rdd.foreach { i =>
      +      acc += checkStopwatch(sw)
      +    }
      +    assert(!sw.isRunning)
      +    val elapsed = sw.elapsed()
      +    assert(elapsed === acc.value)
      +  }
      +
      +  test("MultiStopwatch") {
      +    val sw = new MultiStopwatch(sc)
      +      .addLocal("local")
      +      .addDistributed("spark")
      +    assert(sw("local").name === "local")
      +    assert(sw("spark").name === "spark")
      +    intercept[NoSuchElementException] {
      +      sw("some")
      +    }
      +    assert(sw.toString === "{\n  local: 0ms,\n  spark: 0ms\n}")
      +    val localDuration = checkStopwatch(sw("local"))
      +    val sparkDuration = checkStopwatch(sw("spark"))
      +    val localElapsed = sw("local").elapsed()
      +    val sparkElapsed = sw("spark").elapsed()
      +    assert(localElapsed === localDuration)
      +    assert(sparkElapsed === sparkDuration)
      +    assert(sw.toString ===
      +      s"{\n  local: ${localElapsed}ms,\n  spark: ${sparkElapsed}ms\n}")
      +    val rdd = sc.parallelize(0 until 4, 4)
      +    val acc = sc.accumulator(0L)
      +    rdd.foreach { i =>
      +      sw("local").start()
      +      val duration = checkStopwatch(sw("spark"))
      +      sw("local").stop()
      +      acc += duration
      +    }
      +    val localElapsed2 = sw("local").elapsed()
      +    assert(localElapsed2 === localElapsed)
      +    val sparkElapsed2 = sw("spark").elapsed()
      +    assert(sparkElapsed2 === sparkElapsed + acc.value)
      +  }
      +}
      +
      +private object StopwatchSuite extends SparkFunSuite {
      +
      +  /**
      +   * Checks the input stopwatch on a task that takes a random time (<10ms) to finish. Validates and
      +   * returns the duration reported by the stopwatch.
      +   */
      +  def checkStopwatch(sw: Stopwatch): Long = {
      +    val ubStart = now
      +    sw.start()
      +    val lbStart = now
      +    Thread.sleep(new Random().nextInt(10))
      +    val lb = now - lbStart
      +    val duration = sw.stop()
      +    val ub = now - ubStart
      +    assert(duration >= lb && duration <= ub)
      +    duration
      +  }
      +
      +  /** The current time in milliseconds. */
      +  private def now: Long = System.currentTimeMillis()
      +}
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
      index e8f3d0c4db20..8d14bb657215 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
      @@ -17,7 +17,7 @@
       
       package org.apache.spark.mllib.classification
       
      -import scala.collection.JavaConversions._
      +import scala.collection.JavaConverters._
       import scala.util.Random
       import scala.util.control.Breaks._
       
      @@ -38,7 +38,7 @@ object LogisticRegressionSuite {
           scale: Double,
           nPoints: Int,
           seed: Int): java.util.List[LabeledPoint] = {
      -    seqAsJavaList(generateLogisticInput(offset, scale, nPoints, seed))
      +    generateLogisticInput(offset, scale, nPoints, seed).asJava
         }
       
         // Generate input of the form Y = logistic(offset + scale*X)
      @@ -196,6 +196,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w
             .setStepSize(10.0)
             .setRegParam(0.0)
             .setNumIterations(20)
      +      .setConvergenceTol(0.0005)
       
           val model = lr.run(testRDD)
       
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
      index f7fc8730606a..cffa1ab700f8 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
      @@ -19,13 +19,14 @@ package org.apache.spark.mllib.classification
       
       import scala.util.Random
       
      -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
      +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Vector => BV}
       import breeze.stats.distributions.{Multinomial => BrzMultinomial}
       
       import org.apache.spark.{SparkException, SparkFunSuite}
      -import org.apache.spark.mllib.linalg.Vectors
      +import org.apache.spark.mllib.linalg.{Vector, Vectors}
       import org.apache.spark.mllib.regression.LabeledPoint
       import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
      +import org.apache.spark.mllib.util.TestingUtils._
       import org.apache.spark.util.Utils
       
       object NaiveBayesSuite {
      @@ -154,6 +155,29 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
       
           // Test prediction on Array.
           validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
      +
      +    // Test posteriors
      +    validationData.map(_.features).foreach { features =>
      +      val predicted = model.predictProbabilities(features).toArray
      +      assert(predicted.sum ~== 1.0 relTol 1.0e-10)
      +      val expected = expectedMultinomialProbabilities(model, features)
      +      expected.zip(predicted).foreach { case (e, p) => assert(e ~== p relTol 1.0e-10) }
      +    }
      +  }
      +
      +  /**
      +   * @param model Multinomial Naive Bayes model
      +   * @param testData input to compute posterior probabilities for
      +   * @return posterior class probabilities (in order of labels) for input
      +   */
      +  private def expectedMultinomialProbabilities(model: NaiveBayesModel, testData: Vector) = {
      +    val piVector = new BDV(model.pi)
      +    // model.theta is row-major; treat it as col-major representation of transpose, and transpose:
      +    val thetaMatrix = new BDM(model.theta(0).length, model.theta.length, model.theta.flatten).t
      +    val logClassProbs: BV[Double] = piVector + (thetaMatrix * testData.toBreeze)
      +    val classProbs = logClassProbs.toArray.map(math.exp)
      +    val classProbsSum = classProbs.sum
      +    classProbs.map(_ / classProbsSum)
         }
       
         test("Naive Bayes Bernoulli") {
      @@ -182,6 +206,33 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
       
           // Test prediction on Array.
           validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
      +
      +    // Test posteriors
      +    validationData.map(_.features).foreach { features =>
      +      val predicted = model.predictProbabilities(features).toArray
      +      assert(predicted.sum ~== 1.0 relTol 1.0e-10)
      +      val expected = expectedBernoulliProbabilities(model, features)
      +      expected.zip(predicted).foreach { case (e, p) => assert(e ~== p relTol 1.0e-10) }
      +    }
      +  }
      +
      +  /**
      +   * @param model Bernoulli Naive Bayes model
      +   * @param testData input to compute posterior probabilities for
      +   * @return posterior class probabilities (in order of labels) for input
      +   */
      +  private def expectedBernoulliProbabilities(model: NaiveBayesModel, testData: Vector) = {
      +    val piVector = new BDV(model.pi)
      +    val thetaMatrix = new BDM(model.theta(0).length, model.theta.length, model.theta.flatten).t
      +    val negThetaMatrix = new BDM(model.theta(0).length, model.theta.length,
      +      model.theta.flatten.map(v => math.log(1.0 - math.exp(v)))).t
      +    val testBreeze = testData.toBreeze
      +    val negTestBreeze = new BDV(Array.fill(testBreeze.size)(1.0)) - testBreeze
      +    val piTheta: BV[Double] = piVector + (thetaMatrix * testBreeze)
      +    val logClassProbs: BV[Double] = piTheta + (negThetaMatrix * negTestBreeze)
      +    val classProbs = logClassProbs.toArray.map(math.exp)
      +    val classProbsSum = classProbs.sum
      +    classProbs.map(_ / classProbsSum)
         }
       
         test("detect negative values") {
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
      index b1d78cba9e3d..ee3c85d09a46 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
      @@ -17,7 +17,7 @@
       
       package org.apache.spark.mllib.classification
       
      -import scala.collection.JavaConversions._
      +import scala.collection.JavaConverters._
       import scala.util.Random
       
       import org.jblas.DoubleMatrix
      @@ -35,7 +35,7 @@ object SVMSuite {
           weights: Array[Double],
           nPoints: Int,
           seed: Int): java.util.List[LabeledPoint] = {
      -    seqAsJavaList(generateSVMInput(intercept, weights, nPoints, seed))
      +    generateSVMInput(intercept, weights, nPoints, seed).asJava
         }
       
         // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise)
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
      index fd653296c9d9..d7b291d5a633 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
      @@ -24,13 +24,22 @@ import org.apache.spark.mllib.linalg.Vectors
       import org.apache.spark.mllib.regression.LabeledPoint
       import org.apache.spark.mllib.util.TestingUtils._
       import org.apache.spark.streaming.dstream.DStream
      -import org.apache.spark.streaming.TestSuiteBase
      +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
       
       class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase {
       
         // use longer wait time to ensure job completion
         override def maxWaitTimeMillis: Int = 30000
       
      +  var ssc: StreamingContext = _
      +
      +  override def afterFunction() {
      +    super.afterFunction()
      +    if (ssc != null) {
      +      ssc.stop()
      +    }
      +  }
      +
         // Test if we can accurately learn B for Y = logistic(BX) on streaming data
         test("parameter accuracy") {
       
      @@ -50,7 +59,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
           }
       
           // apply model training to input stream
      -    val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
      +    ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
             model.trainOn(inputDStream)
             inputDStream.count()
           })
      @@ -84,7 +93,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
       
           // apply model training to input stream, storing the intermediate results
           // (we add a count to ensure the result is a DStream)
      -    val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
      +    ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
             model.trainOn(inputDStream)
             inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - B)))
             inputDStream.count()
      @@ -118,7 +127,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
           }
       
           // apply model predictions to test stream
      -    val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
      +    ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
             model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
           })
       
      @@ -147,7 +156,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
           }
       
           // train and predict
      -    val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
      +    ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
             model.trainOn(inputDStream)
             model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
           })
      @@ -167,7 +176,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
             .setNumIterations(10)
           val numBatches = 10
           val emptyInput = Seq.empty[Seq[LabeledPoint]]
      -    val ssc = setupStreams(emptyInput,
      +    ssc = setupStreams(emptyInput,
             (inputDStream: DStream[LabeledPoint]) => {
               model.trainOn(inputDStream)
               model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
      index b218d72f1268..a72723eb00da 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
      @@ -18,7 +18,7 @@
       package org.apache.spark.mllib.clustering
       
       import org.apache.spark.SparkFunSuite
      -import org.apache.spark.mllib.linalg.{Vectors, Matrices}
      +import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrices}
       import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
       import org.apache.spark.mllib.util.MLlibTestSparkContext
       import org.apache.spark.mllib.util.TestingUtils._
      @@ -76,6 +76,20 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext {
           assert(gmm.gaussians(1).sigma ~== Esigma(1) absTol 1E-3)
         }
       
      +  test("two clusters with distributed decompositions") {
      +    val data = sc.parallelize(GaussianTestData.data2, 2)
      +
      +    val k = 5
      +    val d = data.first().size
      +    assert(GaussianMixture.shouldDistributeGaussians(k, d))
      +
      +    val gmm = new GaussianMixture()
      +      .setK(k)
      +      .run(data)
      +
      +    assert(gmm.k === k)
      +  }
      +
         test("single cluster with sparse data") {
           val data = sc.parallelize(Array(
             Vectors.sparse(3, Array(0, 2), Array(4.0, 2.0)),
      @@ -116,7 +130,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext {
           val sparseGMM = new GaussianMixture()
             .setK(2)
             .setInitialModel(initialGmm)
      -      .run(data)
      +      .run(sparseData)
       
           assert(sparseGMM.weights(0) ~== Ew(0) absTol 1E-3)
           assert(sparseGMM.weights(1) ~== Ew(1) absTol 1E-3)
      @@ -148,6 +162,16 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext {
           }
         }
       
      +  test("model prediction, parallel and local") {
      +    val data = sc.parallelize(GaussianTestData.data)
      +    val gmm = new GaussianMixture().setK(2).setSeed(0).run(data)
      +
      +    val batchPredictions = gmm.predict(data)
      +    batchPredictions.zip(data).collect().foreach { case (batchPred, datum) =>
      +      assert(batchPred === gmm.predict(datum))
      +    }
      +  }
      +
         object GaussianTestData {
       
           val data = Array(
      @@ -158,5 +182,9 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext {
             Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
           )
       
      +    val data2: Array[Vector] = Array.tabulate(25){ i: Int =>
      +      Vectors.dense(Array.tabulate(50)(i + _.toDouble))
      +    }
      +
         }
       }
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
      index 0dbbd7127444..3003c62d9876 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
      @@ -278,6 +278,28 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
             }
           }
         }
      +
      +  test("Initialize using given cluster centers") {
      +    val points = Seq(
      +      Vectors.dense(0.0, 0.0),
      +      Vectors.dense(1.0, 0.0),
      +      Vectors.dense(0.0, 1.0),
      +      Vectors.dense(1.0, 1.0)
      +    )
      +    val rdd = sc.parallelize(points, 3)
      +    // creating an initial model
      +    val initialModel = new KMeansModel(Array(points(0), points(2)))
      +
      +    val returnModel = new KMeans()
      +      .setK(2)
      +      .setMaxIterations(0)
      +      .setInitialModel(initialModel)
      +      .run(rdd)
      +   // comparing the returned model and the initial model
      +    assert(returnModel.clusterCenters(0) === initialModel.clusterCenters(0))
      +    assert(returnModel.clusterCenters(1) === initialModel.clusterCenters(1))
      +  }
      +
       }
       
       object KMeansSuite extends SparkFunSuite {
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
      index 406affa25539..37fb69d68f6b 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
      @@ -17,19 +17,24 @@
       
       package org.apache.spark.mllib.clustering
       
      -import breeze.linalg.{DenseMatrix => BDM}
      +import java.util.{ArrayList => JArrayList}
      +
      +import breeze.linalg.{DenseMatrix => BDM, argtopk, max, argmax}
       
       import org.apache.spark.SparkFunSuite
      -import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors}
      +import org.apache.spark.graphx.Edge
      +import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix, Vector, Vectors}
       import org.apache.spark.mllib.util.MLlibTestSparkContext
       import org.apache.spark.mllib.util.TestingUtils._
      +import org.apache.spark.util.Utils
       
       class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
       
         import LDASuite._
       
         test("LocalLDAModel") {
      -    val model = new LocalLDAModel(tinyTopics)
      +    val model = new LocalLDAModel(tinyTopics,
      +      Vectors.dense(Array.fill(tinyTopics.numRows)(1.0 / tinyTopics.numRows)), 1D, 100D)
       
           // Check: basic parameters
           assert(model.k === tinyK)
      @@ -63,6 +68,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
           // Train a model
           val lda = new LDA()
           lda.setK(k)
      +      .setOptimizer(new EMLDAOptimizer)
             .setDocConcentration(topicSmoothing)
             .setTopicConcentration(termSmoothing)
             .setMaxIterations(5)
      @@ -80,37 +86,84 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
           assert(model.topicsMatrix === localModel.topicsMatrix)
       
           // Check: topic summaries
      -    //  The odd decimal formatting and sorting is a hack to do a robust comparison.
      -    val roundedTopicSummary = model.describeTopics().map { case (terms, termWeights) =>
      -      // cut values to 3 digits after the decimal place
      -      terms.zip(termWeights).map { case (term, weight) =>
      -        ("%.3f".format(weight).toDouble, term.toInt)
      -      }
      -    }.sortBy(_.mkString(""))
      -    val roundedLocalTopicSummary = localModel.describeTopics().map { case (terms, termWeights) =>
      -      // cut values to 3 digits after the decimal place
      -      terms.zip(termWeights).map { case (term, weight) =>
      -        ("%.3f".format(weight).toDouble, term.toInt)
      -      }
      -    }.sortBy(_.mkString(""))
      -    roundedTopicSummary.zip(roundedLocalTopicSummary).foreach { case (t1, t2) =>
      -      assert(t1 === t2)
      +    val topicSummary = model.describeTopics().map { case (terms, termWeights) =>
      +      Vectors.sparse(tinyVocabSize, terms, termWeights)
      +    }.sortBy(_.toString)
      +    val localTopicSummary = localModel.describeTopics().map { case (terms, termWeights) =>
      +      Vectors.sparse(tinyVocabSize, terms, termWeights)
      +    }.sortBy(_.toString)
      +    topicSummary.zip(localTopicSummary).foreach { case (topics, topicsLocal) =>
      +      assert(topics ~== topicsLocal absTol 0.01)
           }
       
           // Check: per-doc topic distributions
           val topicDistributions = model.topicDistributions.collect()
      +
           //  Ensure all documents are covered.
      -    assert(topicDistributions.length === tinyCorpus.length)
      -    assert(tinyCorpus.map(_._1).toSet === topicDistributions.map(_._1).toSet)
      +    // SPARK-5562. since the topicDistribution returns the distribution of the non empty docs
      +    // over topics. Compare it against nonEmptyTinyCorpus instead of tinyCorpus
      +    val nonEmptyTinyCorpus = getNonEmptyDoc(tinyCorpus)
      +    assert(topicDistributions.length === nonEmptyTinyCorpus.length)
      +    assert(nonEmptyTinyCorpus.map(_._1).toSet === topicDistributions.map(_._1).toSet)
           //  Ensure we have proper distributions
           topicDistributions.foreach { case (docId, topicDistribution) =>
             assert(topicDistribution.size === tinyK)
             assert(topicDistribution.toArray.sum ~== 1.0 absTol 1e-5)
           }
       
      +    val top2TopicsPerDoc = model.topTopicsPerDocument(2).map(t => (t._1, (t._2, t._3)))
      +    model.topicDistributions.join(top2TopicsPerDoc).collect().foreach {
      +      case (docId, (topicDistribution, (indices, weights))) =>
      +        assert(indices.length == 2)
      +        assert(weights.length == 2)
      +        val bdvTopicDist = topicDistribution.toBreeze
      +        val top2Indices = argtopk(bdvTopicDist, 2)
      +        assert(top2Indices.toArray === indices)
      +        assert(bdvTopicDist(top2Indices).toArray === weights)
      +    }
      +
           // Check: log probabilities
           assert(model.logLikelihood < 0.0)
           assert(model.logPrior < 0.0)
      +
      +    // Check: topDocumentsPerTopic
      +    // Compare it with top documents per topic derived from topicDistributions
      +    val topDocsByTopicDistributions = { n: Int =>
      +      Range(0, k).map { topic =>
      +        val (doc, docWeights) = topicDistributions.sortBy(-_._2(topic)).take(n).unzip
      +        (doc.toArray, docWeights.map(_(topic)).toArray)
      +      }.toArray
      +    }
      +
      +    // Top 3 documents per topic
      +    model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach { case (t1, t2) =>
      +      assert(t1._1 === t2._1)
      +      assert(t1._2 === t2._2)
      +    }
      +
      +    // All documents per topic
      +    val q = tinyCorpus.length
      +    model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach { case (t1, t2) =>
      +      assert(t1._1 === t2._1)
      +      assert(t1._2 === t2._2)
      +    }
      +
      +    // Check: topTopicAssignments
      +    // Make sure it assigns a topic to each term appearing in each doc.
      +    val topTopicAssignments: Map[Long, (Array[Int], Array[Int])] =
      +      model.topicAssignments.collect().map(x => x._1 -> (x._2, x._3)).toMap
      +    assert(topTopicAssignments.keys.max < tinyCorpus.length)
      +    tinyCorpus.foreach { case (docID: Long, doc: Vector) =>
      +      if (topTopicAssignments.contains(docID)) {
      +        val (inds, vals) = topTopicAssignments(docID)
      +        assert(inds.length === doc.numNonzeros)
      +        // For "term" in actual doc,
      +        // check that it has a topic assigned.
      +        doc.foreachActive((term, wcnt) => assert(wcnt === 0 || inds.contains(term)))
      +      } else {
      +        assert(doc.numNonzeros === 0)
      +      }
      +    }
         }
       
         test("vertex indexing") {
      @@ -127,22 +180,38 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
       
         test("setter alias") {
           val lda = new LDA().setAlpha(2.0).setBeta(3.0)
      -    assert(lda.getAlpha === 2.0)
      -    assert(lda.getDocConcentration === 2.0)
      +    assert(lda.getAsymmetricAlpha.toArray.forall(_ === 2.0))
      +    assert(lda.getAsymmetricDocConcentration.toArray.forall(_ === 2.0))
           assert(lda.getBeta === 3.0)
           assert(lda.getTopicConcentration === 3.0)
         }
       
      +  test("initializing with alpha length != k or 1 fails") {
      +    intercept[IllegalArgumentException] {
      +      val lda = new LDA().setK(2).setAlpha(Vectors.dense(1, 2, 3, 4))
      +      val corpus = sc.parallelize(tinyCorpus, 2)
      +      lda.run(corpus)
      +    }
      +  }
      +
      +  test("initializing with elements in alpha < 0 fails") {
      +    intercept[IllegalArgumentException] {
      +      val lda = new LDA().setK(4).setAlpha(Vectors.dense(-1, 2, 3, 4))
      +      val corpus = sc.parallelize(tinyCorpus, 2)
      +      lda.run(corpus)
      +    }
      +  }
      +
         test("OnlineLDAOptimizer initialization") {
           val lda = new LDA().setK(2)
           val corpus = sc.parallelize(tinyCorpus, 2)
           val op = new OnlineLDAOptimizer().initialize(corpus, lda)
           op.setKappa(0.9876).setMiniBatchFraction(0.123).setTau0(567)
      -    assert(op.getAlpha == 0.5) // default 1.0 / k
      -    assert(op.getEta == 0.5)   // default 1.0 / k
      -    assert(op.getKappa == 0.9876)
      -    assert(op.getMiniBatchFraction == 0.123)
      -    assert(op.getTau0 == 567)
      +    assert(op.getAlpha.toArray.forall(_ === 0.5)) // default 1.0 / k
      +    assert(op.getEta === 0.5)   // default 1.0 / k
      +    assert(op.getKappa === 0.9876)
      +    assert(op.getMiniBatchFraction === 0.123)
      +    assert(op.getTau0 === 567)
         }
       
         test("OnlineLDAOptimizer one iteration") {
      @@ -174,23 +243,16 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
       
           // verify the result, Note this generate the identical result as
           // [[https://github.com/Blei-Lab/onlineldavb]]
      -    val topic1 = op.getLambda(0, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
      -    val topic2 = op.getLambda(1, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
      -    assert("1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950" == topic1)
      -    assert("0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050" == topic2)
      +    val topic1: Vector = Vectors.fromBreeze(op.getLambda(0, ::).t)
      +    val topic2: Vector = Vectors.fromBreeze(op.getLambda(1, ::).t)
      +    val expectedTopic1 = Vectors.dense(1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950)
      +    val expectedTopic2 = Vectors.dense(0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050)
      +    assert(topic1 ~== expectedTopic1 absTol 0.01)
      +    assert(topic2 ~== expectedTopic2 absTol 0.01)
         }
       
         test("OnlineLDAOptimizer with toy data") {
      -    def toydata: Array[(Long, Vector)] = Array(
      -      Vectors.sparse(6, Array(0, 1), Array(1, 1)),
      -      Vectors.sparse(6, Array(1, 2), Array(1, 1)),
      -      Vectors.sparse(6, Array(0, 2), Array(1, 1)),
      -      Vectors.sparse(6, Array(3, 4), Array(1, 1)),
      -      Vectors.sparse(6, Array(3, 5), Array(1, 1)),
      -      Vectors.sparse(6, Array(4, 5), Array(1, 1))
      -    ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
      -
      -    val docs = sc.parallelize(toydata)
      +    val docs = sc.parallelize(toyData)
           val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51)
             .setGammaShape(1e10)
           val lda = new LDA().setK(2)
      @@ -213,6 +275,284 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
           }
         }
       
      +  test("LocalLDAModel logLikelihood") {
      +    val ldaModel: LocalLDAModel = toyModel
      +
      +    val docsSingleWord = sc.parallelize(Array(Vectors.sparse(6, Array(0), Array(1)))
      +      .zipWithIndex
      +      .map { case (wordCounts, docId) => (docId.toLong, wordCounts) })
      +    val docsRepeatedWord = sc.parallelize(Array(Vectors.sparse(6, Array(0), Array(5)))
      +      .zipWithIndex
      +      .map { case (wordCounts, docId) => (docId.toLong, wordCounts) })
      +
      +    /* Verify results using gensim:
      +       import numpy as np
      +       from gensim import models
      +       corpus = [
      +          [(0, 1.0), (1, 1.0)],
      +          [(1, 1.0), (2, 1.0)],
      +          [(0, 1.0), (2, 1.0)],
      +          [(3, 1.0), (4, 1.0)],
      +          [(3, 1.0), (5, 1.0)],
      +          [(4, 1.0), (5, 1.0)]]
      +       np.random.seed(2345)
      +       lda = models.ldamodel.LdaModel(
      +          corpus=corpus, alpha=0.01, eta=0.01, num_topics=2, update_every=0, passes=100,
      +          decay=0.51, offset=1024)
      +       docsSingleWord = [[(0, 1.0)]]
      +       docsRepeatedWord = [[(0, 5.0)]]
      +       print(lda.bound(docsSingleWord))
      +       > -25.9706969833
      +       print(lda.bound(docsRepeatedWord))
      +       > -31.4413908227
      +     */
      +
      +    assert(ldaModel.logLikelihood(docsSingleWord) ~== -25.971 relTol 1E-3D)
      +    assert(ldaModel.logLikelihood(docsRepeatedWord) ~== -31.441  relTol 1E-3D)
      +  }
      +
      +  test("LocalLDAModel logPerplexity") {
      +    val docs = sc.parallelize(toyData)
      +    val ldaModel: LocalLDAModel = toyModel
      +
      +    /* Verify results using gensim:
      +       import numpy as np
      +       from gensim import models
      +       corpus = [
      +          [(0, 1.0), (1, 1.0)],
      +          [(1, 1.0), (2, 1.0)],
      +          [(0, 1.0), (2, 1.0)],
      +          [(3, 1.0), (4, 1.0)],
      +          [(3, 1.0), (5, 1.0)],
      +          [(4, 1.0), (5, 1.0)]]
      +       np.random.seed(2345)
      +       lda = models.ldamodel.LdaModel(
      +          corpus=corpus, alpha=0.01, eta=0.01, num_topics=2, update_every=0, passes=100,
      +          decay=0.51, offset=1024)
      +       print(lda.log_perplexity(corpus))
      +       > -3.69051285096
      +     */
      +
      +    // Gensim's definition of perplexity is negative our (and Stanford NLP's) definition
      +    assert(ldaModel.logPerplexity(docs) ~== 3.690D relTol 1E-3D)
      +  }
      +
      +  test("LocalLDAModel predict") {
      +    val docs = sc.parallelize(toyData)
      +    val ldaModel: LocalLDAModel = toyModel
      +
      +    /* Verify results using gensim:
      +       import numpy as np
      +       from gensim import models
      +       corpus = [
      +          [(0, 1.0), (1, 1.0)],
      +          [(1, 1.0), (2, 1.0)],
      +          [(0, 1.0), (2, 1.0)],
      +          [(3, 1.0), (4, 1.0)],
      +          [(3, 1.0), (5, 1.0)],
      +          [(4, 1.0), (5, 1.0)]]
      +       np.random.seed(2345)
      +       lda = models.ldamodel.LdaModel(
      +          corpus=corpus, alpha=0.01, eta=0.01, num_topics=2, update_every=0, passes=100,
      +          decay=0.51, offset=1024)
      +       print(list(lda.get_document_topics(corpus)))
      +       > [[(0, 0.99504950495049516)], [(0, 0.99504950495049516)],
      +       > [(0, 0.99504950495049516)], [(1, 0.99504950495049516)],
      +       > [(1, 0.99504950495049516)], [(1, 0.99504950495049516)]]
      +     */
      +
      +    val expectedPredictions = List(
      +      (0, 0.99504), (0, 0.99504),
      +      (0, 0.99504), (1, 0.99504),
      +      (1, 0.99504), (1, 0.99504))
      +
      +    val actualPredictions = ldaModel.topicDistributions(docs).map { case (id, topics) =>
      +        // convert results to expectedPredictions format, which only has highest probability topic
      +        val topicsBz = topics.toBreeze.toDenseVector
      +        (id, (argmax(topicsBz), max(topicsBz)))
      +      }.sortByKey()
      +      .values
      +      .collect()
      +
      +    expectedPredictions.zip(actualPredictions).forall { case (expected, actual) =>
      +      expected._1 === actual._1 && (expected._2 ~== actual._2 relTol 1E-3D)
      +    }
      +  }
      +
      +  test("OnlineLDAOptimizer with asymmetric prior") {
      +    val docs = sc.parallelize(toyData)
      +    val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51)
      +      .setGammaShape(1e10)
      +    val lda = new LDA().setK(2)
      +      .setDocConcentration(Vectors.dense(0.00001, 0.1))
      +      .setTopicConcentration(0.01)
      +      .setMaxIterations(100)
      +      .setOptimizer(op)
      +      .setSeed(12345)
      +
      +    val ldaModel = lda.run(docs)
      +    val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10)
      +    val topics = topicIndices.map { case (terms, termWeights) =>
      +      terms.zip(termWeights)
      +    }
      +
      +    /* Verify results with Python:
      +
      +       import numpy as np
      +       from gensim import models
      +       corpus = [
      +           [(0, 1.0), (1, 1.0)],
      +           [(1, 1.0), (2, 1.0)],
      +           [(0, 1.0), (2, 1.0)],
      +           [(3, 1.0), (4, 1.0)],
      +           [(3, 1.0), (5, 1.0)],
      +           [(4, 1.0), (5, 1.0)]]
      +       np.random.seed(10)
      +       lda = models.ldamodel.LdaModel(
      +           corpus=corpus, alpha=np.array([0.00001, 0.1]), num_topics=2, update_every=0, passes=100)
      +       lda.print_topics()
      +
      +       > ['0.167*0 + 0.167*1 + 0.167*2 + 0.167*3 + 0.167*4 + 0.167*5',
      +          '0.167*0 + 0.167*1 + 0.167*2 + 0.167*4 + 0.167*3 + 0.167*5']
      +     */
      +    topics.foreach { topic =>
      +      assert(topic.forall { case (_, p) => p ~= 0.167 absTol 0.05 })
      +    }
      +  }
      +
      +  test("OnlineLDAOptimizer alpha hyperparameter optimization") {
      +    val k = 2
      +    val docs = sc.parallelize(toyData)
      +    val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51)
      +      .setGammaShape(100).setOptimizeDocConcentration(true).setSampleWithReplacement(false)
      +    val lda = new LDA().setK(k)
      +      .setDocConcentration(1D / k)
      +      .setTopicConcentration(0.01)
      +      .setMaxIterations(100)
      +      .setOptimizer(op)
      +      .setSeed(12345)
      +    val ldaModel: LocalLDAModel = lda.run(docs).asInstanceOf[LocalLDAModel]
      +
      +    /* Verify the results with gensim:
      +      import numpy as np
      +      from gensim import models
      +      corpus = [
      +       [(0, 1.0), (1, 1.0)],
      +       [(1, 1.0), (2, 1.0)],
      +       [(0, 1.0), (2, 1.0)],
      +       [(3, 1.0), (4, 1.0)],
      +       [(3, 1.0), (5, 1.0)],
      +       [(4, 1.0), (5, 1.0)]]
      +      np.random.seed(2345)
      +      lda = models.ldamodel.LdaModel(
      +         corpus=corpus, alpha='auto', eta=0.01, num_topics=2, update_every=0, passes=100,
      +         decay=0.51, offset=1024)
      +      print(lda.alpha)
      +      > [ 0.42582646  0.43511073]
      +     */
      +
      +    assert(ldaModel.docConcentration ~== Vectors.dense(0.42582646, 0.43511073) absTol 0.05)
      +  }
      +
      +  test("model save/load") {
      +    // Test for LocalLDAModel.
      +    val localModel = new LocalLDAModel(tinyTopics,
      +      Vectors.dense(Array.fill(tinyTopics.numRows)(0.01)), 0.5D, 10D)
      +    val tempDir1 = Utils.createTempDir()
      +    val path1 = tempDir1.toURI.toString
      +
      +    // Test for DistributedLDAModel.
      +    val k = 3
      +    val docConcentration = 1.2
      +    val topicConcentration = 1.5
      +    val lda = new LDA()
      +    lda.setK(k)
      +      .setDocConcentration(docConcentration)
      +      .setTopicConcentration(topicConcentration)
      +      .setMaxIterations(5)
      +      .setSeed(12345)
      +    val corpus = sc.parallelize(tinyCorpus, 2)
      +    val distributedModel: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
      +    val tempDir2 = Utils.createTempDir()
      +    val path2 = tempDir2.toURI.toString
      +
      +    try {
      +      localModel.save(sc, path1)
      +      distributedModel.save(sc, path2)
      +      val samelocalModel = LocalLDAModel.load(sc, path1)
      +      assert(samelocalModel.topicsMatrix === localModel.topicsMatrix)
      +      assert(samelocalModel.k === localModel.k)
      +      assert(samelocalModel.vocabSize === localModel.vocabSize)
      +      assert(samelocalModel.docConcentration === localModel.docConcentration)
      +      assert(samelocalModel.topicConcentration === localModel.topicConcentration)
      +      assert(samelocalModel.gammaShape === localModel.gammaShape)
      +
      +      val sameDistributedModel = DistributedLDAModel.load(sc, path2)
      +      assert(distributedModel.topicsMatrix === sameDistributedModel.topicsMatrix)
      +      assert(distributedModel.k === sameDistributedModel.k)
      +      assert(distributedModel.vocabSize === sameDistributedModel.vocabSize)
      +      assert(distributedModel.iterationTimes === sameDistributedModel.iterationTimes)
      +      assert(distributedModel.docConcentration === sameDistributedModel.docConcentration)
      +      assert(distributedModel.topicConcentration === sameDistributedModel.topicConcentration)
      +      assert(distributedModel.gammaShape === sameDistributedModel.gammaShape)
      +      assert(distributedModel.globalTopicTotals === sameDistributedModel.globalTopicTotals)
      +
      +      val graph = distributedModel.graph
      +      val sameGraph = sameDistributedModel.graph
      +      assert(graph.vertices.sortByKey().collect() === sameGraph.vertices.sortByKey().collect())
      +      val edge = graph.edges.map {
      +        case Edge(sid: Long, did: Long, nos: Double) => (sid, did, nos)
      +      }.sortBy(x => (x._1, x._2)).collect()
      +      val sameEdge = sameGraph.edges.map {
      +        case Edge(sid: Long, did: Long, nos: Double) => (sid, did, nos)
      +      }.sortBy(x => (x._1, x._2)).collect()
      +      assert(edge === sameEdge)
      +    } finally {
      +      Utils.deleteRecursively(tempDir1)
      +      Utils.deleteRecursively(tempDir2)
      +    }
      +  }
      +
      +  test("EMLDAOptimizer with empty docs") {
      +    val vocabSize = 6
      +    val emptyDocsArray = Array.fill(6)(Vectors.sparse(vocabSize, Array.empty, Array.empty))
      +    val emptyDocs = emptyDocsArray
      +      .zipWithIndex.map { case (wordCounts, docId) =>
      +        (docId.toLong, wordCounts)
      +    }
      +    val distributedEmptyDocs = sc.parallelize(emptyDocs, 2)
      +
      +    val op = new EMLDAOptimizer()
      +    val lda = new LDA()
      +      .setK(3)
      +      .setMaxIterations(5)
      +      .setSeed(12345)
      +      .setOptimizer(op)
      +
      +    val model = lda.run(distributedEmptyDocs)
      +    assert(model.vocabSize === vocabSize)
      +  }
      +
      +  test("OnlineLDAOptimizer with empty docs") {
      +    val vocabSize = 6
      +    val emptyDocsArray = Array.fill(6)(Vectors.sparse(vocabSize, Array.empty, Array.empty))
      +    val emptyDocs = emptyDocsArray
      +      .zipWithIndex.map { case (wordCounts, docId) =>
      +        (docId.toLong, wordCounts)
      +    }
      +    val distributedEmptyDocs = sc.parallelize(emptyDocs, 2)
      +
      +    val op = new OnlineLDAOptimizer()
      +    val lda = new LDA()
      +      .setK(3)
      +      .setMaxIterations(5)
      +      .setSeed(12345)
      +      .setOptimizer(op)
      +
      +    val model = lda.run(distributedEmptyDocs)
      +    assert(model.vocabSize === vocabSize)
      +  }
      +
       }
       
       private[clustering] object LDASuite {
      @@ -232,12 +572,51 @@ private[clustering] object LDASuite {
         }
       
         def tinyCorpus: Array[(Long, Vector)] = Array(
      +    Vectors.dense(0, 0, 0, 0, 0), // empty doc
           Vectors.dense(1, 3, 0, 2, 8),
           Vectors.dense(0, 2, 1, 0, 4),
           Vectors.dense(2, 3, 12, 3, 1),
      +    Vectors.dense(0, 0, 0, 0, 0), // empty doc
           Vectors.dense(0, 3, 1, 9, 8),
           Vectors.dense(1, 1, 4, 2, 6)
         ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
         assert(tinyCorpus.forall(_._2.size == tinyVocabSize)) // sanity check for test data
       
      +  def getNonEmptyDoc(corpus: Array[(Long, Vector)]): Array[(Long, Vector)] = corpus.filter {
      +    case (_, wc: Vector) => Vectors.norm(wc, p = 1.0) != 0.0
      +  }
      +
      +  def toyData: Array[(Long, Vector)] = Array(
      +    Vectors.sparse(6, Array(0, 1), Array(1, 1)),
      +    Vectors.sparse(6, Array(1, 2), Array(1, 1)),
      +    Vectors.sparse(6, Array(0, 2), Array(1, 1)),
      +    Vectors.sparse(6, Array(3, 4), Array(1, 1)),
      +    Vectors.sparse(6, Array(3, 5), Array(1, 1)),
      +    Vectors.sparse(6, Array(4, 5), Array(1, 1))
      +  ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
      +
      +  /** Used in the Java Test Suite */
      +  def javaToyData: JArrayList[(java.lang.Long, Vector)] = {
      +    val javaData = new JArrayList[(java.lang.Long, Vector)]
      +    var i = 0
      +    while (i < toyData.length) {
      +      javaData.add((toyData(i)._1, toyData(i)._2))
      +      i += 1
      +    }
      +    javaData
      +  }
      +
      +  def toyModel: LocalLDAModel = {
      +    val k = 2
      +    val vocabSize = 6
      +    val alpha = 0.01
      +    val eta = 0.01
      +    val gammaShape = 100
      +    val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array(
      +      1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597,
      +      0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124))
      +    val ldaModel: LocalLDAModel = new LocalLDAModel(
      +      topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape)
      +    ldaModel
      +  }
       }
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
      index 19e65f1b53ab..189000512155 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
      @@ -68,6 +68,54 @@ class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkCon
           assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
         }
       
      +  test("power iteration clustering on graph") {
      +    /*
      +     We use the following graph to test PIC. All edges are assigned similarity 1.0 except 0.1 for
      +     edge (3, 4).
      +
      +     15-14 -13 -12
      +     |           |
      +     4 . 3 - 2  11
      +     |   | x |   |
      +     5   0 - 1  10
      +     |           |
      +     6 - 7 - 8 - 9
      +     */
      +
      +    val similarities = Seq[(Long, Long, Double)]((0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0),
      +      (1, 3, 1.0), (2, 3, 1.0), (3, 4, 0.1), // (3, 4) is a weak edge
      +      (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0), (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0),
      +      (10, 11, 1.0), (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0))
      +
      +    val edges = similarities.flatMap { case (i, j, s) =>
      +      if (i != j) {
      +        Seq(Edge(i, j, s), Edge(j, i, s))
      +      } else {
      +        None
      +      }
      +    }
      +    val graph = Graph.fromEdges(sc.parallelize(edges, 2), 0.0)
      +
      +    val model = new PowerIterationClustering()
      +      .setK(2)
      +      .run(graph)
      +    val predictions = Array.fill(2)(mutable.Set.empty[Long])
      +    model.assignments.collect().foreach { a =>
      +      predictions(a.cluster) += a.id
      +    }
      +    assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
      +
      +    val model2 = new PowerIterationClustering()
      +      .setK(2)
      +      .setInitializationMode("degree")
      +      .run(sc.parallelize(similarities, 2))
      +    val predictions2 = Array.fill(2)(mutable.Set.empty[Long])
      +    model2.assignments.collect().foreach { a =>
      +      predictions2(a.cluster) += a.id
      +    }
      +    assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
      +  }
      +
         test("normalize and powerIter") {
           /*
            Test normalize() with the following graph:
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
      index ac01622b8a08..3645d29dccdb 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
      @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering
       import org.apache.spark.SparkFunSuite
       import org.apache.spark.mllib.linalg.{Vector, Vectors}
       import org.apache.spark.mllib.util.TestingUtils._
      -import org.apache.spark.streaming.TestSuiteBase
      +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
       import org.apache.spark.streaming.dstream.DStream
       import org.apache.spark.util.random.XORShiftRandom
       
      @@ -28,6 +28,15 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
       
         override def maxWaitTimeMillis: Int = 30000
       
      +  var ssc: StreamingContext = _
      +
      +  override def afterFunction() {
      +    super.afterFunction()
      +    if (ssc != null) {
      +      ssc.stop()
      +    }
      +  }
      +
         test("accuracy for single center and equivalence to grand average") {
           // set parameters
           val numBatches = 10
      @@ -46,7 +55,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
           val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42)
       
           // setup and run the model training
      -    val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
      +    ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
             model.trainOn(inputDStream)
             inputDStream.count()
           })
      @@ -82,7 +91,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
           val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42)
       
           // setup and run the model training
      -    val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
      +    ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
             kMeans.trainOn(inputDStream)
             inputDStream.count()
           })
      @@ -114,7 +123,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
             StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42, Array(Vectors.dense(0.0)))
       
           // setup and run the model training
      -    val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
      +    ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
             kMeans.trainOn(inputDStream)
             inputDStream.count()
           })
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
      index 9de2bdb6d724..4b7f1be58f99 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
      @@ -23,24 +23,85 @@ import org.apache.spark.mllib.util.TestingUtils._
       
       class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
       
      -  test("regression metrics") {
      +  test("regression metrics for unbiased (includes intercept term) predictor") {
      +    /* Verify results in R:
      +       preds = c(2.25, -0.25, 1.75, 7.75)
      +       obs = c(3.0, -0.5, 2.0, 7.0)
      +
      +       SStot = sum((obs - mean(obs))^2)
      +       SSreg = sum((preds - mean(obs))^2)
      +       SSerr = sum((obs - preds)^2)
      +
      +       explainedVariance = SSreg / length(obs)
      +       explainedVariance
      +       > [1] 8.796875
      +       meanAbsoluteError = mean(abs(preds - obs))
      +       meanAbsoluteError
      +       > [1] 0.5
      +       meanSquaredError = mean((preds - obs)^2)
      +       meanSquaredError
      +       > [1] 0.3125
      +       rmse = sqrt(meanSquaredError)
      +       rmse
      +       > [1] 0.559017
      +       r2 = 1 - SSerr / SStot
      +       r2
      +       > [1] 0.9571734
      +     */
      +    val predictionAndObservations = sc.parallelize(
      +      Seq((2.25, 3.0), (-0.25, -0.5), (1.75, 2.0), (7.75, 7.0)), 2)
      +    val metrics = new RegressionMetrics(predictionAndObservations)
      +    assert(metrics.explainedVariance ~== 8.79687 absTol 1E-5,
      +      "explained variance regression score mismatch")
      +    assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch")
      +    assert(metrics.meanSquaredError ~== 0.3125 absTol 1E-5, "mean squared error mismatch")
      +    assert(metrics.rootMeanSquaredError ~== 0.55901 absTol 1E-5,
      +      "root mean squared error mismatch")
      +    assert(metrics.r2 ~== 0.95717 absTol 1E-5, "r2 score mismatch")
      +  }
      +
      +  test("regression metrics for biased (no intercept term) predictor") {
      +    /* Verify results in R:
      +       preds = c(2.5, 0.0, 2.0, 8.0)
      +       obs = c(3.0, -0.5, 2.0, 7.0)
      +
      +       SStot = sum((obs - mean(obs))^2)
      +       SSreg = sum((preds - mean(obs))^2)
      +       SSerr = sum((obs - preds)^2)
      +
      +       explainedVariance = SSreg / length(obs)
      +       explainedVariance
      +       > [1] 8.859375
      +       meanAbsoluteError = mean(abs(preds - obs))
      +       meanAbsoluteError
      +       > [1] 0.5
      +       meanSquaredError = mean((preds - obs)^2)
      +       meanSquaredError
      +       > [1] 0.375
      +       rmse = sqrt(meanSquaredError)
      +       rmse
      +       > [1] 0.6123724
      +       r2 = 1 - SSerr / SStot
      +       r2
      +       > [1] 0.9486081
      +     */
           val predictionAndObservations = sc.parallelize(
             Seq((2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)), 2)
           val metrics = new RegressionMetrics(predictionAndObservations)
      -    assert(metrics.explainedVariance ~== 0.95717 absTol 1E-5,
      +    assert(metrics.explainedVariance ~== 8.85937 absTol 1E-5,
             "explained variance regression score mismatch")
           assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch")
           assert(metrics.meanSquaredError ~== 0.375 absTol 1E-5, "mean squared error mismatch")
           assert(metrics.rootMeanSquaredError ~== 0.61237 absTol 1E-5,
             "root mean squared error mismatch")
      -    assert(metrics.r2 ~== 0.94861 absTol 1E-5, "r2 score mismatch")
      +    assert(metrics.r2 ~== 0.94860 absTol 1E-5, "r2 score mismatch")
         }
       
         test("regression metrics with complete fitting") {
           val predictionAndObservations = sc.parallelize(
             Seq((3.0, 3.0), (0.0, 0.0), (2.0, 2.0), (8.0, 8.0)), 2)
           val metrics = new RegressionMetrics(predictionAndObservations)
      -    assert(metrics.explainedVariance ~== 1.0 absTol 1E-5,
      +    assert(metrics.explainedVariance ~== 8.6875 absTol 1E-5,
             "explained variance regression score mismatch")
           assert(metrics.meanAbsoluteError ~== 0.0 absTol 1E-5, "mean absolute error mismatch")
           assert(metrics.meanSquaredError ~== 0.0 absTol 1E-5, "mean squared error mismatch")
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
      index b6818369208d..a864eec460f2 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
      @@ -37,6 +37,22 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
           assert(syms.length == 2)
           assert(syms(0)._1 == "b")
           assert(syms(1)._1 == "c")
      +
      +    // Test that model built using Word2Vec, i.e wordVectors and wordIndec
      +    // and a Word2VecMap give the same values.
      +    val word2VecMap = model.getVectors
      +    val newModel = new Word2VecModel(word2VecMap)
      +    assert(newModel.getVectors.mapValues(_.toSeq) === word2VecMap.mapValues(_.toSeq))
      +  }
      +
      +  test("Word2Vec throws exception when vocabulary is empty") {
      +    intercept[IllegalArgumentException] {
      +      val sentence = "a b c"
      +      val localDoc = Seq(sentence, sentence)
      +      val doc = sc.parallelize(localDoc)
      +        .map(line => line.split(" ").toSeq)
      +      new Word2Vec().setMinCount(10).fit(doc)
      +    }
         }
       
         test("Word2VecModel") {
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala
      new file mode 100644
      index 000000000000..77a2773c36f5
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala
      @@ -0,0 +1,89 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +package org.apache.spark.mllib.fpm
      +
      +import org.apache.spark.SparkFunSuite
      +import org.apache.spark.mllib.util.MLlibTestSparkContext
      +
      +class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext {
      +
      +  test("association rules using String type") {
      +    val freqItemsets = sc.parallelize(Seq(
      +      (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L),
      +      (Set("r"), 3L),
      +      (Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L),
      +      (Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L),
      +      (Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L),
      +      (Set("t", "y", "x"), 3L),
      +      (Set("t", "y", "x", "z"), 3L)
      +    ).map {
      +      case (items, freq) => new FPGrowth.FreqItemset(items.toArray, freq)
      +    })
      +
      +    val ar = new AssociationRules()
      +
      +    val results1 = ar
      +      .setMinConfidence(0.9)
      +      .run(freqItemsets)
      +      .collect()
      +
      +    /* Verify results using the `R` code:
      +       transactions = as(sapply(
      +         list("r z h k p",
      +              "z y x w v u t s",
      +              "s x o n r",
      +              "x z y m t s q e",
      +              "z",
      +              "x z y r q t p"),
      +         FUN=function(x) strsplit(x," ",fixed=TRUE)),
      +         "transactions")
      +       ars = apriori(transactions,
      +                     parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2))
      +       arsDF = as(ars, "data.frame")
      +       arsDF$support = arsDF$support * length(transactions)
      +       names(arsDF)[names(arsDF) == "support"] = "freq"
      +       > nrow(arsDF)
      +       [1] 23
      +       > sum(arsDF$confidence == 1)
      +       [1] 23
      +     */
      +    assert(results1.size === 23)
      +    assert(results1.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
      +
      +    val results2 = ar
      +      .setMinConfidence(0)
      +      .run(freqItemsets)
      +      .collect()
      +
      +    /* Verify results using the `R` code:
      +       ars = apriori(transactions,
      +                  parameter = list(support = 0.5, confidence = 0.5, target="rules", minlen=2))
      +       arsDF = as(ars, "data.frame")
      +       arsDF$support = arsDF$support * length(transactions)
      +       names(arsDF)[names(arsDF) == "support"] = "freq"
      +       nrow(arsDF)
      +       sum(arsDF$confidence == 1)
      +       > nrow(arsDF)
      +       [1] 30
      +       > sum(arsDF$confidence == 1)
      +       [1] 23
      +     */
      +    assert(results2.size === 30)
      +    assert(results2.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
      +  }
      +}
      +
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
      index 66ae3543ecc4..4a9bfdb348d9 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
      @@ -39,6 +39,22 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
             .setMinSupport(0.9)
             .setNumPartitions(1)
             .run(rdd)
      +
      +    /* Verify results using the `R` code:
      +       transactions = as(sapply(
      +         list("r z h k p",
      +              "z y x w v u t s",
      +              "s x o n r",
      +              "x z y m t s q e",
      +              "z",
      +              "x z y r q t p"),
      +         FUN=function(x) strsplit(x," ",fixed=TRUE)),
      +         "transactions")
      +       > eclat(transactions, parameter = list(support = 0.9))
      +       ...
      +       eclat - zero frequent items
      +       set of 0 itemsets
      +     */
           assert(model6.freqItemsets.count() === 0)
       
           val model3 = fpg
      @@ -48,6 +64,33 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
           val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
             (itemset.items.toSet, itemset.freq)
           }
      +
      +    /* Verify results using the `R` code:
      +       fp = eclat(transactions, parameter = list(support = 0.5))
      +       fpDF = as(sort(fp), "data.frame")
      +       fpDF$support = fpDF$support * length(transactions)
      +       names(fpDF)[names(fpDF) == "support"] = "freq"
      +       > fpDF
      +              items freq
      +       13       {z}    5
      +       14       {x}    4
      +       1      {s,x}    3
      +       2  {t,x,y,z}    3
      +       3    {t,y,z}    3
      +       4    {t,x,y}    3
      +       5    {x,y,z}    3
      +       6      {y,z}    3
      +       7      {x,y}    3
      +       8      {t,y}    3
      +       9    {t,x,z}    3
      +       10     {t,z}    3
      +       11     {t,x}    3
      +       12     {x,z}    3
      +       15       {t}    3
      +       16       {y}    3
      +       17       {s}    3
      +       18       {r}    3
      +     */
           val expected = Set(
             (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L),
             (Set("r"), 3L),
      @@ -62,15 +105,75 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
             .setMinSupport(0.3)
             .setNumPartitions(4)
             .run(rdd)
      +
      +    /* Verify results using the `R` code:
      +       fp = eclat(transactions, parameter = list(support = 0.3))
      +       fpDF = as(fp, "data.frame")
      +       fpDF$support = fpDF$support * length(transactions)
      +       names(fpDF)[names(fpDF) == "support"] = "freq"
      +       > nrow(fpDF)
      +       [1] 54
      +     */
           assert(model2.freqItemsets.count() === 54)
       
           val model1 = fpg
             .setMinSupport(0.1)
             .setNumPartitions(8)
             .run(rdd)
      +
      +    /* Verify results using the `R` code:
      +       fp = eclat(transactions, parameter = list(support = 0.1))
      +       fpDF = as(fp, "data.frame")
      +       fpDF$support = fpDF$support * length(transactions)
      +       names(fpDF)[names(fpDF) == "support"] = "freq"
      +       > nrow(fpDF)
      +       [1] 625
      +     */
           assert(model1.freqItemsets.count() === 625)
         }
       
      +  test("FP-Growth String type association rule generation") {
      +    val transactions = Seq(
      +      "r z h k p",
      +      "z y x w v u t s",
      +      "s x o n r",
      +      "x z y m t s q e",
      +      "z",
      +      "x z y r q t p")
      +      .map(_.split(" "))
      +    val rdd = sc.parallelize(transactions, 2).cache()
      +
      +    /* Verify results using the `R` code:
      +       transactions = as(sapply(
      +         list("r z h k p",
      +              "z y x w v u t s",
      +              "s x o n r",
      +              "x z y m t s q e",
      +              "z",
      +              "x z y r q t p"),
      +         FUN=function(x) strsplit(x," ",fixed=TRUE)),
      +         "transactions")
      +       ars = apriori(transactions,
      +                     parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2))
      +       arsDF = as(ars, "data.frame")
      +       arsDF$support = arsDF$support * length(transactions)
      +       names(arsDF)[names(arsDF) == "support"] = "freq"
      +       > nrow(arsDF)
      +       [1] 23
      +       > sum(arsDF$confidence == 1)
      +       [1] 23
      +     */
      +    val rules = (new FPGrowth())
      +      .setMinSupport(0.5)
      +      .setNumPartitions(2)
      +      .run(rdd)
      +      .generateAssociationRules(0.9)
      +      .collect()
      +
      +    assert(rules.size === 23)
      +    assert(rules.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
      +  }
      +
         test("FP-Growth using Int type") {
           val transactions = Seq(
             "1 2 3",
      @@ -89,6 +192,23 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
             .setMinSupport(0.9)
             .setNumPartitions(1)
             .run(rdd)
      +
      +    /* Verify results using the `R` code:
      +       transactions = as(sapply(
      +         list("1 2 3",
      +              "1 2 3 4",
      +              "5 4 3 2 1",
      +              "6 5 4 3 2 1",
      +              "2 4",
      +              "1 3",
      +              "1 7"),
      +         FUN=function(x) strsplit(x," ",fixed=TRUE)),
      +         "transactions")
      +       > eclat(transactions, parameter = list(support = 0.9))
      +       ...
      +       eclat - zero frequent items
      +       set of 0 itemsets
      +     */
           assert(model6.freqItemsets.count() === 0)
       
           val model3 = fpg
      @@ -100,6 +220,24 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
           val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
             (itemset.items.toSet, itemset.freq)
           }
      +
      +    /* Verify results using the `R` code:
      +       fp = eclat(transactions, parameter = list(support = 0.5))
      +       fpDF = as(sort(fp), "data.frame")
      +       fpDF$support = fpDF$support * length(transactions)
      +       names(fpDF)[names(fpDF) == "support"] = "freq"
      +       > fpDF
      +          items freq
      +      6     {1}    6
      +      3   {1,3}    5
      +      7     {2}    5
      +      8     {3}    5
      +      1   {2,4}    4
      +      2 {1,2,3}    4
      +      4   {2,3}    4
      +      5   {1,2}    4
      +      9     {4}    4
      +     */
           val expected = Set(
             (Set(1), 6L), (Set(2), 5L), (Set(3), 5L), (Set(4), 4L),
             (Set(1, 2), 4L), (Set(1, 3), 5L), (Set(2, 3), 4L),
      @@ -110,12 +248,30 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
             .setMinSupport(0.3)
             .setNumPartitions(4)
             .run(rdd)
      +
      +    /* Verify results using the `R` code:
      +       fp = eclat(transactions, parameter = list(support = 0.3))
      +       fpDF = as(fp, "data.frame")
      +       fpDF$support = fpDF$support * length(transactions)
      +       names(fpDF)[names(fpDF) == "support"] = "freq"
      +       > nrow(fpDF)
      +       [1] 15
      +     */
           assert(model2.freqItemsets.count() === 15)
       
           val model1 = fpg
             .setMinSupport(0.1)
             .setNumPartitions(8)
             .run(rdd)
      +
      +    /* Verify results using the `R` code:
      +       fp = eclat(transactions, parameter = list(support = 0.1))
      +       fpDF = as(fp, "data.frame")
      +       fpDF$support = fpDF$support * length(transactions)
      +       names(fpDF)[names(fpDF) == "support"] = "freq"
      +       > nrow(fpDF)
      +       [1] 65
      +     */
           assert(model1.freqItemsets.count() === 65)
         }
       }
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
      new file mode 100644
      index 000000000000..a83e543859b8
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
      @@ -0,0 +1,379 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +package org.apache.spark.mllib.fpm
      +
      +import org.apache.spark.SparkFunSuite
      +import org.apache.spark.mllib.util.MLlibTestSparkContext
      +
      +class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
      +
      +  test("PrefixSpan internal (integer seq, 0 delim) run, singleton itemsets") {
      +
      +    /*
      +      library("arulesSequences")
      +      prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE"))
      +      freqItemSeq = cspade(
      +        prefixSpanSeqs,
      +        parameter = list(support =
      +          2 / length(unique(transactionInfo(prefixSpanSeqs)$sequenceID)), maxlen = 2 ))
      +      resSeq = as(freqItemSeq, "data.frame")
      +      resSeq
      +    */
      +
      +    val sequences = Array(
      +      Array(0, 1, 0, 3, 0, 4, 0, 5, 0),
      +      Array(0, 2, 0, 3, 0, 1, 0),
      +      Array(0, 2, 0, 4, 0, 1, 0),
      +      Array(0, 3, 0, 1, 0, 3, 0, 4, 0, 5, 0),
      +      Array(0, 3, 0, 4, 0, 4, 0, 3, 0),
      +      Array(0, 6, 0, 5, 0, 3, 0))
      +
      +    val rdd = sc.parallelize(sequences, 2).cache()
      +
      +    val result1 = PrefixSpan.genFreqPatterns(
      +      rdd, minCount = 2L, maxPatternLength = 50, maxLocalProjDBSize = 16L)
      +    val expectedValue1 = Array(
      +      (Array(0, 1, 0), 4L),
      +      (Array(0, 1, 0, 3, 0), 2L),
      +      (Array(0, 1, 0, 3, 0, 4, 0), 2L),
      +      (Array(0, 1, 0, 3, 0, 4, 0, 5, 0), 2L),
      +      (Array(0, 1, 0, 3, 0, 5, 0), 2L),
      +      (Array(0, 1, 0, 4, 0), 2L),
      +      (Array(0, 1, 0, 4, 0, 5, 0), 2L),
      +      (Array(0, 1, 0, 5, 0), 2L),
      +      (Array(0, 2, 0), 2L),
      +      (Array(0, 2, 0, 1, 0), 2L),
      +      (Array(0, 3, 0), 5L),
      +      (Array(0, 3, 0, 1, 0), 2L),
      +      (Array(0, 3, 0, 3, 0), 2L),
      +      (Array(0, 3, 0, 4, 0), 3L),
      +      (Array(0, 3, 0, 4, 0, 5, 0), 2L),
      +      (Array(0, 3, 0, 5, 0), 2L),
      +      (Array(0, 4, 0), 4L),
      +      (Array(0, 4, 0, 5, 0), 2L),
      +      (Array(0, 5, 0), 3L)
      +    )
      +    compareInternalResults(expectedValue1, result1.collect())
      +
      +    val result2 = PrefixSpan.genFreqPatterns(
      +      rdd, minCount = 3, maxPatternLength = 50, maxLocalProjDBSize = 32L)
      +    val expectedValue2 = Array(
      +      (Array(0, 1, 0), 4L),
      +      (Array(0, 3, 0), 5L),
      +      (Array(0, 3, 0, 4, 0), 3L),
      +      (Array(0, 4, 0), 4L),
      +      (Array(0, 5, 0), 3L)
      +    )
      +    compareInternalResults(expectedValue2, result2.collect())
      +
      +    val result3 = PrefixSpan.genFreqPatterns(
      +      rdd, minCount = 2, maxPatternLength = 2, maxLocalProjDBSize = 32L)
      +    val expectedValue3 = Array(
      +      (Array(0, 1, 0), 4L),
      +      (Array(0, 1, 0, 3, 0), 2L),
      +      (Array(0, 1, 0, 4, 0), 2L),
      +      (Array(0, 1, 0, 5, 0), 2L),
      +      (Array(0, 2, 0, 1, 0), 2L),
      +      (Array(0, 2, 0), 2L),
      +      (Array(0, 3, 0), 5L),
      +      (Array(0, 3, 0, 1, 0), 2L),
      +      (Array(0, 3, 0, 3, 0), 2L),
      +      (Array(0, 3, 0, 4, 0), 3L),
      +      (Array(0, 3, 0, 5, 0), 2L),
      +      (Array(0, 4, 0), 4L),
      +      (Array(0, 4, 0, 5, 0), 2L),
      +      (Array(0, 5, 0), 3L)
      +    )
      +    compareInternalResults(expectedValue3, result3.collect())
      +  }
      +
      +  test("PrefixSpan internal (integer seq, -1 delim) run, variable-size itemsets") {
      +    val sequences = Array(
      +      Array(0, 1, 0, 1, 2, 3, 0, 1, 3, 0, 4, 0, 3, 6, 0),
      +      Array(0, 1, 4, 0, 3, 0, 2, 3, 0, 1, 5, 0),
      +      Array(0, 5, 6, 0, 1, 2, 0, 4, 6, 0, 3, 0, 2, 0),
      +      Array(0, 5, 0, 7, 0, 1, 6, 0, 3, 0, 2, 0, 3, 0))
      +    val rdd = sc.parallelize(sequences, 2).cache()
      +    val result = PrefixSpan.genFreqPatterns(
      +      rdd, minCount = 2, maxPatternLength = 5, maxLocalProjDBSize = 128L)
      +
      +    /*
      +      To verify results, create file "prefixSpanSeqs" with content
      +      (format = (transactionID, idxInTransaction, numItemsinItemset, itemset)):
      +        1 1 1 1
      +        1 2 3 1 2 3
      +        1 3 2 1 3
      +        1 4 1 4
      +        1 5 2 3 6
      +        2 1 2 1 4
      +        2 2 1 3
      +        2 3 2 2 3
      +        2 4 2 1 5
      +        3 1 2 5 6
      +        3 2 2 1 2
      +        3 3 2 4 6
      +        3 4 1 3
      +        3 5 1 2
      +        4 1 1 5
      +        4 2 1 7
      +        4 3 2 1 6
      +        4 4 1 3
      +        4 5 1 2
      +        4 6 1 3
      +      In R, run:
      +        library("arulesSequences")
      +        prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE"))
      +        freqItemSeq = cspade(prefixSpanSeqs,
      +                             parameter = list(support = 0.5, maxlen = 5 ))
      +        resSeq = as(freqItemSeq, "data.frame")
      +        resSeq
      +
      +                    sequence support
      +        1              <{1}>    1.00
      +        2              <{2}>    1.00
      +        3              <{3}>    1.00
      +        4              <{4}>    0.75
      +        5              <{5}>    0.75
      +        6              <{6}>    0.75
      +        7          <{1},{6}>    0.50
      +        8          <{2},{6}>    0.50
      +        9          <{5},{6}>    0.50
      +        10       <{1,2},{6}>    0.50
      +        11         <{1},{4}>    0.50
      +        12         <{2},{4}>    0.50
      +        13       <{1,2},{4}>    0.50
      +        14         <{1},{3}>    1.00
      +        15         <{2},{3}>    0.75
      +        16           <{2,3}>    0.50
      +        17         <{3},{3}>    0.75
      +        18         <{4},{3}>    0.75
      +        19         <{5},{3}>    0.50
      +        20         <{6},{3}>    0.50
      +        21     <{5},{6},{3}>    0.50
      +        22     <{6},{2},{3}>    0.50
      +        23     <{5},{2},{3}>    0.50
      +        24     <{5},{1},{3}>    0.50
      +        25     <{2},{4},{3}>    0.50
      +        26     <{1},{4},{3}>    0.50
      +        27   <{1,2},{4},{3}>    0.50
      +        28     <{1},{3},{3}>    0.75
      +        29       <{1,2},{3}>    0.50
      +        30     <{1},{2},{3}>    0.50
      +        31       <{1},{2,3}>    0.50
      +        32         <{1},{2}>    1.00
      +        33           <{1,2}>    0.50
      +        34         <{3},{2}>    0.75
      +        35         <{4},{2}>    0.50
      +        36         <{5},{2}>    0.50
      +        37         <{6},{2}>    0.50
      +        38     <{5},{6},{2}>    0.50
      +        39     <{6},{3},{2}>    0.50
      +        40     <{5},{3},{2}>    0.50
      +        41     <{5},{1},{2}>    0.50
      +        42     <{4},{3},{2}>    0.50
      +        43     <{1},{3},{2}>    0.75
      +        44 <{5},{6},{3},{2}>    0.50
      +        45 <{5},{1},{3},{2}>    0.50
      +        46         <{1},{1}>    0.50
      +        47         <{2},{1}>    0.50
      +        48         <{3},{1}>    0.50
      +        49         <{5},{1}>    0.50
      +        50       <{2,3},{1}>    0.50
      +        51     <{1},{3},{1}>    0.50
      +        52   <{1},{2,3},{1}>    0.50
      +        53     <{1},{2},{1}>    0.50
      +     */
      +    val expectedValue = Array(
      +      (Array(0, 1, 0), 4L),
      +      (Array(0, 2, 0), 4L),
      +      (Array(0, 3, 0), 4L),
      +      (Array(0, 4, 0), 3L),
      +      (Array(0, 5, 0), 3L),
      +      (Array(0, 6, 0), 3L),
      +      (Array(0, 1, 0, 6, 0), 2L),
      +      (Array(0, 2, 0, 6, 0), 2L),
      +      (Array(0, 5, 0, 6, 0), 2L),
      +      (Array(0, 1, 2, 0, 6, 0), 2L),
      +      (Array(0, 1, 0, 4, 0), 2L),
      +      (Array(0, 2, 0, 4, 0), 2L),
      +      (Array(0, 1, 2, 0, 4, 0), 2L),
      +      (Array(0, 1, 0, 3, 0), 4L),
      +      (Array(0, 2, 0, 3, 0), 3L),
      +      (Array(0, 2, 3, 0), 2L),
      +      (Array(0, 3, 0, 3, 0), 3L),
      +      (Array(0, 4, 0, 3, 0), 3L),
      +      (Array(0, 5, 0, 3, 0), 2L),
      +      (Array(0, 6, 0, 3, 0), 2L),
      +      (Array(0, 5, 0, 6, 0, 3, 0), 2L),
      +      (Array(0, 6, 0, 2, 0, 3, 0), 2L),
      +      (Array(0, 5, 0, 2, 0, 3, 0), 2L),
      +      (Array(0, 5, 0, 1, 0, 3, 0), 2L),
      +      (Array(0, 2, 0, 4, 0, 3, 0), 2L),
      +      (Array(0, 1, 0, 4, 0, 3, 0), 2L),
      +      (Array(0, 1, 2, 0, 4, 0, 3, 0), 2L),
      +      (Array(0, 1, 0, 3, 0, 3, 0), 3L),
      +      (Array(0, 1, 2, 0, 3, 0), 2L),
      +      (Array(0, 1, 0, 2, 0, 3, 0), 2L),
      +      (Array(0, 1, 0, 2, 3, 0), 2L),
      +      (Array(0, 1, 0, 2, 0), 4L),
      +      (Array(0, 1, 2, 0), 2L),
      +      (Array(0, 3, 0, 2, 0), 3L),
      +      (Array(0, 4, 0, 2, 0), 2L),
      +      (Array(0, 5, 0, 2, 0), 2L),
      +      (Array(0, 6, 0, 2, 0), 2L),
      +      (Array(0, 5, 0, 6, 0, 2, 0), 2L),
      +      (Array(0, 6, 0, 3, 0, 2, 0), 2L),
      +      (Array(0, 5, 0, 3, 0, 2, 0), 2L),
      +      (Array(0, 5, 0, 1, 0, 2, 0), 2L),
      +      (Array(0, 4, 0, 3, 0, 2, 0), 2L),
      +      (Array(0, 1, 0, 3, 0, 2, 0), 3L),
      +      (Array(0, 5, 0, 6, 0, 3, 0, 2, 0), 2L),
      +      (Array(0, 5, 0, 1, 0, 3, 0, 2, 0), 2L),
      +      (Array(0, 1, 0, 1, 0), 2L),
      +      (Array(0, 2, 0, 1, 0), 2L),
      +      (Array(0, 3, 0, 1, 0), 2L),
      +      (Array(0, 5, 0, 1, 0), 2L),
      +      (Array(0, 2, 3, 0, 1, 0), 2L),
      +      (Array(0, 1, 0, 3, 0, 1, 0), 2L),
      +      (Array(0, 1, 0, 2, 3, 0, 1, 0), 2L),
      +      (Array(0, 1, 0, 2, 0, 1, 0), 2L))
      +
      +    compareInternalResults(expectedValue, result.collect())
      +  }
      +
      +  test("PrefixSpan projections with multiple partial starts") {
      +    val sequences = Seq(
      +      Array(Array(1, 2), Array(1, 2, 3)))
      +    val rdd = sc.parallelize(sequences, 2)
      +    val prefixSpan = new PrefixSpan()
      +      .setMinSupport(1.0)
      +      .setMaxPatternLength(2)
      +    val model = prefixSpan.run(rdd)
      +    val expected = Array(
      +      (Array(Array(1)), 1L),
      +      (Array(Array(1, 2)), 1L),
      +      (Array(Array(1), Array(1)), 1L),
      +      (Array(Array(1), Array(2)), 1L),
      +      (Array(Array(1), Array(3)), 1L),
      +      (Array(Array(1, 3)), 1L),
      +      (Array(Array(2)), 1L),
      +      (Array(Array(2, 3)), 1L),
      +      (Array(Array(2), Array(1)), 1L),
      +      (Array(Array(2), Array(2)), 1L),
      +      (Array(Array(2), Array(3)), 1L),
      +      (Array(Array(3)), 1L))
      +    compareResults(expected, model.freqSequences.collect())
      +  }
      +
      +  test("PrefixSpan Integer type, variable-size itemsets") {
      +    val sequences = Seq(
      +      Array(Array(1, 2), Array(3)),
      +      Array(Array(1), Array(3, 2), Array(1, 2)),
      +      Array(Array(1, 2), Array(5)),
      +      Array(Array(6)))
      +    val rdd = sc.parallelize(sequences, 2).cache()
      +
      +    val prefixSpan = new PrefixSpan()
      +      .setMinSupport(0.5)
      +      .setMaxPatternLength(5)
      +
      +    /*
      +      To verify results, create file "prefixSpanSeqs2" with content
      +      (format = (transactionID, idxInTransaction, numItemsinItemset, itemset)):
      +        1 1 2 1 2
      +        1 2 1 3
      +        2 1 1 1
      +        2 2 2 3 2
      +        2 3 2 1 2
      +        3 1 2 1 2
      +        3 2 1 5
      +        4 1 1 6
      +      In R, run:
      +        library("arulesSequences")
      +        prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE"))
      +        freqItemSeq = cspade(prefixSpanSeqs,
      +                             parameter = 0.5, maxlen = 5 ))
      +        resSeq = as(freqItemSeq, "data.frame")
      +        resSeq
      +
      +           sequence support
      +        1     <{1}>    0.75
      +        2     <{2}>    0.75
      +        3     <{3}>    0.50
      +        4 <{1},{3}>    0.50
      +        5   <{1,2}>    0.75
      +     */
      +
      +    val model = prefixSpan.run(rdd)
      +    val expected = Array(
      +      (Array(Array(1)), 3L),
      +      (Array(Array(2)), 3L),
      +      (Array(Array(3)), 2L),
      +      (Array(Array(1), Array(3)), 2L),
      +      (Array(Array(1, 2)), 3L)
      +    )
      +    compareResults(expected, model.freqSequences.collect())
      +  }
      +
      +  test("PrefixSpan String type, variable-size itemsets") {
      +    // This is the same test as "PrefixSpan Int type, variable-size itemsets" except
      +    // mapped to Strings
      +    val intToString = (1 to 6).zip(Seq("a", "b", "c", "d", "e", "f")).toMap
      +    val sequences = Seq(
      +      Array(Array(1, 2), Array(3)),
      +      Array(Array(1), Array(3, 2), Array(1, 2)),
      +      Array(Array(1, 2), Array(5)),
      +      Array(Array(6))).map(seq => seq.map(itemSet => itemSet.map(intToString)))
      +    val rdd = sc.parallelize(sequences, 2).cache()
      +
      +    val prefixSpan = new PrefixSpan()
      +      .setMinSupport(0.5)
      +      .setMaxPatternLength(5)
      +
      +    val model = prefixSpan.run(rdd)
      +    val expected = Array(
      +      (Array(Array(1)), 3L),
      +      (Array(Array(2)), 3L),
      +      (Array(Array(3)), 2L),
      +      (Array(Array(1), Array(3)), 2L),
      +      (Array(Array(1, 2)), 3L)
      +    ).map { case (pattern, count) =>
      +      (pattern.map(itemSet => itemSet.map(intToString)), count)
      +    }
      +    compareResults(expected, model.freqSequences.collect())
      +  }
      +
      +  private def compareResults[Item](
      +      expectedValue: Array[(Array[Array[Item]], Long)],
      +      actualValue: Array[PrefixSpan.FreqSequence[Item]]): Unit = {
      +    val expectedSet = expectedValue.map { case (pattern: Array[Array[Item]], count: Long) =>
      +      (pattern.map(itemSet => itemSet.toSet).toSeq, count)
      +    }.toSet
      +    val actualSet = actualValue.map { x =>
      +      (x.sequence.map(_.toSet).toSeq, x.freq)
      +    }.toSet
      +    assert(expectedSet === actualSet)
      +  }
      +
      +  private def compareInternalResults(
      +      expectedValue: Array[(Array[Int], Long)],
      +      actualValue: Array[(Array[Int], Long)]): Unit = {
      +    val expectedSet = expectedValue.map(x => (x._1.toSeq, x._2)).toSet
      +    val actualSet = actualValue.map(x => (x._1.toSeq, x._2)).toSet
      +    assert(expectedSet === actualSet)
      +  }
      +}
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
      index d34888af2d73..e331c7598918 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
      @@ -30,20 +30,20 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo
       
         import PeriodicGraphCheckpointerSuite._
       
      -  // TODO: Do I need to call count() on the graphs' RDDs?
      -
         test("Persisting") {
           var graphsToCheck = Seq.empty[GraphToCheck]
       
           val graph1 = createGraph(sc)
      -    val checkpointer = new PeriodicGraphCheckpointer(graph1, 10)
      +    val checkpointer =
      +      new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext)
      +    checkpointer.update(graph1)
           graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
           checkPersistence(graphsToCheck, 1)
       
           var iteration = 2
           while (iteration < 9) {
             val graph = createGraph(sc)
      -      checkpointer.updateGraph(graph)
      +      checkpointer.update(graph)
             graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
             checkPersistence(graphsToCheck, iteration)
             iteration += 1
      @@ -57,7 +57,9 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo
           var graphsToCheck = Seq.empty[GraphToCheck]
           sc.setCheckpointDir(path)
           val graph1 = createGraph(sc)
      -    val checkpointer = new PeriodicGraphCheckpointer(graph1, checkpointInterval)
      +    val checkpointer = new PeriodicGraphCheckpointer[Double, Double](
      +      checkpointInterval, graph1.vertices.sparkContext)
      +    checkpointer.update(graph1)
           graph1.edges.count()
           graph1.vertices.count()
           graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
      @@ -66,7 +68,7 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo
           var iteration = 2
           while (iteration < 9) {
             val graph = createGraph(sc)
      -      checkpointer.updateGraph(graph)
      +      checkpointer.update(graph)
             graph.vertices.count()
             graph.edges.count()
             graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
      @@ -168,7 +170,7 @@ private object PeriodicGraphCheckpointerSuite {
             } else {
               // Graph should never be checkpointed
               assert(!graph.isCheckpointed, "Graph should never have been checkpointed")
      -        assert(graph.getCheckpointFiles.length == 0, "Graph should not have any checkpoint files")
      +        assert(graph.getCheckpointFiles.isEmpty, "Graph should not have any checkpoint files")
             }
           } catch {
             case e: AssertionError =>
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala
      new file mode 100644
      index 000000000000..b2a459a68b5f
      --- /dev/null
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala
      @@ -0,0 +1,173 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.mllib.impl
      +
      +import org.apache.hadoop.fs.{FileSystem, Path}
      +
      +import org.apache.spark.{SparkContext, SparkFunSuite}
      +import org.apache.spark.mllib.util.MLlibTestSparkContext
      +import org.apache.spark.rdd.RDD
      +import org.apache.spark.storage.StorageLevel
      +import org.apache.spark.util.Utils
      +
      +
      +class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext {
      +
      +  import PeriodicRDDCheckpointerSuite._
      +
      +  test("Persisting") {
      +    var rddsToCheck = Seq.empty[RDDToCheck]
      +
      +    val rdd1 = createRDD(sc)
      +    val checkpointer = new PeriodicRDDCheckpointer[Double](10, rdd1.sparkContext)
      +    checkpointer.update(rdd1)
      +    rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1)
      +    checkPersistence(rddsToCheck, 1)
      +
      +    var iteration = 2
      +    while (iteration < 9) {
      +      val rdd = createRDD(sc)
      +      checkpointer.update(rdd)
      +      rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration)
      +      checkPersistence(rddsToCheck, iteration)
      +      iteration += 1
      +    }
      +  }
      +
      +  test("Checkpointing") {
      +    val tempDir = Utils.createTempDir()
      +    val path = tempDir.toURI.toString
      +    val checkpointInterval = 2
      +    var rddsToCheck = Seq.empty[RDDToCheck]
      +    sc.setCheckpointDir(path)
      +    val rdd1 = createRDD(sc)
      +    val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, rdd1.sparkContext)
      +    checkpointer.update(rdd1)
      +    rdd1.count()
      +    rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1)
      +    checkCheckpoint(rddsToCheck, 1, checkpointInterval)
      +
      +    var iteration = 2
      +    while (iteration < 9) {
      +      val rdd = createRDD(sc)
      +      checkpointer.update(rdd)
      +      rdd.count()
      +      rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration)
      +      checkCheckpoint(rddsToCheck, iteration, checkpointInterval)
      +      iteration += 1
      +    }
      +
      +    checkpointer.deleteAllCheckpoints()
      +    rddsToCheck.foreach { rdd =>
      +      confirmCheckpointRemoved(rdd.rdd)
      +    }
      +
      +    Utils.deleteRecursively(tempDir)
      +  }
      +}
      +
      +private object PeriodicRDDCheckpointerSuite {
      +
      +  case class RDDToCheck(rdd: RDD[Double], gIndex: Int)
      +
      +  def createRDD(sc: SparkContext): RDD[Double] = {
      +    sc.parallelize(Seq(0.0, 1.0, 2.0, 3.0))
      +  }
      +
      +  def checkPersistence(rdds: Seq[RDDToCheck], iteration: Int): Unit = {
      +    rdds.foreach { g =>
      +      checkPersistence(g.rdd, g.gIndex, iteration)
      +    }
      +  }
      +
      +  /**
      +   * Check storage level of rdd.
      +   * @param gIndex  Index of rdd in order inserted into checkpointer (from 1).
      +   * @param iteration  Total number of rdds inserted into checkpointer.
      +   */
      +  def checkPersistence(rdd: RDD[_], gIndex: Int, iteration: Int): Unit = {
      +    try {
      +      if (gIndex + 2 < iteration) {
      +        assert(rdd.getStorageLevel == StorageLevel.NONE)
      +      } else {
      +        assert(rdd.getStorageLevel != StorageLevel.NONE)
      +      }
      +    } catch {
      +      case _: AssertionError =>
      +        throw new Exception(s"PeriodicRDDCheckpointerSuite.checkPersistence failed with:\n" +
      +          s"\t gIndex = $gIndex\n" +
      +          s"\t iteration = $iteration\n" +
      +          s"\t rdd.getStorageLevel = ${rdd.getStorageLevel}\n")
      +    }
      +  }
      +
      +  def checkCheckpoint(rdds: Seq[RDDToCheck], iteration: Int, checkpointInterval: Int): Unit = {
      +    rdds.reverse.foreach { g =>
      +      checkCheckpoint(g.rdd, g.gIndex, iteration, checkpointInterval)
      +    }
      +  }
      +
      +  def confirmCheckpointRemoved(rdd: RDD[_]): Unit = {
      +    // Note: We cannot check rdd.isCheckpointed since that value is never updated.
      +    //       Instead, we check for the presence of the checkpoint files.
      +    //       This test should continue to work even after this rdd.isCheckpointed issue
      +    //       is fixed (though it can then be simplified and not look for the files).
      +    val fs = FileSystem.get(rdd.sparkContext.hadoopConfiguration)
      +    rdd.getCheckpointFile.foreach { checkpointFile =>
      +      assert(!fs.exists(new Path(checkpointFile)), "RDD checkpoint file should have been removed")
      +    }
      +  }
      +
      +  /**
      +   * Check checkpointed status of rdd.
      +   * @param gIndex  Index of rdd in order inserted into checkpointer (from 1).
      +   * @param iteration  Total number of rdds inserted into checkpointer.
      +   */
      +  def checkCheckpoint(
      +      rdd: RDD[_],
      +      gIndex: Int,
      +      iteration: Int,
      +      checkpointInterval: Int): Unit = {
      +    try {
      +      if (gIndex % checkpointInterval == 0) {
      +        // We allow 2 checkpoint intervals since we perform an action (checkpointing a second rdd)
      +        // only AFTER PeriodicRDDCheckpointer decides whether to remove the previous checkpoint.
      +        if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) {
      +          assert(rdd.isCheckpointed, "RDD should be checkpointed")
      +          assert(rdd.getCheckpointFile.nonEmpty, "RDD should have 2 checkpoint files")
      +        } else {
      +          confirmCheckpointRemoved(rdd)
      +        }
      +      } else {
      +        // RDD should never be checkpointed
      +        assert(!rdd.isCheckpointed, "RDD should never have been checkpointed")
      +        assert(rdd.getCheckpointFile.isEmpty, "RDD should not have any checkpoint files")
      +      }
      +    } catch {
      +      case e: AssertionError =>
      +        throw new Exception(s"PeriodicRDDCheckpointerSuite.checkCheckpoint failed with:\n" +
      +          s"\t gIndex = $gIndex\n" +
      +          s"\t iteration = $iteration\n" +
      +          s"\t checkpointInterval = $checkpointInterval\n" +
      +          s"\t rdd.isCheckpointed = ${rdd.isCheckpointed}\n" +
      +          s"\t rdd.getCheckpointFile = ${rdd.getCheckpointFile.mkString(", ")}\n" +
      +          s"  AssertionError message: ${e.getMessage}")
      +    }
      +  }
      +
      +}
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
      index b0f3f71113c5..96e5ffef7a13 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
      @@ -126,6 +126,31 @@ class BLASSuite extends SparkFunSuite {
           }
         }
       
      +  test("spr") {
      +    // test dense vector
      +    val alpha = 0.1
      +    val x = new DenseVector(Array(1.0, 2, 2.1, 4))
      +    val U = new DenseVector(Array(1.0, 2, 2, 3, 3, 3, 4, 4, 4, 4))
      +    val expected = new DenseVector(Array(1.1, 2.2, 2.4, 3.21, 3.42, 3.441, 4.4, 4.8, 4.84, 5.6))
      +
      +    spr(alpha, x, U)
      +    assert(U ~== expected absTol 1e-9)
      +
      +    val matrix33 = new DenseVector(Array(1.0, 2, 3, 4, 5))
      +    withClue("Size of vector must match the rank of matrix") {
      +      intercept[Exception] {
      +        spr(alpha, x, matrix33)
      +      }
      +    }
      +
      +    // test sparse vector
      +    val sv = new SparseVector(4, Array(0, 3), Array(1.0, 2))
      +    val U2 = new DenseVector(Array(1.0, 2, 2, 3, 3, 3, 4, 4, 4, 4))
      +    spr(0.1, sv, U2)
      +    val expectedSparse = new DenseVector(Array(1.1, 2.0, 2.0, 3.0, 3.0, 3.0, 4.2, 4.0, 4.0, 4.4))
      +    assert(U2 ~== expectedSparse absTol 1e-15)
      +  }
      +
         test("syr") {
           val dA = new DenseMatrix(4, 4,
             Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8))
      @@ -200,8 +225,15 @@ class BLASSuite extends SparkFunSuite {
           val C10 = C1.copy
           val C11 = C1.copy
           val C12 = C1.copy
      +    val C13 = C1.copy
      +    val C14 = C1.copy
      +    val C15 = C1.copy
      +    val C16 = C1.copy
      +    val C17 = C1.copy
           val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0))
           val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0))
      +    val expected4 = new DenseMatrix(4, 2, Array(5.0, 0.0, 10.0, 5.0, 0.0, 0.0, 5.0, 0.0))
      +    val expected5 = C1.copy
       
           gemm(1.0, dA, B, 2.0, C1)
           gemm(1.0, sA, B, 2.0, C2)
      @@ -211,6 +243,10 @@ class BLASSuite extends SparkFunSuite {
           assert(C2 ~== expected2 absTol 1e-15)
           assert(C3 ~== expected3 absTol 1e-15)
           assert(C4 ~== expected3 absTol 1e-15)
      +    gemm(1.0, dA, B, 0.0, C17)
      +    assert(C17 ~== expected absTol 1e-15)
      +    gemm(1.0, sA, B, 0.0, C17)
      +    assert(C17 ~== expected absTol 1e-15)
       
           withClue("columns of A don't match the rows of B") {
             intercept[Exception] {
      @@ -248,6 +284,16 @@ class BLASSuite extends SparkFunSuite {
           assert(C10 ~== expected2 absTol 1e-15)
           assert(C11 ~== expected3 absTol 1e-15)
           assert(C12 ~== expected3 absTol 1e-15)
      +
      +    gemm(0, dA, B, 5, C13)
      +    gemm(0, sA, B, 5, C14)
      +    gemm(0, dA, B, 1, C15)
      +    gemm(0, sA, B, 1, C16)
      +    assert(C13 ~== expected4 absTol 1e-15)
      +    assert(C14 ~== expected4 absTol 1e-15)
      +    assert(C15 ~== expected5 absTol 1e-15)
      +    assert(C16 ~== expected5 absTol 1e-15)
      +
         }
       
         test("gemv") {
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
      index 8dbb70f5d1c4..bfd6d5495f5e 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
      @@ -74,6 +74,24 @@ class MatricesSuite extends SparkFunSuite {
           }
         }
       
      +  test("equals") {
      +    val dm1 = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0))
      +    assert(dm1 === dm1)
      +    assert(dm1 !== dm1.transpose)
      +
      +    val dm2 = Matrices.dense(2, 2, Array(0.0, 2.0, 1.0, 3.0))
      +    assert(dm1 === dm2.transpose)
      +
      +    val sm1 = dm1.asInstanceOf[DenseMatrix].toSparse
      +    assert(sm1 === sm1)
      +    assert(sm1 === dm1)
      +    assert(sm1 !== sm1.transpose)
      +
      +    val sm2 = dm2.asInstanceOf[DenseMatrix].toSparse
      +    assert(sm1 === sm2.transpose)
      +    assert(sm1 === dm2.transpose)
      +  }
      +
         test("matrix copies are deep copies") {
           val m = 3
           val n = 2
      @@ -455,4 +473,14 @@ class MatricesSuite extends SparkFunSuite {
           lines = mat.toString(5, 100).lines.toArray
           assert(lines.size == 5 && lines.forall(_.size <= 100))
         }
      +
      +  test("numNonzeros and numActives") {
      +    val dm1 = Matrices.dense(3, 2, Array(0, 0, -1, 1, 0, 1))
      +    assert(dm1.numNonzeros === 3)
      +    assert(dm1.numActives === 6)
      +
      +    val sm1 = Matrices.sparse(3, 2, Array(0, 2, 3), Array(0, 2, 1), Array(0.0, -1.2, 0.0))
      +    assert(sm1.numNonzeros === 1)
      +    assert(sm1.numActives === 3)
      +  }
       }
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
      index c4ae0a16f7c0..6508ddeba420 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
      @@ -21,10 +21,10 @@ import scala.util.Random
       
       import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance}
       
      -import org.apache.spark.{SparkException, SparkFunSuite}
      +import org.apache.spark.{Logging, SparkException, SparkFunSuite}
       import org.apache.spark.mllib.util.TestingUtils._
       
      -class VectorsSuite extends SparkFunSuite {
      +class VectorsSuite extends SparkFunSuite with Logging {
       
         val arr = Array(0.1, 0.0, 0.3, 0.4)
         val n = 4
      @@ -57,16 +57,70 @@ class VectorsSuite extends SparkFunSuite {
           assert(vec.values === values)
         }
       
      +  test("sparse vector construction with mismatched indices/values array") {
      +    intercept[IllegalArgumentException] {
      +      Vectors.sparse(4, Array(1, 2, 3), Array(3.0, 5.0, 7.0, 9.0))
      +    }
      +    intercept[IllegalArgumentException] {
      +      Vectors.sparse(4, Array(1, 2, 3), Array(3.0, 5.0))
      +    }
      +  }
      +
      +  test("sparse vector construction with too many indices vs size") {
      +    intercept[IllegalArgumentException] {
      +      Vectors.sparse(3, Array(1, 2, 3, 4), Array(3.0, 5.0, 7.0, 9.0))
      +    }
      +  }
      +
         test("dense to array") {
           val vec = Vectors.dense(arr).asInstanceOf[DenseVector]
           assert(vec.toArray.eq(arr))
         }
       
      +  test("dense argmax") {
      +    val vec = Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]
      +    assert(vec.argmax === -1)
      +
      +    val vec2 = Vectors.dense(arr).asInstanceOf[DenseVector]
      +    assert(vec2.argmax === 3)
      +
      +    val vec3 = Vectors.dense(Array(-1.0, 0.0, -2.0, 1.0)).asInstanceOf[DenseVector]
      +    assert(vec3.argmax === 3)
      +  }
      +
         test("sparse to array") {
           val vec = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector]
           assert(vec.toArray === arr)
         }
       
      +  test("sparse argmax") {
      +    val vec = Vectors.sparse(0, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector]
      +    assert(vec.argmax === -1)
      +
      +    val vec2 = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector]
      +    assert(vec2.argmax === 3)
      +
      +    val vec3 = Vectors.sparse(5, Array(2, 3, 4), Array(1.0, 0.0, -.7))
      +    assert(vec3.argmax === 2)
      +
      +    // check for case that sparse vector is created with
      +    // only negative values {0.0, 0.0,-1.0, -0.7, 0.0}
      +    val vec4 = Vectors.sparse(5, Array(2, 3), Array(-1.0, -.7))
      +    assert(vec4.argmax === 0)
      +
      +    val vec5 = Vectors.sparse(11, Array(0, 3, 10), Array(-1.0, -.7, 0.0))
      +    assert(vec5.argmax === 1)
      +
      +    val vec6 = Vectors.sparse(11, Array(0, 1, 2), Array(-1.0, -.7, 0.0))
      +    assert(vec6.argmax === 2)
      +
      +    val vec7 = Vectors.sparse(5, Array(0, 1, 3), Array(-1.0, 0.0, -.7))
      +    assert(vec7.argmax === 1)
      +
      +    val vec8 = Vectors.sparse(5, Array(1, 2), Array(0.0, -1.0))
      +    assert(vec8.argmax === 0)
      +  }
      +
         test("vector equals") {
           val dv1 = Vectors.dense(arr.clone())
           val dv2 = Vectors.dense(arr.clone())
      @@ -142,7 +196,7 @@ class VectorsSuite extends SparkFunSuite {
           malformatted.foreach { s =>
             intercept[SparkException] {
               Vectors.parse(s)
      -        println(s"Didn't detect malformatted string $s.")
      +        logInfo(s"Didn't detect malformatted string $s.")
             }
           }
         }
      @@ -313,4 +367,11 @@ class VectorsSuite extends SparkFunSuite {
           val sv1c = sv1.compressed.asInstanceOf[DenseVector]
           assert(sv1 === sv1c)
         }
      +
      +  test("SparseVector.slice") {
      +    val v = new SparseVector(5, Array(1, 2, 4), Array(1.1, 2.2, 4.4))
      +    assert(v.slice(Array(0, 2)) === new SparseVector(2, Array(1), Array(2.2)))
      +    assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2)))
      +    assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), Array(2.2, 4.4)))
      +  }
       }
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
      index 4a7b99a976f0..0ecb7a221a50 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
      @@ -135,6 +135,17 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
           assert(closeToZero(U * brzDiag(s) * V.t - localA))
         }
       
      +  test("validate matrix sizes of svd") {
      +    val k = 2
      +    val A = new IndexedRowMatrix(indexedRows)
      +    val svd = A.computeSVD(k, computeU = true)
      +    assert(svd.U.numRows() === m)
      +    assert(svd.U.numCols() === k)
      +    assert(svd.s.size === k)
      +    assert(svd.V.numRows === n)
      +    assert(svd.V.numCols === k)
      +  }
      +
         test("validate k in svd") {
           val A = new IndexedRowMatrix(indexedRows)
           intercept[IllegalArgumentException] {
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
      index b6cb53d0c743..283ffec1d49d 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
      @@ -19,6 +19,7 @@ package org.apache.spark.mllib.linalg.distributed
       
       import scala.util.Random
       
      +import breeze.numerics.abs
       import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd}
       
       import org.apache.spark.SparkFunSuite
      @@ -238,6 +239,22 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
             }
           }
         }
      +
      +  test("QR Decomposition") {
      +    for (mat <- Seq(denseMat, sparseMat)) {
      +      val result = mat.tallSkinnyQR(true)
      +      val expected = breeze.linalg.qr.reduced(mat.toBreeze())
      +      val calcQ = result.Q
      +      val calcR = result.R
      +      assert(closeToZero(abs(expected.q) - abs(calcQ.toBreeze())))
      +      assert(closeToZero(abs(expected.r) - abs(calcR.toBreeze.asInstanceOf[BDM[Double]])))
      +      assert(closeToZero(calcQ.multiply(calcR).toBreeze - mat.toBreeze()))
      +      // Decomposition without computing Q
      +      val rOnly = mat.tallSkinnyQR(computeQ = false)
      +      assert(rOnly.Q == null)
      +      assert(closeToZero(abs(expected.r) - abs(rOnly.R.toBreeze.asInstanceOf[BDM[Double]])))
      +    }
      +  }
       }
       
       class RowMatrixClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
      index a5a59e9fad5a..36ac7d267243 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
      @@ -17,7 +17,7 @@
       
       package org.apache.spark.mllib.optimization
       
      -import scala.collection.JavaConversions._
      +import scala.collection.JavaConverters._
       import scala.util.Random
       
       import org.scalatest.Matchers
      @@ -25,7 +25,7 @@ import org.scalatest.Matchers
       import org.apache.spark.SparkFunSuite
       import org.apache.spark.mllib.linalg.Vectors
       import org.apache.spark.mllib.regression._
      -import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
      +import org.apache.spark.mllib.util.{MLUtils, LocalClusterSparkContext, MLlibTestSparkContext}
       import org.apache.spark.mllib.util.TestingUtils._
       
       object GradientDescentSuite {
      @@ -35,7 +35,7 @@ object GradientDescentSuite {
             scale: Double,
             nPoints: Int,
             seed: Int): java.util.List[LabeledPoint] = {
      -    seqAsJavaList(generateGDInput(offset, scale, nPoints, seed))
      +    generateGDInput(offset, scale, nPoints, seed).asJava
         }
       
         // Generate input of the form Y = logistic(offset + scale * X)
      @@ -82,11 +82,11 @@ class GradientDescentSuite extends SparkFunSuite with MLlibTestSparkContext with
           // Add a extra variable consisting of all 1.0's for the intercept.
           val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42)
           val data = testData.map { case LabeledPoint(label, features) =>
      -      label -> Vectors.dense(1.0 +: features.toArray)
      +      label -> MLUtils.appendBias(features)
           }
       
           val dataRDD = sc.parallelize(data, 2).cache()
      -    val initialWeightsWithIntercept = Vectors.dense(1.0 +: initialWeights.toArray)
      +    val initialWeightsWithIntercept = Vectors.dense(initialWeights.toArray :+ 1.0)
       
           val (_, loss) = GradientDescent.runMiniBatchSGD(
             dataRDD,
      @@ -139,6 +139,45 @@ class GradientDescentSuite extends SparkFunSuite with MLlibTestSparkContext with
             "The different between newWeights with/without regularization " +
               "should be initialWeightsWithIntercept.")
         }
      +
      +  test("iteration should end with convergence tolerance") {
      +    val nPoints = 10000
      +    val A = 2.0
      +    val B = -1.5
      +
      +    val initialB = -1.0
      +    val initialWeights = Array(initialB)
      +
      +    val gradient = new LogisticGradient()
      +    val updater = new SimpleUpdater()
      +    val stepSize = 1.0
      +    val numIterations = 10
      +    val regParam = 0
      +    val miniBatchFrac = 1.0
      +    val convergenceTolerance = 5.0e-1
      +
      +    // Add a extra variable consisting of all 1.0's for the intercept.
      +    val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42)
      +    val data = testData.map { case LabeledPoint(label, features) =>
      +      label -> MLUtils.appendBias(features)
      +    }
      +
      +    val dataRDD = sc.parallelize(data, 2).cache()
      +    val initialWeightsWithIntercept = Vectors.dense(initialWeights.toArray :+ 1.0)
      +
      +    val (_, loss) = GradientDescent.runMiniBatchSGD(
      +      dataRDD,
      +      gradient,
      +      updater,
      +      stepSize,
      +      numIterations,
      +      regParam,
      +      miniBatchFrac,
      +      initialWeightsWithIntercept,
      +      convergenceTolerance)
      +
      +    assert(loss.length < numIterations, "convergenceTolerance failed to stop optimization early")
      +  }
       }
       
       class GradientDescentClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
      index d07b9d5b8922..75ae0eb32fb7 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
      @@ -122,7 +122,8 @@ class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers
             numGDIterations,
             regParam,
             miniBatchFrac,
      -      initialWeightsWithIntercept)
      +      initialWeightsWithIntercept,
      +      convergenceTol)
       
           assert(lossGD(0) ~= lossLBFGS(0) absTol 1E-5,
             "The first losses of LBFGS and GD should be the same.")
      @@ -221,7 +222,8 @@ class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers
             numGDIterations,
             regParam,
             miniBatchFrac,
      -      initialWeightsWithIntercept)
      +      initialWeightsWithIntercept,
      +      convergenceTol)
       
           // for class LBFGS and the optimize method, we only look at the weights
           assert(
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
      index a5ca1518f82f..8416771552fd 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
      @@ -17,7 +17,7 @@
       
       package org.apache.spark.mllib.random
       
      -import scala.math
      +import org.apache.commons.math3.special.Gamma
       
       import org.apache.spark.SparkFunSuite
       import org.apache.spark.util.StatCounter
      @@ -136,4 +136,18 @@ class RandomDataGeneratorSuite extends SparkFunSuite {
               distributionChecks(gamma, expectedMean, expectedStd, 0.1)
           }
         }
      +
      +  test("WeibullGenerator") {
      +    List((1.0, 2.0), (2.0, 3.0), (2.5, 3.5), (10.4, 2.222)).map {
      +      case (alpha: Double, beta: Double) =>
      +        val weibull = new WeibullGenerator(alpha, beta)
      +        apiChecks(weibull)
      +
      +        val expectedMean = math.exp(Gamma.logGamma(1 + (1 / alpha))) * beta
      +        val expectedVariance = math.exp(
      +          Gamma.logGamma(1 + (2 / alpha))) * beta * beta - expectedMean * expectedMean
      +        val expectedStd = math.sqrt(expectedVariance)
      +        distributionChecks(weibull, expectedMean, expectedStd, 0.1)
      +    }
      +  }
       }
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
      index 05b87728d6fd..045135f7f8d6 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
      @@ -17,7 +17,7 @@
       
       package org.apache.spark.mllib.recommendation
       
      -import scala.collection.JavaConversions._
      +import scala.collection.JavaConverters._
       import scala.math.abs
       import scala.util.Random
       
      @@ -38,7 +38,7 @@ object ALSSuite {
             negativeWeights: Boolean): (java.util.List[Rating], DoubleMatrix, DoubleMatrix) = {
           val (sampledRatings, trueRatings, truePrefs) =
             generateRatings(users, products, features, samplingRate, implicitPrefs)
      -    (seqAsJavaList(sampledRatings), trueRatings, truePrefs)
      +    (sampledRatings.asJava, trueRatings, truePrefs)
         }
       
         def generateRatings(
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala
      index d8364a06de4d..f8d0af8820e6 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala
      @@ -31,6 +31,11 @@ class LabeledPointSuite extends SparkFunSuite {
           }
         }
       
      +  test("parse labeled points with whitespaces") {
      +    val point = LabeledPoint.parse("(0.0, [1.0, 2.0])")
      +    assert(point === LabeledPoint(0.0, Vectors.dense(1.0, 2.0)))
      +  }
      +
         test("parse labeled points with v0.9 format") {
           val point = LabeledPoint.parse("1.0,1.0 0.0 -2.0")
           assert(point === LabeledPoint(1.0, Vectors.dense(1.0, 0.0, -2.0)))
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
      index 08a152ffc7a2..39537e7bb4c7 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
      @@ -100,7 +100,7 @@ class LassoSuite extends SparkFunSuite with MLlibTestSparkContext {
           val testRDD = sc.parallelize(testData, 2).cache()
       
           val ls = new LassoWithSGD()
      -    ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40)
      +    ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40).setConvergenceTol(0.0005)
       
           val model = ls.run(testRDD, initialWeights)
           val weight0 = model.weights(0)
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
      index f5e2d31056cb..34c07ed17081 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
      @@ -22,14 +22,23 @@ import scala.collection.mutable.ArrayBuffer
       import org.apache.spark.SparkFunSuite
       import org.apache.spark.mllib.linalg.Vectors
       import org.apache.spark.mllib.util.LinearDataGenerator
      +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
       import org.apache.spark.streaming.dstream.DStream
      -import org.apache.spark.streaming.TestSuiteBase
       
       class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
       
         // use longer wait time to ensure job completion
         override def maxWaitTimeMillis: Int = 20000
       
      +  var ssc: StreamingContext = _
      +
      +  override def afterFunction() {
      +    super.afterFunction()
      +    if (ssc != null) {
      +      ssc.stop()
      +    }
      +  }
      +
         // Assert that two values are equal within tolerance epsilon
         def assertEqual(v1: Double, v2: Double, epsilon: Double) {
           def errorMessage = v1.toString + " did not equal " + v2.toString
      @@ -53,6 +62,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
             .setInitialWeights(Vectors.dense(0.0, 0.0))
             .setStepSize(0.2)
             .setNumIterations(25)
      +      .setConvergenceTol(0.0001)
       
           // generate sequence of simulated data
           val numBatches = 10
      @@ -61,7 +71,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
           }
       
           // apply model training to input stream
      -    val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
      +    ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
             model.trainOn(inputDStream)
             inputDStream.count()
           })
      @@ -97,7 +107,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
       
           // apply model training to input stream, storing the intermediate results
           // (we add a count to ensure the result is a DStream)
      -    val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
      +    ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
             model.trainOn(inputDStream)
             inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - 10.0)))
             inputDStream.count()
      @@ -128,7 +138,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
           }
       
           // apply model predictions to test stream
      -    val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
      +    ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
             model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
           })
           // collect the output as (true, estimated) tuples
      @@ -155,7 +165,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
           }
       
           // train and predict
      -    val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
      +    ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
             model.trainOn(inputDStream)
             model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
           })
      @@ -176,7 +186,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
           val numBatches = 10
           val nPoints = 100
           val emptyInput = Seq.empty[Seq[LabeledPoint]]
      -    val ssc = setupStreams(emptyInput,
      +    ssc = setupStreams(emptyInput,
             (inputDStream: DStream[LabeledPoint]) => {
               model.trainOn(inputDStream)
               model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
      index c292ced75e87..c3eeda012571 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
      @@ -19,13 +19,13 @@ package org.apache.spark.mllib.stat
       
       import breeze.linalg.{DenseMatrix => BDM, Matrix => BM}
       
      -import org.apache.spark.SparkFunSuite
      +import org.apache.spark.{Logging, SparkFunSuite}
       import org.apache.spark.mllib.linalg.Vectors
       import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation,
         SpearmanCorrelation}
       import org.apache.spark.mllib.util.MLlibTestSparkContext
       
      -class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext {
      +class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
       
         // test input data
         val xData = Array(1.0, 0.0, -2.0)
      @@ -146,7 +146,7 @@ class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext {
         def matrixApproxEqual(A: BM[Double], B: BM[Double], threshold: Double = 1e-6): Boolean = {
           for (i <- 0 until A.rows; j <- 0 until A.cols) {
             if (!approxEqual(A(i, j), B(i, j), threshold)) {
      -        println("i, j = " + i + ", " + j + " actual: " + A(i, j) + " expected:" + B(i, j))
      +        logInfo("i, j = " + i + ", " + j + " actual: " + A(i, j) + " expected:" + B(i, j))
               return false
             }
           }
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
      index b084a5fb4313..142b90e764a7 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
      @@ -19,6 +19,10 @@ package org.apache.spark.mllib.stat
       
       import java.util.Random
       
      +import org.apache.commons.math3.distribution.{ExponentialDistribution,
      +  NormalDistribution, UniformRealDistribution}
      +import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest
      +
       import org.apache.spark.{SparkException, SparkFunSuite}
       import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors}
       import org.apache.spark.mllib.regression.LabeledPoint
      @@ -153,4 +157,101 @@ class HypothesisTestSuite extends SparkFunSuite with MLlibTestSparkContext {
             Statistics.chiSqTest(sc.parallelize(continuousFeature, 2))
           }
         }
      +
      +  test("1 sample Kolmogorov-Smirnov test: apache commons math3 implementation equivalence") {
      +    // Create theoretical distributions
      +    val stdNormalDist = new NormalDistribution(0, 1)
      +    val expDist = new ExponentialDistribution(0.6)
      +    val unifDist = new UniformRealDistribution()
      +
      +    // set seeds
      +    val seed = 10L
      +    stdNormalDist.reseedRandomGenerator(seed)
      +    expDist.reseedRandomGenerator(seed)
      +    unifDist.reseedRandomGenerator(seed)
      +
      +    // Sample data from the distributions and parallelize it
      +    val n = 100000
      +    val sampledNorm = sc.parallelize(stdNormalDist.sample(n), 10)
      +    val sampledExp = sc.parallelize(expDist.sample(n), 10)
      +    val sampledUnif = sc.parallelize(unifDist.sample(n), 10)
      +
      +    // Use a apache math commons local KS test to verify calculations
      +    val ksTest = new KolmogorovSmirnovTest()
      +    val pThreshold = 0.05
      +
      +    // Comparing a standard normal sample to a standard normal distribution
      +    val result1 = Statistics.kolmogorovSmirnovTest(sampledNorm, "norm", 0, 1)
      +    val referenceStat1 = ksTest.kolmogorovSmirnovStatistic(stdNormalDist, sampledNorm.collect())
      +    val referencePVal1 = 1 - ksTest.cdf(referenceStat1, n)
      +    // Verify vs apache math commons ks test
      +    assert(result1.statistic ~== referenceStat1 relTol 1e-4)
      +    assert(result1.pValue ~== referencePVal1 relTol 1e-4)
      +    // Cannot reject null hypothesis
      +    assert(result1.pValue > pThreshold)
      +
      +    // Comparing an exponential sample to a standard normal distribution
      +    val result2 = Statistics.kolmogorovSmirnovTest(sampledExp, "norm", 0, 1)
      +    val referenceStat2 = ksTest.kolmogorovSmirnovStatistic(stdNormalDist, sampledExp.collect())
      +    val referencePVal2 = 1 - ksTest.cdf(referenceStat2, n)
      +    // verify vs apache math commons ks test
      +    assert(result2.statistic ~== referenceStat2 relTol 1e-4)
      +    assert(result2.pValue ~== referencePVal2 relTol 1e-4)
      +    // reject null hypothesis
      +    assert(result2.pValue < pThreshold)
      +
      +    // Testing the use of a user provided CDF function
      +    // Distribution is not serializable, so will have to create in the lambda
      +    val expCDF = (x: Double) => new ExponentialDistribution(0.2).cumulativeProbability(x)
      +
      +    // Comparing an exponential sample with mean X to an exponential distribution with mean Y
      +    // Where X != Y
      +    val result3 = Statistics.kolmogorovSmirnovTest(sampledExp, expCDF)
      +    val referenceStat3 = ksTest.kolmogorovSmirnovStatistic(new ExponentialDistribution(0.2),
      +      sampledExp.collect())
      +    val referencePVal3 = 1 - ksTest.cdf(referenceStat3, sampledNorm.count().toInt)
      +    // verify vs apache math commons ks test
      +    assert(result3.statistic ~== referenceStat3 relTol 1e-4)
      +    assert(result3.pValue ~== referencePVal3 relTol 1e-4)
      +    // reject null hypothesis
      +    assert(result3.pValue < pThreshold)
      +  }
      +
      +  test("1 sample Kolmogorov-Smirnov test: R implementation equivalence") {
      +    /*
      +      Comparing results with R's implementation of Kolmogorov-Smirnov for 1 sample
      +      > sessionInfo()
      +      R version 3.2.0 (2015-04-16)
      +      Platform: x86_64-apple-darwin13.4.0 (64-bit)
      +      > set.seed(20)
      +      > v <- rnorm(20)
      +      > v
      +       [1]  1.16268529 -0.58592447  1.78546500 -1.33259371 -0.44656677  0.56960612
      +       [7] -2.88971761 -0.86901834 -0.46170268 -0.55554091 -0.02013537 -0.15038222
      +      [13] -0.62812676  1.32322085 -1.52135057 -0.43742787  0.97057758  0.02822264
      +      [19] -0.08578219  0.38921440
      +      > ks.test(v, pnorm, alternative = "two.sided")
      +
      +               One-sample Kolmogorov-Smirnov test
      +
      +      data:  v
      +      D = 0.18874, p-value = 0.4223
      +      alternative hypothesis: two-sided
      +    */
      +
      +    val rKSStat = 0.18874
      +    val rKSPVal = 0.4223
      +    val rData = sc.parallelize(
      +      Array(
      +        1.1626852897838, -0.585924465893051, 1.78546500331661, -1.33259371048501,
      +        -0.446566766553219, 0.569606122374976, -2.88971761441412, -0.869018343326555,
      +        -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063,
      +        -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691,
      +        0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942
      +      )
      +    )
      +    val rCompResult = Statistics.kolmogorovSmirnovTest(rData, "norm", 0, 1)
      +    assert(rCompResult.statistic ~== rKSStat relTol 1e-4)
      +    assert(rCompResult.pValue ~== rKSPVal relTol 1e-4)
      +  }
       }
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
      index 07efde4f5e6d..b6d41db69be0 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
      @@ -218,4 +218,31 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite {
           s0.merge(s1)
           assert(s0.mean(0) ~== 1.0 absTol 1e-14)
         }
      +
      +  test("merging summarizer with weighted samples") {
      +    val summarizer = (new MultivariateOnlineSummarizer)
      +      .add(instance = Vectors.sparse(3, Seq((0, -0.8), (1, 1.7))), weight = 0.1)
      +      .add(Vectors.dense(0.0, -1.2, -1.7), 0.2).merge(
      +        (new MultivariateOnlineSummarizer)
      +          .add(Vectors.sparse(3, Seq((0, -0.7), (1, 0.01), (2, 1.3))), 0.15)
      +          .add(Vectors.dense(-0.5, 0.3, -1.5), 0.05))
      +
      +    assert(summarizer.count === 4)
      +
      +    // The following values are hand calculated using the formula:
      +    // [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]]
      +    // which defines the reliability weight used for computing the unbiased estimation of variance
      +    // for weighted instances.
      +    assert(summarizer.mean ~== Vectors.dense(Array(-0.42, -0.107, -0.44))
      +      absTol 1E-10, "mean mismatch")
      +    assert(summarizer.variance ~== Vectors.dense(Array(0.17657142857, 1.645115714, 2.42057142857))
      +      absTol 1E-8, "variance mismatch")
      +    assert(summarizer.numNonzeros ~== Vectors.dense(Array(0.3, 0.5, 0.4))
      +      absTol 1E-10, "numNonzeros mismatch")
      +    assert(summarizer.max ~== Vectors.dense(Array(0.0, 1.7, 1.3)) absTol 1E-10, "max mismatch")
      +    assert(summarizer.min ~== Vectors.dense(Array(-0.8, -1.2, -1.7)) absTol 1E-10, "min mismatch")
      +    assert(summarizer.normL2 ~== Vectors.dense(0.387298335, 0.762571308141, 0.9715966241192)
      +      absTol 1E-8, "normL2 mismatch")
      +    assert(summarizer.normL1 ~== Vectors.dense(0.21, 0.4265, 0.61) absTol 1E-10, "normL1 mismatch")
      +  }
       }
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
      index 8972c229b7ec..334bf3790fc7 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
      @@ -70,7 +70,7 @@ object EnsembleTestHelper {
             metricName: String = "mse") {
           val predictions = input.map(x => model.predict(x.features))
           val errors = predictions.zip(input.map(_.label)).map { case (prediction, label) =>
      -      prediction - label
      +      label - prediction
           }
           val metric = metricName match {
             case "mse" =>
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
      index 84dd3b342d4c..6fc9e8df621d 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
      @@ -17,7 +17,7 @@
       
       package org.apache.spark.mllib.tree
       
      -import org.apache.spark.SparkFunSuite
      +import org.apache.spark.{Logging, SparkFunSuite}
       import org.apache.spark.mllib.regression.LabeledPoint
       import org.apache.spark.mllib.tree.configuration.Algo._
       import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
      @@ -31,7 +31,7 @@ import org.apache.spark.util.Utils
       /**
        * Test suite for [[GradientBoostedTrees]].
        */
      -class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext {
      +class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
       
         test("Regression with continuous features: SquaredError") {
           GradientBoostedTreesSuite.testCombinations.foreach {
      @@ -50,7 +50,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
                 EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06)
               } catch {
                 case e: java.lang.AssertionError =>
      -            println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
      +            logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
                     s" subsamplingRate=$subsamplingRate")
                   throw e
               }
      @@ -80,7 +80,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
                 EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.85, "mae")
               } catch {
                 case e: java.lang.AssertionError =>
      -            println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
      +            logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
                     s" subsamplingRate=$subsamplingRate")
                   throw e
               }
      @@ -111,7 +111,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
                 EnsembleTestHelper.validateClassifier(gbt, GradientBoostedTreesSuite.data, 0.9)
               } catch {
                 case e: java.lang.AssertionError =>
      -            println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
      +            logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
                     s" subsamplingRate=$subsamplingRate")
                   throw e
               }
      @@ -166,43 +166,58 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
       
           val algos = Array(Regression, Regression, Classification)
           val losses = Array(SquaredError, AbsoluteError, LogLoss)
      -    (algos zip losses) map {
      -      case (algo, loss) => {
      -        val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
      -          categoricalFeaturesInfo = Map.empty)
      -        val boostingStrategy =
      -          new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
      -        val gbtValidate = new GradientBoostedTrees(boostingStrategy)
      -          .runWithValidation(trainRdd, validateRdd)
      -        val numTrees = gbtValidate.numTrees
      -        assert(numTrees !== numIterations)
      -
      -        // Test that it performs better on the validation dataset.
      -        val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
      -        val (errorWithoutValidation, errorWithValidation) = {
      -          if (algo == Classification) {
      -            val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
      -            (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
      -          } else {
      -            (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
      -          }
      -        }
      -        assert(errorWithValidation <= errorWithoutValidation)
      -
      -        // Test that results from evaluateEachIteration comply with runWithValidation.
      -        // Note that convergenceTol is set to 0.0
      -        val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
      -        assert(evaluationArray.length === numIterations)
      -        assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
      -        var i = 1
      -        while (i < numTrees) {
      -          assert(evaluationArray(i) <= evaluationArray(i - 1))
      -          i += 1
      +    algos.zip(losses).foreach { case (algo, loss) =>
      +      val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
      +        categoricalFeaturesInfo = Map.empty)
      +      val boostingStrategy =
      +        new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
      +      val gbtValidate = new GradientBoostedTrees(boostingStrategy)
      +        .runWithValidation(trainRdd, validateRdd)
      +      val numTrees = gbtValidate.numTrees
      +      assert(numTrees !== numIterations)
      +
      +      // Test that it performs better on the validation dataset.
      +      val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
      +      val (errorWithoutValidation, errorWithValidation) = {
      +        if (algo == Classification) {
      +          val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
      +          (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
      +        } else {
      +          (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
               }
             }
      +      assert(errorWithValidation <= errorWithoutValidation)
      +
      +      // Test that results from evaluateEachIteration comply with runWithValidation.
      +      // Note that convergenceTol is set to 0.0
      +      val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
      +      assert(evaluationArray.length === numIterations)
      +      assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
      +      var i = 1
      +      while (i < numTrees) {
      +        assert(evaluationArray(i) <= evaluationArray(i - 1))
      +        i += 1
      +      }
           }
         }
       
      +  test("Checkpointing") {
      +    val tempDir = Utils.createTempDir()
      +    val path = tempDir.toURI.toString
      +    sc.setCheckpointDir(path)
      +
      +    val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2)
      +
      +    val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
      +      categoricalFeaturesInfo = Map.empty, checkpointInterval = 2)
      +    val boostingStrategy = new BoostingStrategy(treeStrategy, SquaredError, 5, 0.1)
      +
      +    val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
      +
      +    sc.checkpointDir = None
      +    Utils.deleteRecursively(tempDir)
      +  }
      +
       }
       
       private object GradientBoostedTreesSuite {
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala
      index 5e9101cdd380..525ab68c7921 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala
      @@ -26,7 +26,7 @@ trait LocalClusterSparkContext extends BeforeAndAfterAll { self: Suite =>
       
         override def beforeAll() {
           val conf = new SparkConf()
      -      .setMaster("local-cluster[2, 1, 512]")
      +      .setMaster("local-cluster[2, 1, 1024]")
             .setAppName("test-cluster")
             .set("spark.akka.frameSize", "1") // set to 1MB to detect direct serialization of data
           sc = new SparkContext(conf)
      diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala
      index 8dcb9ba9be10..16d7c3ab39b0 100644
      --- a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala
      +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala
      @@ -33,8 +33,15 @@ class NumericParserSuite extends SparkFunSuite {
           malformatted.foreach { s =>
             intercept[SparkException] {
               NumericParser.parse(s)
      -        println(s"Didn't detect malformatted string $s.")
      +        throw new RuntimeException(s"Didn't detect malformatted string $s.")
             }
           }
         }
      +
      +  test("parser with whitespaces") {
      +    val s = "(0.0, [1.0, 2.0])"
      +    val parsed = NumericParser.parse(s).asInstanceOf[Seq[_]]
      +    assert(parsed(0).asInstanceOf[Double] === 0.0)
      +    assert(parsed(1).asInstanceOf[Array[Double]] === Array(1.0, 2.0))
      +  }
       }
      diff --git a/network/common/pom.xml b/network/common/pom.xml
      index a85e0a66f4a3..1cc054a8936c 100644
      --- a/network/common/pom.xml
      +++ b/network/common/pom.xml
      @@ -22,7 +22,7 @@
         
           org.apache.spark
           spark-parent_2.10
      -    1.5.0-SNAPSHOT
      +    1.6.0-SNAPSHOT
           ../../pom.xml
         
       
      @@ -48,6 +48,10 @@
             slf4j-api
             provided
           
      +    
      +      com.google.code.findbugs
      +      jsr305
      +    
           
           
             org.slf4j
      @@ -79,7 +95,7 @@
           
           
             org.mockito
      -      mockito-all
      +      mockito-core
             test
           
         
      diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
      index e4faaf8854fc..3ddf5c3c3918 100644
      --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
      +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
      @@ -17,11 +17,12 @@
       
       package org.apache.spark.network.shuffle;
       
      +import java.io.File;
      +import java.io.IOException;
       import java.util.List;
       
       import com.google.common.annotations.VisibleForTesting;
       import com.google.common.collect.Lists;
      -import org.apache.spark.network.util.TransportConf;
       import org.slf4j.Logger;
       import org.slf4j.LoggerFactory;
       
      @@ -31,10 +32,10 @@
       import org.apache.spark.network.server.OneForOneStreamManager;
       import org.apache.spark.network.server.RpcHandler;
       import org.apache.spark.network.server.StreamManager;
      -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
      -import org.apache.spark.network.shuffle.protocol.OpenBlocks;
      -import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
      -import org.apache.spark.network.shuffle.protocol.StreamHandle;
      +import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId;
      +import org.apache.spark.network.shuffle.protocol.*;
      +import org.apache.spark.network.util.TransportConf;
      +
       
       /**
        * RPC Handler for a server which can serve shuffle blocks from outside of an Executor process.
      @@ -46,16 +47,18 @@
       public class ExternalShuffleBlockHandler extends RpcHandler {
         private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class);
       
      -  private final ExternalShuffleBlockResolver blockManager;
      +  @VisibleForTesting
      +  final ExternalShuffleBlockResolver blockManager;
         private final OneForOneStreamManager streamManager;
       
      -  public ExternalShuffleBlockHandler(TransportConf conf) {
      -    this(new OneForOneStreamManager(), new ExternalShuffleBlockResolver(conf));
      +  public ExternalShuffleBlockHandler(TransportConf conf, File registeredExecutorFile) throws IOException {
      +    this(new OneForOneStreamManager(),
      +      new ExternalShuffleBlockResolver(conf, registeredExecutorFile));
         }
       
         /** Enables mocking out the StreamManager and BlockManager. */
         @VisibleForTesting
      -  ExternalShuffleBlockHandler(
      +  public ExternalShuffleBlockHandler(
             OneForOneStreamManager streamManager,
             ExternalShuffleBlockResolver blockManager) {
           this.streamManager = streamManager;
      @@ -65,20 +68,28 @@ public ExternalShuffleBlockHandler(TransportConf conf) {
         @Override
         public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
           BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteArray(message);
      +    handleMessage(msgObj, client, callback);
      +  }
       
      +  protected void handleMessage(
      +      BlockTransferMessage msgObj,
      +      TransportClient client,
      +      RpcResponseCallback callback) {
           if (msgObj instanceof OpenBlocks) {
             OpenBlocks msg = (OpenBlocks) msgObj;
      -      List blocks = Lists.newArrayList();
      +      checkAuth(client, msg.appId);
       
      +      List blocks = Lists.newArrayList();
             for (String blockId : msg.blockIds) {
               blocks.add(blockManager.getBlockData(msg.appId, msg.execId, blockId));
             }
      -      long streamId = streamManager.registerStream(blocks.iterator());
      +      long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator());
             logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length);
             callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray());
       
           } else if (msgObj instanceof RegisterExecutor) {
             RegisterExecutor msg = (RegisterExecutor) msgObj;
      +      checkAuth(client, msg.appId);
             blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo);
             callback.onSuccess(new byte[0]);
       
      @@ -99,4 +110,30 @@ public StreamManager getStreamManager() {
         public void applicationRemoved(String appId, boolean cleanupLocalDirs) {
           blockManager.applicationRemoved(appId, cleanupLocalDirs);
         }
      +
      +  /**
      +   * Register an (application, executor) with the given shuffle info.
      +   *
      +   * The "re-" is meant to highlight the intended use of this method -- when this service is
      +   * restarted, this is used to restore the state of executors from before the restart.  Normal
      +   * registration will happen via a message handled in receive()
      +   *
      +   * @param appExecId
      +   * @param executorInfo
      +   */
      +  public void reregisterExecutor(AppExecId appExecId, ExecutorShuffleInfo executorInfo) {
      +    blockManager.registerExecutor(appExecId.appId, appExecId.execId, executorInfo);
      +  }
      +
      +  public void close() {
      +    blockManager.close();
      +  }
      +
      +  private void checkAuth(TransportClient client, String appId) {
      +    if (client.getClientId() != null && !client.getClientId().equals(appId)) {
      +      throw new SecurityException(String.format(
      +        "Client for %s not authorized for application %s.", client.getClientId(), appId));
      +    }
      +  }
      +
       }
      diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
      index 022ed88a1648..79beec4429a9 100644
      --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
      +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
      @@ -17,19 +17,24 @@
       
       package org.apache.spark.network.shuffle;
       
      -import java.io.DataInputStream;
      -import java.io.File;
      -import java.io.FileInputStream;
      -import java.io.IOException;
      -import java.util.Iterator;
      -import java.util.Map;
      +import java.io.*;
      +import java.util.*;
       import java.util.concurrent.ConcurrentMap;
       import java.util.concurrent.Executor;
       import java.util.concurrent.Executors;
       
      +import com.fasterxml.jackson.annotation.JsonCreator;
      +import com.fasterxml.jackson.annotation.JsonProperty;
      +import com.fasterxml.jackson.databind.ObjectMapper;
       import com.google.common.annotations.VisibleForTesting;
      +import com.google.common.base.Charsets;
       import com.google.common.base.Objects;
       import com.google.common.collect.Maps;
      +import org.fusesource.leveldbjni.JniDBFactory;
      +import org.fusesource.leveldbjni.internal.NativeDB;
      +import org.iq80.leveldb.DB;
      +import org.iq80.leveldb.DBIterator;
      +import org.iq80.leveldb.Options;
       import org.slf4j.Logger;
       import org.slf4j.LoggerFactory;
       
      @@ -52,25 +57,87 @@
       public class ExternalShuffleBlockResolver {
         private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockResolver.class);
       
      +  private static final ObjectMapper mapper = new ObjectMapper();
      +  /**
      +   * This a common prefix to the key for each app registration we stick in leveldb, so they
      +   * are easy to find, since leveldb lets you search based on prefix.
      +   */
      +  private static final String APP_KEY_PREFIX = "AppExecShuffleInfo";
      +  private static final StoreVersion CURRENT_VERSION = new StoreVersion(1, 0);
      +
         // Map containing all registered executors' metadata.
      -  private final ConcurrentMap executors;
      +  @VisibleForTesting
      +  final ConcurrentMap executors;
       
         // Single-threaded Java executor used to perform expensive recursive directory deletion.
         private final Executor directoryCleaner;
       
         private final TransportConf conf;
       
      -  public ExternalShuffleBlockResolver(TransportConf conf) {
      -    this(conf, Executors.newSingleThreadExecutor(
      +  @VisibleForTesting
      +  final File registeredExecutorFile;
      +  @VisibleForTesting
      +  final DB db;
      +
      +  public ExternalShuffleBlockResolver(TransportConf conf, File registeredExecutorFile)
      +      throws IOException {
      +    this(conf, registeredExecutorFile, Executors.newSingleThreadExecutor(
               // Add `spark` prefix because it will run in NM in Yarn mode.
               NettyUtils.createThreadFactory("spark-shuffle-directory-cleaner")));
         }
       
         // Allows tests to have more control over when directories are cleaned up.
         @VisibleForTesting
      -  ExternalShuffleBlockResolver(TransportConf conf, Executor directoryCleaner) {
      +  ExternalShuffleBlockResolver(
      +      TransportConf conf,
      +      File registeredExecutorFile,
      +      Executor directoryCleaner) throws IOException {
           this.conf = conf;
      -    this.executors = Maps.newConcurrentMap();
      +    this.registeredExecutorFile = registeredExecutorFile;
      +    if (registeredExecutorFile != null) {
      +      Options options = new Options();
      +      options.createIfMissing(false);
      +      options.logger(new LevelDBLogger());
      +      DB tmpDb;
      +      try {
      +        tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options);
      +      } catch (NativeDB.DBException e) {
      +        if (e.isNotFound() || e.getMessage().contains(" does not exist ")) {
      +          logger.info("Creating state database at " + registeredExecutorFile);
      +          options.createIfMissing(true);
      +          try {
      +            tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options);
      +          } catch (NativeDB.DBException dbExc) {
      +            throw new IOException("Unable to create state store", dbExc);
      +          }
      +        } else {
      +          // the leveldb file seems to be corrupt somehow.  Lets just blow it away and create a new
      +          // one, so we can keep processing new apps
      +          logger.error("error opening leveldb file {}.  Creating new file, will not be able to " +
      +            "recover state for existing applications", registeredExecutorFile, e);
      +          if (registeredExecutorFile.isDirectory()) {
      +            for (File f : registeredExecutorFile.listFiles()) {
      +              f.delete();
      +            }
      +          }
      +          registeredExecutorFile.delete();
      +          options.createIfMissing(true);
      +          try {
      +            tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options);
      +          } catch (NativeDB.DBException dbExc) {
      +            throw new IOException("Unable to create state store", dbExc);
      +          }
      +
      +        }
      +      }
      +      // if there is a version mismatch, we throw an exception, which means the service is unusable
      +      checkVersion(tmpDb);
      +      executors = reloadRegisteredExecutors(tmpDb);
      +      db = tmpDb;
      +    } else {
      +      db = null;
      +      executors = Maps.newConcurrentMap();
      +    }
           this.directoryCleaner = directoryCleaner;
         }
       
      @@ -81,6 +148,15 @@ public void registerExecutor(
             ExecutorShuffleInfo executorInfo) {
           AppExecId fullId = new AppExecId(appId, execId);
           logger.info("Registered executor {} with {}", fullId, executorInfo);
      +    try {
      +      if (db != null) {
      +        byte[] key = dbAppExecKey(fullId);
      +        byte[] value = mapper.writeValueAsString(executorInfo).getBytes(Charsets.UTF_8);
      +        db.put(key, value);
      +      }
      +    } catch (Exception e) {
      +      logger.error("Error saving registered executors", e);
      +    }
           executors.put(fullId, executorInfo);
         }
       
      @@ -136,6 +212,13 @@ public void applicationRemoved(String appId, boolean cleanupLocalDirs) {
             // Only touch executors associated with the appId that was removed.
             if (appId.equals(fullId.appId)) {
               it.remove();
      +        if (db != null) {
      +          try {
      +            db.delete(dbAppExecKey(fullId));
      +          } catch (IOException e) {
      +            logger.error("Error deleting {} from executor state db", appId, e);
      +          }
      +        }
       
               if (cleanupLocalDirs) {
                 logger.info("Cleaning up executor {}'s {} local dirs", fullId, executor.localDirs.length);
      @@ -220,12 +303,23 @@ static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename)
           return new File(new File(localDir, String.format("%02x", subDirId)), filename);
         }
       
      +  void close() {
      +    if (db != null) {
      +      try {
      +        db.close();
      +      } catch (IOException e) {
      +        logger.error("Exception closing leveldb with registered executors", e);
      +      }
      +    }
      +  }
      +
         /** Simply encodes an executor's full ID, which is appId + execId. */
      -  private static class AppExecId {
      -    final String appId;
      -    final String execId;
      +  public static class AppExecId {
      +    public final String appId;
      +    public final String execId;
       
      -    private AppExecId(String appId, String execId) {
      +    @JsonCreator
      +    public AppExecId(@JsonProperty("appId") String appId, @JsonProperty("execId") String execId) {
             this.appId = appId;
             this.execId = execId;
           }
      @@ -252,4 +346,105 @@ public String toString() {
               .toString();
           }
         }
      +
      +  private static byte[] dbAppExecKey(AppExecId appExecId) throws IOException {
      +    // we stick a common prefix on all the keys so we can find them in the DB
      +    String appExecJson = mapper.writeValueAsString(appExecId);
      +    String key = (APP_KEY_PREFIX + ";" + appExecJson);
      +    return key.getBytes(Charsets.UTF_8);
      +  }
      +
      +  private static AppExecId parseDbAppExecKey(String s) throws IOException {
      +    if (!s.startsWith(APP_KEY_PREFIX)) {
      +      throw new IllegalArgumentException("expected a string starting with " + APP_KEY_PREFIX);
      +    }
      +    String json = s.substring(APP_KEY_PREFIX.length() + 1);
      +    AppExecId parsed = mapper.readValue(json, AppExecId.class);
      +    return parsed;
      +  }
      +
      +  @VisibleForTesting
      +  static ConcurrentMap reloadRegisteredExecutors(DB db)
      +      throws IOException {
      +    ConcurrentMap registeredExecutors = Maps.newConcurrentMap();
      +    if (db != null) {
      +      DBIterator itr = db.iterator();
      +      itr.seek(APP_KEY_PREFIX.getBytes(Charsets.UTF_8));
      +      while (itr.hasNext()) {
      +        Map.Entry e = itr.next();
      +        String key = new String(e.getKey(), Charsets.UTF_8);
      +        if (!key.startsWith(APP_KEY_PREFIX)) {
      +          break;
      +        }
      +        AppExecId id = parseDbAppExecKey(key);
      +        ExecutorShuffleInfo shuffleInfo = mapper.readValue(e.getValue(), ExecutorShuffleInfo.class);
      +        registeredExecutors.put(id, shuffleInfo);
      +      }
      +    }
      +    return registeredExecutors;
      +  }
      +
      +  private static class LevelDBLogger implements org.iq80.leveldb.Logger {
      +    private static final Logger LOG = LoggerFactory.getLogger(LevelDBLogger.class);
      +
      +    @Override
      +    public void log(String message) {
      +      LOG.info(message);
      +    }
      +  }
      +
      +  /**
      +   * Simple major.minor versioning scheme.  Any incompatible changes should be across major
      +   * versions.  Minor version differences are allowed -- meaning we should be able to read
      +   * dbs that are either earlier *or* later on the minor version.
      +   */
      +  private static void checkVersion(DB db) throws IOException {
      +    byte[] bytes = db.get(StoreVersion.KEY);
      +    if (bytes == null) {
      +      storeVersion(db);
      +    } else {
      +      StoreVersion version = mapper.readValue(bytes, StoreVersion.class);
      +      if (version.major != CURRENT_VERSION.major) {
      +        throw new IOException("cannot read state DB with version " + version + ", incompatible " +
      +          "with current version " + CURRENT_VERSION);
      +      }
      +      storeVersion(db);
      +    }
      +  }
      +
      +  private static void storeVersion(DB db) throws IOException {
      +    db.put(StoreVersion.KEY, mapper.writeValueAsBytes(CURRENT_VERSION));
      +  }
      +
      +
      +  public static class StoreVersion {
      +
      +    final static byte[] KEY = "StoreVersion".getBytes(Charsets.UTF_8);
      +
      +    public final int major;
      +    public final int minor;
      +
      +    @JsonCreator public StoreVersion(@JsonProperty("major") int major, @JsonProperty("minor") int minor) {
      +      this.major = major;
      +      this.minor = minor;
      +    }
      +
      +    @Override
      +    public boolean equals(Object o) {
      +      if (this == o) return true;
      +      if (o == null || getClass() != o.getClass()) return false;
      +
      +      StoreVersion that = (StoreVersion) o;
      +
      +      return major == that.major && minor == that.minor;
      +    }
      +
      +    @Override
      +    public int hashCode() {
      +      int result = major;
      +      result = 31 * result + minor;
      +      return result;
      +    }
      +  }
      +
       }
      diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
      index 612bce571a49..ea6d248d66be 100644
      --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
      +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
      @@ -50,8 +50,8 @@ public class ExternalShuffleClient extends ShuffleClient {
         private final boolean saslEncryptionEnabled;
         private final SecretKeyHolder secretKeyHolder;
       
      -  private TransportClientFactory clientFactory;
      -  private String appId;
      +  protected TransportClientFactory clientFactory;
      +  protected String appId;
       
         /**
          * Creates an external shuffle client, with SASL optionally enabled. If SASL is not enabled,
      @@ -71,6 +71,10 @@ public ExternalShuffleClient(
           this.saslEncryptionEnabled = saslEncryptionEnabled;
         }
       
      +  protected void checkInit() {
      +    assert appId != null : "Called before init()";
      +  }
      +
         @Override
         public void init(String appId) {
           this.appId = appId;
      @@ -89,7 +93,7 @@ public void fetchBlocks(
             final String execId,
             String[] blockIds,
             BlockFetchingListener listener) {
      -    assert appId != null : "Called before init()";
      +    checkInit();
           logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
           try {
             RetryingBlockFetcher.BlockFetchStarter blockFetchStarter =
      @@ -132,7 +136,7 @@ public void registerWithShuffleServer(
             int port,
             String execId,
             ExecutorShuffleInfo executorInfo) throws IOException {
      -    assert appId != null : "Called before init()";
      +    checkInit();
           TransportClient client = clientFactory.createClient(host, port);
           byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray();
           client.sendRpcSync(registerMessage, 5000 /* timeoutMs */);
      diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
      new file mode 100644
      index 000000000000..7543b6be4f2a
      --- /dev/null
      +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
      @@ -0,0 +1,72 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.network.shuffle.mesos;
      +
      +import java.io.IOException;
      +
      +import org.slf4j.Logger;
      +import org.slf4j.LoggerFactory;
      +
      +import org.apache.spark.network.client.RpcResponseCallback;
      +import org.apache.spark.network.client.TransportClient;
      +import org.apache.spark.network.sasl.SecretKeyHolder;
      +import org.apache.spark.network.shuffle.ExternalShuffleClient;
      +import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver;
      +import org.apache.spark.network.util.TransportConf;
      +
      +/**
      + * A client for talking to the external shuffle service in Mesos coarse-grained mode.
      + *
      + * This is used by the Spark driver to register with each external shuffle service on the cluster.
      + * The reason why the driver has to talk to the service is for cleaning up shuffle files reliably
      + * after the application exits. Mesos does not provide a great alternative to do this, so Spark
      + * has to detect this itself.
      + */
      +public class MesosExternalShuffleClient extends ExternalShuffleClient {
      +  private final Logger logger = LoggerFactory.getLogger(MesosExternalShuffleClient.class);
      +
      +  /**
      +   * Creates an Mesos external shuffle client that wraps the {@link ExternalShuffleClient}.
      +   * Please refer to docs on {@link ExternalShuffleClient} for more information.
      +   */
      +  public MesosExternalShuffleClient(
      +      TransportConf conf,
      +      SecretKeyHolder secretKeyHolder,
      +      boolean saslEnabled,
      +      boolean saslEncryptionEnabled) {
      +    super(conf, secretKeyHolder, saslEnabled, saslEncryptionEnabled);
      +  }
      +
      +  public void registerDriverWithShuffleService(String host, int port) throws IOException {
      +    checkInit();
      +    byte[] registerDriver = new RegisterDriver(appId).toByteArray();
      +    TransportClient client = clientFactory.createClient(host, port);
      +    client.sendRpc(registerDriver, new RpcResponseCallback() {
      +      @Override
      +      public void onSuccess(byte[] response) {
      +        logger.info("Successfully registered app " + appId + " with external shuffle service.");
      +      }
      +
      +      @Override
      +      public void onFailure(Throwable e) {
      +        logger.warn("Unable to register app " + appId + " with external shuffle service. " +
      +          "Please manually remove shuffle data after driver exit. Error: " + e);
      +      }
      +    });
      +  }
      +}
      diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
      index 6c1210b33268..fcb52363e632 100644
      --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
      +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
      @@ -21,6 +21,7 @@
       import io.netty.buffer.Unpooled;
       
       import org.apache.spark.network.protocol.Encodable;
      +import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver;
       
       /**
        * Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or
      @@ -37,7 +38,7 @@ public abstract class BlockTransferMessage implements Encodable {
       
         /** Preceding every serialized message is its type, which allows us to deserialize it. */
         public static enum Type {
      -    OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3);
      +    OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4);
       
           private final byte id;
       
      @@ -60,6 +61,7 @@ public static BlockTransferMessage fromByteArray(byte[] msg) {
               case 1: return UploadBlock.decode(buf);
               case 2: return RegisterExecutor.decode(buf);
               case 3: return StreamHandle.decode(buf);
      +        case 4: return RegisterDriver.decode(buf);
               default: throw new IllegalArgumentException("Unknown message type: " + type);
             }
           }
      diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java
      index cadc8e8369c6..102d4efb8bf3 100644
      --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java
      +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java
      @@ -19,6 +19,8 @@
       
       import java.util.Arrays;
       
      +import com.fasterxml.jackson.annotation.JsonCreator;
      +import com.fasterxml.jackson.annotation.JsonProperty;
       import com.google.common.base.Objects;
       import io.netty.buffer.ByteBuf;
       
      @@ -34,7 +36,11 @@ public class ExecutorShuffleInfo implements Encodable {
         /** Shuffle manager (SortShuffleManager or HashShuffleManager) that the executor is using. */
         public final String shuffleManager;
       
      -  public ExecutorShuffleInfo(String[] localDirs, int subDirsPerLocalDir, String shuffleManager) {
      +  @JsonCreator
      +  public ExecutorShuffleInfo(
      +      @JsonProperty("localDirs") String[] localDirs,
      +      @JsonProperty("subDirsPerLocalDir") int subDirsPerLocalDir,
      +      @JsonProperty("shuffleManager") String shuffleManager) {
           this.localDirs = localDirs;
           this.subDirsPerLocalDir = subDirsPerLocalDir;
           this.shuffleManager = shuffleManager;
      diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java
      index cca8b17c4f12..167ef3310422 100644
      --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java
      +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java
      @@ -27,7 +27,7 @@
       
       /**
        * Initial registration message between an executor and its local shuffle server.
      - * Returns nothing (empty bye array).
      + * Returns nothing (empty byte array).
        */
       public class RegisterExecutor extends BlockTransferMessage {
         public final String appId;
      diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java
      new file mode 100644
      index 000000000000..94a61d6caadc
      --- /dev/null
      +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java
      @@ -0,0 +1,63 @@
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
      + */
      +
      +package org.apache.spark.network.shuffle.protocol.mesos;
      +
      +import com.google.common.base.Objects;
      +import io.netty.buffer.ByteBuf;
      +
      +import org.apache.spark.network.protocol.Encoders;
      +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
      +
      +// Needed by ScalaDoc. See SPARK-7726
      +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
      +
      +/**
      + * A message sent from the driver to register with the MesosExternalShuffleService.
      + */
      +public class RegisterDriver extends BlockTransferMessage {
      +  private final String appId;
      +
      +  public RegisterDriver(String appId) {
      +    this.appId = appId;
      +  }
      +
      +  public String getAppId() { return appId; }
      +
      +  @Override
      +  protected Type type() { return Type.REGISTER_DRIVER; }
      +
      +  @Override
      +  public int encodedLength() {
      +    return Encoders.Strings.encodedLength(appId);
      +  }
      +
      +  @Override
      +  public void encode(ByteBuf buf) {
      +    Encoders.Strings.encode(buf, appId);
      +  }
      +
      +  @Override
      +  public int hashCode() {
      +    return Objects.hashCode(appId);
      +  }
      +
      +  public static RegisterDriver decode(ByteBuf buf) {
      +    String appId = Encoders.Strings.decode(buf);
      +    return new RegisterDriver(appId);
      +  }
      +}
      diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
      index 382f613ecbb1..5cb0e4d4a645 100644
      --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
      +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
      @@ -19,6 +19,7 @@
       
       import java.io.IOException;
       import java.util.Arrays;
      +import java.util.concurrent.atomic.AtomicReference;
       
       import com.google.common.collect.Lists;
       import org.junit.After;
      @@ -27,9 +28,12 @@
       import org.junit.Test;
       
       import static org.junit.Assert.*;
      +import static org.mockito.Mockito.*;
       
       import org.apache.spark.network.TestUtils;
       import org.apache.spark.network.TransportContext;
      +import org.apache.spark.network.buffer.ManagedBuffer;
      +import org.apache.spark.network.client.ChunkReceivedCallback;
       import org.apache.spark.network.client.RpcResponseCallback;
       import org.apache.spark.network.client.TransportClient;
       import org.apache.spark.network.client.TransportClientBootstrap;
      @@ -39,44 +43,39 @@
       import org.apache.spark.network.server.StreamManager;
       import org.apache.spark.network.server.TransportServer;
       import org.apache.spark.network.server.TransportServerBootstrap;
      +import org.apache.spark.network.shuffle.BlockFetchingListener;
       import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler;
      +import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver;
      +import org.apache.spark.network.shuffle.OneForOneBlockFetcher;
      +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
      +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
      +import org.apache.spark.network.shuffle.protocol.OpenBlocks;
      +import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
      +import org.apache.spark.network.shuffle.protocol.StreamHandle;
       import org.apache.spark.network.util.SystemPropertyConfigProvider;
       import org.apache.spark.network.util.TransportConf;
       
       public class SaslIntegrationSuite {
      -  static ExternalShuffleBlockHandler handler;
         static TransportServer server;
         static TransportConf conf;
         static TransportContext context;
      +  static SecretKeyHolder secretKeyHolder;
       
         TransportClientFactory clientFactory;
       
      -  /** Provides a secret key holder which always returns the given secret key. */
      -  static class TestSecretKeyHolder implements SecretKeyHolder {
      -
      -    private final String secretKey;
      -
      -    TestSecretKeyHolder(String secretKey) {
      -      this.secretKey = secretKey;
      -    }
      -
      -    @Override
      -    public String getSaslUser(String appId) {
      -      return "user";
      -    }
      -    @Override
      -    public String getSecretKey(String appId) {
      -      return secretKey;
      -    }
      -  }
      -
      -
         @BeforeClass
         public static void beforeAll() throws IOException {
      -    SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key");
           conf = new TransportConf(new SystemPropertyConfigProvider());
           context = new TransportContext(conf, new TestRpcHandler());
       
      +    secretKeyHolder = mock(SecretKeyHolder.class);
      +    when(secretKeyHolder.getSaslUser(eq("app-1"))).thenReturn("app-1");
      +    when(secretKeyHolder.getSecretKey(eq("app-1"))).thenReturn("app-1");
      +    when(secretKeyHolder.getSaslUser(eq("app-2"))).thenReturn("app-2");
      +    when(secretKeyHolder.getSecretKey(eq("app-2"))).thenReturn("app-2");
      +    when(secretKeyHolder.getSaslUser(anyString())).thenReturn("other-app");
      +    when(secretKeyHolder.getSecretKey(anyString())).thenReturn("correct-password");
      +
           TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder);
           server = context.createServer(Arrays.asList(bootstrap));
         }
      @@ -99,7 +98,7 @@ public void afterEach() {
         public void testGoodClient() throws IOException {
           clientFactory = context.createClientFactory(
             Lists.newArrayList(
      -        new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key"))));
      +        new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
       
           TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
           String msg = "Hello, World!";
      @@ -109,13 +108,17 @@ public void testGoodClient() throws IOException {
       
         @Test
         public void testBadClient() {
      +    SecretKeyHolder badKeyHolder = mock(SecretKeyHolder.class);
      +    when(badKeyHolder.getSaslUser(anyString())).thenReturn("other-app");
      +    when(badKeyHolder.getSecretKey(anyString())).thenReturn("wrong-password");
           clientFactory = context.createClientFactory(
             Lists.newArrayList(
      -        new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("bad-key"))));
      +        new SaslClientBootstrap(conf, "unknown-app", badKeyHolder)));
       
           try {
             // Bootstrap should fail on startup.
             clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
      +      fail("Connection should have failed.");
           } catch (Exception e) {
             assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response"));
           }
      @@ -149,7 +152,7 @@ public void testNoSaslServer() {
           TransportContext context = new TransportContext(conf, handler);
           clientFactory = context.createClientFactory(
             Lists.newArrayList(
      -        new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("key"))));
      +        new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
           TransportServer server = context.createServer();
           try {
             clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
      @@ -160,6 +163,110 @@ public void testNoSaslServer() {
           }
         }
       
      +  /**
      +   * This test is not actually testing SASL behavior, but testing that the shuffle service
      +   * performs correct authorization checks based on the SASL authentication data.
      +   */
      +  @Test
      +  public void testAppIsolation() throws Exception {
      +    // Start a new server with the correct RPC handler to serve block data.
      +    ExternalShuffleBlockResolver blockResolver = mock(ExternalShuffleBlockResolver.class);
      +    ExternalShuffleBlockHandler blockHandler = new ExternalShuffleBlockHandler(
      +      new OneForOneStreamManager(), blockResolver);
      +    TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder);
      +    TransportContext blockServerContext = new TransportContext(conf, blockHandler);
      +    TransportServer blockServer = blockServerContext.createServer(Arrays.asList(bootstrap));
      +
      +    TransportClient client1 = null;
      +    TransportClient client2 = null;
      +    TransportClientFactory clientFactory2 = null;
      +    try {
      +      // Create a client, and make a request to fetch blocks from a different app.
      +      clientFactory = blockServerContext.createClientFactory(
      +        Lists.newArrayList(
      +          new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
      +      client1 = clientFactory.createClient(TestUtils.getLocalHost(),
      +        blockServer.getPort());
      +
      +      final AtomicReference exception = new AtomicReference<>();
      +
      +      BlockFetchingListener listener = new BlockFetchingListener() {
      +        @Override
      +        public synchronized void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
      +          notifyAll();
      +        }
      +
      +        @Override
      +        public synchronized void onBlockFetchFailure(String blockId, Throwable t) {
      +          exception.set(t);
      +          notifyAll();
      +        }
      +      };
      +
      +      String[] blockIds = new String[] { "shuffle_2_3_4", "shuffle_6_7_8" };
      +      OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client1, "app-2", "0",
      +        blockIds, listener);
      +      synchronized (listener) {
      +        fetcher.start();
      +        listener.wait();
      +      }
      +      checkSecurityException(exception.get());
      +
      +      // Register an executor so that the next steps work.
      +      ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo(
      +        new String[] { System.getProperty("java.io.tmpdir") }, 1,
      +        "org.apache.spark.shuffle.sort.SortShuffleManager");
      +      RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo);
      +      client1.sendRpcSync(regmsg.toByteArray(), 10000);
      +
      +      // Make a successful request to fetch blocks, which creates a new stream. But do not actually
      +      // fetch any blocks, to keep the stream open.
      +      OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds);
      +      byte[] response = client1.sendRpcSync(openMessage.toByteArray(), 10000);
      +      StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response);
      +      long streamId = stream.streamId;
      +
      +      // Create a second client, authenticated with a different app ID, and try to read from
      +      // the stream created for the previous app.
      +      clientFactory2 = blockServerContext.createClientFactory(
      +        Lists.newArrayList(
      +          new SaslClientBootstrap(conf, "app-2", secretKeyHolder)));
      +      client2 = clientFactory2.createClient(TestUtils.getLocalHost(),
      +        blockServer.getPort());
      +
      +      ChunkReceivedCallback callback = new ChunkReceivedCallback() {
      +        @Override
      +        public synchronized void onSuccess(int chunkIndex, ManagedBuffer buffer) {
      +          notifyAll();
      +        }
      +
      +        @Override
      +        public synchronized void onFailure(int chunkIndex, Throwable t) {
      +          exception.set(t);
      +          notifyAll();
      +        }
      +      };
      +
      +      exception.set(null);
      +      synchronized (callback) {
      +        client2.fetchChunk(streamId, 0, callback);
      +        callback.wait();
      +      }
      +      checkSecurityException(exception.get());
      +    } finally {
      +      if (client1 != null) {
      +        client1.close();
      +      }
      +      if (client2 != null) {
      +        client2.close();
      +      }
      +      if (clientFactory2 != null) {
      +        clientFactory2.close();
      +      }
      +      blockServer.close();
      +    }
      +  }
      +
         /** RPC handler which simply responds with the message it received. */
         public static class TestRpcHandler extends RpcHandler {
           @Override
      @@ -172,4 +279,10 @@ public StreamManager getStreamManager() {
             return new OneForOneStreamManager();
           }
         }
      +
      +  private void checkSecurityException(Throwable t) {
      +    assertNotNull("No exception was caught.", t);
      +    assertTrue("Expected SecurityException.",
      +      t.getMessage().contains(SecurityException.class.getName()));
      +  }
       }
      diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
      index 73374cdc77a2..e61390cf5706 100644
      --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
      +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
      @@ -90,9 +90,11 @@ public void testOpenShuffleBlocks() {
             (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response.getValue());
           assertEquals(2, handle.numChunks);
       
      -    ArgumentCaptor stream = ArgumentCaptor.forClass(Iterator.class);
      -    verify(streamManager, times(1)).registerStream(stream.capture());
      -    Iterator buffers = (Iterator) stream.getValue();
      +    @SuppressWarnings("unchecked")
      +    ArgumentCaptor> stream = (ArgumentCaptor>)
      +        (ArgumentCaptor) ArgumentCaptor.forClass(Iterator.class);
      +    verify(streamManager, times(1)).registerStream(anyString(), stream.capture());
      +    Iterator buffers = stream.getValue();
           assertEquals(block0Marker, buffers.next());
           assertEquals(block1Marker, buffers.next());
           assertFalse(buffers.hasNext());
      diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java
      index d02f4f0fdb68..3c6cb367dea4 100644
      --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java
      +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java
      @@ -21,9 +21,12 @@
       import java.io.InputStream;
       import java.io.InputStreamReader;
       
      +import com.fasterxml.jackson.databind.ObjectMapper;
       import com.google.common.io.CharStreams;
      +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
       import org.apache.spark.network.util.SystemPropertyConfigProvider;
       import org.apache.spark.network.util.TransportConf;
      +import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId;
       import org.junit.AfterClass;
       import org.junit.BeforeClass;
       import org.junit.Test;
      @@ -59,8 +62,8 @@ public static void afterAll() {
         }
       
         @Test
      -  public void testBadRequests() {
      -    ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf);
      +  public void testBadRequests() throws IOException {
      +    ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null);
           // Unregistered executor
           try {
             resolver.getBlockData("app0", "exec1", "shuffle_1_1_0");
      @@ -91,7 +94,7 @@ public void testBadRequests() {
       
         @Test
         public void testSortShuffleBlocks() throws IOException {
      -    ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf);
      +    ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null);
           resolver.registerExecutor("app0", "exec0",
             dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager"));
       
      @@ -110,7 +113,7 @@ public void testSortShuffleBlocks() throws IOException {
       
         @Test
         public void testHashShuffleBlocks() throws IOException {
      -    ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf);
      +    ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null);
           resolver.registerExecutor("app0", "exec0",
             dataContext.createExecutorInfo("org.apache.spark.shuffle.hash.HashShuffleManager"));
       
      @@ -126,4 +129,28 @@ public void testHashShuffleBlocks() throws IOException {
           block1Stream.close();
           assertEquals(hashBlock1, block1);
         }
      +
      +  @Test
      +  public void jsonSerializationOfExecutorRegistration() throws IOException {
      +    ObjectMapper mapper = new ObjectMapper();
      +    AppExecId appId = new AppExecId("foo", "bar");
      +    String appIdJson = mapper.writeValueAsString(appId);
      +    AppExecId parsedAppId = mapper.readValue(appIdJson, AppExecId.class);
      +    assertEquals(parsedAppId, appId);
      +
      +    ExecutorShuffleInfo shuffleInfo =
      +      new ExecutorShuffleInfo(new String[]{"/bippy", "/flippy"}, 7, "hash");
      +    String shuffleJson = mapper.writeValueAsString(shuffleInfo);
      +    ExecutorShuffleInfo parsedShuffleInfo =
      +      mapper.readValue(shuffleJson, ExecutorShuffleInfo.class);
      +    assertEquals(parsedShuffleInfo, shuffleInfo);
      +
      +    // Intentionally keep these hard-coded strings in here, to check backwards-compatability.
      +    // its not legacy yet, but keeping this here in case anybody changes it
      +    String legacyAppIdJson = "{\"appId\":\"foo\", \"execId\":\"bar\"}";
      +    assertEquals(appId, mapper.readValue(legacyAppIdJson, AppExecId.class));
      +    String legacyShuffleJson = "{\"localDirs\": [\"/bippy\", \"/flippy\"], " +
      +      "\"subDirsPerLocalDir\": 7, \"shuffleManager\": \"hash\"}";
      +    assertEquals(shuffleInfo, mapper.readValue(legacyShuffleJson, ExecutorShuffleInfo.class));
      +  }
       }
      diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java
      index d9d9c1bf2f17..2f4f1d0df478 100644
      --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java
      +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java
      @@ -42,7 +42,7 @@ public void noCleanupAndCleanup() throws IOException {
           TestShuffleDataContext dataContext = createSomeData();
       
           ExternalShuffleBlockResolver resolver =
      -      new ExternalShuffleBlockResolver(conf, sameThreadExecutor);
      +      new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor);
           resolver.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr"));
           resolver.applicationRemoved("app", false /* cleanup */);
       
      @@ -65,7 +65,8 @@ public void cleanupUsesExecutor() throws IOException {
             @Override public void execute(Runnable runnable) { cleanupCalled.set(true); }
           };
       
      -    ExternalShuffleBlockResolver manager = new ExternalShuffleBlockResolver(conf, noThreadExecutor);
      +    ExternalShuffleBlockResolver manager =
      +      new ExternalShuffleBlockResolver(conf, null, noThreadExecutor);
       
           manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr"));
           manager.applicationRemoved("app", true);
      @@ -83,7 +84,7 @@ public void cleanupMultipleExecutors() throws IOException {
           TestShuffleDataContext dataContext1 = createSomeData();
       
           ExternalShuffleBlockResolver resolver =
      -      new ExternalShuffleBlockResolver(conf, sameThreadExecutor);
      +      new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor);
       
           resolver.registerExecutor("app", "exec0", dataContext0.createExecutorInfo("shuffleMgr"));
           resolver.registerExecutor("app", "exec1", dataContext1.createExecutorInfo("shuffleMgr"));
      @@ -99,7 +100,7 @@ public void cleanupOnlyRemovedApp() throws IOException {
           TestShuffleDataContext dataContext1 = createSomeData();
       
           ExternalShuffleBlockResolver resolver =
      -      new ExternalShuffleBlockResolver(conf, sameThreadExecutor);
      +      new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor);
       
           resolver.registerExecutor("app-0", "exec0", dataContext0.createExecutorInfo("shuffleMgr"));
           resolver.registerExecutor("app-1", "exec0", dataContext1.createExecutorInfo("shuffleMgr"));
      diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
      index 39aa49911d9c..a3f9a38b1aeb 100644
      --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
      +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
      @@ -92,7 +92,7 @@ public static void beforeAll() throws IOException {
           dataContext1.insertHashShuffleData(1, 0, exec1Blocks);
       
           conf = new TransportConf(new SystemPropertyConfigProvider());
      -    handler = new ExternalShuffleBlockHandler(conf);
      +    handler = new ExternalShuffleBlockHandler(conf, null);
           TransportContext transportContext = new TransportContext(conf, handler);
           server = transportContext.createServer();
         }
      diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
      index d4ec1956c1e2..aa99efda9494 100644
      --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
      +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
      @@ -43,8 +43,9 @@ public class ExternalShuffleSecuritySuite {
         TransportServer server;
       
         @Before
      -  public void beforeEach() {
      -    TransportContext context = new TransportContext(conf, new ExternalShuffleBlockHandler(conf));
      +  public void beforeEach() throws IOException {
      +    TransportContext context =
      +      new TransportContext(conf, new ExternalShuffleBlockHandler(conf, null));
           TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf,
               new TestSecretKeyHolder("my-app-id", "secret"));
           this.server = context.createServer(Arrays.asList(bootstrap));
      diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java
      index 1ad0d72ae5ec..06e46f924109 100644
      --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java
      +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java
      @@ -20,7 +20,9 @@
       
       import java.io.IOException;
       import java.nio.ByteBuffer;
      +import java.util.Arrays;
       import java.util.LinkedHashSet;
      +import java.util.List;
       import java.util.Map;
       
       import com.google.common.collect.ImmutableMap;
      @@ -67,13 +69,13 @@ public void afterEach() {
         public void testNoFailures() throws IOException {
           BlockFetchingListener listener = mock(BlockFetchingListener.class);
       
      -    Map[] interactions = new Map[] {
      +    List> interactions = Arrays.asList(
             // Immediately return both blocks successfully.
             ImmutableMap.builder()
               .put("b0", block0)
               .put("b1", block1)
      -        .build(),
      -    };
      +        .build()
      +      );
       
           performInteractions(interactions, listener);
       
      @@ -86,13 +88,13 @@ public void testNoFailures() throws IOException {
         public void testUnrecoverableFailure() throws IOException {
           BlockFetchingListener listener = mock(BlockFetchingListener.class);
       
      -    Map[] interactions = new Map[] {
      +    List> interactions = Arrays.asList(
             // b0 throws a non-IOException error, so it will be failed without retry.
             ImmutableMap.builder()
               .put("b0", new RuntimeException("Ouch!"))
               .put("b1", block1)
      -        .build(),
      -    };
      +        .build()
      +    );
       
           performInteractions(interactions, listener);
       
      @@ -105,7 +107,7 @@ public void testUnrecoverableFailure() throws IOException {
         public void testSingleIOExceptionOnFirst() throws IOException {
           BlockFetchingListener listener = mock(BlockFetchingListener.class);
       
      -    Map[] interactions = new Map[] {
      +    List> interactions = Arrays.asList(
             // IOException will cause a retry. Since b0 fails, we will retry both.
             ImmutableMap.builder()
               .put("b0", new IOException("Connection failed or something"))
      @@ -114,8 +116,8 @@ public void testSingleIOExceptionOnFirst() throws IOException {
             ImmutableMap.builder()
               .put("b0", block0)
               .put("b1", block1)
      -        .build(),
      -    };
      +        .build()
      +    );
       
           performInteractions(interactions, listener);
       
      @@ -128,7 +130,7 @@ public void testSingleIOExceptionOnFirst() throws IOException {
         public void testSingleIOExceptionOnSecond() throws IOException {
           BlockFetchingListener listener = mock(BlockFetchingListener.class);
       
      -    Map[] interactions = new Map[] {
      +    List> interactions = Arrays.asList(
             // IOException will cause a retry. Since b1 fails, we will not retry b0.
             ImmutableMap.builder()
               .put("b0", block0)
      @@ -136,8 +138,8 @@ public void testSingleIOExceptionOnSecond() throws IOException {
               .build(),
             ImmutableMap.builder()
               .put("b1", block1)
      -        .build(),
      -    };
      +        .build()
      +    );
       
           performInteractions(interactions, listener);
       
      @@ -150,7 +152,7 @@ public void testSingleIOExceptionOnSecond() throws IOException {
         public void testTwoIOExceptions() throws IOException {
           BlockFetchingListener listener = mock(BlockFetchingListener.class);
       
      -    Map[] interactions = new Map[] {
      +    List> interactions = Arrays.asList(
             // b0's IOException will trigger retry, b1's will be ignored.
             ImmutableMap.builder()
               .put("b0", new IOException())
      @@ -164,8 +166,8 @@ public void testTwoIOExceptions() throws IOException {
             // b1 returns successfully within 2 retries.
             ImmutableMap.builder()
               .put("b1", block1)
      -        .build(),
      -    };
      +        .build()
      +    );
       
           performInteractions(interactions, listener);
       
      @@ -178,7 +180,7 @@ public void testTwoIOExceptions() throws IOException {
         public void testThreeIOExceptions() throws IOException {
           BlockFetchingListener listener = mock(BlockFetchingListener.class);
       
      -    Map[] interactions = new Map[] {
      +    List> interactions = Arrays.asList(
             // b0's IOException will trigger retry, b1's will be ignored.
             ImmutableMap.builder()
               .put("b0", new IOException())
      @@ -196,8 +198,8 @@ public void testThreeIOExceptions() throws IOException {
             // This is not reached -- b1 has failed.
             ImmutableMap.builder()
               .put("b1", block1)
      -        .build(),
      -    };
      +        .build()
      +    );
       
           performInteractions(interactions, listener);
       
      @@ -210,7 +212,7 @@ public void testThreeIOExceptions() throws IOException {
         public void testRetryAndUnrecoverable() throws IOException {
           BlockFetchingListener listener = mock(BlockFetchingListener.class);
       
      -    Map[] interactions = new Map[] {
      +    List> interactions = Arrays.asList(
             // b0's IOException will trigger retry, subsequent messages will be ignored.
             ImmutableMap.builder()
               .put("b0", new IOException())
      @@ -226,8 +228,8 @@ public void testRetryAndUnrecoverable() throws IOException {
             // b2 succeeds in its last retry.
             ImmutableMap.builder()
               .put("b2", block2)
      -        .build(),
      -    };
      +        .build()
      +    );
       
           performInteractions(interactions, listener);
       
      @@ -248,7 +250,8 @@ public void testRetryAndUnrecoverable() throws IOException {
          * subset of the original blocks in a second interaction.
          */
         @SuppressWarnings("unchecked")
      -  private void performInteractions(final Map[] interactions, BlockFetchingListener listener)
      +  private static void performInteractions(List> interactions,
      +                                          BlockFetchingListener listener)
           throws IOException {
       
           TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
      diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml
      index a99f7c4392d3..e745180eace7 100644
      --- a/network/yarn/pom.xml
      +++ b/network/yarn/pom.xml
      @@ -22,7 +22,7 @@
         
           org.apache.spark
           spark-parent_2.10
      -    1.5.0-SNAPSHOT
      +    1.6.0-SNAPSHOT
           ../../pom.xml
         
       
      diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
      index 463f99ef3352..11ea7f3fd3cf 100644
      --- a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
      +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
      @@ -17,25 +17,21 @@
       
       package org.apache.spark.network.yarn;
       
      +import java.io.File;
       import java.nio.ByteBuffer;
       import java.util.List;
       
      +import com.google.common.annotations.VisibleForTesting;
       import com.google.common.collect.Lists;
       import org.apache.hadoop.conf.Configuration;
      -import org.apache.hadoop.yarn.api.records.ApplicationId;
       import org.apache.hadoop.yarn.api.records.ContainerId;
      -import org.apache.hadoop.yarn.server.api.AuxiliaryService;
      -import org.apache.hadoop.yarn.server.api.ApplicationInitializationContext;
      -import org.apache.hadoop.yarn.server.api.ApplicationTerminationContext;
      -import org.apache.hadoop.yarn.server.api.ContainerInitializationContext;
      -import org.apache.hadoop.yarn.server.api.ContainerTerminationContext;
      +import org.apache.hadoop.yarn.server.api.*;
       import org.slf4j.Logger;
       import org.slf4j.LoggerFactory;
       
       import org.apache.spark.network.TransportContext;
       import org.apache.spark.network.sasl.SaslServerBootstrap;
       import org.apache.spark.network.sasl.ShuffleSecretManager;
      -import org.apache.spark.network.server.RpcHandler;
       import org.apache.spark.network.server.TransportServer;
       import org.apache.spark.network.server.TransportServerBootstrap;
       import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler;
      @@ -79,11 +75,26 @@ public class YarnShuffleService extends AuxiliaryService {
         private TransportServer shuffleServer = null;
       
         // Handles registering executors and opening shuffle blocks
      -  private ExternalShuffleBlockHandler blockHandler;
      +  @VisibleForTesting
      +  ExternalShuffleBlockHandler blockHandler;
      +
      +  // Where to store & reload executor info for recovering state after an NM restart
      +  @VisibleForTesting
      +  File registeredExecutorFile;
      +
      +  // just for testing when you want to find an open port
      +  @VisibleForTesting
      +  static int boundPort = -1;
      +
      +  // just for integration tests that want to look at this file -- in general not sensible as
      +  // a static
      +  @VisibleForTesting
      +  static YarnShuffleService instance;
       
         public YarnShuffleService() {
           super("spark_shuffle");
           logger.info("Initializing YARN shuffle service for Spark");
      +    instance = this;
         }
       
         /**
      @@ -100,11 +111,24 @@ private boolean isAuthenticationEnabled() {
          */
         @Override
         protected void serviceInit(Configuration conf) {
      +
      +    // In case this NM was killed while there were running spark applications, we need to restore
      +    // lost state for the existing executors.  We look for an existing file in the NM's local dirs.
      +    // If we don't find one, then we choose a file to use to save the state next time.  Even if
      +    // an application was stopped while the NM was down, we expect yarn to call stopApplication()
      +    // when it comes back
      +    registeredExecutorFile =
      +      findRegisteredExecutorFile(conf.getStrings("yarn.nodemanager.local-dirs"));
      +
           TransportConf transportConf = new TransportConf(new HadoopConfigProvider(conf));
           // If authentication is enabled, set up the shuffle server to use a
           // special RPC handler that filters out unauthenticated fetch requests
           boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE);
      -    blockHandler = new ExternalShuffleBlockHandler(transportConf);
      +    try {
      +      blockHandler = new ExternalShuffleBlockHandler(transportConf, registeredExecutorFile);
      +    } catch (Exception e) {
      +      logger.error("Failed to initialize external shuffle service", e);
      +    }
       
           List bootstraps = Lists.newArrayList();
           if (authEnabled) {
      @@ -116,9 +140,13 @@ protected void serviceInit(Configuration conf) {
             SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT);
           TransportContext transportContext = new TransportContext(transportConf, blockHandler);
           shuffleServer = transportContext.createServer(port, bootstraps);
      +    // the port should normally be fixed, but for tests its useful to find an open port
      +    port = shuffleServer.getPort();
      +    boundPort = port;
           String authEnabledString = authEnabled ? "enabled" : "not enabled";
           logger.info("Started YARN shuffle service for Spark on port {}. " +
      -      "Authentication is {}.", port, authEnabledString);
      +      "Authentication is {}.  Registered executor file is {}", port, authEnabledString,
      +      registeredExecutorFile);
         }
       
         @Override
      @@ -161,6 +189,16 @@ public void stopContainer(ContainerTerminationContext context) {
           logger.info("Stopping container {}", containerId);
         }
       
      +  private File findRegisteredExecutorFile(String[] localDirs) {
      +    for (String dir: localDirs) {
      +      File f = new File(dir, "registeredExecutors.ldb");
      +      if (f.exists()) {
      +        return f;
      +      }
      +    }
      +    return new File(localDirs[0], "registeredExecutors.ldb");
      +  }
      +
         /**
          * Close the shuffle server to clean up any associated state.
          */
      @@ -170,6 +208,9 @@ protected void serviceStop() {
             if (shuffleServer != null) {
               shuffleServer.close();
             }
      +      if (blockHandler != null) {
      +        blockHandler.close();
      +      }
           } catch (Exception e) {
             logger.error("Exception when stopping service", e);
           }
      @@ -180,5 +221,4 @@ protected void serviceStop() {
         public ByteBuffer getMetaData() {
           return ByteBuffer.allocate(0);
         }
      -
       }
      diff --git a/pom.xml b/pom.xml
      index 6d4f717d4931..653599464114 100644
      --- a/pom.xml
      +++ b/pom.xml
      @@ -26,7 +26,7 @@
         
         org.apache.spark
         spark-parent_2.10
      -  1.5.0-SNAPSHOT
      +  1.6.0-SNAPSHOT
         pom
         Spark Project Parent POM
         http://spark.apache.org/
      @@ -59,7 +59,7 @@
         
       
         
      -    3.0.4
      +    ${maven.version}
         
       
         
      @@ -87,7 +87,7 @@
       
         
           core
      -    bagel
      +    bagel 
           graphx
           mllib
           tools
      @@ -102,7 +102,9 @@
           external/twitter
           external/flume
           external/flume-sink
      +    external/flume-assembly
           external/mqtt
      +    external/mqtt-assembly
           external/zeromq
           examples
           repl
      @@ -117,6 +119,7 @@
           com.typesafe.akka
           2.3.11
           1.7
      +    3.3.3
           spark
           0.21.1
           shaded-protobuf
      @@ -127,40 +130,55 @@
           ${hadoop.version}
           0.98.7-hadoop2
           hbase
      -    1.4.0
      +    1.6.0
           3.4.5
           2.4.0
           org.spark-project.hive
           
      -    0.13.1a
      +    1.2.1.spark
           
      -    0.13.1
      +    1.2.1
           10.10.1.1
           1.7.0
      +    1.6.0
           1.2.4
           8.1.14.v20131031
           3.0.0.v201112011016
           0.5.0
           2.4.0
           2.0.8
      -    3.1.0
      +    3.1.2
           1.7.7
           hadoop2
           0.7.1
           1.9.16
           1.2.1
      +    
           4.3.2
      +    
      +    3.1
           3.4.1
      -    ${project.build.directory}/spark-test-classpath.txt
           2.10.4
           2.10
           ${scala.version}
           org.scala-lang
      -    3.6.3
           1.9.13
           2.4.4
           1.1.1.7
           1.1.2
      +    1.2.0-incubating
      +    1.10
      +    
      +    2.6
      +    
      +    3.3.2
      +    3.2.10
      +    2.7.8
      +    1.9
      +    2.5
      +    3.5.2
      +    1.3.9
      +    0.9.2
       
           ${java.home}
       
      @@ -177,6 +195,7 @@
           compile
           compile
           compile
      +    test
       
           
           
             spring-releases
             Spring Release Repository
             https://repo.spring.io/libs-release
             
      -        true
      +        false
             
             
               false
             
           
      -    
      +    
           
      -      spark-1.4-staging
      -      Spark 1.4 RC4 Staging Repository
      -      https://repository.apache.org/content/repositories/orgapachespark-1112
      +      twttr-repo
      +      Twttr Repository
      +      http://maven.twttr.com
             
               true
             
      @@ -303,17 +330,6 @@
             unused
             1.0.0
           
      -    
      -    
      -      org.codehaus.groovy
      -      groovy-all
      -      2.3.7
      -      provided
      -    
           
             
      @@ -587,7 +618,7 @@
             
               io.netty
               netty-all
      -        4.0.28.Final
      +        4.0.29.Final
             
             
               org.apache.derby
      @@ -624,11 +655,16 @@
               jackson-databind
               ${fasterxml.jackson.version}
             
      +      
      +        com.fasterxml.jackson.core
      +        jackson-annotations
      +        ${fasterxml.jackson.version}
      +      
             
             
               com.fasterxml.jackson.module
      -        jackson-module-scala_2.10
      +        jackson-module-scala_${scala.binary.version}
               ${fasterxml.jackson.version}
               
                 
      @@ -640,15 +676,26 @@
             
               com.sun.jersey
               jersey-server
      -        1.9
      +        ${jersey.version}
               ${hadoop.deps.scope}
             
             
               com.sun.jersey
               jersey-core
      -        1.9
      +        ${jersey.version}
               ${hadoop.deps.scope}
             
      +      
      +        com.sun.jersey
      +        jersey-json
      +        ${jersey.version}
      +        
      +          
      +            stax
      +            stax-api
      +          
      +        
      +      
             
               org.scala-lang
               scala-compiler
      @@ -682,7 +729,7 @@
             
             
               org.mockito
      -        mockito-all
      +        mockito-core
               1.9.5
               test
             
      @@ -738,6 +785,12 @@
               curator-framework
               ${curator.version}
             
      +      
      +        org.apache.curator
      +        curator-test
      +        ${curator.version}
      +        test
      +      
             
               org.apache.hadoop
               hadoop-client
      @@ -748,6 +801,10 @@
                   asm
                   asm
                 
      +          
      +            org.codehaus.jackson
      +            jackson-mapper-asl
      +          
                 
                   org.ow2.asm
                   asm
      @@ -760,6 +817,10 @@
                   commons-logging
                   commons-logging
                 
      +          
      +            org.mockito
      +            mockito-all
      +          
                 
                   org.mortbay.jetty
                   servlet-api-2.5
      @@ -1024,58 +1085,503 @@
               hive-beeline
               ${hive.version}
               ${hive.deps.scope}
      +        
      +          
      +            ${hive.group}
      +            hive-common
      +          
      +          
      +            ${hive.group}
      +            hive-exec
      +          
      +          
      +            ${hive.group}
      +            hive-jdbc
      +          
      +          
      +            ${hive.group}
      +            hive-metastore
      +          
      +          
      +            ${hive.group}
      +            hive-service
      +          
      +          
      +            ${hive.group}
      +            hive-shims
      +          
      +          
      +            org.apache.thrift
      +            libthrift
      +          
      +          
      +            org.slf4j
      +            slf4j-api
      +          
      +          
      +            org.slf4j
      +            slf4j-log4j12
      +          
      +          
      +            log4j
      +            log4j
      +          
      +          
      +            commons-logging
      +            commons-logging
      +          
      +        
             
             
               ${hive.group}
               hive-cli
               ${hive.version}
               ${hive.deps.scope}
      +        
      +          
      +            ${hive.group}
      +            hive-common
      +          
      +          
      +            ${hive.group}
      +            hive-exec
      +          
      +          
      +            ${hive.group}
      +            hive-jdbc
      +          
      +          
      +            ${hive.group}
      +            hive-metastore
      +          
      +          
      +            ${hive.group}
      +            hive-serde
      +          
      +          
      +            ${hive.group}
      +            hive-service
      +          
      +          
      +            ${hive.group}
      +            hive-shims
      +          
      +          
      +            org.apache.thrift
      +            libthrift
      +          
      +          
      +            org.slf4j
      +            slf4j-api
      +          
      +          
      +            org.slf4j
      +            slf4j-log4j12
      +          
      +          
      +            log4j
      +            log4j
      +          
      +          
      +            commons-logging
      +            commons-logging
      +          
      +        
             
             
               ${hive.group}
      -        hive-exec
      +        hive-common
               ${hive.version}
               ${hive.deps.scope}
               
      +          
      +            ${hive.group}
      +            hive-shims
      +          
      +          
      +            org.apache.ant
      +            ant
      +          
      +          
      +            org.apache.zookeeper
      +            zookeeper
      +          
      +          
      +            org.slf4j
      +            slf4j-api
      +          
      +          
      +            org.slf4j
      +            slf4j-log4j12
      +          
      +          
      +            log4j
      +            log4j
      +          
                 
                   commons-logging
                   commons-logging
                 
      +        
      +      
      +
      +      
      +        ${hive.group}
      +        hive-exec
      +
      +        ${hive.version}
      +        ${hive.deps.scope}
      +        
      +
      +          
      +          
      +            ${hive.group}
      +            hive-metastore
      +          
      +          
      +            ${hive.group}
      +            hive-shims
      +          
      +          
      +            ${hive.group}
      +            hive-ant
      +          
      +          
      +          
      +            ${hive.group}
      +            spark-client
      +          
      +
      +          
      +          
      +            ant
      +            ant
      +          
      +          
      +            org.apache.ant
      +            ant
      +          
                 
                   com.esotericsoftware.kryo
                   kryo
                 
      +          
      +            commons-codec
      +            commons-codec
      +          
      +          
      +            commons-httpclient
      +            commons-httpclient
      +          
                 
                   org.apache.avro
                   avro-mapred
                 
      +          
      +          
      +            org.apache.calcite
      +            calcite-core
      +          
      +          
      +            org.apache.curator
      +            apache-curator
      +          
      +          
      +            org.apache.curator
      +            curator-client
      +          
      +          
      +            org.apache.curator
      +            curator-framework
      +          
      +          
      +            org.apache.thrift
      +            libthrift
      +          
      +          
      +            org.apache.thrift
      +            libfb303
      +          
      +          
      +            org.apache.zookeeper
      +            zookeeper
      +          
      +          
      +            org.slf4j
      +            slf4j-api
      +          
      +          
      +            org.slf4j
      +            slf4j-log4j12
      +          
      +          
      +            log4j
      +            log4j
      +          
      +          
      +            commons-logging
      +            commons-logging
      +          
               
             
             
               ${hive.group}
               hive-jdbc
               ${hive.version}
      -        ${hive.deps.scope}
      +        
      +          
      +            ${hive.group}
      +            hive-common
      +          
      +          
      +            ${hive.group}
      +            hive-common
      +          
      +          
      +            ${hive.group}
      +            hive-metastore
      +          
      +          
      +            ${hive.group}
      +            hive-serde
      +          
      +          
      +            ${hive.group}
      +            hive-service
      +          
      +          
      +            ${hive.group}
      +            hive-shims
      +          
      +          
      +            org.apache.httpcomponents
      +            httpclient
      +          
      +          
      +            org.apache.httpcomponents
      +            httpcore
      +          
      +          
      +            org.apache.curator
      +            curator-framework
      +          
      +          
      +            org.apache.thrift
      +            libthrift
      +          
      +          
      +            org.apache.thrift
      +            libfb303
      +          
      +          
      +            org.apache.zookeeper
      +            zookeeper
      +          
      +          
      +            org.slf4j
      +            slf4j-api
      +          
      +          
      +            org.slf4j
      +            slf4j-log4j12
      +          
      +          
      +            log4j
      +            log4j
      +          
      +          
      +            commons-logging
      +            commons-logging
      +          
      +        
             
      +
             
               ${hive.group}
               hive-metastore
               ${hive.version}
               ${hive.deps.scope}
      +        
      +          
      +            ${hive.group}
      +            hive-serde
      +          
      +          
      +            ${hive.group}
      +            hive-shims
      +          
      +          
      +            org.apache.thrift
      +            libfb303
      +          
      +          
      +            org.apache.thrift
      +            libthrift
      +          
      +          
      +            org.mortbay.jetty
      +            servlet-api
      +          
      +          
      +            com.google.guava
      +            guava
      +          
      +          
      +            org.slf4j
      +            slf4j-api
      +          
      +          
      +            org.slf4j
      +            slf4j-log4j12
      +          
      +        
             
      +
             
               ${hive.group}
               hive-serde
               ${hive.version}
               ${hive.deps.scope}
               
      +          
      +            ${hive.group}
      +            hive-common
      +          
      +          
      +            ${hive.group}
      +            hive-shims
      +          
      +          
      +            commons-codec
      +            commons-codec
      +          
      +          
      +            com.google.code.findbugs
      +            jsr305
      +          
      +          
      +            org.apache.avro
      +            avro
      +          
      +          
      +            org.apache.thrift
      +            libthrift
      +          
      +          
      +            org.apache.thrift
      +            libfb303
      +          
      +          
      +            org.slf4j
      +            slf4j-api
      +          
      +          
      +            org.slf4j
      +            slf4j-log4j12
      +          
      +          
      +            log4j
      +            log4j
      +          
                 
                   commons-logging
                   commons-logging
                 
      +        
      +      
      +
      +      
      +        ${hive.group}
      +        hive-service
      +        ${hive.version}
      +        ${hive.deps.scope}
      +        
      +          
      +            ${hive.group}
      +            hive-common
      +          
      +          
      +            ${hive.group}
      +            hive-exec
      +          
      +          
      +            ${hive.group}
      +            hive-metastore
      +          
      +          
      +            ${hive.group}
      +            hive-shims
      +          
      +          
      +            commons-codec
      +            commons-codec
      +          
      +          
      +            org.apache.curator
      +            curator-framework
      +          
      +          
      +            org.apache.curator
      +            curator-recipes
      +          
      +          
      +            org.apache.thrift
      +            libfb303
      +          
      +          
      +            org.apache.thrift
      +            libthrift
      +          
      +        
      +      
      +
      +      
      +      
      +        ${hive.group}
      +        hive-shims
      +        ${hive.version}
      +        ${hive.deps.scope}
      +        
      +          
      +            com.google.guava
      +            guava
      +          
      +          
      +            org.apache.hadoop
      +            hadoop-yarn-server-resourcemanager
      +          
      +          
      +            org.apache.curator
      +            curator-framework
      +          
      +          
      +            org.apache.thrift
      +            libthrift
      +          
      +          
      +            org.apache.zookeeper
      +            zookeeper
      +          
      +          
      +            org.slf4j
      +            slf4j-api
      +          
      +          
      +            org.slf4j
      +            slf4j-log4j12
      +          
      +          
      +            log4j
      +            log4j
      +          
                 
                   commons-logging
      -            commons-logging-api
      +            commons-logging
                 
               
             
      @@ -1091,6 +1597,18 @@
               ${parquet.version}
               ${parquet.deps.scope}
             
      +      
      +        org.apache.parquet
      +        parquet-avro
      +        ${parquet.version}
      +        ${parquet.test.deps.scope}
      +      
      +      
      +        com.twitter
      +        parquet-hadoop-bundle
      +        ${hive.parquet.version}
      +        compile
      +      
             
               org.apache.flume
               flume-ng-core
      @@ -1101,6 +1619,10 @@
                   io.netty
                   netty
                 
      +          
      +            org.apache.flume
      +            flume-ng-auth
      +          
                 
                   org.apache.thrift
                   libthrift
      @@ -1127,6 +1649,125 @@
                 
               
             
      +      
      +        org.apache.calcite
      +        calcite-core
      +        ${calcite.version}
      +        
      +          
      +            com.fasterxml.jackson.core
      +            jackson-annotations
      +          
      +          
      +            com.fasterxml.jackson.core
      +            jackson-core
      +          
      +          
      +            com.fasterxml.jackson.core
      +            jackson-databind
      +          
      +          
      +            com.google.guava
      +            guava
      +          
      +          
      +            com.google.code.findbugs
      +            jsr305
      +          
      +          
      +            org.codehaus.janino
      +            janino
      +          
      +          
      +          
      +            org.hsqldb
      +            hsqldb
      +          
      +          
      +            org.pentaho
      +            pentaho-aggdesigner-algorithm
      +          
      +        
      +      
      +      
      +        org.apache.calcite
      +        calcite-avatica
      +        ${calcite.version}
      +        
      +          
      +            com.fasterxml.jackson.core
      +            jackson-annotations
      +          
      +          
      +            com.fasterxml.jackson.core
      +            jackson-core
      +          
      +          
      +            com.fasterxml.jackson.core
      +            jackson-databind
      +          
      +        
      +      
      +      
      +        org.codehaus.janino
      +        janino
      +        ${janino.version}
      +      
      +      
      +        joda-time
      +        joda-time
      +        ${joda.version}
      +      
      +      
      +        org.jodd
      +        jodd-core
      +        ${jodd.version}
      +      
      +      
      +        org.datanucleus
      +        datanucleus-core
      +        ${datanucleus-core.version}
      +      
      +      
      +        org.apache.thrift
      +        libthrift
      +        ${libthrift.version}
      +        
      +          
      +            org.apache.httpcomponents
      +            httpclient
      +          
      +          
      +            org.apache.httpcomponents
      +            httpcore
      +          
      +          
      +            org.slf4j
      +            slf4j-api
      +          
      +        
      +      
      +      
      +        org.apache.thrift
      +        libfb303
      +        ${libthrift.version}
      +        
      +          
      +            org.apache.httpcomponents
      +            httpclient
      +          
      +          
      +            org.apache.httpcomponents
      +            httpcore
      +          
      +          
      +            org.slf4j
      +            slf4j-api
      +          
      +        
      +      
           
         
       
      @@ -1146,7 +1787,7 @@
                     
                       
                         
      -                    3.0.4
      +                    ${maven.version}
                         
                         
                           ${java.version}
      @@ -1164,7 +1805,7 @@
               
                 net.alchim31.maven
                 scala-maven-plugin
      -          3.2.0
      +          3.2.2
                 
                   
                     eclipse-add-source
      @@ -1215,6 +1856,7 @@
                     ${java.version}
                     -target
                     ${java.version}
      +              -Xlint:all,-serial,-path
                   
                 
               
      @@ -1228,6 +1870,9 @@
                   UTF-8
                   1024m
                   true
      +            
      +              -Xlint:all,-serial,-path
      +            
                 
               
               
      @@ -1251,6 +1896,8 @@
                       launched by the tests have access to the correct test-time classpath.
                     -->
                     ${test_classpath}
      +              1
      +              1
                     ${test.java.home}
                   
                   
      @@ -1259,10 +1906,13 @@
                     ${project.build.directory}/tmp
                     ${spark.test.home}
                     1
      +              false
                     false
                     false
                     true
                     true
      +              
      +              src
                   
                   false
                 
      @@ -1285,6 +1935,8 @@
                       launched by the tests have access to the correct test-time classpath.
                     -->
                     ${test_classpath}
      +              1
      +              1
                     ${test.java.home}
                   
                   
      @@ -1296,6 +1948,9 @@
                     false
                     false
                     true
      +              true
      +              
      +              __not_used__
                   
                 
                 
      @@ -1365,7 +2020,12 @@
               
                 org.apache.maven.plugins
                 maven-assembly-plugin
      -          2.5.3
      +          2.5.5
      +        
      +        
      +          org.apache.maven.plugins
      +          maven-shade-plugin
      +          2.4.1
               
               
                 org.apache.maven.plugins
      @@ -1377,6 +2037,58 @@
                 maven-deploy-plugin
                 2.8.2
               
      +        
      +        
      +        
      +          org.eclipse.m2e
      +          lifecycle-mapping
      +          1.0.0
      +          
      +            
      +              
      +                
      +                  
      +                    org.apache.maven.plugins
      +                    maven-dependency-plugin
      +                    [2.8,)
      +                    
      +                      build-classpath
      +                    
      +                  
      +                  
      +                    
      +                  
      +                
      +                
      +                  
      +                    org.apache.maven.plugins
      +                    maven-jar-plugin
      +                    [2.6,)
      +                    
      +                      test-jar
      +                    
      +                  
      +                  
      +                    
      +                  
      +                
      +                
      +                  
      +                    org.apache.maven.plugins
      +                    maven-antrun-plugin
      +                    [1.8,)
      +                    
      +                      run
      +                    
      +                  
      +                  
      +                    
      +                  
      +                
      +              
      +            
      +          
      +        
             
           
       
      @@ -1394,34 +2106,12 @@
                   
                   
                     test
      -              ${test_classpath_file}
      +              test_classpath
                   
                 
               
             
       
      -      
      -      
      -        org.codehaus.gmavenplus
      -        gmavenplus-plugin
      -        1.5
      -        
      -          
      -            process-test-classes
      -            
      -              execute
      -            
      -            
      -              
      -                
      -              
      -            
      -          
      -        
      -      
             
      -          false
                 
                   
                     
      @@ -1488,36 +2175,6 @@
               org.apache.maven.plugins
               maven-enforcer-plugin
             
      -      
      -        org.codehaus.mojo
      -        build-helper-maven-plugin
      -        
      -          
      -            add-scala-sources
      -            generate-sources
      -            
      -              add-source
      -            
      -            
      -              
      -                src/main/scala
      -              
      -            
      -          
      -          
      -            add-scala-test-sources
      -            generate-test-sources
      -            
      -              add-test-source
      -            
      -            
      -              
      -                src/test/scala
      -              
      -            
      -          
      -        
      -      
             
               net.alchim31.maven
               scala-maven-plugin
      @@ -1631,6 +2288,7 @@
             kinesis-asl
             
               extras/kinesis-asl
      +        extras/kinesis-asl-assembly
             
           
       
      @@ -1687,7 +2345,7 @@
           
             hadoop-1
             
      -        1.0.4
      +        1.2.1
               2.4.1
               0.98.7-hadoop1
               hadoop1
      @@ -1707,7 +2365,6 @@
             
               2.3.0
               0.9.3
      -        3.1.1
             
           
       
      @@ -1716,7 +2373,6 @@
             
               2.4.0
               0.9.3
      -        3.1.1
             
           
       
      @@ -1725,7 +2381,6 @@
             
               2.6.0
               0.9.3
      -        3.1.1
               3.4.6
               2.6.0
             
      @@ -1739,44 +2394,6 @@
             
           
       
      -    
      -      mapr3
      -      
      -        1.0.3-mapr-3.0.3
      -        2.4.1-mapr-1408
      -        0.98.4-mapr-1408
      -        3.4.5-mapr-1406
      -      
      -    
      -
      -    
      -      mapr4
      -      
      -        2.4.1-mapr-1408
      -        2.4.1-mapr-1408
      -        0.98.4-mapr-1408
      -        3.4.5-mapr-1406
      -      
      -      
      -        
      -          org.apache.curator
      -          curator-recipes
      -          ${curator.version}
      -          
      -            
      -              org.apache.zookeeper
      -              zookeeper
      -            
      -          
      -        
      -        
      -          org.apache.zookeeper
      -          zookeeper
      -          3.4.5-mapr-1406
      -        
      -      
      -    
      -
           
             hive-thriftserver
             
      @@ -1795,6 +2412,15 @@
               ${scala.version}
               org.scala-lang
             
      +      
      +        
      +          
      +            ${jline.groupid}
      +            jline
      +            ${jline.version}
      +          
      +        
      +      
           
       
           
      @@ -1813,10 +2439,8 @@
               scala-2.11
             
             
      -        2.11.6
      +        2.11.7
               2.11
      -        2.12.1
      -        jline
             
           
       
      diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala
      index 5812b72f0aa7..519052620246 100644
      --- a/project/MimaBuild.scala
      +++ b/project/MimaBuild.scala
      @@ -91,8 +91,7 @@ object MimaBuild {
       
         def mimaSettings(sparkHome: File, projectRef: ProjectRef) = {
           val organization = "org.apache.spark"
      -    // TODO: Change this once Spark 1.4.0 is released
      -    val previousSparkVersion = "1.4.0-rc4"
      +    val previousSparkVersion = "1.5.0"
           val fullId = "spark-" + projectRef.project + "_2.10"
           mimaDefaultSettings ++
           Seq(previousArtifact := Some(organization % fullId % previousSparkVersion),
      diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
      index 015d0296dd36..1c96b0958586 100644
      --- a/project/MimaExcludes.scala
      +++ b/project/MimaExcludes.scala
      @@ -32,496 +32,689 @@ import com.typesafe.tools.mima.core.ProblemFilters._
        * MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap")
        */
       object MimaExcludes {
      -    def excludes(version: String) =
      -      version match {
      -        case v if v.startsWith("1.5") =>
      -          Seq(
      -            MimaBuild.excludeSparkPackage("deploy"),
      -            // These are needed if checking against the sbt build, since they are part of
      -            // the maven-generated artifacts in 1.3.
      -            excludePackage("org.spark-project.jetty"),
      -            MimaBuild.excludeSparkPackage("unused"),
      -            // JavaRDDLike is not meant to be extended by user programs
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.api.java.JavaRDDLike.partitioner"),
      -            // Modification of private static method
      -            ProblemFilters.exclude[IncompatibleMethTypeProblem](
      -              "org.apache.spark.streaming.kafka.KafkaUtils.org$apache$spark$streaming$kafka$KafkaUtils$$leadersForRanges"),
      -            // Mima false positive (was a private[spark] class)
      -            ProblemFilters.exclude[MissingClassProblem](
      -              "org.apache.spark.util.collection.PairIterator"),
      -            // Removing a testing method from a private class
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"),
      -            // SQL execution is considered private.
      -            excludePackage("org.apache.spark.sql.execution")
      -          )
      -        case v if v.startsWith("1.4") =>
      -          Seq(
      -            MimaBuild.excludeSparkPackage("deploy"),
      -            MimaBuild.excludeSparkPackage("ml"),
      -            // SPARK-7910 Adding a method to get the partioner to JavaRDD,
      -            ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"),
      -            // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD
      -            ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"),
      -            // These are needed if checking against the sbt build, since they are part of
      -            // the maven-generated artifacts in 1.3.
      -            excludePackage("org.spark-project.jetty"),
      -            MimaBuild.excludeSparkPackage("unused"),
      -            ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional"),
      -            ProblemFilters.exclude[IncompatibleResultTypeProblem](
      -              "org.apache.spark.rdd.JdbcRDD.compute"),
      -            ProblemFilters.exclude[IncompatibleResultTypeProblem](
      -              "org.apache.spark.broadcast.HttpBroadcastFactory.newBroadcast"),
      -            ProblemFilters.exclude[IncompatibleResultTypeProblem](
      -              "org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast"),
      -            ProblemFilters.exclude[MissingClassProblem](
      -              "org.apache.spark.scheduler.OutputCommitCoordinator$OutputCommitCoordinatorActor")
      -          ) ++ Seq(
      -            // SPARK-4655 - Making Stage an Abstract class broke binary compatility even though
      -            // the stage class is defined as private[spark]
      -            ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.scheduler.Stage")
      -          ) ++ Seq(
      -            // SPARK-6510 Add a Graph#minus method acting as Set#difference
      -            ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.minus")
      -          ) ++ Seq(
      -            // SPARK-6492 Fix deadlock in SparkContext.stop()
      -            ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.org$" +
      -                "apache$spark$SparkContext$$SPARK_CONTEXT_CONSTRUCTOR_LOCK")
      -          )++ Seq(
      -            // SPARK-6693 add tostring with max lines and width for matrix
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.linalg.Matrix.toString")
      -          )++ Seq(
      -            // SPARK-6703 Add getOrCreate method to SparkContext
      -            ProblemFilters.exclude[IncompatibleResultTypeProblem]
      -                ("org.apache.spark.SparkContext.org$apache$spark$SparkContext$$activeContext")
      -          )++ Seq(
      -            // SPARK-7090 Introduce LDAOptimizer to LDA to further improve extensibility
      -            ProblemFilters.exclude[MissingClassProblem](
      -              "org.apache.spark.mllib.clustering.LDA$EMOptimizer")
      -          ) ++ Seq(
      -            // SPARK-6756 add toSparse, toDense, numActives, numNonzeros, and compressed to Vector
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.linalg.Vector.compressed"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.linalg.Vector.toDense"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.linalg.Vector.numNonzeros"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.linalg.Vector.toSparse"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.linalg.Vector.numActives"),
      -            // SPARK-7681 add SparseVector support for gemv
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.linalg.Matrix.multiply"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.linalg.DenseMatrix.multiply"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.linalg.SparseMatrix.multiply")
      -          ) ++ Seq(
      -            // Execution should never be included as its always internal.
      -            MimaBuild.excludeSparkPackage("sql.execution"),
      -            // This `protected[sql]` method was removed in 1.3.1
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.sql.SQLContext.checkAnalysis"),
      -            // These `private[sql]` class were removed in 1.4.0:
      -            ProblemFilters.exclude[MissingClassProblem](
      -              "org.apache.spark.sql.execution.AddExchange"),
      -            ProblemFilters.exclude[MissingClassProblem](
      -              "org.apache.spark.sql.execution.AddExchange$"),
      -            ProblemFilters.exclude[MissingClassProblem](
      -              "org.apache.spark.sql.parquet.PartitionSpec"),
      -            ProblemFilters.exclude[MissingClassProblem](
      -              "org.apache.spark.sql.parquet.PartitionSpec$"),
      -            ProblemFilters.exclude[MissingClassProblem](
      -              "org.apache.spark.sql.parquet.Partition"),
      -            ProblemFilters.exclude[MissingClassProblem](
      -              "org.apache.spark.sql.parquet.Partition$"),
      -            ProblemFilters.exclude[MissingClassProblem](
      -              "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues"),
      -            ProblemFilters.exclude[MissingClassProblem](
      -              "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues$"),
      -            ProblemFilters.exclude[MissingClassProblem](
      -              "org.apache.spark.sql.parquet.ParquetRelation2"),
      -            ProblemFilters.exclude[MissingClassProblem](
      -              "org.apache.spark.sql.parquet.ParquetRelation2$"),
      -            ProblemFilters.exclude[MissingClassProblem](
      -              "org.apache.spark.sql.parquet.ParquetRelation2$MetadataCache"),
      -            // These test support classes were moved out of src/main and into src/test:
      -            ProblemFilters.exclude[MissingClassProblem](
      -              "org.apache.spark.sql.parquet.ParquetTestData"),
      -            ProblemFilters.exclude[MissingClassProblem](
      -              "org.apache.spark.sql.parquet.ParquetTestData$"),
      -            ProblemFilters.exclude[MissingClassProblem](
      -              "org.apache.spark.sql.parquet.TestGroupWriteSupport"),
      -            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData"),
      -            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData$"),
      -            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CacheManager"),
      -            // TODO: Remove the following rule once ParquetTest has been moved to src/test.
      -            ProblemFilters.exclude[MissingClassProblem](
      -              "org.apache.spark.sql.parquet.ParquetTest")
      -          ) ++ Seq(
      -            // SPARK-7530 Added StreamingContext.getState()
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.streaming.StreamingContext.state_=")
      -          ) ++ Seq(
      -            // SPARK-7081 changed ShuffleWriter from a trait to an abstract class and removed some
      -            // unnecessary type bounds in order to fix some compiler warnings that occurred when
      -            // implementing this interface in Java. Note that ShuffleWriter is private[spark].
      -            ProblemFilters.exclude[IncompatibleTemplateDefProblem](
      -              "org.apache.spark.shuffle.ShuffleWriter")
      -          ) ++ Seq(
      -            // SPARK-6888 make jdbc driver handling user definable
      -            // This patch renames some classes to API friendly names.
      -            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks$"),
      -            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks"),
      -            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.PostgresQuirks"),
      -            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.NoQuirks"),
      -            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.MySQLQuirks")
      -          )
      +  def excludes(version: String) = version match {
      +    case v if v.startsWith("1.6") =>
      +      Seq(
      +        MimaBuild.excludeSparkPackage("deploy"),
      +        MimaBuild.excludeSparkPackage("network"),
      +        // These are needed if checking against the sbt build, since they are part of
      +        // the maven-generated artifacts in 1.3.
      +        excludePackage("org.spark-project.jetty"),
      +        MimaBuild.excludeSparkPackage("unused"),
      +        // SQL execution is considered private.
      +        excludePackage("org.apache.spark.sql.execution")
      +      ) ++
      +      MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++
      +      MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++
      +      Seq(
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.ml.classification.LogisticCostFun.this"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.ml.classification.LogisticAggregator.add"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.ml.classification.LogisticAggregator.count")
      +      ) ++ Seq(
      +        // SPARK-10381 Fix types / units in private AskPermissionToCommitOutput RPC message.
      +        // This class is marked as `private` but MiMa still seems to be confused by the change.
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.scheduler.AskPermissionToCommitOutput.task"),
      +        ProblemFilters.exclude[IncompatibleResultTypeProblem](
      +          "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$2"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem](
      +          "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.scheduler.AskPermissionToCommitOutput.taskAttempt"),
      +        ProblemFilters.exclude[IncompatibleResultTypeProblem](
      +          "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$3"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem](
      +          "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem](
      +          "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply")
      +      )
      +    case v if v.startsWith("1.5") =>
      +      Seq(
      +        MimaBuild.excludeSparkPackage("network"),
      +        MimaBuild.excludeSparkPackage("deploy"),
      +        // These are needed if checking against the sbt build, since they are part of
      +        // the maven-generated artifacts in 1.3.
      +        excludePackage("org.spark-project.jetty"),
      +        MimaBuild.excludeSparkPackage("unused"),
      +        // JavaRDDLike is not meant to be extended by user programs
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.api.java.JavaRDDLike.partitioner"),
      +        // Modification of private static method
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem](
      +          "org.apache.spark.streaming.kafka.KafkaUtils.org$apache$spark$streaming$kafka$KafkaUtils$$leadersForRanges"),
      +        // Mima false positive (was a private[spark] class)
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.util.collection.PairIterator"),
      +        // Removing a testing method from a private class
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"),
      +        // While private MiMa is still not happy about the changes,
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.ml.regression.LeastSquaresAggregator.this"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.ml.regression.LeastSquaresCostFun.this"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.ml.classification.LogisticCostFun.this"),
      +        // SQL execution is considered private.
      +        excludePackage("org.apache.spark.sql.execution"),
      +        // The old JSON RDD is removed in favor of streaming Jackson
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD"),
      +        // local function inside a method
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$24")
      +      ) ++ Seq(
      +        // SPARK-8479 Add numNonzeros and numActives to Matrix.
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.linalg.Matrix.numNonzeros"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.linalg.Matrix.numActives")
      +      ) ++ Seq(
      +        // SPARK-8914 Remove RDDApi
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RDDApi")
      +      ) ++ Seq(
      +        // SPARK-7292 Provide operator to truncate lineage cheaply
      +        ProblemFilters.exclude[AbstractClassProblem](
      +          "org.apache.spark.rdd.RDDCheckpointData"),
      +        ProblemFilters.exclude[AbstractClassProblem](
      +          "org.apache.spark.rdd.CheckpointRDD")
      +      ) ++ Seq(
      +        // SPARK-8701 Add input metadata in the batch page.
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.streaming.scheduler.InputInfo$"),
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.streaming.scheduler.InputInfo")
      +      ) ++ Seq(
      +        // SPARK-6797 Support YARN modes for SparkR
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.api.r.PairwiseRRDD.this"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.api.r.RRDD.createRWorker"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.api.r.RRDD.this"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.api.r.StringRRDD.this"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.api.r.BaseRRDD.this")
      +      ) ++ Seq(
      +        // SPARK-7422 add argmax for sparse vectors
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.linalg.Vector.argmax")
      +      ) ++ Seq(
      +        // SPARK-8906 Move all internal data source classes into execution.datasources
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.ResolvedDataSource"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopPartition"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DefaultWriterContainer"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DynamicPartitionWriterContainer"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.BaseWriterContainer"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.ResolvedDataSource$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLParser"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CaseInsensitiveMap"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException"),
      +        // SPARK-9763 Minimize exposure of internal SQL classes
      +        excludePackage("org.apache.spark.sql.parquet"),
      +        excludePackage("org.apache.spark.sql.json"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$DecimalConversion$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartition"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JdbcUtils$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$DecimalConversion"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartitioningInfo$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartition$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$JDBCConversion"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package$DriverWrapper"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartitioningInfo"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JdbcUtils"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DefaultSource"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRelation$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRelation")
      +      ) ++ Seq(
      +        // SPARK-4751 Dynamic allocation for standalone mode
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.SparkContext.supportDynamicAllocation")
      +      ) ++ Seq(
      +        // SPARK-9580: Remove SQL test singletons
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.sql.test.LocalSQLContext$SQLSession"),
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.sql.test.LocalSQLContext"),
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.sql.test.TestSQLContext"),
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.sql.test.TestSQLContext$")
      +      ) ++ Seq(
      +        // SPARK-9704 Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs
      +        ProblemFilters.exclude[IncompatibleResultTypeProblem](
      +          "org.apache.spark.mllib.linalg.VectorUDT.serialize")
      +      ) ++ Seq(
      +        // SPARK-10381 Fix types / units in private AskPermissionToCommitOutput RPC message.
      +        // This class is marked as `private` but MiMa still seems to be confused by the change.
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.scheduler.AskPermissionToCommitOutput.task"),
      +        ProblemFilters.exclude[IncompatibleResultTypeProblem](
      +          "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$2"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem](
      +          "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.scheduler.AskPermissionToCommitOutput.taskAttempt"),
      +        ProblemFilters.exclude[IncompatibleResultTypeProblem](
      +          "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$3"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem](
      +          "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem](
      +          "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply")
      +      )
       
      -        case v if v.startsWith("1.3") =>
      -          Seq(
      -            MimaBuild.excludeSparkPackage("deploy"),
      -            MimaBuild.excludeSparkPackage("ml"),
      -            // These are needed if checking against the sbt build, since they are part of
      -            // the maven-generated artifacts in the 1.2 build.
      -            MimaBuild.excludeSparkPackage("unused"),
      -            ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional")
      -          ) ++ Seq(
      -            // SPARK-2321
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.SparkStageInfoImpl.this"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.SparkStageInfo.submissionTime")
      -          ) ++ Seq(
      -            // SPARK-4614
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.linalg.Matrices.randn"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.linalg.Matrices.rand")
      -          ) ++ Seq(
      -            // SPARK-5321
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.linalg.SparseMatrix.transposeMultiply"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.linalg.Matrix.transpose"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.linalg.DenseMatrix.transposeMultiply"),
      -            ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix." +
      -                "org$apache$spark$mllib$linalg$Matrix$_setter_$isTransposed_="),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.linalg.Matrix.isTransposed"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.linalg.Matrix.foreachActive")
      -          ) ++ Seq(
      -            // SPARK-5540
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.recommendation.ALS.solveLeastSquares"),
      -            // SPARK-5536
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateBlock")
      -          ) ++ Seq(
      -            // SPARK-3325
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.streaming.api.java.JavaDStreamLike.print"),
      -            // SPARK-2757
      -            ProblemFilters.exclude[IncompatibleResultTypeProblem](
      -              "org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler." +
      -                "removeAndGetProcessor")
      -          ) ++ Seq(
      -            // SPARK-5123 (SparkSQL data type change) - alpha component only
      -            ProblemFilters.exclude[IncompatibleResultTypeProblem](
      -              "org.apache.spark.ml.feature.HashingTF.outputDataType"),
      -            ProblemFilters.exclude[IncompatibleResultTypeProblem](
      -              "org.apache.spark.ml.feature.Tokenizer.outputDataType"),
      -            ProblemFilters.exclude[IncompatibleMethTypeProblem](
      -              "org.apache.spark.ml.feature.Tokenizer.validateInputType"),
      -            ProblemFilters.exclude[IncompatibleMethTypeProblem](
      -              "org.apache.spark.ml.classification.LogisticRegressionModel.validateAndTransformSchema"),
      -            ProblemFilters.exclude[IncompatibleMethTypeProblem](
      -              "org.apache.spark.ml.classification.LogisticRegression.validateAndTransformSchema")
      -          ) ++ Seq(
      -            // SPARK-4014
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.TaskContext.taskAttemptId"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.TaskContext.attemptNumber")
      -          ) ++ Seq(
      -            // SPARK-5166 Spark SQL API stabilization
      -            ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"),
      -            ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit"),
      -            ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Transformer.transform"),
      -            ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Pipeline.fit"),
      -            ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PipelineModel.transform"),
      -            ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Estimator.fit"),
      -            ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Evaluator.evaluate"),
      -            ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Evaluator.evaluate"),
      -            ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidator.fit"),
      -            ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidatorModel.transform"),
      -            ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScaler.fit"),
      -            ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.transform"),
      -            ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.transform"),
      -            ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegression.fit"),
      -            ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate")
      -          ) ++ Seq(
      -            // SPARK-5270
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.api.java.JavaRDDLike.isEmpty")
      -          ) ++ Seq(
      -            // SPARK-5430
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.api.java.JavaRDDLike.treeReduce"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.api.java.JavaRDDLike.treeAggregate")
      -          ) ++ Seq(
      -            // SPARK-5297 Java FileStream do not work with custom key/values
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.streaming.api.java.JavaStreamingContext.fileStream")
      -          ) ++ Seq(
      -            // SPARK-5315 Spark Streaming Java API returns Scala DStream
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow")
      -          ) ++ Seq(
      -            // SPARK-5461 Graph should have isCheckpointed, getCheckpointFiles methods
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.graphx.Graph.getCheckpointFiles"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.graphx.Graph.isCheckpointed")
      -          ) ++ Seq(
      -            // SPARK-4789 Standardize ML Prediction APIs
      -            ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.linalg.VectorUDT"),
      -            ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"),
      -            ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.sqlType")
      -          ) ++ Seq(
      -            // SPARK-5814
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$wrapDoubleArray"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$fillFullMatrix"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$iterations"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeOutLinkBlock"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$computeYtY"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeLinkRDDs"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$alpha"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$randomFactor"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeInLinkBlock"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$dspr"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$lambda"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$implicitPrefs"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$rank")
      -          ) ++ Seq(
      -            // SPARK-4682
      -            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.RealClock"),
      -            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Clock"),
      -            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.TestClock")
      -          ) ++ Seq(
      -            // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD
      -            ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff")
      -          )
      +    case v if v.startsWith("1.4") =>
      +      Seq(
      +        MimaBuild.excludeSparkPackage("deploy"),
      +        MimaBuild.excludeSparkPackage("ml"),
      +        // SPARK-7910 Adding a method to get the partioner to JavaRDD,
      +        ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"),
      +        // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD
      +        ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"),
      +        // These are needed if checking against the sbt build, since they are part of
      +        // the maven-generated artifacts in 1.3.
      +        excludePackage("org.spark-project.jetty"),
      +        MimaBuild.excludeSparkPackage("unused"),
      +        ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional"),
      +        ProblemFilters.exclude[IncompatibleResultTypeProblem](
      +          "org.apache.spark.rdd.JdbcRDD.compute"),
      +        ProblemFilters.exclude[IncompatibleResultTypeProblem](
      +          "org.apache.spark.broadcast.HttpBroadcastFactory.newBroadcast"),
      +        ProblemFilters.exclude[IncompatibleResultTypeProblem](
      +          "org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast"),
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.scheduler.OutputCommitCoordinator$OutputCommitCoordinatorEndpoint")
      +      ) ++ Seq(
      +        // SPARK-4655 - Making Stage an Abstract class broke binary compatility even though
      +        // the stage class is defined as private[spark]
      +        ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.scheduler.Stage")
      +      ) ++ Seq(
      +        // SPARK-6510 Add a Graph#minus method acting as Set#difference
      +        ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.minus")
      +      ) ++ Seq(
      +        // SPARK-6492 Fix deadlock in SparkContext.stop()
      +        ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.org$" +
      +            "apache$spark$SparkContext$$SPARK_CONTEXT_CONSTRUCTOR_LOCK")
      +      )++ Seq(
      +        // SPARK-6693 add tostring with max lines and width for matrix
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.linalg.Matrix.toString")
      +      )++ Seq(
      +        // SPARK-6703 Add getOrCreate method to SparkContext
      +        ProblemFilters.exclude[IncompatibleResultTypeProblem]
      +            ("org.apache.spark.SparkContext.org$apache$spark$SparkContext$$activeContext")
      +      )++ Seq(
      +        // SPARK-7090 Introduce LDAOptimizer to LDA to further improve extensibility
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.mllib.clustering.LDA$EMOptimizer")
      +      ) ++ Seq(
      +        // SPARK-6756 add toSparse, toDense, numActives, numNonzeros, and compressed to Vector
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.linalg.Vector.compressed"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.linalg.Vector.toDense"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.linalg.Vector.numNonzeros"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.linalg.Vector.toSparse"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.linalg.Vector.numActives"),
      +        // SPARK-7681 add SparseVector support for gemv
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.linalg.Matrix.multiply"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.linalg.DenseMatrix.multiply"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.linalg.SparseMatrix.multiply")
      +      ) ++ Seq(
      +        // Execution should never be included as its always internal.
      +        MimaBuild.excludeSparkPackage("sql.execution"),
      +        // This `protected[sql]` method was removed in 1.3.1
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.sql.SQLContext.checkAnalysis"),
      +        // These `private[sql]` class were removed in 1.4.0:
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.sql.execution.AddExchange"),
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.sql.execution.AddExchange$"),
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.sql.parquet.PartitionSpec"),
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.sql.parquet.PartitionSpec$"),
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.sql.parquet.Partition"),
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.sql.parquet.Partition$"),
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues"),
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues$"),
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.sql.parquet.ParquetRelation2"),
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.sql.parquet.ParquetRelation2$"),
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.sql.parquet.ParquetRelation2$MetadataCache"),
      +        // These test support classes were moved out of src/main and into src/test:
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.sql.parquet.ParquetTestData"),
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.sql.parquet.ParquetTestData$"),
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.sql.parquet.TestGroupWriteSupport"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CacheManager"),
      +        // TODO: Remove the following rule once ParquetTest has been moved to src/test.
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.sql.parquet.ParquetTest")
      +      ) ++ Seq(
      +        // SPARK-7530 Added StreamingContext.getState()
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.streaming.StreamingContext.state_=")
      +      ) ++ Seq(
      +        // SPARK-7081 changed ShuffleWriter from a trait to an abstract class and removed some
      +        // unnecessary type bounds in order to fix some compiler warnings that occurred when
      +        // implementing this interface in Java. Note that ShuffleWriter is private[spark].
      +        ProblemFilters.exclude[IncompatibleTemplateDefProblem](
      +          "org.apache.spark.shuffle.ShuffleWriter")
      +      ) ++ Seq(
      +        // SPARK-6888 make jdbc driver handling user definable
      +        // This patch renames some classes to API friendly names.
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.PostgresQuirks"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.NoQuirks"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.MySQLQuirks")
      +      )
       
      -        case v if v.startsWith("1.2") =>
      -          Seq(
      -            MimaBuild.excludeSparkPackage("deploy"),
      -            MimaBuild.excludeSparkPackage("graphx")
      -          ) ++
      -          MimaBuild.excludeSparkClass("mllib.linalg.Matrix") ++
      -          MimaBuild.excludeSparkClass("mllib.linalg.Vector") ++
      -          Seq(
      -            ProblemFilters.exclude[IncompatibleTemplateDefProblem](
      -              "org.apache.spark.scheduler.TaskLocation"),
      -            // Added normL1 and normL2 to trait MultivariateStatisticalSummary
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL1"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL2"),
      -            // MapStatus should be private[spark]
      -            ProblemFilters.exclude[IncompatibleTemplateDefProblem](
      -              "org.apache.spark.scheduler.MapStatus"),
      -            ProblemFilters.exclude[MissingClassProblem](
      -              "org.apache.spark.network.netty.PathResolver"),
      -            ProblemFilters.exclude[MissingClassProblem](
      -              "org.apache.spark.network.netty.client.BlockClientListener"),
      +    case v if v.startsWith("1.3") =>
      +      Seq(
      +        MimaBuild.excludeSparkPackage("deploy"),
      +        MimaBuild.excludeSparkPackage("ml"),
      +        // These are needed if checking against the sbt build, since they are part of
      +        // the maven-generated artifacts in the 1.2 build.
      +        MimaBuild.excludeSparkPackage("unused"),
      +        ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional")
      +      ) ++ Seq(
      +        // SPARK-2321
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.SparkStageInfoImpl.this"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.SparkStageInfo.submissionTime")
      +      ) ++ Seq(
      +        // SPARK-4614
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.linalg.Matrices.randn"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.linalg.Matrices.rand")
      +      ) ++ Seq(
      +        // SPARK-5321
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.linalg.SparseMatrix.transposeMultiply"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.linalg.Matrix.transpose"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.linalg.DenseMatrix.transposeMultiply"),
      +        ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix." +
      +            "org$apache$spark$mllib$linalg$Matrix$_setter_$isTransposed_="),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.linalg.Matrix.isTransposed"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.linalg.Matrix.foreachActive")
      +      ) ++ Seq(
      +        // SPARK-5540
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.recommendation.ALS.solveLeastSquares"),
      +        // SPARK-5536
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateBlock")
      +      ) ++ Seq(
      +        // SPARK-3325
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.streaming.api.java.JavaDStreamLike.print"),
      +        // SPARK-2757
      +        ProblemFilters.exclude[IncompatibleResultTypeProblem](
      +          "org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler." +
      +            "removeAndGetProcessor")
      +      ) ++ Seq(
      +        // SPARK-5123 (SparkSQL data type change) - alpha component only
      +        ProblemFilters.exclude[IncompatibleResultTypeProblem](
      +          "org.apache.spark.ml.feature.HashingTF.outputDataType"),
      +        ProblemFilters.exclude[IncompatibleResultTypeProblem](
      +          "org.apache.spark.ml.feature.Tokenizer.outputDataType"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem](
      +          "org.apache.spark.ml.feature.Tokenizer.validateInputType"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem](
      +          "org.apache.spark.ml.classification.LogisticRegressionModel.validateAndTransformSchema"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem](
      +          "org.apache.spark.ml.classification.LogisticRegression.validateAndTransformSchema")
      +      ) ++ Seq(
      +        // SPARK-4014
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.TaskContext.taskAttemptId"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.TaskContext.attemptNumber")
      +      ) ++ Seq(
      +        // SPARK-5166 Spark SQL API stabilization
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit"),
      +        ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Transformer.transform"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Pipeline.fit"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PipelineModel.transform"),
      +        ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Estimator.fit"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Evaluator.evaluate"),
      +        ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Evaluator.evaluate"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidator.fit"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidatorModel.transform"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScaler.fit"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.transform"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.transform"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegression.fit"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate")
      +      ) ++ Seq(
      +        // SPARK-5270
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.api.java.JavaRDDLike.isEmpty")
      +      ) ++ Seq(
      +        // SPARK-5430
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.api.java.JavaRDDLike.treeReduce"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.api.java.JavaRDDLike.treeAggregate")
      +      ) ++ Seq(
      +        // SPARK-5297 Java FileStream do not work with custom key/values
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.streaming.api.java.JavaStreamingContext.fileStream")
      +      ) ++ Seq(
      +        // SPARK-5315 Spark Streaming Java API returns Scala DStream
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow")
      +      ) ++ Seq(
      +        // SPARK-5461 Graph should have isCheckpointed, getCheckpointFiles methods
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.graphx.Graph.getCheckpointFiles"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.graphx.Graph.isCheckpointed")
      +      ) ++ Seq(
      +        // SPARK-4789 Standardize ML Prediction APIs
      +        ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.linalg.VectorUDT"),
      +        ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"),
      +        ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.sqlType")
      +      ) ++ Seq(
      +        // SPARK-5814
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$wrapDoubleArray"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$fillFullMatrix"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$iterations"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeOutLinkBlock"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$computeYtY"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeLinkRDDs"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$alpha"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$randomFactor"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeInLinkBlock"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$dspr"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$lambda"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$implicitPrefs"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$rank")
      +      ) ++ Seq(
      +        // SPARK-4682
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.RealClock"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Clock"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.TestClock")
      +      ) ++ Seq(
      +        // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD
      +        ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff")
      +      )
       
      -            // TaskContext was promoted to Abstract class
      -            ProblemFilters.exclude[AbstractClassProblem](
      -              "org.apache.spark.TaskContext"),
      -            ProblemFilters.exclude[IncompatibleTemplateDefProblem](
      -              "org.apache.spark.util.collection.SortDataFormat")
      -          ) ++ Seq(
      -            // Adding new methods to the JavaRDDLike trait:
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.api.java.JavaRDDLike.takeAsync"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.api.java.JavaRDDLike.foreachPartitionAsync"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.api.java.JavaRDDLike.countAsync"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.api.java.JavaRDDLike.foreachAsync"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.api.java.JavaRDDLike.collectAsync")
      -          ) ++ Seq(
      -            // SPARK-3822
      -            ProblemFilters.exclude[IncompatibleResultTypeProblem](
      -              "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler")
      -          ) ++ Seq(
      -            // SPARK-1209
      -            ProblemFilters.exclude[MissingClassProblem](
      -              "org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil"),
      -            ProblemFilters.exclude[MissingClassProblem](
      -              "org.apache.hadoop.mapred.SparkHadoopMapRedUtil"),
      -            ProblemFilters.exclude[MissingTypesProblem](
      -              "org.apache.spark.rdd.PairRDDFunctions")
      -          ) ++ Seq(
      -            // SPARK-4062
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.streaming.kafka.KafkaReceiver#MessageHandler.this")
      -          )
      +    case v if v.startsWith("1.2") =>
      +      Seq(
      +        MimaBuild.excludeSparkPackage("deploy"),
      +        MimaBuild.excludeSparkPackage("graphx")
      +      ) ++
      +      MimaBuild.excludeSparkClass("mllib.linalg.Matrix") ++
      +      MimaBuild.excludeSparkClass("mllib.linalg.Vector") ++
      +      Seq(
      +        ProblemFilters.exclude[IncompatibleTemplateDefProblem](
      +          "org.apache.spark.scheduler.TaskLocation"),
      +        // Added normL1 and normL2 to trait MultivariateStatisticalSummary
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL1"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL2"),
      +        // MapStatus should be private[spark]
      +        ProblemFilters.exclude[IncompatibleTemplateDefProblem](
      +          "org.apache.spark.scheduler.MapStatus"),
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.network.netty.PathResolver"),
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.spark.network.netty.client.BlockClientListener"),
       
      -        case v if v.startsWith("1.1") =>
      -          Seq(
      -            MimaBuild.excludeSparkPackage("deploy"),
      -            MimaBuild.excludeSparkPackage("graphx")
      -          ) ++
      -          Seq(
      -            // Adding new method to JavaRDLike trait - we should probably mark this as a developer API.
      -            ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitions"),
      -            // Should probably mark this as Experimental
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.api.java.JavaRDDLike.foreachAsync"),
      -            // We made a mistake earlier (ed06500d3) in the Java API to use default parameter values
      -            // for countApproxDistinct* functions, which does not work in Java. We later removed
      -            // them, and use the following to tell Mima to not care about them.
      -            ProblemFilters.exclude[IncompatibleResultTypeProblem](
      -              "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"),
      -            ProblemFilters.exclude[IncompatibleResultTypeProblem](
      -              "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.api.java.JavaPairRDD.countApproxDistinct$default$1"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey$default$1"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.api.java.JavaRDD.countApproxDistinct$default$1"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.storage.DiskStore.getValues"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.storage.MemoryStore.Entry")
      -          ) ++
      -          Seq(
      -            // Serializer interface change. See SPARK-3045.
      -            ProblemFilters.exclude[IncompatibleTemplateDefProblem](
      -              "org.apache.spark.serializer.DeserializationStream"),
      -            ProblemFilters.exclude[IncompatibleTemplateDefProblem](
      -              "org.apache.spark.serializer.Serializer"),
      -            ProblemFilters.exclude[IncompatibleTemplateDefProblem](
      -              "org.apache.spark.serializer.SerializationStream"),
      -            ProblemFilters.exclude[IncompatibleTemplateDefProblem](
      -              "org.apache.spark.serializer.SerializerInstance")
      -          )++
      -          Seq(
      -            // Renamed putValues -> putArray + putIterator
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.storage.MemoryStore.putValues"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.storage.DiskStore.putValues"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.storage.TachyonStore.putValues")
      -          ) ++
      -          Seq(
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.streaming.flume.FlumeReceiver.this"),
      -            ProblemFilters.exclude[IncompatibleMethTypeProblem](
      -              "org.apache.spark.streaming.kafka.KafkaUtils.createStream"),
      -            ProblemFilters.exclude[IncompatibleMethTypeProblem](
      -              "org.apache.spark.streaming.kafka.KafkaReceiver.this")
      -          ) ++
      -          Seq( // Ignore some private methods in ALS.
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"),
      -            ProblemFilters.exclude[MissingMethodProblem]( // The only public constructor is the one without arguments.
      -              "org.apache.spark.mllib.recommendation.ALS.this"),
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$$default$7"),
      -            ProblemFilters.exclude[IncompatibleMethTypeProblem](
      -              "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures")
      -          ) ++
      -          MimaBuild.excludeSparkClass("mllib.linalg.distributed.ColumnStatisticsAggregator") ++
      -          MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++
      -          MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++
      -          MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++
      -          MimaBuild.excludeSparkClass("storage.Values") ++
      -          MimaBuild.excludeSparkClass("storage.Entry") ++
      -          MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++
      -          // Class was missing "@DeveloperApi" annotation in 1.0.
      -          MimaBuild.excludeSparkClass("scheduler.SparkListenerApplicationStart") ++
      -          Seq(
      -            ProblemFilters.exclude[IncompatibleMethTypeProblem](
      -              "org.apache.spark.mllib.tree.impurity.Gini.calculate"),
      -            ProblemFilters.exclude[IncompatibleMethTypeProblem](
      -              "org.apache.spark.mllib.tree.impurity.Entropy.calculate"),
      -            ProblemFilters.exclude[IncompatibleMethTypeProblem](
      -              "org.apache.spark.mllib.tree.impurity.Variance.calculate")
      -          ) ++
      -          Seq( // Package-private classes removed in SPARK-2341
      -            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser"),
      -            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser$"),
      -            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser"),
      -            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser$"),
      -            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser"),
      -            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser$")
      -          ) ++
      -          Seq( // package-private classes removed in MLlib
      -            ProblemFilters.exclude[MissingMethodProblem](
      -              "org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.org$apache$spark$mllib$regression$GeneralizedLinearAlgorithm$$prependOne")
      -          ) ++
      -          Seq( // new Vector methods in MLlib (binary compatible assuming users do not implement Vector)
      -            ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.copy")
      -          ) ++
      -          Seq( // synthetic methods generated in LabeledPoint
      -            ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.regression.LabeledPoint$"),
      -            ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.regression.LabeledPoint.apply"),
      -            ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.regression.LabeledPoint.toString")
      -          ) ++
      -          Seq ( // Scala 2.11 compatibility fix
      -            ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.$default$2")
      -          )
      -        case v if v.startsWith("1.0") =>
      -          Seq(
      -            MimaBuild.excludeSparkPackage("api.java"),
      -            MimaBuild.excludeSparkPackage("mllib"),
      -            MimaBuild.excludeSparkPackage("streaming")
      -          ) ++
      -          MimaBuild.excludeSparkClass("rdd.ClassTags") ++
      -          MimaBuild.excludeSparkClass("util.XORShiftRandom") ++
      -          MimaBuild.excludeSparkClass("graphx.EdgeRDD") ++
      -          MimaBuild.excludeSparkClass("graphx.VertexRDD") ++
      -          MimaBuild.excludeSparkClass("graphx.impl.GraphImpl") ++
      -          MimaBuild.excludeSparkClass("graphx.impl.RoutingTable") ++
      -          MimaBuild.excludeSparkClass("graphx.util.collection.PrimitiveKeyOpenHashMap") ++
      -          MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") ++
      -          MimaBuild.excludeSparkClass("mllib.recommendation.MFDataGenerator") ++
      -          MimaBuild.excludeSparkClass("mllib.optimization.SquaredGradient") ++
      -          MimaBuild.excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++
      -          MimaBuild.excludeSparkClass("mllib.regression.LassoWithSGD") ++
      -          MimaBuild.excludeSparkClass("mllib.regression.LinearRegressionWithSGD")
      -        case _ => Seq()
      -      }
      -}
      +        // TaskContext was promoted to Abstract class
      +        ProblemFilters.exclude[AbstractClassProblem](
      +          "org.apache.spark.TaskContext"),
      +        ProblemFilters.exclude[IncompatibleTemplateDefProblem](
      +          "org.apache.spark.util.collection.SortDataFormat")
      +      ) ++ Seq(
      +        // Adding new methods to the JavaRDDLike trait:
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.api.java.JavaRDDLike.takeAsync"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.api.java.JavaRDDLike.foreachPartitionAsync"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.api.java.JavaRDDLike.countAsync"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.api.java.JavaRDDLike.foreachAsync"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.api.java.JavaRDDLike.collectAsync")
      +      ) ++ Seq(
      +        // SPARK-3822
      +        ProblemFilters.exclude[IncompatibleResultTypeProblem](
      +          "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler")
      +      ) ++ Seq(
      +        // SPARK-1209
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil"),
      +        ProblemFilters.exclude[MissingClassProblem](
      +          "org.apache.hadoop.mapred.SparkHadoopMapRedUtil"),
      +        ProblemFilters.exclude[MissingTypesProblem](
      +          "org.apache.spark.rdd.PairRDDFunctions")
      +      ) ++ Seq(
      +        // SPARK-4062
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.streaming.kafka.KafkaReceiver#MessageHandler.this")
      +      )
      +
      +    case v if v.startsWith("1.1") =>
      +      Seq(
      +        MimaBuild.excludeSparkPackage("deploy"),
      +        MimaBuild.excludeSparkPackage("graphx")
      +      ) ++
      +      Seq(
      +        // Adding new method to JavaRDLike trait - we should probably mark this as a developer API.
      +        ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitions"),
      +        // Should probably mark this as Experimental
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.api.java.JavaRDDLike.foreachAsync"),
      +        // We made a mistake earlier (ed06500d3) in the Java API to use default parameter values
      +        // for countApproxDistinct* functions, which does not work in Java. We later removed
      +        // them, and use the following to tell Mima to not care about them.
      +        ProblemFilters.exclude[IncompatibleResultTypeProblem](
      +          "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"),
      +        ProblemFilters.exclude[IncompatibleResultTypeProblem](
      +          "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.api.java.JavaPairRDD.countApproxDistinct$default$1"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey$default$1"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.api.java.JavaRDD.countApproxDistinct$default$1"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.storage.DiskStore.getValues"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.storage.MemoryStore.Entry")
      +      ) ++
      +      Seq(
      +        // Serializer interface change. See SPARK-3045.
      +        ProblemFilters.exclude[IncompatibleTemplateDefProblem](
      +          "org.apache.spark.serializer.DeserializationStream"),
      +        ProblemFilters.exclude[IncompatibleTemplateDefProblem](
      +          "org.apache.spark.serializer.Serializer"),
      +        ProblemFilters.exclude[IncompatibleTemplateDefProblem](
      +          "org.apache.spark.serializer.SerializationStream"),
      +        ProblemFilters.exclude[IncompatibleTemplateDefProblem](
      +          "org.apache.spark.serializer.SerializerInstance")
      +      )++
      +      Seq(
      +        // Renamed putValues -> putArray + putIterator
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.storage.MemoryStore.putValues"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.storage.DiskStore.putValues"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.storage.TachyonStore.putValues")
      +      ) ++
      +      Seq(
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.streaming.flume.FlumeReceiver.this"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem](
      +          "org.apache.spark.streaming.kafka.KafkaUtils.createStream"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem](
      +          "org.apache.spark.streaming.kafka.KafkaReceiver.this")
      +      ) ++
      +      Seq( // Ignore some private methods in ALS.
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"),
      +        ProblemFilters.exclude[MissingMethodProblem]( // The only public constructor is the one without arguments.
      +          "org.apache.spark.mllib.recommendation.ALS.this"),
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$$default$7"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem](
      +          "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures")
      +      ) ++
      +      MimaBuild.excludeSparkClass("mllib.linalg.distributed.ColumnStatisticsAggregator") ++
      +      MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++
      +      MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++
      +      MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++
      +      MimaBuild.excludeSparkClass("storage.Values") ++
      +      MimaBuild.excludeSparkClass("storage.Entry") ++
      +      MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++
      +      // Class was missing "@DeveloperApi" annotation in 1.0.
      +      MimaBuild.excludeSparkClass("scheduler.SparkListenerApplicationStart") ++
      +      Seq(
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem](
      +          "org.apache.spark.mllib.tree.impurity.Gini.calculate"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem](
      +          "org.apache.spark.mllib.tree.impurity.Entropy.calculate"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem](
      +          "org.apache.spark.mllib.tree.impurity.Variance.calculate")
      +      ) ++
      +      Seq( // Package-private classes removed in SPARK-2341
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser$"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser"),
      +        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser$")
      +      ) ++
      +      Seq( // package-private classes removed in MLlib
      +        ProblemFilters.exclude[MissingMethodProblem](
      +          "org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.org$apache$spark$mllib$regression$GeneralizedLinearAlgorithm$$prependOne")
      +      ) ++
      +      Seq( // new Vector methods in MLlib (binary compatible assuming users do not implement Vector)
      +        ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.copy")
      +      ) ++
      +      Seq( // synthetic methods generated in LabeledPoint
      +        ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.regression.LabeledPoint$"),
      +        ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.regression.LabeledPoint.apply"),
      +        ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.regression.LabeledPoint.toString")
      +      ) ++
      +      Seq ( // Scala 2.11 compatibility fix
      +        ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.$default$2")
      +      )
      +    case v if v.startsWith("1.0") =>
      +      Seq(
      +        MimaBuild.excludeSparkPackage("api.java"),
      +        MimaBuild.excludeSparkPackage("mllib"),
      +        MimaBuild.excludeSparkPackage("streaming")
      +      ) ++
      +      MimaBuild.excludeSparkClass("rdd.ClassTags") ++
      +      MimaBuild.excludeSparkClass("util.XORShiftRandom") ++
      +      MimaBuild.excludeSparkClass("graphx.EdgeRDD") ++
      +      MimaBuild.excludeSparkClass("graphx.VertexRDD") ++
      +      MimaBuild.excludeSparkClass("graphx.impl.GraphImpl") ++
      +      MimaBuild.excludeSparkClass("graphx.impl.RoutingTable") ++
      +      MimaBuild.excludeSparkClass("graphx.util.collection.PrimitiveKeyOpenHashMap") ++
      +      MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") ++
      +      MimaBuild.excludeSparkClass("mllib.recommendation.MFDataGenerator") ++
      +      MimaBuild.excludeSparkClass("mllib.optimization.SquaredGradient") ++
      +      MimaBuild.excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++
      +      MimaBuild.excludeSparkClass("mllib.regression.LassoWithSGD") ++
      +      MimaBuild.excludeSparkClass("mllib.regression.LinearRegressionWithSGD")
      +    case _ => Seq()
      +  }
      +}
      \ No newline at end of file
      diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
      index e01720296fed..901cfa538d23 100644
      --- a/project/SparkBuild.scala
      +++ b/project/SparkBuild.scala
      @@ -18,13 +18,13 @@
       import java.io._
       
       import scala.util.Properties
      -import scala.collection.JavaConversions._
      +import scala.collection.JavaConverters._
       
       import sbt._
       import sbt.Classpaths.publishTask
       import sbt.Keys._
       import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion
      -import com.typesafe.sbt.pom.{loadEffectivePom, PomBuild, SbtPomKeys}
      +import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys}
       import net.virtualvoid.sbt.graph.Plugin.graphSettings
       
       import spray.revolver.RevolverPlugin._
      @@ -42,11 +42,11 @@ object BuildCommons {
             "streaming-zeromq", "launcher", "unsafe").map(ProjectRef(buildLocation, _))
       
         val optionallyEnabledProjects@Seq(yarn, yarnStable, java8Tests, sparkGangliaLgpl,
      -    sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl",
      -    "kinesis-asl").map(ProjectRef(buildLocation, _))
      +    streamingKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl",
      +    "streaming-kinesis-asl").map(ProjectRef(buildLocation, _))
       
      -  val assemblyProjects@Seq(assembly, examples, networkYarn, streamingKafkaAssembly) =
      -    Seq("assembly", "examples", "network-yarn", "streaming-kafka-assembly")
      +  val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingMqttAssembly, streamingKinesisAslAssembly) =
      +    Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-mqtt-assembly", "streaming-kinesis-asl-assembly")
             .map(ProjectRef(buildLocation, _))
       
         val tools = ProjectRef(buildLocation, "tools")
      @@ -69,6 +69,7 @@ object SparkBuild extends PomBuild {
           import scala.collection.mutable
           var isAlphaYarn = false
           var profiles: mutable.Seq[String] = mutable.Seq("sbt")
      +    // scalastyle:off println
           if (Properties.envOrNone("SPARK_GANGLIA_LGPL").isDefined) {
             println("NOTE: SPARK_GANGLIA_LGPL is deprecated, please use -Pspark-ganglia-lgpl flag.")
             profiles ++= Seq("spark-ganglia-lgpl")
      @@ -88,6 +89,7 @@ object SparkBuild extends PomBuild {
             println("NOTE: SPARK_YARN is deprecated, please use -Pyarn flag.")
             profiles ++= Seq("yarn")
           }
      +    // scalastyle:on println
           profiles
         }
       
      @@ -96,8 +98,10 @@ object SparkBuild extends PomBuild {
           case None => backwardCompatibility
           case Some(v) =>
             if (backwardCompatibility.nonEmpty)
      +        // scalastyle:off println
               println("Note: We ignore environment variables, when use of profile is detected in " +
                 "conjunction with environment variable.")
      +        // scalastyle:on println
             v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq
           }
       
      @@ -116,7 +120,7 @@ object SparkBuild extends PomBuild {
           case _ =>
         }
       
      -  override val userPropertiesMap = System.getProperties.toMap
      +  override val userPropertiesMap = System.getProperties.asScala.toMap
       
         lazy val MavenCompile = config("m2r") extend(Compile)
         lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy")
      @@ -150,7 +154,38 @@ object SparkBuild extends PomBuild {
             if (major.toInt >= 1 && minor.toInt >= 8) Seq("-Xdoclint:all", "-Xdoclint:-missing") else Seq.empty
           },
       
      -    javacOptions in Compile ++= Seq("-encoding", "UTF-8")
      +    javacOptions in Compile ++= Seq("-encoding", "UTF-8"),
      +
      +    // Implements -Xfatal-warnings, ignoring deprecation warnings.
      +    // Code snippet taken from https://issues.scala-lang.org/browse/SI-8410.
      +    compile in Compile := {
      +      val analysis = (compile in Compile).value
      +      val s = streams.value
      +
      +      def logProblem(l: (=> String) => Unit, f: File, p: xsbti.Problem) = {
      +        l(f.toString + ":" + p.position.line.fold("")(_ + ":") + " " + p.message)
      +        l(p.position.lineContent)
      +        l("")
      +      }
      +
      +      var failed = 0
      +      analysis.infos.allInfos.foreach { case (k, i) =>
      +        i.reportedProblems foreach { p =>
      +          val deprecation = p.message.contains("is deprecated")
      +
      +          if (!deprecation) {
      +            failed = failed + 1
      +          }
      +
      +          logProblem(if (deprecation) s.log.warn else s.log.error, k, p)
      +        }
      +      }
      +
      +      if (failed > 0) {
      +        sys.error(s"$failed fatal warnings")
      +      }
      +      analysis
      +    }
         )
       
         def enable(settings: Seq[Setting[_]])(projectRef: ProjectRef) = {
      @@ -161,14 +196,13 @@ object SparkBuild extends PomBuild {
         // Note ordering of these settings matter.
         /* Enable shared settings on all projects */
         (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ Seq(spark, tools))
      -    .foreach(enable(sharedSettings ++ ExludedDependencies.settings ++ Revolver.settings))
      +    .foreach(enable(sharedSettings ++ ExcludedDependencies.settings ++ Revolver.settings))
       
         /* Enable tests settings for all projects except examples, assembly and tools */
         (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings))
       
      -  // TODO: remove launcher from this list after 1.4.0
         allProjects.filterNot(x => Seq(spark, hive, hiveThriftServer, catalyst, repl,
      -    networkCommon, networkShuffle, networkYarn, launcher, unsafe).contains(x)).foreach {
      +    networkCommon, networkShuffle, networkYarn, unsafe).contains(x)).foreach {
             x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)
           }
       
      @@ -178,6 +212,9 @@ object SparkBuild extends PomBuild {
         /* Enable Assembly for all assembly projects */
         assemblyProjects.foreach(enable(Assembly.settings))
       
      +  /* Enable Assembly for streamingMqtt test */
      +  enable(inConfig(Test)(Assembly.settings))(streamingMqtt)
      +
         /* Package pyspark artifacts in a separate zip file for YARN. */
         enable(PySparkAssembly.settings)(assembly)
       
      @@ -207,7 +244,7 @@ object SparkBuild extends PomBuild {
           fork := true,
           outputStrategy in run := Some (StdoutOutput),
       
      -    javaOptions ++= Seq("-Xmx2G", "-XX:MaxPermSize=1g"),
      +    javaOptions ++= Seq("-Xmx2G", "-XX:MaxPermSize=256m"),
       
           sparkShell := {
             (runMain in Compile).toTask(" org.apache.spark.repl.Main -usejavacp").value
      @@ -247,7 +284,7 @@ object Flume {
         This excludes library dependencies in sbt, which are specified in maven but are
         not needed by sbt build.
         */
      -object ExludedDependencies {
      +object ExcludedDependencies {
         lazy val settings = Seq(
           libraryDependencies ~= { libs => libs.filterNot(_.name == "groovy-all") }
         )
      @@ -282,6 +319,8 @@ object SQL {
         lazy val settings = Seq(
           initialCommands in console :=
             """
      +        |import org.apache.spark.SparkContext
      +        |import org.apache.spark.sql.SQLContext
               |import org.apache.spark.sql.catalyst.analysis._
               |import org.apache.spark.sql.catalyst.dsl._
               |import org.apache.spark.sql.catalyst.errors._
      @@ -291,20 +330,23 @@ object SQL {
               |import org.apache.spark.sql.catalyst.util._
               |import org.apache.spark.sql.execution
               |import org.apache.spark.sql.functions._
      -        |import org.apache.spark.sql.test.TestSQLContext._
      -        |import org.apache.spark.sql.types._""".stripMargin,
      -    cleanupCommands in console := "sparkContext.stop()"
      +        |import org.apache.spark.sql.types._
      +        |
      +        |val sc = new SparkContext("local[*]", "dev-shell")
      +        |val sqlContext = new SQLContext(sc)
      +        |import sqlContext.implicits._
      +        |import sqlContext._
      +      """.stripMargin,
      +    cleanupCommands in console := "sc.stop()"
         )
       }
       
       object Hive {
       
         lazy val settings = Seq(
      -    javaOptions += "-XX:MaxPermSize=1g",
      +    javaOptions += "-XX:MaxPermSize=256m",
           // Specially disable assertions since some Hive tests fail them
           javaOptions in Test := (javaOptions in Test).value.filterNot(_ == "-ea"),
      -    // Multiple queries rely on the TestHive singleton. See comments there for more details.
      -    parallelExecution in Test := false,
           // Supporting all SerDes requires us to depend on deprecated APIs, so we turn off the warnings
           // only for this subproject.
           scalacOptions <<= scalacOptions map { currentOpts: Seq[String] =>
      @@ -312,6 +354,7 @@ object Hive {
           },
           initialCommands in console :=
             """
      +        |import org.apache.spark.SparkContext
               |import org.apache.spark.sql.catalyst.analysis._
               |import org.apache.spark.sql.catalyst.dsl._
               |import org.apache.spark.sql.catalyst.errors._
      @@ -348,13 +391,16 @@ object Assembly {
               .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String])
           },
           jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) =>
      -      if (mName.contains("streaming-kafka-assembly")) {
      +      if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly") || mName.contains("streaming-mqtt-assembly") || mName.contains("streaming-kinesis-asl-assembly")) {
               // This must match the same name used in maven (see external/kafka-assembly/pom.xml)
               s"${mName}-${v}.jar"
             } else {
               s"${mName}-${v}-hadoop${hv}.jar"
             }
           },
      +    jarName in (Test, assembly) <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) =>
      +      s"${mName}-test-${v}.jar"
      +    },
           mergeStrategy in assembly := {
             case PathList("org", "datanucleus", xs @ _*)             => MergeStrategy.discard
             case m if m.toLowerCase.endsWith("manifest.mf")          => MergeStrategy.discard
      @@ -478,8 +524,8 @@ object Unidoc {
               "mllib.tree.impurity", "mllib.tree.model", "mllib.util",
               "mllib.evaluation", "mllib.feature", "mllib.random", "mllib.stat.correlation",
               "mllib.stat.test", "mllib.tree.impl", "mllib.tree.loss",
      -        "ml", "ml.attribute", "ml.classification", "ml.evaluation", "ml.feature", "ml.param",
      -        "ml.recommendation", "ml.regression", "ml.tuning"
      +        "ml", "ml.attribute", "ml.classification", "ml.clustering", "ml.evaluation", "ml.feature",
      +        "ml.param", "ml.recommendation", "ml.regression", "ml.tuning"
             ),
             "-group", "Spark SQL", packageList("sql.api.java", "sql.api.java.types", "sql.hive.api.java"),
             "-noqualifier", "java.lang"
      @@ -501,18 +547,21 @@ object TestSettings {
           envVars in Test ++= Map(
             "SPARK_DIST_CLASSPATH" ->
               (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":"),
      +      "SPARK_PREPEND_CLASSES" -> "1",
      +      "SPARK_TESTING" -> "1",
             "JAVA_HOME" -> sys.env.get("JAVA_HOME").getOrElse(sys.props("java.home"))),
           javaOptions in Test += s"-Djava.io.tmpdir=$testTempDir",
           javaOptions in Test += "-Dspark.test.home=" + sparkHome,
           javaOptions in Test += "-Dspark.testing=1",
           javaOptions in Test += "-Dspark.port.maxRetries=100",
      +    javaOptions in Test += "-Dspark.master.rest.enabled=false",
           javaOptions in Test += "-Dspark.ui.enabled=false",
           javaOptions in Test += "-Dspark.ui.showConsoleProgress=false",
           javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true",
           javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true",
           javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true",
           javaOptions in Test += "-Dderby.system.durability=test",
      -    javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark")
      +    javaOptions in Test ++= System.getProperties.asScala.filter(_._1.startsWith("spark"))
             .map { case (k,v) => s"-D$k=$v" }.toSeq,
           javaOptions in Test += "-ea",
           javaOptions in Test ++= "-Xmx3g -Xss4096k -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g"
      diff --git a/project/plugins.sbt b/project/plugins.sbt
      index 51820460ca1a..c06687d8f197 100644
      --- a/project/plugins.sbt
      +++ b/project/plugins.sbt
      @@ -1,5 +1,3 @@
      -scalaVersion := "2.10.4"
      -
       resolvers += Resolver.url("artifactory", url("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases"))(Resolver.ivyStylePatterns)
       
       resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/"
      diff --git a/pylintrc b/pylintrc
      new file mode 100644
      index 000000000000..6a675770da69
      --- /dev/null
      +++ b/pylintrc
      @@ -0,0 +1,404 @@
      +#
      +# Licensed to the Apache Software Foundation (ASF) under one or more
      +# contributor license agreements.  See the NOTICE file distributed with
      +# this work for additional information regarding copyright ownership.
      +# The ASF licenses this file to You under the Apache License, Version 2.0
      +# (the "License"); you may not use this file except in compliance with
      +# the License.  You may obtain a copy of the License at
      +#
      +#    http://www.apache.org/licenses/LICENSE-2.0
      +#
      +# Unless required by applicable law or agreed to in writing, software
      +# distributed under the License is distributed on an "AS IS" BASIS,
      +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      +# See the License for the specific language governing permissions and
      +# limitations under the License.
      +#
      +
      +[MASTER]
      +
      +# Specify a configuration file.
      +#rcfile=
      +
      +# Python code to execute, usually for sys.path manipulation such as
      +# pygtk.require().
      +#init-hook=
      +
      +# Profiled execution.
      +profile=no
      +
      +# Add files or directories to the blacklist. They should be base names, not
      +# paths.
      +ignore=pyspark.heapq3
      +
      +# Pickle collected data for later comparisons.
      +persistent=yes
      +
      +# List of plugins (as comma separated values of python modules names) to load,
      +# usually to register additional checkers.
      +load-plugins=
      +
      +# Use multiple processes to speed up Pylint.
      +jobs=1
      +
      +# Allow loading of arbitrary C extensions. Extensions are imported into the
      +# active Python interpreter and may run arbitrary code.
      +unsafe-load-any-extension=no
      +
      +# A comma-separated list of package or module names from where C extensions may
      +# be loaded. Extensions are loading into the active Python interpreter and may
      +# run arbitrary code
      +extension-pkg-whitelist=
      +
      +# Allow optimization of some AST trees. This will activate a peephole AST
      +# optimizer, which will apply various small optimizations. For instance, it can
      +# be used to obtain the result of joining multiple strings with the addition
      +# operator. Joining a lot of strings can lead to a maximum recursion error in
      +# Pylint and this flag can prevent that. It has one side effect, the resulting
      +# AST will be different than the one from reality.
      +optimize-ast=no
      +
      +
      +[MESSAGES CONTROL]
      +
      +# Only show warnings with the listed confidence levels. Leave empty to show
      +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
      +confidence=
      +
      +# Enable the message, report, category or checker with the given id(s). You can
      +# either give multiple identifier separated by comma (,) or put this option
      +# multiple time. See also the "--disable" option for examples.
      +enable=
      +
      +# Disable the message, report, category or checker with the given id(s). You
      +# can either give multiple identifiers separated by comma (,) or put this
      +# option multiple times (only on the command line, not in the configuration
      +# file where it should appear only once).You can also use "--disable=all" to
      +# disable everything first and then reenable specific checks. For example, if
      +# you want to run only the similarities checker, you can use "--disable=all
      +# --enable=similarities". If you want to run only the classes checker, but have
      +# no Warning level messages displayed, use"--disable=all --enable=classes
      +# --disable=W"
      +
      +# These errors are arranged in order of number of warning given in pylint.
      +# If you would like to improve the code quality of pyspark, remove any of these disabled errors
      +# run ./dev/lint-python and see if the errors raised by pylint can be fixed.
      +
      +disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable,undefined-loop-variable
      +
      +
      +[REPORTS]
      +
      +# Set the output format. Available formats are text, parseable, colorized, msvs
      +# (visual studio) and html. You can also give a reporter class, eg
      +# mypackage.mymodule.MyReporterClass.
      +output-format=text
      +
      +# Put messages in a separate file for each module / package specified on the
      +# command line instead of printing them on stdout. Reports (if any) will be
      +# written in a file name "pylint_global.[txt|html]".
      +files-output=no
      +
      +# Tells whether to display a full report or only the messages
      +reports=no
      +
      +# Python expression which should return a note less than 10 (10 is the highest
      +# note). You have access to the variables errors warning, statement which
      +# respectively contain the number of errors / warnings messages and the total
      +# number of statements analyzed. This is used by the global evaluation report
      +# (RP0004).
      +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
      +
      +# Add a comment according to your evaluation note. This is used by the global
      +# evaluation report (RP0004).
      +comment=no
      +
      +# Template used to display messages. This is a python new-style format string
      +# used to format the message information. See doc for all details
      +#msg-template=
      +
      +
      +[MISCELLANEOUS]
      +
      +# List of note tags to take in consideration, separated by a comma.
      +notes=FIXME,XXX,TODO
      +
      +
      +[BASIC]
      +
      +# Required attributes for module, separated by a comma
      +required-attributes=
      +
      +# List of builtins function names that should not be used, separated by a comma
      +bad-functions=
      +
      +# Good variable names which should always be accepted, separated by a comma
      +good-names=i,j,k,ex,Run,_
      +
      +# Bad variable names which should always be refused, separated by a comma
      +bad-names=baz,toto,tutu,tata
      +
      +# Colon-delimited sets of names that determine each other's naming style when
      +# the name regexes allow several styles.
      +name-group=
      +
      +# Include a hint for the correct naming format with invalid-name
      +include-naming-hint=no
      +
      +# Regular expression matching correct function names
      +function-rgx=[a-z_][a-z0-9_]{2,30}$
      +
      +# Naming hint for function names
      +function-name-hint=[a-z_][a-z0-9_]{2,30}$
      +
      +# Regular expression matching correct variable names
      +variable-rgx=[a-z_][a-z0-9_]{2,30}$
      +
      +# Naming hint for variable names
      +variable-name-hint=[a-z_][a-z0-9_]{2,30}$
      +
      +# Regular expression matching correct constant names
      +const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$
      +
      +# Naming hint for constant names
      +const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$
      +
      +# Regular expression matching correct attribute names
      +attr-rgx=[a-z_][a-z0-9_]{2,30}$
      +
      +# Naming hint for attribute names
      +attr-name-hint=[a-z_][a-z0-9_]{2,30}$
      +
      +# Regular expression matching correct argument names
      +argument-rgx=[a-z_][a-z0-9_]{2,30}$
      +
      +# Naming hint for argument names
      +argument-name-hint=[a-z_][a-z0-9_]{2,30}$
      +
      +# Regular expression matching correct class attribute names
      +class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
      +
      +# Naming hint for class attribute names
      +class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
      +
      +# Regular expression matching correct inline iteration names
      +inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$
      +
      +# Naming hint for inline iteration names
      +inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$
      +
      +# Regular expression matching correct class names
      +class-rgx=[A-Z_][a-zA-Z0-9]+$
      +
      +# Naming hint for class names
      +class-name-hint=[A-Z_][a-zA-Z0-9]+$
      +
      +# Regular expression matching correct module names
      +module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
      +
      +# Naming hint for module names
      +module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
      +
      +# Regular expression matching correct method names
      +method-rgx=[a-z_][a-z0-9_]{2,30}$
      +
      +# Naming hint for method names
      +method-name-hint=[a-z_][a-z0-9_]{2,30}$
      +
      +# Regular expression which should only match function or class names that do
      +# not require a docstring.
      +no-docstring-rgx=__.*__
      +
      +# Minimum line length for functions/classes that require docstrings, shorter
      +# ones are exempt.
      +docstring-min-length=-1
      +
      +
      +[FORMAT]
      +
      +# Maximum number of characters on a single line.
      +max-line-length=100
      +
      +# Regexp for a line that is allowed to be longer than the limit.
      +ignore-long-lines=^\s*(# )??$
      +
      +# Allow the body of an if to be on the same line as the test if there is no
      +# else.
      +single-line-if-stmt=no
      +
      +# List of optional constructs for which whitespace checking is disabled
      +no-space-check=trailing-comma,dict-separator
      +
      +# Maximum number of lines in a module
      +max-module-lines=1000
      +
      +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
      +# tab).
      +indent-string='    '
      +
      +# Number of spaces of indent required inside a hanging or continued line.
      +indent-after-paren=4
      +
      +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
      +expected-line-ending-format=
      +
      +
      +[SIMILARITIES]
      +
      +# Minimum lines number of a similarity.
      +min-similarity-lines=4
      +
      +# Ignore comments when computing similarities.
      +ignore-comments=yes
      +
      +# Ignore docstrings when computing similarities.
      +ignore-docstrings=yes
      +
      +# Ignore imports when computing similarities.
      +ignore-imports=no
      +
      +
      +[VARIABLES]
      +
      +# Tells whether we should check for unused import in __init__ files.
      +init-import=no
      +
      +# A regular expression matching the name of dummy variables (i.e. expectedly
      +# not used).
      +dummy-variables-rgx=_$|dummy
      +
      +# List of additional names supposed to be defined in builtins. Remember that
      +# you should avoid to define new builtins when possible.
      +additional-builtins=
      +
      +# List of strings which can identify a callback function by name. A callback
      +# name must start or end with one of those strings.
      +callbacks=cb_,_cb
      +
      +
      +[SPELLING]
      +
      +# Spelling dictionary name. Available dictionaries: none. To make it working
      +# install python-enchant package.
      +spelling-dict=
      +
      +# List of comma separated words that should not be checked.
      +spelling-ignore-words=
      +
      +# A path to a file that contains private dictionary; one word per line.
      +spelling-private-dict-file=
      +
      +# Tells whether to store unknown words to indicated private dictionary in
      +# --spelling-private-dict-file option instead of raising a message.
      +spelling-store-unknown-words=no
      +
      +
      +[LOGGING]
      +
      +# Logging modules to check that the string format arguments are in logging
      +# function parameter format
      +logging-modules=logging
      +
      +
      +[TYPECHECK]
      +
      +# Tells whether missing members accessed in mixin class should be ignored. A
      +# mixin class is detected if its name ends with "mixin" (case insensitive).
      +ignore-mixin-members=yes
      +
      +# List of module names for which member attributes should not be checked
      +# (useful for modules/projects where namespaces are manipulated during runtime
      +# and thus existing member attributes cannot be deduced by static analysis
      +ignored-modules=
      +
      +# List of classes names for which member attributes should not be checked
      +# (useful for classes with attributes dynamically set).
      +ignored-classes=SQLObject
      +
      +# When zope mode is activated, add a predefined set of Zope acquired attributes
      +# to generated-members.
      +zope=no
      +
      +# List of members which are set dynamically and missed by pylint inference
      +# system, and so shouldn't trigger E0201 when accessed. Python regular
      +# expressions are accepted.
      +generated-members=REQUEST,acl_users,aq_parent
      +
      +
      +[CLASSES]
      +
      +# List of interface methods to ignore, separated by a comma. This is used for
      +# instance to not check methods defines in Zope's Interface base class.
      +ignore-iface-methods=isImplementedBy,deferred,extends,names,namesAndDescriptions,queryDescriptionFor,getBases,getDescriptionFor,getDoc,getName,getTaggedValue,getTaggedValueTags,isEqualOrExtendedBy,setTaggedValue,isImplementedByInstancesOf,adaptWith,is_implemented_by
      +
      +# List of method names used to declare (i.e. assign) instance attributes.
      +defining-attr-methods=__init__,__new__,setUp
      +
      +# List of valid names for the first argument in a class method.
      +valid-classmethod-first-arg=cls
      +
      +# List of valid names for the first argument in a metaclass class method.
      +valid-metaclass-classmethod-first-arg=mcs
      +
      +# List of member names, which should be excluded from the protected access
      +# warning.
      +exclude-protected=_asdict,_fields,_replace,_source,_make
      +
      +
      +[IMPORTS]
      +
      +# Deprecated modules which should not be used, separated by a comma
      +deprecated-modules=regsub,TERMIOS,Bastion,rexec
      +
      +# Create a graph of every (i.e. internal and external) dependencies in the
      +# given file (report RP0402 must not be disabled)
      +import-graph=
      +
      +# Create a graph of external dependencies in the given file (report RP0402 must
      +# not be disabled)
      +ext-import-graph=
      +
      +# Create a graph of internal dependencies in the given file (report RP0402 must
      +# not be disabled)
      +int-import-graph=
      +
      +
      +[DESIGN]
      +
      +# Maximum number of arguments for function / method
      +max-args=5
      +
      +# Argument names that match this expression will be ignored. Default to name
      +# with leading underscore
      +ignored-argument-names=_.*
      +
      +# Maximum number of locals for function / method body
      +max-locals=15
      +
      +# Maximum number of return / yield for function / method body
      +max-returns=6
      +
      +# Maximum number of branch for function / method body
      +max-branches=12
      +
      +# Maximum number of statements in function / method body
      +max-statements=50
      +
      +# Maximum number of parents for a class (see R0901).
      +max-parents=7
      +
      +# Maximum number of attributes for a class (see R0902).
      +max-attributes=7
      +
      +# Minimum number of public methods for a class (see R0903).
      +min-public-methods=2
      +
      +# Maximum number of public methods for a class (see R0904).
      +max-public-methods=20
      +
      +
      +[EXCEPTIONS]
      +
      +# Exceptions that will emit a warning when being caught. Defaults to
      +# "Exception"
      +overgeneral-exceptions=Exception
      diff --git a/python/docs/index.rst b/python/docs/index.rst
      index f7eede9c3c82..306ffdb0e0f1 100644
      --- a/python/docs/index.rst
      +++ b/python/docs/index.rst
      @@ -29,6 +29,14 @@ Core classes:
       
           A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
       
      +    :class:`pyspark.streaming.StreamingContext`
      +
      +    Main entry point for Spark Streaming functionality.
      +
      +    :class:`pyspark.streaming.DStream`
      +
      +    A Discretized Stream (DStream), the basic abstraction in Spark Streaming.
      +
           :class:`pyspark.sql.SQLContext`
       
           Main entry point for DataFrame and SQL functionality.
      diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst
      index 518b8e774dd5..86d4186a2c79 100644
      --- a/python/docs/pyspark.ml.rst
      +++ b/python/docs/pyspark.ml.rst
      @@ -33,6 +33,14 @@ pyspark.ml.classification module
           :undoc-members:
           :inherited-members:
       
      +pyspark.ml.clustering module
      +----------------------------
      +
      +.. automodule:: pyspark.ml.clustering
      +    :members:
      +    :undoc-members:
      +    :inherited-members:
      +
       pyspark.ml.recommendation module
       --------------------------------
       
      diff --git a/python/docs/pyspark.mllib.rst b/python/docs/pyspark.mllib.rst
      index 26ece4c2c389..2d54ab118b94 100644
      --- a/python/docs/pyspark.mllib.rst
      +++ b/python/docs/pyspark.mllib.rst
      @@ -46,6 +46,14 @@ pyspark.mllib.linalg module
           :undoc-members:
           :show-inheritance:
       
      +pyspark.mllib.linalg.distributed module
      +---------------------------------------
      +
      +.. automodule:: pyspark.mllib.linalg.distributed
      +    :members:
      +    :undoc-members:
      +    :show-inheritance:
      +
       pyspark.mllib.random module
       ---------------------------
       
      diff --git a/python/docs/pyspark.streaming.rst b/python/docs/pyspark.streaming.rst
      index 50822c93faba..fc52a647543e 100644
      --- a/python/docs/pyspark.streaming.rst
      +++ b/python/docs/pyspark.streaming.rst
      @@ -15,3 +15,24 @@ pyspark.streaming.kafka module
           :members:
           :undoc-members:
           :show-inheritance:
      +
      +pyspark.streaming.kinesis module
      +--------------------------------
      +.. automodule:: pyspark.streaming.kinesis
      +    :members:
      +    :undoc-members:
      +    :show-inheritance:
      +
      +pyspark.streaming.flume.module
      +------------------------------
      +.. automodule:: pyspark.streaming.flume
      +    :members:
      +    :undoc-members:
      +    :show-inheritance:
      +
      +pyspark.streaming.mqtt module
      +-----------------------------
      +.. automodule:: pyspark.streaming.mqtt
      +    :members:
      +    :undoc-members:
      +    :show-inheritance:
      diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
      index 5f70ac6ed8fe..8475dfb1c6ad 100644
      --- a/python/pyspark/__init__.py
      +++ b/python/pyspark/__init__.py
      @@ -48,6 +48,22 @@
       from pyspark.status import *
       from pyspark.profiler import Profiler, BasicProfiler
       
      +
      +def since(version):
      +    """
      +    A decorator that annotates a function to append the version of Spark the function was added.
      +    """
      +    import re
      +    indent_p = re.compile(r'\n( +)')
      +
      +    def deco(f):
      +        indents = indent_p.findall(f.__doc__)
      +        indent = ' ' * (min(len(m) for m in indents) if indents else 0)
      +        f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version)
      +        return f
      +    return deco
      +
      +
       # for back compatibility
       from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row
       
      diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
      index adca90ddaf39..6ef8cf53cc74 100644
      --- a/python/pyspark/accumulators.py
      +++ b/python/pyspark/accumulators.py
      @@ -264,4 +264,6 @@ def _start_update_server():
       
       if __name__ == "__main__":
           import doctest
      -    doctest.testmod()
      +    (failure_count, test_count) = doctest.testmod()
      +    if failure_count:
      +        exit(-1)
      diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
      index 3de4615428bb..663c9abe0881 100644
      --- a/python/pyspark/broadcast.py
      +++ b/python/pyspark/broadcast.py
      @@ -115,4 +115,6 @@ def __reduce__(self):
       
       if __name__ == "__main__":
           import doctest
      -    doctest.testmod()
      +    (failure_count, test_count) = doctest.testmod()
      +    if failure_count:
      +        exit(-1)
      diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py
      index 9ef93071d2e7..95b3abc74244 100644
      --- a/python/pyspark/cloudpickle.py
      +++ b/python/pyspark/cloudpickle.py
      @@ -350,7 +350,31 @@ def save_global(self, obj, name=None, pack=struct.pack):
                   if new_override:
                       d['__new__'] = obj.__new__
       
      -            self.save_reduce(typ, (obj.__name__, obj.__bases__, d), obj=obj)
      +            # workaround for namedtuple (hijacked by PySpark)
      +            if getattr(obj, '_is_namedtuple_', False):
      +                self.save_reduce(_load_namedtuple, (obj.__name__, obj._fields))
      +                return
      +
      +            self.save(_load_class)
      +            self.save_reduce(typ, (obj.__name__, obj.__bases__, {"__doc__": obj.__doc__}), obj=obj)
      +            d.pop('__doc__', None)
      +            # handle property and staticmethod
      +            dd = {}
      +            for k, v in d.items():
      +                if isinstance(v, property):
      +                    k = ('property', k)
      +                    v = (v.fget, v.fset, v.fdel, v.__doc__)
      +                elif isinstance(v, staticmethod) and hasattr(v, '__func__'):
      +                    k = ('staticmethod', k)
      +                    v = v.__func__
      +                elif isinstance(v, classmethod) and hasattr(v, '__func__'):
      +                    k = ('classmethod', k)
      +                    v = v.__func__
      +                dd[k] = v
      +            self.save(dd)
      +            self.write(pickle.TUPLE2)
      +            self.write(pickle.REDUCE)
      +
               else:
                   raise pickle.PicklingError("Can't pickle %r" % obj)
       
      @@ -363,7 +387,7 @@ def save_instancemethod(self, obj):
                   self.save_reduce(types.MethodType, (obj.__func__, obj.__self__), obj=obj)
               else:
                   self.save_reduce(types.MethodType, (obj.__func__, obj.__self__, obj.__self__.__class__),
      -                         obj=obj)
      +                             obj=obj)
           dispatch[types.MethodType] = save_instancemethod
       
           def save_inst(self, obj):
      @@ -708,6 +732,31 @@ def _make_skel_func(code, closures, base_globals = None):
                                     None, None, closure)
       
       
      +def _load_class(cls, d):
      +    """
      +    Loads additional properties into class `cls`.
      +    """
      +    for k, v in d.items():
      +        if isinstance(k, tuple):
      +            typ, k = k
      +            if typ == 'property':
      +                v = property(*v)
      +            elif typ == 'staticmethod':
      +                v = staticmethod(v)
      +            elif typ == 'classmethod':
      +                v = classmethod(v)
      +        setattr(cls, k, v)
      +    return cls
      +
      +
      +def _load_namedtuple(name, fields):
      +    """
      +    Loads a class generated by namedtuple
      +    """
      +    from collections import namedtuple
      +    return namedtuple(name, fields)
      +
      +
       """Constructors for 3rd party libraries
       Note: These can never be renamed due to client compatibility issues"""
       
      diff --git a/python/pyspark/context.py b/python/pyspark/context.py
      index 90b2fffbb9c7..a0a1ccbeefb0 100644
      --- a/python/pyspark/context.py
      +++ b/python/pyspark/context.py
      @@ -152,6 +152,11 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
               self.master = self._conf.get("spark.master")
               self.appName = self._conf.get("spark.app.name")
               self.sparkHome = self._conf.get("spark.home", None)
      +
      +        # Let YARN know it's a pyspark app, so it distributes needed libraries.
      +        if self.master == "yarn-client":
      +            self._conf.set("spark.yarn.isPython", "true")
      +
               for (k, v) in self._conf.getAll():
                   if k.startswith("spark.executorEnv."):
                       varName = k[len("spark.executorEnv."):]
      @@ -250,7 +255,7 @@ def __getnewargs__(self):
               # This method is called when attempting to pickle SparkContext, which is always an error:
               raise Exception(
                   "It appears that you are attempting to reference SparkContext from a broadcast "
      -            "variable, action, or transforamtion. SparkContext can only be used on the driver, "
      +            "variable, action, or transformation. SparkContext can only be used on the driver, "
                   "not in code that it run on workers. For more information, see SPARK-5063."
               )
       
      @@ -291,6 +296,21 @@ def version(self):
               """
               return self._jsc.version()
       
      +    @property
      +    @ignore_unicode_prefix
      +    def applicationId(self):
      +        """
      +        A unique identifier for the Spark application.
      +        Its format depends on the scheduler implementation.
      +
      +        * in case of local spark app something like 'local-1433865536131'
      +        * in case of YARN something like 'application_1433865536131_34483'
      +
      +        >>> sc.applicationId  # doctest: +ELLIPSIS
      +        u'local-...'
      +        """
      +        return self._jsc.sc().applicationId()
      +
           @property
           def startTime(self):
               """Return the epoch time when the Spark Context was started."""
      @@ -893,8 +913,7 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
               # by runJob() in order to avoid having to pass a Python lambda into
               # SparkContext#runJob.
               mappedRDD = rdd.mapPartitions(partitionFunc)
      -        port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions,
      -                                          allowLocal)
      +        port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions)
               return list(_load_from_socket(port, mappedRDD._jrdd_deserializer))
       
           def show_profiles(self):
      diff --git a/python/pyspark/heapq3.py b/python/pyspark/heapq3.py
      index 4ef2afe03544..b27e91a4cc25 100644
      --- a/python/pyspark/heapq3.py
      +++ b/python/pyspark/heapq3.py
      @@ -883,6 +883,7 @@ def nlargest(n, iterable, key=None):
       
       
       if __name__ == "__main__":
      -
           import doctest
      -    print(doctest.testmod())
      +    (failure_count, test_count) = doctest.testmod()
      +    if failure_count:
      +        exit(-1)
      diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
      index 3cee4ea6e3a3..cd4c55f79f18 100644
      --- a/python/pyspark/java_gateway.py
      +++ b/python/pyspark/java_gateway.py
      @@ -51,6 +51,11 @@ def launch_gateway():
               on_windows = platform.system() == "Windows"
               script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit"
               submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
      +        if os.environ.get("SPARK_TESTING"):
      +            submit_args = ' '.join([
      +                "--conf spark.ui.enabled=false",
      +                submit_args
      +            ])
               command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args)
       
               # Start a socket that will be used by PythonGatewayServer to communicate its port to us
      diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
      index 7abbde8b260e..88815e561f57 100644
      --- a/python/pyspark/ml/classification.py
      +++ b/python/pyspark/ml/classification.py
      @@ -18,20 +18,25 @@
       from pyspark.ml.util import keyword_only
       from pyspark.ml.wrapper import JavaEstimator, JavaModel
       from pyspark.ml.param.shared import *
      -from pyspark.ml.regression import RandomForestParams
      +from pyspark.ml.regression import (
      +    RandomForestParams, DecisionTreeModel, TreeEnsembleModels)
       from pyspark.mllib.common import inherit_doc
       
       
       __all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassifier',
                  'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel',
      -           'RandomForestClassifier', 'RandomForestClassificationModel']
      +           'RandomForestClassifier', 'RandomForestClassificationModel', 'NaiveBayes',
      +           'NaiveBayesModel', 'MultilayerPerceptronClassifier',
      +           'MultilayerPerceptronClassificationModel']
       
       
       @inherit_doc
       class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
      -                         HasRegParam, HasTol, HasProbabilityCol):
      +                         HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
      +                         HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds):
           """
           Logistic regression.
      +    Currently, this class only supports binary classification.
       
           >>> from pyspark.sql import Row
           >>> from pyspark.mllib.linalg import Vectors
      @@ -40,13 +45,18 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
           ...     Row(label=0.0, features=Vectors.sparse(1, [], []))]).toDF()
           >>> lr = LogisticRegression(maxIter=5, regParam=0.01)
           >>> model = lr.fit(df)
      -    >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF()
      -    >>> model.transform(test0).head().prediction
      -    0.0
           >>> model.weights
           DenseVector([5.5...])
           >>> model.intercept
           -2.68...
      +    >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF()
      +    >>> result = model.transform(test0).head()
      +    >>> result.prediction
      +    0.0
      +    >>> result.probability
      +    DenseVector([0.99..., 0.00...])
      +    >>> result.rawPrediction
      +    DenseVector([8.22..., -8.22...])
           >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF()
           >>> model.transform(test1).head().prediction
           1.0
      @@ -57,96 +67,116 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
           """
       
           # a placeholder to make it appear in the generated doc
      -    elasticNetParam = \
      -        Param(Params._dummy(), "elasticNetParam",
      -              "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " +
      -              "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.")
      -    fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.")
           threshold = Param(Params._dummy(), "threshold",
      -                      "threshold in binary classification prediction, in range [0, 1].")
      +                      "Threshold in binary classification prediction, in range [0, 1]." +
      +                      " If threshold and thresholds are both set, they must match.")
       
           @keyword_only
           def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
                        maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
      -                 threshold=0.5, probabilityCol="probability"):
      +                 threshold=0.5, thresholds=None, probabilityCol="probability",
      +                 rawPredictionCol="rawPrediction", standardization=True):
               """
               __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
                        maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
      -                 threshold=0.5, probabilityCol="probability")
      +                 threshold=0.5, thresholds=None, probabilityCol="probability", \
      +                 rawPredictionCol="rawPrediction", standardization=True)
      +        If the threshold and thresholds Params are both set, they must be equivalent.
               """
               super(LogisticRegression, self).__init__()
               self._java_obj = self._new_java_obj(
                   "org.apache.spark.ml.classification.LogisticRegression", self.uid)
      -        #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty
      -        #  is an L2 penalty. For alpha = 1, it is an L1 penalty.
      -        self.elasticNetParam = \
      -            Param(self, "elasticNetParam",
      -                  "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty " +
      -                  "is an L2 penalty. For alpha = 1, it is an L1 penalty.")
      -        #: param for whether to fit an intercept term.
      -        self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.")
      -        #: param for threshold in binary classification prediction, in range [0, 1].
      +        #: param for threshold in binary classification, in range [0, 1].
               self.threshold = Param(self, "threshold",
      -                               "threshold in binary classification prediction, in range [0, 1].")
      -        self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6,
      -                         fitIntercept=True, threshold=0.5)
      +                               "Threshold in binary classification prediction, in range [0, 1]." +
      +                               " If threshold and thresholds are both set, they must match.")
      +        self._setDefault(maxIter=100, regParam=0.1, tol=1E-6, threshold=0.5)
               kwargs = self.__init__._input_kwargs
               self.setParams(**kwargs)
      +        self._checkThresholdConsistency()
       
           @keyword_only
           def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
                         maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
      -                  threshold=0.5, probabilityCol="probability"):
      +                  threshold=0.5, thresholds=None, probabilityCol="probability",
      +                  rawPredictionCol="rawPrediction", standardization=True):
               """
               setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
                         maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
      -                 threshold=0.5, probabilityCol="probability")
      +                  threshold=0.5, thresholds=None, probabilityCol="probability", \
      +                  rawPredictionCol="rawPrediction", standardization=True)
               Sets params for logistic regression.
      +        If the threshold and thresholds Params are both set, they must be equivalent.
               """
               kwargs = self.setParams._input_kwargs
      -        return self._set(**kwargs)
      +        self._set(**kwargs)
      +        self._checkThresholdConsistency()
      +        return self
       
           def _create_model(self, java_model):
               return LogisticRegressionModel(java_model)
       
      -    def setElasticNetParam(self, value):
      -        """
      -        Sets the value of :py:attr:`elasticNetParam`.
      -        """
      -        self._paramMap[self.elasticNetParam] = value
      -        return self
      -
      -    def getElasticNetParam(self):
      -        """
      -        Gets the value of elasticNetParam or its default value.
      -        """
      -        return self.getOrDefault(self.elasticNetParam)
      -
      -    def setFitIntercept(self, value):
      -        """
      -        Sets the value of :py:attr:`fitIntercept`.
      -        """
      -        self._paramMap[self.fitIntercept] = value
      -        return self
      -
      -    def getFitIntercept(self):
      -        """
      -        Gets the value of fitIntercept or its default value.
      -        """
      -        return self.getOrDefault(self.fitIntercept)
      -
           def setThreshold(self, value):
               """
               Sets the value of :py:attr:`threshold`.
      +        Clears value of :py:attr:`thresholds` if it has been set.
               """
               self._paramMap[self.threshold] = value
      +        if self.isSet(self.thresholds):
      +            del self._paramMap[self.thresholds]
               return self
       
           def getThreshold(self):
               """
               Gets the value of threshold or its default value.
               """
      -        return self.getOrDefault(self.threshold)
      +        self._checkThresholdConsistency()
      +        if self.isSet(self.thresholds):
      +            ts = self.getOrDefault(self.thresholds)
      +            if len(ts) != 2:
      +                raise ValueError("Logistic Regression getThreshold only applies to" +
      +                                 " binary classification, but thresholds has length != 2." +
      +                                 "  thresholds: " + ",".join(ts))
      +            return 1.0/(1.0 + ts[0]/ts[1])
      +        else:
      +            return self.getOrDefault(self.threshold)
      +
      +    def setThresholds(self, value):
      +        """
      +        Sets the value of :py:attr:`thresholds`.
      +        Clears value of :py:attr:`threshold` if it has been set.
      +        """
      +        self._paramMap[self.thresholds] = value
      +        if self.isSet(self.threshold):
      +            del self._paramMap[self.threshold]
      +        return self
      +
      +    def getThresholds(self):
      +        """
      +        If :py:attr:`thresholds` is set, return its value.
      +        Otherwise, if :py:attr:`threshold` is set, return the equivalent thresholds for binary
      +        classification: (1-threshold, threshold).
      +        If neither are set, throw an error.
      +        """
      +        self._checkThresholdConsistency()
      +        if not self.isSet(self.thresholds) and self.isSet(self.threshold):
      +            t = self.getOrDefault(self.threshold)
      +            return [1.0-t, t]
      +        else:
      +            return self.getOrDefault(self.thresholds)
      +
      +    def _checkThresholdConsistency(self):
      +        if self.isSet(self.threshold) and self.isSet(self.thresholds):
      +            ts = self.getParam(self.thresholds)
      +            if len(ts) != 2:
      +                raise ValueError("Logistic Regression getThreshold only applies to" +
      +                                 " binary classification, but thresholds has length != 2." +
      +                                 " thresholds: " + ",".join(ts))
      +            t = 1.0/(1.0 + ts[0]/ts[1])
      +            t2 = self.getParam(self.threshold)
      +            if abs(t2 - t) >= 1E-5:
      +                raise ValueError("Logistic Regression getThreshold found inconsistent values for" +
      +                                 " threshold (%g) and thresholds (equivalent to %g)" % (t2, t))
       
       
       class LogisticRegressionModel(JavaModel):
      @@ -185,7 +215,8 @@ class GBTParams(object):
       
       @inherit_doc
       class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
      -                             DecisionTreeParams, HasCheckpointInterval):
      +                             HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams,
      +                             HasCheckpointInterval):
           """
           `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree`
           learning algorithm for classification.
      @@ -202,9 +233,18 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
           >>> td = si_model.transform(df)
           >>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed")
           >>> model = dt.fit(td)
      +    >>> model.numNodes
      +    3
      +    >>> model.depth
      +    1
           >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
      -    >>> model.transform(test0).head().prediction
      +    >>> result = model.transform(test0).head()
      +    >>> result.prediction
           0.0
      +    >>> result.probability
      +    DenseVector([1.0, 0.0])
      +    >>> result.rawPrediction
      +    DenseVector([1.0, 0.0])
           >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
           >>> model.transform(test1).head().prediction
           1.0
      @@ -217,10 +257,12 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
       
           @keyword_only
           def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
      +                 probabilityCol="probability", rawPredictionCol="rawPrediction",
                        maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
                        maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini"):
               """
               __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
      +                 probabilityCol="probability", rawPredictionCol="rawPrediction", \
                        maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
                        maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini")
               """
      @@ -240,11 +282,13 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
       
           @keyword_only
           def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
      +                  probabilityCol="probability", rawPredictionCol="rawPrediction",
                         maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
                         maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
                         impurity="gini"):
               """
               setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
      +                  probabilityCol="probability", rawPredictionCol="rawPrediction", \
                         maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
                         maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini")
               Sets params for the DecisionTreeClassifier.
      @@ -269,7 +313,8 @@ def getImpurity(self):
               return self.getOrDefault(self.impurity)
       
       
      -class DecisionTreeClassificationModel(JavaModel):
      +@inherit_doc
      +class DecisionTreeClassificationModel(DecisionTreeModel):
           """
           Model fitted by DecisionTreeClassifier.
           """
      @@ -277,6 +322,7 @@ class DecisionTreeClassificationModel(JavaModel):
       
       @inherit_doc
       class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed,
      +                             HasRawPredictionCol, HasProbabilityCol,
                                    DecisionTreeParams, HasCheckpointInterval):
           """
           `http://en.wikipedia.org/wiki/Random_forest  Random Forest`
      @@ -284,6 +330,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
           It supports both binary and multiclass labels, as well as both continuous and categorical
           features.
       
      +    >>> import numpy
      +    >>> from numpy import allclose
           >>> from pyspark.mllib.linalg import Vectors
           >>> from pyspark.ml.feature import StringIndexer
           >>> df = sqlContext.createDataFrame([
      @@ -292,11 +340,18 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
           >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
           >>> si_model = stringIndexer.fit(df)
           >>> td = si_model.transform(df)
      -    >>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed", seed=42)
      +    >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42)
           >>> model = rf.fit(td)
      +    >>> allclose(model.treeWeights, [1.0, 1.0, 1.0])
      +    True
           >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
      -    >>> model.transform(test0).head().prediction
      +    >>> result = model.transform(test0).head()
      +    >>> result.prediction
           0.0
      +    >>> numpy.argmax(result.probability)
      +    0
      +    >>> numpy.argmax(result.rawPrediction)
      +    0
           >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
           >>> model.transform(test1).head().prediction
           1.0
      @@ -317,11 +372,13 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
       
           @keyword_only
           def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
      +                 probabilityCol="probability", rawPredictionCol="rawPrediction",
                        maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
                        maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
                        numTrees=20, featureSubsetStrategy="auto", seed=None):
               """
               __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
      +                 probabilityCol="probability", rawPredictionCol="rawPrediction", \
                        maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
                        maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
                        numTrees=20, featureSubsetStrategy="auto", seed=None)
      @@ -354,11 +411,13 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
       
           @keyword_only
           def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
      +                  probabilityCol="probability", rawPredictionCol="rawPrediction",
                         maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
                         maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
                         impurity="gini", numTrees=20, featureSubsetStrategy="auto"):
               """
               setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
      +                 probabilityCol="probability", rawPredictionCol="rawPrediction", \
                         maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
                         maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \
                         impurity="gini", numTrees=20, featureSubsetStrategy="auto")
      @@ -423,7 +482,7 @@ def getFeatureSubsetStrategy(self):
               return self.getOrDefault(self.featureSubsetStrategy)
       
       
      -class RandomForestClassificationModel(JavaModel):
      +class RandomForestClassificationModel(TreeEnsembleModels):
           """
           Model fitted by RandomForestClassifier.
           """
      @@ -438,6 +497,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
           It supports binary labels, as well as both continuous and categorical features.
           Note: Multiclass labels are not currently supported.
       
      +    >>> from numpy import allclose
           >>> from pyspark.mllib.linalg import Vectors
           >>> from pyspark.ml.feature import StringIndexer
           >>> df = sqlContext.createDataFrame([
      @@ -448,6 +508,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
           >>> td = si_model.transform(df)
           >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed")
           >>> model = gbt.fit(td)
      +    >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
      +    True
           >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
           >>> model.transform(test0).head().prediction
           0.0
      @@ -558,12 +620,271 @@ def getStepSize(self):
               return self.getOrDefault(self.stepSize)
       
       
      -class GBTClassificationModel(JavaModel):
      +class GBTClassificationModel(TreeEnsembleModels):
           """
           Model fitted by GBTClassifier.
           """
       
       
      +@inherit_doc
      +class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol,
      +                 HasRawPredictionCol):
      +    """
      +    Naive Bayes Classifiers.
      +    It supports both Multinomial and Bernoulli NB. Multinomial NB
      +    (`http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html`)
      +    can handle finitely supported discrete data. For example, by converting documents into
      +    TF-IDF vectors, it can be used for document classification. By making every vector a
      +    binary (0/1) data, it can also be used as Bernoulli NB
      +    (`http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html`).
      +    The input feature values must be nonnegative.
      +
      +    >>> from pyspark.sql import Row
      +    >>> from pyspark.mllib.linalg import Vectors
      +    >>> df = sqlContext.createDataFrame([
      +    ...     Row(label=0.0, features=Vectors.dense([0.0, 0.0])),
      +    ...     Row(label=0.0, features=Vectors.dense([0.0, 1.0])),
      +    ...     Row(label=1.0, features=Vectors.dense([1.0, 0.0]))])
      +    >>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial")
      +    >>> model = nb.fit(df)
      +    >>> model.pi
      +    DenseVector([-0.51..., -0.91...])
      +    >>> model.theta
      +    DenseMatrix(2, 2, [-1.09..., -0.40..., -0.40..., -1.09...], 1)
      +    >>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF()
      +    >>> result = model.transform(test0).head()
      +    >>> result.prediction
      +    1.0
      +    >>> result.probability
      +    DenseVector([0.42..., 0.57...])
      +    >>> result.rawPrediction
      +    DenseVector([-1.60..., -1.32...])
      +    >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()
      +    >>> model.transform(test1).head().prediction
      +    1.0
      +    """
      +
      +    # a placeholder to make it appear in the generated doc
      +    smoothing = Param(Params._dummy(), "smoothing", "The smoothing parameter, should be >= 0, " +
      +                      "default is 1.0")
      +    modelType = Param(Params._dummy(), "modelType", "The model type which is a string " +
      +                      "(case-sensitive). Supported options: multinomial (default) and bernoulli.")
      +
      +    @keyword_only
      +    def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
      +                 probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0,
      +                 modelType="multinomial"):
      +        """
      +        __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
      +                 probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \
      +                 modelType="multinomial")
      +        """
      +        super(NaiveBayes, self).__init__()
      +        self._java_obj = self._new_java_obj(
      +            "org.apache.spark.ml.classification.NaiveBayes", self.uid)
      +        #: param for the smoothing parameter.
      +        self.smoothing = Param(self, "smoothing", "The smoothing parameter, should be >= 0, " +
      +                               "default is 1.0")
      +        #: param for the model type.
      +        self.modelType = Param(self, "modelType", "The model type which is a string " +
      +                               "(case-sensitive). Supported options: multinomial (default) " +
      +                               "and bernoulli.")
      +        self._setDefault(smoothing=1.0, modelType="multinomial")
      +        kwargs = self.__init__._input_kwargs
      +        self.setParams(**kwargs)
      +
      +    @keyword_only
      +    def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
      +                  probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0,
      +                  modelType="multinomial"):
      +        """
      +        setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
      +                  probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \
      +                  modelType="multinomial")
      +        Sets params for Naive Bayes.
      +        """
      +        kwargs = self.setParams._input_kwargs
      +        return self._set(**kwargs)
      +
      +    def _create_model(self, java_model):
      +        return NaiveBayesModel(java_model)
      +
      +    def setSmoothing(self, value):
      +        """
      +        Sets the value of :py:attr:`smoothing`.
      +        """
      +        self._paramMap[self.smoothing] = value
      +        return self
      +
      +    def getSmoothing(self):
      +        """
      +        Gets the value of smoothing or its default value.
      +        """
      +        return self.getOrDefault(self.smoothing)
      +
      +    def setModelType(self, value):
      +        """
      +        Sets the value of :py:attr:`modelType`.
      +        """
      +        self._paramMap[self.modelType] = value
      +        return self
      +
      +    def getModelType(self):
      +        """
      +        Gets the value of modelType or its default value.
      +        """
      +        return self.getOrDefault(self.modelType)
      +
      +
      +class NaiveBayesModel(JavaModel):
      +    """
      +    Model fitted by NaiveBayes.
      +    """
      +
      +    @property
      +    def pi(self):
      +        """
      +        log of class priors.
      +        """
      +        return self._call_java("pi")
      +
      +    @property
      +    def theta(self):
      +        """
      +        log of class conditional probabilities.
      +        """
      +        return self._call_java("theta")
      +
      +
      +@inherit_doc
      +class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
      +                                     HasMaxIter, HasTol, HasSeed):
      +    """
      +    Classifier trainer based on the Multilayer Perceptron.
      +    Each layer has sigmoid activation function, output layer has softmax.
      +    Number of inputs has to be equal to the size of feature vectors.
      +    Number of outputs has to be equal to the total number of labels.
      +
      +    >>> from pyspark.mllib.linalg import Vectors
      +    >>> df = sqlContext.createDataFrame([
      +    ...     (0.0, Vectors.dense([0.0, 0.0])),
      +    ...     (1.0, Vectors.dense([0.0, 1.0])),
      +    ...     (1.0, Vectors.dense([1.0, 0.0])),
      +    ...     (0.0, Vectors.dense([1.0, 1.0]))], ["label", "features"])
      +    >>> mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[2, 5, 2], blockSize=1, seed=11)
      +    >>> model = mlp.fit(df)
      +    >>> model.layers
      +    [2, 5, 2]
      +    >>> model.weights.size
      +    27
      +    >>> testDF = sqlContext.createDataFrame([
      +    ...     (Vectors.dense([1.0, 0.0]),),
      +    ...     (Vectors.dense([0.0, 0.0]),)], ["features"])
      +    >>> model.transform(testDF).show()
      +    +---------+----------+
      +    | features|prediction|
      +    +---------+----------+
      +    |[1.0,0.0]|       1.0|
      +    |[0.0,0.0]|       0.0|
      +    +---------+----------+
      +    ...
      +    """
      +
      +    # a placeholder to make it appear in the generated doc
      +    layers = Param(Params._dummy(), "layers", "Sizes of layers from input layer to output layer " +
      +                   "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 " +
      +                   "neurons and output layer of 10 neurons, default is [1, 1].")
      +    blockSize = Param(Params._dummy(), "blockSize", "Block size for stacking input data in " +
      +                      "matrices. Data is stacked within partitions. If block size is more than " +
      +                      "remaining data in a partition then it is adjusted to the size of this " +
      +                      "data. Recommended size is between 10 and 1000, default is 128.")
      +
      +    @keyword_only
      +    def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
      +                 maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128):
      +        """
      +        __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
      +                 maxIter=100, tol=1e-4, seed=None, layers=[1, 1], blockSize=128)
      +        """
      +        super(MultilayerPerceptronClassifier, self).__init__()
      +        self._java_obj = self._new_java_obj(
      +            "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid)
      +        self.layers = Param(self, "layers", "Sizes of layers from input layer to output layer " +
      +                            "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with " +
      +                            "100 neurons and output layer of 10 neurons, default is [1, 1].")
      +        self.blockSize = Param(self, "blockSize", "Block size for stacking input data in " +
      +                               "matrices. Data is stacked within partitions. If block size is " +
      +                               "more than remaining data in a partition then it is adjusted to " +
      +                               "the size of this data. Recommended size is between 10 and 1000, " +
      +                               "default is 128.")
      +        self._setDefault(maxIter=100, tol=1E-4, layers=[1, 1], blockSize=128)
      +        kwargs = self.__init__._input_kwargs
      +        self.setParams(**kwargs)
      +
      +    @keyword_only
      +    def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
      +                  maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128):
      +        """
      +        setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
      +                  maxIter=100, tol=1e-4, seed=None, layers=[1, 1], blockSize=128)
      +        Sets params for MultilayerPerceptronClassifier.
      +        """
      +        kwargs = self.setParams._input_kwargs
      +        if layers is None:
      +            return self._set(**kwargs).setLayers([1, 1])
      +        else:
      +            return self._set(**kwargs)
      +
      +    def _create_model(self, java_model):
      +        return MultilayerPerceptronClassificationModel(java_model)
      +
      +    def setLayers(self, value):
      +        """
      +        Sets the value of :py:attr:`layers`.
      +        """
      +        self._paramMap[self.layers] = value
      +        return self
      +
      +    def getLayers(self):
      +        """
      +        Gets the value of layers or its default value.
      +        """
      +        return self.getOrDefault(self.layers)
      +
      +    def setBlockSize(self, value):
      +        """
      +        Sets the value of :py:attr:`blockSize`.
      +        """
      +        self._paramMap[self.blockSize] = value
      +        return self
      +
      +    def getBlockSize(self):
      +        """
      +        Gets the value of blockSize or its default value.
      +        """
      +        return self.getOrDefault(self.blockSize)
      +
      +
      +class MultilayerPerceptronClassificationModel(JavaModel):
      +    """
      +    Model fitted by MultilayerPerceptronClassifier.
      +    """
      +
      +    @property
      +    def layers(self):
      +        """
      +        array of layer sizes including input and output layers.
      +        """
      +        return self._call_java("javaLayers")
      +
      +    @property
      +    def weights(self):
      +        """
      +        vector of initial weights for the model that consists of the weights of layers.
      +        """
      +        return self._call_java("weights")
      +
      +
       if __name__ == "__main__":
           import doctest
           from pyspark.context import SparkContext
      diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
      new file mode 100644
      index 000000000000..cb4c16e25a7a
      --- /dev/null
      +++ b/python/pyspark/ml/clustering.py
      @@ -0,0 +1,171 @@
      +#
      +# Licensed to the Apache Software Foundation (ASF) under one or more
      +# contributor license agreements.  See the NOTICE file distributed with
      +# this work for additional information regarding copyright ownership.
      +# The ASF licenses this file to You under the Apache License, Version 2.0
      +# (the "License"); you may not use this file except in compliance with
      +# the License.  You may obtain a copy of the License at
      +#
      +#    http://www.apache.org/licenses/LICENSE-2.0
      +#
      +# Unless required by applicable law or agreed to in writing, software
      +# distributed under the License is distributed on an "AS IS" BASIS,
      +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      +# See the License for the specific language governing permissions and
      +# limitations under the License.
      +#
      +
      +from pyspark.ml.util import keyword_only
      +from pyspark.ml.wrapper import JavaEstimator, JavaModel
      +from pyspark.ml.param.shared import *
      +from pyspark.mllib.common import inherit_doc
      +
      +__all__ = ['KMeans', 'KMeansModel']
      +
      +
      +class KMeansModel(JavaModel):
      +    """
      +    Model fitted by KMeans.
      +    """
      +
      +    def clusterCenters(self):
      +        """Get the cluster centers, represented as a list of NumPy arrays."""
      +        return [c.toArray() for c in self._call_java("clusterCenters")]
      +
      +
      +@inherit_doc
      +class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed):
      +    """
      +    K-means clustering with support for multiple parallel runs and a k-means++ like initialization
      +    mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested,
      +    they are executed together with joint passes over the data for efficiency.
      +
      +    >>> from pyspark.mllib.linalg import Vectors
      +    >>> data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),),
      +    ...         (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)]
      +    >>> df = sqlContext.createDataFrame(data, ["features"])
      +    >>> kmeans = KMeans(k=2, seed=1)
      +    >>> model = kmeans.fit(df)
      +    >>> centers = model.clusterCenters()
      +    >>> len(centers)
      +    2
      +    >>> transformed = model.transform(df).select("features", "prediction")
      +    >>> rows = transformed.collect()
      +    >>> rows[0].prediction == rows[1].prediction
      +    True
      +    >>> rows[2].prediction == rows[3].prediction
      +    True
      +    """
      +
      +    # a placeholder to make it appear in the generated doc
      +    k = Param(Params._dummy(), "k", "number of clusters to create")
      +    initMode = Param(Params._dummy(), "initMode",
      +                     "the initialization algorithm. This can be either \"random\" to " +
      +                     "choose random points as initial cluster centers, or \"k-means||\" " +
      +                     "to use a parallel variant of k-means++")
      +    initSteps = Param(Params._dummy(), "initSteps", "steps for k-means initialization mode")
      +
      +    @keyword_only
      +    def __init__(self, featuresCol="features", predictionCol="prediction", k=2,
      +                 initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None):
      +        """
      +        __init__(self, featuresCol="features", predictionCol="prediction", k=2, \
      +                 initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None)
      +        """
      +        super(KMeans, self).__init__()
      +        self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid)
      +        self.k = Param(self, "k", "number of clusters to create")
      +        self.initMode = Param(self, "initMode",
      +                              "the initialization algorithm. This can be either \"random\" to " +
      +                              "choose random points as initial cluster centers, or \"k-means||\" " +
      +                              "to use a parallel variant of k-means++")
      +        self.initSteps = Param(self, "initSteps", "steps for k-means initialization mode")
      +        self._setDefault(k=2, initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20)
      +        kwargs = self.__init__._input_kwargs
      +        self.setParams(**kwargs)
      +
      +    def _create_model(self, java_model):
      +        return KMeansModel(java_model)
      +
      +    @keyword_only
      +    def setParams(self, featuresCol="features", predictionCol="prediction", k=2,
      +                  initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None):
      +        """
      +        setParams(self, featuresCol="features", predictionCol="prediction", k=2, \
      +                  initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None)
      +
      +        Sets params for KMeans.
      +        """
      +        kwargs = self.setParams._input_kwargs
      +        return self._set(**kwargs)
      +
      +    def setK(self, value):
      +        """
      +        Sets the value of :py:attr:`k`.
      +
      +        >>> algo = KMeans().setK(10)
      +        >>> algo.getK()
      +        10
      +        """
      +        self._paramMap[self.k] = value
      +        return self
      +
      +    def getK(self):
      +        """
      +        Gets the value of `k`
      +        """
      +        return self.getOrDefault(self.k)
      +
      +    def setInitMode(self, value):
      +        """
      +        Sets the value of :py:attr:`initMode`.
      +
      +        >>> algo = KMeans()
      +        >>> algo.getInitMode()
      +        'k-means||'
      +        >>> algo = algo.setInitMode("random")
      +        >>> algo.getInitMode()
      +        'random'
      +        """
      +        self._paramMap[self.initMode] = value
      +        return self
      +
      +    def getInitMode(self):
      +        """
      +        Gets the value of `initMode`
      +        """
      +        return self.getOrDefault(self.initMode)
      +
      +    def setInitSteps(self, value):
      +        """
      +        Sets the value of :py:attr:`initSteps`.
      +
      +        >>> algo = KMeans().setInitSteps(10)
      +        >>> algo.getInitSteps()
      +        10
      +        """
      +        self._paramMap[self.initSteps] = value
      +        return self
      +
      +    def getInitSteps(self):
      +        """
      +        Gets the value of `initSteps`
      +        """
      +        return self.getOrDefault(self.initSteps)
      +
      +
      +if __name__ == "__main__":
      +    import doctest
      +    from pyspark.context import SparkContext
      +    from pyspark.sql import SQLContext
      +    globs = globals().copy()
      +    # The small batch size here ensures that we see multiple batches,
      +    # even in these small test examples:
      +    sc = SparkContext("local[2]", "ml.clustering tests")
      +    sqlContext = SQLContext(sc)
      +    globs['sc'] = sc
      +    globs['sqlContext'] = sqlContext
      +    (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
      +    sc.stop()
      +    if failure_count:
      +        exit(-1)
      diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
      index 595593a7f2cd..cb3b07947e48 100644
      --- a/python/pyspark/ml/evaluation.py
      +++ b/python/pyspark/ml/evaluation.py
      @@ -23,7 +23,8 @@
       from pyspark.ml.util import keyword_only
       from pyspark.mllib.common import inherit_doc
       
      -__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator']
      +__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator',
      +           'MulticlassClassificationEvaluator']
       
       
       @inherit_doc
      @@ -45,7 +46,7 @@ def _evaluate(self, dataset):
               """
               raise NotImplementedError()
       
      -    def evaluate(self, dataset, params={}):
      +    def evaluate(self, dataset, params=None):
               """
               Evaluates the output with optional parameters.
       
      @@ -55,6 +56,8 @@ def evaluate(self, dataset, params={}):
                              params
               :return: metric
               """
      +        if params is None:
      +            params = dict()
               if isinstance(params, dict):
                   if params:
                       return self.copy(params)._evaluate(dataset)
      @@ -63,6 +66,14 @@ def evaluate(self, dataset, params={}):
               else:
                   raise ValueError("Params must be a param map but got %s." % type(params))
       
      +    def isLargerBetter(self):
      +        """
      +        Indicates whether the metric returned by :py:meth:`evaluate` should be maximized
      +        (True, default) or minimized (False).
      +        A given evaluator may support multiple metrics which may be maximized or minimized.
      +        """
      +        return True
      +
       
       @inherit_doc
       class JavaEvaluator(Evaluator, JavaWrapper):
      @@ -82,6 +93,10 @@ def _evaluate(self, dataset):
               self._transfer_params_to_java()
               return self._java_obj.evaluate(dataset._jdf)
       
      +    def isLargerBetter(self):
      +        self._transfer_params_to_java()
      +        return self._java_obj.isLargerBetter()
      +
       
       @inherit_doc
       class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol):
      @@ -160,11 +175,11 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
           ...
           >>> evaluator = RegressionEvaluator(predictionCol="raw")
           >>> evaluator.evaluate(dataset)
      -    -2.842...
      +    2.842...
           >>> evaluator.evaluate(dataset, {evaluator.metricName: "r2"})
           0.993...
           >>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"})
      -    -2.649...
      +    2.649...
           """
           # Because we will maximize evaluation value (ref: `CrossValidator`),
           # when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`),
      @@ -214,6 +229,72 @@ def setParams(self, predictionCol="prediction", labelCol="label",
               kwargs = self.setParams._input_kwargs
               return self._set(**kwargs)
       
      +
      +@inherit_doc
      +class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
      +    """
      +    Evaluator for Multiclass Classification, which expects two input
      +    columns: prediction and label.
      +    >>> scoreAndLabels = [(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
      +    ...     (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]
      +    >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["prediction", "label"])
      +    ...
      +    >>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction")
      +    >>> evaluator.evaluate(dataset)
      +    0.66...
      +    >>> evaluator.evaluate(dataset, {evaluator.metricName: "precision"})
      +    0.66...
      +    >>> evaluator.evaluate(dataset, {evaluator.metricName: "recall"})
      +    0.66...
      +    """
      +    # a placeholder to make it appear in the generated doc
      +    metricName = Param(Params._dummy(), "metricName",
      +                       "metric name in evaluation "
      +                       "(f1|precision|recall|weightedPrecision|weightedRecall)")
      +
      +    @keyword_only
      +    def __init__(self, predictionCol="prediction", labelCol="label",
      +                 metricName="f1"):
      +        """
      +        __init__(self, predictionCol="prediction", labelCol="label", \
      +                 metricName="f1")
      +        """
      +        super(MulticlassClassificationEvaluator, self).__init__()
      +        self._java_obj = self._new_java_obj(
      +            "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid)
      +        # param for metric name in evaluation (f1|precision|recall|weightedPrecision|weightedRecall)
      +        self.metricName = Param(self, "metricName",
      +                                "metric name in evaluation"
      +                                " (f1|precision|recall|weightedPrecision|weightedRecall)")
      +        self._setDefault(predictionCol="prediction", labelCol="label",
      +                         metricName="f1")
      +        kwargs = self.__init__._input_kwargs
      +        self._set(**kwargs)
      +
      +    def setMetricName(self, value):
      +        """
      +        Sets the value of :py:attr:`metricName`.
      +        """
      +        self._paramMap[self.metricName] = value
      +        return self
      +
      +    def getMetricName(self):
      +        """
      +        Gets the value of metricName or its default value.
      +        """
      +        return self.getOrDefault(self.metricName)
      +
      +    @keyword_only
      +    def setParams(self, predictionCol="prediction", labelCol="label",
      +                  metricName="f1"):
      +        """
      +        setParams(self, predictionCol="prediction", labelCol="label", \
      +                  metricName="f1")
      +        Sets params for multiclass classification evaluator.
      +        """
      +        kwargs = self.setParams._input_kwargs
      +        return self._set(**kwargs)
      +
       if __name__ == "__main__":
           import doctest
           from pyspark.context import SparkContext
      diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
      index ddb33f427ac6..92db8df80280 100644
      --- a/python/pyspark/ml/feature.py
      +++ b/python/pyspark/ml/feature.py
      @@ -15,21 +15,30 @@
       # limitations under the License.
       #
       
      +import sys
      +if sys.version > '3':
      +    basestring = str
      +
       from pyspark.rdd import ignore_unicode_prefix
       from pyspark.ml.param.shared import *
       from pyspark.ml.util import keyword_only
      -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer
      +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm
       from pyspark.mllib.common import inherit_doc
      +from pyspark.mllib.linalg import _convert_to_vector
       
      -__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'Normalizer', 'OneHotEncoder',
      -           'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel',
      -           'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer',
      -           'Word2Vec', 'Word2VecModel']
      +__all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel',
      +           'IndexToString', 'MinMaxScaler', 'MinMaxScalerModel', 'NGram', 'Normalizer',
      +           'OneHotEncoder', 'PCA', 'PCAModel', 'PolynomialExpansion', 'RegexTokenizer',
      +           'RFormula', 'RFormulaModel', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel',
      +           'StopWordsRemover', 'StringIndexer', 'StringIndexerModel', 'Tokenizer',
      +           'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel']
       
       
       @inherit_doc
       class Binarizer(JavaTransformer, HasInputCol, HasOutputCol):
           """
      +    .. note:: Experimental
      +
           Binarize a column of continuous features given a threshold.
       
           >>> df = sqlContext.createDataFrame([(0.5,)], ["values"])
      @@ -86,6 +95,8 @@ def getThreshold(self):
       @inherit_doc
       class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol):
           """
      +    .. note:: Experimental
      +
           Maps a column of continuous features to a column of feature buckets.
       
           >>> df = sqlContext.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"])
      @@ -160,9 +171,135 @@ def getSplits(self):
               return self.getOrDefault(self.splits)
       
       
      +@inherit_doc
      +class DCT(JavaTransformer, HasInputCol, HasOutputCol):
      +    """
      +    .. note:: Experimental
      +
      +    A feature transformer that takes the 1D discrete cosine transform
      +    of a real vector. No zero padding is performed on the input vector.
      +    It returns a real vector of the same length representing the DCT.
      +    The return vector is scaled such that the transform matrix is
      +    unitary (aka scaled DCT-II).
      +
      +    More information on
      +    `https://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II Wikipedia`.
      +
      +    >>> from pyspark.mllib.linalg import Vectors
      +    >>> df1 = sqlContext.createDataFrame([(Vectors.dense([5.0, 8.0, 6.0]),)], ["vec"])
      +    >>> dct = DCT(inverse=False, inputCol="vec", outputCol="resultVec")
      +    >>> df2 = dct.transform(df1)
      +    >>> df2.head().resultVec
      +    DenseVector([10.969..., -0.707..., -2.041...])
      +    >>> df3 = DCT(inverse=True, inputCol="resultVec", outputCol="origVec").transform(df2)
      +    >>> df3.head().origVec
      +    DenseVector([5.0, 8.0, 6.0])
      +    """
      +
      +    # a placeholder to make it appear in the generated doc
      +    inverse = Param(Params._dummy(), "inverse", "Set transformer to perform inverse DCT, " +
      +                    "default False.")
      +
      +    @keyword_only
      +    def __init__(self, inverse=False, inputCol=None, outputCol=None):
      +        """
      +        __init__(self, inverse=False, inputCol=None, outputCol=None)
      +        """
      +        super(DCT, self).__init__()
      +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.DCT", self.uid)
      +        self.inverse = Param(self, "inverse", "Set transformer to perform inverse DCT, " +
      +                             "default False.")
      +        self._setDefault(inverse=False)
      +        kwargs = self.__init__._input_kwargs
      +        self.setParams(**kwargs)
      +
      +    @keyword_only
      +    def setParams(self, inverse=False, inputCol=None, outputCol=None):
      +        """
      +        setParams(self, inverse=False, inputCol=None, outputCol=None)
      +        Sets params for this DCT.
      +        """
      +        kwargs = self.setParams._input_kwargs
      +        return self._set(**kwargs)
      +
      +    def setInverse(self, value):
      +        """
      +        Sets the value of :py:attr:`inverse`.
      +        """
      +        self._paramMap[self.inverse] = value
      +        return self
      +
      +    def getInverse(self):
      +        """
      +        Gets the value of inverse or its default value.
      +        """
      +        return self.getOrDefault(self.inverse)
      +
      +
      +@inherit_doc
      +class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol):
      +    """
      +    .. note:: Experimental
      +
      +    Outputs the Hadamard product (i.e., the element-wise product) of each input vector
      +    with a provided "weight" vector. In other words, it scales each column of the dataset
      +    by a scalar multiplier.
      +
      +    >>> from pyspark.mllib.linalg import Vectors
      +    >>> df = sqlContext.createDataFrame([(Vectors.dense([2.0, 1.0, 3.0]),)], ["values"])
      +    >>> ep = ElementwiseProduct(scalingVec=Vectors.dense([1.0, 2.0, 3.0]),
      +    ...     inputCol="values", outputCol="eprod")
      +    >>> ep.transform(df).head().eprod
      +    DenseVector([2.0, 2.0, 9.0])
      +    >>> ep.setParams(scalingVec=Vectors.dense([2.0, 3.0, 5.0])).transform(df).head().eprod
      +    DenseVector([4.0, 3.0, 15.0])
      +    """
      +
      +    # a placeholder to make it appear in the generated doc
      +    scalingVec = Param(Params._dummy(), "scalingVec", "vector for hadamard product, " +
      +                       "it must be MLlib Vector type.")
      +
      +    @keyword_only
      +    def __init__(self, scalingVec=None, inputCol=None, outputCol=None):
      +        """
      +        __init__(self, scalingVec=None, inputCol=None, outputCol=None)
      +        """
      +        super(ElementwiseProduct, self).__init__()
      +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ElementwiseProduct",
      +                                            self.uid)
      +        self.scalingVec = Param(self, "scalingVec", "vector for hadamard product, " +
      +                                "it must be MLlib Vector type.")
      +        kwargs = self.__init__._input_kwargs
      +        self.setParams(**kwargs)
      +
      +    @keyword_only
      +    def setParams(self, scalingVec=None, inputCol=None, outputCol=None):
      +        """
      +        setParams(self, scalingVec=None, inputCol=None, outputCol=None)
      +        Sets params for this ElementwiseProduct.
      +        """
      +        kwargs = self.setParams._input_kwargs
      +        return self._set(**kwargs)
      +
      +    def setScalingVec(self, value):
      +        """
      +        Sets the value of :py:attr:`scalingVec`.
      +        """
      +        self._paramMap[self.scalingVec] = value
      +        return self
      +
      +    def getScalingVec(self):
      +        """
      +        Gets the value of scalingVec or its default value.
      +        """
      +        return self.getOrDefault(self.scalingVec)
      +
      +
       @inherit_doc
       class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
           """
      +    .. note:: Experimental
      +
           Maps a sequence of terms to their term frequencies using the
           hashing trick.
       
      @@ -201,6 +338,8 @@ def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None):
       @inherit_doc
       class IDF(JavaEstimator, HasInputCol, HasOutputCol):
           """
      +    .. note:: Experimental
      +
           Compute the Inverse Document Frequency (IDF) given a collection of documents.
       
           >>> from pyspark.mllib.linalg import DenseVector
      @@ -261,13 +400,182 @@ def _create_model(self, java_model):
       
       class IDFModel(JavaModel):
           """
      +    .. note:: Experimental
      +
           Model fitted by IDF.
           """
       
       
      +@inherit_doc
      +class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol):
      +    """
      +    .. note:: Experimental
      +
      +    Rescale each feature individually to a common range [min, max] linearly using column summary
      +    statistics, which is also known as min-max normalization or Rescaling. The rescaled value for
      +    feature E is calculated as,
      +
      +    Rescaled(e_i) = (e_i - E_min) / (E_max - E_min) * (max - min) + min
      +
      +    For the case E_max == E_min, Rescaled(e_i) = 0.5 * (max + min)
      +
      +    Note that since zero values will probably be transformed to non-zero values, output of the
      +    transformer will be DenseVector even for sparse input.
      +
      +    >>> from pyspark.mllib.linalg import Vectors
      +    >>> df = sqlContext.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"])
      +    >>> mmScaler = MinMaxScaler(inputCol="a", outputCol="scaled")
      +    >>> model = mmScaler.fit(df)
      +    >>> model.transform(df).show()
      +    +-----+------+
      +    |    a|scaled|
      +    +-----+------+
      +    |[0.0]| [0.0]|
      +    |[2.0]| [1.0]|
      +    +-----+------+
      +    ...
      +    """
      +
      +    # a placeholder to make it appear in the generated doc
      +    min = Param(Params._dummy(), "min", "Lower bound of the output feature range")
      +    max = Param(Params._dummy(), "max", "Upper bound of the output feature range")
      +
      +    @keyword_only
      +    def __init__(self, min=0.0, max=1.0, inputCol=None, outputCol=None):
      +        """
      +        __init__(self, min=0.0, max=1.0, inputCol=None, outputCol=None)
      +        """
      +        super(MinMaxScaler, self).__init__()
      +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MinMaxScaler", self.uid)
      +        self.min = Param(self, "min", "Lower bound of the output feature range")
      +        self.max = Param(self, "max", "Upper bound of the output feature range")
      +        self._setDefault(min=0.0, max=1.0)
      +        kwargs = self.__init__._input_kwargs
      +        self.setParams(**kwargs)
      +
      +    @keyword_only
      +    def setParams(self, min=0.0, max=1.0, inputCol=None, outputCol=None):
      +        """
      +        setParams(self, min=0.0, max=1.0, inputCol=None, outputCol=None)
      +        Sets params for this MinMaxScaler.
      +        """
      +        kwargs = self.setParams._input_kwargs
      +        return self._set(**kwargs)
      +
      +    def setMin(self, value):
      +        """
      +        Sets the value of :py:attr:`min`.
      +        """
      +        self._paramMap[self.min] = value
      +        return self
      +
      +    def getMin(self):
      +        """
      +        Gets the value of min or its default value.
      +        """
      +        return self.getOrDefault(self.min)
      +
      +    def setMax(self, value):
      +        """
      +        Sets the value of :py:attr:`max`.
      +        """
      +        self._paramMap[self.max] = value
      +        return self
      +
      +    def getMax(self):
      +        """
      +        Gets the value of max or its default value.
      +        """
      +        return self.getOrDefault(self.max)
      +
      +    def _create_model(self, java_model):
      +        return MinMaxScalerModel(java_model)
      +
      +
      +class MinMaxScalerModel(JavaModel):
      +    """
      +    .. note:: Experimental
      +
      +    Model fitted by :py:class:`MinMaxScaler`.
      +    """
      +
      +
      +@inherit_doc
      +@ignore_unicode_prefix
      +class NGram(JavaTransformer, HasInputCol, HasOutputCol):
      +    """
      +    .. note:: Experimental
      +
      +    A feature transformer that converts the input array of strings into an array of n-grams. Null
      +    values in the input array are ignored.
      +    It returns an array of n-grams where each n-gram is represented by a space-separated string of
      +    words.
      +    When the input is empty, an empty array is returned.
      +    When the input array length is less than n (number of elements per n-gram), no n-grams are
      +    returned.
      +
      +    >>> df = sqlContext.createDataFrame([Row(inputTokens=["a", "b", "c", "d", "e"])])
      +    >>> ngram = NGram(n=2, inputCol="inputTokens", outputCol="nGrams")
      +    >>> ngram.transform(df).head()
      +    Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b', u'b c', u'c d', u'd e'])
      +    >>> # Change n-gram length
      +    >>> ngram.setParams(n=4).transform(df).head()
      +    Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e'])
      +    >>> # Temporarily modify output column.
      +    >>> ngram.transform(df, {ngram.outputCol: "output"}).head()
      +    Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], output=[u'a b c d', u'b c d e'])
      +    >>> ngram.transform(df).head()
      +    Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e'])
      +    >>> # Must use keyword arguments to specify params.
      +    >>> ngram.setParams("text")
      +    Traceback (most recent call last):
      +        ...
      +    TypeError: Method setParams forces keyword arguments.
      +    """
      +
      +    # a placeholder to make it appear in the generated doc
      +    n = Param(Params._dummy(), "n", "number of elements per n-gram (>=1)")
      +
      +    @keyword_only
      +    def __init__(self, n=2, inputCol=None, outputCol=None):
      +        """
      +        __init__(self, n=2, inputCol=None, outputCol=None)
      +        """
      +        super(NGram, self).__init__()
      +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.NGram", self.uid)
      +        self.n = Param(self, "n", "number of elements per n-gram (>=1)")
      +        self._setDefault(n=2)
      +        kwargs = self.__init__._input_kwargs
      +        self.setParams(**kwargs)
      +
      +    @keyword_only
      +    def setParams(self, n=2, inputCol=None, outputCol=None):
      +        """
      +        setParams(self, n=2, inputCol=None, outputCol=None)
      +        Sets params for this NGram.
      +        """
      +        kwargs = self.setParams._input_kwargs
      +        return self._set(**kwargs)
      +
      +    def setN(self, value):
      +        """
      +        Sets the value of :py:attr:`n`.
      +        """
      +        self._paramMap[self.n] = value
      +        return self
      +
      +    def getN(self):
      +        """
      +        Gets the value of n or its default value.
      +        """
      +        return self.getOrDefault(self.n)
      +
      +
       @inherit_doc
       class Normalizer(JavaTransformer, HasInputCol, HasOutputCol):
           """
      +    .. note:: Experimental
      +
            Normalize a vector to have unit norm using the given p-norm.
       
           >>> from pyspark.mllib.linalg import Vectors
      @@ -324,6 +632,8 @@ def getP(self):
       @inherit_doc
       class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol):
           """
      +    .. note:: Experimental
      +
           A one-hot encoder that maps a column of category indices to a
           column of binary vectors, with at most a single one-value per row
           that indicates the input category index.
      @@ -396,6 +706,8 @@ def getDropLast(self):
       @inherit_doc
       class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol):
           """
      +    .. note:: Experimental
      +
           Perform feature expansion in a polynomial space. As said in wikipedia of Polynomial Expansion,
           which is available at `http://en.wikipedia.org/wiki/Polynomial_expansion`, "In mathematics, an
           expansion of a product of sums expresses it as a sum of products by using the fact that
      @@ -454,9 +766,11 @@ def getDegree(self):
       @ignore_unicode_prefix
       class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol):
           """
      +    .. note:: Experimental
      +
           A regex based tokenizer that extracts tokens either by using the
           provided regex pattern (in Java dialect) to split the text
      -    (default) or repeatedly matching the regex (if gaps is true).
      +    (default) or repeatedly matching the regex (if gaps is false).
           Optional parameters also allow filtering tokens using a minimal
           length.
           It returns an array of strings that can be empty.
      @@ -548,9 +862,64 @@ def getPattern(self):
               return self.getOrDefault(self.pattern)
       
       
      +@inherit_doc
      +class SQLTransformer(JavaTransformer):
      +    """
      +    .. note:: Experimental
      +
      +    Implements the transforms which are defined by SQL statement.
      +    Currently we only support SQL syntax like 'SELECT ... FROM __THIS__'
      +    where '__THIS__' represents the underlying table of the input dataset.
      +
      +    >>> df = sqlContext.createDataFrame([(0, 1.0, 3.0), (2, 2.0, 5.0)], ["id", "v1", "v2"])
      +    >>> sqlTrans = SQLTransformer(
      +    ...     statement="SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__")
      +    >>> sqlTrans.transform(df).head()
      +    Row(id=0, v1=1.0, v2=3.0, v3=4.0, v4=3.0)
      +    """
      +
      +    # a placeholder to make it appear in the generated doc
      +    statement = Param(Params._dummy(), "statement", "SQL statement")
      +
      +    @keyword_only
      +    def __init__(self, statement=None):
      +        """
      +        __init__(self, statement=None)
      +        """
      +        super(SQLTransformer, self).__init__()
      +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.SQLTransformer", self.uid)
      +        self.statement = Param(self, "statement", "SQL statement")
      +        kwargs = self.__init__._input_kwargs
      +        self.setParams(**kwargs)
      +
      +    @keyword_only
      +    def setParams(self, statement=None):
      +        """
      +        setParams(self, statement=None)
      +        Sets params for this SQLTransformer.
      +        """
      +        kwargs = self.setParams._input_kwargs
      +        return self._set(**kwargs)
      +
      +    def setStatement(self, value):
      +        """
      +        Sets the value of :py:attr:`statement`.
      +        """
      +        self._paramMap[self.statement] = value
      +        return self
      +
      +    def getStatement(self):
      +        """
      +        Gets the value of statement or its default value.
      +        """
      +        return self.getOrDefault(self.statement)
      +
      +
       @inherit_doc
       class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol):
           """
      +    .. note:: Experimental
      +
           Standardizes features by removing the mean and scaling to unit variance using column summary
           statistics on the samples in the training set.
       
      @@ -558,6 +927,10 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol):
           >>> df = sqlContext.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"])
           >>> standardScaler = StandardScaler(inputCol="a", outputCol="scaled")
           >>> model = standardScaler.fit(df)
      +    >>> model.mean
      +    DenseVector([1.0])
      +    >>> model.std
      +    DenseVector([1.4142])
           >>> model.transform(df).collect()[1].scaled
           DenseVector([1.4142])
           """
      @@ -620,13 +993,31 @@ def _create_model(self, java_model):
       
       class StandardScalerModel(JavaModel):
           """
      +    .. note:: Experimental
      +
           Model fitted by StandardScaler.
           """
       
      +    @property
      +    def std(self):
      +        """
      +        Standard deviation of the StandardScalerModel.
      +        """
      +        return self._call_java("std")
      +
      +    @property
      +    def mean(self):
      +        """
      +        Mean of the StandardScalerModel.
      +        """
      +        return self._call_java("mean")
      +
       
       @inherit_doc
      -class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol):
      +class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid):
           """
      +    .. note:: Experimental
      +
           A label indexer that maps a string column of labels to an ML column of label indices.
           If the input column is numeric, we cast it to string and index the string values.
           The indices are in [0, numLabels), ordered by label frequencies.
      @@ -638,22 +1029,28 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol):
           >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),
           ...     key=lambda x: x[0])
           [(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)]
      +    >>> inverter = IndexToString(inputCol="indexed", outputCol="label2", labels=model.labels())
      +    >>> itd = inverter.transform(td)
      +    >>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]),
      +    ...     key=lambda x: x[0])
      +    [(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')]
           """
       
           @keyword_only
      -    def __init__(self, inputCol=None, outputCol=None):
      +    def __init__(self, inputCol=None, outputCol=None, handleInvalid="error"):
               """
      -        __init__(self, inputCol=None, outputCol=None)
      +        __init__(self, inputCol=None, outputCol=None, handleInvalid="error")
               """
               super(StringIndexer, self).__init__()
               self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid)
      +        self._setDefault(handleInvalid="error")
               kwargs = self.__init__._input_kwargs
               self.setParams(**kwargs)
       
           @keyword_only
      -    def setParams(self, inputCol=None, outputCol=None):
      +    def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"):
               """
      -        setParams(self, inputCol=None, outputCol=None)
      +        setParams(self, inputCol=None, outputCol=None, handleInvalid="error")
               Sets params for this StringIndexer.
               """
               kwargs = self.setParams._input_kwargs
      @@ -665,14 +1062,147 @@ def _create_model(self, java_model):
       
       class StringIndexerModel(JavaModel):
           """
      +    .. note:: Experimental
      +
           Model fitted by StringIndexer.
           """
      +    @property
      +    def labels(self):
      +        """
      +        Ordered list of labels, corresponding to indices to be assigned.
      +        """
      +        return self._java_obj.labels
      +
      +
      +@inherit_doc
      +class IndexToString(JavaTransformer, HasInputCol, HasOutputCol):
      +    """
      +    .. note:: Experimental
      +
      +    A :py:class:`Transformer` that maps a column of indices back to a new column of
      +    corresponding string values.
      +    The index-string mapping is either from the ML attributes of the input column,
      +    or from user-supplied labels (which take precedence over ML attributes).
      +    See L{StringIndexer} for converting strings into indices.
      +    """
      +
      +    # a placeholder to make the labels show up in generated doc
      +    labels = Param(Params._dummy(), "labels",
      +                   "Optional array of labels specifying index-string mapping." +
      +                   " If not provided or if empty, then metadata from inputCol is used instead.")
      +
      +    @keyword_only
      +    def __init__(self, inputCol=None, outputCol=None, labels=None):
      +        """
      +        __init__(self, inputCol=None, outputCol=None, labels=None)
      +        """
      +        super(IndexToString, self).__init__()
      +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString",
      +                                            self.uid)
      +        self.labels = Param(self, "labels",
      +                            "Optional array of labels specifying index-string mapping. If not" +
      +                            " provided or if empty, then metadata from inputCol is used instead.")
      +        kwargs = self.__init__._input_kwargs
      +        self.setParams(**kwargs)
      +
      +    @keyword_only
      +    def setParams(self, inputCol=None, outputCol=None, labels=None):
      +        """
      +        setParams(self, inputCol=None, outputCol=None, labels=None)
      +        Sets params for this IndexToString.
      +        """
      +        kwargs = self.setParams._input_kwargs
      +        return self._set(**kwargs)
      +
      +    def setLabels(self, value):
      +        """
      +        Sets the value of :py:attr:`labels`.
      +        """
      +        self._paramMap[self.labels] = value
      +        return self
      +
      +    def getLabels(self):
      +        """
      +        Gets the value of :py:attr:`labels` or its default value.
      +        """
      +        return self.getOrDefault(self.labels)
      +
      +
      +class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol):
      +    """
      +    .. note:: Experimental
      +
      +    A feature transformer that filters out stop words from input.
      +    Note: null values from input array are preserved unless adding null to stopWords explicitly.
      +    """
      +    # a placeholder to make the stopwords show up in generated doc
      +    stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out")
      +    caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " +
      +                          "comparison over the stop words")
      +
      +    @keyword_only
      +    def __init__(self, inputCol=None, outputCol=None, stopWords=None,
      +                 caseSensitive=False):
      +        """
      +        __init__(self, inputCol=None, outputCol=None, stopWords=None,\
      +                 caseSensitive=false)
      +        """
      +        super(StopWordsRemover, self).__init__()
      +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
      +                                            self.uid)
      +        self.stopWords = Param(self, "stopWords", "The words to be filtered out")
      +        self.caseSensitive = Param(self, "caseSensitive", "whether to do a case " +
      +                                   "sensitive comparison over the stop words")
      +        stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWords
      +        defaultStopWords = stopWordsObj.English()
      +        self._setDefault(stopWords=defaultStopWords)
      +        kwargs = self.__init__._input_kwargs
      +        self.setParams(**kwargs)
      +
      +    @keyword_only
      +    def setParams(self, inputCol=None, outputCol=None, stopWords=None,
      +                  caseSensitive=False):
      +        """
      +        setParams(self, inputCol="input", outputCol="output", stopWords=None,\
      +                  caseSensitive=false)
      +        Sets params for this StopWordRemover.
      +        """
      +        kwargs = self.setParams._input_kwargs
      +        return self._set(**kwargs)
      +
      +    def setStopWords(self, value):
      +        """
      +        Specify the stopwords to be filtered.
      +        """
      +        self._paramMap[self.stopWords] = value
      +        return self
      +
      +    def getStopWords(self):
      +        """
      +        Get the stopwords.
      +        """
      +        return self.getOrDefault(self.stopWords)
      +
      +    def setCaseSensitive(self, value):
      +        """
      +        Set whether to do a case sensitive comparison over the stop words
      +        """
      +        self._paramMap[self.caseSensitive] = value
      +        return self
      +
      +    def getCaseSensitive(self):
      +        """
      +        Get whether to do a case sensitive comparison over the stop words.
      +        """
      +        return self.getOrDefault(self.caseSensitive)
       
       
       @inherit_doc
       @ignore_unicode_prefix
       class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
           """
      +    .. note:: Experimental
      +
           A tokenizer that converts the input string to lowercase and then
           splits it by white spaces.
       
      @@ -718,6 +1248,8 @@ def setParams(self, inputCol=None, outputCol=None):
       @inherit_doc
       class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol):
           """
      +    .. note:: Experimental
      +
           A feature transformer that merges multiple columns into a vector column.
       
           >>> df = sqlContext.createDataFrame([(1, 0, 3)], ["a", "b", "c"])
      @@ -754,6 +1286,8 @@ def setParams(self, inputCols=None, outputCol=None):
       @inherit_doc
       class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol):
           """
      +    .. note:: Experimental
      +
           Class for indexing categorical feature columns in a dataset of [[Vector]].
       
           This has 2 usage modes:
      @@ -796,6 +1330,10 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol):
           >>> model = indexer.fit(df)
           >>> model.transform(df).head().indexed
           DenseVector([1.0, 0.0])
      +    >>> model.numFeatures
      +    2
      +    >>> model.categoryMaps
      +    {0: {0.0: 0, -1.0: 1}}
           >>> indexer.setParams(outputCol="test").fit(df).transform(df).collect()[1].test
           DenseVector([0.0, 1.0])
           >>> params = {indexer.maxCategories: 3, indexer.outputCol: "vector"}
      @@ -853,20 +1391,142 @@ def _create_model(self, java_model):
       
       class VectorIndexerModel(JavaModel):
           """
      +    .. note:: Experimental
      +
           Model fitted by VectorIndexer.
           """
       
      +    @property
      +    def numFeatures(self):
      +        """
      +        Number of features, i.e., length of Vectors which this transforms.
      +        """
      +        return self._call_java("numFeatures")
      +
      +    @property
      +    def categoryMaps(self):
      +        """
      +        Feature value index.  Keys are categorical feature indices (column indices).
      +        Values are maps from original features values to 0-based category indices.
      +        If a feature is not in this map, it is treated as continuous.
      +        """
      +        return self._call_java("javaCategoryMaps")
      +
      +
      +@inherit_doc
      +class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol):
      +    """
      +    .. note:: Experimental
      +
      +    This class takes a feature vector and outputs a new feature vector with a subarray
      +    of the original features.
      +
      +    The subset of features can be specified with either indices (`setIndices()`)
      +    or names (`setNames()`).  At least one feature must be selected. Duplicate features
      +    are not allowed, so there can be no overlap between selected indices and names.
      +
      +    The output vector will order features with the selected indices first (in the order given),
      +    followed by the selected names (in the order given).
      +
      +    >>> from pyspark.mllib.linalg import Vectors
      +    >>> df = sqlContext.createDataFrame([
      +    ...     (Vectors.dense([-2.0, 2.3, 0.0, 0.0, 1.0]),),
      +    ...     (Vectors.dense([0.0, 0.0, 0.0, 0.0, 0.0]),),
      +    ...     (Vectors.dense([0.6, -1.1, -3.0, 4.5, 3.3]),)], ["features"])
      +    >>> vs = VectorSlicer(inputCol="features", outputCol="sliced", indices=[1, 4])
      +    >>> vs.transform(df).head().sliced
      +    DenseVector([2.3, 1.0])
      +    """
      +
      +    # a placeholder to make it appear in the generated doc
      +    indices = Param(Params._dummy(), "indices", "An array of indices to select features from " +
      +                    "a vector column. There can be no overlap with names.")
      +    names = Param(Params._dummy(), "names", "An array of feature names to select features from " +
      +                  "a vector column. These names must be specified by ML " +
      +                  "org.apache.spark.ml.attribute.Attribute. There can be no overlap with " +
      +                  "indices.")
      +
      +    @keyword_only
      +    def __init__(self, inputCol=None, outputCol=None, indices=None, names=None):
      +        """
      +        __init__(self, inputCol=None, outputCol=None, indices=None, names=None)
      +        """
      +        super(VectorSlicer, self).__init__()
      +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorSlicer", self.uid)
      +        self.indices = Param(self, "indices", "An array of indices to select features from " +
      +                             "a vector column. There can be no overlap with names.")
      +        self.names = Param(self, "names", "An array of feature names to select features from " +
      +                           "a vector column. These names must be specified by ML " +
      +                           "org.apache.spark.ml.attribute.Attribute. There can be no overlap " +
      +                           "with indices.")
      +        kwargs = self.__init__._input_kwargs
      +        self.setParams(**kwargs)
      +
      +    @keyword_only
      +    def setParams(self, inputCol=None, outputCol=None, indices=None, names=None):
      +        """
      +        setParams(self, inputCol=None, outputCol=None, indices=None, names=None):
      +        Sets params for this VectorSlicer.
      +        """
      +        kwargs = self.setParams._input_kwargs
      +        return self._set(**kwargs)
      +
      +    def setIndices(self, value):
      +        """
      +        Sets the value of :py:attr:`indices`.
      +        """
      +        self._paramMap[self.indices] = value
      +        return self
      +
      +    def getIndices(self):
      +        """
      +        Gets the value of indices or its default value.
      +        """
      +        return self.getOrDefault(self.indices)
      +
      +    def setNames(self, value):
      +        """
      +        Sets the value of :py:attr:`names`.
      +        """
      +        self._paramMap[self.names] = value
      +        return self
      +
      +    def getNames(self):
      +        """
      +        Gets the value of names or its default value.
      +        """
      +        return self.getOrDefault(self.names)
      +
       
       @inherit_doc
       @ignore_unicode_prefix
       class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCol):
           """
      +    .. note:: Experimental
      +
           Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further
           natural language processing or machine learning process.
       
           >>> sent = ("a b " * 100 + "a c " * 10).split(" ")
           >>> doc = sqlContext.createDataFrame([(sent,), (sent,)], ["sentence"])
           >>> model = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", outputCol="model").fit(doc)
      +    >>> model.getVectors().show()
      +    +----+--------------------+
      +    |word|              vector|
      +    +----+--------------------+
      +    |   a|[-0.3511952459812...|
      +    |   b|[0.29077222943305...|
      +    |   c|[0.02315592765808...|
      +    +----+--------------------+
      +    ...
      +    >>> model.findSynonyms("a", 2).show()
      +    +----+-------------------+
      +    |word|         similarity|
      +    +----+-------------------+
      +    |   b|0.29255685145799626|
      +    |   c|-0.5414068302988307|
      +    +----+-------------------+
      +    ...
           >>> model.transform(doc).head().model
           DenseVector([-0.0422, -0.5138, -0.2546, 0.6885, 0.276])
           """
      @@ -957,9 +1617,180 @@ def _create_model(self, java_model):
       
       class Word2VecModel(JavaModel):
           """
      +    .. note:: Experimental
      +
           Model fitted by Word2Vec.
           """
       
      +    def getVectors(self):
      +        """
      +        Returns the vector representation of the words as a dataframe
      +        with two fields, word and vector.
      +        """
      +        return self._call_java("getVectors")
      +
      +    def findSynonyms(self, word, num):
      +        """
      +        Find "num" number of words closest in similarity to "word".
      +        word can be a string or vector representation.
      +        Returns a dataframe with two fields word and similarity (which
      +        gives the cosine similarity).
      +        """
      +        if not isinstance(word, basestring):
      +            word = _convert_to_vector(word)
      +        return self._call_java("findSynonyms", word, num)
      +
      +
      +@inherit_doc
      +class PCA(JavaEstimator, HasInputCol, HasOutputCol):
      +    """
      +    .. note:: Experimental
      +
      +    PCA trains a model to project vectors to a low-dimensional space using PCA.
      +
      +    >>> from pyspark.mllib.linalg import Vectors
      +    >>> data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),),
      +    ...     (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),),
      +    ...     (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)]
      +    >>> df = sqlContext.createDataFrame(data,["features"])
      +    >>> pca = PCA(k=2, inputCol="features", outputCol="pca_features")
      +    >>> model = pca.fit(df)
      +    >>> model.transform(df).collect()[0].pca_features
      +    DenseVector([1.648..., -4.013...])
      +    """
      +
      +    # a placeholder to make it appear in the generated doc
      +    k = Param(Params._dummy(), "k", "the number of principal components")
      +
      +    @keyword_only
      +    def __init__(self, k=None, inputCol=None, outputCol=None):
      +        """
      +        __init__(self, k=None, inputCol=None, outputCol=None)
      +        """
      +        super(PCA, self).__init__()
      +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.PCA", self.uid)
      +        self.k = Param(self, "k", "the number of principal components")
      +        kwargs = self.__init__._input_kwargs
      +        self.setParams(**kwargs)
      +
      +    @keyword_only
      +    def setParams(self, k=None, inputCol=None, outputCol=None):
      +        """
      +        setParams(self, k=None, inputCol=None, outputCol=None)
      +        Set params for this PCA.
      +        """
      +        kwargs = self.setParams._input_kwargs
      +        return self._set(**kwargs)
      +
      +    def setK(self, value):
      +        """
      +        Sets the value of :py:attr:`k`.
      +        """
      +        self._paramMap[self.k] = value
      +        return self
      +
      +    def getK(self):
      +        """
      +        Gets the value of k or its default value.
      +        """
      +        return self.getOrDefault(self.k)
      +
      +    def _create_model(self, java_model):
      +        return PCAModel(java_model)
      +
      +
      +class PCAModel(JavaModel):
      +    """
      +    .. note:: Experimental
      +
      +    Model fitted by PCA.
      +    """
      +
      +
      +@inherit_doc
      +class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol):
      +    """
      +    .. note:: Experimental
      +
      +    Implements the transforms required for fitting a dataset against an
      +    R model formula. Currently we support a limited subset of the R
      +    operators, including '~', '+', '-', and '.'. Also see the R formula
      +    docs:
      +    http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
      +
      +    >>> df = sqlContext.createDataFrame([
      +    ...     (1.0, 1.0, "a"),
      +    ...     (0.0, 2.0, "b"),
      +    ...     (0.0, 0.0, "a")
      +    ... ], ["y", "x", "s"])
      +    >>> rf = RFormula(formula="y ~ x + s")
      +    >>> rf.fit(df).transform(df).show()
      +    +---+---+---+---------+-----+
      +    |  y|  x|  s| features|label|
      +    +---+---+---+---------+-----+
      +    |1.0|1.0|  a|[1.0,1.0]|  1.0|
      +    |0.0|2.0|  b|[2.0,0.0]|  0.0|
      +    |0.0|0.0|  a|[0.0,1.0]|  0.0|
      +    +---+---+---+---------+-----+
      +    ...
      +    >>> rf.fit(df, {rf.formula: "y ~ . - s"}).transform(df).show()
      +    +---+---+---+--------+-----+
      +    |  y|  x|  s|features|label|
      +    +---+---+---+--------+-----+
      +    |1.0|1.0|  a|   [1.0]|  1.0|
      +    |0.0|2.0|  b|   [2.0]|  0.0|
      +    |0.0|0.0|  a|   [0.0]|  0.0|
      +    +---+---+---+--------+-----+
      +    ...
      +    """
      +
      +    # a placeholder to make it appear in the generated doc
      +    formula = Param(Params._dummy(), "formula", "R model formula")
      +
      +    @keyword_only
      +    def __init__(self, formula=None, featuresCol="features", labelCol="label"):
      +        """
      +        __init__(self, formula=None, featuresCol="features", labelCol="label")
      +        """
      +        super(RFormula, self).__init__()
      +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid)
      +        self.formula = Param(self, "formula", "R model formula")
      +        kwargs = self.__init__._input_kwargs
      +        self.setParams(**kwargs)
      +
      +    @keyword_only
      +    def setParams(self, formula=None, featuresCol="features", labelCol="label"):
      +        """
      +        setParams(self, formula=None, featuresCol="features", labelCol="label")
      +        Sets params for RFormula.
      +        """
      +        kwargs = self.setParams._input_kwargs
      +        return self._set(**kwargs)
      +
      +    def setFormula(self, value):
      +        """
      +        Sets the value of :py:attr:`formula`.
      +        """
      +        self._paramMap[self.formula] = value
      +        return self
      +
      +    def getFormula(self):
      +        """
      +        Gets the value of :py:attr:`formula`.
      +        """
      +        return self.getOrDefault(self.formula)
      +
      +    def _create_model(self, java_model):
      +        return RFormulaModel(java_model)
      +
      +
      +class RFormulaModel(JavaModel):
      +    """
      +    .. note:: Experimental
      +
      +    Model fitted by :py:class:`RFormula`.
      +    """
      +
       
       if __name__ == "__main__":
           import doctest
      diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
      index 7845536161e0..eeeac49b2198 100644
      --- a/python/pyspark/ml/param/__init__.py
      +++ b/python/pyspark/ml/param/__init__.py
      @@ -60,14 +60,16 @@ class Params(Identifiable):
       
           __metaclass__ = ABCMeta
       
      -    #: internal param map for user-supplied values param map
      -    _paramMap = {}
      +    def __init__(self):
      +        super(Params, self).__init__()
      +        #: internal param map for user-supplied values param map
      +        self._paramMap = {}
       
      -    #: internal param map for default values
      -    _defaultParamMap = {}
      +        #: internal param map for default values
      +        self._defaultParamMap = {}
       
      -    #: value returned by :py:func:`params`
      -    _params = None
      +        #: value returned by :py:func:`params`
      +        self._params = None
       
           @property
           def params(self):
      @@ -155,7 +157,7 @@ def getOrDefault(self, param):
               else:
                   return self._defaultParamMap[param]
       
      -    def extractParamMap(self, extra={}):
      +    def extractParamMap(self, extra=None):
               """
               Extracts the embedded default param values and user-supplied
               values, and then merges them with extra values from input into
      @@ -165,12 +167,14 @@ def extractParamMap(self, extra={}):
               :param extra: extra param values
               :return: merged param map
               """
      +        if extra is None:
      +            extra = dict()
               paramMap = self._defaultParamMap.copy()
               paramMap.update(self._paramMap)
               paramMap.update(extra)
               return paramMap
       
      -    def copy(self, extra={}):
      +    def copy(self, extra=None):
               """
               Creates a copy of this instance with the same uid and some
               extra params. The default implementation creates a
      @@ -181,6 +185,8 @@ def copy(self, extra={}):
               :param extra: Extra parameters to copy to the new instance
               :return: Copy of this instance
               """
      +        if extra is None:
      +            extra = dict()
               that = copy.copy(self)
               that._paramMap = self.extractParamMap(extra)
               return that
      @@ -233,7 +239,7 @@ def _setDefault(self, **kwargs):
                   self._defaultParamMap[getattr(self, param)] = value
               return self
       
      -    def _copyValues(self, to, extra={}):
      +    def _copyValues(self, to, extra=None):
               """
               Copies param values from this instance to another instance for
               params shared by them.
      @@ -241,6 +247,8 @@ def _copyValues(self, to, extra={}):
               :param extra: extra params to be copied
               :return: the target instance with param values copied
               """
      +        if extra is None:
      +            extra = dict()
               paramMap = self.extractParamMap(extra)
               for p in self.params:
                   if p in paramMap and to.hasParam(p.name):
      diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
      index 69efc424ec4e..5b39e5dd4e25 100644
      --- a/python/pyspark/ml/param/_shared_params_code_gen.py
      +++ b/python/pyspark/ml/param/_shared_params_code_gen.py
      @@ -121,7 +121,19 @@ def get$Name(self):
               ("checkpointInterval", "checkpoint interval (>= 1)", None),
               ("seed", "random seed", "hash(type(self).__name__)"),
               ("tol", "the convergence tolerance for iterative algorithms", None),
      -        ("stepSize", "Step size to be used for each iteration of optimization.", None)]
      +        ("stepSize", "Step size to be used for each iteration of optimization.", None),
      +        ("handleInvalid", "how to handle invalid entries. Options are skip (which will filter " +
      +         "out rows with bad values), or error (which will throw an errror). More options may be " +
      +         "added later.", None),
      +        ("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " +
      +         "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0"),
      +        ("fitIntercept", "whether to fit an intercept term.", "True"),
      +        ("standardization", "whether to standardize the training features before fitting the " +
      +         "model.", "True"),
      +        ("thresholds", "Thresholds in multi-class classification to adjust the probability of " +
      +         "predicting each class. Array must have length equal to the number of classes, with " +
      +         "values >= 0. The class with largest value p/t is predicted, where p is the original " +
      +         "probability of that class and t is the class' threshold.", None)]
           code = []
           for name, doc, defaultValueStr in shared:
               param_code = _gen_param_header(name, doc, defaultValueStr)
      diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
      index bc088e4c29e2..af1218128602 100644
      --- a/python/pyspark/ml/param/shared.py
      +++ b/python/pyspark/ml/param/shared.py
      @@ -432,6 +432,144 @@ def getStepSize(self):
               return self.getOrDefault(self.stepSize)
       
       
      +class HasHandleInvalid(Params):
      +    """
      +    Mixin for param handleInvalid: how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later..
      +    """
      +
      +    # a placeholder to make it appear in the generated doc
      +    handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.")
      +
      +    def __init__(self):
      +        super(HasHandleInvalid, self).__init__()
      +        #: param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.
      +        self.handleInvalid = Param(self, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.")
      +
      +    def setHandleInvalid(self, value):
      +        """
      +        Sets the value of :py:attr:`handleInvalid`.
      +        """
      +        self._paramMap[self.handleInvalid] = value
      +        return self
      +
      +    def getHandleInvalid(self):
      +        """
      +        Gets the value of handleInvalid or its default value.
      +        """
      +        return self.getOrDefault(self.handleInvalid)
      +
      +
      +class HasElasticNetParam(Params):
      +    """
      +    Mixin for param elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty..
      +    """
      +
      +    # a placeholder to make it appear in the generated doc
      +    elasticNetParam = Param(Params._dummy(), "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.")
      +
      +    def __init__(self):
      +        super(HasElasticNetParam, self).__init__()
      +        #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
      +        self.elasticNetParam = Param(self, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.")
      +        self._setDefault(elasticNetParam=0.0)
      +
      +    def setElasticNetParam(self, value):
      +        """
      +        Sets the value of :py:attr:`elasticNetParam`.
      +        """
      +        self._paramMap[self.elasticNetParam] = value
      +        return self
      +
      +    def getElasticNetParam(self):
      +        """
      +        Gets the value of elasticNetParam or its default value.
      +        """
      +        return self.getOrDefault(self.elasticNetParam)
      +
      +
      +class HasFitIntercept(Params):
      +    """
      +    Mixin for param fitIntercept: whether to fit an intercept term..
      +    """
      +
      +    # a placeholder to make it appear in the generated doc
      +    fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.")
      +
      +    def __init__(self):
      +        super(HasFitIntercept, self).__init__()
      +        #: param for whether to fit an intercept term.
      +        self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.")
      +        self._setDefault(fitIntercept=True)
      +
      +    def setFitIntercept(self, value):
      +        """
      +        Sets the value of :py:attr:`fitIntercept`.
      +        """
      +        self._paramMap[self.fitIntercept] = value
      +        return self
      +
      +    def getFitIntercept(self):
      +        """
      +        Gets the value of fitIntercept or its default value.
      +        """
      +        return self.getOrDefault(self.fitIntercept)
      +
      +
      +class HasStandardization(Params):
      +    """
      +    Mixin for param standardization: whether to standardize the training features before fitting the model..
      +    """
      +
      +    # a placeholder to make it appear in the generated doc
      +    standardization = Param(Params._dummy(), "standardization", "whether to standardize the training features before fitting the model.")
      +
      +    def __init__(self):
      +        super(HasStandardization, self).__init__()
      +        #: param for whether to standardize the training features before fitting the model.
      +        self.standardization = Param(self, "standardization", "whether to standardize the training features before fitting the model.")
      +        self._setDefault(standardization=True)
      +
      +    def setStandardization(self, value):
      +        """
      +        Sets the value of :py:attr:`standardization`.
      +        """
      +        self._paramMap[self.standardization] = value
      +        return self
      +
      +    def getStandardization(self):
      +        """
      +        Gets the value of standardization or its default value.
      +        """
      +        return self.getOrDefault(self.standardization)
      +
      +
      +class HasThresholds(Params):
      +    """
      +    Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold..
      +    """
      +
      +    # a placeholder to make it appear in the generated doc
      +    thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.")
      +
      +    def __init__(self):
      +        super(HasThresholds, self).__init__()
      +        #: param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.
      +        self.thresholds = Param(self, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.")
      +
      +    def setThresholds(self, value):
      +        """
      +        Sets the value of :py:attr:`thresholds`.
      +        """
      +        self._paramMap[self.thresholds] = value
      +        return self
      +
      +    def getThresholds(self):
      +        """
      +        Gets the value of thresholds or its default value.
      +        """
      +        return self.getOrDefault(self.thresholds)
      +
      +
       class DecisionTreeParams(Params):
           """
           Mixin for Decision Tree parameters.
      diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
      index a563024b2cdc..13cf2b0f7bbd 100644
      --- a/python/pyspark/ml/pipeline.py
      +++ b/python/pyspark/ml/pipeline.py
      @@ -42,7 +42,7 @@ def _fit(self, dataset):
               """
               raise NotImplementedError()
       
      -    def fit(self, dataset, params={}):
      +    def fit(self, dataset, params=None):
               """
               Fits a model to the input dataset with optional parameters.
       
      @@ -54,6 +54,8 @@ def fit(self, dataset, params={}):
                              list of models.
               :returns: fitted model(s)
               """
      +        if params is None:
      +            params = dict()
               if isinstance(params, (list, tuple)):
                   return [self.fit(dataset, paramMap) for paramMap in params]
               elif isinstance(params, dict):
      @@ -86,7 +88,7 @@ def _transform(self, dataset):
               """
               raise NotImplementedError()
       
      -    def transform(self, dataset, params={}):
      +    def transform(self, dataset, params=None):
               """
               Transforms the input dataset with optional parameters.
       
      @@ -96,6 +98,8 @@ def transform(self, dataset, params={}):
                              params.
               :returns: transformed dataset
               """
      +        if params is None:
      +            params = dict()
               if isinstance(params, dict):
                   if params:
                       return self.copy(params,)._transform(dataset)
      @@ -135,10 +139,12 @@ class Pipeline(Estimator):
           """
       
           @keyword_only
      -    def __init__(self, stages=[]):
      +    def __init__(self, stages=None):
               """
      -        __init__(self, stages=[])
      +        __init__(self, stages=None)
               """
      +        if stages is None:
      +            stages = []
               super(Pipeline, self).__init__()
               #: Param for pipeline stages.
               self.stages = Param(self, "stages", "pipeline stages")
      @@ -162,11 +168,13 @@ def getStages(self):
                   return self._paramMap[self.stages]
       
           @keyword_only
      -    def setParams(self, stages=[]):
      +    def setParams(self, stages=None):
               """
      -        setParams(self, stages=[])
      +        setParams(self, stages=None)
               Sets params for Pipeline.
               """
      +        if stages is None:
      +            stages = []
               kwargs = self.setParams._input_kwargs
               return self._set(**kwargs)
       
      @@ -195,7 +203,9 @@ def _fit(self, dataset):
                       transformers.append(stage)
               return PipelineModel(transformers)
       
      -    def copy(self, extra={}):
      +    def copy(self, extra=None):
      +        if extra is None:
      +            extra = dict()
               that = Params.copy(self, extra)
               stages = [stage.copy(extra) for stage in that.getStages()]
               return that.setStages(stages)
      @@ -216,6 +226,8 @@ def _transform(self, dataset):
                   dataset = t.transform(dataset)
               return dataset
       
      -    def copy(self, extra={}):
      +    def copy(self, extra=None):
      +        if extra is None:
      +            extra = dict()
               stages = [stage.copy(extra) for stage in self.stages]
               return PipelineModel(stages)
      diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
      index b139e27372d8..a9503608b7f2 100644
      --- a/python/pyspark/ml/regression.py
      +++ b/python/pyspark/ml/regression.py
      @@ -28,7 +28,8 @@
       
       @inherit_doc
       class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
      -                       HasRegParam, HasTol):
      +                       HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept,
      +                       HasStandardization):
           """
           Linear regression.
       
      @@ -63,38 +64,30 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
           TypeError: Method setParams forces keyword arguments.
           """
       
      -    # a placeholder to make it appear in the generated doc
      -    elasticNetParam = \
      -        Param(Params._dummy(), "elasticNetParam",
      -              "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " +
      -              "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.")
      -
           @keyword_only
           def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
      -                 maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6):
      +                 maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
      +                 standardization=True):
               """
               __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
      -                 maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6)
      +                 maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
      +                 standardization=True)
               """
               super(LinearRegression, self).__init__()
               self._java_obj = self._new_java_obj(
                   "org.apache.spark.ml.regression.LinearRegression", self.uid)
      -        #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty
      -        #  is an L2 penalty. For alpha = 1, it is an L1 penalty.
      -        self.elasticNetParam = \
      -            Param(self, "elasticNetParam",
      -                  "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty " +
      -                  "is an L2 penalty. For alpha = 1, it is an L1 penalty.")
      -        self._setDefault(maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6)
      +        self._setDefault(maxIter=100, regParam=0.0, tol=1e-6)
               kwargs = self.__init__._input_kwargs
               self.setParams(**kwargs)
       
           @keyword_only
           def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
      -                  maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6):
      +                  maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
      +                  standardization=True):
               """
               setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
      -                  maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6)
      +                  maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
      +                  standardization=True)
               Sets params for linear regression.
               """
               kwargs = self.setParams._input_kwargs
      @@ -103,19 +96,6 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
           def _create_model(self, java_model):
               return LinearRegressionModel(java_model)
       
      -    def setElasticNetParam(self, value):
      -        """
      -        Sets the value of :py:attr:`elasticNetParam`.
      -        """
      -        self._paramMap[self.elasticNetParam] = value
      -        return self
      -
      -    def getElasticNetParam(self):
      -        """
      -        Gets the value of elasticNetParam or its default value.
      -        """
      -        return self.getOrDefault(self.elasticNetParam)
      -
       
       class LinearRegressionModel(JavaModel):
           """
      @@ -172,6 +152,10 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
           ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
           >>> dt = DecisionTreeRegressor(maxDepth=2)
           >>> model = dt.fit(df)
      +    >>> model.depth
      +    1
      +    >>> model.numNodes
      +    3
           >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
           >>> model.transform(test0).head().prediction
           0.0
      @@ -239,7 +223,37 @@ def getImpurity(self):
               return self.getOrDefault(self.impurity)
       
       
      -class DecisionTreeRegressionModel(JavaModel):
      +@inherit_doc
      +class DecisionTreeModel(JavaModel):
      +
      +    @property
      +    def numNodes(self):
      +        """Return number of nodes of the decision tree."""
      +        return self._call_java("numNodes")
      +
      +    @property
      +    def depth(self):
      +        """Return depth of the decision tree."""
      +        return self._call_java("depth")
      +
      +    def __repr__(self):
      +        return self._call_java("toString")
      +
      +
      +@inherit_doc
      +class TreeEnsembleModels(JavaModel):
      +
      +    @property
      +    def treeWeights(self):
      +        """Return the weights for each tree"""
      +        return list(self._call_java("javaTreeWeights"))
      +
      +    def __repr__(self):
      +        return self._call_java("toString")
      +
      +
      +@inherit_doc
      +class DecisionTreeRegressionModel(DecisionTreeModel):
           """
           Model fitted by DecisionTreeRegressor.
           """
      @@ -253,12 +267,15 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
           learning algorithm for regression.
           It supports both continuous and categorical features.
       
      +    >>> from numpy import allclose
           >>> from pyspark.mllib.linalg import Vectors
           >>> df = sqlContext.createDataFrame([
           ...     (1.0, Vectors.dense(1.0)),
           ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
           >>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42)
           >>> model = rf.fit(df)
      +    >>> allclose(model.treeWeights, [1.0, 1.0])
      +    True
           >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
           >>> model.transform(test0).head().prediction
           0.0
      @@ -389,7 +406,7 @@ def getFeatureSubsetStrategy(self):
               return self.getOrDefault(self.featureSubsetStrategy)
       
       
      -class RandomForestRegressionModel(JavaModel):
      +class RandomForestRegressionModel(TreeEnsembleModels):
           """
           Model fitted by RandomForestRegressor.
           """
      @@ -403,12 +420,15 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
           learning algorithm for regression.
           It supports both continuous and categorical features.
       
      +    >>> from numpy import allclose
           >>> from pyspark.mllib.linalg import Vectors
           >>> df = sqlContext.createDataFrame([
           ...     (1.0, Vectors.dense(1.0)),
           ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
           >>> gbt = GBTRegressor(maxIter=5, maxDepth=2)
           >>> model = gbt.fit(df)
      +    >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
      +    True
           >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
           >>> model.transform(test0).head().prediction
           0.0
      @@ -518,7 +538,7 @@ def getStepSize(self):
               return self.getOrDefault(self.stepSize)
       
       
      -class GBTRegressionModel(JavaModel):
      +class GBTRegressionModel(TreeEnsembleModels):
           """
           Model fitted by GBTRegressor.
           """
      diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
      index 6adbf166f34a..b892318f50bd 100644
      --- a/python/pyspark/ml/tests.py
      +++ b/python/pyspark/ml/tests.py
      @@ -31,12 +31,15 @@
           import unittest
       
       from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
      -from pyspark.sql import DataFrame, SQLContext
      +from pyspark.sql import DataFrame, SQLContext, Row
      +from pyspark.sql.functions import rand
      +from pyspark.ml.evaluation import RegressionEvaluator
       from pyspark.ml.param import Param, Params
       from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
       from pyspark.ml.util import keyword_only
       from pyspark.ml import Estimator, Model, Pipeline, Transformer
       from pyspark.ml.feature import *
      +from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, CrossValidatorModel
       from pyspark.mllib.linalg import DenseVector
       
       
      @@ -252,6 +255,117 @@ def test_idf(self):
               output = idf0m.transform(dataset)
               self.assertIsNotNone(output.head().idf)
       
      +    def test_ngram(self):
      +        sqlContext = SQLContext(self.sc)
      +        dataset = sqlContext.createDataFrame([
      +            Row(input=["a", "b", "c", "d", "e"])])
      +        ngram0 = NGram(n=4, inputCol="input", outputCol="output")
      +        self.assertEqual(ngram0.getN(), 4)
      +        self.assertEqual(ngram0.getInputCol(), "input")
      +        self.assertEqual(ngram0.getOutputCol(), "output")
      +        transformedDF = ngram0.transform(dataset)
      +        self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"])
      +
      +    def test_stopwordsremover(self):
      +        sqlContext = SQLContext(self.sc)
      +        dataset = sqlContext.createDataFrame([Row(input=["a", "panda"])])
      +        stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output")
      +        # Default
      +        self.assertEquals(stopWordRemover.getInputCol(), "input")
      +        transformedDF = stopWordRemover.transform(dataset)
      +        self.assertEquals(transformedDF.head().output, ["panda"])
      +        # Custom
      +        stopwords = ["panda"]
      +        stopWordRemover.setStopWords(stopwords)
      +        self.assertEquals(stopWordRemover.getInputCol(), "input")
      +        self.assertEquals(stopWordRemover.getStopWords(), stopwords)
      +        transformedDF = stopWordRemover.transform(dataset)
      +        self.assertEquals(transformedDF.head().output, ["a"])
      +
      +
      +class HasInducedError(Params):
      +
      +    def __init__(self):
      +        super(HasInducedError, self).__init__()
      +        self.inducedError = Param(self, "inducedError",
      +                                  "Uniformly-distributed error added to feature")
      +
      +    def getInducedError(self):
      +        return self.getOrDefault(self.inducedError)
      +
      +
      +class InducedErrorModel(Model, HasInducedError):
      +
      +    def __init__(self):
      +        super(InducedErrorModel, self).__init__()
      +
      +    def _transform(self, dataset):
      +        return dataset.withColumn("prediction",
      +                                  dataset.feature + (rand(0) * self.getInducedError()))
      +
      +
      +class InducedErrorEstimator(Estimator, HasInducedError):
      +
      +    def __init__(self, inducedError=1.0):
      +        super(InducedErrorEstimator, self).__init__()
      +        self._set(inducedError=inducedError)
      +
      +    def _fit(self, dataset):
      +        model = InducedErrorModel()
      +        self._copyValues(model)
      +        return model
      +
      +
      +class CrossValidatorTests(PySparkTestCase):
      +
      +    def test_fit_minimize_metric(self):
      +        sqlContext = SQLContext(self.sc)
      +        dataset = sqlContext.createDataFrame([
      +            (10, 10.0),
      +            (50, 50.0),
      +            (100, 100.0),
      +            (500, 500.0)] * 10,
      +            ["feature", "label"])
      +
      +        iee = InducedErrorEstimator()
      +        evaluator = RegressionEvaluator(metricName="rmse")
      +
      +        grid = (ParamGridBuilder()
      +                .addGrid(iee.inducedError, [100.0, 0.0, 10000.0])
      +                .build())
      +        cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator)
      +        cvModel = cv.fit(dataset)
      +        bestModel = cvModel.bestModel
      +        bestModelMetric = evaluator.evaluate(bestModel.transform(dataset))
      +
      +        self.assertEqual(0.0, bestModel.getOrDefault('inducedError'),
      +                         "Best model should have zero induced error")
      +        self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0")
      +
      +    def test_fit_maximize_metric(self):
      +        sqlContext = SQLContext(self.sc)
      +        dataset = sqlContext.createDataFrame([
      +            (10, 10.0),
      +            (50, 50.0),
      +            (100, 100.0),
      +            (500, 500.0)] * 10,
      +            ["feature", "label"])
      +
      +        iee = InducedErrorEstimator()
      +        evaluator = RegressionEvaluator(metricName="r2")
      +
      +        grid = (ParamGridBuilder()
      +                .addGrid(iee.inducedError, [100.0, 0.0, 10000.0])
      +                .build())
      +        cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator)
      +        cvModel = cv.fit(dataset)
      +        bestModel = cvModel.bestModel
      +        bestModelMetric = evaluator.evaluate(bestModel.transform(dataset))
      +
      +        self.assertEqual(0.0, bestModel.getOrDefault('inducedError'),
      +                         "Best model should have zero induced error")
      +        self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
      +
       
       if __name__ == "__main__":
           unittest.main()
      diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
      index 0bf988fd72f1..cae778869e9c 100644
      --- a/python/pyspark/ml/tuning.py
      +++ b/python/pyspark/ml/tuning.py
      @@ -223,11 +223,17 @@ def _fit(self, dataset):
                       # TODO: duplicate evaluator to take extra params from input
                       metric = eva.evaluate(model.transform(validation, epm[j]))
                       metrics[j] += metric
      -        bestIndex = np.argmax(metrics)
      +
      +        if eva.isLargerBetter():
      +            bestIndex = np.argmax(metrics)
      +        else:
      +            bestIndex = np.argmin(metrics)
               bestModel = est.fit(dataset, epm[bestIndex])
               return CrossValidatorModel(bestModel)
       
      -    def copy(self, extra={}):
      +    def copy(self, extra=None):
      +        if extra is None:
      +            extra = dict()
               newCV = Params.copy(self, extra)
               if self.isSet(self.estimator):
                   newCV.setEstimator(self.getEstimator().copy(extra))
      @@ -250,7 +256,7 @@ def __init__(self, bestModel):
           def _transform(self, dataset):
               return self.bestModel.transform(dataset)
       
      -    def copy(self, extra={}):
      +    def copy(self, extra=None):
               """
               Creates a copy of this instance with a randomly generated uid
               and some extra params. This copies the underlying bestModel,
      @@ -259,6 +265,8 @@ def copy(self, extra={}):
               :param extra: Extra parameters to copy to the new instance
               :return: Copy of this instance
               """
      +        if extra is None:
      +            extra = dict()
               return CrossValidatorModel(self.bestModel.copy(extra))
       
       
      diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
      index 7b0893e2cdad..8218c7c5f801 100644
      --- a/python/pyspark/ml/wrapper.py
      +++ b/python/pyspark/ml/wrapper.py
      @@ -136,7 +136,8 @@ def _fit(self, dataset):
       class JavaTransformer(Transformer, JavaWrapper):
           """
           Base class for :py:class:`Transformer`s that wrap Java/Scala
      -    implementations.
      +    implementations. Subclasses should ensure they have the transformer Java object
      +    available as _java_obj.
           """
       
           __metaclass__ = ABCMeta
      @@ -166,7 +167,7 @@ def __init__(self, java_model):
               self._java_obj = java_model
               self.uid = java_model.uid()
       
      -    def copy(self, extra={}):
      +    def copy(self, extra=None):
               """
               Creates a copy of this instance with the same uid and some
               extra params. This implementation first calls Params.copy and
      @@ -175,6 +176,8 @@ def copy(self, extra={}):
               :param extra: Extra parameters to copy to the new instance
               :return: Copy of this instance
               """
      +        if extra is None:
      +            extra = dict()
               that = super(JavaModel, self).copy(extra)
               that._java_obj = self._java_obj.copy(self._empty_java_param_map())
               that._transfer_params_to_java()
      diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
      index 42e41397bf4b..cb4ee8367808 100644
      --- a/python/pyspark/mllib/classification.py
      +++ b/python/pyspark/mllib/classification.py
      @@ -21,14 +21,18 @@
       from numpy import array
       
       from pyspark import RDD
      +from pyspark.streaming import DStream
       from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py
       from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector
      -from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper
      +from pyspark.mllib.regression import (
      +    LabeledPoint, LinearModel, _regression_train_wrapper,
      +    StreamingLinearAlgorithm)
       from pyspark.mllib.util import Saveable, Loader, inherit_doc
       
       
       __all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'LogisticRegressionWithLBFGS',
      -           'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes']
      +           'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes',
      +           'StreamingLogisticRegressionWithSGD']
       
       
       class LinearClassificationModel(LinearModel):
      @@ -135,8 +139,9 @@ class LogisticRegressionModel(LinearClassificationModel):
           1
           >>> sameModel.predict(SparseVector(2, {0: 1.0}))
           0
      +    >>> from shutil import rmtree
           >>> try:
      -    ...    os.removedirs(path)
      +    ...    rmtree(path)
           ... except:
           ...    pass
           >>> multi_class_data = [
      @@ -236,7 +241,7 @@ class LogisticRegressionWithSGD(object):
           @classmethod
           def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
                     initialWeights=None, regParam=0.01, regType="l2", intercept=False,
      -              validateData=True):
      +              validateData=True, convergenceTol=0.001):
               """
               Train a logistic regression model on the given data.
       
      @@ -269,11 +274,13 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
               :param validateData:      Boolean parameter which indicates if
                                         the algorithm should validate data
                                         before training. (default: True)
      +        :param convergenceTol:    A condition which decides iteration termination.
      +                                  (default: 0.001)
               """
               def train(rdd, i):
                   return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations),
                                        float(step), float(miniBatchFraction), i, float(regParam), regType,
      -                                 bool(intercept), bool(validateData))
      +                                 bool(intercept), bool(validateData), float(convergenceTol))
       
               return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights)
       
      @@ -387,8 +394,9 @@ class SVMModel(LinearClassificationModel):
           1
           >>> sameModel.predict(SparseVector(2, {0: -1.0}))
           0
      +    >>> from shutil import rmtree
           >>> try:
      -    ...    os.removedirs(path)
      +    ...    rmtree(path)
           ... except:
           ...    pass
           """
      @@ -433,7 +441,7 @@ class SVMWithSGD(object):
           @classmethod
           def train(cls, data, iterations=100, step=1.0, regParam=0.01,
                     miniBatchFraction=1.0, initialWeights=None, regType="l2",
      -              intercept=False, validateData=True):
      +              intercept=False, validateData=True, convergenceTol=0.001):
               """
               Train a support vector machine on the given data.
       
      @@ -466,11 +474,13 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01,
               :param validateData:      Boolean parameter which indicates if
                                         the algorithm should validate data
                                         before training. (default: True)
      +        :param convergenceTol:    A condition which decides iteration termination.
      +                                  (default: 0.001)
               """
               def train(rdd, i):
                   return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step),
                                        float(regParam), float(miniBatchFraction), i, regType,
      -                                 bool(intercept), bool(validateData))
      +                                 bool(intercept), bool(validateData), float(convergenceTol))
       
               return _regression_train_wrapper(train, SVMModel, data, initialWeights)
       
      @@ -515,8 +525,9 @@ class NaiveBayesModel(Saveable, Loader):
           >>> sameModel = NaiveBayesModel.load(sc, path)
           >>> sameModel.predict(SparseVector(2, {0: 1.0})) == model.predict(SparseVector(2, {0: 1.0}))
           True
      +    >>> from shutil import rmtree
           >>> try:
      -    ...     os.removedirs(path)
      +    ...     rmtree(path)
           ... except OSError:
           ...     pass
           """
      @@ -576,10 +587,63 @@ def train(cls, data, lambda_=1.0):
               first = data.first()
               if not isinstance(first, LabeledPoint):
                   raise ValueError("`data` should be an RDD of LabeledPoint")
      -        labels, pi, theta = callMLlibFunc("trainNaiveBayes", data, lambda_)
      +        labels, pi, theta = callMLlibFunc("trainNaiveBayesModel", data, lambda_)
               return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta))
       
       
      +@inherit_doc
      +class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm):
      +    """
      +    Run LogisticRegression with SGD on a batch of data.
      +
      +    The weights obtained at the end of training a stream are used as initial
      +    weights for the next batch.
      +
      +    :param stepSize: Step size for each iteration of gradient descent.
      +    :param numIterations: Number of iterations run for each batch of data.
      +    :param miniBatchFraction: Fraction of data on which SGD is run for each
      +                              iteration.
      +    :param regParam: L2 Regularization parameter.
      +    :param convergenceTol: A condition which decides iteration termination.
      +    """
      +    def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.01,
      +                 convergenceTol=0.001):
      +        self.stepSize = stepSize
      +        self.numIterations = numIterations
      +        self.regParam = regParam
      +        self.miniBatchFraction = miniBatchFraction
      +        self.convergenceTol = convergenceTol
      +        self._model = None
      +        super(StreamingLogisticRegressionWithSGD, self).__init__(
      +            model=self._model)
      +
      +    def setInitialWeights(self, initialWeights):
      +        """
      +        Set the initial value of weights.
      +
      +        This must be set before running trainOn and predictOn.
      +        """
      +        initialWeights = _convert_to_vector(initialWeights)
      +
      +        # LogisticRegressionWithSGD does only binary classification.
      +        self._model = LogisticRegressionModel(
      +            initialWeights, 0, initialWeights.size, 2)
      +        return self
      +
      +    def trainOn(self, dstream):
      +        """Train the model on the incoming dstream."""
      +        self._validate(dstream)
      +
      +        def update(rdd):
      +            # LogisticRegressionWithSGD.train raises an error for an empty RDD.
      +            if not rdd.isEmpty():
      +                self._model = LogisticRegressionWithSGD.train(
      +                    rdd, self.numIterations, self.stepSize,
      +                    self.miniBatchFraction, self._model.weights)
      +
      +        dstream.foreachRDD(update)
      +
      +
       def _test():
           import doctest
           from pyspark import SparkContext
      diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
      index c38229864d3b..900ade248c38 100644
      --- a/python/pyspark/mllib/clustering.py
      +++ b/python/pyspark/mllib/clustering.py
      @@ -20,21 +20,27 @@
       
       if sys.version > '3':
           xrange = range
      +    basestring = str
       
       from math import exp, log
       
       from numpy import array, random, tile
       
      +from collections import namedtuple
      +
       from pyspark import SparkContext
       from pyspark.rdd import RDD, ignore_unicode_prefix
      -from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _py2java, _java2py
      +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, callJavaFunc, _py2java, _java2py
       from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector
      +from pyspark.mllib.regression import LabeledPoint
       from pyspark.mllib.stat.distribution import MultivariateGaussian
      -from pyspark.mllib.util import Saveable, Loader, inherit_doc
      +from pyspark.mllib.util import Saveable, Loader, inherit_doc, JavaLoader, JavaSaveable
       from pyspark.streaming import DStream
       
       __all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture',
      -           'StreamingKMeans', 'StreamingKMeansModel']
      +           'PowerIterationClusteringModel', 'PowerIterationClustering',
      +           'StreamingKMeans', 'StreamingKMeansModel',
      +           'LDA', 'LDAModel']
       
       
       @inherit_doc
      @@ -79,8 +85,9 @@ class KMeansModel(Saveable, Loader):
           >>> sameModel = KMeansModel.load(sc, path)
           >>> sameModel.predict(sparse_data[0]) == model.predict(sparse_data[0])
           True
      +    >>> from shutil import rmtree
           >>> try:
      -    ...     os.removedirs(path)
      +    ...     rmtree(path)
           ... except OSError:
           ...     pass
           """
      @@ -145,11 +152,19 @@ def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"
               return KMeansModel([c.toArray() for c in centers])
       
       
      -class GaussianMixtureModel(object):
      +@inherit_doc
      +class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader):
       
      -    """A clustering model derived from the Gaussian Mixture Model method.
      +    """
      +    .. note:: Experimental
      +
      +    A clustering model derived from the Gaussian Mixture Model method.
       
           >>> from pyspark.mllib.linalg import Vectors, DenseMatrix
      +    >>> from numpy.testing import assert_equal
      +    >>> from shutil import rmtree
      +    >>> import os, tempfile
      +
           >>> clusterdata_1 =  sc.parallelize(array([-0.1,-0.05,-0.01,-0.1,
           ...                                         0.9,0.8,0.75,0.935,
           ...                                        -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2))
      @@ -162,6 +177,25 @@ class GaussianMixtureModel(object):
           True
           >>> labels[4]==labels[5]
           True
      +
      +    >>> path = tempfile.mkdtemp()
      +    >>> model.save(sc, path)
      +    >>> sameModel = GaussianMixtureModel.load(sc, path)
      +    >>> assert_equal(model.weights, sameModel.weights)
      +    >>> mus, sigmas = list(
      +    ...     zip(*[(g.mu, g.sigma) for g in model.gaussians]))
      +    >>> sameMus, sameSigmas = list(
      +    ...     zip(*[(g.mu, g.sigma) for g in sameModel.gaussians]))
      +    >>> mus == sameMus
      +    True
      +    >>> sigmas == sameSigmas
      +    True
      +    >>> from shutil import rmtree
      +    >>> try:
      +    ...     rmtree(path)
      +    ... except OSError:
      +    ...     pass
      +
           >>> data =  array([-5.1971, -2.5359, -3.8220,
           ...                -5.2211, -5.0602,  4.7118,
           ...                 6.8989, 3.4592,  4.6322,
      @@ -175,25 +209,15 @@ class GaussianMixtureModel(object):
           True
           >>> labels[3]==labels[4]
           True
      -    >>> clusterdata_3 = sc.parallelize(data.reshape(15, 1))
      -    >>> im = GaussianMixtureModel([0.5, 0.5],
      -    ...      [MultivariateGaussian(Vectors.dense([-1.0]), DenseMatrix(1, 1, [1.0])),
      -    ...      MultivariateGaussian(Vectors.dense([1.0]), DenseMatrix(1, 1, [1.0]))])
      -    >>> model = GaussianMixture.train(clusterdata_3, 2, initialModel=im)
           """
       
      -    def __init__(self, weights, gaussians):
      -        self._weights = weights
      -        self._gaussians = gaussians
      -        self._k = len(self._weights)
      -
           @property
           def weights(self):
               """
               Weights for each Gaussian distribution in the mixture, where weights[i] is
               the weight for Gaussian i, and weights.sum == 1.
               """
      -        return self._weights
      +        return array(self.call("weights"))
       
           @property
           def gaussians(self):
      @@ -201,12 +225,14 @@ def gaussians(self):
               Array of MultivariateGaussian where gaussians[i] represents
               the Multivariate Gaussian (Normal) Distribution for Gaussian i.
               """
      -        return self._gaussians
      +        return [
      +            MultivariateGaussian(gaussian[0], gaussian[1])
      +            for gaussian in zip(*self.call("gaussians"))]
       
           @property
           def k(self):
               """Number of gaussians in mixture."""
      -        return self._k
      +        return len(self.weights)
       
           def predict(self, x):
               """
      @@ -231,17 +257,30 @@ def predictSoft(self, x):
               :return:     membership_matrix. RDD of array of double values.
               """
               if isinstance(x, RDD):
      -            means, sigmas = zip(*[(g.mu, g.sigma) for g in self._gaussians])
      +            means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians])
                   membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector),
      -                                              _convert_to_vector(self._weights), means, sigmas)
      +                                              _convert_to_vector(self.weights), means, sigmas)
                   return membership_matrix.map(lambda x: pyarray.array('d', x))
               else:
                   raise TypeError("x should be represented by an RDD, "
                                   "but got %s." % type(x))
       
      +    @classmethod
      +    def load(cls, sc, path):
      +        """Load the GaussianMixtureModel from disk.
      +
      +        :param sc: SparkContext
      +        :param path: str, path to where the model is stored.
      +        """
      +        model = cls._load_java(sc, path)
      +        wrapper = sc._jvm.GaussianMixtureModelWrapper(model)
      +        return cls(wrapper)
      +
       
       class GaussianMixture(object):
           """
      +    .. note:: Experimental
      +
           Learning algorithm for Gaussian Mixtures using the expectation-maximization algorithm.
       
           :param data:            RDD of data points
      @@ -264,11 +303,106 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia
                   initialModelWeights = initialModel.weights
                   initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)]
                   initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)]
      -        weight, mu, sigma = callMLlibFunc("trainGaussianMixture", rdd.map(_convert_to_vector), k,
      -                                          convergenceTol, maxIterations, seed, initialModelWeights,
      -                                          initialModelMu, initialModelSigma)
      -        mvg_obj = [MultivariateGaussian(mu[i], sigma[i]) for i in range(k)]
      -        return GaussianMixtureModel(weight, mvg_obj)
      +        java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector),
      +                                   k, convergenceTol, maxIterations, seed,
      +                                   initialModelWeights, initialModelMu, initialModelSigma)
      +        return GaussianMixtureModel(java_model)
      +
      +
      +class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader):
      +
      +    """
      +    .. note:: Experimental
      +
      +    Model produced by [[PowerIterationClustering]].
      +
      +    >>> data = [(0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), (1, 3, 1.0),
      +    ... (2, 3, 1.0), (3, 4, 0.1), (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0),
      +    ... (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0), (10, 11, 1.0),
      +    ... (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0)]
      +    >>> rdd = sc.parallelize(data, 2)
      +    >>> model = PowerIterationClustering.train(rdd, 2, 100)
      +    >>> model.k
      +    2
      +    >>> result = sorted(model.assignments().collect(), key=lambda x: x.id)
      +    >>> result[0].cluster == result[1].cluster == result[2].cluster == result[3].cluster
      +    True
      +    >>> result[4].cluster == result[5].cluster == result[6].cluster == result[7].cluster
      +    True
      +    >>> import os, tempfile
      +    >>> path = tempfile.mkdtemp()
      +    >>> model.save(sc, path)
      +    >>> sameModel = PowerIterationClusteringModel.load(sc, path)
      +    >>> sameModel.k
      +    2
      +    >>> result = sorted(model.assignments().collect(), key=lambda x: x.id)
      +    >>> result[0].cluster == result[1].cluster == result[2].cluster == result[3].cluster
      +    True
      +    >>> result[4].cluster == result[5].cluster == result[6].cluster == result[7].cluster
      +    True
      +    >>> from shutil import rmtree
      +    >>> try:
      +    ...     rmtree(path)
      +    ... except OSError:
      +    ...     pass
      +    """
      +
      +    @property
      +    def k(self):
      +        """
      +        Returns the number of clusters.
      +        """
      +        return self.call("k")
      +
      +    def assignments(self):
      +        """
      +        Returns the cluster assignments of this model.
      +        """
      +        return self.call("getAssignments").map(
      +            lambda x: (PowerIterationClustering.Assignment(*x)))
      +
      +    @classmethod
      +    def load(cls, sc, path):
      +        model = cls._load_java(sc, path)
      +        wrapper = sc._jvm.PowerIterationClusteringModelWrapper(model)
      +        return PowerIterationClusteringModel(wrapper)
      +
      +
      +class PowerIterationClustering(object):
      +    """
      +    .. note:: Experimental
      +
      +    Power Iteration Clustering (PIC), a scalable graph clustering algorithm
      +    developed by [[http://www.icml2010.org/papers/387.pdf Lin and Cohen]].
      +    From the abstract: PIC finds a very low-dimensional embedding of a
      +    dataset using truncated power iteration on a normalized pair-wise
      +    similarity matrix of the data.
      +    """
      +
      +    @classmethod
      +    def train(cls, rdd, k, maxIterations=100, initMode="random"):
      +        """
      +        :param rdd: an RDD of (i, j, s,,ij,,) tuples representing the
      +            affinity matrix, which is the matrix A in the PIC paper.
      +            The similarity s,,ij,, must be nonnegative.
      +            This is a symmetric matrix and hence s,,ij,, = s,,ji,,.
      +            For any (i, j) with nonzero similarity, there should be
      +            either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input.
      +            Tuples with i = j are ignored, because we assume
      +            s,,ij,, = 0.0.
      +        :param k: Number of clusters.
      +        :param maxIterations: Maximum number of iterations of the
      +            PIC algorithm.
      +        :param initMode: Initialization mode.
      +        """
      +        model = callMLlibFunc("trainPowerIterationClusteringModel",
      +                              rdd.map(_convert_to_vector), int(k), int(maxIterations), initMode)
      +        return PowerIterationClusteringModel(model)
      +
      +    class Assignment(namedtuple("Assignment", ["id", "cluster"])):
      +        """
      +        Represents an (id, cluster) tuple.
      +        """
       
       
       class StreamingKMeansModel(KMeansModel):
      @@ -463,9 +597,112 @@ def predictOnValues(self, dstream):
               return dstream.mapValues(lambda x: self._model.predict(x))
       
       
      +class LDAModel(JavaModelWrapper):
      +
      +    """ A clustering model derived from the LDA method.
      +
      +    Latent Dirichlet Allocation (LDA), a topic model designed for text documents.
      +    Terminology
      +    - "word" = "term": an element of the vocabulary
      +    - "token": instance of a term appearing in a document
      +    - "topic": multinomial distribution over words representing some concept
      +    References:
      +    - Original LDA paper (journal version):
      +    Blei, Ng, and Jordan.  "Latent Dirichlet Allocation."  JMLR, 2003.
      +
      +    >>> from pyspark.mllib.linalg import Vectors
      +    >>> from numpy.testing import assert_almost_equal, assert_equal
      +    >>> data = [
      +    ...     [1, Vectors.dense([0.0, 1.0])],
      +    ...     [2, SparseVector(2, {0: 1.0})],
      +    ... ]
      +    >>> rdd =  sc.parallelize(data)
      +    >>> model = LDA.train(rdd, k=2)
      +    >>> model.vocabSize()
      +    2
      +    >>> topics = model.topicsMatrix()
      +    >>> topics_expect = array([[0.5,  0.5], [0.5, 0.5]])
      +    >>> assert_almost_equal(topics, topics_expect, 1)
      +
      +    >>> import os, tempfile
      +    >>> from shutil import rmtree
      +    >>> path = tempfile.mkdtemp()
      +    >>> model.save(sc, path)
      +    >>> sameModel = LDAModel.load(sc, path)
      +    >>> assert_equal(sameModel.topicsMatrix(), model.topicsMatrix())
      +    >>> sameModel.vocabSize() == model.vocabSize()
      +    True
      +    >>> try:
      +    ...     rmtree(path)
      +    ... except OSError:
      +    ...     pass
      +    """
      +
      +    def topicsMatrix(self):
      +        """Inferred topics, where each topic is represented by a distribution over terms."""
      +        return self.call("topicsMatrix").toArray()
      +
      +    def vocabSize(self):
      +        """Vocabulary size (number of terms or terms in the vocabulary)"""
      +        return self.call("vocabSize")
      +
      +    def save(self, sc, path):
      +        """Save the LDAModel on to disk.
      +
      +        :param sc: SparkContext
      +        :param path: str, path to where the model needs to be stored.
      +        """
      +        if not isinstance(sc, SparkContext):
      +            raise TypeError("sc should be a SparkContext, got type %s" % type(sc))
      +        if not isinstance(path, basestring):
      +            raise TypeError("path should be a basestring, got type %s" % type(path))
      +        self._java_model.save(sc._jsc.sc(), path)
      +
      +    @classmethod
      +    def load(cls, sc, path):
      +        """Load the LDAModel from disk.
      +
      +        :param sc: SparkContext
      +        :param path: str, path to where the model is stored.
      +        """
      +        if not isinstance(sc, SparkContext):
      +            raise TypeError("sc should be a SparkContext, got type %s" % type(sc))
      +        if not isinstance(path, basestring):
      +            raise TypeError("path should be a basestring, got type %s" % type(path))
      +        java_model = sc._jvm.org.apache.spark.mllib.clustering.DistributedLDAModel.load(
      +            sc._jsc.sc(), path)
      +        return cls(java_model)
      +
      +
      +class LDA(object):
      +
      +    @classmethod
      +    def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0,
      +              topicConcentration=-1.0, seed=None, checkpointInterval=10, optimizer="em"):
      +        """Train a LDA model.
      +
      +        :param rdd:                 RDD of data points
      +        :param k:                   Number of clusters you want
      +        :param maxIterations:       Number of iterations. Default to 20
      +        :param docConcentration:    Concentration parameter (commonly named "alpha")
      +            for the prior placed on documents' distributions over topics ("theta").
      +        :param topicConcentration:  Concentration parameter (commonly named "beta" or "eta")
      +            for the prior placed on topics' distributions over terms.
      +        :param seed:                Random Seed
      +        :param checkpointInterval:  Period (in iterations) between checkpoints.
      +        :param optimizer:           LDAOptimizer used to perform the actual calculation.
      +            Currently "em", "online" are supported. Default to "em".
      +        """
      +        model = callMLlibFunc("trainLDAModel", rdd, k, maxIterations,
      +                              docConcentration, topicConcentration, seed,
      +                              checkpointInterval, optimizer)
      +        return LDAModel(model)
      +
      +
       def _test():
           import doctest
      -    globs = globals().copy()
      +    import pyspark.mllib.clustering
      +    globs = pyspark.mllib.clustering.__dict__.copy()
           globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
           (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
           globs['sc'].stop()
      diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py
      index 855e85f57155..a439a488de5c 100644
      --- a/python/pyspark/mllib/common.py
      +++ b/python/pyspark/mllib/common.py
      @@ -73,6 +73,8 @@ def _py2java(sc, obj):
           """ Convert Python object into Java """
           if isinstance(obj, RDD):
               obj = _to_java_object_rdd(obj)
      +    elif isinstance(obj, DataFrame):
      +        obj = obj._jdf
           elif isinstance(obj, SparkContext):
               obj = obj._jsc
           elif isinstance(obj, list):
      diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
      index c5cf3a4e7ff2..4398ca86f2ec 100644
      --- a/python/pyspark/mllib/evaluation.py
      +++ b/python/pyspark/mllib/evaluation.py
      @@ -82,7 +82,7 @@ class RegressionMetrics(JavaModelWrapper):
           ...     (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)])
           >>> metrics = RegressionMetrics(predictionAndObservations)
           >>> metrics.explainedVariance
      -    0.95...
      +    8.859...
           >>> metrics.meanAbsoluteError
           0.5...
           >>> metrics.meanSquaredError
      @@ -152,6 +152,10 @@ class MulticlassMetrics(JavaModelWrapper):
           >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
           ...     (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)])
           >>> metrics = MulticlassMetrics(predictionAndLabels)
      +    >>> metrics.confusionMatrix().toArray()
      +    array([[ 2.,  1.,  1.],
      +           [ 1.,  3.,  0.],
      +           [ 0.,  0.,  1.]])
           >>> metrics.falsePositiveRate(0.0)
           0.2...
           >>> metrics.precision(1.0)
      @@ -186,6 +190,13 @@ def __init__(self, predictionAndLabels):
               java_model = java_class(df._jdf)
               super(MulticlassMetrics, self).__init__(java_model)
       
      +    def confusionMatrix(self):
      +        """
      +        Returns confusion matrix: predicted classes are in columns,
      +        they are ordered by class label ascending, as in "labels".
      +        """
      +        return self.call("confusionMatrix")
      +
           def truePositiveRate(self, label):
               """
               Returns true positive rate for a given label (category).
      diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
      index 334f5b86cd39..7b077b058c3f 100644
      --- a/python/pyspark/mllib/feature.py
      +++ b/python/pyspark/mllib/feature.py
      @@ -30,12 +30,13 @@
       
       from py4j.protocol import Py4JJavaError
       
      -from pyspark import SparkContext
      +from pyspark import SparkContext, since
       from pyspark.rdd import RDD, ignore_unicode_prefix
       from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
       from pyspark.mllib.linalg import (
           Vector, Vectors, DenseVector, SparseVector, _convert_to_vector)
       from pyspark.mllib.regression import LabeledPoint
      +from pyspark.mllib.util import JavaLoader, JavaSaveable
       
       __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler',
                  'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel',
      @@ -83,11 +84,14 @@ class Normalizer(VectorTransformer):
           >>> nor2 = Normalizer(float("inf"))
           >>> nor2.transform(v)
           DenseVector([0.0, 0.5, 1.0])
      +
      +    .. versionadded:: 1.2.0
           """
           def __init__(self, p=2.0):
               assert p >= 1.0, "p should be greater than 1.0"
               self.p = float(p)
       
      +    @since('1.2.0')
           def transform(self, vector):
               """
               Applies unit length normalization on a vector.
      @@ -111,6 +115,15 @@ class JavaVectorTransformer(JavaModelWrapper, VectorTransformer):
           """
       
           def transform(self, vector):
      +        """
      +        Applies transformation on a vector or an RDD[Vector].
      +
      +        Note: In Python, transform cannot currently be used within
      +              an RDD transformation or action.
      +              Call transform directly on the RDD instead.
      +
      +        :param vector: Vector or RDD of Vector to be transformed.
      +        """
               if isinstance(vector, RDD):
                   vector = vector.map(_convert_to_vector)
               else:
      @@ -123,7 +136,11 @@ class StandardScalerModel(JavaVectorTransformer):
           .. note:: Experimental
       
           Represents a StandardScaler model that can transform vectors.
      +
      +    .. versionadded:: 1.2.0
           """
      +
      +    @since('1.2.0')
           def transform(self, vector):
               """
               Applies standardization transformation on a vector.
      @@ -139,6 +156,7 @@ def transform(self, vector):
               """
               return JavaVectorTransformer.transform(self, vector)
       
      +    @since('1.4.0')
           def setWithMean(self, withMean):
               """
               Setter of the boolean which decides
      @@ -147,6 +165,7 @@ def setWithMean(self, withMean):
               self.call("setWithMean", withMean)
               return self
       
      +    @since('1.4.0')
           def setWithStd(self, withStd):
               """
               Setter of the boolean which decides
      @@ -179,6 +198,8 @@ class StandardScaler(object):
           >>> for r in result.collect(): r
           DenseVector([-0.7071, 0.7071, -0.7071])
           DenseVector([0.7071, -0.7071, 0.7071])
      +
      +    .. versionadded:: 1.2.0
           """
           def __init__(self, withMean=False, withStd=True):
               if not (withMean or withStd):
      @@ -186,12 +207,13 @@ def __init__(self, withMean=False, withStd=True):
               self.withMean = withMean
               self.withStd = withStd
       
      +    @since('1.2.0')
           def fit(self, dataset):
               """
               Computes the mean and variance and stores as a model to be used
               for later scaling.
       
      -        :param data: The data used to compute the mean and variance
      +        :param dataset: The data used to compute the mean and variance
                            to build the transformation model.
               :return: a StandardScalarModel
               """
      @@ -205,7 +227,11 @@ class ChiSqSelectorModel(JavaVectorTransformer):
           .. note:: Experimental
       
           Represents a Chi Squared selector model.
      +
      +    .. versionadded:: 1.4.0
           """
      +
      +    @since('1.4.0')
           def transform(self, vector):
               """
               Applies transformation on a vector.
      @@ -235,10 +261,13 @@ class ChiSqSelector(object):
           SparseVector(1, {0: 6.0})
           >>> model.transform(DenseVector([8.0, 9.0, 5.0]))
           DenseVector([5.0])
      +
      +    .. versionadded:: 1.4.0
           """
           def __init__(self, numTopFeatures):
               self.numTopFeatures = int(numTopFeatures)
       
      +    @since('1.4.0')
           def fit(self, data):
               """
               Returns a ChiSquared feature selector.
      @@ -255,6 +284,8 @@ def fit(self, data):
       class PCAModel(JavaVectorTransformer):
           """
           Model fitted by [[PCA]] that can project vectors to a low-dimensional space using PCA.
      +
      +    .. versionadded:: 1.5.0
           """
       
       
      @@ -271,6 +302,8 @@ class PCA(object):
           1.648...
           >>> pcArray[1]
           -4.013...
      +
      +    .. versionadded:: 1.5.0
           """
           def __init__(self, k):
               """
      @@ -278,6 +311,7 @@ def __init__(self, k):
               """
               self.k = int(k)
       
      +    @since('1.5.0')
           def fit(self, data):
               """
               Computes a [[PCAModel]] that contains the principal components of the input vectors.
      @@ -302,14 +336,18 @@ class HashingTF(object):
           >>> doc = "a a b b c d".split(" ")
           >>> htf.transform(doc)
           SparseVector(100, {...})
      +
      +    .. versionadded:: 1.2.0
           """
           def __init__(self, numFeatures=1 << 20):
               self.numFeatures = numFeatures
       
      +    @since('1.2.0')
           def indexOf(self, term):
               """ Returns the index of the input term. """
               return hash(term) % self.numFeatures
       
      +    @since('1.2.0')
           def transform(self, document):
               """
               Transforms the input document (list of terms) to term frequency
      @@ -329,7 +367,10 @@ def transform(self, document):
       class IDFModel(JavaVectorTransformer):
           """
           Represents an IDF model that can transform term frequency vectors.
      +
      +    .. versionadded:: 1.2.0
           """
      +    @since('1.2.0')
           def transform(self, x):
               """
               Transforms term frequency (TF) vectors to TF-IDF vectors.
      @@ -346,12 +387,9 @@ def transform(self, x):
                         vector
               :return: an RDD of TF-IDF vectors or a TF-IDF vector
               """
      -        if isinstance(x, RDD):
      -            return JavaVectorTransformer.transform(self, x)
      -
      -        x = _convert_to_vector(x)
               return JavaVectorTransformer.transform(self, x)
       
      +    @since('1.4.0')
           def idf(self):
               """
               Returns the current IDF vector.
      @@ -395,10 +433,13 @@ class IDF(object):
           DenseVector([0.0, 0.0, 1.3863, 0.863])
           >>> model.transform(Vectors.sparse(n, (1, 3), (1.0, 2.0)))
           SparseVector(4, {1: 0.0, 3: 0.5754})
      +
      +    .. versionadded:: 1.2.0
           """
           def __init__(self, minDocFreq=0):
               self.minDocFreq = minDocFreq
       
      +    @since('1.2.0')
           def fit(self, dataset):
               """
               Computes the inverse document frequency.
      @@ -411,10 +452,13 @@ def fit(self, dataset):
               return IDFModel(jmodel)
       
       
      -class Word2VecModel(JavaVectorTransformer):
      +class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader):
           """
           class for Word2Vec model
      +
      +    .. versionadded:: 1.2.0
           """
      +    @since('1.2.0')
           def transform(self, word):
               """
               Transforms a word to its vector representation
      @@ -429,6 +473,7 @@ def transform(self, word):
               except Py4JJavaError:
                   raise ValueError("%s not found" % word)
       
      +    @since('1.2.0')
           def findSynonyms(self, word, num):
               """
               Find synonyms of a word
      @@ -444,12 +489,23 @@ def findSynonyms(self, word, num):
               words, similarity = self.call("findSynonyms", word, num)
               return zip(words, similarity)
       
      +    @since('1.4.0')
           def getVectors(self):
               """
               Returns a map of words to their vector representations.
               """
               return self.call("getVectors")
       
      +    @classmethod
      +    @since('1.5.0')
      +    def load(cls, sc, path):
      +        """
      +        Load a model from the given path.
      +        """
      +        jmodel = sc._jvm.org.apache.spark.mllib.feature \
      +            .Word2VecModel.load(sc._jsc.sc(), path)
      +        return Word2VecModel(jmodel)
      +
       
       @ignore_unicode_prefix
       class Word2Vec(object):
      @@ -483,6 +539,20 @@ class Word2Vec(object):
           >>> syms = model.findSynonyms(vec, 2)
           >>> [s[0] for s in syms]
           [u'b', u'c']
      +
      +    >>> import os, tempfile
      +    >>> path = tempfile.mkdtemp()
      +    >>> model.save(sc, path)
      +    >>> sameModel = Word2VecModel.load(sc, path)
      +    >>> model.transform("a") == sameModel.transform("a")
      +    True
      +    >>> from shutil import rmtree
      +    >>> try:
      +    ...     rmtree(path)
      +    ... except OSError:
      +    ...     pass
      +
      +    .. versionadded:: 1.2.0
           """
           def __init__(self):
               """
      @@ -495,6 +565,7 @@ def __init__(self):
               self.seed = random.randint(0, sys.maxsize)
               self.minCount = 5
       
      +    @since('1.2.0')
           def setVectorSize(self, vectorSize):
               """
               Sets vector size (default: 100).
      @@ -502,6 +573,7 @@ def setVectorSize(self, vectorSize):
               self.vectorSize = vectorSize
               return self
       
      +    @since('1.2.0')
           def setLearningRate(self, learningRate):
               """
               Sets initial learning rate (default: 0.025).
      @@ -509,6 +581,7 @@ def setLearningRate(self, learningRate):
               self.learningRate = learningRate
               return self
       
      +    @since('1.2.0')
           def setNumPartitions(self, numPartitions):
               """
               Sets number of partitions (default: 1). Use a small number for
      @@ -517,6 +590,7 @@ def setNumPartitions(self, numPartitions):
               self.numPartitions = numPartitions
               return self
       
      +    @since('1.2.0')
           def setNumIterations(self, numIterations):
               """
               Sets number of iterations (default: 1), which should be smaller
      @@ -525,6 +599,7 @@ def setNumIterations(self, numIterations):
               self.numIterations = numIterations
               return self
       
      +    @since('1.2.0')
           def setSeed(self, seed):
               """
               Sets random seed.
      @@ -532,6 +607,7 @@ def setSeed(self, seed):
               self.seed = seed
               return self
       
      +    @since('1.4.0')
           def setMinCount(self, minCount):
               """
               Sets minCount, the minimum number of times a token must appear
      @@ -540,6 +616,7 @@ def setMinCount(self, minCount):
               self.minCount = minCount
               return self
       
      +    @since('1.2.0')
           def fit(self, data):
               """
               Computes the vector representation of each word in vocabulary.
      @@ -549,7 +626,7 @@ def fit(self, data):
               """
               if not isinstance(data, RDD):
                   raise TypeError("data should be an RDD of list of string")
      -        jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize),
      +        jmodel = callMLlibFunc("trainWord2VecModel", data, int(self.vectorSize),
                                      float(self.learningRate), int(self.numPartitions),
                                      int(self.numIterations), int(self.seed),
                                      int(self.minCount))
      @@ -572,10 +649,13 @@ class ElementwiseProduct(VectorTransformer):
           >>> rdd = sc.parallelize([a, b])
           >>> eprod.transform(rdd).collect()
           [DenseVector([2.0, 2.0, 9.0]), DenseVector([9.0, 6.0, 12.0])]
      +
      +    .. versionadded:: 1.5.0
           """
           def __init__(self, scalingVector):
               self.scalingVector = _convert_to_vector(scalingVector)
       
      +    @since('1.5.0')
           def transform(self, vector):
               """
               Computes the Hadamard product of the vector.
      diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg/__init__.py
      similarity index 78%
      rename from python/pyspark/mllib/linalg.py
      rename to python/pyspark/mllib/linalg/__init__.py
      index e96c5ef87df8..4829acb16ed8 100644
      --- a/python/pyspark/mllib/linalg.py
      +++ b/python/pyspark/mllib/linalg/__init__.py
      @@ -25,12 +25,15 @@
       
       import sys
       import array
      +import struct
       
       if sys.version >= '3':
           basestring = str
           xrange = range
           import copyreg as copy_reg
      +    long = int
       else:
      +    from itertools import izip as zip
           import copy_reg
       
       import numpy as np
      @@ -116,6 +119,17 @@ def _format_float(f, digits=4):
           return s
       
       
      +def _format_float_list(l):
      +    return [_format_float(x) for x in l]
      +
      +
      +def _double_to_long_bits(value):
      +    if np.isnan(value):
      +        value = float('nan')
      +    # pack double into 64 bits, then unpack as long int
      +    return struct.unpack('Q', struct.pack('d', value))[0]
      +
      +
       class VectorUDT(UserDefinedType):
           """
           SQL user-defined type (UDT) for Vector.
      @@ -385,6 +399,10 @@ def squared_distance(self, other):
           def toArray(self):
               return self.array
       
      +    @property
      +    def values(self):
      +        return self.array
      +
           def __getitem__(self, item):
               return self.array[item]
       
      @@ -398,11 +416,31 @@ def __repr__(self):
               return "DenseVector([%s])" % (', '.join(_format_float(i) for i in self.array))
       
           def __eq__(self, other):
      -        return isinstance(other, DenseVector) and np.array_equal(self.array, other.array)
      +        if isinstance(other, DenseVector):
      +            return np.array_equal(self.array, other.array)
      +        elif isinstance(other, SparseVector):
      +            if len(self) != other.size:
      +                return False
      +            return Vectors._equals(list(xrange(len(self))), self.array, other.indices, other.values)
      +        return False
       
           def __ne__(self, other):
               return not self == other
       
      +    def __hash__(self):
      +        size = len(self)
      +        result = 31 + size
      +        nnz = 0
      +        i = 0
      +        while i < size and nnz < 128:
      +            if self.array[i] != 0:
      +                result = 31 * result + i
      +                bits = _double_to_long_bits(self.array[i])
      +                result = 31 * result + (bits ^ (bits >> 32))
      +                nnz += 1
      +            i += 1
      +        return result
      +
           def __getattr__(self, item):
               return getattr(self.array, item)
       
      @@ -440,8 +478,10 @@ def __init__(self, size, *args):
               values (sorted by index).
       
               :param size: Size of the vector.
      -        :param args: Non-zero entries, as a dictionary, list of tupes,
      -               or two sorted lists containing indices and values.
      +        :param args: Active entries, as a dictionary {index: value, ...},
      +          a list of tuples [(index, value), ...], or a list of strictly i
      +          ncreasing indices and a list of corresponding values [index, ...],
      +          [value, ...]. Inactive entries are treated as zeros.
       
               >>> SparseVector(4, {1: 1.0, 3: 5.5})
               SparseVector(4, {1: 1.0, 3: 5.5})
      @@ -451,6 +491,7 @@ def __init__(self, size, *args):
               SparseVector(4, {1: 1.0, 3: 5.5})
               """
               self.size = int(size)
      +        """ Size of the vector. """
               assert 1 <= len(args) <= 2, "must pass either 2 or 3 arguments"
               if len(args) == 1:
                   pairs = args[0]
      @@ -458,7 +499,9 @@ def __init__(self, size, *args):
                       pairs = pairs.items()
                   pairs = sorted(pairs)
                   self.indices = np.array([p[0] for p in pairs], dtype=np.int32)
      +            """ A list of indices corresponding to active entries. """
                   self.values = np.array([p[1] for p in pairs], dtype=np.float64)
      +            """ A list of values corresponding to active entries. """
               else:
                   if isinstance(args[0], bytes):
                       assert isinstance(args[1], bytes), "values should be string too"
      @@ -555,7 +598,7 @@ def dot(self, other):
               25.0
               >>> a.dot(array.array('d', [1., 2., 3., 4.]))
               22.0
      -        >>> b = SparseVector(4, [2, 4], [1.0, 2.0])
      +        >>> b = SparseVector(4, [2], [1.0])
               >>> a.dot(b)
               0.0
               >>> a.dot(np.array([[1, 1], [2, 2], [3, 3], [4, 4]]))
      @@ -577,34 +620,27 @@ def dot(self, other):
                   ...
               AssertionError: dimension mismatch
               """
      -        if type(other) == np.ndarray:
      -            if other.ndim == 2:
      -                results = [self.dot(other[:, i]) for i in xrange(other.shape[1])]
      -                return np.array(results)
      -            elif other.ndim > 2:
      +
      +        if isinstance(other, np.ndarray):
      +            if other.ndim not in [2, 1]:
                       raise ValueError("Cannot call dot with %d-dimensional array" % other.ndim)
      +            assert len(self) == other.shape[0], "dimension mismatch"
      +            return np.dot(self.values, other[self.indices])
       
               assert len(self) == _vector_size(other), "dimension mismatch"
       
      -        if type(other) in (np.ndarray, array.array, DenseVector):
      -            result = 0.0
      -            for i in xrange(len(self.indices)):
      -                result += self.values[i] * other[self.indices[i]]
      -            return result
      +        if isinstance(other, DenseVector):
      +            return np.dot(other.array[self.indices], self.values)
       
      -        elif type(other) is SparseVector:
      -            result = 0.0
      -            i, j = 0, 0
      -            while i < len(self.indices) and j < len(other.indices):
      -                if self.indices[i] == other.indices[j]:
      -                    result += self.values[i] * other.values[j]
      -                    i += 1
      -                    j += 1
      -                elif self.indices[i] < other.indices[j]:
      -                    i += 1
      -                else:
      -                    j += 1
      -            return result
      +        elif isinstance(other, SparseVector):
      +            # Find out common indices.
      +            self_cmind = np.in1d(self.indices, other.indices, assume_unique=True)
      +            self_values = self.values[self_cmind]
      +            if self_values.size == 0:
      +                return 0.0
      +            else:
      +                other_cmind = np.in1d(other.indices, self.indices, assume_unique=True)
      +                return np.dot(self_values, other.values[other_cmind])
       
               else:
                   return self.dot(_convert_to_vector(other))
      @@ -620,11 +656,11 @@ def squared_distance(self, other):
               11.0
               >>> a.squared_distance(np.array([1., 2., 3., 4.]))
               11.0
      -        >>> b = SparseVector(4, [2, 4], [1.0, 2.0])
      +        >>> b = SparseVector(4, [2], [1.0])
               >>> a.squared_distance(b)
      -        30.0
      +        26.0
               >>> b.squared_distance(a)
      -        30.0
      +        26.0
               >>> b.squared_distance([1., 2.])
               Traceback (most recent call last):
                   ...
      @@ -635,22 +671,23 @@ def squared_distance(self, other):
               AssertionError: dimension mismatch
               """
               assert len(self) == _vector_size(other), "dimension mismatch"
      -        if type(other) in (list, array.array, DenseVector, np.array, np.ndarray):
      -            if type(other) is np.array and other.ndim != 1:
      +
      +        if isinstance(other, np.ndarray) or isinstance(other, DenseVector):
      +            if isinstance(other, np.ndarray) and other.ndim != 1:
                       raise Exception("Cannot call squared_distance with %d-dimensional array" %
                                       other.ndim)
      -            result = 0.0
      -            j = 0   # index into our own array
      -            for i in xrange(len(other)):
      -                if j < len(self.indices) and self.indices[j] == i:
      -                    diff = self.values[j] - other[i]
      -                    result += diff * diff
      -                    j += 1
      -                else:
      -                    result += other[i] * other[i]
      +            if isinstance(other, DenseVector):
      +                other = other.array
      +            sparse_ind = np.zeros(other.size, dtype=bool)
      +            sparse_ind[self.indices] = True
      +            dist = other[sparse_ind] - self.values
      +            result = np.dot(dist, dist)
      +
      +            other_ind = other[~sparse_ind]
      +            result += np.dot(other_ind, other_ind)
                   return result
       
      -        elif type(other) is SparseVector:
      +        elif isinstance(other, SparseVector):
                   result = 0.0
                   i, j = 0, 0
                   while i < len(self.indices) and j < len(other.indices):
      @@ -699,20 +736,14 @@ def __repr__(self):
               return "SparseVector({0}, {{{1}}})".format(self.size, entries)
       
           def __eq__(self, other):
      -        """
      -        Test SparseVectors for equality.
      -
      -        >>> v1 = SparseVector(4, [(1, 1.0), (3, 5.5)])
      -        >>> v2 = SparseVector(4, [(1, 1.0), (3, 5.5)])
      -        >>> v1 == v2
      -        True
      -        >>> v1 != v2
      -        False
      -        """
      -        return (isinstance(other, self.__class__)
      -                and other.size == self.size
      -                and np.array_equal(other.indices, self.indices)
      -                and np.array_equal(other.values, self.values))
      +        if isinstance(other, SparseVector):
      +            return other.size == self.size and np.array_equal(other.indices, self.indices) \
      +                and np.array_equal(other.values, self.values)
      +        elif isinstance(other, DenseVector):
      +            if self.size != len(other):
      +                return False
      +            return Vectors._equals(self.indices, self.values, list(xrange(len(other))), other.array)
      +        return False
       
           def __getitem__(self, index):
               inds = self.indices
      @@ -734,6 +765,19 @@ def __getitem__(self, index):
           def __ne__(self, other):
               return not self.__eq__(other)
       
      +    def __hash__(self):
      +        result = 31 + self.size
      +        nnz = 0
      +        i = 0
      +        while i < len(self.values) and nnz < 128:
      +            if self.values[i] != 0:
      +                result = 31 * result + int(self.indices[i])
      +                bits = _double_to_long_bits(self.values[i])
      +                result = 31 * result + (bits ^ (bits >> 32))
      +                nnz += 1
      +            i += 1
      +        return result
      +
       
       class Vectors(object):
       
      @@ -766,14 +810,18 @@ def sparse(size, *args):
               return SparseVector(size, *args)
       
           @staticmethod
      -    def dense(elements):
      +    def dense(*elements):
               """
      -        Create a dense vector of 64-bit floats from a Python list. Always
      -        returns a NumPy array.
      +        Create a dense vector of 64-bit floats from a Python list or numbers.
       
               >>> Vectors.dense([1, 2, 3])
               DenseVector([1.0, 2.0, 3.0])
      +        >>> Vectors.dense(1.0, 2.0)
      +        DenseVector([1.0, 2.0])
               """
      +        if len(elements) == 1 and not isinstance(elements[0], (float, int, long)):
      +            # it's list, numpy.array or other iterable object.
      +            elements = elements[0]
               return DenseVector(elements)
       
           @staticmethod
      @@ -832,6 +880,31 @@ def parse(s):
           def zeros(size):
               return DenseVector(np.zeros(size))
       
      +    @staticmethod
      +    def _equals(v1_indices, v1_values, v2_indices, v2_values):
      +        """
      +        Check equality between sparse/dense vectors,
      +        v1_indices and v2_indices assume to be strictly increasing.
      +        """
      +        v1_size = len(v1_values)
      +        v2_size = len(v2_values)
      +        k1 = 0
      +        k2 = 0
      +        all_equal = True
      +        while all_equal:
      +            while k1 < v1_size and v1_values[k1] == 0:
      +                k1 += 1
      +            while k2 < v2_size and v2_values[k2] == 0:
      +                k2 += 1
      +
      +            if k1 >= v1_size or k2 >= v2_size:
      +                return k1 >= v1_size and k2 >= v2_size
      +
      +            all_equal = v1_indices[k1] == v2_indices[k2] and v1_values[k1] == v2_values[k2]
      +            k1 += 1
      +            k2 += 1
      +        return all_equal
      +
       
       class Matrix(object):
       
      @@ -876,6 +949,50 @@ def __reduce__(self):
                   self.numRows, self.numCols, self.values.tostring(),
                   int(self.isTransposed))
       
      +    def __str__(self):
      +        """
      +        Pretty printing of a DenseMatrix
      +
      +        >>> dm = DenseMatrix(2, 2, range(4))
      +        >>> print(dm)
      +        DenseMatrix([[ 0.,  2.],
      +                     [ 1.,  3.]])
      +        >>> dm = DenseMatrix(2, 2, range(4), isTransposed=True)
      +        >>> print(dm)
      +        DenseMatrix([[ 0.,  1.],
      +                     [ 2.,  3.]])
      +        """
      +        # Inspired by __repr__ in scipy matrices.
      +        array_lines = repr(self.toArray()).splitlines()
      +
      +        # We need to adjust six spaces which is the difference in number
      +        # of letters between "DenseMatrix" and "array"
      +        x = '\n'.join([(" " * 6 + line) for line in array_lines[1:]])
      +        return array_lines[0].replace("array", "DenseMatrix") + "\n" + x
      +
      +    def __repr__(self):
      +        """
      +        Representation of a DenseMatrix
      +
      +        >>> dm = DenseMatrix(2, 2, range(4))
      +        >>> dm
      +        DenseMatrix(2, 2, [0.0, 1.0, 2.0, 3.0], False)
      +        """
      +        # If the number of values are less than seventeen then return as it is.
      +        # Else return first eight values and last eight values.
      +        if len(self.values) < 17:
      +            entries = _format_float_list(self.values)
      +        else:
      +            entries = (
      +                _format_float_list(self.values[:8]) +
      +                ["..."] +
      +                _format_float_list(self.values[-8:])
      +            )
      +
      +        entries = ", ".join(entries)
      +        return "DenseMatrix({0}, {1}, [{2}], {3})".format(
      +            self.numRows, self.numCols, entries, self.isTransposed)
      +
           def toArray(self):
               """
               Return an numpy.ndarray
      @@ -952,6 +1069,84 @@ def __init__(self, numRows, numCols, colPtrs, rowIndices, values,
                   raise ValueError("Expected rowIndices of length %d, got %d."
                                    % (self.rowIndices.size, self.values.size))
       
      +    def __str__(self):
      +        """
      +        Pretty printing of a SparseMatrix
      +
      +        >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4])
      +        >>> print(sm1)
      +        2 X 2 CSCMatrix
      +        (0,0) 2.0
      +        (1,0) 3.0
      +        (1,1) 4.0
      +        >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True)
      +        >>> print(sm1)
      +        2 X 2 CSRMatrix
      +        (0,0) 2.0
      +        (0,1) 3.0
      +        (1,1) 4.0
      +        """
      +        spstr = "{0} X {1} ".format(self.numRows, self.numCols)
      +        if self.isTransposed:
      +            spstr += "CSRMatrix\n"
      +        else:
      +            spstr += "CSCMatrix\n"
      +
      +        cur_col = 0
      +        smlist = []
      +
      +        # Display first 16 values.
      +        if len(self.values) <= 16:
      +            zipindval = zip(self.rowIndices, self.values)
      +        else:
      +            zipindval = zip(self.rowIndices[:16], self.values[:16])
      +        for i, (rowInd, value) in enumerate(zipindval):
      +            if self.colPtrs[cur_col + 1] <= i:
      +                cur_col += 1
      +            if self.isTransposed:
      +                smlist.append('({0},{1}) {2}'.format(
      +                    cur_col, rowInd, _format_float(value)))
      +            else:
      +                smlist.append('({0},{1}) {2}'.format(
      +                    rowInd, cur_col, _format_float(value)))
      +        spstr += "\n".join(smlist)
      +
      +        if len(self.values) > 16:
      +            spstr += "\n.." * 2
      +        return spstr
      +
      +    def __repr__(self):
      +        """
      +        Representation of a SparseMatrix
      +
      +        >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4])
      +        >>> sm1
      +        SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2.0, 3.0, 4.0], False)
      +        """
      +        rowIndices = list(self.rowIndices)
      +        colPtrs = list(self.colPtrs)
      +
      +        if len(self.values) <= 16:
      +            values = _format_float_list(self.values)
      +
      +        else:
      +            values = (
      +                _format_float_list(self.values[:8]) +
      +                ["..."] +
      +                _format_float_list(self.values[-8:])
      +            )
      +            rowIndices = rowIndices[:8] + ["..."] + rowIndices[-8:]
      +
      +        if len(self.colPtrs) > 16:
      +            colPtrs = colPtrs[:8] + ["..."] + colPtrs[-8:]
      +
      +        values = ", ".join(values)
      +        rowIndices = ", ".join([str(ind) for ind in rowIndices])
      +        colPtrs = ", ".join([str(ptr) for ptr in colPtrs])
      +        return "SparseMatrix({0}, {1}, [{2}], [{3}], [{4}], {5})".format(
      +            self.numRows, self.numCols, colPtrs, rowIndices,
      +            values, self.isTransposed)
      +
           def __reduce__(self):
               return SparseMatrix, (
                   self.numRows, self.numCols, self.colPtrs.tostring(),
      diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py
      new file mode 100644
      index 000000000000..aec407de90aa
      --- /dev/null
      +++ b/python/pyspark/mllib/linalg/distributed.py
      @@ -0,0 +1,853 @@
      +#
      +# Licensed to the Apache Software Foundation (ASF) under one or more
      +# contributor license agreements.  See the NOTICE file distributed with
      +# this work for additional information regarding copyright ownership.
      +# The ASF licenses this file to You under the Apache License, Version 2.0
      +# (the "License"); you may not use this file except in compliance with
      +# the License.  You may obtain a copy of the License at
      +#
      +#    http://www.apache.org/licenses/LICENSE-2.0
      +#
      +# Unless required by applicable law or agreed to in writing, software
      +# distributed under the License is distributed on an "AS IS" BASIS,
      +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      +# See the License for the specific language governing permissions and
      +# limitations under the License.
      +#
      +
      +"""
      +Package for distributed linear algebra.
      +"""
      +
      +import sys
      +
      +if sys.version >= '3':
      +    long = int
      +
      +from py4j.java_gateway import JavaObject
      +
      +from pyspark import RDD
      +from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
      +from pyspark.mllib.linalg import _convert_to_vector, Matrix
      +
      +
      +__all__ = ['DistributedMatrix', 'RowMatrix', 'IndexedRow',
      +           'IndexedRowMatrix', 'MatrixEntry', 'CoordinateMatrix',
      +           'BlockMatrix']
      +
      +
      +class DistributedMatrix(object):
      +    """
      +    .. note:: Experimental
      +
      +    Represents a distributively stored matrix backed by one or
      +    more RDDs.
      +
      +    """
      +    def numRows(self):
      +        """Get or compute the number of rows."""
      +        raise NotImplementedError
      +
      +    def numCols(self):
      +        """Get or compute the number of cols."""
      +        raise NotImplementedError
      +
      +
      +class RowMatrix(DistributedMatrix):
      +    """
      +    .. note:: Experimental
      +
      +    Represents a row-oriented distributed Matrix with no meaningful
      +    row indices.
      +
      +    :param rows: An RDD of vectors.
      +    :param numRows: Number of rows in the matrix. A non-positive
      +                    value means unknown, at which point the number
      +                    of rows will be determined by the number of
      +                    records in the `rows` RDD.
      +    :param numCols: Number of columns in the matrix. A non-positive
      +                    value means unknown, at which point the number
      +                    of columns will be determined by the size of
      +                    the first row.
      +    """
      +    def __init__(self, rows, numRows=0, numCols=0):
      +        """
      +        Note: This docstring is not shown publicly.
      +
      +        Create a wrapper over a Java RowMatrix.
      +
      +        Publicly, we require that `rows` be an RDD.  However, for
      +        internal usage, `rows` can also be a Java RowMatrix
      +        object, in which case we can wrap it directly.  This
      +        assists in clean matrix conversions.
      +
      +        >>> rows = sc.parallelize([[1, 2, 3], [4, 5, 6]])
      +        >>> mat = RowMatrix(rows)
      +
      +        >>> mat_diff = RowMatrix(rows)
      +        >>> (mat_diff._java_matrix_wrapper._java_model ==
      +        ...  mat._java_matrix_wrapper._java_model)
      +        False
      +
      +        >>> mat_same = RowMatrix(mat._java_matrix_wrapper._java_model)
      +        >>> (mat_same._java_matrix_wrapper._java_model ==
      +        ...  mat._java_matrix_wrapper._java_model)
      +        True
      +        """
      +        if isinstance(rows, RDD):
      +            rows = rows.map(_convert_to_vector)
      +            java_matrix = callMLlibFunc("createRowMatrix", rows, long(numRows), int(numCols))
      +        elif (isinstance(rows, JavaObject)
      +              and rows.getClass().getSimpleName() == "RowMatrix"):
      +            java_matrix = rows
      +        else:
      +            raise TypeError("rows should be an RDD of vectors, got %s" % type(rows))
      +
      +        self._java_matrix_wrapper = JavaModelWrapper(java_matrix)
      +
      +    @property
      +    def rows(self):
      +        """
      +        Rows of the RowMatrix stored as an RDD of vectors.
      +
      +        >>> mat = RowMatrix(sc.parallelize([[1, 2, 3], [4, 5, 6]]))
      +        >>> rows = mat.rows
      +        >>> rows.first()
      +        DenseVector([1.0, 2.0, 3.0])
      +        """
      +        return self._java_matrix_wrapper.call("rows")
      +
      +    def numRows(self):
      +        """
      +        Get or compute the number of rows.
      +
      +        >>> rows = sc.parallelize([[1, 2, 3], [4, 5, 6],
      +        ...                        [7, 8, 9], [10, 11, 12]])
      +
      +        >>> mat = RowMatrix(rows)
      +        >>> print(mat.numRows())
      +        4
      +
      +        >>> mat = RowMatrix(rows, 7, 6)
      +        >>> print(mat.numRows())
      +        7
      +        """
      +        return self._java_matrix_wrapper.call("numRows")
      +
      +    def numCols(self):
      +        """
      +        Get or compute the number of cols.
      +
      +        >>> rows = sc.parallelize([[1, 2, 3], [4, 5, 6],
      +        ...                        [7, 8, 9], [10, 11, 12]])
      +
      +        >>> mat = RowMatrix(rows)
      +        >>> print(mat.numCols())
      +        3
      +
      +        >>> mat = RowMatrix(rows, 7, 6)
      +        >>> print(mat.numCols())
      +        6
      +        """
      +        return self._java_matrix_wrapper.call("numCols")
      +
      +
      +class IndexedRow(object):
      +    """
      +    .. note:: Experimental
      +
      +    Represents a row of an IndexedRowMatrix.
      +
      +    Just a wrapper over a (long, vector) tuple.
      +
      +    :param index: The index for the given row.
      +    :param vector: The row in the matrix at the given index.
      +    """
      +    def __init__(self, index, vector):
      +        self.index = long(index)
      +        self.vector = _convert_to_vector(vector)
      +
      +    def __repr__(self):
      +        return "IndexedRow(%s, %s)" % (self.index, self.vector)
      +
      +
      +def _convert_to_indexed_row(row):
      +    if isinstance(row, IndexedRow):
      +        return row
      +    elif isinstance(row, tuple) and len(row) == 2:
      +        return IndexedRow(*row)
      +    else:
      +        raise TypeError("Cannot convert type %s into IndexedRow" % type(row))
      +
      +
      +class IndexedRowMatrix(DistributedMatrix):
      +    """
      +    .. note:: Experimental
      +
      +    Represents a row-oriented distributed Matrix with indexed rows.
      +
      +    :param rows: An RDD of IndexedRows or (long, vector) tuples.
      +    :param numRows: Number of rows in the matrix. A non-positive
      +                    value means unknown, at which point the number
      +                    of rows will be determined by the max row
      +                    index plus one.
      +    :param numCols: Number of columns in the matrix. A non-positive
      +                    value means unknown, at which point the number
      +                    of columns will be determined by the size of
      +                    the first row.
      +    """
      +    def __init__(self, rows, numRows=0, numCols=0):
      +        """
      +        Note: This docstring is not shown publicly.
      +
      +        Create a wrapper over a Java IndexedRowMatrix.
      +
      +        Publicly, we require that `rows` be an RDD.  However, for
      +        internal usage, `rows` can also be a Java IndexedRowMatrix
      +        object, in which case we can wrap it directly.  This
      +        assists in clean matrix conversions.
      +
      +        >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]),
      +        ...                        IndexedRow(1, [4, 5, 6])])
      +        >>> mat = IndexedRowMatrix(rows)
      +
      +        >>> mat_diff = IndexedRowMatrix(rows)
      +        >>> (mat_diff._java_matrix_wrapper._java_model ==
      +        ...  mat._java_matrix_wrapper._java_model)
      +        False
      +
      +        >>> mat_same = IndexedRowMatrix(mat._java_matrix_wrapper._java_model)
      +        >>> (mat_same._java_matrix_wrapper._java_model ==
      +        ...  mat._java_matrix_wrapper._java_model)
      +        True
      +        """
      +        if isinstance(rows, RDD):
      +            rows = rows.map(_convert_to_indexed_row)
      +            # We use DataFrames for serialization of IndexedRows from
      +            # Python, so first convert the RDD to a DataFrame on this
      +            # side. This will convert each IndexedRow to a Row
      +            # containing the 'index' and 'vector' values, which can
      +            # both be easily serialized.  We will convert back to
      +            # IndexedRows on the Scala side.
      +            java_matrix = callMLlibFunc("createIndexedRowMatrix", rows.toDF(),
      +                                        long(numRows), int(numCols))
      +        elif (isinstance(rows, JavaObject)
      +              and rows.getClass().getSimpleName() == "IndexedRowMatrix"):
      +            java_matrix = rows
      +        else:
      +            raise TypeError("rows should be an RDD of IndexedRows or (long, vector) tuples, "
      +                            "got %s" % type(rows))
      +
      +        self._java_matrix_wrapper = JavaModelWrapper(java_matrix)
      +
      +    @property
      +    def rows(self):
      +        """
      +        Rows of the IndexedRowMatrix stored as an RDD of IndexedRows.
      +
      +        >>> mat = IndexedRowMatrix(sc.parallelize([IndexedRow(0, [1, 2, 3]),
      +        ...                                        IndexedRow(1, [4, 5, 6])]))
      +        >>> rows = mat.rows
      +        >>> rows.first()
      +        IndexedRow(0, [1.0,2.0,3.0])
      +        """
      +        # We use DataFrames for serialization of IndexedRows from
      +        # Java, so we first convert the RDD of rows to a DataFrame
      +        # on the Scala/Java side. Then we map each Row in the
      +        # DataFrame back to an IndexedRow on this side.
      +        rows_df = callMLlibFunc("getIndexedRows", self._java_matrix_wrapper._java_model)
      +        rows = rows_df.map(lambda row: IndexedRow(row[0], row[1]))
      +        return rows
      +
      +    def numRows(self):
      +        """
      +        Get or compute the number of rows.
      +
      +        >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]),
      +        ...                        IndexedRow(1, [4, 5, 6]),
      +        ...                        IndexedRow(2, [7, 8, 9]),
      +        ...                        IndexedRow(3, [10, 11, 12])])
      +
      +        >>> mat = IndexedRowMatrix(rows)
      +        >>> print(mat.numRows())
      +        4
      +
      +        >>> mat = IndexedRowMatrix(rows, 7, 6)
      +        >>> print(mat.numRows())
      +        7
      +        """
      +        return self._java_matrix_wrapper.call("numRows")
      +
      +    def numCols(self):
      +        """
      +        Get or compute the number of cols.
      +
      +        >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]),
      +        ...                        IndexedRow(1, [4, 5, 6]),
      +        ...                        IndexedRow(2, [7, 8, 9]),
      +        ...                        IndexedRow(3, [10, 11, 12])])
      +
      +        >>> mat = IndexedRowMatrix(rows)
      +        >>> print(mat.numCols())
      +        3
      +
      +        >>> mat = IndexedRowMatrix(rows, 7, 6)
      +        >>> print(mat.numCols())
      +        6
      +        """
      +        return self._java_matrix_wrapper.call("numCols")
      +
      +    def toRowMatrix(self):
      +        """
      +        Convert this matrix to a RowMatrix.
      +
      +        >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]),
      +        ...                        IndexedRow(6, [4, 5, 6])])
      +        >>> mat = IndexedRowMatrix(rows).toRowMatrix()
      +        >>> mat.rows.collect()
      +        [DenseVector([1.0, 2.0, 3.0]), DenseVector([4.0, 5.0, 6.0])]
      +        """
      +        java_row_matrix = self._java_matrix_wrapper.call("toRowMatrix")
      +        return RowMatrix(java_row_matrix)
      +
      +    def toCoordinateMatrix(self):
      +        """
      +        Convert this matrix to a CoordinateMatrix.
      +
      +        >>> rows = sc.parallelize([IndexedRow(0, [1, 0]),
      +        ...                        IndexedRow(6, [0, 5])])
      +        >>> mat = IndexedRowMatrix(rows).toCoordinateMatrix()
      +        >>> mat.entries.take(3)
      +        [MatrixEntry(0, 0, 1.0), MatrixEntry(0, 1, 0.0), MatrixEntry(6, 0, 0.0)]
      +        """
      +        java_coordinate_matrix = self._java_matrix_wrapper.call("toCoordinateMatrix")
      +        return CoordinateMatrix(java_coordinate_matrix)
      +
      +    def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024):
      +        """
      +        Convert this matrix to a BlockMatrix.
      +
      +        :param rowsPerBlock: Number of rows that make up each block.
      +                             The blocks forming the final rows are not
      +                             required to have the given number of rows.
      +        :param colsPerBlock: Number of columns that make up each block.
      +                             The blocks forming the final columns are not
      +                             required to have the given number of columns.
      +
      +        >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]),
      +        ...                        IndexedRow(6, [4, 5, 6])])
      +        >>> mat = IndexedRowMatrix(rows).toBlockMatrix()
      +
      +        >>> # This IndexedRowMatrix will have 7 effective rows, due to
      +        >>> # the highest row index being 6, and the ensuing
      +        >>> # BlockMatrix will have 7 rows as well.
      +        >>> print(mat.numRows())
      +        7
      +
      +        >>> print(mat.numCols())
      +        3
      +        """
      +        java_block_matrix = self._java_matrix_wrapper.call("toBlockMatrix",
      +                                                           rowsPerBlock,
      +                                                           colsPerBlock)
      +        return BlockMatrix(java_block_matrix, rowsPerBlock, colsPerBlock)
      +
      +
      +class MatrixEntry(object):
      +    """
      +    .. note:: Experimental
      +
      +    Represents an entry of a CoordinateMatrix.
      +
      +    Just a wrapper over a (long, long, float) tuple.
      +
      +    :param i: The row index of the matrix.
      +    :param j: The column index of the matrix.
      +    :param value: The (i, j)th entry of the matrix, as a float.
      +    """
      +    def __init__(self, i, j, value):
      +        self.i = long(i)
      +        self.j = long(j)
      +        self.value = float(value)
      +
      +    def __repr__(self):
      +        return "MatrixEntry(%s, %s, %s)" % (self.i, self.j, self.value)
      +
      +
      +def _convert_to_matrix_entry(entry):
      +    if isinstance(entry, MatrixEntry):
      +        return entry
      +    elif isinstance(entry, tuple) and len(entry) == 3:
      +        return MatrixEntry(*entry)
      +    else:
      +        raise TypeError("Cannot convert type %s into MatrixEntry" % type(entry))
      +
      +
      +class CoordinateMatrix(DistributedMatrix):
      +    """
      +    .. note:: Experimental
      +
      +    Represents a matrix in coordinate format.
      +
      +    :param entries: An RDD of MatrixEntry inputs or
      +                    (long, long, float) tuples.
      +    :param numRows: Number of rows in the matrix. A non-positive
      +                    value means unknown, at which point the number
      +                    of rows will be determined by the max row
      +                    index plus one.
      +    :param numCols: Number of columns in the matrix. A non-positive
      +                    value means unknown, at which point the number
      +                    of columns will be determined by the max row
      +                    index plus one.
      +    """
      +    def __init__(self, entries, numRows=0, numCols=0):
      +        """
      +        Note: This docstring is not shown publicly.
      +
      +        Create a wrapper over a Java CoordinateMatrix.
      +
      +        Publicly, we require that `rows` be an RDD.  However, for
      +        internal usage, `rows` can also be a Java CoordinateMatrix
      +        object, in which case we can wrap it directly.  This
      +        assists in clean matrix conversions.
      +
      +        >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2),
      +        ...                           MatrixEntry(6, 4, 2.1)])
      +        >>> mat = CoordinateMatrix(entries)
      +
      +        >>> mat_diff = CoordinateMatrix(entries)
      +        >>> (mat_diff._java_matrix_wrapper._java_model ==
      +        ...  mat._java_matrix_wrapper._java_model)
      +        False
      +
      +        >>> mat_same = CoordinateMatrix(mat._java_matrix_wrapper._java_model)
      +        >>> (mat_same._java_matrix_wrapper._java_model ==
      +        ...  mat._java_matrix_wrapper._java_model)
      +        True
      +        """
      +        if isinstance(entries, RDD):
      +            entries = entries.map(_convert_to_matrix_entry)
      +            # We use DataFrames for serialization of MatrixEntry entries
      +            # from Python, so first convert the RDD to a DataFrame on
      +            # this side. This will convert each MatrixEntry to a Row
      +            # containing the 'i', 'j', and 'value' values, which can
      +            # each be easily serialized. We will convert back to
      +            # MatrixEntry inputs on the Scala side.
      +            java_matrix = callMLlibFunc("createCoordinateMatrix", entries.toDF(),
      +                                        long(numRows), long(numCols))
      +        elif (isinstance(entries, JavaObject)
      +              and entries.getClass().getSimpleName() == "CoordinateMatrix"):
      +            java_matrix = entries
      +        else:
      +            raise TypeError("entries should be an RDD of MatrixEntry entries or "
      +                            "(long, long, float) tuples, got %s" % type(entries))
      +
      +        self._java_matrix_wrapper = JavaModelWrapper(java_matrix)
      +
      +    @property
      +    def entries(self):
      +        """
      +        Entries of the CoordinateMatrix stored as an RDD of
      +        MatrixEntries.
      +
      +        >>> mat = CoordinateMatrix(sc.parallelize([MatrixEntry(0, 0, 1.2),
      +        ...                                        MatrixEntry(6, 4, 2.1)]))
      +        >>> entries = mat.entries
      +        >>> entries.first()
      +        MatrixEntry(0, 0, 1.2)
      +        """
      +        # We use DataFrames for serialization of MatrixEntry entries
      +        # from Java, so we first convert the RDD of entries to a
      +        # DataFrame on the Scala/Java side. Then we map each Row in
      +        # the DataFrame back to a MatrixEntry on this side.
      +        entries_df = callMLlibFunc("getMatrixEntries", self._java_matrix_wrapper._java_model)
      +        entries = entries_df.map(lambda row: MatrixEntry(row[0], row[1], row[2]))
      +        return entries
      +
      +    def numRows(self):
      +        """
      +        Get or compute the number of rows.
      +
      +        >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2),
      +        ...                           MatrixEntry(1, 0, 2),
      +        ...                           MatrixEntry(2, 1, 3.7)])
      +
      +        >>> mat = CoordinateMatrix(entries)
      +        >>> print(mat.numRows())
      +        3
      +
      +        >>> mat = CoordinateMatrix(entries, 7, 6)
      +        >>> print(mat.numRows())
      +        7
      +        """
      +        return self._java_matrix_wrapper.call("numRows")
      +
      +    def numCols(self):
      +        """
      +        Get or compute the number of cols.
      +
      +        >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2),
      +        ...                           MatrixEntry(1, 0, 2),
      +        ...                           MatrixEntry(2, 1, 3.7)])
      +
      +        >>> mat = CoordinateMatrix(entries)
      +        >>> print(mat.numCols())
      +        2
      +
      +        >>> mat = CoordinateMatrix(entries, 7, 6)
      +        >>> print(mat.numCols())
      +        6
      +        """
      +        return self._java_matrix_wrapper.call("numCols")
      +
      +    def toRowMatrix(self):
      +        """
      +        Convert this matrix to a RowMatrix.
      +
      +        >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2),
      +        ...                           MatrixEntry(6, 4, 2.1)])
      +        >>> mat = CoordinateMatrix(entries).toRowMatrix()
      +
      +        >>> # This CoordinateMatrix will have 7 effective rows, due to
      +        >>> # the highest row index being 6, but the ensuing RowMatrix
      +        >>> # will only have 2 rows since there are only entries on 2
      +        >>> # unique rows.
      +        >>> print(mat.numRows())
      +        2
      +
      +        >>> # This CoordinateMatrix will have 5 columns, due to the
      +        >>> # highest column index being 4, and the ensuing RowMatrix
      +        >>> # will have 5 columns as well.
      +        >>> print(mat.numCols())
      +        5
      +        """
      +        java_row_matrix = self._java_matrix_wrapper.call("toRowMatrix")
      +        return RowMatrix(java_row_matrix)
      +
      +    def toIndexedRowMatrix(self):
      +        """
      +        Convert this matrix to an IndexedRowMatrix.
      +
      +        >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2),
      +        ...                           MatrixEntry(6, 4, 2.1)])
      +        >>> mat = CoordinateMatrix(entries).toIndexedRowMatrix()
      +
      +        >>> # This CoordinateMatrix will have 7 effective rows, due to
      +        >>> # the highest row index being 6, and the ensuing
      +        >>> # IndexedRowMatrix will have 7 rows as well.
      +        >>> print(mat.numRows())
      +        7
      +
      +        >>> # This CoordinateMatrix will have 5 columns, due to the
      +        >>> # highest column index being 4, and the ensuing
      +        >>> # IndexedRowMatrix will have 5 columns as well.
      +        >>> print(mat.numCols())
      +        5
      +        """
      +        java_indexed_row_matrix = self._java_matrix_wrapper.call("toIndexedRowMatrix")
      +        return IndexedRowMatrix(java_indexed_row_matrix)
      +
      +    def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024):
      +        """
      +        Convert this matrix to a BlockMatrix.
      +
      +        :param rowsPerBlock: Number of rows that make up each block.
      +                             The blocks forming the final rows are not
      +                             required to have the given number of rows.
      +        :param colsPerBlock: Number of columns that make up each block.
      +                             The blocks forming the final columns are not
      +                             required to have the given number of columns.
      +
      +        >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2),
      +        ...                           MatrixEntry(6, 4, 2.1)])
      +        >>> mat = CoordinateMatrix(entries).toBlockMatrix()
      +
      +        >>> # This CoordinateMatrix will have 7 effective rows, due to
      +        >>> # the highest row index being 6, and the ensuing
      +        >>> # BlockMatrix will have 7 rows as well.
      +        >>> print(mat.numRows())
      +        7
      +
      +        >>> # This CoordinateMatrix will have 5 columns, due to the
      +        >>> # highest column index being 4, and the ensuing
      +        >>> # BlockMatrix will have 5 columns as well.
      +        >>> print(mat.numCols())
      +        5
      +        """
      +        java_block_matrix = self._java_matrix_wrapper.call("toBlockMatrix",
      +                                                           rowsPerBlock,
      +                                                           colsPerBlock)
      +        return BlockMatrix(java_block_matrix, rowsPerBlock, colsPerBlock)
      +
      +
      +def _convert_to_matrix_block_tuple(block):
      +    if (isinstance(block, tuple) and len(block) == 2
      +            and isinstance(block[0], tuple) and len(block[0]) == 2
      +            and isinstance(block[1], Matrix)):
      +        blockRowIndex = int(block[0][0])
      +        blockColIndex = int(block[0][1])
      +        subMatrix = block[1]
      +        return ((blockRowIndex, blockColIndex), subMatrix)
      +    else:
      +        raise TypeError("Cannot convert type %s into a sub-matrix block tuple" % type(block))
      +
      +
      +class BlockMatrix(DistributedMatrix):
      +    """
      +    .. note:: Experimental
      +
      +    Represents a distributed matrix in blocks of local matrices.
      +
      +    :param blocks: An RDD of sub-matrix blocks
      +                   ((blockRowIndex, blockColIndex), sub-matrix) that
      +                   form this distributed matrix. If multiple blocks
      +                   with the same index exist, the results for
      +                   operations like add and multiply will be
      +                   unpredictable.
      +    :param rowsPerBlock: Number of rows that make up each block.
      +                         The blocks forming the final rows are not
      +                         required to have the given number of rows.
      +    :param colsPerBlock: Number of columns that make up each block.
      +                         The blocks forming the final columns are not
      +                         required to have the given number of columns.
      +    :param numRows: Number of rows of this matrix. If the supplied
      +                    value is less than or equal to zero, the number
      +                    of rows will be calculated when `numRows` is
      +                    invoked.
      +    :param numCols: Number of columns of this matrix. If the supplied
      +                    value is less than or equal to zero, the number
      +                    of columns will be calculated when `numCols` is
      +                    invoked.
      +    """
      +    def __init__(self, blocks, rowsPerBlock, colsPerBlock, numRows=0, numCols=0):
      +        """
      +        Note: This docstring is not shown publicly.
      +
      +        Create a wrapper over a Java BlockMatrix.
      +
      +        Publicly, we require that `blocks` be an RDD.  However, for
      +        internal usage, `blocks` can also be a Java BlockMatrix
      +        object, in which case we can wrap it directly.  This
      +        assists in clean matrix conversions.
      +
      +        >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])),
      +        ...                          ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))])
      +        >>> mat = BlockMatrix(blocks, 3, 2)
      +
      +        >>> mat_diff = BlockMatrix(blocks, 3, 2)
      +        >>> (mat_diff._java_matrix_wrapper._java_model ==
      +        ...  mat._java_matrix_wrapper._java_model)
      +        False
      +
      +        >>> mat_same = BlockMatrix(mat._java_matrix_wrapper._java_model, 3, 2)
      +        >>> (mat_same._java_matrix_wrapper._java_model ==
      +        ...  mat._java_matrix_wrapper._java_model)
      +        True
      +        """
      +        if isinstance(blocks, RDD):
      +            blocks = blocks.map(_convert_to_matrix_block_tuple)
      +            # We use DataFrames for serialization of sub-matrix blocks
      +            # from Python, so first convert the RDD to a DataFrame on
      +            # this side. This will convert each sub-matrix block
      +            # tuple to a Row containing the 'blockRowIndex',
      +            # 'blockColIndex', and 'subMatrix' values, which can
      +            # each be easily serialized.  We will convert back to
      +            # ((blockRowIndex, blockColIndex), sub-matrix) tuples on
      +            # the Scala side.
      +            java_matrix = callMLlibFunc("createBlockMatrix", blocks.toDF(),
      +                                        int(rowsPerBlock), int(colsPerBlock),
      +                                        long(numRows), long(numCols))
      +        elif (isinstance(blocks, JavaObject)
      +              and blocks.getClass().getSimpleName() == "BlockMatrix"):
      +            java_matrix = blocks
      +        else:
      +            raise TypeError("blocks should be an RDD of sub-matrix blocks as "
      +                            "((int, int), matrix) tuples, got %s" % type(blocks))
      +
      +        self._java_matrix_wrapper = JavaModelWrapper(java_matrix)
      +
      +    @property
      +    def blocks(self):
      +        """
      +        The RDD of sub-matrix blocks
      +        ((blockRowIndex, blockColIndex), sub-matrix) that form this
      +        distributed matrix.
      +
      +        >>> mat = BlockMatrix(
      +        ...     sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])),
      +        ...                     ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]), 3, 2)
      +        >>> blocks = mat.blocks
      +        >>> blocks.first()
      +        ((0, 0), DenseMatrix(3, 2, [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 0))
      +
      +        """
      +        # We use DataFrames for serialization of sub-matrix blocks
      +        # from Java, so we first convert the RDD of blocks to a
      +        # DataFrame on the Scala/Java side. Then we map each Row in
      +        # the DataFrame back to a sub-matrix block on this side.
      +        blocks_df = callMLlibFunc("getMatrixBlocks", self._java_matrix_wrapper._java_model)
      +        blocks = blocks_df.map(lambda row: ((row[0][0], row[0][1]), row[1]))
      +        return blocks
      +
      +    @property
      +    def rowsPerBlock(self):
      +        """
      +        Number of rows that make up each block.
      +
      +        >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])),
      +        ...                          ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))])
      +        >>> mat = BlockMatrix(blocks, 3, 2)
      +        >>> mat.rowsPerBlock
      +        3
      +        """
      +        return self._java_matrix_wrapper.call("rowsPerBlock")
      +
      +    @property
      +    def colsPerBlock(self):
      +        """
      +        Number of columns that make up each block.
      +
      +        >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])),
      +        ...                          ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))])
      +        >>> mat = BlockMatrix(blocks, 3, 2)
      +        >>> mat.colsPerBlock
      +        2
      +        """
      +        return self._java_matrix_wrapper.call("colsPerBlock")
      +
      +    @property
      +    def numRowBlocks(self):
      +        """
      +        Number of rows of blocks in the BlockMatrix.
      +
      +        >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])),
      +        ...                          ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))])
      +        >>> mat = BlockMatrix(blocks, 3, 2)
      +        >>> mat.numRowBlocks
      +        2
      +        """
      +        return self._java_matrix_wrapper.call("numRowBlocks")
      +
      +    @property
      +    def numColBlocks(self):
      +        """
      +        Number of columns of blocks in the BlockMatrix.
      +
      +        >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])),
      +        ...                          ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))])
      +        >>> mat = BlockMatrix(blocks, 3, 2)
      +        >>> mat.numColBlocks
      +        1
      +        """
      +        return self._java_matrix_wrapper.call("numColBlocks")
      +
      +    def numRows(self):
      +        """
      +        Get or compute the number of rows.
      +
      +        >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])),
      +        ...                          ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))])
      +
      +        >>> mat = BlockMatrix(blocks, 3, 2)
      +        >>> print(mat.numRows())
      +        6
      +
      +        >>> mat = BlockMatrix(blocks, 3, 2, 7, 6)
      +        >>> print(mat.numRows())
      +        7
      +        """
      +        return self._java_matrix_wrapper.call("numRows")
      +
      +    def numCols(self):
      +        """
      +        Get or compute the number of cols.
      +
      +        >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])),
      +        ...                          ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))])
      +
      +        >>> mat = BlockMatrix(blocks, 3, 2)
      +        >>> print(mat.numCols())
      +        2
      +
      +        >>> mat = BlockMatrix(blocks, 3, 2, 7, 6)
      +        >>> print(mat.numCols())
      +        6
      +        """
      +        return self._java_matrix_wrapper.call("numCols")
      +
      +    def toLocalMatrix(self):
      +        """
      +        Collect the distributed matrix on the driver as a DenseMatrix.
      +
      +        >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])),
      +        ...                          ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))])
      +        >>> mat = BlockMatrix(blocks, 3, 2).toLocalMatrix()
      +
      +        >>> # This BlockMatrix will have 6 effective rows, due to
      +        >>> # having two sub-matrix blocks stacked, each with 3 rows.
      +        >>> # The ensuing DenseMatrix will also have 6 rows.
      +        >>> print(mat.numRows)
      +        6
      +
      +        >>> # This BlockMatrix will have 2 effective columns, due to
      +        >>> # having two sub-matrix blocks stacked, each with 2
      +        >>> # columns. The ensuing DenseMatrix will also have 2 columns.
      +        >>> print(mat.numCols)
      +        2
      +        """
      +        return self._java_matrix_wrapper.call("toLocalMatrix")
      +
      +    def toIndexedRowMatrix(self):
      +        """
      +        Convert this matrix to an IndexedRowMatrix.
      +
      +        >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])),
      +        ...                          ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))])
      +        >>> mat = BlockMatrix(blocks, 3, 2).toIndexedRowMatrix()
      +
      +        >>> # This BlockMatrix will have 6 effective rows, due to
      +        >>> # having two sub-matrix blocks stacked, each with 3 rows.
      +        >>> # The ensuing IndexedRowMatrix will also have 6 rows.
      +        >>> print(mat.numRows())
      +        6
      +
      +        >>> # This BlockMatrix will have 2 effective columns, due to
      +        >>> # having two sub-matrix blocks stacked, each with 2 columns.
      +        >>> # The ensuing IndexedRowMatrix will also have 2 columns.
      +        >>> print(mat.numCols())
      +        2
      +        """
      +        java_indexed_row_matrix = self._java_matrix_wrapper.call("toIndexedRowMatrix")
      +        return IndexedRowMatrix(java_indexed_row_matrix)
      +
      +    def toCoordinateMatrix(self):
      +        """
      +        Convert this matrix to a CoordinateMatrix.
      +
      +        >>> blocks = sc.parallelize([((0, 0), Matrices.dense(1, 2, [1, 2])),
      +        ...                          ((1, 0), Matrices.dense(1, 2, [7, 8]))])
      +        >>> mat = BlockMatrix(blocks, 1, 2).toCoordinateMatrix()
      +        >>> mat.entries.take(3)
      +        [MatrixEntry(0, 0, 1.0), MatrixEntry(0, 1, 2.0), MatrixEntry(1, 0, 7.0)]
      +        """
      +        java_coordinate_matrix = self._java_matrix_wrapper.call("toCoordinateMatrix")
      +        return CoordinateMatrix(java_coordinate_matrix)
      +
      +
      +def _test():
      +    import doctest
      +    from pyspark import SparkContext
      +    from pyspark.sql import SQLContext
      +    from pyspark.mllib.linalg import Matrices
      +    import pyspark.mllib.linalg.distributed
      +    globs = pyspark.mllib.linalg.distributed.__dict__.copy()
      +    globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2)
      +    globs['sqlContext'] = SQLContext(globs['sc'])
      +    globs['Matrices'] = Matrices
      +    (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
      +    globs['sc'].stop()
      +    if failure_count:
      +        exit(-1)
      +
      +if __name__ == "__main__":
      +    _test()
      diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py
      index 06fbc0eb6aef..6a3c643b6641 100644
      --- a/python/pyspark/mllib/random.py
      +++ b/python/pyspark/mllib/random.py
      @@ -21,6 +21,7 @@
       
       from functools import wraps
       
      +from pyspark import since
       from pyspark.mllib.common import callMLlibFunc
       
       
      @@ -39,9 +40,12 @@ class RandomRDDs(object):
           """
           Generator methods for creating RDDs comprised of i.i.d samples from
           some distribution.
      +
      +    .. versionadded:: 1.1.0
           """
       
           @staticmethod
      +    @since("1.1.0")
           def uniformRDD(sc, size, numPartitions=None, seed=None):
               """
               Generates an RDD comprised of i.i.d. samples from the
      @@ -72,6 +76,7 @@ def uniformRDD(sc, size, numPartitions=None, seed=None):
               return callMLlibFunc("uniformRDD", sc._jsc, size, numPartitions, seed)
       
           @staticmethod
      +    @since("1.1.0")
           def normalRDD(sc, size, numPartitions=None, seed=None):
               """
               Generates an RDD comprised of i.i.d. samples from the standard normal
      @@ -100,6 +105,7 @@ def normalRDD(sc, size, numPartitions=None, seed=None):
               return callMLlibFunc("normalRDD", sc._jsc, size, numPartitions, seed)
       
           @staticmethod
      +    @since("1.3.0")
           def logNormalRDD(sc, mean, std, size, numPartitions=None, seed=None):
               """
               Generates an RDD comprised of i.i.d. samples from the log normal
      @@ -132,6 +138,7 @@ def logNormalRDD(sc, mean, std, size, numPartitions=None, seed=None):
                                    size, numPartitions, seed)
       
           @staticmethod
      +    @since("1.1.0")
           def poissonRDD(sc, mean, size, numPartitions=None, seed=None):
               """
               Generates an RDD comprised of i.i.d. samples from the Poisson
      @@ -158,6 +165,7 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None):
               return callMLlibFunc("poissonRDD", sc._jsc, float(mean), size, numPartitions, seed)
       
           @staticmethod
      +    @since("1.3.0")
           def exponentialRDD(sc, mean, size, numPartitions=None, seed=None):
               """
               Generates an RDD comprised of i.i.d. samples from the Exponential
      @@ -184,6 +192,7 @@ def exponentialRDD(sc, mean, size, numPartitions=None, seed=None):
               return callMLlibFunc("exponentialRDD", sc._jsc, float(mean), size, numPartitions, seed)
       
           @staticmethod
      +    @since("1.3.0")
           def gammaRDD(sc, shape, scale, size, numPartitions=None, seed=None):
               """
               Generates an RDD comprised of i.i.d. samples from the Gamma
      @@ -216,6 +225,7 @@ def gammaRDD(sc, shape, scale, size, numPartitions=None, seed=None):
       
           @staticmethod
           @toArray
      +    @since("1.1.0")
           def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None):
               """
               Generates an RDD comprised of vectors containing i.i.d. samples drawn
      @@ -241,6 +251,7 @@ def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None):
       
           @staticmethod
           @toArray
      +    @since("1.1.0")
           def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None):
               """
               Generates an RDD comprised of vectors containing i.i.d. samples drawn
      @@ -266,6 +277,7 @@ def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None):
       
           @staticmethod
           @toArray
      +    @since("1.3.0")
           def logNormalVectorRDD(sc, mean, std, numRows, numCols, numPartitions=None, seed=None):
               """
               Generates an RDD comprised of vectors containing i.i.d. samples drawn
      @@ -300,6 +312,7 @@ def logNormalVectorRDD(sc, mean, std, numRows, numCols, numPartitions=None, seed
       
           @staticmethod
           @toArray
      +    @since("1.1.0")
           def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None):
               """
               Generates an RDD comprised of vectors containing i.i.d. samples drawn
      @@ -330,6 +343,7 @@ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None):
       
           @staticmethod
           @toArray
      +    @since("1.3.0")
           def exponentialVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None):
               """
               Generates an RDD comprised of vectors containing i.i.d. samples drawn
      @@ -360,6 +374,7 @@ def exponentialVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=No
       
           @staticmethod
           @toArray
      +    @since("1.3.0")
           def gammaVectorRDD(sc, shape, scale, numRows, numCols, numPartitions=None, seed=None):
               """
               Generates an RDD comprised of vectors containing i.i.d. samples drawn
      diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
      index 9c4647ddfdcf..95047b5b7b4b 100644
      --- a/python/pyspark/mllib/recommendation.py
      +++ b/python/pyspark/mllib/recommendation.py
      @@ -18,7 +18,7 @@
       import array
       from collections import namedtuple
       
      -from pyspark import SparkContext
      +from pyspark import SparkContext, since
       from pyspark.rdd import RDD
       from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
       from pyspark.mllib.util import JavaLoader, JavaSaveable
      @@ -36,6 +36,8 @@ class Rating(namedtuple("Rating", ["user", "product", "rating"])):
           (1, 2, 5.0)
           >>> (r[0], r[1], r[2])
           (1, 2, 5.0)
      +
      +    .. versionadded:: 1.2.0
           """
       
           def __reduce__(self):
      @@ -106,17 +108,22 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
           0.4...
           >>> sameModel.predictAll(testset).collect()
           [Rating(...
      +    >>> from shutil import rmtree
           >>> try:
      -    ...     os.removedirs(path)
      +    ...     rmtree(path)
           ... except OSError:
           ...     pass
      +
      +    .. versionadded:: 0.9.0
           """
      +    @since("0.9.0")
           def predict(self, user, product):
               """
               Predicts rating for the given user and product.
               """
               return self._java_model.predict(int(user), int(product))
       
      +    @since("0.9.0")
           def predictAll(self, user_product):
               """
               Returns a list of predicted ratings for input user and product pairs.
      @@ -127,6 +134,7 @@ def predictAll(self, user_product):
               user_product = user_product.map(lambda u_p: (int(u_p[0]), int(u_p[1])))
               return self.call("predict", user_product)
       
      +    @since("1.2.0")
           def userFeatures(self):
               """
               Returns a paired RDD, where the first element is the user and the
      @@ -134,6 +142,7 @@ def userFeatures(self):
               """
               return self.call("getUserFeatures").mapValues(lambda v: array.array('d', v))
       
      +    @since("1.2.0")
           def productFeatures(self):
               """
               Returns a paired RDD, where the first element is the product and the
      @@ -141,6 +150,7 @@ def productFeatures(self):
               """
               return self.call("getProductFeatures").mapValues(lambda v: array.array('d', v))
       
      +    @since("1.4.0")
           def recommendUsers(self, product, num):
               """
               Recommends the top "num" number of users for a given product and returns a list
      @@ -148,6 +158,7 @@ def recommendUsers(self, product, num):
               """
               return list(self.call("recommendUsers", product, num))
       
      +    @since("1.4.0")
           def recommendProducts(self, user, num):
               """
               Recommends the top "num" number of products for a given user and returns a list
      @@ -156,17 +167,25 @@ def recommendProducts(self, user, num):
               return list(self.call("recommendProducts", user, num))
       
           @property
      +    @since("1.4.0")
           def rank(self):
      +        """Rank for the features in this model"""
               return self.call("rank")
       
           @classmethod
      +    @since("1.3.1")
           def load(cls, sc, path):
      +        """Load a model from the given path"""
               model = cls._load_java(sc, path)
               wrapper = sc._jvm.MatrixFactorizationModelWrapper(model)
               return MatrixFactorizationModel(wrapper)
       
       
       class ALS(object):
      +    """Alternating Least Squares matrix factorization
      +
      +    .. versionadded:: 0.9.0
      +    """
       
           @classmethod
           def _prepare(cls, ratings):
      @@ -187,15 +206,31 @@ def _prepare(cls, ratings):
               return ratings
       
           @classmethod
      +    @since("0.9.0")
           def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative=False,
                     seed=None):
      +        """
      +        Train a matrix factorization model given an RDD of ratings given by users to some products,
      +        in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the
      +        product of two lower-rank matrices of a given rank (number of features). To solve for these
      +        features, we run a given number of iterations of ALS. This is done using a level of
      +        parallelism given by `blocks`.
      +        """
               model = callMLlibFunc("trainALSModel", cls._prepare(ratings), rank, iterations,
                                     lambda_, blocks, nonnegative, seed)
               return MatrixFactorizationModel(model)
       
           @classmethod
      +    @since("0.9.0")
           def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01,
                             nonnegative=False, seed=None):
      +        """
      +        Train a matrix factorization model given an RDD of 'implicit preferences' given by users
      +        to some products, in the form of (userID, productID, preference) pairs. We approximate the
      +        ratings matrix as the product of two lower-rank matrices of a given rank (number of
      +        features).  To solve for these features, we run a given number of iterations of ALS.
      +        This is done using a level of parallelism given by `blocks`.
      +        """
               model = callMLlibFunc("trainImplicitALSModel", cls._prepare(ratings), rank,
                                     iterations, lambda_, blocks, alpha, nonnegative, seed)
               return MatrixFactorizationModel(model)
      diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
      index 0c4d7d3bbee0..256b7537fef6 100644
      --- a/python/pyspark/mllib/regression.py
      +++ b/python/pyspark/mllib/regression.py
      @@ -19,6 +19,7 @@
       from numpy import array
       
       from pyspark import RDD
      +from pyspark.streaming.dstream import DStream
       from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc
       from pyspark.mllib.linalg import SparseVector, Vectors, _convert_to_vector
       from pyspark.mllib.util import Saveable, Loader
      @@ -27,7 +28,8 @@
                  'LinearRegressionModel', 'LinearRegressionWithSGD',
                  'RidgeRegressionModel', 'RidgeRegressionWithSGD',
                  'LassoModel', 'LassoWithSGD', 'IsotonicRegressionModel',
      -           'IsotonicRegression']
      +           'IsotonicRegression', 'StreamingLinearAlgorithm',
      +           'StreamingLinearRegressionWithSGD']
       
       
       class LabeledPoint(object):
      @@ -96,9 +98,11 @@ class LinearRegressionModelBase(LinearModel):
       
           def predict(self, x):
               """
      -        Predict the value of the dependent variable given a vector x
      -        containing values for the independent variables.
      +        Predict the value of the dependent variable given a vector or
      +        an RDD of vectors containing values for the independent variables.
               """
      +        if isinstance(x, RDD):
      +            return x.map(self.predict)
               x = _convert_to_vector(x)
               return self.weights.dot(x) + self.intercept
       
      @@ -123,6 +127,8 @@ class LinearRegressionModel(LinearRegressionModelBase):
           True
           >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
           True
      +    >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5
      +    True
           >>> import os, tempfile
           >>> path = tempfile.mkdtemp()
           >>> lrm.save(sc, path)
      @@ -133,10 +139,11 @@ class LinearRegressionModel(LinearRegressionModelBase):
           True
           >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
           True
      +    >>> from shutil import rmtree
           >>> try:
      -    ...    os.removedirs(path)
      +    ...     rmtree(path)
           ... except:
      -    ...    pass
      +    ...     pass
           >>> data = [
           ...     LabeledPoint(0.0, SparseVector(1, {0: 0.0})),
           ...     LabeledPoint(1.0, SparseVector(1, {0: 1.0})),
      @@ -196,13 +203,15 @@ class LinearRegressionWithSGD(object):
           @classmethod
           def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
                     initialWeights=None, regParam=0.0, regType=None, intercept=False,
      -              validateData=True):
      +              validateData=True, convergenceTol=0.001):
               """
               Train a linear regression model using Stochastic Gradient
               Descent (SGD).
               This solves the least squares regression formulation
      -                f(weights) = 1/n ||A weights-y||^2^
      -        (which is the mean squared error).
      +
      +            f(weights) = 1/(2n) ||A weights - y||^2,
      +
      +        which is the mean squared error.
               Here the data matrix has n rows, and the input RDD holds the
               set of rows of A, each with its corresponding right hand side
               label y. See also the documentation for the precise formulation.
      @@ -236,11 +245,14 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
               :param validateData:      Boolean parameter which indicates if
                                         the algorithm should validate data
                                         before training. (default: True)
      +        :param convergenceTol:    A condition which decides iteration termination.
      +                                  (default: 0.001)
               """
               def train(rdd, i):
                   return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations),
                                        float(step), float(miniBatchFraction), i, float(regParam),
      -                                 regType, bool(intercept), bool(validateData))
      +                                 regType, bool(intercept), bool(validateData),
      +                                 float(convergenceTol))
       
               return _regression_train_wrapper(train, LinearRegressionModel, data, initialWeights)
       
      @@ -265,6 +277,8 @@ class LassoModel(LinearRegressionModelBase):
           True
           >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
           True
      +    >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5
      +    True
           >>> import os, tempfile
           >>> path = tempfile.mkdtemp()
           >>> lrm.save(sc, path)
      @@ -275,8 +289,9 @@ class LassoModel(LinearRegressionModelBase):
           True
           >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
           True
      +    >>> from shutil import rmtree
           >>> try:
      -    ...    os.removedirs(path)
      +    ...    rmtree(path)
           ... except:
           ...    pass
           >>> data = [
      @@ -319,13 +334,15 @@ class LassoWithSGD(object):
           @classmethod
           def train(cls, data, iterations=100, step=1.0, regParam=0.01,
                     miniBatchFraction=1.0, initialWeights=None, intercept=False,
      -              validateData=True):
      +              validateData=True, convergenceTol=0.001):
               """
               Train a regression model with L1-regularization using
               Stochastic Gradient Descent.
               This solves the l1-regularized least squares regression
               formulation
      -            f(weights) = 1/2n ||A weights-y||^2^  + regParam ||weights||_1
      +
      +            f(weights) = 1/(2n) ||A weights - y||^2  + regParam ||weights||_1.
      +
               Here the data matrix has n rows, and the input RDD holds the
               set of rows of A, each with its corresponding right hand side
               label y. See also the documentation for the precise formulation.
      @@ -349,11 +366,13 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01,
               :param validateData:      Boolean parameter which indicates if
                                         the algorithm should validate data
                                         before training. (default: True)
      +        :param convergenceTol:    A condition which decides iteration termination.
      +                                  (default: 0.001)
               """
               def train(rdd, i):
                   return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step),
                                        float(regParam), float(miniBatchFraction), i, bool(intercept),
      -                                 bool(validateData))
      +                                 bool(validateData), float(convergenceTol))
       
               return _regression_train_wrapper(train, LassoModel, data, initialWeights)
       
      @@ -379,6 +398,8 @@ class RidgeRegressionModel(LinearRegressionModelBase):
           True
           >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
           True
      +    >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5
      +    True
           >>> import os, tempfile
           >>> path = tempfile.mkdtemp()
           >>> lrm.save(sc, path)
      @@ -389,8 +410,9 @@ class RidgeRegressionModel(LinearRegressionModelBase):
           True
           >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
           True
      +    >>> from shutil import rmtree
           >>> try:
      -    ...    os.removedirs(path)
      +    ...    rmtree(path)
           ... except:
           ...    pass
           >>> data = [
      @@ -433,13 +455,15 @@ class RidgeRegressionWithSGD(object):
           @classmethod
           def train(cls, data, iterations=100, step=1.0, regParam=0.01,
                     miniBatchFraction=1.0, initialWeights=None, intercept=False,
      -              validateData=True):
      +              validateData=True, convergenceTol=0.001):
               """
               Train a regression model with L2-regularization using
               Stochastic Gradient Descent.
               This solves the l2-regularized least squares regression
               formulation
      -            f(weights) = 1/2n ||A weights-y||^2^  + regParam/2 ||weights||^2^
      +
      +            f(weights) = 1/(2n) ||A weights - y||^2 + regParam/2 ||weights||^2.
      +
               Here the data matrix has n rows, and the input RDD holds the
               set of rows of A, each with its corresponding right hand side
               label y. See also the documentation for the precise formulation.
      @@ -463,11 +487,13 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01,
               :param validateData:      Boolean parameter which indicates if
                                         the algorithm should validate data
                                         before training. (default: True)
      +        :param convergenceTol:    A condition which decides iteration termination.
      +                                  (default: 0.001)
               """
               def train(rdd, i):
                   return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step),
                                        float(regParam), float(miniBatchFraction), i, bool(intercept),
      -                                 bool(validateData))
      +                                 bool(validateData), float(convergenceTol))
       
               return _regression_train_wrapper(train, RidgeRegressionModel, data, initialWeights)
       
      @@ -500,8 +526,9 @@ class IsotonicRegressionModel(Saveable, Loader):
           2.0
           >>> sameModel.predict(5)
           16.5
      +    >>> from shutil import rmtree
           >>> try:
      -    ...     os.removedirs(path)
      +    ...     rmtree(path)
           ... except OSError:
           ...     pass
           """
      @@ -566,6 +593,97 @@ def train(cls, data, isotonic=True):
               return IsotonicRegressionModel(boundaries.toArray(), predictions.toArray(), isotonic)
       
       
      +class StreamingLinearAlgorithm(object):
      +    """
      +    Base class that has to be inherited by any StreamingLinearAlgorithm.
      +
      +    Prevents reimplementation of methods predictOn and predictOnValues.
      +    """
      +    def __init__(self, model):
      +        self._model = model
      +
      +    def latestModel(self):
      +        """
      +        Returns the latest model.
      +        """
      +        return self._model
      +
      +    def _validate(self, dstream):
      +        if not isinstance(dstream, DStream):
      +            raise TypeError(
      +                "dstream should be a DStream object, got %s" % type(dstream))
      +        if not self._model:
      +            raise ValueError(
      +                "Model must be intialized using setInitialWeights")
      +
      +    def predictOn(self, dstream):
      +        """
      +        Make predictions on a dstream.
      +
      +        :return: Transformed dstream object.
      +        """
      +        self._validate(dstream)
      +        return dstream.map(lambda x: self._model.predict(x))
      +
      +    def predictOnValues(self, dstream):
      +        """
      +        Make predictions on a keyed dstream.
      +
      +        :return: Transformed dstream object.
      +        """
      +        self._validate(dstream)
      +        return dstream.mapValues(lambda x: self._model.predict(x))
      +
      +
      +@inherit_doc
      +class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm):
      +    """
      +    Run LinearRegression with SGD on a batch of data.
      +
      +    The problem minimized is (1 / n_samples) * (y - weights'X)**2.
      +    After training on a batch of data, the weights obtained at the end of
      +    training are used as initial weights for the next batch.
      +
      +    :param stepSize: Step size for each iteration of gradient descent.
      +    :param numIterations: Total number of iterations run.
      +    :param miniBatchFraction: Fraction of data on which SGD is run for each
      +                              iteration.
      +    :param convergenceTol: A condition which decides iteration termination.
      +    """
      +    def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, convergenceTol=0.001):
      +        self.stepSize = stepSize
      +        self.numIterations = numIterations
      +        self.miniBatchFraction = miniBatchFraction
      +        self.convergenceTol = convergenceTol
      +        self._model = None
      +        super(StreamingLinearRegressionWithSGD, self).__init__(
      +            model=self._model)
      +
      +    def setInitialWeights(self, initialWeights):
      +        """
      +        Set the initial value of weights.
      +
      +        This must be set before running trainOn and predictOn
      +        """
      +        initialWeights = _convert_to_vector(initialWeights)
      +        self._model = LinearRegressionModel(initialWeights, 0)
      +        return self
      +
      +    def trainOn(self, dstream):
      +        """Train the model on the incoming dstream."""
      +        self._validate(dstream)
      +
      +        def update(rdd):
      +            # LinearRegressionWithSGD.train raises an error for an empty RDD.
      +            if not rdd.isEmpty():
      +                self._model = LinearRegressionWithSGD.train(
      +                    rdd, self.numIterations, self.stepSize,
      +                    self.miniBatchFraction, self._model.weights,
      +                    self._model.intercept)
      +
      +        dstream.foreachRDD(update)
      +
      +
       def _test():
           import doctest
           from pyspark import SparkContext
      diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py
      index b475be4b4d95..36c8f48a4a88 100644
      --- a/python/pyspark/mllib/stat/_statistics.py
      +++ b/python/pyspark/mllib/stat/_statistics.py
      @@ -15,11 +15,15 @@
       # limitations under the License.
       #
       
      +import sys
      +if sys.version >= '3':
      +    basestring = str
      +
       from pyspark.rdd import RDD, ignore_unicode_prefix
       from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
       from pyspark.mllib.linalg import Matrix, _convert_to_vector
       from pyspark.mllib.regression import LabeledPoint
      -from pyspark.mllib.stat.test import ChiSqTestResult
      +from pyspark.mllib.stat.test import ChiSqTestResult, KolmogorovSmirnovTestResult
       
       
       __all__ = ['MultivariateStatisticalSummary', 'Statistics']
      @@ -238,6 +242,67 @@ def chiSqTest(observed, expected=None):
                   jmodel = callMLlibFunc("chiSqTest", _convert_to_vector(observed), expected)
               return ChiSqTestResult(jmodel)
       
      +    @staticmethod
      +    @ignore_unicode_prefix
      +    def kolmogorovSmirnovTest(data, distName="norm", *params):
      +        """
      +        .. note:: Experimental
      +
      +        Performs the Kolmogorov-Smirnov (KS) test for data sampled from
      +        a continuous distribution. It tests the null hypothesis that
      +        the data is generated from a particular distribution.
      +
      +        The given data is sorted and the Empirical Cumulative
      +        Distribution Function (ECDF) is calculated
      +        which for a given point is the number of points having a CDF
      +        value lesser than it divided by the total number of points.
      +
      +        Since the data is sorted, this is a step function
      +        that rises by (1 / length of data) for every ordered point.
      +
      +        The KS statistic gives us the maximum distance between the
      +        ECDF and the CDF. Intuitively if this statistic is large, the
      +        probabilty that the null hypothesis is true becomes small.
      +        For specific details of the implementation, please have a look
      +        at the Scala documentation.
      +
      +        :param data: RDD, samples from the data
      +        :param distName: string, currently only "norm" is supported.
      +                         (Normal distribution) to calculate the
      +                         theoretical distribution of the data.
      +        :param params: additional values which need to be provided for
      +                       a certain distribution.
      +                       If not provided, the default values are used.
      +        :return: KolmogorovSmirnovTestResult object containing the test
      +                 statistic, degrees of freedom, p-value,
      +                 the method used, and the null hypothesis.
      +
      +        >>> kstest = Statistics.kolmogorovSmirnovTest
      +        >>> data = sc.parallelize([-1.0, 0.0, 1.0])
      +        >>> ksmodel = kstest(data, "norm")
      +        >>> print(round(ksmodel.pValue, 3))
      +        1.0
      +        >>> print(round(ksmodel.statistic, 3))
      +        0.175
      +        >>> ksmodel.nullHypothesis
      +        u'Sample follows theoretical distribution'
      +
      +        >>> data = sc.parallelize([2.0, 3.0, 4.0])
      +        >>> ksmodel = kstest(data, "norm", 3.0, 1.0)
      +        >>> print(round(ksmodel.pValue, 3))
      +        1.0
      +        >>> print(round(ksmodel.statistic, 3))
      +        0.175
      +        """
      +        if not isinstance(data, RDD):
      +            raise TypeError("data should be an RDD, got %s." % type(data))
      +        if not isinstance(distName, basestring):
      +            raise TypeError("distName should be a string, got %s." % type(distName))
      +
      +        params = [float(param) for param in params]
      +        return KolmogorovSmirnovTestResult(
      +            callMLlibFunc("kolmogorovSmirnovTest", data, distName, params))
      +
       
       def _test():
           import doctest
      diff --git a/python/pyspark/mllib/stat/test.py b/python/pyspark/mllib/stat/test.py
      index 762506e952b4..0abe104049ff 100644
      --- a/python/pyspark/mllib/stat/test.py
      +++ b/python/pyspark/mllib/stat/test.py
      @@ -15,24 +15,16 @@
       # limitations under the License.
       #
       
      -from pyspark.mllib.common import JavaModelWrapper
      +from pyspark.mllib.common import inherit_doc, JavaModelWrapper
       
       
      -__all__ = ["ChiSqTestResult"]
      +__all__ = ["ChiSqTestResult", "KolmogorovSmirnovTestResult"]
       
       
      -class ChiSqTestResult(JavaModelWrapper):
      +class TestResult(JavaModelWrapper):
           """
      -    .. note:: Experimental
      -
      -    Object containing the test results for the chi-squared hypothesis test.
      +    Base class for all test results.
           """
      -    @property
      -    def method(self):
      -        """
      -        Name of the test method
      -        """
      -        return self._java_model.method()
       
           @property
           def pValue(self):
      @@ -67,3 +59,24 @@ def nullHypothesis(self):
       
           def __str__(self):
               return self._java_model.toString()
      +
      +
      +@inherit_doc
      +class ChiSqTestResult(TestResult):
      +    """
      +    Contains test results for the chi-squared hypothesis test.
      +    """
      +
      +    @property
      +    def method(self):
      +        """
      +        Name of the test method
      +        """
      +        return self._java_model.method()
      +
      +
      +@inherit_doc
      +class KolmogorovSmirnovTestResult(TestResult):
      +    """
      +    Contains test results for the Kolmogorov-Smirnov test.
      +    """
      diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
      index 744dc112d920..636f9a06cab7 100644
      --- a/python/pyspark/mllib/tests.py
      +++ b/python/pyspark/mllib/tests.py
      @@ -24,11 +24,17 @@
       import tempfile
       import array as pyarray
       from time import time, sleep
      +from shutil import rmtree
       
      -from numpy import array, array_equal, zeros, inf, all, random
      +from numpy import (
      +    array, array_equal, zeros, inf, random, exp, dot, all, mean, abs, arange, tile, ones)
       from numpy import sum as array_sum
      +
       from py4j.protocol import Py4JJavaError
       
      +if sys.version > '3':
      +    basestring = str
      +
       if sys.version_info[:2] <= (2, 6):
           try:
               import unittest2 as unittest
      @@ -43,16 +49,19 @@
       from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel
       from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
           DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
      -from pyspark.mllib.regression import LabeledPoint
      +from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD
      +from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD
       from pyspark.mllib.random import RandomRDDs
       from pyspark.mllib.stat import Statistics
       from pyspark.mllib.feature import Word2Vec
       from pyspark.mllib.feature import IDF
      -from pyspark.mllib.feature import StandardScaler
      -from pyspark.mllib.feature import ElementwiseProduct
      +from pyspark.mllib.feature import StandardScaler, ElementwiseProduct
      +from pyspark.mllib.util import LinearDataGenerator
      +from pyspark.mllib.util import MLUtils
       from pyspark.serializers import PickleSerializer
       from pyspark.streaming import StreamingContext
       from pyspark.sql import SQLContext
      +from pyspark.streaming import StreamingContext
       
       _have_scipy = False
       try:
      @@ -80,9 +89,42 @@ def tearDown(self):
               self.ssc.stop(False)
       
           @staticmethod
      -    def _ssc_wait(start_time, end_time, sleep_time):
      -        while time() - start_time < end_time:
      +    def _eventually(condition, timeout=30.0, catch_assertions=False):
      +        """
      +        Wait a given amount of time for a condition to pass, else fail with an error.
      +        This is a helper utility for streaming ML tests.
      +        :param condition: Function that checks for termination conditions.
      +                          condition() can return:
      +                           - True: Conditions met. Return without error.
      +                           - other value: Conditions not met yet. Continue. Upon timeout,
      +                                          include last such value in error message.
      +                          Note that this method may be called at any time during
      +                          streaming execution (e.g., even before any results
      +                          have been created).
      +        :param timeout: Number of seconds to wait.  Default 30 seconds.
      +        :param catch_assertions: If False (default), do not catch AssertionErrors.
      +                                 If True, catch AssertionErrors; continue, but save
      +                                 error to throw upon timeout.
      +        """
      +        start_time = time()
      +        lastValue = None
      +        while time() - start_time < timeout:
      +            if catch_assertions:
      +                try:
      +                    lastValue = condition()
      +                except AssertionError as e:
      +                    lastValue = e
      +            else:
      +                lastValue = condition()
      +            if lastValue is True:
      +                return
                   sleep(0.01)
      +        if isinstance(lastValue, AssertionError):
      +            raise lastValue
      +        else:
      +            raise AssertionError(
      +                "Test failed due to timeout after %g sec, with last condition returning: %s"
      +                % (timeout, lastValue))
       
       
       def _squared_distance(a, b):
      @@ -123,17 +165,22 @@ def test_dot(self):
                            [1., 2., 3., 4.],
                            [1., 2., 3., 4.],
                            [1., 2., 3., 4.]])
      +        arr = pyarray.array('d', [0, 1, 2, 3])
               self.assertEquals(10.0, sv.dot(dv))
               self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat)))
               self.assertEquals(30.0, dv.dot(dv))
               self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat)))
               self.assertEquals(30.0, lst.dot(dv))
               self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat)))
      +        self.assertEquals(7.0, sv.dot(arr))
       
           def test_squared_distance(self):
               sv = SparseVector(4, {1: 1, 3: 2})
               dv = DenseVector(array([1., 2., 3., 4.]))
               lst = DenseVector([4, 3, 2, 1])
      +        lst1 = [4, 3, 2, 1]
      +        arr = pyarray.array('d', [0, 2, 1, 3])
      +        narr = array([0, 2, 1, 3])
               self.assertEquals(15.0, _squared_distance(sv, dv))
               self.assertEquals(25.0, _squared_distance(sv, lst))
               self.assertEquals(20.0, _squared_distance(dv, lst))
      @@ -143,6 +190,41 @@ def test_squared_distance(self):
               self.assertEquals(0.0, _squared_distance(sv, sv))
               self.assertEquals(0.0, _squared_distance(dv, dv))
               self.assertEquals(0.0, _squared_distance(lst, lst))
      +        self.assertEquals(25.0, _squared_distance(sv, lst1))
      +        self.assertEquals(3.0, _squared_distance(sv, arr))
      +        self.assertEquals(3.0, _squared_distance(sv, narr))
      +
      +    def test_hash(self):
      +        v1 = DenseVector([0.0, 1.0, 0.0, 5.5])
      +        v2 = SparseVector(4, [(1, 1.0), (3, 5.5)])
      +        v3 = DenseVector([0.0, 1.0, 0.0, 5.5])
      +        v4 = SparseVector(4, [(1, 1.0), (3, 2.5)])
      +        self.assertEquals(hash(v1), hash(v2))
      +        self.assertEquals(hash(v1), hash(v3))
      +        self.assertEquals(hash(v2), hash(v3))
      +        self.assertFalse(hash(v1) == hash(v4))
      +        self.assertFalse(hash(v2) == hash(v4))
      +
      +    def test_eq(self):
      +        v1 = DenseVector([0.0, 1.0, 0.0, 5.5])
      +        v2 = SparseVector(4, [(1, 1.0), (3, 5.5)])
      +        v3 = DenseVector([0.0, 1.0, 0.0, 5.5])
      +        v4 = SparseVector(6, [(1, 1.0), (3, 5.5)])
      +        v5 = DenseVector([0.0, 1.0, 0.0, 2.5])
      +        v6 = SparseVector(4, [(1, 1.0), (3, 2.5)])
      +        self.assertEquals(v1, v2)
      +        self.assertEquals(v1, v3)
      +        self.assertFalse(v2 == v4)
      +        self.assertFalse(v1 == v5)
      +        self.assertFalse(v1 == v6)
      +
      +    def test_equals(self):
      +        indices = [1, 2, 4]
      +        values = [1., 3., 2.]
      +        self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.]))
      +        self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.]))
      +        self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.]))
      +        self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.]))
       
           def test_conversion(self):
               # numpy arrays should be automatically upcast to float64
      @@ -175,6 +257,53 @@ def test_matrix_indexing(self):
                   for j in range(2):
                       self.assertEquals(mat[i, j], expected[i][j])
       
      +    def test_repr_dense_matrix(self):
      +        mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10])
      +        self.assertTrue(
      +            repr(mat),
      +            'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)')
      +
      +        mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True)
      +        self.assertTrue(
      +            repr(mat),
      +            'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)')
      +
      +        mat = DenseMatrix(6, 3, zeros(18))
      +        self.assertTrue(
      +            repr(mat),
      +            'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \
      +                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)')
      +
      +    def test_repr_sparse_matrix(self):
      +        sm1t = SparseMatrix(
      +            3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0],
      +            isTransposed=True)
      +        self.assertTrue(
      +            repr(sm1t),
      +            'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)')
      +
      +        indices = tile(arange(6), 3)
      +        values = ones(18)
      +        sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values)
      +        self.assertTrue(
      +            repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \
      +                [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \
      +                [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \
      +                1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)")
      +
      +        self.assertTrue(
      +            str(sm),
      +            "6 X 3 CSCMatrix\n\
      +            (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\
      +            (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\
      +            (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..")
      +
      +        sm = SparseMatrix(1, 18, zeros(19), [], [])
      +        self.assertTrue(
      +            repr(sm),
      +            'SparseMatrix(1, 18, \
      +                [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)')
      +
           def test_sparse_matrix(self):
               # Test sparse matrix creation.
               sm1 = SparseMatrix(
      @@ -184,6 +313,9 @@ def test_sparse_matrix(self):
               self.assertEquals(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4])
               self.assertEquals(sm1.rowIndices.tolist(), [1, 2, 1, 2])
               self.assertEquals(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0])
      +        self.assertTrue(
      +            repr(sm1),
      +            'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)')
       
               # Test indexing
               expected = [
      @@ -398,7 +530,7 @@ def test_classification(self):
               self.assertEqual(same_gbt_model.toDebugString(), gbt_model.toDebugString())
       
               try:
      -            os.removedirs(temp_dir)
      +            rmtree(temp_dir)
               except OSError:
                   pass
       
      @@ -462,6 +594,13 @@ def test_regression(self):
               except ValueError:
                   self.fail()
       
      +        # Verify that maxBins is being passed through
      +        GradientBoostedTrees.trainRegressor(
      +            rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=32)
      +        with self.assertRaises(Exception) as cm:
      +            GradientBoostedTrees.trainRegressor(
      +                rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=1)
      +
       
       class StatTests(MLlibTestCase):
           # SPARK-4023
      @@ -798,6 +937,25 @@ def test_right_number_of_results(self):
               self.assertIsNotNone(chi[1000])
       
       
      +class KolmogorovSmirnovTest(MLlibTestCase):
      +
      +    def test_R_implementation_equivalence(self):
      +        data = self.sc.parallelize([
      +            1.1626852897838, -0.585924465893051, 1.78546500331661, -1.33259371048501,
      +            -0.446566766553219, 0.569606122374976, -2.88971761441412, -0.869018343326555,
      +            -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063,
      +            -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691,
      +            0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942
      +        ])
      +        model = Statistics.kolmogorovSmirnovTest(data, "norm")
      +        self.assertAlmostEqual(model.statistic, 0.189, 3)
      +        self.assertAlmostEqual(model.pValue, 0.422, 3)
      +
      +        model = Statistics.kolmogorovSmirnovTest(data, "norm", 0, 1)
      +        self.assertAlmostEqual(model.statistic, 0.189, 3)
      +        self.assertAlmostEqual(model.pValue, 0.422, 3)
      +
      +
       class SerDeTest(MLlibTestCase):
           def test_to_java_object_rdd(self):  # SPARK-6660
               data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0)
      @@ -909,10 +1067,13 @@ def test_accuracy_for_single_center(self):
                   [self.sc.parallelize(batch, 1) for batch in batches])
               stkm.trainOn(input_stream)
       
      -        t = time()
               self.ssc.start()
      -        self._ssc_wait(t, 10.0, 0.01)
      -        self.assertEquals(stkm.latestModel().clusterWeights, [25.0])
      +
      +        def condition():
      +            self.assertEquals(stkm.latestModel().clusterWeights, [25.0])
      +            return True
      +        self._eventually(condition, catch_assertions=True)
      +
               realCenters = array_sum(array(centers), axis=0)
               for i in range(5):
                   modelCenters = stkm.latestModel().centers[0][i]
      @@ -937,7 +1098,7 @@ def test_trainOn_model(self):
               stkm.setInitialCenters(
                   centers=initCenters, weights=[1.0, 1.0, 1.0, 1.0])
       
      -        # Create a toy dataset by setting a tiny offest for each point.
      +        # Create a toy dataset by setting a tiny offset for each point.
               offsets = [[0, 0.1], [0, -0.1], [0.1, 0], [-0.1, 0]]
               batches = []
               for offset in offsets:
      @@ -947,14 +1108,15 @@ def test_trainOn_model(self):
               batches = [self.sc.parallelize(batch, 1) for batch in batches]
               input_stream = self.ssc.queueStream(batches)
               stkm.trainOn(input_stream)
      -        t = time()
               self.ssc.start()
       
               # Give enough time to train the model.
      -        self._ssc_wait(t, 6.0, 0.01)
      -        finalModel = stkm.latestModel()
      -        self.assertTrue(all(finalModel.centers == array(initCenters)))
      -        self.assertEquals(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0])
      +        def condition():
      +            finalModel = stkm.latestModel()
      +            self.assertTrue(all(finalModel.centers == array(initCenters)))
      +            self.assertEquals(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0])
      +            return True
      +        self._eventually(condition, catch_assertions=True)
       
           def test_predictOn_model(self):
               """Test that the model predicts correctly on toy data."""
      @@ -976,10 +1138,13 @@ def update(rdd):
                       result.append(rdd_collect)
       
               predict_val.foreachRDD(update)
      -        t = time()
               self.ssc.start()
      -        self._ssc_wait(t, 6.0, 0.01)
      -        self.assertEquals(result, [[0], [1], [2], [3]])
      +
      +        def condition():
      +            self.assertEquals(result, [[0], [1], [2], [3]])
      +            return True
      +
      +        self._eventually(condition, catch_assertions=True)
       
           def test_trainOn_predictOn(self):
               """Test that prediction happens on the updated model."""
      @@ -1005,10 +1170,357 @@ def collect(rdd):
               predict_stream = stkm.predictOn(input_stream)
               predict_stream.foreachRDD(collect)
       
      -        t = time()
               self.ssc.start()
      -        self._ssc_wait(t, 6.0, 0.01)
      -        self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]])
      +
      +        def condition():
      +            self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]])
      +            return True
      +
      +        self._eventually(condition, catch_assertions=True)
      +
      +
      +class LinearDataGeneratorTests(MLlibTestCase):
      +    def test_dim(self):
      +        linear_data = LinearDataGenerator.generateLinearInput(
      +            intercept=0.0, weights=[0.0, 0.0, 0.0],
      +            xMean=[0.0, 0.0, 0.0], xVariance=[0.33, 0.33, 0.33],
      +            nPoints=4, seed=0, eps=0.1)
      +        self.assertEqual(len(linear_data), 4)
      +        for point in linear_data:
      +            self.assertEqual(len(point.features), 3)
      +
      +        linear_data = LinearDataGenerator.generateLinearRDD(
      +            sc=sc, nexamples=6, nfeatures=2, eps=0.1,
      +            nParts=2, intercept=0.0).collect()
      +        self.assertEqual(len(linear_data), 6)
      +        for point in linear_data:
      +            self.assertEqual(len(point.features), 2)
      +
      +
      +class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase):
      +
      +    @staticmethod
      +    def generateLogisticInput(offset, scale, nPoints, seed):
      +        """
      +        Generate 1 / (1 + exp(-x * scale + offset))
      +
      +        where,
      +        x is randomnly distributed and the threshold
      +        and labels for each sample in x is obtained from a random uniform
      +        distribution.
      +        """
      +        rng = random.RandomState(seed)
      +        x = rng.randn(nPoints)
      +        sigmoid = 1. / (1 + exp(-(dot(x, scale) + offset)))
      +        y_p = rng.rand(nPoints)
      +        cut_off = y_p <= sigmoid
      +        y_p[cut_off] = 1.0
      +        y_p[~cut_off] = 0.0
      +        return [
      +            LabeledPoint(y_p[i], Vectors.dense([x[i]]))
      +            for i in range(nPoints)]
      +
      +    def test_parameter_accuracy(self):
      +        """
      +        Test that the final value of weights is close to the desired value.
      +        """
      +        input_batches = [
      +            self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i))
      +            for i in range(20)]
      +        input_stream = self.ssc.queueStream(input_batches)
      +
      +        slr = StreamingLogisticRegressionWithSGD(
      +            stepSize=0.2, numIterations=25)
      +        slr.setInitialWeights([0.0])
      +        slr.trainOn(input_stream)
      +
      +        self.ssc.start()
      +
      +        def condition():
      +            rel = (1.5 - slr.latestModel().weights.array[0]) / 1.5
      +            self.assertAlmostEqual(rel, 0.1, 1)
      +            return True
      +
      +        self._eventually(condition, catch_assertions=True)
      +
      +    def test_convergence(self):
      +        """
      +        Test that weights converge to the required value on toy data.
      +        """
      +        input_batches = [
      +            self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i))
      +            for i in range(20)]
      +        input_stream = self.ssc.queueStream(input_batches)
      +        models = []
      +
      +        slr = StreamingLogisticRegressionWithSGD(
      +            stepSize=0.2, numIterations=25)
      +        slr.setInitialWeights([0.0])
      +        slr.trainOn(input_stream)
      +        input_stream.foreachRDD(
      +            lambda x: models.append(slr.latestModel().weights[0]))
      +
      +        self.ssc.start()
      +
      +        def condition():
      +            self.assertEquals(len(models), len(input_batches))
      +            return True
      +
      +        # We want all batches to finish for this test.
      +        self._eventually(condition, 60.0, catch_assertions=True)
      +
      +        t_models = array(models)
      +        diff = t_models[1:] - t_models[:-1]
      +        # Test that weights improve with a small tolerance
      +        self.assertTrue(all(diff >= -0.1))
      +        self.assertTrue(array_sum(diff > 0) > 1)
      +
      +    @staticmethod
      +    def calculate_accuracy_error(true, predicted):
      +        return sum(abs(array(true) - array(predicted))) / len(true)
      +
      +    def test_predictions(self):
      +        """Test predicted values on a toy model."""
      +        input_batches = []
      +        for i in range(20):
      +            batch = self.sc.parallelize(
      +                self.generateLogisticInput(0, 1.5, 100, 42 + i))
      +            input_batches.append(batch.map(lambda x: (x.label, x.features)))
      +        input_stream = self.ssc.queueStream(input_batches)
      +
      +        slr = StreamingLogisticRegressionWithSGD(
      +            stepSize=0.2, numIterations=25)
      +        slr.setInitialWeights([1.5])
      +        predict_stream = slr.predictOnValues(input_stream)
      +        true_predicted = []
      +        predict_stream.foreachRDD(lambda x: true_predicted.append(x.collect()))
      +        self.ssc.start()
      +
      +        def condition():
      +            self.assertEquals(len(true_predicted), len(input_batches))
      +            return True
      +
      +        self._eventually(condition, catch_assertions=True)
      +
      +        # Test that the accuracy error is no more than 0.4 on each batch.
      +        for batch in true_predicted:
      +            true, predicted = zip(*batch)
      +            self.assertTrue(
      +                self.calculate_accuracy_error(true, predicted) < 0.4)
      +
      +    def test_training_and_prediction(self):
      +        """Test that the model improves on toy data with no. of batches"""
      +        input_batches = [
      +            self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i))
      +            for i in range(20)]
      +        predict_batches = [
      +            b.map(lambda lp: (lp.label, lp.features)) for b in input_batches]
      +
      +        slr = StreamingLogisticRegressionWithSGD(
      +            stepSize=0.01, numIterations=25)
      +        slr.setInitialWeights([-0.1])
      +        errors = []
      +
      +        def collect_errors(rdd):
      +            true, predicted = zip(*rdd.collect())
      +            errors.append(self.calculate_accuracy_error(true, predicted))
      +
      +        true_predicted = []
      +        input_stream = self.ssc.queueStream(input_batches)
      +        predict_stream = self.ssc.queueStream(predict_batches)
      +        slr.trainOn(input_stream)
      +        ps = slr.predictOnValues(predict_stream)
      +        ps.foreachRDD(lambda x: collect_errors(x))
      +
      +        self.ssc.start()
      +
      +        def condition():
      +            # Test that the improvement in error is > 0.3
      +            if len(errors) == len(predict_batches):
      +                self.assertGreater(errors[1] - errors[-1], 0.3)
      +            if len(errors) >= 3 and errors[1] - errors[-1] > 0.3:
      +                return True
      +            return "Latest errors: " + ", ".join(map(lambda x: str(x), errors))
      +
      +        self._eventually(condition)
      +
      +
      +class StreamingLinearRegressionWithTests(MLLibStreamingTestCase):
      +
      +    def assertArrayAlmostEqual(self, array1, array2, dec):
      +        for i, j in array1, array2:
      +            self.assertAlmostEqual(i, j, dec)
      +
      +    def test_parameter_accuracy(self):
      +        """Test that coefs are predicted accurately by fitting on toy data."""
      +
      +        # Test that fitting (10*X1 + 10*X2), (X1, X2) gives coefficients
      +        # (10, 10)
      +        slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25)
      +        slr.setInitialWeights([0.0, 0.0])
      +        xMean = [0.0, 0.0]
      +        xVariance = [1.0 / 3.0, 1.0 / 3.0]
      +
      +        # Create ten batches with 100 sample points in each.
      +        batches = []
      +        for i in range(10):
      +            batch = LinearDataGenerator.generateLinearInput(
      +                0.0, [10.0, 10.0], xMean, xVariance, 100, 42 + i, 0.1)
      +            batches.append(sc.parallelize(batch))
      +
      +        input_stream = self.ssc.queueStream(batches)
      +        slr.trainOn(input_stream)
      +        self.ssc.start()
      +
      +        def condition():
      +            self.assertArrayAlmostEqual(
      +                slr.latestModel().weights.array, [10., 10.], 1)
      +            self.assertAlmostEqual(slr.latestModel().intercept, 0.0, 1)
      +            return True
      +
      +        self._eventually(condition, catch_assertions=True)
      +
      +    def test_parameter_convergence(self):
      +        """Test that the model parameters improve with streaming data."""
      +        slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25)
      +        slr.setInitialWeights([0.0])
      +
      +        # Create ten batches with 100 sample points in each.
      +        batches = []
      +        for i in range(10):
      +            batch = LinearDataGenerator.generateLinearInput(
      +                0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1)
      +            batches.append(sc.parallelize(batch))
      +
      +        model_weights = []
      +        input_stream = self.ssc.queueStream(batches)
      +        input_stream.foreachRDD(
      +            lambda x: model_weights.append(slr.latestModel().weights[0]))
      +        slr.trainOn(input_stream)
      +        self.ssc.start()
      +
      +        def condition():
      +            self.assertEquals(len(model_weights), len(batches))
      +            return True
      +
      +        # We want all batches to finish for this test.
      +        self._eventually(condition, catch_assertions=True)
      +
      +        w = array(model_weights)
      +        diff = w[1:] - w[:-1]
      +        self.assertTrue(all(diff >= -0.1))
      +
      +    def test_prediction(self):
      +        """Test prediction on a model with weights already set."""
      +        # Create a model with initial Weights equal to coefs
      +        slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25)
      +        slr.setInitialWeights([10.0, 10.0])
      +
      +        # Create ten batches with 100 sample points in each.
      +        batches = []
      +        for i in range(10):
      +            batch = LinearDataGenerator.generateLinearInput(
      +                0.0, [10.0, 10.0], [0.0, 0.0], [1.0 / 3.0, 1.0 / 3.0],
      +                100, 42 + i, 0.1)
      +            batches.append(
      +                sc.parallelize(batch).map(lambda lp: (lp.label, lp.features)))
      +
      +        input_stream = self.ssc.queueStream(batches)
      +        output_stream = slr.predictOnValues(input_stream)
      +        samples = []
      +        output_stream.foreachRDD(lambda x: samples.append(x.collect()))
      +
      +        self.ssc.start()
      +
      +        def condition():
      +            self.assertEquals(len(samples), len(batches))
      +            return True
      +
      +        # We want all batches to finish for this test.
      +        self._eventually(condition, catch_assertions=True)
      +
      +        # Test that mean absolute error on each batch is less than 0.1
      +        for batch in samples:
      +            true, predicted = zip(*batch)
      +            self.assertTrue(mean(abs(array(true) - array(predicted))) < 0.1)
      +
      +    def test_train_prediction(self):
      +        """Test that error on test data improves as model is trained."""
      +        slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25)
      +        slr.setInitialWeights([0.0])
      +
      +        # Create ten batches with 100 sample points in each.
      +        batches = []
      +        for i in range(10):
      +            batch = LinearDataGenerator.generateLinearInput(
      +                0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1)
      +            batches.append(sc.parallelize(batch))
      +
      +        predict_batches = [
      +            b.map(lambda lp: (lp.label, lp.features)) for b in batches]
      +        errors = []
      +
      +        def func(rdd):
      +            true, predicted = zip(*rdd.collect())
      +            errors.append(mean(abs(true) - abs(predicted)))
      +
      +        input_stream = self.ssc.queueStream(batches)
      +        output_stream = self.ssc.queueStream(predict_batches)
      +        slr.trainOn(input_stream)
      +        output_stream = slr.predictOnValues(output_stream)
      +        output_stream.foreachRDD(func)
      +        self.ssc.start()
      +
      +        def condition():
      +            if len(errors) == len(predict_batches):
      +                self.assertGreater(errors[1] - errors[-1], 2)
      +            if len(errors) >= 3 and errors[1] - errors[-1] > 2:
      +                return True
      +            return "Latest errors: " + ", ".join(map(lambda x: str(x), errors))
      +
      +        self._eventually(condition)
      +
      +
      +class MLUtilsTests(MLlibTestCase):
      +    def test_append_bias(self):
      +        data = [2.0, 2.0, 2.0]
      +        ret = MLUtils.appendBias(data)
      +        self.assertEqual(ret[3], 1.0)
      +        self.assertEqual(type(ret), DenseVector)
      +
      +    def test_append_bias_with_vector(self):
      +        data = Vectors.dense([2.0, 2.0, 2.0])
      +        ret = MLUtils.appendBias(data)
      +        self.assertEqual(ret[3], 1.0)
      +        self.assertEqual(type(ret), DenseVector)
      +
      +    def test_append_bias_with_sp_vector(self):
      +        data = Vectors.sparse(3, {0: 2.0, 2: 2.0})
      +        expected = Vectors.sparse(4, {0: 2.0, 2: 2.0, 3: 1.0})
      +        # Returned value must be SparseVector
      +        ret = MLUtils.appendBias(data)
      +        self.assertEqual(ret, expected)
      +        self.assertEqual(type(ret), SparseVector)
      +
      +    def test_load_vectors(self):
      +        import shutil
      +        data = [
      +            [1.0, 2.0, 3.0],
      +            [1.0, 2.0, 3.0]
      +        ]
      +        temp_dir = tempfile.mkdtemp()
      +        load_vectors_path = os.path.join(temp_dir, "test_load_vectors")
      +        try:
      +            self.sc.parallelize(data).saveAsTextFile(load_vectors_path)
      +            ret_rdd = MLUtils.loadVectors(self.sc, load_vectors_path)
      +            ret = ret_rdd.collect()
      +            self.assertEqual(len(ret), 2)
      +            self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0]))
      +            self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0]))
      +        except:
      +            self.fail()
      +        finally:
      +            shutil.rmtree(load_vectors_path)
       
       
       if __name__ == "__main__":
      diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
      index cfcbea573fd2..372b86a7c95d 100644
      --- a/python/pyspark/mllib/tree.py
      +++ b/python/pyspark/mllib/tree.py
      @@ -299,7 +299,7 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees,
                        1 internal node + 2 leaf nodes. (default: 4)
               :param maxBins: maximum number of bins used for splitting
                        features
      -                 (default: 100)
      +                 (default: 32)
               :param seed: Random seed for bootstrapping and choosing feature
                        subsets.
               :return: RandomForestModel that can be used for prediction
      @@ -377,7 +377,7 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt
                        1 leaf node; depth 1 means 1 internal node + 2 leaf
                        nodes. (default: 4)
               :param maxBins: maximum number of bins used for splitting
      -                 features (default: 100)
      +                 features (default: 32)
               :param seed: Random seed for bootstrapping and choosing feature
                        subsets.
               :return: RandomForestModel that can be used for prediction
      @@ -435,16 +435,17 @@ class GradientBoostedTrees(object):
       
           @classmethod
           def _train(cls, data, algo, categoricalFeaturesInfo,
      -               loss, numIterations, learningRate, maxDepth):
      +               loss, numIterations, learningRate, maxDepth, maxBins):
               first = data.first()
               assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint"
               model = callMLlibFunc("trainGradientBoostedTreesModel", data, algo, categoricalFeaturesInfo,
      -                              loss, numIterations, learningRate, maxDepth)
      +                              loss, numIterations, learningRate, maxDepth, maxBins)
               return GradientBoostedTreesModel(model)
       
           @classmethod
           def trainClassifier(cls, data, categoricalFeaturesInfo,
      -                        loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3):
      +                        loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3,
      +                        maxBins=32):
               """
               Method to train a gradient-boosted trees model for
               classification.
      @@ -467,6 +468,8 @@ def trainClassifier(cls, data, categoricalFeaturesInfo,
               :param maxDepth: Maximum depth of the tree. E.g., depth 0 means
                        1 leaf node; depth 1 means 1 internal node + 2 leaf
                        nodes. (default: 3)
      +        :param maxBins: maximum number of bins used for splitting
      +                 features (default: 32) DecisionTree requires maxBins >= max categories
               :return: GradientBoostedTreesModel that can be used for
                          prediction
       
      @@ -499,11 +502,12 @@ def trainClassifier(cls, data, categoricalFeaturesInfo,
               [1.0, 0.0]
               """
               return cls._train(data, "classification", categoricalFeaturesInfo,
      -                          loss, numIterations, learningRate, maxDepth)
      +                          loss, numIterations, learningRate, maxDepth, maxBins)
       
           @classmethod
           def trainRegressor(cls, data, categoricalFeaturesInfo,
      -                       loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3):
      +                       loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3,
      +                       maxBins=32):
               """
               Method to train a gradient-boosted trees model for regression.
       
      @@ -522,6 +526,8 @@ def trainRegressor(cls, data, categoricalFeaturesInfo,
                        contribution of each estimator. The learning rate
                        should be between in the interval (0, 1].
                        (default: 0.1)
      +        :param maxBins: maximum number of bins used for splitting
      +                 features (default: 32) DecisionTree requires maxBins >= max categories
               :param maxDepth: Maximum depth of the tree. E.g., depth 0 means
                        1 leaf node; depth 1 means 1 internal node + 2 leaf
                        nodes.  (default: 3)
      @@ -556,7 +562,7 @@ def trainRegressor(cls, data, categoricalFeaturesInfo,
               [1.0, 0.0]
               """
               return cls._train(data, "regression", categoricalFeaturesInfo,
      -                          loss, numIterations, learningRate, maxDepth)
      +                          loss, numIterations, learningRate, maxDepth, maxBins)
       
       
       def _test():
      diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
      index 16a90db146ef..10a1e4b3eb0f 100644
      --- a/python/pyspark/mllib/util.py
      +++ b/python/pyspark/mllib/util.py
      @@ -21,7 +21,9 @@
       
       if sys.version > '3':
           xrange = range
      +    basestring = str
       
      +from pyspark import SparkContext
       from pyspark.mllib.common import callMLlibFunc, inherit_doc
       from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector
       
      @@ -169,6 +171,28 @@ def loadLabeledPoints(sc, path, minPartitions=None):
               minPartitions = minPartitions or min(sc.defaultParallelism, 2)
               return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions)
       
      +    @staticmethod
      +    def appendBias(data):
      +        """
      +        Returns a new vector with `1.0` (bias) appended to
      +        the end of the input vector.
      +        """
      +        vec = _convert_to_vector(data)
      +        if isinstance(vec, SparseVector):
      +            newIndices = np.append(vec.indices, len(vec))
      +            newValues = np.append(vec.values, 1.0)
      +            return SparseVector(len(vec) + 1, newIndices, newValues)
      +        else:
      +            return _convert_to_vector(np.append(vec.toArray(), 1.0))
      +
      +    @staticmethod
      +    def loadVectors(sc, path):
      +        """
      +        Loads vectors saved using `RDD[Vector].saveAsTextFile`
      +        with the default number of partitions.
      +        """
      +        return callMLlibFunc("loadVectors", sc, path)
      +
       
       class Saveable(object):
           """
      @@ -201,6 +225,10 @@ class JavaSaveable(Saveable):
           """
       
           def save(self, sc, path):
      +        if not isinstance(sc, SparkContext):
      +            raise TypeError("sc should be a SparkContext, got type %s" % type(sc))
      +        if not isinstance(path, basestring):
      +            raise TypeError("path should be a basestring, got type %s" % type(path))
               self._java_model.save(sc._jsc.sc(), path)
       
       
      @@ -257,6 +285,42 @@ def load(cls, sc, path):
               return cls(java_model)
       
       
      +class LinearDataGenerator(object):
      +    """Utils for generating linear data"""
      +
      +    @staticmethod
      +    def generateLinearInput(intercept, weights, xMean, xVariance,
      +                            nPoints, seed, eps):
      +        """
      +        :param: intercept bias factor, the term c in X'w + c
      +        :param: weights   feature vector, the term w in X'w + c
      +        :param: xMean     Point around which the data X is centered.
      +        :param: xVariance Variance of the given data
      +        :param: nPoints   Number of points to be generated
      +        :param: seed      Random Seed
      +        :param: eps       Used to scale the noise. If eps is set high,
      +                          the amount of gaussian noise added is more.
      +
      +        Returns a list of LabeledPoints of length nPoints
      +        """
      +        weights = [float(weight) for weight in weights]
      +        xMean = [float(mean) for mean in xMean]
      +        xVariance = [float(var) for var in xVariance]
      +        return list(callMLlibFunc(
      +            "generateLinearInputWrapper", float(intercept), weights, xMean,
      +            xVariance, int(nPoints), int(seed), float(eps)))
      +
      +    @staticmethod
      +    def generateLinearRDD(sc, nexamples, nfeatures, eps,
      +                          nParts=2, intercept=0.0):
      +        """
      +        Generate a RDD of LabeledPoints.
      +        """
      +        return callMLlibFunc(
      +            "generateLinearRDDWrapper", sc, int(nexamples), int(nfeatures),
      +            float(eps), int(nParts), float(intercept))
      +
      +
       def _test():
           import doctest
           from pyspark.context import SparkContext
      diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py
      index d18daaabfcb3..44d17bd62947 100644
      --- a/python/pyspark/profiler.py
      +++ b/python/pyspark/profiler.py
      @@ -90,9 +90,11 @@ class Profiler(object):
           >>> sc = SparkContext('local', 'test', conf=conf, profiler_cls=MyCustomProfiler)
           >>> sc.parallelize(range(1000)).map(lambda x: 2 * x).take(10)
           [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
      +    >>> sc.parallelize(range(1000)).count()
      +    1000
           >>> sc.show_profiles()
           My custom profiles for RDD:1
      -    My custom profiles for RDD:2
      +    My custom profiles for RDD:3
           >>> sc.stop()
           """
       
      @@ -169,4 +171,6 @@ def stats(self):
       
       if __name__ == "__main__":
           import doctest
      -    doctest.testmod()
      +    (failure_count, test_count) = doctest.testmod()
      +    if failure_count:
      +        exit(-1)
      diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
      index 20c0bc93f413..9ef60a7e2c84 100644
      --- a/python/pyspark/rdd.py
      +++ b/python/pyspark/rdd.py
      @@ -121,10 +121,23 @@ def _parse_memory(s):
       
       
       def _load_from_socket(port, serializer):
      -    sock = socket.socket()
      -    sock.settimeout(3)
      +    sock = None
      +    # Support for both IPv4 and IPv6.
      +    # On most of IPv6-ready systems, IPv6 will take precedence.
      +    for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
      +        af, socktype, proto, canonname, sa = res
      +        sock = socket.socket(af, socktype, proto)
      +        try:
      +            sock.settimeout(3)
      +            sock.connect(sa)
      +        except socket.error:
      +            sock.close()
      +            sock = None
      +            continue
      +        break
      +    if not sock:
      +        raise Exception("could not open socket")
           try:
      -        sock.connect(("localhost", port))
               rf = sock.makefile("rb", 65536)
               for item in serializer.load_stream(rf):
                   yield item
      @@ -687,13 +700,18 @@ def groupBy(self, f, numPartitions=None):
               return self.map(lambda x: (f(x), x)).groupByKey(numPartitions)
       
           @ignore_unicode_prefix
      -    def pipe(self, command, env={}):
      +    def pipe(self, command, env=None, checkCode=False):
               """
               Return an RDD created by piping elements to a forked external process.
       
               >>> sc.parallelize(['1', '2', '', '3']).pipe('cat').collect()
               [u'1', u'2', u'', u'3']
      +
      +        :param checkCode: whether or not to check the return value of the shell command.
               """
      +        if env is None:
      +            env = dict()
      +
               def func(iterator):
                   pipe = Popen(
                       shlex.split(command), env=env, stdin=PIPE, stdout=PIPE)
      @@ -704,7 +722,17 @@ def pipe_objs(out):
                           out.write(s.encode('utf-8'))
                       out.close()
                   Thread(target=pipe_objs, args=[pipe.stdin]).start()
      -            return (x.rstrip(b'\n').decode('utf-8') for x in iter(pipe.stdout.readline, b''))
      +
      +            def check_return_code():
      +                pipe.wait()
      +                if checkCode and pipe.returncode:
      +                    raise Exception("Pipe function `%s' exited "
      +                                    "with error code %d" % (command, pipe.returncode))
      +                else:
      +                    for i in range(0):
      +                        yield i
      +            return (x.rstrip(b'\n').decode('utf-8') for x in
      +                    chain(iter(pipe.stdout.readline, b''), check_return_code()))
               return self.mapPartitions(func)
       
           def foreach(self, f):
      @@ -837,6 +865,9 @@ def func(iterator):
                   for obj in iterator:
                       acc = op(obj, acc)
                   yield acc
      +        # collecting result of mapPartitions here ensures that the copy of
      +        # zeroValue provided to each partition is unique from the one provided
      +        # to the final reduce call
               vals = self.mapPartitions(func).collect()
               return reduce(op, vals, zeroValue)
       
      @@ -866,8 +897,11 @@ def func(iterator):
                   for obj in iterator:
                       acc = seqOp(acc, obj)
                   yield acc
      -
      -        return self.mapPartitions(func).fold(zeroValue, combOp)
      +        # collecting result of mapPartitions here ensures that the copy of
      +        # zeroValue provided to each partition is unique from the one provided
      +        # to the final reduce call
      +        vals = self.mapPartitions(func).collect()
      +        return reduce(combOp, vals, zeroValue)
       
           def treeAggregate(self, zeroValue, seqOp, combOp, depth=2):
               """
      @@ -1262,7 +1296,7 @@ def takeUpToNumLeft(iterator):
                           taken += 1
       
                   p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts))
      -            res = self.context.runJob(self, takeUpToNumLeft, p, True)
      +            res = self.context.runJob(self, takeUpToNumLeft, p)
       
                   items += res
                   partsScanned += numPartsToTry
      @@ -2162,7 +2196,7 @@ def lookup(self, key):
               values = self.filter(lambda kv: kv[0] == key).values()
       
               if self.partitioner is not None:
      -            return self.ctx.runJob(values, lambda x: x, [self.partitioner(key)], False)
      +            return self.ctx.runJob(values, lambda x: x, [self.partitioner(key)])
       
               return values.collect()
       
      @@ -2198,7 +2232,7 @@ def sumApprox(self, timeout, confidence=0.95):
       
               >>> rdd = sc.parallelize(range(1000), 10)
               >>> r = sum(range(1000))
      -        >>> (rdd.sumApprox(1000) - r) / r < 0.05
      +        >>> abs(rdd.sumApprox(1000) - r) / r < 0.05
               True
               """
               jrdd = self.mapPartitions(lambda it: [float(sum(it))])._to_java_object_rdd()
      @@ -2215,7 +2249,7 @@ def meanApprox(self, timeout, confidence=0.95):
       
               >>> rdd = sc.parallelize(range(1000), 10)
               >>> r = sum(range(1000)) / 1000.0
      -        >>> (rdd.meanApprox(1000) - r) / r < 0.05
      +        >>> abs(rdd.meanApprox(1000) - r) / r < 0.05
               True
               """
               jrdd = self.map(float)._to_java_object_rdd()
      diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
      index 7f9d0a338d31..2a1326947f4f 100644
      --- a/python/pyspark/serializers.py
      +++ b/python/pyspark/serializers.py
      @@ -44,8 +44,8 @@
       
       >>> rdd.glom().collect()
       [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
      ->>> rdd._jrdd.count()
      -8L
      +>>> int(rdd._jrdd.count())
      +8
       >>> sc.stop()
       """
       
      @@ -359,6 +359,7 @@ def _hack_namedtuple(cls):
           def __reduce__(self):
               return (_restore, (name, fields, tuple(self)))
           cls.__reduce__ = __reduce__
      +    cls._is_namedtuple_ = True
           return cls
       
       
      @@ -556,4 +557,6 @@ def write_with_length(obj, stream):
       
       if __name__ == '__main__':
           import doctest
      -    doctest.testmod()
      +    (failure_count, test_count) = doctest.testmod()
      +    if failure_count:
      +        exit(-1)
      diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
      index 144cdf0b0cdd..99331297c19f 100644
      --- a/python/pyspark/shell.py
      +++ b/python/pyspark/shell.py
      @@ -40,7 +40,7 @@
       if os.environ.get("SPARK_EXECUTOR_URI"):
           SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"])
       
      -sc = SparkContext(appName="PySparkShell", pyFiles=add_files)
      +sc = SparkContext(pyFiles=add_files)
       atexit.register(lambda: sc.stop())
       
       try:
      diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
      index 67752c0d150b..b8118bdb7ca7 100644
      --- a/python/pyspark/shuffle.py
      +++ b/python/pyspark/shuffle.py
      @@ -606,7 +606,7 @@ def _open_file(self):
               if not os.path.exists(d):
                   os.makedirs(d)
               p = os.path.join(d, str(id(self)))
      -        self._file = open(p, "wb+", 65536)
      +        self._file = open(p, "w+b", 65536)
               self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024)
               os.unlink(p)
       
      @@ -838,4 +838,6 @@ def load_partition(j):
       
       if __name__ == "__main__":
           import doctest
      -    doctest.testmod()
      +    (failure_count, test_count) = doctest.testmod()
      +    if failure_count:
      +        exit(-1)
      diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py
      index ad9c891ba1c0..98eaf52866d2 100644
      --- a/python/pyspark/sql/__init__.py
      +++ b/python/pyspark/sql/__init__.py
      @@ -44,21 +44,6 @@
       from __future__ import absolute_import
       
       
      -def since(version):
      -    """
      -    A decorator that annotates a function to append the version of Spark the function was added.
      -    """
      -    import re
      -    indent_p = re.compile(r'\n( +)')
      -
      -    def deco(f):
      -        indents = indent_p.findall(f.__doc__)
      -        indent = ' ' * (min(len(m) for m in indents) if indents else 0)
      -        f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version)
      -        return f
      -    return deco
      -
      -
       from pyspark.sql.types import Row
       from pyspark.sql.context import SQLContext, HiveContext
       from pyspark.sql.column import Column
      diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
      index 1ecec5b12650..9ca8e1f264cf 100644
      --- a/python/pyspark/sql/column.py
      +++ b/python/pyspark/sql/column.py
      @@ -16,14 +16,15 @@
       #
       
       import sys
      +import warnings
       
       if sys.version >= '3':
           basestring = str
           long = int
       
      +from pyspark import since
       from pyspark.context import SparkContext
       from pyspark.rdd import ignore_unicode_prefix
      -from pyspark.sql import since
       from pyspark.sql.types import *
       
       __all__ = ["DataFrame", "Column", "SchemaRDD", "DataFrameNaFunctions",
      @@ -60,6 +61,18 @@ def _to_seq(sc, cols, converter=None):
           return sc._jvm.PythonUtils.toSeq(cols)
       
       
      +def _to_list(sc, cols, converter=None):
      +    """
      +    Convert a list of Column (or names) into a JVM (Scala) List of Column.
      +
      +    An optional `converter` could be used to convert items in `cols`
      +    into JVM Column objects.
      +    """
      +    if converter:
      +        cols = [converter(c) for c in cols]
      +    return sc._jvm.PythonUtils.toList(cols)
      +
      +
       def _unary_op(name, doc="unary operator"):
           """ Create a method for given unary operator """
           def _(self):
      @@ -78,6 +91,17 @@ def _(self):
           return _
       
       
      +def _bin_func_op(name, reverse=False, doc="binary function"):
      +    def _(self, other):
      +        sc = SparkContext._active_spark_context
      +        fn = getattr(sc._jvm.functions, name)
      +        jc = other._jc if isinstance(other, Column) else _create_column_from_literal(other)
      +        njc = fn(self._jc, jc) if not reverse else fn(jc, self._jc)
      +        return Column(njc)
      +    _.__doc__ = doc
      +    return _
      +
      +
       def _bin_op(name, doc="binary operator"):
           """ Create a method for given binary operator
           """
      @@ -138,6 +162,8 @@ def __init__(self, jc):
           __rdiv__ = _reverse_op("divide")
           __rtruediv__ = _reverse_op("divide")
           __rmod__ = _reverse_op("mod")
      +    __pow__ = _bin_func_op("pow")
      +    __rpow__ = _bin_func_op("pow", reverse=True)
       
           # logistic operators
           __eq__ = _bin_op("equalTo")
      @@ -213,6 +239,9 @@ def __getattr__(self, item):
                   raise AttributeError(item)
               return self.getField(item)
       
      +    def __iter__(self):
      +        raise TypeError("Column is not iterable")
      +
           # string methods
           rlike = _bin_op("rlike")
           like = _bin_op("like")
      @@ -254,12 +283,29 @@ def inSet(self, *cols):
               [Row(age=5, name=u'Bob')]
               >>> df[df.age.inSet([1, 2, 3])].collect()
               [Row(age=2, name=u'Alice')]
      +
      +        .. note:: Deprecated in 1.5, use :func:`Column.isin` instead.
      +        """
      +        warnings.warn("inSet is deprecated. Use isin() instead.")
      +        return self.isin(*cols)
      +
      +    @ignore_unicode_prefix
      +    @since(1.5)
      +    def isin(self, *cols):
      +        """
      +        A boolean expression that is evaluated to true if the value of this
      +        expression is contained by the evaluated values of the arguments.
      +
      +        >>> df[df.name.isin("Bob", "Mike")].collect()
      +        [Row(age=5, name=u'Bob')]
      +        >>> df[df.age.isin([1, 2, 3])].collect()
      +        [Row(age=2, name=u'Alice')]
               """
               if len(cols) == 1 and isinstance(cols[0], (list, set)):
                   cols = cols[0]
               cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols]
               sc = SparkContext._active_spark_context
      -        jc = getattr(self._jc, "in")(_to_seq(sc, cols))
      +        jc = getattr(self._jc, "isin")(_to_seq(sc, cols))
               return Column(jc)
       
           # order
      @@ -396,6 +442,11 @@ def over(self, window):
               jc = self._jc.over(window._jspec)
               return Column(jc)
       
      +    def __nonzero__(self):
      +        raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', "
      +                         "'~' for 'not' when building DataFrame boolean expressions.")
      +    __bool__ = __nonzero__
      +
           def __repr__(self):
               return 'Column<%s>' % self._jc.toString().encode('utf8')
       
      diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
      index 599c9ac5794a..89c8c6e0d94f 100644
      --- a/python/pyspark/sql/context.py
      +++ b/python/pyspark/sql/context.py
      @@ -26,18 +26,20 @@
       
       from py4j.protocol import Py4JError
       
      +from pyspark import since
       from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix
       from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
      -from pyspark.sql import since
       from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
      -    _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
      +    _infer_schema, _has_nulltype, _merge_type, _create_converter
       from pyspark.sql.dataframe import DataFrame
       from pyspark.sql.readwriter import DataFrameReader
      +from pyspark.sql.utils import install_exception_handler
      +from pyspark.sql.functions import UserDefinedFunction
       
       try:
           import pandas
           has_pandas = True
      -except ImportError:
      +except Exception:
           has_pandas = False
       
       __all__ = ["SQLContext", "HiveContext", "UDFRegistration"]
      @@ -86,7 +88,8 @@ def __init__(self, sparkContext, sqlContext=None):
               >>> df.registerTempTable("allTypes")
               >>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
               ...            'from allTypes where b and i > 0').collect()
      -        [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
      +        [Row(_c0=2, _c1=2.0, _c2=False, _c3=2, _c4=0, \
      +            time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
               >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect()
               [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
               """
      @@ -95,6 +98,7 @@ def __init__(self, sparkContext, sqlContext=None):
               self._jvm = self._sc._jvm
               self._scala_SQLContext = sqlContext
               _monkey_patch_RDD(self)
      +        install_exception_handler()
       
           @property
           def _ssql_ctx(self):
      @@ -176,33 +180,52 @@ def registerFunction(self, name, f, returnType=StringType()):
       
               >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x))
               >>> sqlContext.sql("SELECT stringLengthString('test')").collect()
      -        [Row(c0=u'4')]
      +        [Row(_c0=u'4')]
       
               >>> from pyspark.sql.types import IntegerType
               >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
               >>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
      -        [Row(c0=4)]
      +        [Row(_c0=4)]
       
               >>> from pyspark.sql.types import IntegerType
               >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
               >>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
      -        [Row(c0=4)]
      -        """
      -        func = lambda _, it: map(lambda x: f(*x), it)
      -        ser = AutoBatchedSerializer(PickleSerializer())
      -        command = (func, None, ser, ser)
      -        pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self)
      -        self._ssql_ctx.udf().registerPython(name,
      -                                            bytearray(pickled_cmd),
      -                                            env,
      -                                            includes,
      -                                            self._sc.pythonExec,
      -                                            self._sc.pythonVer,
      -                                            bvars,
      -                                            self._sc._javaAccumulator,
      -                                            returnType.json())
      +        [Row(_c0=4)]
      +        """
      +        udf = UserDefinedFunction(f, returnType, name)
      +        self._ssql_ctx.udf().registerPython(name, udf._judf)
      +
      +    def _inferSchemaFromList(self, data):
      +        """
      +        Infer schema from list of Row or tuple.
      +
      +        :param data: list of Row or tuple
      +        :return: StructType
      +        """
      +        if not data:
      +            raise ValueError("can not infer schema from empty dataset")
      +        first = data[0]
      +        if type(first) is dict:
      +            warnings.warn("inferring schema from dict is deprecated,"
      +                          "please use pyspark.sql.Row instead")
      +        schema = _infer_schema(first)
      +        if _has_nulltype(schema):
      +            for r in data:
      +                schema = _merge_type(schema, _infer_schema(r))
      +                if not _has_nulltype(schema):
      +                    break
      +            else:
      +                raise ValueError("Some of types cannot be determined after inferring")
      +        return schema
       
           def _inferSchema(self, rdd, samplingRatio=None):
      +        """
      +        Infer schema from an RDD of Row or tuple.
      +
      +        :param rdd: an RDD of Row or tuple
      +        :param samplingRatio: sampling ratio, or no sampling (default)
      +        :return: StructType
      +        """
               first = rdd.first()
               if not first:
                   raise ValueError("The first row in RDD is empty, "
      @@ -254,6 +277,66 @@ def applySchema(self, rdd, schema):
       
               return self.createDataFrame(rdd, schema)
       
      +    def _createFromRDD(self, rdd, schema, samplingRatio):
      +        """
      +        Create an RDD for DataFrame from an existing RDD, returns the RDD and schema.
      +        """
      +        if schema is None or isinstance(schema, (list, tuple)):
      +            struct = self._inferSchema(rdd, samplingRatio)
      +            converter = _create_converter(struct)
      +            rdd = rdd.map(converter)
      +            if isinstance(schema, (list, tuple)):
      +                for i, name in enumerate(schema):
      +                    struct.fields[i].name = name
      +                    struct.names[i] = name
      +            schema = struct
      +
      +        elif isinstance(schema, StructType):
      +            # take the first few rows to verify schema
      +            rows = rdd.take(10)
      +            for row in rows:
      +                _verify_type(row, schema)
      +
      +        else:
      +            raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
      +
      +        # convert python objects to sql data
      +        rdd = rdd.map(schema.toInternal)
      +        return rdd, schema
      +
      +    def _createFromLocal(self, data, schema):
      +        """
      +        Create an RDD for DataFrame from an list or pandas.DataFrame, returns
      +        the RDD and schema.
      +        """
      +        if has_pandas and isinstance(data, pandas.DataFrame):
      +            if schema is None:
      +                schema = [str(x) for x in data.columns]
      +            data = [r.tolist() for r in data.to_records(index=False)]
      +
      +        # make sure data could consumed multiple times
      +        if not isinstance(data, list):
      +            data = list(data)
      +
      +        if schema is None or isinstance(schema, (list, tuple)):
      +            struct = self._inferSchemaFromList(data)
      +            if isinstance(schema, (list, tuple)):
      +                for i, name in enumerate(schema):
      +                    struct.fields[i].name = name
      +                    struct.names[i] = name
      +            schema = struct
      +
      +        elif isinstance(schema, StructType):
      +            for row in data:
      +                _verify_type(row, schema)
      +
      +        else:
      +            raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
      +
      +        # convert python objects to sql data
      +        data = [schema.toInternal(row) for row in data]
      +        return self._sc.parallelize(data), schema
      +
           @since(1.3)
           @ignore_unicode_prefix
           def createDataFrame(self, data, schema=None, samplingRatio=None):
      @@ -311,54 +394,21 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
       
               >>> sqlContext.createDataFrame(df.toPandas()).collect()  # doctest: +SKIP
               [Row(name=u'Alice', age=1)]
      +        >>> sqlContext.createDataFrame(pandas.DataFrame([[1, 2]]).collect())  # doctest: +SKIP
      +        [Row(0=1, 1=2)]
               """
               if isinstance(data, DataFrame):
                   raise TypeError("data is already a DataFrame")
       
      -        if has_pandas and isinstance(data, pandas.DataFrame):
      -            if schema is None:
      -                schema = list(data.columns)
      -            data = [r.tolist() for r in data.to_records(index=False)]
      -
      -        if not isinstance(data, RDD):
      -            try:
      -                # data could be list, tuple, generator ...
      -                rdd = self._sc.parallelize(data)
      -            except Exception:
      -                raise TypeError("cannot create an RDD from type: %s" % type(data))
      +        if isinstance(data, RDD):
      +            rdd, schema = self._createFromRDD(data, schema, samplingRatio)
               else:
      -            rdd = data
      -
      -        if schema is None:
      -            schema = self._inferSchema(rdd, samplingRatio)
      -            converter = _create_converter(schema)
      -            rdd = rdd.map(converter)
      -
      -        if isinstance(schema, (list, tuple)):
      -            first = rdd.first()
      -            if not isinstance(first, (list, tuple)):
      -                raise TypeError("each row in `rdd` should be list or tuple, "
      -                                "but got %r" % type(first))
      -            row_cls = Row(*schema)
      -            schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), samplingRatio)
      -
      -        # take the first few rows to verify schema
      -        rows = rdd.take(10)
      -        # Row() cannot been deserialized by Pyrolite
      -        if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row':
      -            rdd = rdd.map(tuple)
      -            rows = rdd.take(10)
      -
      -        for row in rows:
      -            _verify_type(row, schema)
      -
      -        # convert python objects to sql data
      -        converter = _python_to_sql_converter(schema)
      -        rdd = rdd.map(converter)
      -
      +            rdd, schema = self._createFromLocal(data, schema)
               jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
      -        df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
      -        return DataFrame(df, self)
      +        jdf = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
      +        df = DataFrame(jdf, self)
      +        df._schema = schema
      +        return df
       
           @since(1.3)
           def registerDataFrameAsTable(self, df, tableName):
      diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
      index 152b87351db3..fb995fa3a76b 100644
      --- a/python/pyspark/sql/dataframe.py
      +++ b/python/pyspark/sql/dataframe.py
      @@ -26,13 +26,13 @@
       else:
           from itertools import imap as map
       
      +from pyspark import since
       from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
       from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
       from pyspark.storagelevel import StorageLevel
       from pyspark.traceback_utils import SCCallSiteSync
      -from pyspark.sql import since
      -from pyspark.sql.types import _create_cls, _parse_datatype_json_string
      -from pyspark.sql.column import Column, _to_seq, _to_java_column
      +from pyspark.sql.types import _parse_datatype_json_string
      +from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column
       from pyspark.sql.readwriter import DataFrameWriter
       from pyspark.sql.types import *
       
      @@ -83,15 +83,7 @@ def rdd(self):
               """
               if self._lazy_rdd is None:
                   jrdd = self._jdf.javaToPython()
      -            rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
      -            schema = self.schema
      -
      -            def applySchema(it):
      -                cls = _create_cls(schema)
      -                return map(cls, it)
      -
      -            self._lazy_rdd = rdd.mapPartitions(applySchema)
      -
      +            self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
               return self._lazy_rdd
       
           @property
      @@ -220,8 +212,7 @@ def explain(self, extended=False):
               :param extended: boolean, default ``False``. If ``False``, prints only the physical plan.
       
               >>> df.explain()
      -        PhysicalRDD [age#0,name#1], MapPartitionsRDD[...] at applySchemaToPythonRDD at\
      -          NativeMethodAccessorImpl.java:...
      +        Scan PhysicalRDD[age#0,name#1]
       
               >>> df.explain(True)
               == Parsed Logical Plan ==
      @@ -232,7 +223,6 @@ def explain(self, extended=False):
               ...
               == Physical Plan ==
               ...
      -        == RDD ==
               """
               if extended:
                   print(self._jdf.queryExecution().toString())
      @@ -247,9 +237,12 @@ def isLocal(self):
               return self._jdf.isLocal()
       
           @since(1.3)
      -    def show(self, n=20):
      +    def show(self, n=20, truncate=True):
               """Prints the first ``n`` rows to the console.
       
      +        :param n: Number of rows to show.
      +        :param truncate: Whether truncate long strings and align cells right.
      +
               >>> df
               DataFrame[age: int, name: string]
               >>> df.show()
      @@ -260,7 +253,7 @@ def show(self, n=20):
               |  5|  Bob|
               +---+-----+
               """
      -        print(self._jdf.showString(n))
      +        print(self._jdf.showString(n, truncate))
       
           def __repr__(self):
               return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
      @@ -284,9 +277,7 @@ def collect(self):
               """
               with SCCallSiteSync(self._sc) as css:
                   port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd())
      -        rs = list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
      -        cls = _create_cls(self.schema)
      -        return [cls(r) for r in rs]
      +        return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
       
           @ignore_unicode_prefix
           @since(1.3)
      @@ -448,6 +439,42 @@ def sample(self, withReplacement, fraction, seed=None):
               rdd = self._jdf.sample(withReplacement, fraction, long(seed))
               return DataFrame(rdd, self.sql_ctx)
       
      +    @since(1.5)
      +    def sampleBy(self, col, fractions, seed=None):
      +        """
      +        Returns a stratified sample without replacement based on the
      +        fraction given on each stratum.
      +
      +        :param col: column that defines strata
      +        :param fractions:
      +            sampling fraction for each stratum. If a stratum is not
      +            specified, we treat its fraction as zero.
      +        :param seed: random seed
      +        :return: a new DataFrame that represents the stratified sample
      +
      +        >>> from pyspark.sql.functions import col
      +        >>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key"))
      +        >>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0)
      +        >>> sampled.groupBy("key").count().orderBy("key").show()
      +        +---+-----+
      +        |key|count|
      +        +---+-----+
      +        |  0|    3|
      +        |  1|    8|
      +        +---+-----+
      +
      +        """
      +        if not isinstance(col, str):
      +            raise ValueError("col must be a string, but got %r" % type(col))
      +        if not isinstance(fractions, dict):
      +            raise ValueError("fractions must be a dict but got %r" % type(fractions))
      +        for k, v in fractions.items():
      +            if not isinstance(k, (float, int, long, basestring)):
      +                raise ValueError("key must be float, int, long, or string, but got %r" % type(k))
      +            fractions[k] = float(v)
      +        seed = seed if seed is not None else random.randint(0, sys.maxsize)
      +        return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx)
      +
           @since(1.4)
           def randomSplit(self, weights, seed=None):
               """Randomly splits this :class:`DataFrame` with the provided weights.
      @@ -467,7 +494,7 @@ def randomSplit(self, weights, seed=None):
                   if w < 0.0:
                       raise ValueError("Weights must be positive. Found weight value: %s" % w)
               seed = seed if seed is not None else random.randint(0, sys.maxsize)
      -        rdd_array = self._jdf.randomSplit(_to_seq(self.sql_ctx._sc, weights), long(seed))
      +        rdd_array = self._jdf.randomSplit(_to_list(self.sql_ctx._sc, weights), long(seed))
               return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array]
       
           @property
      @@ -481,13 +508,12 @@ def dtypes(self):
               return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields]
       
           @property
      -    @ignore_unicode_prefix
           @since(1.3)
           def columns(self):
               """Returns all column names as a list.
       
               >>> df.columns
      -        [u'age', u'name']
      +        ['age', 'name']
               """
               return [f.name for f in self.schema.fields]
       
      @@ -540,8 +566,7 @@ def join(self, other, on=None, how=None):
       
               if on is None or len(on) == 0:
                   jdf = self._jdf.join(other._jdf)
      -
      -        if isinstance(on[0], basestring):
      +        elif isinstance(on[0], basestring):
                   jdf = self._jdf.join(other._jdf, self._jseq(on))
               else:
                   assert isinstance(on[0], Column), "on should be Column or list of Column"
      @@ -628,25 +653,25 @@ def describe(self, *cols):
               guarantee about the backward compatibility of the schema of the resulting DataFrame.
       
               >>> df.describe().show()
      -        +-------+---+
      -        |summary|age|
      -        +-------+---+
      -        |  count|  2|
      -        |   mean|3.5|
      -        | stddev|1.5|
      -        |    min|  2|
      -        |    max|  5|
      -        +-------+---+
      +        +-------+------------------+
      +        |summary|               age|
      +        +-------+------------------+
      +        |  count|                 2|
      +        |   mean|               3.5|
      +        | stddev|2.1213203435596424|
      +        |    min|                 2|
      +        |    max|                 5|
      +        +-------+------------------+
               >>> df.describe(['age', 'name']).show()
      -        +-------+---+-----+
      -        |summary|age| name|
      -        +-------+---+-----+
      -        |  count|  2|    2|
      -        |   mean|3.5| null|
      -        | stddev|1.5| null|
      -        |    min|  2|Alice|
      -        |    max|  5|  Bob|
      -        +-------+---+-----+
      +        +-------+------------------+-----+
      +        |summary|               age| name|
      +        +-------+------------------+-----+
      +        |  count|                 2|    2|
      +        |   mean|               3.5| null|
      +        | stddev|2.1213203435596424| null|
      +        |    min|                 2|Alice|
      +        |    max|                 5|  Bob|
      +        +-------+------------------+-----+
               """
               if len(cols) == 1 and isinstance(cols[0], list):
                   cols = cols[0]
      @@ -697,8 +722,6 @@ def __getitem__(self, item):
               [Row(age=5, name=u'Bob')]
               """
               if isinstance(item, basestring):
      -            if item not in self.columns:
      -                raise IndexError("no such column: %s" % item)
                   jc = self._jdf.apply(item)
                   return Column(jc)
               elif isinstance(item, Column):
      @@ -800,11 +823,11 @@ def groupBy(self, *cols):
                   Each element should be a column name (string) or an expression (:class:`Column`).
       
               >>> df.groupBy().avg().collect()
      -        [Row(AVG(age)=3.5)]
      +        [Row(avg(age)=3.5)]
               >>> df.groupBy('name').agg({'age': 'mean'}).collect()
      -        [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
      +        [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)]
               >>> df.groupBy(df.name).avg().collect()
      -        [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
      +        [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)]
               >>> df.groupBy(['name', df.age]).count().collect()
               [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)]
               """
      @@ -862,10 +885,10 @@ def agg(self, *exprs):
               (shorthand for ``df.groupBy.agg()``).
       
               >>> df.agg({"age": "max"}).collect()
      -        [Row(MAX(age)=5)]
      +        [Row(max(age)=5)]
               >>> from pyspark.sql import functions as F
               >>> df.agg(F.min(df.age)).collect()
      -        [Row(MIN(age)=2)]
      +        [Row(min(age)=2)]
               """
               return self.groupBy().agg(*exprs)
       
      @@ -1138,7 +1161,7 @@ def crosstab(self, col1, col2):
               non-zero pair frequencies will be returned.
               The first column of each row will be the distinct values of `col1` and the column names
               will be the distinct values of `col2`. The name of the first column will be `$col1_$col2`.
      -        Pairs that have no occurrences will have `null` as their counts.
      +        Pairs that have no occurrences will have zero as their counts.
               :func:`DataFrame.crosstab` and :func:`DataFrameStatFunctions.crosstab` are aliases.
       
               :param col1: The name of the first column. Distinct items will make the first item of
      @@ -1179,7 +1202,9 @@ def freqItems(self, cols, support=None):
           @ignore_unicode_prefix
           @since(1.3)
           def withColumn(self, colName, col):
      -        """Returns a new :class:`DataFrame` by adding a column.
      +        """
      +        Returns a new :class:`DataFrame` by adding a column or replacing the
      +        existing column that has the same name.
       
               :param colName: string, name of the new column.
               :param col: a :class:`Column` expression for the new column.
      @@ -1187,7 +1212,8 @@ def withColumn(self, colName, col):
               >>> df.withColumn('age2', df.age + 2).collect()
               [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
               """
      -        return self.select('*', col.alias(colName))
      +        assert isinstance(col, Column), "col should be Column"
      +        return DataFrame(self._jdf.withColumn(colName, col._jc), self.sql_ctx)
       
           @ignore_unicode_prefix
           @since(1.3)
      @@ -1200,10 +1226,7 @@ def withColumnRenamed(self, existing, new):
               >>> df.withColumnRenamed('age', 'age2').collect()
               [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')]
               """
      -        cols = [Column(_to_java_column(c)).alias(new)
      -                if c == existing else c
      -                for c in self.columns]
      -        return self.select(*cols)
      +        return DataFrame(self._jdf.withColumnRenamed(existing, new), self.sql_ctx)
       
           @since(1.4)
           @ignore_unicode_prefix
      @@ -1322,6 +1345,11 @@ def freqItems(self, cols, support=None):
       
           freqItems.__doc__ = DataFrame.freqItems.__doc__
       
      +    def sampleBy(self, col, fractions, seed=None):
      +        return self.df.sampleBy(col, fractions, seed)
      +
      +    sampleBy.__doc__ = DataFrame.sampleBy.__doc__
      +
       
       def _test():
           import doctest
      diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
      index cfa87aeea193..26b8662718a6 100644
      --- a/python/pyspark/sql/functions.py
      +++ b/python/pyspark/sql/functions.py
      @@ -24,32 +24,13 @@
       if sys.version < "3":
           from itertools import imap as map
       
      -from pyspark import SparkContext
      +from pyspark import since, SparkContext
       from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
       from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
      -from pyspark.sql import since
       from pyspark.sql.types import StringType
       from pyspark.sql.column import Column, _to_java_column, _to_seq
       
       
      -__all__ = [
      -    'array',
      -    'approxCountDistinct',
      -    'bin',
      -    'coalesce',
      -    'countDistinct',
      -    'explode',
      -    'monotonicallyIncreasingId',
      -    'rand',
      -    'randn',
      -    'sparkPartitionId',
      -    'struct',
      -    'udf',
      -    'when']
      -
      -__all__ += ['lag', 'lead', 'ntile']
      -
      -
       def _create_function(name, doc=""):
           """ Create a function for aggregator by name"""
           def _(col):
      @@ -191,30 +172,6 @@ def _():
       for _name, _doc in _window_functions.items():
           globals()[_name] = since(1.4)(_create_window_function(_name, _doc))
       del _name, _doc
      -__all__ += _functions.keys()
      -__all__ += _functions_1_4.keys()
      -__all__ += _binary_mathfunctions.keys()
      -__all__ += _window_functions.keys()
      -__all__.sort()
      -
      -
      -@since(1.4)
      -def array(*cols):
      -    """Creates a new array column.
      -
      -    :param cols: list of column names (string) or list of :class:`Column` expressions that have
      -        the same data type.
      -
      -    >>> df.select(array('age', 'age').alias("arr")).collect()
      -    [Row(arr=[2, 2]), Row(arr=[5, 5])]
      -    >>> df.select(array([df.age, df.age]).alias("arr")).collect()
      -    [Row(arr=[2, 2]), Row(arr=[5, 5])]
      -    """
      -    sc = SparkContext._active_spark_context
      -    if len(cols) == 1 and isinstance(cols[0], (list, set)):
      -        cols = cols[0]
      -    jc = sc._jvm.functions.array(_to_seq(sc, cols, _to_java_column))
      -    return Column(jc)
       
       
       @since(1.3)
      @@ -232,19 +189,6 @@ def approxCountDistinct(col, rsd=None):
           return Column(jc)
       
       
      -@ignore_unicode_prefix
      -@since(1.5)
      -def bin(col):
      -    """Returns the string representation of the binary value of the given column.
      -
      -    >>> df.select(bin(df.age).alias('c')).collect()
      -    [Row(c=u'10'), Row(c=u'101')]
      -    """
      -    sc = SparkContext._active_spark_context
      -    jc = sc._jvm.functions.bin(_to_java_column(col))
      -    return Column(jc)
      -
      -
       @since(1.4)
       def coalesce(*cols):
           """Returns the first column that is not null.
      @@ -261,7 +205,7 @@ def coalesce(*cols):
       
           >>> cDf.select(coalesce(cDf["a"], cDf["b"])).show()
           +-------------+
      -    |Coalesce(a,b)|
      +    |coalesce(a,b)|
           +-------------+
           |         null|
           |            1|
      @@ -270,7 +214,7 @@ def coalesce(*cols):
       
           >>> cDf.select('*', coalesce(cDf["a"], lit(0.0))).show()
           +----+----+---------------+
      -    |   a|   b|Coalesce(a,0.0)|
      +    |   a|   b|coalesce(a,0.0)|
           +----+----+---------------+
           |null|null|            0.0|
           |   1|null|            1.0|
      @@ -297,27 +241,6 @@ def countDistinct(col, *cols):
           return Column(jc)
       
       
      -@since(1.4)
      -def explode(col):
      -    """Returns a new row for each element in the given array or map.
      -
      -    >>> from pyspark.sql import Row
      -    >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})])
      -    >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect()
      -    [Row(anInt=1), Row(anInt=2), Row(anInt=3)]
      -
      -    >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show()
      -    +---+-----+
      -    |key|value|
      -    +---+-----+
      -    |  a|    b|
      -    +---+-----+
      -    """
      -    sc = SparkContext._active_spark_context
      -    jc = sc._jvm.functions.explode(_to_java_column(col))
      -    return Column(jc)
      -
      -
       @since(1.4)
       def monotonicallyIncreasingId():
           """A column that generates monotonically increasing 64-bit integers.
      @@ -344,7 +267,7 @@ def rand(seed=None):
           """Generates a random column with i.i.d. samples from U[0.0, 1.0].
           """
           sc = SparkContext._active_spark_context
      -    if seed:
      +    if seed is not None:
               jc = sc._jvm.functions.rand(seed)
           else:
               jc = sc._jvm.functions.rand()
      @@ -356,13 +279,62 @@ def randn(seed=None):
           """Generates a column with i.i.d. samples from the standard normal distribution.
           """
           sc = SparkContext._active_spark_context
      -    if seed:
      +    if seed is not None:
               jc = sc._jvm.functions.randn(seed)
           else:
               jc = sc._jvm.functions.randn()
           return Column(jc)
       
       
      +@since(1.5)
      +def round(col, scale=0):
      +    """
      +    Round the value of `e` to `scale` decimal places if `scale` >= 0
      +    or at integral part when `scale` < 0.
      +
      +    >>> sqlContext.createDataFrame([(2.546,)], ['a']).select(round('a', 1).alias('r')).collect()
      +    [Row(r=2.5)]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.round(_to_java_column(col), scale))
      +
      +
      +@since(1.5)
      +def shiftLeft(col, numBits):
      +    """Shift the the given value numBits left.
      +
      +    >>> sqlContext.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).collect()
      +    [Row(r=42)]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.shiftLeft(_to_java_column(col), numBits))
      +
      +
      +@since(1.5)
      +def shiftRight(col, numBits):
      +    """Shift the the given value numBits right.
      +
      +    >>> sqlContext.createDataFrame([(42,)], ['a']).select(shiftRight('a', 1).alias('r')).collect()
      +    [Row(r=21)]
      +    """
      +    sc = SparkContext._active_spark_context
      +    jc = sc._jvm.functions.shiftRight(_to_java_column(col), numBits)
      +    return Column(jc)
      +
      +
      +@since(1.5)
      +def shiftRightUnsigned(col, numBits):
      +    """Unsigned shift the the given value numBits right.
      +
      +    >>> df = sqlContext.createDataFrame([(-42,)], ['a'])
      +    >>> df.select(shiftRightUnsigned('a', 1).alias('r')).collect()
      +    [Row(r=9223372036854775787)]
      +    """
      +    sc = SparkContext._active_spark_context
      +    jc = sc._jvm.functions.shiftRightUnsigned(_to_java_column(col), numBits)
      +    return Column(jc)
      +
      +
       @since(1.4)
       def sparkPartitionId():
           """A column for partition ID of the Spark task.
      @@ -376,13 +348,23 @@ def sparkPartitionId():
           return Column(sc._jvm.functions.sparkPartitionId())
       
       
      +@since(1.5)
      +def expr(str):
      +    """Parses the expression string into the column that it represents
      +
      +    >>> df.select(expr("length(name)")).collect()
      +    [Row('length(name)=5), Row('length(name)=3)]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.expr(str))
      +
      +
       @ignore_unicode_prefix
       @since(1.4)
       def struct(*cols):
           """Creates a new struct column.
       
           :param cols: list of column names (string) or list of :class:`Column` expressions
      -        that are named or aliased.
       
           >>> df.select(struct('age', 'name').alias("struct")).collect()
           [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))]
      @@ -396,6 +378,38 @@ def struct(*cols):
           return Column(jc)
       
       
      +@since(1.5)
      +def greatest(*cols):
      +    """
      +    Returns the greatest value of the list of column names, skipping null values.
      +    This function takes at least 2 parameters. It will return null iff all parameters are null.
      +
      +    >>> df = sqlContext.createDataFrame([(1, 4, 3)], ['a', 'b', 'c'])
      +    >>> df.select(greatest(df.a, df.b, df.c).alias("greatest")).collect()
      +    [Row(greatest=4)]
      +    """
      +    if len(cols) < 2:
      +        raise ValueError("greatest should take at least two columns")
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.greatest(_to_seq(sc, cols, _to_java_column)))
      +
      +
      +@since(1.5)
      +def least(*cols):
      +    """
      +    Returns the least value of the list of column names, skipping null values.
      +    This function takes at least 2 parameters. It will return null iff all parameters are null.
      +
      +    >>> df = sqlContext.createDataFrame([(1, 4, 3)], ['a', 'b', 'c'])
      +    >>> df.select(least(df.a, df.b, df.c).alias("least")).collect()
      +    [Row(least=1)]
      +    """
      +    if len(cols) < 2:
      +        raise ValueError("least should take at least two columns")
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.least(_to_seq(sc, cols, _to_java_column)))
      +
      +
       @since(1.4)
       def when(condition, value):
           """Evaluates a list of conditions and returns one of multiple possible result expressions.
      @@ -438,6 +452,46 @@ def log(arg1, arg2=None):
           return Column(jc)
       
       
      +@since(1.5)
      +def log2(col):
      +    """Returns the base-2 logarithm of the argument.
      +
      +    >>> sqlContext.createDataFrame([(4,)], ['a']).select(log2('a').alias('log2')).collect()
      +    [Row(log2=2.0)]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.log2(_to_java_column(col)))
      +
      +
      +@since(1.5)
      +@ignore_unicode_prefix
      +def conv(col, fromBase, toBase):
      +    """
      +    Convert a number in a string column from one base to another.
      +
      +    >>> df = sqlContext.createDataFrame([("010101",)], ['n'])
      +    >>> df.select(conv(df.n, 2, 16).alias('hex')).collect()
      +    [Row(hex=u'15')]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.conv(_to_java_column(col), fromBase, toBase))
      +
      +
      +@since(1.5)
      +def factorial(col):
      +    """
      +    Computes the factorial of the given value.
      +
      +    >>> df = sqlContext.createDataFrame([(5,)], ['n'])
      +    >>> df.select(factorial(df.n).alias('f')).collect()
      +    [Row(f=120)]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.factorial(_to_java_column(col)))
      +
      +
      +# ---------------  Window functions ------------------------
      +
       @since(1.4)
       def lag(col, count=1, default=None):
           """
      @@ -475,9 +529,10 @@ def lead(col, count=1, default=None):
       @since(1.4)
       def ntile(n):
           """
      -    Window function: returns a group id from 1 to `n` (inclusive) in a round-robin fashion in
      -    a window partition. Fow example, if `n` is 3, the first row will get 1, the second row will
      -    get 2, the third row will get 3, and the fourth row will get 1...
      +    Window function: returns the ntile group id (from 1 to `n` inclusive)
      +    in an ordered window partition. For example, if `n` is 4, the first
      +    quarter of the rows will get value 1, the second quarter will get 2,
      +    the third quarter will get 3, and the last quarter will get 4.
       
           This is equivalent to the NTILE function in SQL.
       
      @@ -487,54 +542,917 @@ def ntile(n):
           return Column(sc._jvm.functions.ntile(int(n)))
       
       
      -class UserDefinedFunction(object):
      +# ---------------------- Date/Timestamp functions ------------------------------
      +
      +@since(1.5)
      +def current_date():
           """
      -    User defined function in Python
      +    Returns the current date as a date column.
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.current_date())
       
      -    .. versionadded:: 1.3
      +
      +def current_timestamp():
           """
      -    def __init__(self, func, returnType):
      -        self.func = func
      -        self.returnType = returnType
      -        self._broadcast = None
      -        self._judf = self._create_judf()
      +    Returns the current timestamp as a timestamp column.
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.current_timestamp())
       
      -    def _create_judf(self):
      -        f = self.func  # put it in closure `func`
      -        func = lambda _, it: map(lambda x: f(*x), it)
      -        ser = AutoBatchedSerializer(PickleSerializer())
      -        command = (func, None, ser, ser)
      -        sc = SparkContext._active_spark_context
      -        pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
      -        ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
      -        jdt = ssql_ctx.parseDataType(self.returnType.json())
      -        fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
      -        judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes,
      -                                                 sc.pythonExec, sc.pythonVer, broadcast_vars,
      -                                                 sc._javaAccumulator, jdt)
      -        return judf
       
      -    def __del__(self):
      -        if self._broadcast is not None:
      -            self._broadcast.unpersist()
      -            self._broadcast = None
      +@ignore_unicode_prefix
      +@since(1.5)
      +def date_format(date, format):
      +    """
      +    Converts a date/timestamp/string to a value of string in the format specified by the date
      +    format given by the second argument.
       
      -    def __call__(self, *cols):
      -        sc = SparkContext._active_spark_context
      -        jc = self._judf.apply(_to_seq(sc, cols, _to_java_column))
      -        return Column(jc)
      +    A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All
      +    pattern letters of the Java class `java.text.SimpleDateFormat` can be used.
       
      +    NOTE: Use when ever possible specialized functions like `year`. These benefit from a
      +    specialized implementation.
       
      -@since(1.3)
      -def udf(f, returnType=StringType()):
      -    """Creates a :class:`Column` expression representing a user defined function (UDF).
      +    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
      +    >>> df.select(date_format('a', 'MM/dd/yyy').alias('date')).collect()
      +    [Row(date=u'04/08/2015')]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.date_format(_to_java_column(date), format))
       
      -    >>> from pyspark.sql.types import IntegerType
      -    >>> slen = udf(lambda s: len(s), IntegerType())
      -    >>> df.select(slen(df.name).alias('slen')).collect()
      -    [Row(slen=5), Row(slen=3)]
      +
      +@since(1.5)
      +def year(col):
           """
      -    return UserDefinedFunction(f, returnType)
      +    Extract the year of a given date as integer.
      +
      +    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
      +    >>> df.select(year('a').alias('year')).collect()
      +    [Row(year=2015)]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.year(_to_java_column(col)))
      +
      +
      +@since(1.5)
      +def quarter(col):
      +    """
      +    Extract the quarter of a given date as integer.
      +
      +    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
      +    >>> df.select(quarter('a').alias('quarter')).collect()
      +    [Row(quarter=2)]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.quarter(_to_java_column(col)))
      +
      +
      +@since(1.5)
      +def month(col):
      +    """
      +    Extract the month of a given date as integer.
      +
      +    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
      +    >>> df.select(month('a').alias('month')).collect()
      +    [Row(month=4)]
      +   """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.month(_to_java_column(col)))
      +
      +
      +@since(1.5)
      +def dayofmonth(col):
      +    """
      +    Extract the day of the month of a given date as integer.
      +
      +    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
      +    >>> df.select(dayofmonth('a').alias('day')).collect()
      +    [Row(day=8)]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.dayofmonth(_to_java_column(col)))
      +
      +
      +@since(1.5)
      +def dayofyear(col):
      +    """
      +    Extract the day of the year of a given date as integer.
      +
      +    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
      +    >>> df.select(dayofyear('a').alias('day')).collect()
      +    [Row(day=98)]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.dayofyear(_to_java_column(col)))
      +
      +
      +@since(1.5)
      +def hour(col):
      +    """
      +    Extract the hours of a given date as integer.
      +
      +    >>> df = sqlContext.createDataFrame([('2015-04-08 13:08:15',)], ['a'])
      +    >>> df.select(hour('a').alias('hour')).collect()
      +    [Row(hour=13)]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.hour(_to_java_column(col)))
      +
      +
      +@since(1.5)
      +def minute(col):
      +    """
      +    Extract the minutes of a given date as integer.
      +
      +    >>> df = sqlContext.createDataFrame([('2015-04-08 13:08:15',)], ['a'])
      +    >>> df.select(minute('a').alias('minute')).collect()
      +    [Row(minute=8)]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.minute(_to_java_column(col)))
      +
      +
      +@since(1.5)
      +def second(col):
      +    """
      +    Extract the seconds of a given date as integer.
      +
      +    >>> df = sqlContext.createDataFrame([('2015-04-08 13:08:15',)], ['a'])
      +    >>> df.select(second('a').alias('second')).collect()
      +    [Row(second=15)]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.second(_to_java_column(col)))
      +
      +
      +@since(1.5)
      +def weekofyear(col):
      +    """
      +    Extract the week number of a given date as integer.
      +
      +    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
      +    >>> df.select(weekofyear(df.a).alias('week')).collect()
      +    [Row(week=15)]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.weekofyear(_to_java_column(col)))
      +
      +
      +@since(1.5)
      +def date_add(start, days):
      +    """
      +    Returns the date that is `days` days after `start`
      +
      +    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
      +    >>> df.select(date_add(df.d, 1).alias('d')).collect()
      +    [Row(d=datetime.date(2015, 4, 9))]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.date_add(_to_java_column(start), days))
      +
      +
      +@since(1.5)
      +def date_sub(start, days):
      +    """
      +    Returns the date that is `days` days before `start`
      +
      +    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
      +    >>> df.select(date_sub(df.d, 1).alias('d')).collect()
      +    [Row(d=datetime.date(2015, 4, 7))]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.date_sub(_to_java_column(start), days))
      +
      +
      +@since(1.5)
      +def datediff(end, start):
      +    """
      +    Returns the number of days from `start` to `end`.
      +
      +    >>> df = sqlContext.createDataFrame([('2015-04-08','2015-05-10')], ['d1', 'd2'])
      +    >>> df.select(datediff(df.d2, df.d1).alias('diff')).collect()
      +    [Row(diff=32)]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.datediff(_to_java_column(end), _to_java_column(start)))
      +
      +
      +@since(1.5)
      +def add_months(start, months):
      +    """
      +    Returns the date that is `months` months after `start`
      +
      +    >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
      +    >>> df.select(add_months(df.d, 1).alias('d')).collect()
      +    [Row(d=datetime.date(2015, 5, 8))]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.add_months(_to_java_column(start), months))
      +
      +
      +@since(1.5)
      +def months_between(date1, date2):
      +    """
      +    Returns the number of months between date1 and date2.
      +
      +    >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['t', 'd'])
      +    >>> df.select(months_between(df.t, df.d).alias('months')).collect()
      +    [Row(months=3.9495967...)]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2)))
      +
      +
      +@since(1.5)
      +def to_date(col):
      +    """
      +    Converts the column of StringType or TimestampType into DateType.
      +
      +    >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
      +    >>> df.select(to_date(df.t).alias('date')).collect()
      +    [Row(date=datetime.date(1997, 2, 28))]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.to_date(_to_java_column(col)))
      +
      +
      +@since(1.5)
      +def trunc(date, format):
      +    """
      +    Returns date truncated to the unit specified by the format.
      +
      +    :param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm'
      +
      +    >>> df = sqlContext.createDataFrame([('1997-02-28',)], ['d'])
      +    >>> df.select(trunc(df.d, 'year').alias('year')).collect()
      +    [Row(year=datetime.date(1997, 1, 1))]
      +    >>> df.select(trunc(df.d, 'mon').alias('month')).collect()
      +    [Row(month=datetime.date(1997, 2, 1))]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.trunc(_to_java_column(date), format))
      +
      +
      +@since(1.5)
      +def next_day(date, dayOfWeek):
      +    """
      +    Returns the first date which is later than the value of the date column.
      +
      +    Day of the week parameter is case insensitive, and accepts:
      +        "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun".
      +
      +    >>> df = sqlContext.createDataFrame([('2015-07-27',)], ['d'])
      +    >>> df.select(next_day(df.d, 'Sun').alias('date')).collect()
      +    [Row(date=datetime.date(2015, 8, 2))]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.next_day(_to_java_column(date), dayOfWeek))
      +
      +
      +@since(1.5)
      +def last_day(date):
      +    """
      +    Returns the last day of the month which the given date belongs to.
      +
      +    >>> df = sqlContext.createDataFrame([('1997-02-10',)], ['d'])
      +    >>> df.select(last_day(df.d).alias('date')).collect()
      +    [Row(date=datetime.date(1997, 2, 28))]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.last_day(_to_java_column(date)))
      +
      +
      +@since(1.5)
      +def from_unixtime(timestamp, format="yyyy-MM-dd HH:mm:ss"):
      +    """
      +    Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string
      +    representing the timestamp of that moment in the current system time zone in the given
      +    format.
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.from_unixtime(_to_java_column(timestamp), format))
      +
      +
      +@since(1.5)
      +def unix_timestamp(timestamp=None, format='yyyy-MM-dd HH:mm:ss'):
      +    """
      +    Convert time string with given pattern ('yyyy-MM-dd HH:mm:ss', by default)
      +    to Unix time stamp (in seconds), using the default timezone and the default
      +    locale, return null if fail.
      +
      +    if `timestamp` is None, then it returns current timestamp.
      +    """
      +    sc = SparkContext._active_spark_context
      +    if timestamp is None:
      +        return Column(sc._jvm.functions.unix_timestamp())
      +    return Column(sc._jvm.functions.unix_timestamp(_to_java_column(timestamp), format))
      +
      +
      +@since(1.5)
      +def from_utc_timestamp(timestamp, tz):
      +    """
      +    Assumes given timestamp is UTC and converts to given timezone.
      +
      +    >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
      +    >>> df.select(from_utc_timestamp(df.t, "PST").alias('t')).collect()
      +    [Row(t=datetime.datetime(1997, 2, 28, 2, 30))]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.from_utc_timestamp(_to_java_column(timestamp), tz))
      +
      +
      +@since(1.5)
      +def to_utc_timestamp(timestamp, tz):
      +    """
      +    Assumes given timestamp is in given timezone and converts to UTC.
      +
      +    >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
      +    >>> df.select(to_utc_timestamp(df.t, "PST").alias('t')).collect()
      +    [Row(t=datetime.datetime(1997, 2, 28, 18, 30))]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.to_utc_timestamp(_to_java_column(timestamp), tz))
      +
      +
      +# ---------------------------- misc functions ----------------------------------
      +
      +@since(1.5)
      +@ignore_unicode_prefix
      +def crc32(col):
      +    """
      +    Calculates the cyclic redundancy check value  (CRC32) of a binary column and
      +    returns the value as a bigint.
      +
      +    >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(crc32('a').alias('crc32')).collect()
      +    [Row(crc32=2743272264)]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.crc32(_to_java_column(col)))
      +
      +
      +@ignore_unicode_prefix
      +@since(1.5)
      +def md5(col):
      +    """Calculates the MD5 digest and returns the value as a 32 character hex string.
      +
      +    >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect()
      +    [Row(hash=u'902fbdd2b1df0c4f70b4a5d23525e932')]
      +    """
      +    sc = SparkContext._active_spark_context
      +    jc = sc._jvm.functions.md5(_to_java_column(col))
      +    return Column(jc)
      +
      +
      +@ignore_unicode_prefix
      +@since(1.5)
      +def sha1(col):
      +    """Returns the hex string result of SHA-1.
      +
      +    >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect()
      +    [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')]
      +    """
      +    sc = SparkContext._active_spark_context
      +    jc = sc._jvm.functions.sha1(_to_java_column(col))
      +    return Column(jc)
      +
      +
      +@ignore_unicode_prefix
      +@since(1.5)
      +def sha2(col, numBits):
      +    """Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384,
      +    and SHA-512). The numBits indicates the desired bit length of the result, which must have a
      +    value of 224, 256, 384, 512, or 0 (which is equivalent to 256).
      +
      +    >>> digests = df.select(sha2(df.name, 256).alias('s')).collect()
      +    >>> digests[0]
      +    Row(s=u'3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043')
      +    >>> digests[1]
      +    Row(s=u'cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961')
      +    """
      +    sc = SparkContext._active_spark_context
      +    jc = sc._jvm.functions.sha2(_to_java_column(col), numBits)
      +    return Column(jc)
      +
      +
      +# ---------------------- String/Binary functions ------------------------------
      +
      +_string_functions = {
      +    'ascii': 'Computes the numeric value of the first character of the string column.',
      +    'base64': 'Computes the BASE64 encoding of a binary column and returns it as a string column.',
      +    'unbase64': 'Decodes a BASE64 encoded string column and returns it as a binary column.',
      +    'initcap': 'Returns a new string column by converting the first letter of each word to ' +
      +               'uppercase. Words are delimited by whitespace.',
      +    'lower': 'Converts a string column to lower case.',
      +    'upper': 'Converts a string column to upper case.',
      +    'reverse': 'Reverses the string column and returns it as a new string column.',
      +    'ltrim': 'Trim the spaces from right end for the specified string value.',
      +    'rtrim': 'Trim the spaces from right end for the specified string value.',
      +    'trim': 'Trim the spaces from both ends for the specified string column.',
      +}
      +
      +
      +for _name, _doc in _string_functions.items():
      +    globals()[_name] = since(1.5)(_create_function(_name, _doc))
      +del _name, _doc
      +
      +
      +@since(1.5)
      +@ignore_unicode_prefix
      +def concat(*cols):
      +    """
      +    Concatenates multiple input string columns together into a single string column.
      +
      +    >>> df = sqlContext.createDataFrame([('abcd','123')], ['s', 'd'])
      +    >>> df.select(concat(df.s, df.d).alias('s')).collect()
      +    [Row(s=u'abcd123')]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))
      +
      +
      +@since(1.5)
      +@ignore_unicode_prefix
      +def concat_ws(sep, *cols):
      +    """
      +    Concatenates multiple input string columns together into a single string column,
      +    using the given separator.
      +
      +    >>> df = sqlContext.createDataFrame([('abcd','123')], ['s', 'd'])
      +    >>> df.select(concat_ws('-', df.s, df.d).alias('s')).collect()
      +    [Row(s=u'abcd-123')]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.concat_ws(sep, _to_seq(sc, cols, _to_java_column)))
      +
      +
      +@since(1.5)
      +def decode(col, charset):
      +    """
      +    Computes the first argument into a string from a binary using the provided character set
      +    (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.decode(_to_java_column(col), charset))
      +
      +
      +@since(1.5)
      +def encode(col, charset):
      +    """
      +    Computes the first argument into a binary from a string using the provided character set
      +    (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.encode(_to_java_column(col), charset))
      +
      +
      +@ignore_unicode_prefix
      +@since(1.5)
      +def format_number(col, d):
      +    """
      +    Formats the number X to a format like '#,--#,--#.--', rounded to d decimal places,
      +    and returns the result as a string.
      +
      +    :param col: the column name of the numeric value to be formatted
      +    :param d: the N decimal places
      +
      +    >>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect()
      +    [Row(v=u'5.0000')]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.format_number(_to_java_column(col), d))
      +
      +
      +@ignore_unicode_prefix
      +@since(1.5)
      +def format_string(format, *cols):
      +    """
      +    Formats the arguments in printf-style and returns the result as a string column.
      +
      +    :param col: the column name of the numeric value to be formatted
      +    :param d: the N decimal places
      +
      +    >>> df = sqlContext.createDataFrame([(5, "hello")], ['a', 'b'])
      +    >>> df.select(format_string('%d %s', df.a, df.b).alias('v')).collect()
      +    [Row(v=u'5 hello')]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.format_string(format, _to_seq(sc, cols, _to_java_column)))
      +
      +
      +@since(1.5)
      +def instr(str, substr):
      +    """
      +    Locate the position of the first occurrence of substr column in the given string.
      +    Returns null if either of the arguments are null.
      +
      +    NOTE: The position is not zero based, but 1 based index, returns 0 if substr
      +    could not be found in str.
      +
      +    >>> df = sqlContext.createDataFrame([('abcd',)], ['s',])
      +    >>> df.select(instr(df.s, 'b').alias('s')).collect()
      +    [Row(s=2)]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.instr(_to_java_column(str), substr))
      +
      +
      +@since(1.5)
      +@ignore_unicode_prefix
      +def substring(str, pos, len):
      +    """
      +    Substring starts at `pos` and is of length `len` when str is String type or
      +    returns the slice of byte array that starts at `pos` in byte and is of length `len`
      +    when str is Binary type
      +
      +    >>> df = sqlContext.createDataFrame([('abcd',)], ['s',])
      +    >>> df.select(substring(df.s, 1, 2).alias('s')).collect()
      +    [Row(s=u'ab')]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.substring(_to_java_column(str), pos, len))
      +
      +
      +@since(1.5)
      +@ignore_unicode_prefix
      +def substring_index(str, delim, count):
      +    """
      +    Returns the substring from string str before count occurrences of the delimiter delim.
      +    If count is positive, everything the left of the final delimiter (counting from left) is
      +    returned. If count is negative, every to the right of the final delimiter (counting from the
      +    right) is returned. substring_index performs a case-sensitive match when searching for delim.
      +
      +    >>> df = sqlContext.createDataFrame([('a.b.c.d',)], ['s'])
      +    >>> df.select(substring_index(df.s, '.', 2).alias('s')).collect()
      +    [Row(s=u'a.b')]
      +    >>> df.select(substring_index(df.s, '.', -3).alias('s')).collect()
      +    [Row(s=u'b.c.d')]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.substring_index(_to_java_column(str), delim, count))
      +
      +
      +@ignore_unicode_prefix
      +@since(1.5)
      +def levenshtein(left, right):
      +    """Computes the Levenshtein distance of the two given strings.
      +
      +    >>> df0 = sqlContext.createDataFrame([('kitten', 'sitting',)], ['l', 'r'])
      +    >>> df0.select(levenshtein('l', 'r').alias('d')).collect()
      +    [Row(d=3)]
      +    """
      +    sc = SparkContext._active_spark_context
      +    jc = sc._jvm.functions.levenshtein(_to_java_column(left), _to_java_column(right))
      +    return Column(jc)
      +
      +
      +@since(1.5)
      +def locate(substr, str, pos=0):
      +    """
      +    Locate the position of the first occurrence of substr in a string column, after position pos.
      +
      +    NOTE: The position is not zero based, but 1 based index. returns 0 if substr
      +    could not be found in str.
      +
      +    :param substr: a string
      +    :param str: a Column of StringType
      +    :param pos: start position (zero based)
      +
      +    >>> df = sqlContext.createDataFrame([('abcd',)], ['s',])
      +    >>> df.select(locate('b', df.s, 1).alias('s')).collect()
      +    [Row(s=2)]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.locate(substr, _to_java_column(str), pos))
      +
      +
      +@since(1.5)
      +@ignore_unicode_prefix
      +def lpad(col, len, pad):
      +    """
      +    Left-pad the string column to width `len` with `pad`.
      +
      +    >>> df = sqlContext.createDataFrame([('abcd',)], ['s',])
      +    >>> df.select(lpad(df.s, 6, '#').alias('s')).collect()
      +    [Row(s=u'##abcd')]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.lpad(_to_java_column(col), len, pad))
      +
      +
      +@since(1.5)
      +@ignore_unicode_prefix
      +def rpad(col, len, pad):
      +    """
      +    Right-pad the string column to width `len` with `pad`.
      +
      +    >>> df = sqlContext.createDataFrame([('abcd',)], ['s',])
      +    >>> df.select(rpad(df.s, 6, '#').alias('s')).collect()
      +    [Row(s=u'abcd##')]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.rpad(_to_java_column(col), len, pad))
      +
      +
      +@since(1.5)
      +@ignore_unicode_prefix
      +def repeat(col, n):
      +    """
      +    Repeats a string column n times, and returns it as a new string column.
      +
      +    >>> df = sqlContext.createDataFrame([('ab',)], ['s',])
      +    >>> df.select(repeat(df.s, 3).alias('s')).collect()
      +    [Row(s=u'ababab')]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.repeat(_to_java_column(col), n))
      +
      +
      +@since(1.5)
      +@ignore_unicode_prefix
      +def split(str, pattern):
      +    """
      +    Splits str around pattern (pattern is a regular expression).
      +
      +    NOTE: pattern is a string represent the regular expression.
      +
      +    >>> df = sqlContext.createDataFrame([('ab12cd',)], ['s',])
      +    >>> df.select(split(df.s, '[0-9]+').alias('s')).collect()
      +    [Row(s=[u'ab', u'cd'])]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.split(_to_java_column(str), pattern))
      +
      +
      +@ignore_unicode_prefix
      +@since(1.5)
      +def regexp_extract(str, pattern, idx):
      +    """Extract a specific(idx) group identified by a java regex, from the specified string column.
      +
      +    >>> df = sqlContext.createDataFrame([('100-200',)], ['str'])
      +    >>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect()
      +    [Row(d=u'100')]
      +    """
      +    sc = SparkContext._active_spark_context
      +    jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx)
      +    return Column(jc)
      +
      +
      +@ignore_unicode_prefix
      +@since(1.5)
      +def regexp_replace(str, pattern, replacement):
      +    """Replace all substrings of the specified string value that match regexp with rep.
      +
      +    >>> df = sqlContext.createDataFrame([('100-200',)], ['str'])
      +    >>> df.select(regexp_replace('str', '(\\d+)', '--').alias('d')).collect()
      +    [Row(d=u'-----')]
      +    """
      +    sc = SparkContext._active_spark_context
      +    jc = sc._jvm.functions.regexp_replace(_to_java_column(str), pattern, replacement)
      +    return Column(jc)
      +
      +
      +@ignore_unicode_prefix
      +@since(1.5)
      +def initcap(col):
      +    """Translate the first letter of each word to upper case in the sentence.
      +
      +    >>> sqlContext.createDataFrame([('ab cd',)], ['a']).select(initcap("a").alias('v')).collect()
      +    [Row(v=u'Ab Cd')]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.initcap(_to_java_column(col)))
      +
      +
      +@since(1.5)
      +@ignore_unicode_prefix
      +def soundex(col):
      +    """
      +    Returns the SoundEx encoding for a string
      +
      +    >>> df = sqlContext.createDataFrame([("Peters",),("Uhrbach",)], ['name'])
      +    >>> df.select(soundex(df.name).alias("soundex")).collect()
      +    [Row(soundex=u'P362'), Row(soundex=u'U612')]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.soundex(_to_java_column(col)))
      +
      +
      +@ignore_unicode_prefix
      +@since(1.5)
      +def bin(col):
      +    """Returns the string representation of the binary value of the given column.
      +
      +    >>> df.select(bin(df.age).alias('c')).collect()
      +    [Row(c=u'10'), Row(c=u'101')]
      +    """
      +    sc = SparkContext._active_spark_context
      +    jc = sc._jvm.functions.bin(_to_java_column(col))
      +    return Column(jc)
      +
      +
      +@ignore_unicode_prefix
      +@since(1.5)
      +def hex(col):
      +    """Computes hex value of the given column, which could be StringType,
      +    BinaryType, IntegerType or LongType.
      +
      +    >>> sqlContext.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect()
      +    [Row(hex(a)=u'414243', hex(b)=u'3')]
      +    """
      +    sc = SparkContext._active_spark_context
      +    jc = sc._jvm.functions.hex(_to_java_column(col))
      +    return Column(jc)
      +
      +
      +@ignore_unicode_prefix
      +@since(1.5)
      +def unhex(col):
      +    """Inverse of hex. Interprets each pair of characters as a hexadecimal number
      +    and converts to the byte representation of number.
      +
      +    >>> sqlContext.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect()
      +    [Row(unhex(a)=bytearray(b'ABC'))]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.unhex(_to_java_column(col)))
      +
      +
      +@ignore_unicode_prefix
      +@since(1.5)
      +def length(col):
      +    """Calculates the length of a string or binary expression.
      +
      +    >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect()
      +    [Row(length=3)]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.length(_to_java_column(col)))
      +
      +
      +@ignore_unicode_prefix
      +@since(1.5)
      +def translate(srcCol, matching, replace):
      +    """A function translate any character in the `srcCol` by a character in `matching`.
      +    The characters in `replace` is corresponding to the characters in `matching`.
      +    The translate will happen when any character in the string matching with the character
      +    in the `matching`.
      +
      +    >>> sqlContext.createDataFrame([('translate',)], ['a']).select(translate('a', "rnlt", "123")\
      +    .alias('r')).collect()
      +    [Row(r=u'1a2s3ae')]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.translate(_to_java_column(srcCol), matching, replace))
      +
      +
      +# ---------------------- Collection functions ------------------------------
      +
      +@since(1.4)
      +def array(*cols):
      +    """Creates a new array column.
      +
      +    :param cols: list of column names (string) or list of :class:`Column` expressions that have
      +        the same data type.
      +
      +    >>> df.select(array('age', 'age').alias("arr")).collect()
      +    [Row(arr=[2, 2]), Row(arr=[5, 5])]
      +    >>> df.select(array([df.age, df.age]).alias("arr")).collect()
      +    [Row(arr=[2, 2]), Row(arr=[5, 5])]
      +    """
      +    sc = SparkContext._active_spark_context
      +    if len(cols) == 1 and isinstance(cols[0], (list, set)):
      +        cols = cols[0]
      +    jc = sc._jvm.functions.array(_to_seq(sc, cols, _to_java_column))
      +    return Column(jc)
      +
      +
      +@since(1.5)
      +def array_contains(col, value):
      +    """
      +    Collection function: returns True if the array contains the given value. The collection
      +    elements and value must be of the same type.
      +
      +    :param col: name of column containing array
      +    :param value: value to check for in array
      +
      +    >>> df = sqlContext.createDataFrame([(["a", "b", "c"],), ([],)], ['data'])
      +    >>> df.select(array_contains(df.data, "a")).collect()
      +    [Row(array_contains(data,a)=True), Row(array_contains(data,a)=False)]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))
      +
      +
      +@since(1.4)
      +def explode(col):
      +    """Returns a new row for each element in the given array or map.
      +
      +    >>> from pyspark.sql import Row
      +    >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})])
      +    >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect()
      +    [Row(anInt=1), Row(anInt=2), Row(anInt=3)]
      +
      +    >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show()
      +    +---+-----+
      +    |key|value|
      +    +---+-----+
      +    |  a|    b|
      +    +---+-----+
      +    """
      +    sc = SparkContext._active_spark_context
      +    jc = sc._jvm.functions.explode(_to_java_column(col))
      +    return Column(jc)
      +
      +
      +@since(1.5)
      +def size(col):
      +    """
      +    Collection function: returns the length of the array or map stored in the column.
      +
      +    :param col: name of column or expression
      +
      +    >>> df = sqlContext.createDataFrame([([1, 2, 3],),([1],),([],)], ['data'])
      +    >>> df.select(size(df.data)).collect()
      +    [Row(size(data)=3), Row(size(data)=1), Row(size(data)=0)]
      +    """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.size(_to_java_column(col)))
      +
      +
      +@since(1.5)
      +def sort_array(col, asc=True):
      +    """
      +    Collection function: sorts the input array for the given column in ascending order.
      +
      +    :param col: name of column or expression
      +
      +    >>> df = sqlContext.createDataFrame([([2, 1, 3],),([1],),([],)], ['data'])
      +    >>> df.select(sort_array(df.data).alias('r')).collect()
      +    [Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])]
      +    >>> df.select(sort_array(df.data, asc=False).alias('r')).collect()
      +    [Row(r=[3, 2, 1]), Row(r=[1]), Row(r=[])]
      +     """
      +    sc = SparkContext._active_spark_context
      +    return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc))
      +
      +
      +# ---------------------------- User Defined Function ----------------------------------
      +
      +class UserDefinedFunction(object):
      +    """
      +    User defined function in Python
      +
      +    .. versionadded:: 1.3
      +    """
      +    def __init__(self, func, returnType, name=None):
      +        self.func = func
      +        self.returnType = returnType
      +        self._broadcast = None
      +        self._judf = self._create_judf(name)
      +
      +    def _create_judf(self, name):
      +        f, returnType = self.func, self.returnType  # put them in closure `func`
      +        func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it)
      +        ser = AutoBatchedSerializer(PickleSerializer())
      +        command = (func, None, ser, ser)
      +        sc = SparkContext._active_spark_context
      +        pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
      +        ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
      +        jdt = ssql_ctx.parseDataType(self.returnType.json())
      +        if name is None:
      +            name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
      +        judf = sc._jvm.UserDefinedPythonFunction(name, bytearray(pickled_command), env, includes,
      +                                                 sc.pythonExec, sc.pythonVer, broadcast_vars,
      +                                                 sc._javaAccumulator, jdt)
      +        return judf
      +
      +    def __del__(self):
      +        if self._broadcast is not None:
      +            self._broadcast.unpersist()
      +            self._broadcast = None
      +
      +    def __call__(self, *cols):
      +        sc = SparkContext._active_spark_context
      +        jc = self._judf.apply(_to_seq(sc, cols, _to_java_column))
      +        return Column(jc)
      +
      +
      +@since(1.3)
      +def udf(f, returnType=StringType()):
      +    """Creates a :class:`Column` expression representing a user defined function (UDF).
      +
      +    >>> from pyspark.sql.types import IntegerType
      +    >>> slen = udf(lambda s: len(s), IntegerType())
      +    >>> df.select(slen(df.name).alias('slen')).collect()
      +    [Row(slen=5), Row(slen=3)]
      +    """
      +    return UserDefinedFunction(f, returnType)
      +
      +blacklist = ['map', 'since', 'ignore_unicode_prefix']
      +__all__ = [k for k, v in globals().items()
      +           if not k.startswith('_') and k[0].islower() and callable(v) and k not in blacklist]
      +__all__.sort()
       
       
       def _test():
      diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
      index 5a37a673ee80..71c0bccc5eef 100644
      --- a/python/pyspark/sql/group.py
      +++ b/python/pyspark/sql/group.py
      @@ -15,8 +15,8 @@
       # limitations under the License.
       #
       
      +from pyspark import since
       from pyspark.rdd import ignore_unicode_prefix
      -from pyspark.sql import since
       from pyspark.sql.column import Column, _to_seq
       from pyspark.sql.dataframe import DataFrame
       from pyspark.sql.types import *
      @@ -75,11 +75,11 @@ def agg(self, *exprs):
       
               >>> gdf = df.groupBy(df.name)
               >>> gdf.agg({"*": "count"}).collect()
      -        [Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)]
      +        [Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)]
       
               >>> from pyspark.sql import functions as F
               >>> gdf.agg(F.min(df.age)).collect()
      -        [Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)]
      +        [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)]
               """
               assert exprs, "exprs should not be empty"
               if len(exprs) == 1 and isinstance(exprs[0], dict):
      @@ -110,9 +110,9 @@ def mean(self, *cols):
               :param cols: list of column names (string). Non-numeric columns are ignored.
       
               >>> df.groupBy().mean('age').collect()
      -        [Row(AVG(age)=3.5)]
      +        [Row(avg(age)=3.5)]
               >>> df3.groupBy().mean('age', 'height').collect()
      -        [Row(AVG(age)=3.5, AVG(height)=82.5)]
      +        [Row(avg(age)=3.5, avg(height)=82.5)]
               """
       
           @df_varargs_api
      @@ -125,9 +125,9 @@ def avg(self, *cols):
               :param cols: list of column names (string). Non-numeric columns are ignored.
       
               >>> df.groupBy().avg('age').collect()
      -        [Row(AVG(age)=3.5)]
      +        [Row(avg(age)=3.5)]
               >>> df3.groupBy().avg('age', 'height').collect()
      -        [Row(AVG(age)=3.5, AVG(height)=82.5)]
      +        [Row(avg(age)=3.5, avg(height)=82.5)]
               """
       
           @df_varargs_api
      @@ -136,9 +136,9 @@ def max(self, *cols):
               """Computes the max value for each numeric columns for each group.
       
               >>> df.groupBy().max('age').collect()
      -        [Row(MAX(age)=5)]
      +        [Row(max(age)=5)]
               >>> df3.groupBy().max('age', 'height').collect()
      -        [Row(MAX(age)=5, MAX(height)=85)]
      +        [Row(max(age)=5, max(height)=85)]
               """
       
           @df_varargs_api
      @@ -149,9 +149,9 @@ def min(self, *cols):
               :param cols: list of column names (string). Non-numeric columns are ignored.
       
               >>> df.groupBy().min('age').collect()
      -        [Row(MIN(age)=2)]
      +        [Row(min(age)=2)]
               >>> df3.groupBy().min('age', 'height').collect()
      -        [Row(MIN(age)=2, MIN(height)=80)]
      +        [Row(min(age)=2, min(height)=80)]
               """
       
           @df_varargs_api
      @@ -162,9 +162,9 @@ def sum(self, *cols):
               :param cols: list of column names (string). Non-numeric columns are ignored.
       
               >>> df.groupBy().sum('age').collect()
      -        [Row(SUM(age)=7)]
      +        [Row(sum(age)=7)]
               >>> df3.groupBy().sum('age', 'height').collect()
      -        [Row(SUM(age)=7, SUM(height)=165)]
      +        [Row(sum(age)=7, sum(height)=165)]
               """
       
       
      diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
      index f036644acc96..f43d8bf646a9 100644
      --- a/python/pyspark/sql/readwriter.py
      +++ b/python/pyspark/sql/readwriter.py
      @@ -15,15 +15,30 @@
       # limitations under the License.
       #
       
      +import sys
      +
      +if sys.version >= '3':
      +    basestring = unicode = str
      +
       from py4j.java_gateway import JavaClass
       
      -from pyspark.sql import since
      +from pyspark import RDD, since
       from pyspark.sql.column import _to_seq
       from pyspark.sql.types import *
       
       __all__ = ["DataFrameReader", "DataFrameWriter"]
       
       
      +def to_str(value):
      +    """
      +    A wrapper over str(), but convert bool values to lower case string
      +    """
      +    if isinstance(value, bool):
      +        return str(value).lower()
      +    else:
      +        return str(value)
      +
      +
       class DataFrameReader(object):
           """
           Interface used to load a :class:`DataFrame` from external storage systems
      @@ -73,12 +88,19 @@ def schema(self, schema):
               self._jreader = self._jreader.schema(jschema)
               return self
       
      +    @since(1.5)
      +    def option(self, key, value):
      +        """Adds an input option for the underlying data source.
      +        """
      +        self._jreader = self._jreader.option(key, to_str(value))
      +        return self
      +
           @since(1.4)
           def options(self, **options):
               """Adds input options for the underlying data source.
               """
               for k in options:
      -            self._jreader = self._jreader.option(k, options[k])
      +            self._jreader = self._jreader.option(k, to_str(options[k]))
               return self
       
           @since(1.4)
      @@ -90,7 +112,8 @@ def load(self, path=None, format=None, schema=None, **options):
               :param schema: optional :class:`StructType` for the input schema.
               :param options: all other string options
       
      -        >>> df = sqlContext.read.load('python/test_support/sql/parquet_partitioned')
      +        >>> df = sqlContext.read.load('python/test_support/sql/parquet_partitioned', opt1=True,
      +        ...     opt2=1, opt3='str')
               >>> df.dtypes
               [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
               """
      @@ -107,23 +130,33 @@ def load(self, path=None, format=None, schema=None, **options):
           @since(1.4)
           def json(self, path, schema=None):
               """
      -        Loads a JSON file (one object per line) and returns the result as
      -        a :class`DataFrame`.
      +        Loads a JSON file (one object per line) or an RDD of Strings storing JSON objects
      +        (one object per record) and returns the result as a :class`DataFrame`.
       
               If the ``schema`` parameter is not specified, this function goes
               through the input once to determine the input schema.
       
      -        :param path: string, path to the JSON dataset.
      +        :param path: string represents path to the JSON dataset,
      +                     or RDD of Strings storing JSON objects.
               :param schema: an optional :class:`StructType` for the input schema.
       
      -        >>> df = sqlContext.read.json('python/test_support/sql/people.json')
      -        >>> df.dtypes
      +        >>> df1 = sqlContext.read.json('python/test_support/sql/people.json')
      +        >>> df1.dtypes
      +        [('age', 'bigint'), ('name', 'string')]
      +        >>> rdd = sc.textFile('python/test_support/sql/people.json')
      +        >>> df2 = sqlContext.read.json(rdd)
      +        >>> df2.dtypes
               [('age', 'bigint'), ('name', 'string')]
       
               """
               if schema is not None:
                   self.schema(schema)
      -        return self._df(self._jreader.json(path))
      +        if isinstance(path, basestring):
      +            return self._df(self._jreader.json(path))
      +        elif isinstance(path, RDD):
      +            return self._df(self._jreader.json(path._jrdd))
      +        else:
      +            raise TypeError("path can be only string or RDD")
       
           @since(1.4)
           def table(self, tableName):
      @@ -139,18 +172,32 @@ def table(self, tableName):
               return self._df(self._jreader.table(tableName))
       
           @since(1.4)
      -    def parquet(self, *path):
      +    def parquet(self, *paths):
               """Loads a Parquet file, returning the result as a :class:`DataFrame`.
       
               >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned')
               >>> df.dtypes
               [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
               """
      -        return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, path)))
      +        return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, paths)))
      +
      +    @since(1.5)
      +    def orc(self, path):
      +        """
      +        Loads an ORC file, returning the result as a :class:`DataFrame`.
      +
      +        ::Note: Currently ORC support is only available together with
      +        :class:`HiveContext`.
      +
      +        >>> df = hiveContext.read.orc('python/test_support/sql/orc_partitioned')
      +        >>> df.dtypes
      +        [('a', 'bigint'), ('b', 'int'), ('c', 'int')]
      +        """
      +        return self._df(self._jreader.orc(path))
       
           @since(1.4)
           def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPartitions=None,
      -             predicates=None, properties={}):
      +             predicates=None, properties=None):
               """
               Construct a :class:`DataFrame` representing the database table accessible
               via JDBC URL `url` named `table` and connection `properties`.
      @@ -176,6 +223,8 @@ def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPar
                                  should be included.
               :return: a DataFrame
               """
      +        if properties is None:
      +            properties = dict()
               jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)()
               for k in properties:
                   jprop.setProperty(k, properties[k])
      @@ -218,7 +267,10 @@ def mode(self, saveMode):
       
               >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
               """
      -        self._jwrite = self._jwrite.mode(saveMode)
      +        # At the JVM side, the default value of mode is already set to "error".
      +        # So, if the given saveMode is None, we will not call JVM-side's mode method.
      +        if saveMode is not None:
      +            self._jwrite = self._jwrite.mode(saveMode)
               return self
       
           @since(1.4)
      @@ -232,6 +284,13 @@ def format(self, source):
               self._jwrite = self._jwrite.format(source)
               return self
       
      +    @since(1.5)
      +    def option(self, key, value):
      +        """Adds an output option for the underlying data source.
      +        """
      +        self._jwrite = self._jwrite.option(key, value)
      +        return self
      +
           @since(1.4)
           def options(self, **options):
               """Adds output options for the underlying data source.
      @@ -257,7 +316,7 @@ def partitionBy(self, *cols):
               return self
       
           @since(1.4)
      -    def save(self, path=None, format=None, mode="error", **options):
      +    def save(self, path=None, format=None, mode=None, partitionBy=None, **options):
               """Saves the contents of the :class:`DataFrame` to a data source.
       
               The data source is specified by the ``format`` and a set of ``options``.
      @@ -272,11 +331,14 @@ def save(self, path=None, format=None, mode="error", **options):
                   * ``overwrite``: Overwrite existing data.
                   * ``ignore``: Silently ignore this operation if data already exists.
                   * ``error`` (default case): Throw an exception if data already exists.
      +        :param partitionBy: names of partitioning columns
               :param options: all other string options
       
               >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
               """
               self.mode(mode).options(**options)
      +        if partitionBy is not None:
      +            self.partitionBy(partitionBy)
               if format is not None:
                   self.format(format)
               if path is None:
      @@ -296,7 +358,7 @@ def insertInto(self, tableName, overwrite=False):
               self._jwrite.mode("overwrite" if overwrite else "append").insertInto(tableName)
       
           @since(1.4)
      -    def saveAsTable(self, name, format=None, mode="error", **options):
      +    def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options):
               """Saves the content of the :class:`DataFrame` as the specified table.
       
               In the case the table already exists, behavior of this function depends on the
      @@ -312,15 +374,18 @@ def saveAsTable(self, name, format=None, mode="error", **options):
               :param name: the table name
               :param format: the format used to save
               :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
      +        :param partitionBy: names of partitioning columns
               :param options: all other string options
               """
               self.mode(mode).options(**options)
      +        if partitionBy is not None:
      +            self.partitionBy(partitionBy)
               if format is not None:
                   self.format(format)
               self._jwrite.saveAsTable(name)
       
           @since(1.4)
      -    def json(self, path, mode="error"):
      +    def json(self, path, mode=None):
               """Saves the content of the :class:`DataFrame` in JSON format at the specified path.
       
               :param path: the path in any Hadoop supported file system
      @@ -333,10 +398,10 @@ def json(self, path, mode="error"):
       
               >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
               """
      -        self._jwrite.mode(mode).json(path)
      +        self.mode(mode)._jwrite.json(path)
       
           @since(1.4)
      -    def parquet(self, path, mode="error"):
      +    def parquet(self, path, mode=None, partitionBy=None):
               """Saves the content of the :class:`DataFrame` in Parquet format at the specified path.
       
               :param path: the path in any Hadoop supported file system
      @@ -346,13 +411,40 @@ def parquet(self, path, mode="error"):
                   * ``overwrite``: Overwrite existing data.
                   * ``ignore``: Silently ignore this operation if data already exists.
                   * ``error`` (default case): Throw an exception if data already exists.
      +        :param partitionBy: names of partitioning columns
       
               >>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data'))
               """
      -        self._jwrite.mode(mode).parquet(path)
      +        self.mode(mode)
      +        if partitionBy is not None:
      +            self.partitionBy(partitionBy)
      +        self._jwrite.parquet(path)
      +
      +    def orc(self, path, mode=None, partitionBy=None):
      +        """Saves the content of the :class:`DataFrame` in ORC format at the specified path.
      +
      +        ::Note: Currently ORC support is only available together with
      +        :class:`HiveContext`.
      +
      +        :param path: the path in any Hadoop supported file system
      +        :param mode: specifies the behavior of the save operation when data already exists.
      +
      +            * ``append``: Append contents of this :class:`DataFrame` to existing data.
      +            * ``overwrite``: Overwrite existing data.
      +            * ``ignore``: Silently ignore this operation if data already exists.
      +            * ``error`` (default case): Throw an exception if data already exists.
      +        :param partitionBy: names of partitioning columns
      +
      +        >>> orc_df = hiveContext.read.orc('python/test_support/sql/orc_partitioned')
      +        >>> orc_df.write.orc(os.path.join(tempfile.mkdtemp(), 'data'))
      +        """
      +        self.mode(mode)
      +        if partitionBy is not None:
      +            self.partitionBy(partitionBy)
      +        self._jwrite.orc(path)
       
           @since(1.4)
      -    def jdbc(self, url, table, mode="error", properties={}):
      +    def jdbc(self, url, table, mode=None, properties=None):
               """Saves the content of the :class:`DataFrame` to a external database table via JDBC.
       
               .. note:: Don't create too many partitions in parallel on a large cluster;\
      @@ -370,6 +462,8 @@ def jdbc(self, url, table, mode="error", properties={}):
                                  arbitrary string tag/value. Normally at least a
                                  "user" and "password" property should be included.
               """
      +        if properties is None:
      +            properties = dict()
               jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)()
               for k in properties:
                   jprop.setProperty(k, properties[k])
      @@ -381,7 +475,7 @@ def _test():
           import os
           import tempfile
           from pyspark.context import SparkContext
      -    from pyspark.sql import Row, SQLContext
      +    from pyspark.sql import Row, SQLContext, HiveContext
           import pyspark.sql.readwriter
       
           os.chdir(os.environ["SPARK_HOME"])
      @@ -393,6 +487,7 @@ def _test():
           globs['os'] = os
           globs['sc'] = sc
           globs['sqlContext'] = SQLContext(sc)
      +    globs['hiveContext'] = HiveContext(sc)
           globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned')
       
           (failure_count, test_count) = doctest.testmod(
      diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
      index b5fbb7d09882..f2172b7a27d8 100644
      --- a/python/pyspark/sql/tests.py
      +++ b/python/pyspark/sql/tests.py
      @@ -1,3 +1,4 @@
      +# -*- encoding: utf-8 -*-
       #
       # Licensed to the Apache Software Foundation (ASF) under one or more
       # contributor license agreements.  See the NOTICE file distributed with
      @@ -44,20 +45,22 @@
       from pyspark.sql.types import *
       from pyspark.sql.types import UserDefinedType, _infer_type
       from pyspark.tests import ReusedPySparkTestCase
      -from pyspark.sql.functions import UserDefinedFunction
      +from pyspark.sql.functions import UserDefinedFunction, sha2
       from pyspark.sql.window import Window
      +from pyspark.sql.utils import AnalysisException, IllegalArgumentException
       
       
      -class UTC(datetime.tzinfo):
      -    """UTC"""
      -    ZERO = datetime.timedelta(0)
      +class UTCOffsetTimezone(datetime.tzinfo):
      +    """
      +    Specifies timezone in UTC offset
      +    """
      +
      +    def __init__(self, offset=0):
      +        self.ZERO = datetime.timedelta(hours=offset)
       
           def utcoffset(self, dt):
               return self.ZERO
       
      -    def tzname(self, dt):
      -        return "UTC"
      -
           def dst(self, dt):
               return self.ZERO
       
      @@ -73,7 +76,7 @@ def sqlType(self):
       
           @classmethod
           def module(cls):
      -        return 'pyspark.tests'
      +        return 'pyspark.sql.tests'
       
           @classmethod
           def scalaUDT(cls):
      @@ -104,10 +107,51 @@ def __str__(self):
               return "(%s,%s)" % (self.x, self.y)
       
           def __eq__(self, other):
      -        return isinstance(other, ExamplePoint) and \
      +        return isinstance(other, self.__class__) and \
                   other.x == self.x and other.y == self.y
       
       
      +class PythonOnlyUDT(UserDefinedType):
      +    """
      +    User-defined type (UDT) for ExamplePoint.
      +    """
      +
      +    @classmethod
      +    def sqlType(self):
      +        return ArrayType(DoubleType(), False)
      +
      +    @classmethod
      +    def module(cls):
      +        return '__main__'
      +
      +    def serialize(self, obj):
      +        return [obj.x, obj.y]
      +
      +    def deserialize(self, datum):
      +        return PythonOnlyPoint(datum[0], datum[1])
      +
      +    @staticmethod
      +    def foo():
      +        pass
      +
      +    @property
      +    def props(self):
      +        return {}
      +
      +
      +class PythonOnlyPoint(ExamplePoint):
      +    """
      +    An example class to demonstrate UDT in only Python
      +    """
      +    __UDT__ = PythonOnlyUDT()
      +
      +
      +class MyObject(object):
      +    def __init__(self, key, value):
      +        self.key = key
      +        self.value = value
      +
      +
       class DataTypeTests(unittest.TestCase):
           # regression test for SPARK-6055
           def test_data_type_eq(self):
      @@ -124,6 +168,11 @@ def test_decimal_type(self):
               t3 = DecimalType(8)
               self.assertNotEqual(t2, t3)
       
      +    # regression test for SPARK-10392
      +    def test_datetype_equal_zero(self):
      +        dt = DateType()
      +        self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1))
      +
       
       class SQLTests(ReusedPySparkTestCase):
       
      @@ -142,6 +191,21 @@ def tearDownClass(cls):
               ReusedPySparkTestCase.tearDownClass()
               shutil.rmtree(cls.tempdir.name, ignore_errors=True)
       
      +    def test_row_should_be_read_only(self):
      +        row = Row(a=1, b=2)
      +        self.assertEqual(1, row.a)
      +
      +        def foo():
      +            row.a = 3
      +        self.assertRaises(Exception, foo)
      +
      +        row2 = self.sqlCtx.range(10).first()
      +        self.assertEqual(0, row2.id)
      +
      +        def foo2():
      +            row2.id = 2
      +        self.assertRaises(Exception, foo2)
      +
           def test_range(self):
               self.assertEqual(self.sqlCtx.range(1, 1).count(), 0)
               self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1)
      @@ -149,6 +213,17 @@ def test_range(self):
               self.assertEqual(self.sqlCtx.range(-2).count(), 0)
               self.assertEqual(self.sqlCtx.range(3).count(), 3)
       
      +    def test_duplicated_column_names(self):
      +        df = self.sqlCtx.createDataFrame([(1, 2)], ["c", "c"])
      +        row = df.select('*').first()
      +        self.assertEqual(1, row[0])
      +        self.assertEqual(2, row[1])
      +        self.assertEqual("Row(c=1, c=2)", str(row))
      +        # Cannot access columns
      +        self.assertRaises(AnalysisException, lambda: df.select(df[0]).first())
      +        self.assertRaises(AnalysisException, lambda: df.select(df.c).first())
      +        self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first())
      +
           def test_explode(self):
               from pyspark.sql.functions import explode
               d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})]
      @@ -164,6 +239,14 @@ def test_explode(self):
               self.assertEqual(result[0][0], "a")
               self.assertEqual(result[0][1], "b")
       
      +    def test_and_in_expression(self):
      +        self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count())
      +        self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2"))
      +        self.assertEqual(14, self.df.filter((self.df.key <= 3) | (self.df.value < "2")).count())
      +        self.assertRaises(ValueError, lambda: self.df.key <= 3 or self.df.value < "2")
      +        self.assertEqual(99, self.df.filter(~(self.df.key == 1)).count())
      +        self.assertRaises(ValueError, lambda: not self.df.key == 1)
      +
           def test_udf_with_callable(self):
               d = [Row(number=i, squared=i**2) for i in range(10)]
               rdd = self.sc.parallelize(d)
      @@ -312,6 +395,16 @@ def test_infer_nested_schema(self):
               df = self.sqlCtx.inferSchema(rdd)
               self.assertEquals(Row(field1=1, field2=u'row1'), df.first())
       
      +    def test_create_dataframe_from_objects(self):
      +        data = [MyObject(1, "1"), MyObject(2, "2")]
      +        df = self.sqlCtx.createDataFrame(data)
      +        self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")])
      +        self.assertEqual(df.first(), Row(key=1, value="1"))
      +
      +    def test_select_null_literal(self):
      +        df = self.sqlCtx.sql("select null as col")
      +        self.assertEquals(Row(col=None), df.first())
      +
           def test_apply_schema(self):
               from datetime import date, datetime
               rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0,
      @@ -370,10 +463,39 @@ def test_convert_row_to_dict(self):
               self.assertEqual(1, row.asDict()["l"][0].a)
               self.assertEqual(1.0, row.asDict()['d']['key'].c)
       
      +    def test_udt(self):
      +        from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _verify_type
      +        from pyspark.sql.tests import ExamplePointUDT, ExamplePoint
      +
      +        def check_datatype(datatype):
      +            pickled = pickle.loads(pickle.dumps(datatype))
      +            assert datatype == pickled
      +            scala_datatype = self.sqlCtx._ssql_ctx.parseDataType(datatype.json())
      +            python_datatype = _parse_datatype_json_string(scala_datatype.json())
      +            assert datatype == python_datatype
      +
      +        check_datatype(ExamplePointUDT())
      +        structtype_with_udt = StructType([StructField("label", DoubleType(), False),
      +                                          StructField("point", ExamplePointUDT(), False)])
      +        check_datatype(structtype_with_udt)
      +        p = ExamplePoint(1.0, 2.0)
      +        self.assertEqual(_infer_type(p), ExamplePointUDT())
      +        _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
      +        self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], ExamplePointUDT()))
      +
      +        check_datatype(PythonOnlyUDT())
      +        structtype_with_udt = StructType([StructField("label", DoubleType(), False),
      +                                          StructField("point", PythonOnlyUDT(), False)])
      +        check_datatype(structtype_with_udt)
      +        p = PythonOnlyPoint(1.0, 2.0)
      +        self.assertEqual(_infer_type(p), PythonOnlyUDT())
      +        _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT())
      +        self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT()))
      +
           def test_infer_schema_with_udt(self):
               from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
               row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
      -        df = self.sc.parallelize([row]).toDF()
      +        df = self.sqlCtx.createDataFrame([row])
               schema = df.schema
               field = [f for f in schema.fields if f.name == "point"][0]
               self.assertEqual(type(field.dataType), ExamplePointUDT)
      @@ -381,34 +503,74 @@ def test_infer_schema_with_udt(self):
               point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
               self.assertEqual(point, ExamplePoint(1.0, 2.0))
       
      +        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
      +        df = self.sqlCtx.createDataFrame([row])
      +        schema = df.schema
      +        field = [f for f in schema.fields if f.name == "point"][0]
      +        self.assertEqual(type(field.dataType), PythonOnlyUDT)
      +        df.registerTempTable("labeled_point")
      +        point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
      +        self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
      +
           def test_apply_schema_with_udt(self):
               from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
               row = (1.0, ExamplePoint(1.0, 2.0))
      -        rdd = self.sc.parallelize([row])
               schema = StructType([StructField("label", DoubleType(), False),
                                    StructField("point", ExamplePointUDT(), False)])
      -        df = rdd.toDF(schema)
      +        df = self.sqlCtx.createDataFrame([row], schema)
               point = df.head().point
               self.assertEquals(point, ExamplePoint(1.0, 2.0))
       
      +        row = (1.0, PythonOnlyPoint(1.0, 2.0))
      +        schema = StructType([StructField("label", DoubleType(), False),
      +                             StructField("point", PythonOnlyUDT(), False)])
      +        df = self.sqlCtx.createDataFrame([row], schema)
      +        point = df.head().point
      +        self.assertEquals(point, PythonOnlyPoint(1.0, 2.0))
      +
      +    def test_udf_with_udt(self):
      +        from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
      +        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
      +        df = self.sqlCtx.createDataFrame([row])
      +        self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
      +        udf = UserDefinedFunction(lambda p: p.y, DoubleType())
      +        self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
      +        udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
      +        self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
      +
      +        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
      +        df = self.sqlCtx.createDataFrame([row])
      +        self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
      +        udf = UserDefinedFunction(lambda p: p.y, DoubleType())
      +        self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
      +        udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
      +        self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
      +
           def test_parquet_with_udt(self):
      -        from pyspark.sql.tests import ExamplePoint
      +        from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
               row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
      -        df0 = self.sc.parallelize([row]).toDF()
      +        df0 = self.sqlCtx.createDataFrame([row])
               output_dir = os.path.join(self.tempdir.name, "labeled_point")
      -        df0.saveAsParquetFile(output_dir)
      +        df0.write.parquet(output_dir)
               df1 = self.sqlCtx.parquetFile(output_dir)
               point = df1.head().point
               self.assertEquals(point, ExamplePoint(1.0, 2.0))
       
      +        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
      +        df0 = self.sqlCtx.createDataFrame([row])
      +        df0.write.parquet(output_dir, mode='overwrite')
      +        df1 = self.sqlCtx.parquetFile(output_dir)
      +        point = df1.head().point
      +        self.assertEquals(point, PythonOnlyPoint(1.0, 2.0))
      +
           def test_column_operators(self):
               ci = self.df.key
               cs = self.df.value
               c = ci == cs
               self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
      -        rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci)
      +        rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci ** 1)
               self.assertTrue(all(isinstance(c, Column) for c in rcc))
      -        cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs]
      +        cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7]
               self.assertTrue(all(isinstance(c, Column) for c in cb))
               cbool = (ci & ci), (ci | ci), (~ci)
               self.assertTrue(all(isinstance(c, Column) for c in cbool))
      @@ -500,6 +662,16 @@ def test_rand_functions(self):
               for row in rndn:
                   assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]
       
      +        # If the specified seed is 0, we should use it.
      +        # https://issues.apache.org/jira/browse/SPARK-9691
      +        rnd1 = df.select('key', functions.rand(0)).collect()
      +        rnd2 = df.select('key', functions.rand(0)).collect()
      +        self.assertEqual(sorted(rnd1), sorted(rnd2))
      +
      +        rndn1 = df.select('key', functions.randn(0)).collect()
      +        rndn2 = df.select('key', functions.randn(0)).collect()
      +        self.assertEqual(sorted(rndn1), sorted(rndn2))
      +
           def test_between_function(self):
               df = self.sc.parallelize([
                   Row(a=1, b=2, c=3),
      @@ -508,6 +680,35 @@ def test_between_function(self):
               self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)],
                                df.filter(df.a.between(df.b, df.c)).collect())
       
      +    def test_struct_type(self):
      +        from pyspark.sql.types import StructType, StringType, StructField
      +        struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
      +        struct2 = StructType([StructField("f1", StringType(), True),
      +                              StructField("f2", StringType(), True, None)])
      +        self.assertEqual(struct1, struct2)
      +
      +        struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
      +        struct2 = StructType([StructField("f1", StringType(), True)])
      +        self.assertNotEqual(struct1, struct2)
      +
      +        struct1 = (StructType().add(StructField("f1", StringType(), True))
      +                   .add(StructField("f2", StringType(), True, None)))
      +        struct2 = StructType([StructField("f1", StringType(), True),
      +                              StructField("f2", StringType(), True, None)])
      +        self.assertEqual(struct1, struct2)
      +
      +        struct1 = (StructType().add(StructField("f1", StringType(), True))
      +                   .add(StructField("f2", StringType(), True, None)))
      +        struct2 = StructType([StructField("f1", StringType(), True)])
      +        self.assertNotEqual(struct1, struct2)
      +
      +        # Catch exception raised during improper construction
      +        try:
      +            struct1 = StructType().add("name")
      +            self.assertEqual(1, 0)
      +        except ValueError:
      +            self.assertEqual(1, 1)
      +
           def test_save_and_load(self):
               df = self.df
               tmpPath = tempfile.mkdtemp()
      @@ -539,6 +740,39 @@ def test_save_and_load(self):
       
               shutil.rmtree(tmpPath)
       
      +    def test_save_and_load_builder(self):
      +        df = self.df
      +        tmpPath = tempfile.mkdtemp()
      +        shutil.rmtree(tmpPath)
      +        df.write.json(tmpPath)
      +        actual = self.sqlCtx.read.json(tmpPath)
      +        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
      +
      +        schema = StructType([StructField("value", StringType(), True)])
      +        actual = self.sqlCtx.read.json(tmpPath, schema)
      +        self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
      +
      +        df.write.mode("overwrite").json(tmpPath)
      +        actual = self.sqlCtx.read.json(tmpPath)
      +        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
      +
      +        df.write.mode("overwrite").options(noUse="this options will not be used in save.")\
      +                .option("noUse", "this option will not be used in save.")\
      +                .format("json").save(path=tmpPath)
      +        actual =\
      +            self.sqlCtx.read.format("json")\
      +                            .load(path=tmpPath, noUse="this options will not be used in load.")
      +        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
      +
      +        defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
      +                                                    "org.apache.spark.sql.parquet")
      +        self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
      +        actual = self.sqlCtx.load(path=tmpPath)
      +        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
      +        self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
      +
      +        shutil.rmtree(tmpPath)
      +
           def test_help_command(self):
               # Regression test for SPARK-5464
               rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
      @@ -554,9 +788,17 @@ def test_access_column(self):
               self.assertTrue(isinstance(df['key'], Column))
               self.assertTrue(isinstance(df[0], Column))
               self.assertRaises(IndexError, lambda: df[2])
      -        self.assertRaises(IndexError, lambda: df["bad_key"])
      +        self.assertRaises(AnalysisException, lambda: df["bad_key"])
               self.assertRaises(TypeError, lambda: df[{}])
       
      +    def test_column_name_with_non_ascii(self):
      +        df = self.sqlCtx.createDataFrame([(1,)], ["数量"])
      +        self.assertEqual(StructType([StructField("数量", LongType(), True)]), df.schema)
      +        self.assertEqual("DataFrame[数量: bigint]", str(df))
      +        self.assertEqual([("数量", 'bigint')], df.dtypes)
      +        self.assertEqual(1, df.select("数量").first()[0])
      +        self.assertEqual(1, df.select(df["数量"]).first()[0])
      +
           def test_access_nested_types(self):
               df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
               self.assertEqual(1, df.select(df.l[0]).first()[0])
      @@ -570,7 +812,9 @@ def test_field_accessor(self):
               df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
               self.assertEqual(1, df.select(df.l[0]).first()[0])
               self.assertEqual(1, df.select(df.r["a"]).first()[0])
      +        self.assertEqual(1, df.select(df["r.a"]).first()[0])
               self.assertEqual("b", df.select(df.r["b"]).first()[0])
      +        self.assertEqual("b", df.select(df["r.b"]).first()[0])
               self.assertEqual("v", df.select(df.d["k"]).first()[0])
       
           def test_infer_long_type(self):
      @@ -603,22 +847,43 @@ def test_filter_with_datetime(self):
               self.assertEqual(0, df.filter(df.date > date).count())
               self.assertEqual(0, df.filter(df.time > time).count())
       
      +    def test_filter_with_datetime_timezone(self):
      +        dt1 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(0))
      +        dt2 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(1))
      +        row = Row(date=dt1)
      +        df = self.sqlCtx.createDataFrame([row])
      +        self.assertEqual(0, df.filter(df.date == dt2).count())
      +        self.assertEqual(1, df.filter(df.date > dt2).count())
      +        self.assertEqual(0, df.filter(df.date < dt2).count())
      +
           def test_time_with_timezone(self):
               day = datetime.date.today()
               now = datetime.datetime.now()
      -        ts = time.mktime(now.timetuple()) + now.microsecond / 1e6
      +        ts = time.mktime(now.timetuple())
               # class in __main__ is not serializable
      -        from pyspark.sql.tests import UTC
      -        utc = UTC()
      -        utcnow = datetime.datetime.fromtimestamp(ts, utc)
      +        from pyspark.sql.tests import UTCOffsetTimezone
      +        utc = UTCOffsetTimezone()
      +        utcnow = datetime.datetime.utcfromtimestamp(ts)  # without microseconds
      +        # add microseconds to utcnow (keeping year,month,day,hour,minute,second)
      +        utcnow = datetime.datetime(*(utcnow.timetuple()[:6] + (now.microsecond, utc)))
               df = self.sqlCtx.createDataFrame([(day, now, utcnow)])
               day1, now1, utcnow1 = df.first()
      -        # Pyrolite serialize java.sql.Date as datetime, will be fixed in new version
      -        self.assertEqual(day1.date(), day)
      -        # Pyrolite does not support microsecond, the error should be
      -        # less than 1 millisecond
      -        self.assertTrue(now - now1 < datetime.timedelta(0.001))
      -        self.assertTrue(now - utcnow1 < datetime.timedelta(0.001))
      +        self.assertEqual(day1, day)
      +        self.assertEqual(now, now1)
      +        self.assertEqual(now, utcnow1)
      +
      +    def test_decimal(self):
      +        from decimal import Decimal
      +        schema = StructType([StructField("decimal", DecimalType(10, 5))])
      +        df = self.sqlCtx.createDataFrame([(Decimal("3.14159"),)], schema)
      +        row = df.select(df.decimal + 1).first()
      +        self.assertEqual(row[0], Decimal("4.14159"))
      +        tmpPath = tempfile.mkdtemp()
      +        shutil.rmtree(tmpPath)
      +        df.write.parquet(tmpPath)
      +        df2 = self.sqlCtx.read.parquet(tmpPath)
      +        row = df2.first()
      +        self.assertEqual(row[0], Decimal("3.14159"))
       
           def test_dropna(self):
               schema = StructType([
      @@ -729,6 +994,13 @@ def test_bitwise_operations(self):
               result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict()
               self.assertEqual(~75, result['~b'])
       
      +    def test_expr(self):
      +        from pyspark.sql import functions
      +        row = Row(a="length string", b=75)
      +        df = self.sqlCtx.createDataFrame([row])
      +        result = df.select(functions.expr("length(a)")).collect()[0].asDict()
      +        self.assertEqual(13, result["'length(a)"])
      +
           def test_replace(self):
               schema = StructType([
                   StructField("name", StringType(), True),
      @@ -777,6 +1049,32 @@ def test_replace(self):
               self.assertEqual(row.age, 10)
               self.assertEqual(row.height, None)
       
      +    def test_capture_analysis_exception(self):
      +        self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc"))
      +        self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b"))
      +        # RuntimeException should not be captured
      +        self.assertRaises(py4j.protocol.Py4JJavaError, lambda: self.sqlCtx.sql("abc"))
      +
      +    def test_capture_illegalargument_exception(self):
      +        self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks",
      +                                lambda: self.sqlCtx.sql("SET mapred.reduce.tasks=-1"))
      +        df = self.sqlCtx.createDataFrame([(1, 2)], ["a", "b"])
      +        self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values",
      +                                lambda: df.select(sha2(df.a, 1024)).collect())
      +
      +    def test_with_column_with_existing_name(self):
      +        keys = self.df.withColumn("key", self.df.key).select("key").collect()
      +        self.assertEqual([r.key for r in keys], list(range(100)))
      +
      +    # regression test for SPARK-10417
      +    def test_column_iterator(self):
      +
      +        def foo():
      +            for x in self.df.key:
      +                break
      +
      +        self.assertRaises(TypeError, foo)
      +
       
       class HiveContextSQLTests(ReusedPySparkTestCase):
       
      @@ -868,5 +1166,28 @@ def test_window_functions(self):
               for r, ex in zip(rs, expected):
                   self.assertEqual(tuple(r), ex[:len(r)])
       
      +    def test_window_functions_without_partitionBy(self):
      +        df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
      +        w = Window.orderBy("key", df.value)
      +        from pyspark.sql import functions as F
      +        sel = df.select(df.value, df.key,
      +                        F.max("key").over(w.rowsBetween(0, 1)),
      +                        F.min("key").over(w.rowsBetween(0, 1)),
      +                        F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))),
      +                        F.rowNumber().over(w),
      +                        F.rank().over(w),
      +                        F.denseRank().over(w),
      +                        F.ntile(2).over(w))
      +        rs = sorted(sel.collect())
      +        expected = [
      +            ("1", 1, 1, 1, 4, 1, 1, 1, 1),
      +            ("2", 1, 1, 1, 4, 2, 2, 2, 1),
      +            ("2", 1, 2, 1, 4, 3, 2, 2, 2),
      +            ("2", 2, 2, 2, 4, 4, 4, 3, 2)
      +        ]
      +        for r, ex in zip(rs, expected):
      +            self.assertEqual(tuple(r), ex[:len(r)])
      +
      +
       if __name__ == "__main__":
           unittest.main()
      diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
      index 23d9adb0daea..1f86894855cb 100644
      --- a/python/pyspark/sql/types.py
      +++ b/python/pyspark/sql/types.py
      @@ -20,13 +20,10 @@
       import time
       import datetime
       import calendar
      -import keyword
      -import warnings
       import json
       import re
      -import weakref
      +import base64
       from array import array
      -from operator import itemgetter
       
       if sys.version >= "3":
           long = int
      @@ -35,6 +32,8 @@
       from py4j.protocol import register_input_converter
       from py4j.java_gateway import JavaClass
       
      +from pyspark.serializers import CloudPickleSerializer
      +
       __all__ = [
           "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
           "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
      @@ -71,6 +70,26 @@ def json(self):
                                 separators=(',', ':'),
                                 sort_keys=True)
       
      +    def needConversion(self):
      +        """
      +        Does this type need to conversion between Python object and internal SQL object.
      +
      +        This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
      +        """
      +        return False
      +
      +    def toInternal(self, obj):
      +        """
      +        Converts a Python object into an internal SQL object.
      +        """
      +        return obj
      +
      +    def fromInternal(self, obj):
      +        """
      +        Converts an internal SQL object into a native Python object.
      +        """
      +        return obj
      +
       
       # This singleton pattern does not work with pickle, you will get
       # another object after pickle and unpickle
      @@ -143,6 +162,19 @@ class DateType(AtomicType):
       
           __metaclass__ = DataTypeSingleton
       
      +    EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal()
      +
      +    def needConversion(self):
      +        return True
      +
      +    def toInternal(self, d):
      +        if d is not None:
      +            return d.toordinal() - self.EPOCH_ORDINAL
      +
      +    def fromInternal(self, v):
      +        if v is not None:
      +            return datetime.date.fromordinal(v + self.EPOCH_ORDINAL)
      +
       
       class TimestampType(AtomicType):
           """Timestamp (datetime.datetime) data type.
      @@ -150,33 +182,50 @@ class TimestampType(AtomicType):
       
           __metaclass__ = DataTypeSingleton
       
      +    def needConversion(self):
      +        return True
      +
      +    def toInternal(self, dt):
      +        if dt is not None:
      +            seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo
      +                       else time.mktime(dt.timetuple()))
      +            return int(seconds * 1e6 + dt.microsecond)
      +
      +    def fromInternal(self, ts):
      +        if ts is not None:
      +            # using int to avoid precision loss in float
      +            return datetime.datetime.fromtimestamp(ts // 1000000).replace(microsecond=ts % 1000000)
      +
       
       class DecimalType(FractionalType):
           """Decimal (decimal.Decimal) data type.
      +
      +    The DecimalType must have fixed precision (the maximum total number of digits)
      +    and scale (the number of digits on the right of dot). For example, (5, 2) can
      +    support the value from [-999.99 to 999.99].
      +
      +    The precision can be up to 38, the scale must less or equal to precision.
      +
      +    When create a DecimalType, the default precision and scale is (10, 0). When infer
      +    schema from decimal.Decimal objects, it will be DecimalType(38, 18).
      +
      +    :param precision: the maximum total number of digits (default: 10)
      +    :param scale: the number of digits on right side of dot. (default: 0)
           """
       
      -    def __init__(self, precision=None, scale=None):
      +    def __init__(self, precision=10, scale=0):
               self.precision = precision
               self.scale = scale
      -        self.hasPrecisionInfo = precision is not None
      +        self.hasPrecisionInfo = True  # this is public API
       
           def simpleString(self):
      -        if self.hasPrecisionInfo:
      -            return "decimal(%d,%d)" % (self.precision, self.scale)
      -        else:
      -            return "decimal(10,0)"
      +        return "decimal(%d,%d)" % (self.precision, self.scale)
       
           def jsonValue(self):
      -        if self.hasPrecisionInfo:
      -            return "decimal(%d,%d)" % (self.precision, self.scale)
      -        else:
      -            return "decimal"
      +        return "decimal(%d,%d)" % (self.precision, self.scale)
       
           def __repr__(self):
      -        if self.hasPrecisionInfo:
      -            return "DecimalType(%d,%d)" % (self.precision, self.scale)
      -        else:
      -            return "DecimalType()"
      +        return "DecimalType(%d,%d)" % (self.precision, self.scale)
       
       
       class DoubleType(FractionalType):
      @@ -259,6 +308,19 @@ def fromJson(cls, json):
               return ArrayType(_parse_datatype_json_value(json["elementType"]),
                                json["containsNull"])
       
      +    def needConversion(self):
      +        return self.elementType.needConversion()
      +
      +    def toInternal(self, obj):
      +        if not self.needConversion():
      +            return obj
      +        return obj and [self.elementType.toInternal(v) for v in obj]
      +
      +    def fromInternal(self, obj):
      +        if not self.needConversion():
      +            return obj
      +        return obj and [self.elementType.fromInternal(v) for v in obj]
      +
       
       class MapType(DataType):
           """Map data type.
      @@ -304,6 +366,21 @@ def fromJson(cls, json):
                              _parse_datatype_json_value(json["valueType"]),
                              json["valueContainsNull"])
       
      +    def needConversion(self):
      +        return self.keyType.needConversion() or self.valueType.needConversion()
      +
      +    def toInternal(self, obj):
      +        if not self.needConversion():
      +            return obj
      +        return obj and dict((self.keyType.toInternal(k), self.valueType.toInternal(v))
      +                            for k, v in obj.items())
      +
      +    def fromInternal(self, obj):
      +        if not self.needConversion():
      +            return obj
      +        return obj and dict((self.keyType.fromInternal(k), self.valueType.fromInternal(v))
      +                            for k, v in obj.items())
      +
       
       class StructField(DataType):
           """A field in :class:`StructType`.
      @@ -311,7 +388,7 @@ class StructField(DataType):
           :param name: string, name of the field.
           :param dataType: :class:`DataType` of the field.
           :param nullable: boolean, whether the field can be null (None) or not.
      -    :param metadata: a dict from string to simple type that can be serialized to JSON automatically
      +    :param metadata: a dict from string to simple type that can be toInternald to JSON automatically
           """
       
           def __init__(self, name, dataType, nullable=True, metadata=None):
      @@ -324,6 +401,8 @@ def __init__(self, name, dataType, nullable=True, metadata=None):
               False
               """
               assert isinstance(dataType, DataType), "dataType should be DataType"
      +        if not isinstance(name, str):
      +            name = name.encode('utf-8')
               self.name = name
               self.dataType = dataType
               self.nullable = nullable
      @@ -349,14 +428,22 @@ def fromJson(cls, json):
                                  json["nullable"],
                                  json["metadata"])
       
      +    def needConversion(self):
      +        return self.dataType.needConversion()
      +
      +    def toInternal(self, obj):
      +        return self.dataType.toInternal(obj)
      +
      +    def fromInternal(self, obj):
      +        return self.dataType.fromInternal(obj)
      +
       
       class StructType(DataType):
           """Struct type, consisting of a list of :class:`StructField`.
       
           This is the data type representing a :class:`Row`.
           """
      -
      -    def __init__(self, fields):
      +    def __init__(self, fields=None):
               """
               >>> struct1 = StructType([StructField("f1", StringType(), True)])
               >>> struct2 = StructType([StructField("f1", StringType(), True)])
      @@ -368,8 +455,61 @@ def __init__(self, fields):
               >>> struct1 == struct2
               False
               """
      -        assert all(isinstance(f, DataType) for f in fields), "fields should be a list of DataType"
      -        self.fields = fields
      +        if not fields:
      +            self.fields = []
      +            self.names = []
      +        else:
      +            self.fields = fields
      +            self.names = [f.name for f in fields]
      +            assert all(isinstance(f, StructField) for f in fields),\
      +                "fields should be a list of StructField"
      +        self._needSerializeAnyField = any(f.needConversion() for f in self.fields)
      +
      +    def add(self, field, data_type=None, nullable=True, metadata=None):
      +        """
      +        Construct a StructType by adding new elements to it to define the schema. The method accepts
      +        either:
      +
      +            a) A single parameter which is a StructField object.
      +            b) Between 2 and 4 parameters as (name, data_type, nullable (optional),
      +               metadata(optional). The data_type parameter may be either a String or a
      +               DataType object.
      +
      +        >>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
      +        >>> struct2 = StructType([StructField("f1", StringType(), True),\
      +         StructField("f2", StringType(), True, None)])
      +        >>> struct1 == struct2
      +        True
      +        >>> struct1 = StructType().add(StructField("f1", StringType(), True))
      +        >>> struct2 = StructType([StructField("f1", StringType(), True)])
      +        >>> struct1 == struct2
      +        True
      +        >>> struct1 = StructType().add("f1", "string", True)
      +        >>> struct2 = StructType([StructField("f1", StringType(), True)])
      +        >>> struct1 == struct2
      +        True
      +
      +        :param field: Either the name of the field or a StructField object
      +        :param data_type: If present, the DataType of the StructField to create
      +        :param nullable: Whether the field to add should be nullable (default True)
      +        :param metadata: Any additional metadata (default None)
      +        :return: a new updated StructType
      +        """
      +        if isinstance(field, StructField):
      +            self.fields.append(field)
      +            self.names.append(field.name)
      +        else:
      +            if isinstance(field, str) and data_type is None:
      +                raise ValueError("Must specify DataType if passing name of struct_field to create.")
      +
      +            if isinstance(data_type, str):
      +                data_type_f = _parse_datatype_json_value(data_type)
      +            else:
      +                data_type_f = data_type
      +            self.fields.append(StructField(field, data_type_f, nullable, metadata))
      +            self.names.append(field)
      +        self._needSerializeAnyField = any(f.needConversion() for f in self.fields)
      +        return self
       
           def simpleString(self):
               return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields))
      @@ -386,6 +526,47 @@ def jsonValue(self):
           def fromJson(cls, json):
               return StructType([StructField.fromJson(f) for f in json["fields"]])
       
      +    def needConversion(self):
      +        # We need convert Row()/namedtuple into tuple()
      +        return True
      +
      +    def toInternal(self, obj):
      +        if obj is None:
      +            return
      +
      +        if self._needSerializeAnyField:
      +            if isinstance(obj, dict):
      +                return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields))
      +            elif isinstance(obj, (tuple, list)):
      +                return tuple(f.toInternal(v) for f, v in zip(self.fields, obj))
      +            elif hasattr(obj, "__dict__"):
      +                d = obj.__dict__
      +                return tuple(f.toInternal(d.get(n)) for n, f in zip(self.names, self.fields))
      +            else:
      +                raise ValueError("Unexpected tuple %r with StructType" % obj)
      +        else:
      +            if isinstance(obj, dict):
      +                return tuple(obj.get(n) for n in self.names)
      +            elif isinstance(obj, (list, tuple)):
      +                return tuple(obj)
      +            elif hasattr(obj, "__dict__"):
      +                d = obj.__dict__
      +                return tuple(d.get(n) for n in self.names)
      +            else:
      +                raise ValueError("Unexpected tuple %r with StructType" % obj)
      +
      +    def fromInternal(self, obj):
      +        if obj is None:
      +            return
      +        if isinstance(obj, Row):
      +            # it's already converted by pickler
      +            return obj
      +        if self._needSerializeAnyField:
      +            values = [f.fromInternal(v) for f, v in zip(self.fields, obj)]
      +        else:
      +            values = obj
      +        return _create_row(self.names, values)
      +
       
       class UserDefinedType(DataType):
           """User-defined type (UDT).
      @@ -414,21 +595,40 @@ def module(cls):
           @classmethod
           def scalaUDT(cls):
               """
      -        The class name of the paired Scala UDT.
      +        The class name of the paired Scala UDT (could be '', if there
      +        is no corresponding one).
      +        """
      +        return ''
      +
      +    def needConversion(self):
      +        return True
      +
      +    @classmethod
      +    def _cachedSqlType(cls):
      +        """
      +        Cache the sqlType() into class, because it's heavy used in `toInternal`.
               """
      -        raise NotImplementedError("UDT must have a paired Scala UDT.")
      +        if not hasattr(cls, "_cached_sql_type"):
      +            cls._cached_sql_type = cls.sqlType()
      +        return cls._cached_sql_type
      +
      +    def toInternal(self, obj):
      +        return self._cachedSqlType().toInternal(self.serialize(obj))
      +
      +    def fromInternal(self, obj):
      +        return self.deserialize(self._cachedSqlType().fromInternal(obj))
       
           def serialize(self, obj):
               """
               Converts the a user-type object into a SQL datum.
               """
      -        raise NotImplementedError("UDT must implement serialize().")
      +        raise NotImplementedError("UDT must implement toInternal().")
       
           def deserialize(self, datum):
               """
               Converts a SQL datum into a user-type object.
               """
      -        raise NotImplementedError("UDT must implement deserialize().")
      +        raise NotImplementedError("UDT must implement fromInternal().")
       
           def simpleString(self):
               return 'udt'
      @@ -437,22 +637,37 @@ def json(self):
               return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True)
       
           def jsonValue(self):
      -        schema = {
      -            "type": "udt",
      -            "class": self.scalaUDT(),
      -            "pyClass": "%s.%s" % (self.module(), type(self).__name__),
      -            "sqlType": self.sqlType().jsonValue()
      -        }
      +        if self.scalaUDT():
      +            assert self.module() != '__main__', 'UDT in __main__ cannot work with ScalaUDT'
      +            schema = {
      +                "type": "udt",
      +                "class": self.scalaUDT(),
      +                "pyClass": "%s.%s" % (self.module(), type(self).__name__),
      +                "sqlType": self.sqlType().jsonValue()
      +            }
      +        else:
      +            ser = CloudPickleSerializer()
      +            b = ser.dumps(type(self))
      +            schema = {
      +                "type": "udt",
      +                "pyClass": "%s.%s" % (self.module(), type(self).__name__),
      +                "serializedClass": base64.b64encode(b).decode('utf8'),
      +                "sqlType": self.sqlType().jsonValue()
      +            }
               return schema
       
           @classmethod
           def fromJson(cls, json):
      -        pyUDT = json["pyClass"]
      +        pyUDT = str(json["pyClass"])  # convert unicode to str
               split = pyUDT.rfind(".")
               pyModule = pyUDT[:split]
               pyClass = pyUDT[split+1:]
               m = __import__(pyModule, globals(), locals(), [pyClass])
      -        UDT = getattr(m, pyClass)
      +        if not hasattr(m, pyClass):
      +            s = base64.b64decode(json['serializedClass'].encode('utf-8'))
      +            UDT = CloudPickleSerializer().loads(s)
      +        else:
      +            UDT = getattr(m, pyClass)
               return UDT()
       
           def __eq__(self, other):
      @@ -460,7 +675,7 @@ def __eq__(self, other):
       
       
       _atomic_types = [StringType, BinaryType, BooleanType, DecimalType, FloatType, DoubleType,
      -                 ByteType, ShortType, IntegerType, LongType, DateType, TimestampType]
      +                 ByteType, ShortType, IntegerType, LongType, DateType, TimestampType, NullType]
       _all_atomic_types = dict((t.typeName(), t) for t in _atomic_types)
       _all_complex_types = dict((v.typeName(), v)
                                 for v in [ArrayType, MapType, StructType])
      @@ -511,11 +726,6 @@ def _parse_datatype_json_string(json_string):
           >>> complex_maptype = MapType(complex_structtype,
           ...                           complex_arraytype, False)
           >>> check_datatype(complex_maptype)
      -
      -    >>> check_datatype(ExamplePointUDT())
      -    >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
      -    ...                                   StructField("point", ExamplePointUDT(), False)])
      -    >>> check_datatype(structtype_with_udt)
           """
           return _parse_datatype_json_value(json.loads(json_string))
       
      @@ -567,10 +777,6 @@ def _parse_datatype_json_value(json_value):
       
       def _infer_type(obj):
           """Infer the DataType from obj
      -
      -    >>> p = ExamplePoint(1.0, 2.0)
      -    >>> _infer_type(p)
      -    ExamplePointUDT
           """
           if obj is None:
               return NullType()
      @@ -579,7 +785,10 @@ def _infer_type(obj):
               return obj.__UDT__
       
           dataType = _type_mappings.get(type(obj))
      -    if dataType is not None:
      +    if dataType is DecimalType:
      +        # the precision and scale of `obj` may be different from row to row.
      +        return DecimalType(38, 18)
      +    elif dataType is not None:
               return dataType()
       
           if isinstance(obj, dict):
      @@ -625,112 +834,6 @@ def _infer_schema(row):
           return StructType(fields)
       
       
      -def _need_python_to_sql_conversion(dataType):
      -    """
      -    Checks whether we need python to sql conversion for the given type.
      -    For now, only UDTs need this conversion.
      -
      -    >>> _need_python_to_sql_conversion(DoubleType())
      -    False
      -    >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False),
      -    ...                       StructField("values", ArrayType(DoubleType(), False), False)])
      -    >>> _need_python_to_sql_conversion(schema0)
      -    False
      -    >>> _need_python_to_sql_conversion(ExamplePointUDT())
      -    True
      -    >>> schema1 = ArrayType(ExamplePointUDT(), False)
      -    >>> _need_python_to_sql_conversion(schema1)
      -    True
      -    >>> schema2 = StructType([StructField("label", DoubleType(), False),
      -    ...                       StructField("point", ExamplePointUDT(), False)])
      -    >>> _need_python_to_sql_conversion(schema2)
      -    True
      -    """
      -    if isinstance(dataType, StructType):
      -        return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields])
      -    elif isinstance(dataType, ArrayType):
      -        return _need_python_to_sql_conversion(dataType.elementType)
      -    elif isinstance(dataType, MapType):
      -        return _need_python_to_sql_conversion(dataType.keyType) or \
      -            _need_python_to_sql_conversion(dataType.valueType)
      -    elif isinstance(dataType, UserDefinedType):
      -        return True
      -    elif isinstance(dataType, (DateType, TimestampType)):
      -        return True
      -    else:
      -        return False
      -
      -
      -EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal()
      -
      -
      -def _python_to_sql_converter(dataType):
      -    """
      -    Returns a converter that converts a Python object into a SQL datum for the given type.
      -
      -    >>> conv = _python_to_sql_converter(DoubleType())
      -    >>> conv(1.0)
      -    1.0
      -    >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False))
      -    >>> conv([1.0, 2.0])
      -    [1.0, 2.0]
      -    >>> conv = _python_to_sql_converter(ExamplePointUDT())
      -    >>> conv(ExamplePoint(1.0, 2.0))
      -    [1.0, 2.0]
      -    >>> schema = StructType([StructField("label", DoubleType(), False),
      -    ...                      StructField("point", ExamplePointUDT(), False)])
      -    >>> conv = _python_to_sql_converter(schema)
      -    >>> conv((1.0, ExamplePoint(1.0, 2.0)))
      -    (1.0, [1.0, 2.0])
      -    """
      -    if not _need_python_to_sql_conversion(dataType):
      -        return lambda x: x
      -
      -    if isinstance(dataType, StructType):
      -        names, types = zip(*[(f.name, f.dataType) for f in dataType.fields])
      -        converters = [_python_to_sql_converter(t) for t in types]
      -
      -        def converter(obj):
      -            if isinstance(obj, dict):
      -                return tuple(c(obj.get(n)) for n, c in zip(names, converters))
      -            elif isinstance(obj, tuple):
      -                if hasattr(obj, "__fields__") or hasattr(obj, "_fields"):
      -                    return tuple(c(v) for c, v in zip(converters, obj))
      -                elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):  # k-v pairs
      -                    d = dict(obj)
      -                    return tuple(c(d.get(n)) for n, c in zip(names, converters))
      -                else:
      -                    return tuple(c(v) for c, v in zip(converters, obj))
      -            elif obj is not None:
      -                raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
      -        return converter
      -    elif isinstance(dataType, ArrayType):
      -        element_converter = _python_to_sql_converter(dataType.elementType)
      -        return lambda a: a and [element_converter(v) for v in a]
      -    elif isinstance(dataType, MapType):
      -        key_converter = _python_to_sql_converter(dataType.keyType)
      -        value_converter = _python_to_sql_converter(dataType.valueType)
      -        return lambda m: m and dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
      -
      -    elif isinstance(dataType, UserDefinedType):
      -        return lambda obj: obj and dataType.serialize(obj)
      -
      -    elif isinstance(dataType, DateType):
      -        return lambda d: d and d.toordinal() - EPOCH_ORDINAL
      -
      -    elif isinstance(dataType, TimestampType):
      -
      -        def to_posix_timstamp(dt):
      -            if dt:
      -                seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo
      -                           else time.mktime(dt.timetuple()))
      -                return int(seconds * 1e7 + dt.microsecond * 10)
      -        return to_posix_timstamp
      -
      -    else:
      -        raise ValueError("Unexpected type %r" % dataType)
      -
      -
       def _has_nulltype(dt):
           """ Return whether there is NullType in `dt` or not """
           if isinstance(dt, StructType):
      @@ -1008,29 +1111,31 @@ def _verify_type(obj, dataType):
           Traceback (most recent call last):
               ...
           ValueError:...
      -    >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
      -    >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL
      -    Traceback (most recent call last):
      -        ...
      -    ValueError:...
           """
           # all objects are nullable
           if obj is None:
               return
       
      +    # StringType can work with any types
      +    if isinstance(dataType, StringType):
      +        return
      +
           if isinstance(dataType, UserDefinedType):
               if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
                   raise ValueError("%r is not an instance of type %r" % (obj, dataType))
      -        _verify_type(dataType.serialize(obj), dataType.sqlType())
      +        _verify_type(dataType.toInternal(obj), dataType.sqlType())
               return
       
           _type = type(dataType)
           assert _type in _acceptable_types, "unknown datatype: %s" % dataType
       
      -    # subclass of them can not be deserialized in JVM
      -    if type(obj) not in _acceptable_types[_type]:
      -        raise TypeError("%s can not accept object in type %s"
      -                        % (dataType, type(obj)))
      +    if _type is StructType:
      +        if not isinstance(obj, (tuple, list)):
      +            raise TypeError("StructType can not accept object in type %s" % type(obj))
      +    else:
      +        # subclass of them can not be fromInternald in JVM
      +        if type(obj) not in _acceptable_types[_type]:
      +            raise TypeError("%s can not accept object in type %s" % (dataType, type(obj)))
       
           if isinstance(dataType, ArrayType):
               for i in obj:
      @@ -1048,159 +1153,10 @@ def _verify_type(obj, dataType):
               for v, f in zip(obj, dataType.fields):
                   _verify_type(v, f.dataType)
       
      -_cached_cls = weakref.WeakValueDictionary()
      -
      -
      -def _restore_object(dataType, obj):
      -    """ Restore object during unpickling. """
      -    # use id(dataType) as key to speed up lookup in dict
      -    # Because of batched pickling, dataType will be the
      -    # same object in most cases.
      -    k = id(dataType)
      -    cls = _cached_cls.get(k)
      -    if cls is None or cls.__datatype is not dataType:
      -        # use dataType as key to avoid create multiple class
      -        cls = _cached_cls.get(dataType)
      -        if cls is None:
      -            cls = _create_cls(dataType)
      -            _cached_cls[dataType] = cls
      -        cls.__datatype = dataType
      -        _cached_cls[k] = cls
      -    return cls(obj)
      -
      -
      -def _create_object(cls, v):
      -    """ Create an customized object with class `cls`. """
      -    # datetime.date would be deserialized as datetime.datetime
      -    # from java type, so we need to set it back.
      -    if cls is datetime.date and isinstance(v, datetime.datetime):
      -        return v.date()
      -    return cls(v) if v is not None else v
      -
      -
      -def _create_getter(dt, i):
      -    """ Create a getter for item `i` with schema """
      -    cls = _create_cls(dt)
      -
      -    def getter(self):
      -        return _create_object(cls, self[i])
      -
      -    return getter
      -
      -
      -def _has_struct_or_date(dt):
      -    """Return whether `dt` is or has StructType/DateType in it"""
      -    if isinstance(dt, StructType):
      -        return True
      -    elif isinstance(dt, ArrayType):
      -        return _has_struct_or_date(dt.elementType)
      -    elif isinstance(dt, MapType):
      -        return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType)
      -    elif isinstance(dt, DateType):
      -        return True
      -    elif isinstance(dt, UserDefinedType):
      -        return True
      -    return False
      -
      -
      -def _create_properties(fields):
      -    """Create properties according to fields"""
      -    ps = {}
      -    for i, f in enumerate(fields):
      -        name = f.name
      -        if (name.startswith("__") and name.endswith("__")
      -                or keyword.iskeyword(name)):
      -            warnings.warn("field name %s can not be accessed in Python,"
      -                          "use position to access it instead" % name)
      -        if _has_struct_or_date(f.dataType):
      -            # delay creating object until accessing it
      -            getter = _create_getter(f.dataType, i)
      -        else:
      -            getter = itemgetter(i)
      -        ps[name] = property(getter)
      -    return ps
      -
       
      -def _create_cls(dataType):
      -    """
      -    Create an class by dataType
      -
      -    The created class is similar to namedtuple, but can have nested schema.
      -
      -    >>> schema = _parse_schema_abstract("a b c")
      -    >>> row = (1, 1.0, "str")
      -    >>> schema = _infer_schema_type(row, schema)
      -    >>> obj = _create_cls(schema)(row)
      -    >>> import pickle
      -    >>> pickle.loads(pickle.dumps(obj))
      -    Row(a=1, b=1.0, c='str')
      -
      -    >>> row = [[1], {"key": (1, 2.0)}]
      -    >>> schema = _parse_schema_abstract("a[] b{c d}")
      -    >>> schema = _infer_schema_type(row, schema)
      -    >>> obj = _create_cls(schema)(row)
      -    >>> pickle.loads(pickle.dumps(obj))
      -    Row(a=[1], b={'key': Row(c=1, d=2.0)})
      -    >>> pickle.loads(pickle.dumps(obj.a))
      -    [1]
      -    >>> pickle.loads(pickle.dumps(obj.b))
      -    {'key': Row(c=1, d=2.0)}
      -    """
      -
      -    if isinstance(dataType, ArrayType):
      -        cls = _create_cls(dataType.elementType)
      -
      -        def List(l):
      -            if l is None:
      -                return
      -            return [_create_object(cls, v) for v in l]
      -
      -        return List
      -
      -    elif isinstance(dataType, MapType):
      -        kcls = _create_cls(dataType.keyType)
      -        vcls = _create_cls(dataType.valueType)
      -
      -        def Dict(d):
      -            if d is None:
      -                return
      -            return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items())
      -
      -        return Dict
      -
      -    elif isinstance(dataType, DateType):
      -        return datetime.date
      -
      -    elif isinstance(dataType, UserDefinedType):
      -        return lambda datum: dataType.deserialize(datum)
      -
      -    elif not isinstance(dataType, StructType):
      -        # no wrapper for atomic types
      -        return lambda x: x
      -
      -    class Row(tuple):
      -
      -        """ Row in DataFrame """
      -        __datatype = dataType
      -        __fields__ = tuple(f.name for f in dataType.fields)
      -        __slots__ = ()
      -
      -        # create property for fast access
      -        locals().update(_create_properties(dataType.fields))
      -
      -        def asDict(self):
      -            """ Return as a dict """
      -            return dict((n, getattr(self, n)) for n in self.__fields__)
      -
      -        def __repr__(self):
      -            # call collect __repr__ for nested objects
      -            return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n))
      -                                          for n in self.__fields__))
      -
      -        def __reduce__(self):
      -            return (_restore_object, (self.__datatype, tuple(self)))
      -
      -    return Row
      +# This is used to unpickle a Row from JVM
      +def _create_row_inbound_converter(dataType):
      +    return lambda *a: dataType.fromInternal(a)
       
       
       def _create_row(fields, values):
      @@ -1220,6 +1176,8 @@ class Row(tuple):
           >>> row = Row(name="Alice", age=11)
           >>> row
           Row(age=11, name='Alice')
      +    >>> row['name'], row['age']
      +    ('Alice', 11)
           >>> row.name, row.age
           ('Alice', 11)
       
      @@ -1251,19 +1209,55 @@ def __new__(self, *args, **kwargs):
               else:
                   raise ValueError("No args or kwargs")
       
      -    def asDict(self):
      +    def asDict(self, recursive=False):
               """
               Return as an dict
      +
      +        :param recursive: turns the nested Row as dict (default: False).
      +
      +        >>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11}
      +        True
      +        >>> row = Row(key=1, value=Row(name='a', age=2))
      +        >>> row.asDict() == {'key': 1, 'value': Row(age=2, name='a')}
      +        True
      +        >>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}}
      +        True
               """
               if not hasattr(self, "__fields__"):
                   raise TypeError("Cannot convert a Row class into dict")
      -        return dict(zip(self.__fields__, self))
      +
      +        if recursive:
      +            def conv(obj):
      +                if isinstance(obj, Row):
      +                    return obj.asDict(True)
      +                elif isinstance(obj, list):
      +                    return [conv(o) for o in obj]
      +                elif isinstance(obj, dict):
      +                    return dict((k, conv(v)) for k, v in obj.items())
      +                else:
      +                    return obj
      +            return dict(zip(self.__fields__, (conv(o) for o in self)))
      +        else:
      +            return dict(zip(self.__fields__, self))
       
           # let object acts like class
           def __call__(self, *args):
               """create new Row object"""
               return _create_row(self, args)
       
      +    def __getitem__(self, item):
      +        if isinstance(item, (int, slice)):
      +            return super(Row, self).__getitem__(item)
      +        try:
      +            # it will be slow when it has many fields,
      +            # but this will not be used in normal cases
      +            idx = self.__fields__.index(item)
      +            return super(Row, self).__getitem__(idx)
      +        except IndexError:
      +            raise KeyError(item)
      +        except ValueError:
      +            raise ValueError(item)
      +
           def __getattr__(self, item):
               if item.startswith("__"):
                   raise AttributeError(item)
      @@ -1277,6 +1271,11 @@ def __getattr__(self, item):
               except ValueError:
                   raise AttributeError(item)
       
      +    def __setattr__(self, key, value):
      +        if key != '__fields__':
      +            raise Exception("Row is read-only")
      +        self.__dict__[key] = value
      +
           def __reduce__(self):
               """Returns a tuple so Python knows how to pickle Row."""
               if hasattr(self, "__fields__"):
      @@ -1308,8 +1307,11 @@ def can_convert(self, obj):
       
           def convert(self, obj, gateway_client):
               Timestamp = JavaClass("java.sql.Timestamp", gateway_client)
      -        return Timestamp(int(time.mktime(obj.timetuple())) * 1000 + obj.microsecond // 1000)
      -
      +        seconds = (calendar.timegm(obj.utctimetuple()) if obj.tzinfo
      +                   else time.mktime(obj.timetuple()))
      +        t = Timestamp(int(seconds) * 1000)
      +        t.setNanos(obj.microsecond * 1000)
      +        return t
       
       # datetime is a subclass of date, we should register DatetimeConverter first
       register_input_converter(DatetimeConverter())
      @@ -1319,18 +1321,12 @@ def convert(self, obj, gateway_client):
       def _test():
           import doctest
           from pyspark.context import SparkContext
      -    # let doctest run in pyspark.sql.types, so DataTypes can be picklable
      -    import pyspark.sql.types
      -    from pyspark.sql import Row, SQLContext
      -    from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
      -    globs = pyspark.sql.types.__dict__.copy()
      +    from pyspark.sql import SQLContext
      +    globs = globals()
           sc = SparkContext('local[4]', 'PythonTest')
           globs['sc'] = sc
           globs['sqlContext'] = SQLContext(sc)
      -    globs['ExamplePoint'] = ExamplePoint
      -    globs['ExamplePointUDT'] = ExamplePointUDT
      -    (failure_count, test_count) = doctest.testmod(
      -        pyspark.sql.types, globs=globs, optionflags=doctest.ELLIPSIS)
      +    (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
           globs['sc'].stop()
           if failure_count:
               exit(-1)
      diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
      new file mode 100644
      index 000000000000..0f795ca35b38
      --- /dev/null
      +++ b/python/pyspark/sql/utils.py
      @@ -0,0 +1,62 @@
      +#
      +# Licensed to the Apache Software Foundation (ASF) under one or more
      +# contributor license agreements.  See the NOTICE file distributed with
      +# this work for additional information regarding copyright ownership.
      +# The ASF licenses this file to You under the Apache License, Version 2.0
      +# (the "License"); you may not use this file except in compliance with
      +# the License.  You may obtain a copy of the License at
      +#
      +#    http://www.apache.org/licenses/LICENSE-2.0
      +#
      +# Unless required by applicable law or agreed to in writing, software
      +# distributed under the License is distributed on an "AS IS" BASIS,
      +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      +# See the License for the specific language governing permissions and
      +# limitations under the License.
      +#
      +
      +import py4j
      +
      +
      +class AnalysisException(Exception):
      +    """
      +    Failed to analyze a SQL query plan.
      +    """
      +
      +
      +class IllegalArgumentException(Exception):
      +    """
      +    Passed an illegal or inappropriate argument.
      +    """
      +
      +
      +def capture_sql_exception(f):
      +    def deco(*a, **kw):
      +        try:
      +            return f(*a, **kw)
      +        except py4j.protocol.Py4JJavaError as e:
      +            s = e.java_exception.toString()
      +            if s.startswith('org.apache.spark.sql.AnalysisException: '):
      +                raise AnalysisException(s.split(': ', 1)[1])
      +            if s.startswith('java.lang.IllegalArgumentException: '):
      +                raise IllegalArgumentException(s.split(': ', 1)[1])
      +            raise
      +    return deco
      +
      +
      +def install_exception_handler():
      +    """
      +    Hook an exception handler into Py4j, which could capture some SQL exceptions in Java.
      +
      +    When calling Java API, it will call `get_return_value` to parse the returned object.
      +    If any exception happened in JVM, the result will be Java exception object, it raise
      +    py4j.protocol.Py4JJavaError. We replace the original `get_return_value` with one that
      +    could capture the Java exception and throw a Python one (with the same error message).
      +
      +    It's idempotent, could be called multiple times.
      +    """
      +    original = py4j.protocol.get_return_value
      +    # The original `get_return_value` is not patched, it's idempotent.
      +    patched = capture_sql_exception(original)
      +    # only patch the one used in in py4j.java_gateway (call Java API)
      +    py4j.java_gateway.get_return_value = patched
      diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py
      index c74745c726a0..57bbe340bbd4 100644
      --- a/python/pyspark/sql/window.py
      +++ b/python/pyspark/sql/window.py
      @@ -17,8 +17,7 @@
       
       import sys
       
      -from pyspark import SparkContext
      -from pyspark.sql import since
      +from pyspark import since, SparkContext
       from pyspark.sql.column import _to_seq, _to_java_column
       
       __all__ = ["Window", "WindowSpec"]
      @@ -64,7 +63,7 @@ def orderBy(*cols):
               Creates a :class:`WindowSpec` with the partitioning defined.
               """
               sc = SparkContext._active_spark_context
      -        jspec = sc._jvm.org.apache.spark.sql.expressions.Window.partitionBy(_to_java_cols(cols))
      +        jspec = sc._jvm.org.apache.spark.sql.expressions.Window.orderBy(_to_java_cols(cols))
               return WindowSpec(jspec)
       
       
      diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py
      index 944fa414b0c0..0fee3b209682 100644
      --- a/python/pyspark/statcounter.py
      +++ b/python/pyspark/statcounter.py
      @@ -30,7 +30,9 @@
       
       class StatCounter(object):
       
      -    def __init__(self, values=[]):
      +    def __init__(self, values=None):
      +        if values is None:
      +            values = list()
               self.n = 0    # Running count of our values
               self.mu = 0.0  # Running mean of our values
               self.m2 = 0.0  # Running variance numerator (sum of (x - mean)^2)
      diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py
      index ac5ba69e8dbb..4069d7a14998 100644
      --- a/python/pyspark/streaming/context.py
      +++ b/python/pyspark/streaming/context.py
      @@ -86,6 +86,9 @@ class StreamingContext(object):
           """
           _transformerSerializer = None
       
      +    # Reference to a currently active StreamingContext
      +    _activeContext = None
      +
           def __init__(self, sparkContext, batchDuration=None, jssc=None):
               """
               Create a new StreamingContext.
      @@ -142,34 +145,84 @@ def getOrCreate(cls, checkpointPath, setupFunc):
               Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
               If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
               recreated from the checkpoint data. If the data does not exist, then the provided setupFunc
      -        will be used to create a JavaStreamingContext.
      +        will be used to create a new context.
       
      -        @param checkpointPath: Checkpoint directory used in an earlier JavaStreamingContext program
      -        @param setupFunc:      Function to create a new JavaStreamingContext and setup DStreams
      +        @param checkpointPath: Checkpoint directory used in an earlier streaming program
      +        @param setupFunc:      Function to create a new context and setup DStreams
               """
      -        # TODO: support checkpoint in HDFS
      -        if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath):
      +        cls._ensure_initialized()
      +        gw = SparkContext._gateway
      +
      +        # Check whether valid checkpoint information exists in the given path
      +        if gw.jvm.CheckpointReader.read(checkpointPath).isEmpty():
                   ssc = setupFunc()
                   ssc.checkpoint(checkpointPath)
                   return ssc
       
      -        cls._ensure_initialized()
      -        gw = SparkContext._gateway
      -
               try:
                   jssc = gw.jvm.JavaStreamingContext(checkpointPath)
               except Exception:
                   print("failed to load StreamingContext from checkpoint", file=sys.stderr)
                   raise
       
      -        jsc = jssc.sparkContext()
      -        conf = SparkConf(_jconf=jsc.getConf())
      -        sc = SparkContext(conf=conf, gateway=gw, jsc=jsc)
      +        # If there is already an active instance of Python SparkContext use it, or create a new one
      +        if not SparkContext._active_spark_context:
      +            jsc = jssc.sparkContext()
      +            conf = SparkConf(_jconf=jsc.getConf())
      +            SparkContext(conf=conf, gateway=gw, jsc=jsc)
      +
      +        sc = SparkContext._active_spark_context
      +
               # update ctx in serializer
      -        SparkContext._active_spark_context = sc
               cls._transformerSerializer.ctx = sc
               return StreamingContext(sc, None, jssc)
       
      +    @classmethod
      +    def getActive(cls):
      +        """
      +        Return either the currently active StreamingContext (i.e., if there is a context started
      +        but not stopped) or None.
      +        """
      +        activePythonContext = cls._activeContext
      +        if activePythonContext is not None:
      +            # Verify that the current running Java StreamingContext is active and is the same one
      +            # backing the supposedly active Python context
      +            activePythonContextJavaId = activePythonContext._jssc.ssc().hashCode()
      +            activeJvmContextOption = activePythonContext._jvm.StreamingContext.getActive()
      +
      +            if activeJvmContextOption.isEmpty():
      +                cls._activeContext = None
      +            elif activeJvmContextOption.get().hashCode() != activePythonContextJavaId:
      +                cls._activeContext = None
      +                raise Exception("JVM's active JavaStreamingContext is not the JavaStreamingContext "
      +                                "backing the action Python StreamingContext. This is unexpected.")
      +        return cls._activeContext
      +
      +    @classmethod
      +    def getActiveOrCreate(cls, checkpointPath, setupFunc):
      +        """
      +        Either return the active StreamingContext (i.e. currently started but not stopped),
      +        or recreate a StreamingContext from checkpoint data or create a new StreamingContext
      +        using the provided setupFunc function. If the checkpointPath is None or does not contain
      +        valid checkpoint data, then setupFunc will be called to create a new context and setup
      +        DStreams.
      +
      +        @param checkpointPath: Checkpoint directory used in an earlier streaming program. Can be
      +                               None if the intention is to always create a new context when there
      +                               is no active context.
      +        @param setupFunc:      Function to create a new JavaStreamingContext and setup DStreams
      +        """
      +
      +        if setupFunc is None:
      +            raise Exception("setupFunc cannot be None")
      +        activeContext = cls.getActive()
      +        if activeContext is not None:
      +            return activeContext
      +        elif checkpointPath is not None:
      +            return cls.getOrCreate(checkpointPath, setupFunc)
      +        else:
      +            return setupFunc()
      +
           @property
           def sparkContext(self):
               """
      @@ -182,6 +235,7 @@ def start(self):
               Start the execution of the streams.
               """
               self._jssc.start()
      +        StreamingContext._activeContext = self
       
           def awaitTermination(self, timeout=None):
               """
      @@ -212,6 +266,7 @@ def stop(self, stopSparkContext=True, stopGraceFully=False):
                                     of all received data to be completed
               """
               self._jssc.stop(stopSparkContext, stopGraceFully)
      +        StreamingContext._activeContext = None
               if stopSparkContext:
                   self._sc.stop()
       
      diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py
      index 8dcb9645cdc6..698336cfce18 100644
      --- a/python/pyspark/streaming/dstream.py
      +++ b/python/pyspark/streaming/dstream.py
      @@ -610,7 +610,10 @@ def __init__(self, prev, func):
               self.is_checkpointed = False
               self._jdstream_val = None
       
      -        if (isinstance(prev, TransformedDStream) and
      +        # Using type() to avoid folding the functions and compacting the DStreams which is not
      +        # not strictly a object of TransformedDStream.
      +        # Changed here is to avoid bug in KafkaTransformedDStream when calling offsetRanges().
      +        if (type(prev) is TransformedDStream and
                       not prev.is_cached and not prev.is_checkpointed):
                   prev_func = prev.func
                   self.func = lambda t, rdd: func(t, prev_func(t, rdd))
      diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py
      new file mode 100644
      index 000000000000..c0cdc50d8d42
      --- /dev/null
      +++ b/python/pyspark/streaming/flume.py
      @@ -0,0 +1,149 @@
      +#
      +# Licensed to the Apache Software Foundation (ASF) under one or more
      +# contributor license agreements.  See the NOTICE file distributed with
      +# this work for additional information regarding copyright ownership.
      +# The ASF licenses this file to You under the Apache License, Version 2.0
      +# (the "License"); you may not use this file except in compliance with
      +# the License.  You may obtain a copy of the License at
      +#
      +#    http://www.apache.org/licenses/LICENSE-2.0
      +#
      +# Unless required by applicable law or agreed to in writing, software
      +# distributed under the License is distributed on an "AS IS" BASIS,
      +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      +# See the License for the specific language governing permissions and
      +# limitations under the License.
      +#
      +
      +import sys
      +if sys.version >= "3":
      +    from io import BytesIO
      +else:
      +    from StringIO import StringIO
      +from py4j.java_gateway import Py4JJavaError
      +
      +from pyspark.storagelevel import StorageLevel
      +from pyspark.serializers import PairDeserializer, NoOpSerializer, UTF8Deserializer, read_int
      +from pyspark.streaming import DStream
      +
      +__all__ = ['FlumeUtils', 'utf8_decoder']
      +
      +
      +def utf8_decoder(s):
      +    """ Decode the unicode as UTF-8 """
      +    if s is None:
      +        return None
      +    return s.decode('utf-8')
      +
      +
      +class FlumeUtils(object):
      +
      +    @staticmethod
      +    def createStream(ssc, hostname, port,
      +                     storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2,
      +                     enableDecompression=False,
      +                     bodyDecoder=utf8_decoder):
      +        """
      +        Create an input stream that pulls events from Flume.
      +
      +        :param ssc:  StreamingContext object
      +        :param hostname:  Hostname of the slave machine to which the flume data will be sent
      +        :param port:  Port of the slave machine to which the flume data will be sent
      +        :param storageLevel:  Storage level to use for storing the received objects
      +        :param enableDecompression:  Should netty server decompress input stream
      +        :param bodyDecoder:  A function used to decode body (default is utf8_decoder)
      +        :return: A DStream object
      +        """
      +        jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
      +
      +        try:
      +            helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\
      +                .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper")
      +            helper = helperClass.newInstance()
      +            jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression)
      +        except Py4JJavaError as e:
      +            if 'ClassNotFoundException' in str(e.java_exception):
      +                FlumeUtils._printErrorMsg(ssc.sparkContext)
      +            raise e
      +
      +        return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder)
      +
      +    @staticmethod
      +    def createPollingStream(ssc, addresses,
      +                            storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2,
      +                            maxBatchSize=1000,
      +                            parallelism=5,
      +                            bodyDecoder=utf8_decoder):
      +        """
      +        Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent.
      +        This stream will poll the sink for data and will pull events as they are available.
      +
      +        :param ssc:  StreamingContext object
      +        :param addresses:  List of (host, port)s on which the Spark Sink is running.
      +        :param storageLevel:  Storage level to use for storing the received objects
      +        :param maxBatchSize:  The maximum number of events to be pulled from the Spark sink
      +                              in a single RPC call
      +        :param parallelism:  Number of concurrent requests this stream should send to the sink.
      +                             Note that having a higher number of requests concurrently being pulled
      +                             will result in this stream using more threads
      +        :param bodyDecoder:  A function used to decode body (default is utf8_decoder)
      +        :return: A DStream object
      +        """
      +        jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
      +        hosts = []
      +        ports = []
      +        for (host, port) in addresses:
      +            hosts.append(host)
      +            ports.append(port)
      +
      +        try:
      +            helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
      +                .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper")
      +            helper = helperClass.newInstance()
      +            jstream = helper.createPollingStream(
      +                ssc._jssc, hosts, ports, jlevel, maxBatchSize, parallelism)
      +        except Py4JJavaError as e:
      +            if 'ClassNotFoundException' in str(e.java_exception):
      +                FlumeUtils._printErrorMsg(ssc.sparkContext)
      +            raise e
      +
      +        return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder)
      +
      +    @staticmethod
      +    def _toPythonDStream(ssc, jstream, bodyDecoder):
      +        ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
      +        stream = DStream(jstream, ssc, ser)
      +
      +        def func(event):
      +            headersBytes = BytesIO(event[0]) if sys.version >= "3" else StringIO(event[0])
      +            headers = {}
      +            strSer = UTF8Deserializer()
      +            for i in range(0, read_int(headersBytes)):
      +                key = strSer.loads(headersBytes)
      +                value = strSer.loads(headersBytes)
      +                headers[key] = value
      +            body = bodyDecoder(event[1])
      +            return (headers, body)
      +        return stream.map(func)
      +
      +    @staticmethod
      +    def _printErrorMsg(sc):
      +        print("""
      +________________________________________________________________________________________________
      +
      +  Spark Streaming's Flume libraries not found in class path. Try one of the following.
      +
      +  1. Include the Flume library and its dependencies with in the
      +     spark-submit command as
      +
      +     $ bin/spark-submit --packages org.apache.spark:spark-streaming-flume:%s ...
      +
      +  2. Download the JAR of the artifact from Maven Central http://search.maven.org/,
      +     Group Id = org.apache.spark, Artifact Id = spark-streaming-flume-assembly, Version = %s.
      +     Then, include the jar in the spark-submit command as
      +
      +     $ bin/spark-submit --jars  ...
      +
      +________________________________________________________________________________________________
      +
      +""" % (sc.version, sc.version))
      diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py
      index 10a859a532e2..8a814c64c042 100644
      --- a/python/pyspark/streaming/kafka.py
      +++ b/python/pyspark/streaming/kafka.py
      @@ -21,19 +21,23 @@
       from pyspark.storagelevel import StorageLevel
       from pyspark.serializers import PairDeserializer, NoOpSerializer
       from pyspark.streaming import DStream
      +from pyspark.streaming.dstream import TransformedDStream
      +from pyspark.streaming.util import TransformFunction
       
       __all__ = ['Broker', 'KafkaUtils', 'OffsetRange', 'TopicAndPartition', 'utf8_decoder']
       
       
       def utf8_decoder(s):
           """ Decode the unicode as UTF-8 """
      -    return s and s.decode('utf-8')
      +    if s is None:
      +        return None
      +    return s.decode('utf-8')
       
       
       class KafkaUtils(object):
       
           @staticmethod
      -    def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={},
      +    def createStream(ssc, zkQuorum, groupId, topics, kafkaParams=None,
                            storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2,
                            keyDecoder=utf8_decoder, valueDecoder=utf8_decoder):
               """
      @@ -50,6 +54,8 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={},
               :param valueDecoder:  A function used to decode value (default is utf8_decoder)
               :return: A DStream object
               """
      +        if kafkaParams is None:
      +            kafkaParams = dict()
               kafkaParams.update({
                   "zookeeper.connect": zkQuorum,
                   "group.id": groupId,
      @@ -75,7 +81,7 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={},
               return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
       
           @staticmethod
      -    def createDirectStream(ssc, topics, kafkaParams, fromOffsets={},
      +    def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None,
                                  keyDecoder=utf8_decoder, valueDecoder=utf8_decoder):
               """
               .. note:: Experimental
      @@ -103,6 +109,8 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets={},
               :param valueDecoder:  A function used to decode value (default is utf8_decoder).
               :return: A DStream object
               """
      +        if fromOffsets is None:
      +            fromOffsets = dict()
               if not isinstance(topics, list):
                   raise TypeError("topics should be list")
               if not isinstance(kafkaParams, dict):
      @@ -122,11 +130,12 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets={},
                   raise e
       
               ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
      -        stream = DStream(jstream, ssc, ser)
      -        return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
      +        stream = DStream(jstream, ssc, ser) \
      +            .map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
      +        return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer)
       
           @staticmethod
      -    def createRDD(sc, kafkaParams, offsetRanges, leaders={},
      +    def createRDD(sc, kafkaParams, offsetRanges, leaders=None,
                         keyDecoder=utf8_decoder, valueDecoder=utf8_decoder):
               """
               .. note:: Experimental
      @@ -142,6 +151,8 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders={},
               :param valueDecoder:  A function used to decode value (default is utf8_decoder)
               :return: A RDD object
               """
      +        if leaders is None:
      +            leaders = dict()
               if not isinstance(kafkaParams, dict):
                   raise TypeError("kafkaParams should be dict")
               if not isinstance(offsetRanges, list):
      @@ -161,8 +172,8 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders={},
                   raise e
       
               ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
      -        rdd = RDD(jrdd, sc, ser)
      -        return rdd.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
      +        rdd = RDD(jrdd, sc, ser).map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
      +        return KafkaRDD(rdd._jrdd, rdd.ctx, rdd._jrdd_deserializer)
       
           @staticmethod
           def _printErrorMsg(sc):
      @@ -200,14 +211,30 @@ def __init__(self, topic, partition, fromOffset, untilOffset):
               :param fromOffset: Inclusive starting offset.
               :param untilOffset: Exclusive ending offset.
               """
      -        self._topic = topic
      -        self._partition = partition
      -        self._fromOffset = fromOffset
      -        self._untilOffset = untilOffset
      +        self.topic = topic
      +        self.partition = partition
      +        self.fromOffset = fromOffset
      +        self.untilOffset = untilOffset
      +
      +    def __eq__(self, other):
      +        if isinstance(other, self.__class__):
      +            return (self.topic == other.topic
      +                    and self.partition == other.partition
      +                    and self.fromOffset == other.fromOffset
      +                    and self.untilOffset == other.untilOffset)
      +        else:
      +            return False
      +
      +    def __ne__(self, other):
      +        return not self.__eq__(other)
      +
      +    def __str__(self):
      +        return "OffsetRange(topic: %s, partition: %d, range: [%d -> %d]" \
      +               % (self.topic, self.partition, self.fromOffset, self.untilOffset)
       
           def _jOffsetRange(self, helper):
      -        return helper.createOffsetRange(self._topic, self._partition, self._fromOffset,
      -                                        self._untilOffset)
      +        return helper.createOffsetRange(self.topic, self.partition, self.fromOffset,
      +                                        self.untilOffset)
       
       
       class TopicAndPartition(object):
      @@ -244,3 +271,87 @@ def __init__(self, host, port):
       
           def _jBroker(self, helper):
               return helper.createBroker(self._host, self._port)
      +
      +
      +class KafkaRDD(RDD):
      +    """
      +    A Python wrapper of KafkaRDD, to provide additional information on normal RDD.
      +    """
      +
      +    def __init__(self, jrdd, ctx, jrdd_deserializer):
      +        RDD.__init__(self, jrdd, ctx, jrdd_deserializer)
      +
      +    def offsetRanges(self):
      +        """
      +        Get the OffsetRange of specific KafkaRDD.
      +        :return: A list of OffsetRange
      +        """
      +        try:
      +            helperClass = self.ctx._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
      +                .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
      +            helper = helperClass.newInstance()
      +            joffsetRanges = helper.offsetRangesOfKafkaRDD(self._jrdd.rdd())
      +        except Py4JJavaError as e:
      +            if 'ClassNotFoundException' in str(e.java_exception):
      +                KafkaUtils._printErrorMsg(self.ctx)
      +            raise e
      +
      +        ranges = [OffsetRange(o.topic(), o.partition(), o.fromOffset(), o.untilOffset())
      +                  for o in joffsetRanges]
      +        return ranges
      +
      +
      +class KafkaDStream(DStream):
      +    """
      +    A Python wrapper of KafkaDStream
      +    """
      +
      +    def __init__(self, jdstream, ssc, jrdd_deserializer):
      +        DStream.__init__(self, jdstream, ssc, jrdd_deserializer)
      +
      +    def foreachRDD(self, func):
      +        """
      +        Apply a function to each RDD in this DStream.
      +        """
      +        if func.__code__.co_argcount == 1:
      +            old_func = func
      +            func = lambda r, rdd: old_func(rdd)
      +        jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) \
      +            .rdd_wrapper(lambda jrdd, ctx, ser: KafkaRDD(jrdd, ctx, ser))
      +        api = self._ssc._jvm.PythonDStream
      +        api.callForeachRDD(self._jdstream, jfunc)
      +
      +    def transform(self, func):
      +        """
      +        Return a new DStream in which each RDD is generated by applying a function
      +        on each RDD of this DStream.
      +
      +        `func` can have one argument of `rdd`, or have two arguments of
      +        (`time`, `rdd`)
      +        """
      +        if func.__code__.co_argcount == 1:
      +            oldfunc = func
      +            func = lambda t, rdd: oldfunc(rdd)
      +        assert func.__code__.co_argcount == 2, "func should take one or two arguments"
      +
      +        return KafkaTransformedDStream(self, func)
      +
      +
      +class KafkaTransformedDStream(TransformedDStream):
      +    """
      +    Kafka specific wrapper of TransformedDStream to transform on Kafka RDD.
      +    """
      +
      +    def __init__(self, prev, func):
      +        TransformedDStream.__init__(self, prev, func)
      +
      +    @property
      +    def _jdstream(self):
      +        if self._jdstream_val is not None:
      +            return self._jdstream_val
      +
      +        jfunc = TransformFunction(self._sc, self.func, self.prev._jrdd_deserializer) \
      +            .rdd_wrapper(lambda jrdd, ctx, ser: KafkaRDD(jrdd, ctx, ser))
      +        dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc)
      +        self._jdstream_val = dstream.asJavaDStream()
      +        return self._jdstream_val
      diff --git a/python/pyspark/streaming/kinesis.py b/python/pyspark/streaming/kinesis.py
      new file mode 100644
      index 000000000000..34be5880e170
      --- /dev/null
      +++ b/python/pyspark/streaming/kinesis.py
      @@ -0,0 +1,114 @@
      +#
      +# Licensed to the Apache Software Foundation (ASF) under one or more
      +# contributor license agreements.  See the NOTICE file distributed with
      +# this work for additional information regarding copyright ownership.
      +# The ASF licenses this file to You under the Apache License, Version 2.0
      +# (the "License"); you may not use this file except in compliance with
      +# the License.  You may obtain a copy of the License at
      +#
      +#    http://www.apache.org/licenses/LICENSE-2.0
      +#
      +# Unless required by applicable law or agreed to in writing, software
      +# distributed under the License is distributed on an "AS IS" BASIS,
      +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      +# See the License for the specific language governing permissions and
      +# limitations under the License.
      +#
      +
      +from py4j.java_gateway import Py4JJavaError
      +
      +from pyspark.serializers import PairDeserializer, NoOpSerializer
      +from pyspark.storagelevel import StorageLevel
      +from pyspark.streaming import DStream
      +
      +__all__ = ['KinesisUtils', 'InitialPositionInStream', 'utf8_decoder']
      +
      +
      +def utf8_decoder(s):
      +    """ Decode the unicode as UTF-8 """
      +    if s is None:
      +        return None
      +    return s.decode('utf-8')
      +
      +
      +class KinesisUtils(object):
      +
      +    @staticmethod
      +    def createStream(ssc, kinesisAppName, streamName, endpointUrl, regionName,
      +                     initialPositionInStream, checkpointInterval,
      +                     storageLevel=StorageLevel.MEMORY_AND_DISK_2,
      +                     awsAccessKeyId=None, awsSecretKey=None, decoder=utf8_decoder):
      +        """
      +        Create an input stream that pulls messages from a Kinesis stream. This uses the
      +        Kinesis Client Library (KCL) to pull messages from Kinesis.
      +
      +        Note: The given AWS credentials will get saved in DStream checkpoints if checkpointing is
      +        enabled. Make sure that your checkpoint directory is secure.
      +
      +        :param ssc:  StreamingContext object
      +        :param kinesisAppName:  Kinesis application name used by the Kinesis Client Library (KCL) to
      +                                update DynamoDB
      +        :param streamName:  Kinesis stream name
      +        :param endpointUrl:  Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
      +        :param regionName:  Name of region used by the Kinesis Client Library (KCL) to update
      +                            DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
      +        :param initialPositionInStream:  In the absence of Kinesis checkpoint info, this is the
      +                                         worker's initial starting position in the stream. The
      +                                         values are either the beginning of the stream per Kinesis'
      +                                         limit of 24 hours (InitialPositionInStream.TRIM_HORIZON) or
      +                                         the tip of the stream (InitialPositionInStream.LATEST).
      +        :param checkpointInterval:  Checkpoint interval for Kinesis checkpointing. See the Kinesis
      +                                    Spark Streaming documentation for more details on the different
      +                                    types of checkpoints.
      +        :param storageLevel:  Storage level to use for storing the received objects (default is
      +                              StorageLevel.MEMORY_AND_DISK_2)
      +        :param awsAccessKeyId:  AWS AccessKeyId (default is None. If None, will use
      +                                DefaultAWSCredentialsProviderChain)
      +        :param awsSecretKey:  AWS SecretKey (default is None. If None, will use
      +                              DefaultAWSCredentialsProviderChain)
      +        :param decoder:  A function used to decode value (default is utf8_decoder)
      +        :return: A DStream object
      +        """
      +        jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
      +        jduration = ssc._jduration(checkpointInterval)
      +
      +        try:
      +            # Use KinesisUtilsPythonHelper to access Scala's KinesisUtils
      +            helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\
      +                .loadClass("org.apache.spark.streaming.kinesis.KinesisUtilsPythonHelper")
      +            helper = helperClass.newInstance()
      +            jstream = helper.createStream(ssc._jssc, kinesisAppName, streamName, endpointUrl,
      +                                          regionName, initialPositionInStream, jduration, jlevel,
      +                                          awsAccessKeyId, awsSecretKey)
      +        except Py4JJavaError as e:
      +            if 'ClassNotFoundException' in str(e.java_exception):
      +                KinesisUtils._printErrorMsg(ssc.sparkContext)
      +            raise e
      +        stream = DStream(jstream, ssc, NoOpSerializer())
      +        return stream.map(lambda v: decoder(v))
      +
      +    @staticmethod
      +    def _printErrorMsg(sc):
      +        print("""
      +________________________________________________________________________________________________
      +
      +  Spark Streaming's Kinesis libraries not found in class path. Try one of the following.
      +
      +  1. Include the Kinesis library and its dependencies with in the
      +     spark-submit command as
      +
      +     $ bin/spark-submit --packages org.apache.spark:spark-streaming-kinesis-asl:%s ...
      +
      +  2. Download the JAR of the artifact from Maven Central http://search.maven.org/,
      +     Group Id = org.apache.spark, Artifact Id = spark-streaming-kinesis-asl-assembly, Version = %s.
      +     Then, include the jar in the spark-submit command as
      +
      +     $ bin/spark-submit --jars  ...
      +
      +________________________________________________________________________________________________
      +
      +""" % (sc.version, sc.version))
      +
      +
      +class InitialPositionInStream(object):
      +    LATEST, TRIM_HORIZON = (0, 1)
      diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py
      new file mode 100644
      index 000000000000..f06598971c54
      --- /dev/null
      +++ b/python/pyspark/streaming/mqtt.py
      @@ -0,0 +1,72 @@
      +#
      +# Licensed to the Apache Software Foundation (ASF) under one or more
      +# contributor license agreements.  See the NOTICE file distributed with
      +# this work for additional information regarding copyright ownership.
      +# The ASF licenses this file to You under the Apache License, Version 2.0
      +# (the "License"); you may not use this file except in compliance with
      +# the License.  You may obtain a copy of the License at
      +#
      +#    http://www.apache.org/licenses/LICENSE-2.0
      +#
      +# Unless required by applicable law or agreed to in writing, software
      +# distributed under the License is distributed on an "AS IS" BASIS,
      +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      +# See the License for the specific language governing permissions and
      +# limitations under the License.
      +#
      +
      +from py4j.java_gateway import Py4JJavaError
      +
      +from pyspark.storagelevel import StorageLevel
      +from pyspark.serializers import UTF8Deserializer
      +from pyspark.streaming import DStream
      +
      +__all__ = ['MQTTUtils']
      +
      +
      +class MQTTUtils(object):
      +
      +    @staticmethod
      +    def createStream(ssc, brokerUrl, topic,
      +                     storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2):
      +        """
      +        Create an input stream that pulls messages from a Mqtt Broker.
      +        :param ssc:  StreamingContext object
      +        :param brokerUrl:  Url of remote mqtt publisher
      +        :param topic:  topic name to subscribe to
      +        :param storageLevel:  RDD storage level.
      +        :return: A DStream object
      +        """
      +        jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
      +
      +        try:
      +            helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
      +                .loadClass("org.apache.spark.streaming.mqtt.MQTTUtilsPythonHelper")
      +            helper = helperClass.newInstance()
      +            jstream = helper.createStream(ssc._jssc, brokerUrl, topic, jlevel)
      +        except Py4JJavaError as e:
      +            if 'ClassNotFoundException' in str(e.java_exception):
      +                MQTTUtils._printErrorMsg(ssc.sparkContext)
      +            raise e
      +
      +        return DStream(jstream, ssc, UTF8Deserializer())
      +
      +    @staticmethod
      +    def _printErrorMsg(sc):
      +        print("""
      +________________________________________________________________________________________________
      +
      +  Spark Streaming's MQTT libraries not found in class path. Try one of the following.
      +
      +  1. Include the MQTT library and its dependencies with in the
      +     spark-submit command as
      +
      +     $ bin/spark-submit --packages org.apache.spark:spark-streaming-mqtt:%s ...
      +
      +  2. Download the JAR of the artifact from Maven Central http://search.maven.org/,
      +     Group Id = org.apache.spark, Artifact Id = spark-streaming-mqtt-assembly, Version = %s.
      +     Then, include the jar in the spark-submit command as
      +
      +     $ bin/spark-submit --jars  ...
      +________________________________________________________________________________________________
      +""" % (sc.version, sc.version))
      diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
      index 57049beea4db..cfea95b0dec7 100644
      --- a/python/pyspark/streaming/tests.py
      +++ b/python/pyspark/streaming/tests.py
      @@ -15,6 +15,7 @@
       # limitations under the License.
       #
       
      +import glob
       import os
       import sys
       from itertools import chain
      @@ -23,6 +24,7 @@
       import tempfile
       import random
       import struct
      +import shutil
       from functools import reduce
       
       if sys.version_info[:2] <= (2, 6):
      @@ -35,8 +37,12 @@
           import unittest
       
       from pyspark.context import SparkConf, SparkContext, RDD
      +from pyspark.storagelevel import StorageLevel
       from pyspark.streaming.context import StreamingContext
       from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition
      +from pyspark.streaming.flume import FlumeUtils
      +from pyspark.streaming.mqtt import MQTTUtils
      +from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream
       
       
       class PySparkStreamingTestCase(unittest.TestCase):
      @@ -54,12 +60,21 @@ def setUpClass(cls):
           @classmethod
           def tearDownClass(cls):
               cls.sc.stop()
      +        # Clean up in the JVM just in case there has been some issues in Python API
      +        jSparkContextOption = SparkContext._jvm.SparkContext.get()
      +        if jSparkContextOption.nonEmpty():
      +            jSparkContextOption.get().stop()
       
           def setUp(self):
               self.ssc = StreamingContext(self.sc, self.duration)
       
           def tearDown(self):
      -        self.ssc.stop(False)
      +        if self.ssc is not None:
      +            self.ssc.stop(False)
      +        # Clean up in the JVM just in case there has been some issues in Python API
      +        jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive()
      +        if jStreamingContextOption.nonEmpty():
      +            jStreamingContextOption.get().stop(False)
       
           def wait_for(self, result, n):
               start_time = time.time()
      @@ -437,6 +452,7 @@ def test_reduce_by_invalid_window(self):
       class StreamingContextTests(PySparkStreamingTestCase):
       
           duration = 0.1
      +    setupCalled = False
       
           def _add_input_stream(self):
               inputs = [range(1, x) for x in range(101)]
      @@ -510,10 +526,89 @@ def func(rdds):
       
               self.assertEqual([2, 3, 1], self._take(dstream, 3))
       
      +    def test_get_active(self):
      +        self.assertEqual(StreamingContext.getActive(), None)
      +
      +        # Verify that getActive() returns the active context
      +        self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count())
      +        self.ssc.start()
      +        self.assertEqual(StreamingContext.getActive(), self.ssc)
      +
      +        # Verify that getActive() returns None
      +        self.ssc.stop(False)
      +        self.assertEqual(StreamingContext.getActive(), None)
      +
      +        # Verify that if the Java context is stopped, then getActive() returns None
      +        self.ssc = StreamingContext(self.sc, self.duration)
      +        self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count())
      +        self.ssc.start()
      +        self.assertEqual(StreamingContext.getActive(), self.ssc)
      +        self.ssc._jssc.stop(False)
      +        self.assertEqual(StreamingContext.getActive(), None)
      +
      +    def test_get_active_or_create(self):
      +        # Test StreamingContext.getActiveOrCreate() without checkpoint data
      +        # See CheckpointTests for tests with checkpoint data
      +        self.ssc = None
      +        self.assertEqual(StreamingContext.getActive(), None)
      +
      +        def setupFunc():
      +            ssc = StreamingContext(self.sc, self.duration)
      +            ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count())
      +            self.setupCalled = True
      +            return ssc
      +
      +        # Verify that getActiveOrCreate() (w/o checkpoint) calls setupFunc when no context is active
      +        self.setupCalled = False
      +        self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc)
      +        self.assertTrue(self.setupCalled)
      +
      +        # Verify that getActiveOrCreate() retuns active context and does not call the setupFunc
      +        self.ssc.start()
      +        self.setupCalled = False
      +        self.assertEqual(StreamingContext.getActiveOrCreate(None, setupFunc), self.ssc)
      +        self.assertFalse(self.setupCalled)
      +
      +        # Verify that getActiveOrCreate() calls setupFunc after active context is stopped
      +        self.ssc.stop(False)
      +        self.setupCalled = False
      +        self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc)
      +        self.assertTrue(self.setupCalled)
      +
      +        # Verify that if the Java context is stopped, then getActive() returns None
      +        self.ssc = StreamingContext(self.sc, self.duration)
      +        self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count())
      +        self.ssc.start()
      +        self.assertEqual(StreamingContext.getActive(), self.ssc)
      +        self.ssc._jssc.stop(False)
      +        self.setupCalled = False
      +        self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc)
      +        self.assertTrue(self.setupCalled)
      +
       
       class CheckpointTests(unittest.TestCase):
       
      -    def test_get_or_create(self):
      +    setupCalled = False
      +
      +    @staticmethod
      +    def tearDownClass():
      +        # Clean up in the JVM just in case there has been some issues in Python API
      +        jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive()
      +        if jStreamingContextOption.nonEmpty():
      +            jStreamingContextOption.get().stop()
      +        jSparkContextOption = SparkContext._jvm.SparkContext.get()
      +        if jSparkContextOption.nonEmpty():
      +            jSparkContextOption.get().stop()
      +
      +    def tearDown(self):
      +        if self.ssc is not None:
      +            self.ssc.stop(True)
      +        if self.sc is not None:
      +            self.sc.stop()
      +        if self.cpd is not None:
      +            shutil.rmtree(self.cpd)
      +
      +    def test_get_or_create_and_get_active_or_create(self):
               inputd = tempfile.mkdtemp()
               outputd = tempfile.mkdtemp() + "/"
       
      @@ -528,11 +623,16 @@ def setup():
                   wc = dstream.updateStateByKey(updater)
                   wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test")
                   wc.checkpoint(.5)
      +            self.setupCalled = True
                   return ssc
       
      -        cpd = tempfile.mkdtemp("test_streaming_cps")
      -        ssc = StreamingContext.getOrCreate(cpd, setup)
      -        ssc.start()
      +        # Verify that getOrCreate() calls setup() in absence of checkpoint files
      +        self.cpd = tempfile.mkdtemp("test_streaming_cps")
      +        self.setupCalled = False
      +        self.ssc = StreamingContext.getOrCreate(self.cpd, setup)
      +        self.assertFalse(self.setupCalled)
      +
      +        self.ssc.start()
       
               def check_output(n):
                   while not os.listdir(outputd):
      @@ -547,7 +647,7 @@ def check_output(n):
                           # not finished
                           time.sleep(0.01)
                           continue
      -                ordd = ssc.sparkContext.textFile(p).map(lambda line: line.split(","))
      +                ordd = self.ssc.sparkContext.textFile(p).map(lambda line: line.split(","))
                       d = ordd.values().map(int).collect()
                       if not d:
                           time.sleep(0.01)
      @@ -563,13 +663,58 @@ def check_output(n):
       
               check_output(1)
               check_output(2)
      -        ssc.stop(True, True)
       
      +        # Verify the getOrCreate() recovers from checkpoint files
      +        self.ssc.stop(True, True)
               time.sleep(1)
      -        ssc = StreamingContext.getOrCreate(cpd, setup)
      -        ssc.start()
      +        self.setupCalled = False
      +        self.ssc = StreamingContext.getOrCreate(self.cpd, setup)
      +        self.assertFalse(self.setupCalled)
      +        self.ssc.start()
               check_output(3)
      -        ssc.stop(True, True)
      +
      +        # Verify that getOrCreate() uses existing SparkContext
      +        self.ssc.stop(True, True)
      +        time.sleep(1)
      +        sc = SparkContext(SparkConf())
      +        self.setupCalled = False
      +        self.ssc = StreamingContext.getOrCreate(self.cpd, setup)
      +        self.assertFalse(self.setupCalled)
      +        self.assertTrue(self.ssc.sparkContext == sc)
      +
      +        # Verify the getActiveOrCreate() recovers from checkpoint files
      +        self.ssc.stop(True, True)
      +        time.sleep(1)
      +        self.setupCalled = False
      +        self.ssc = StreamingContext.getActiveOrCreate(self.cpd, setup)
      +        self.assertFalse(self.setupCalled)
      +        self.ssc.start()
      +        check_output(4)
      +
      +        # Verify that getActiveOrCreate() returns active context
      +        self.setupCalled = False
      +        self.assertEquals(StreamingContext.getActiveOrCreate(self.cpd, setup), self.ssc)
      +        self.assertFalse(self.setupCalled)
      +
      +        # Verify that getActiveOrCreate() uses existing SparkContext
      +        self.ssc.stop(True, True)
      +        time.sleep(1)
      +        self.sc = SparkContext(SparkConf())
      +        self.setupCalled = False
      +        self.ssc = StreamingContext.getActiveOrCreate(self.cpd, setup)
      +        self.assertFalse(self.setupCalled)
      +        self.assertTrue(self.ssc.sparkContext == sc)
      +
      +        # Verify that getActiveOrCreate() calls setup() in absence of checkpoint files
      +        self.ssc.stop(True, True)
      +        shutil.rmtree(self.cpd)  # delete checkpoint directory
      +        time.sleep(1)
      +        self.setupCalled = False
      +        self.ssc = StreamingContext.getActiveOrCreate(self.cpd, setup)
      +        self.assertTrue(self.setupCalled)
      +
      +        # Stop everything
      +        self.ssc.stop(True, True)
       
       
       class KafkaStreamTests(PySparkStreamingTestCase):
      @@ -676,5 +821,480 @@ def test_kafka_rdd_with_leaders(self):
               rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders)
               self._validateRddResult(sendData, rdd)
       
      +    @unittest.skipIf(sys.version >= "3", "long type not support")
      +    def test_kafka_rdd_get_offsetRanges(self):
      +        """Test Python direct Kafka RDD get OffsetRanges."""
      +        topic = self._randomTopic()
      +        sendData = {"a": 3, "b": 4, "c": 5}
      +        offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))]
      +        kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()}
      +
      +        self._kafkaTestUtils.createTopic(topic)
      +        self._kafkaTestUtils.sendMessages(topic, sendData)
      +        rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges)
      +        self.assertEqual(offsetRanges, rdd.offsetRanges())
      +
      +    @unittest.skipIf(sys.version >= "3", "long type not support")
      +    def test_kafka_direct_stream_foreach_get_offsetRanges(self):
      +        """Test the Python direct Kafka stream foreachRDD get offsetRanges."""
      +        topic = self._randomTopic()
      +        sendData = {"a": 1, "b": 2, "c": 3}
      +        kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(),
      +                       "auto.offset.reset": "smallest"}
      +
      +        self._kafkaTestUtils.createTopic(topic)
      +        self._kafkaTestUtils.sendMessages(topic, sendData)
      +
      +        stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams)
      +
      +        offsetRanges = []
      +
      +        def getOffsetRanges(_, rdd):
      +            for o in rdd.offsetRanges():
      +                offsetRanges.append(o)
      +
      +        stream.foreachRDD(getOffsetRanges)
      +        self.ssc.start()
      +        self.wait_for(offsetRanges, 1)
      +
      +        self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))])
      +
      +    @unittest.skipIf(sys.version >= "3", "long type not support")
      +    def test_kafka_direct_stream_transform_get_offsetRanges(self):
      +        """Test the Python direct Kafka stream transform get offsetRanges."""
      +        topic = self._randomTopic()
      +        sendData = {"a": 1, "b": 2, "c": 3}
      +        kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(),
      +                       "auto.offset.reset": "smallest"}
      +
      +        self._kafkaTestUtils.createTopic(topic)
      +        self._kafkaTestUtils.sendMessages(topic, sendData)
      +
      +        stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams)
      +
      +        offsetRanges = []
      +
      +        def transformWithOffsetRanges(rdd):
      +            for o in rdd.offsetRanges():
      +                offsetRanges.append(o)
      +            return rdd
      +
      +        # Test whether it is ok mixing KafkaTransformedDStream and TransformedDStream together,
      +        # only the TransformedDstreams can be folded together.
      +        stream.transform(transformWithOffsetRanges).map(lambda kv: kv[1]).count().pprint()
      +        self.ssc.start()
      +        self.wait_for(offsetRanges, 1)
      +
      +        self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))])
      +
      +
      +class FlumeStreamTests(PySparkStreamingTestCase):
      +    timeout = 20  # seconds
      +    duration = 1
      +
      +    def setUp(self):
      +        super(FlumeStreamTests, self).setUp()
      +
      +        utilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
      +            .loadClass("org.apache.spark.streaming.flume.FlumeTestUtils")
      +        self._utils = utilsClz.newInstance()
      +
      +    def tearDown(self):
      +        if self._utils is not None:
      +            self._utils.close()
      +            self._utils = None
      +
      +        super(FlumeStreamTests, self).tearDown()
      +
      +    def _startContext(self, n, compressed):
      +        # Start the StreamingContext and also collect the result
      +        dstream = FlumeUtils.createStream(self.ssc, "localhost", self._utils.getTestPort(),
      +                                          enableDecompression=compressed)
      +        result = []
      +
      +        def get_output(_, rdd):
      +            for event in rdd.collect():
      +                if len(result) < n:
      +                    result.append(event)
      +        dstream.foreachRDD(get_output)
      +        self.ssc.start()
      +        return result
      +
      +    def _validateResult(self, input, result):
      +        # Validate both the header and the body
      +        header = {"test": "header"}
      +        self.assertEqual(len(input), len(result))
      +        for i in range(0, len(input)):
      +            self.assertEqual(header, result[i][0])
      +            self.assertEqual(input[i], result[i][1])
      +
      +    def _writeInput(self, input, compressed):
      +        # Try to write input to the receiver until success or timeout
      +        start_time = time.time()
      +        while True:
      +            try:
      +                self._utils.writeInput(input, compressed)
      +                break
      +            except:
      +                if time.time() - start_time < self.timeout:
      +                    time.sleep(0.01)
      +                else:
      +                    raise
      +
      +    def test_flume_stream(self):
      +        input = [str(i) for i in range(1, 101)]
      +        result = self._startContext(len(input), False)
      +        self._writeInput(input, False)
      +        self.wait_for(result, len(input))
      +        self._validateResult(input, result)
      +
      +    def test_compressed_flume_stream(self):
      +        input = [str(i) for i in range(1, 101)]
      +        result = self._startContext(len(input), True)
      +        self._writeInput(input, True)
      +        self.wait_for(result, len(input))
      +        self._validateResult(input, result)
      +
      +
      +class FlumePollingStreamTests(PySparkStreamingTestCase):
      +    timeout = 20  # seconds
      +    duration = 1
      +    maxAttempts = 5
      +
      +    def setUp(self):
      +        utilsClz = \
      +            self.sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
      +                .loadClass("org.apache.spark.streaming.flume.PollingFlumeTestUtils")
      +        self._utils = utilsClz.newInstance()
      +
      +    def tearDown(self):
      +        if self._utils is not None:
      +            self._utils.close()
      +            self._utils = None
      +
      +    def _writeAndVerify(self, ports):
      +        # Set up the streaming context and input streams
      +        ssc = StreamingContext(self.sc, self.duration)
      +        try:
      +            addresses = [("localhost", port) for port in ports]
      +            dstream = FlumeUtils.createPollingStream(
      +                ssc,
      +                addresses,
      +                maxBatchSize=self._utils.eventsPerBatch(),
      +                parallelism=5)
      +            outputBuffer = []
      +
      +            def get_output(_, rdd):
      +                for e in rdd.collect():
      +                    outputBuffer.append(e)
      +
      +            dstream.foreachRDD(get_output)
      +            ssc.start()
      +            self._utils.sendDatAndEnsureAllDataHasBeenReceived()
      +
      +            self.wait_for(outputBuffer, self._utils.getTotalEvents())
      +            outputHeaders = [event[0] for event in outputBuffer]
      +            outputBodies = [event[1] for event in outputBuffer]
      +            self._utils.assertOutput(outputHeaders, outputBodies)
      +        finally:
      +            ssc.stop(False)
      +
      +    def _testMultipleTimes(self, f):
      +        attempt = 0
      +        while True:
      +            try:
      +                f()
      +                break
      +            except:
      +                attempt += 1
      +                if attempt >= self.maxAttempts:
      +                    raise
      +                else:
      +                    import traceback
      +                    traceback.print_exc()
      +
      +    def _testFlumePolling(self):
      +        try:
      +            port = self._utils.startSingleSink()
      +            self._writeAndVerify([port])
      +            self._utils.assertChannelsAreEmpty()
      +        finally:
      +            self._utils.close()
      +
      +    def _testFlumePollingMultipleHosts(self):
      +        try:
      +            port = self._utils.startSingleSink()
      +            self._writeAndVerify([port])
      +            self._utils.assertChannelsAreEmpty()
      +        finally:
      +            self._utils.close()
      +
      +    def test_flume_polling(self):
      +        self._testMultipleTimes(self._testFlumePolling)
      +
      +    def test_flume_polling_multiple_hosts(self):
      +        self._testMultipleTimes(self._testFlumePollingMultipleHosts)
      +
      +
      +class MQTTStreamTests(PySparkStreamingTestCase):
      +    timeout = 20  # seconds
      +    duration = 1
      +
      +    def setUp(self):
      +        super(MQTTStreamTests, self).setUp()
      +
      +        MQTTTestUtilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
      +            .loadClass("org.apache.spark.streaming.mqtt.MQTTTestUtils")
      +        self._MQTTTestUtils = MQTTTestUtilsClz.newInstance()
      +        self._MQTTTestUtils.setup()
      +
      +    def tearDown(self):
      +        if self._MQTTTestUtils is not None:
      +            self._MQTTTestUtils.teardown()
      +            self._MQTTTestUtils = None
      +
      +        super(MQTTStreamTests, self).tearDown()
      +
      +    def _randomTopic(self):
      +        return "topic-%d" % random.randint(0, 10000)
      +
      +    def _startContext(self, topic):
      +        # Start the StreamingContext and also collect the result
      +        stream = MQTTUtils.createStream(self.ssc, "tcp://" + self._MQTTTestUtils.brokerUri(), topic)
      +        result = []
      +
      +        def getOutput(_, rdd):
      +            for data in rdd.collect():
      +                result.append(data)
      +
      +        stream.foreachRDD(getOutput)
      +        self.ssc.start()
      +        return result
      +
      +    def test_mqtt_stream(self):
      +        """Test the Python MQTT stream API."""
      +        sendData = "MQTT demo for spark streaming"
      +        topic = self._randomTopic()
      +        result = self._startContext(topic)
      +
      +        def retry():
      +            self._MQTTTestUtils.publishData(topic, sendData)
      +            # Because "publishData" sends duplicate messages, here we should use > 0
      +            self.assertTrue(len(result) > 0)
      +            self.assertEqual(sendData, result[0])
      +
      +        # Retry it because we don't know when the receiver will start.
      +        self._retry_or_timeout(retry)
      +
      +    def _retry_or_timeout(self, test_func):
      +        start_time = time.time()
      +        while True:
      +            try:
      +                test_func()
      +                break
      +            except:
      +                if time.time() - start_time > self.timeout:
      +                    raise
      +                time.sleep(0.01)
      +
      +
      +class KinesisStreamTests(PySparkStreamingTestCase):
      +
      +    def test_kinesis_stream_api(self):
      +        # Don't start the StreamingContext because we cannot test it in Jenkins
      +        kinesisStream1 = KinesisUtils.createStream(
      +            self.ssc, "myAppNam", "mySparkStream",
      +            "https://kinesis.us-west-2.amazonaws.com", "us-west-2",
      +            InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2)
      +        kinesisStream2 = KinesisUtils.createStream(
      +            self.ssc, "myAppNam", "mySparkStream",
      +            "https://kinesis.us-west-2.amazonaws.com", "us-west-2",
      +            InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2,
      +            "awsAccessKey", "awsSecretKey")
      +
      +    def test_kinesis_stream(self):
      +        if not are_kinesis_tests_enabled:
      +            sys.stderr.write(
      +                "Skipped test_kinesis_stream (enable by setting environment variable %s=1"
      +                % kinesis_test_environ_var)
      +            return
      +
      +        import random
      +        kinesisAppName = ("KinesisStreamTests-%d" % abs(random.randint(0, 10000000)))
      +        kinesisTestUtilsClz = \
      +            self.sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
      +                .loadClass("org.apache.spark.streaming.kinesis.KinesisTestUtils")
      +        kinesisTestUtils = kinesisTestUtilsClz.newInstance()
      +        try:
      +            kinesisTestUtils.createStream()
      +            aWSCredentials = kinesisTestUtils.getAWSCredentials()
      +            stream = KinesisUtils.createStream(
      +                self.ssc, kinesisAppName, kinesisTestUtils.streamName(),
      +                kinesisTestUtils.endpointUrl(), kinesisTestUtils.regionName(),
      +                InitialPositionInStream.LATEST, 10, StorageLevel.MEMORY_ONLY,
      +                aWSCredentials.getAWSAccessKeyId(), aWSCredentials.getAWSSecretKey())
      +
      +            outputBuffer = []
      +
      +            def get_output(_, rdd):
      +                for e in rdd.collect():
      +                    outputBuffer.append(e)
      +
      +            stream.foreachRDD(get_output)
      +            self.ssc.start()
      +
      +            testData = [i for i in range(1, 11)]
      +            expectedOutput = set([str(i) for i in testData])
      +            start_time = time.time()
      +            while time.time() - start_time < 120:
      +                kinesisTestUtils.pushData(testData)
      +                if expectedOutput == set(outputBuffer):
      +                    break
      +                time.sleep(10)
      +            self.assertEqual(expectedOutput, set(outputBuffer))
      +        except:
      +            import traceback
      +            traceback.print_exc()
      +            raise
      +        finally:
      +            self.ssc.stop(False)
      +            kinesisTestUtils.deleteStream()
      +            kinesisTestUtils.deleteDynamoDBTable(kinesisAppName)
      +
      +
      +# Search jar in the project dir using the jar name_prefix for both sbt build and maven build because
      +# the artifact jars are in different directories.
      +def search_jar(dir, name_prefix):
      +    # We should ignore the following jars
      +    ignored_jar_suffixes = ("javadoc.jar", "sources.jar", "test-sources.jar", "tests.jar")
      +    jars = (glob.glob(os.path.join(dir, "target/scala-*/" + name_prefix + "-*.jar")) +  # sbt build
      +            glob.glob(os.path.join(dir, "target/" + name_prefix + "_*.jar")))  # maven build
      +    return [jar for jar in jars if not jar.endswith(ignored_jar_suffixes)]
      +
      +
      +def search_kafka_assembly_jar():
      +    SPARK_HOME = os.environ["SPARK_HOME"]
      +    kafka_assembly_dir = os.path.join(SPARK_HOME, "external/kafka-assembly")
      +    jars = search_jar(kafka_assembly_dir, "spark-streaming-kafka-assembly")
      +    if not jars:
      +        raise Exception(
      +            ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) +
      +            "You need to build Spark with "
      +            "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or "
      +            "'build/mvn package' before running this test.")
      +    elif len(jars) > 1:
      +        raise Exception(("Found multiple Spark Streaming Kafka assembly JARs: %s; please "
      +                         "remove all but one") % (", ".join(jars)))
      +    else:
      +        return jars[0]
      +
      +
      +def search_flume_assembly_jar():
      +    SPARK_HOME = os.environ["SPARK_HOME"]
      +    flume_assembly_dir = os.path.join(SPARK_HOME, "external/flume-assembly")
      +    jars = search_jar(flume_assembly_dir, "spark-streaming-flume-assembly")
      +    if not jars:
      +        raise Exception(
      +            ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) +
      +            "You need to build Spark with "
      +            "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or "
      +            "'build/mvn package' before running this test.")
      +    elif len(jars) > 1:
      +        raise Exception(("Found multiple Spark Streaming Flume assembly JARs: %s; please "
      +                        "remove all but one") % (", ".join(jars)))
      +    else:
      +        return jars[0]
      +
      +
      +def search_mqtt_assembly_jar():
      +    SPARK_HOME = os.environ["SPARK_HOME"]
      +    mqtt_assembly_dir = os.path.join(SPARK_HOME, "external/mqtt-assembly")
      +    jars = search_jar(mqtt_assembly_dir, "spark-streaming-mqtt-assembly")
      +    if not jars:
      +        raise Exception(
      +            ("Failed to find Spark Streaming MQTT assembly jar in %s. " % mqtt_assembly_dir) +
      +            "You need to build Spark with "
      +            "'build/sbt assembly/assembly streaming-mqtt-assembly/assembly' or "
      +            "'build/mvn package' before running this test")
      +    elif len(jars) > 1:
      +        raise Exception(("Found multiple Spark Streaming MQTT assembly JARs: %s; please "
      +                         "remove all but one") % (", ".join(jars)))
      +    else:
      +        return jars[0]
      +
      +
      +def search_mqtt_test_jar():
      +    SPARK_HOME = os.environ["SPARK_HOME"]
      +    mqtt_test_dir = os.path.join(SPARK_HOME, "external/mqtt")
      +    jars = glob.glob(
      +        os.path.join(mqtt_test_dir, "target/scala-*/spark-streaming-mqtt-test-*.jar"))
      +    if not jars:
      +        raise Exception(
      +            ("Failed to find Spark Streaming MQTT test jar in %s. " % mqtt_test_dir) +
      +            "You need to build Spark with "
      +            "'build/sbt assembly/assembly streaming-mqtt/test:assembly'")
      +    elif len(jars) > 1:
      +        raise Exception(("Found multiple Spark Streaming MQTT test JARs: %s; please "
      +                         "remove all but one") % (", ".join(jars)))
      +    else:
      +        return jars[0]
      +
      +
      +def search_kinesis_asl_assembly_jar():
      +    SPARK_HOME = os.environ["SPARK_HOME"]
      +    kinesis_asl_assembly_dir = os.path.join(SPARK_HOME, "extras/kinesis-asl-assembly")
      +    jars = search_jar(kinesis_asl_assembly_dir, "spark-streaming-kinesis-asl-assembly")
      +    if not jars:
      +        return None
      +    elif len(jars) > 1:
      +        raise Exception(("Found multiple Spark Streaming Kinesis ASL assembly JARs: %s; please "
      +                         "remove all but one") % (", ".join(jars)))
      +    else:
      +        return jars[0]
      +
      +
      +# Must be same as the variable and condition defined in KinesisTestUtils.scala
      +kinesis_test_environ_var = "ENABLE_KINESIS_TESTS"
      +are_kinesis_tests_enabled = os.environ.get(kinesis_test_environ_var) == '1'
      +
       if __name__ == "__main__":
      -    unittest.main()
      +    kafka_assembly_jar = search_kafka_assembly_jar()
      +    flume_assembly_jar = search_flume_assembly_jar()
      +    mqtt_assembly_jar = search_mqtt_assembly_jar()
      +    mqtt_test_jar = search_mqtt_test_jar()
      +    kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar()
      +
      +    if kinesis_asl_assembly_jar is None:
      +        kinesis_jar_present = False
      +        jars = "%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar,
      +                                mqtt_test_jar)
      +    else:
      +        kinesis_jar_present = True
      +        jars = "%s,%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar,
      +                                   mqtt_test_jar, kinesis_asl_assembly_jar)
      +
      +    os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars
      +    testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests,
      +                 KafkaStreamTests, FlumeStreamTests, FlumePollingStreamTests, MQTTStreamTests]
      +
      +    if kinesis_jar_present is True:
      +        testcases.append(KinesisStreamTests)
      +    elif are_kinesis_tests_enabled is False:
      +        sys.stderr.write("Skipping all Kinesis Python tests as the optional Kinesis project was "
      +                         "not compiled into a JAR. To run these tests, "
      +                         "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/assembly "
      +                         "streaming-kinesis-asl-assembly/assembly' or "
      +                         "'build/mvn -Pkinesis-asl package' before running this test.")
      +    else:
      +        raise Exception(
      +            ("Failed to find Spark Streaming Kinesis assembly jar in %s. "
      +             % kinesis_asl_assembly_dir) +
      +            "You need to build Spark with 'build/sbt -Pkinesis-asl "
      +            "assembly/assembly streaming-kinesis-asl-assembly/assembly'"
      +            "or 'build/mvn -Pkinesis-asl package' before running this test.")
      +
      +    sys.stderr.write("Running tests: %s \n" % (str(testcases)))
      +    for testcase in testcases:
      +        sys.stderr.write("[Running %s]\n" % (testcase))
      +        tests = unittest.TestLoader().loadTestsFromTestCase(testcase)
      +        unittest.TextTestRunner(verbosity=3).run(tests)
      diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py
      index 34291f30a565..b20613b1283b 100644
      --- a/python/pyspark/streaming/util.py
      +++ b/python/pyspark/streaming/util.py
      @@ -37,6 +37,11 @@ def __init__(self, ctx, func, *deserializers):
               self.ctx = ctx
               self.func = func
               self.deserializers = deserializers
      +        self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser)
      +
      +    def rdd_wrapper(self, func):
      +        self._rdd_wrapper = func
      +        return self
       
           def call(self, milliseconds, jrdds):
               try:
      @@ -51,7 +56,7 @@ def call(self, milliseconds, jrdds):
                   if len(sers) < len(jrdds):
                       sers += (sers[0],) * (len(jrdds) - len(sers))
       
      -            rdds = [RDD(jrdd, self.ctx, ser) if jrdd else None
      +            rdds = [self._rdd_wrapper(jrdd, self.ctx, ser) if jrdd else None
                           for jrdd, ser in zip(jrdds, sers)]
                   t = datetime.fromtimestamp(milliseconds / 1000.0)
                   r = self.func(t, *rdds)
      @@ -125,4 +130,6 @@ def rddToFileName(prefix, suffix, timestamp):
       
       if __name__ == "__main__":
           import doctest
      -    doctest.testmod()
      +    (failure_count, test_count) = doctest.testmod()
      +    if failure_count:
      +        exit(-1)
      diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
      index 78265423682b..647504c32f15 100644
      --- a/python/pyspark/tests.py
      +++ b/python/pyspark/tests.py
      @@ -218,6 +218,11 @@ def test_namedtuple(self):
               p2 = loads(dumps(p1, 2))
               self.assertEqual(p1, p2)
       
      +        from pyspark.cloudpickle import dumps
      +        P2 = loads(dumps(P))
      +        p3 = P2(1, 3)
      +        self.assertEqual(p1, p3)
      +
           def test_itemgetter(self):
               from operator import itemgetter
               ser = CloudPickleSerializer()
      @@ -529,10 +534,127 @@ def test_deleting_input_files(self):
       
           def test_sampling_default_seed(self):
               # Test for SPARK-3995 (default seed setting)
      -        data = self.sc.parallelize(range(1000), 1)
      +        data = self.sc.parallelize(xrange(1000), 1)
               subset = data.takeSample(False, 10)
               self.assertEqual(len(subset), 10)
       
      +    def test_aggregate_mutable_zero_value(self):
      +        # Test for SPARK-9021; uses aggregate and treeAggregate to build dict
      +        # representing a counter of ints
      +        # NOTE: dict is used instead of collections.Counter for Python 2.6
      +        # compatibility
      +        from collections import defaultdict
      +
      +        # Show that single or multiple partitions work
      +        data1 = self.sc.range(10, numSlices=1)
      +        data2 = self.sc.range(10, numSlices=2)
      +
      +        def seqOp(x, y):
      +            x[y] += 1
      +            return x
      +
      +        def comboOp(x, y):
      +            for key, val in y.items():
      +                x[key] += val
      +            return x
      +
      +        counts1 = data1.aggregate(defaultdict(int), seqOp, comboOp)
      +        counts2 = data2.aggregate(defaultdict(int), seqOp, comboOp)
      +        counts3 = data1.treeAggregate(defaultdict(int), seqOp, comboOp, 2)
      +        counts4 = data2.treeAggregate(defaultdict(int), seqOp, comboOp, 2)
      +
      +        ground_truth = defaultdict(int, dict((i, 1) for i in range(10)))
      +        self.assertEqual(counts1, ground_truth)
      +        self.assertEqual(counts2, ground_truth)
      +        self.assertEqual(counts3, ground_truth)
      +        self.assertEqual(counts4, ground_truth)
      +
      +    def test_aggregate_by_key_mutable_zero_value(self):
      +        # Test for SPARK-9021; uses aggregateByKey to make a pair RDD that
      +        # contains lists of all values for each key in the original RDD
      +
      +        # list(range(...)) for Python 3.x compatibility (can't use * operator
      +        # on a range object)
      +        # list(zip(...)) for Python 3.x compatibility (want to parallelize a
      +        # collection, not a zip object)
      +        tuples = list(zip(list(range(10))*2, [1]*20))
      +        # Show that single or multiple partitions work
      +        data1 = self.sc.parallelize(tuples, 1)
      +        data2 = self.sc.parallelize(tuples, 2)
      +
      +        def seqOp(x, y):
      +            x.append(y)
      +            return x
      +
      +        def comboOp(x, y):
      +            x.extend(y)
      +            return x
      +
      +        values1 = data1.aggregateByKey([], seqOp, comboOp).collect()
      +        values2 = data2.aggregateByKey([], seqOp, comboOp).collect()
      +        # Sort lists to ensure clean comparison with ground_truth
      +        values1.sort()
      +        values2.sort()
      +
      +        ground_truth = [(i, [1]*2) for i in range(10)]
      +        self.assertEqual(values1, ground_truth)
      +        self.assertEqual(values2, ground_truth)
      +
      +    def test_fold_mutable_zero_value(self):
      +        # Test for SPARK-9021; uses fold to merge an RDD of dict counters into
      +        # a single dict
      +        # NOTE: dict is used instead of collections.Counter for Python 2.6
      +        # compatibility
      +        from collections import defaultdict
      +
      +        counts1 = defaultdict(int, dict((i, 1) for i in range(10)))
      +        counts2 = defaultdict(int, dict((i, 1) for i in range(3, 8)))
      +        counts3 = defaultdict(int, dict((i, 1) for i in range(4, 7)))
      +        counts4 = defaultdict(int, dict((i, 1) for i in range(5, 6)))
      +        all_counts = [counts1, counts2, counts3, counts4]
      +        # Show that single or multiple partitions work
      +        data1 = self.sc.parallelize(all_counts, 1)
      +        data2 = self.sc.parallelize(all_counts, 2)
      +
      +        def comboOp(x, y):
      +            for key, val in y.items():
      +                x[key] += val
      +            return x
      +
      +        fold1 = data1.fold(defaultdict(int), comboOp)
      +        fold2 = data2.fold(defaultdict(int), comboOp)
      +
      +        ground_truth = defaultdict(int)
      +        for counts in all_counts:
      +            for key, val in counts.items():
      +                ground_truth[key] += val
      +        self.assertEqual(fold1, ground_truth)
      +        self.assertEqual(fold2, ground_truth)
      +
      +    def test_fold_by_key_mutable_zero_value(self):
      +        # Test for SPARK-9021; uses foldByKey to make a pair RDD that contains
      +        # lists of all values for each key in the original RDD
      +
      +        tuples = [(i, range(i)) for i in range(10)]*2
      +        # Show that single or multiple partitions work
      +        data1 = self.sc.parallelize(tuples, 1)
      +        data2 = self.sc.parallelize(tuples, 2)
      +
      +        def comboOp(x, y):
      +            x.extend(y)
      +            return x
      +
      +        values1 = data1.foldByKey([], comboOp).collect()
      +        values2 = data2.foldByKey([], comboOp).collect()
      +        # Sort lists to ensure clean comparison with ground_truth
      +        values1.sort()
      +        values2.sort()
      +
      +        # list(range(...)) for Python 3.x compatibility
      +        ground_truth = [(i, list(range(i))*2) for i in range(10)]
      +        self.assertEqual(values1, ground_truth)
      +        self.assertEqual(values2, ground_truth)
      +
           def test_aggregate_by_key(self):
               data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2)
       
      @@ -624,8 +746,8 @@ def test_zip_with_different_serializers(self):
       
           def test_zip_with_different_object_sizes(self):
               # regress test for SPARK-5973
      -        a = self.sc.parallelize(range(10000)).map(lambda i: '*' * i)
      -        b = self.sc.parallelize(range(10000, 20000)).map(lambda i: '*' * i)
      +        a = self.sc.parallelize(xrange(10000)).map(lambda i: '*' * i)
      +        b = self.sc.parallelize(xrange(10000, 20000)).map(lambda i: '*' * i)
               self.assertEqual(10000, a.zip(b).count())
       
           def test_zip_with_different_number_of_items(self):
      @@ -647,7 +769,7 @@ def test_zip_with_different_number_of_items(self):
                   self.assertRaises(Exception, lambda: a.zip(b).count())
       
           def test_count_approx_distinct(self):
      -        rdd = self.sc.parallelize(range(1000))
      +        rdd = self.sc.parallelize(xrange(1000))
               self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050)
               self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050)
               self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050)
      @@ -777,7 +899,7 @@ def test_distinct(self):
           def test_external_group_by_key(self):
               self.sc._conf.set("spark.python.worker.memory", "1m")
               N = 200001
      -        kv = self.sc.parallelize(range(N)).map(lambda x: (x % 3, x))
      +        kv = self.sc.parallelize(xrange(N)).map(lambda x: (x % 3, x))
               gkv = kv.groupByKey().cache()
               self.assertEqual(3, gkv.count())
               filtered = gkv.filter(lambda kv: kv[0] == 1)
      @@ -871,7 +993,7 @@ def test_narrow_dependency_in_join(self):
       
           # Regression test for SPARK-6294
           def test_take_on_jrdd(self):
      -        rdd = self.sc.parallelize(range(1 << 20)).map(lambda x: str(x))
      +        rdd = self.sc.parallelize(xrange(1 << 20)).map(lambda x: str(x))
               rdd._jrdd.first()
       
           def test_sortByKey_uses_all_partitions_not_only_first_and_last(self):
      @@ -885,6 +1007,19 @@ def test_sortByKey_uses_all_partitions_not_only_first_and_last(self):
                   for size in sizes:
                       self.assertGreater(size, 0)
       
      +    def test_pipe_functions(self):
      +        data = ['1', '2', '3']
      +        rdd = self.sc.parallelize(data)
      +        with QuietTest(self.sc):
      +            self.assertEqual([], rdd.pipe('cc').collect())
      +            self.assertRaises(Py4JJavaError, rdd.pipe('cc', checkCode=True).collect)
      +        result = rdd.pipe('cat').collect()
      +        result.sort()
      +        for x, y in zip(data, result):
      +            self.assertEqual(x, y)
      +        self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect)
      +        self.assertEqual([], rdd.pipe('grep 4').collect())
      +
       
       class ProfilerTests(PySparkTestCase):
       
      @@ -1421,7 +1556,8 @@ def do_termination_test(self, terminator):
       
               # start daemon
               daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py")
      -        daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE)
      +        python_exec = sys.executable or os.environ.get("PYSPARK_PYTHON")
      +        daemon = Popen([python_exec, daemon_path], stdin=PIPE, stdout=PIPE)
       
               # read the port number
               port = read_int(daemon.stdout)
      @@ -1503,13 +1639,13 @@ def run():
                   self.fail("daemon had been killed")
       
               # run a normal job
      -        rdd = self.sc.parallelize(range(100), 1)
      +        rdd = self.sc.parallelize(xrange(100), 1)
               self.assertEqual(100, rdd.map(str).count())
       
           def test_after_exception(self):
               def raise_exception(_):
                   raise Exception()
      -        rdd = self.sc.parallelize(range(100), 1)
      +        rdd = self.sc.parallelize(xrange(100), 1)
               with QuietTest(self.sc):
                   self.assertRaises(Exception, lambda: rdd.foreach(raise_exception))
               self.assertEqual(100, rdd.map(str).count())
      @@ -1525,22 +1661,22 @@ def test_after_jvm_exception(self):
               with QuietTest(self.sc):
                   self.assertRaises(Exception, lambda: filtered_data.count())
       
      -        rdd = self.sc.parallelize(range(100), 1)
      +        rdd = self.sc.parallelize(xrange(100), 1)
               self.assertEqual(100, rdd.map(str).count())
       
           def test_accumulator_when_reuse_worker(self):
               from pyspark.accumulators import INT_ACCUMULATOR_PARAM
               acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
      -        self.sc.parallelize(range(100), 20).foreach(lambda x: acc1.add(x))
      +        self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc1.add(x))
               self.assertEqual(sum(range(100)), acc1.value)
       
               acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
      -        self.sc.parallelize(range(100), 20).foreach(lambda x: acc2.add(x))
      +        self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc2.add(x))
               self.assertEqual(sum(range(100)), acc2.value)
               self.assertEqual(sum(range(100)), acc1.value)
       
           def test_reuse_worker_after_take(self):
      -        rdd = self.sc.parallelize(range(100000), 1)
      +        rdd = self.sc.parallelize(xrange(100000), 1)
               self.assertEqual(0, rdd.first())
       
               def count():
      @@ -1692,7 +1828,7 @@ def test_module_dependency_on_cluster(self):
                   |    return x + 1
                   """)
               proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, "--master",
      -                                "local-cluster[1,1,512]", script],
      +                                "local-cluster[1,1,1024]", script],
                                       stdout=subprocess.PIPE)
               out, err = proc.communicate()
               self.assertEqual(0, proc.returncode)
      @@ -1726,7 +1862,7 @@ def test_package_dependency_on_cluster(self):
               self.create_spark_package("a:mylib:0.1")
               proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories",
                                        "file:" + self.programDir, "--master",
      -                                 "local-cluster[1,1,512]", script], stdout=subprocess.PIPE)
      +                                 "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE)
               out, err = proc.communicate()
               self.assertEqual(0, proc.returncode)
               self.assertIn("[2, 3, 4]", out.decode('utf-8'))
      @@ -1745,7 +1881,7 @@ def test_single_script_on_cluster(self):
               # this will fail if you have different spark.executor.memory
               # in conf/spark-defaults.conf
               proc = subprocess.Popen(
      -            [self.sparkSubmit, "--master", "local-cluster[1,1,512]", script],
      +            [self.sparkSubmit, "--master", "local-cluster[1,1,1024]", script],
                   stdout=subprocess.PIPE)
               out, err = proc.communicate()
               self.assertEqual(0, proc.returncode)
      diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
      index 93df9002be37..42c2f8b75933 100644
      --- a/python/pyspark/worker.py
      +++ b/python/pyspark/worker.py
      @@ -146,5 +146,5 @@ def process():
           java_port = int(sys.stdin.readline())
           sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
           sock.connect(("127.0.0.1", java_port))
      -    sock_file = sock.makefile("a+", 65536)
      +    sock_file = sock.makefile("rwb", 65536)
           main(sock_file, sock_file)
      diff --git a/python/run-tests b/python/run-tests
      index 4468fdb3f267..24949657ed7a 100755
      --- a/python/run-tests
      +++ b/python/run-tests
      @@ -18,165 +18,7 @@
       #
       
       
      -# Figure out where the Spark framework is installed
      -FWDIR="$(cd "`dirname "$0"`"; cd ../; pwd)"
      +FWDIR="$(cd "`dirname $0`"/..; pwd)"
      +cd "$FWDIR"
       
      -. "$FWDIR"/bin/load-spark-env.sh
      -
      -# CD into the python directory to find things on the right path
      -cd "$FWDIR/python"
      -
      -FAILED=0
      -LOG_FILE=unit-tests.log
      -START=$(date +"%s")
      -
      -rm -f $LOG_FILE
      -
      -# Remove the metastore and warehouse directory created by the HiveContext tests in Spark SQL
      -rm -rf metastore warehouse
      -
      -function run_test() {
      -    echo -en "Running test: $1 ... " | tee -a $LOG_FILE
      -    start=$(date +"%s")
      -    SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 > $LOG_FILE 2>&1
      -
      -    FAILED=$((PIPESTATUS[0]||$FAILED))
      -
      -    # Fail and exit on the first test failure.
      -    if [[ $FAILED != 0 ]]; then
      -        cat $LOG_FILE | grep -v "^[0-9][0-9]*" # filter all lines starting with a number.
      -        echo -en "\033[31m"  # Red
      -        echo "Had test failures; see logs."
      -        echo -en "\033[0m"  # No color
      -        exit -1
      -    else
      -        now=$(date +"%s")
      -        echo "ok ($(($now - $start))s)"
      -    fi
      -}
      -
      -function run_core_tests() {
      -    echo "Run core tests ..."
      -    run_test "pyspark.rdd"
      -    run_test "pyspark.context"
      -    run_test "pyspark.conf"
      -    run_test "pyspark.broadcast"
      -    run_test "pyspark.accumulators"
      -    run_test "pyspark.serializers"
      -    run_test "pyspark.profiler"
      -    run_test "pyspark.shuffle"
      -    run_test "pyspark.tests"
      -}
      -
      -function run_sql_tests() {
      -    echo "Run sql tests ..."
      -    run_test "pyspark.sql.types"
      -    run_test "pyspark.sql.context"
      -    run_test "pyspark.sql.column"
      -    run_test "pyspark.sql.dataframe"
      -    run_test "pyspark.sql.group"
      -    run_test "pyspark.sql.functions"
      -    run_test "pyspark.sql.readwriter"
      -    run_test "pyspark.sql.window"
      -    run_test "pyspark.sql.tests"
      -}
      -
      -function run_mllib_tests() {
      -    echo "Run mllib tests ..."
      -    run_test "pyspark.mllib.classification"
      -    run_test "pyspark.mllib.clustering"
      -    run_test "pyspark.mllib.evaluation"
      -    run_test "pyspark.mllib.feature"
      -    run_test "pyspark.mllib.fpm"
      -    run_test "pyspark.mllib.linalg"
      -    run_test "pyspark.mllib.random"
      -    run_test "pyspark.mllib.recommendation"
      -    run_test "pyspark.mllib.regression"
      -    run_test "pyspark.mllib.stat._statistics"
      -    run_test "pyspark.mllib.stat.KernelDensity"
      -    run_test "pyspark.mllib.tree"
      -    run_test "pyspark.mllib.util"
      -    run_test "pyspark.mllib.tests"
      -}
      -
      -function run_ml_tests() {
      -    echo "Run ml tests ..."
      -    run_test "pyspark.ml.feature"
      -    run_test "pyspark.ml.classification"
      -    run_test "pyspark.ml.recommendation"
      -    run_test "pyspark.ml.regression"
      -    run_test "pyspark.ml.tuning"
      -    run_test "pyspark.ml.tests"
      -    run_test "pyspark.ml.evaluation"
      -}
      -
      -function run_streaming_tests() {
      -    echo "Run streaming tests ..."
      -
      -    KAFKA_ASSEMBLY_DIR="$FWDIR"/external/kafka-assembly
      -    JAR_PATH="${KAFKA_ASSEMBLY_DIR}/target/scala-${SPARK_SCALA_VERSION}"
      -    for f in "${JAR_PATH}"/spark-streaming-kafka-assembly-*.jar; do
      -      if [[ ! -e "$f" ]]; then
      -        echo "Failed to find Spark Streaming Kafka assembly jar in $KAFKA_ASSEMBLY_DIR" 1>&2
      -        echo "You need to build Spark with " \
      -             "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or" \
      -             "'build/mvn package' before running this program" 1>&2
      -        exit 1
      -      fi
      -      KAFKA_ASSEMBLY_JAR="$f"
      -    done
      -
      -    export PYSPARK_SUBMIT_ARGS="--jars ${KAFKA_ASSEMBLY_JAR} pyspark-shell"
      -    run_test "pyspark.streaming.util"
      -    run_test "pyspark.streaming.tests"
      -}
      -
      -echo "Running PySpark tests. Output is in python/$LOG_FILE."
      -
      -export PYSPARK_PYTHON="python"
      -
      -# Try to test with Python 2.6, since that's the minimum version that we support:
      -if [ $(which python2.6) ]; then
      -    export PYSPARK_PYTHON="python2.6"
      -fi
      -
      -echo "Testing with Python version:"
      -$PYSPARK_PYTHON --version
      -
      -run_core_tests
      -run_sql_tests
      -run_mllib_tests
      -run_ml_tests
      -run_streaming_tests
      -
      -# Try to test with Python 3
      -if [ $(which python3.4) ]; then
      -    export PYSPARK_PYTHON="python3.4"
      -    echo "Testing with Python3.4 version:"
      -    $PYSPARK_PYTHON --version
      -
      -    run_core_tests
      -    run_sql_tests
      -    run_mllib_tests
      -    run_ml_tests
      -    run_streaming_tests
      -fi
      -
      -# Try to test with PyPy
      -if [ $(which pypy) ]; then
      -    export PYSPARK_PYTHON="pypy"
      -    echo "Testing with PyPy version:"
      -    $PYSPARK_PYTHON --version
      -
      -    run_core_tests
      -    run_sql_tests
      -    run_streaming_tests
      -fi
      -
      -if [[ $FAILED == 0 ]]; then
      -    now=$(date +"%s")
      -    echo -e "\033[32mTests passed \033[0min $(($now - $START)) seconds"
      -fi
      -
      -# TODO: in the long-run, it would be nice to use a test runner like `nose`.
      -# The doctest fixtures are the current barrier to doing this.
      +exec python -u ./python/run-tests.py "$@"
      diff --git a/python/run-tests.py b/python/run-tests.py
      new file mode 100755
      index 000000000000..fd56c7ab6e0e
      --- /dev/null
      +++ b/python/run-tests.py
      @@ -0,0 +1,214 @@
      +#!/usr/bin/env python
      +
      +#
      +# Licensed to the Apache Software Foundation (ASF) under one or more
      +# contributor license agreements.  See the NOTICE file distributed with
      +# this work for additional information regarding copyright ownership.
      +# The ASF licenses this file to You under the Apache License, Version 2.0
      +# (the "License"); you may not use this file except in compliance with
      +# the License.  You may obtain a copy of the License at
      +#
      +#    http://www.apache.org/licenses/LICENSE-2.0
      +#
      +# Unless required by applicable law or agreed to in writing, software
      +# distributed under the License is distributed on an "AS IS" BASIS,
      +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      +# See the License for the specific language governing permissions and
      +# limitations under the License.
      +#
      +
      +from __future__ import print_function
      +import logging
      +from optparse import OptionParser
      +import os
      +import re
      +import subprocess
      +import sys
      +import tempfile
      +from threading import Thread, Lock
      +import time
      +if sys.version < '3':
      +    import Queue
      +else:
      +    import queue as Queue
      +if sys.version_info >= (2, 7):
      +    subprocess_check_output = subprocess.check_output
      +else:
      +    # SPARK-8763
      +    # backported from subprocess module in Python 2.7
      +    def subprocess_check_output(*popenargs, **kwargs):
      +        if 'stdout' in kwargs:
      +            raise ValueError('stdout argument not allowed, it will be overridden.')
      +        process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs)
      +        output, unused_err = process.communicate()
      +        retcode = process.poll()
      +        if retcode:
      +            cmd = kwargs.get("args")
      +            if cmd is None:
      +                cmd = popenargs[0]
      +            raise subprocess.CalledProcessError(retcode, cmd, output=output)
      +        return output
      +
      +
      +# Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module
      +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../dev/"))
      +
      +
      +from sparktestsupport import SPARK_HOME  # noqa (suppress pep8 warnings)
      +from sparktestsupport.shellutils import which  # noqa
      +from sparktestsupport.modules import all_modules  # noqa
      +
      +
      +python_modules = dict((m.name, m) for m in all_modules if m.python_test_goals if m.name != 'root')
      +
      +
      +def print_red(text):
      +    print('\033[31m' + text + '\033[0m')
      +
      +
      +LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log")
      +FAILURE_REPORTING_LOCK = Lock()
      +LOGGER = logging.getLogger()
      +
      +
      +def run_individual_python_test(test_name, pyspark_python):
      +    env = dict(os.environ)
      +    env.update({'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)})
      +    LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name)
      +    start_time = time.time()
      +    try:
      +        per_test_output = tempfile.TemporaryFile()
      +        retcode = subprocess.Popen(
      +            [os.path.join(SPARK_HOME, "bin/pyspark"), test_name],
      +            stderr=per_test_output, stdout=per_test_output, env=env).wait()
      +    except:
      +        LOGGER.exception("Got exception while running %s with %s", test_name, pyspark_python)
      +        # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if
      +        # this code is invoked from a thread other than the main thread.
      +        os._exit(1)
      +    duration = time.time() - start_time
      +    # Exit on the first failure.
      +    if retcode != 0:
      +        try:
      +            with FAILURE_REPORTING_LOCK:
      +                with open(LOG_FILE, 'ab') as log_file:
      +                    per_test_output.seek(0)
      +                    log_file.writelines(per_test_output)
      +                per_test_output.seek(0)
      +                for line in per_test_output:
      +                    decoded_line = line.decode()
      +                    if not re.match('[0-9]+', decoded_line):
      +                        print(decoded_line, end='')
      +                per_test_output.close()
      +        except:
      +            LOGGER.exception("Got an exception while trying to print failed test output")
      +        finally:
      +            print_red("\nHad test failures in %s with %s; see logs." % (test_name, pyspark_python))
      +            # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if
      +            # this code is invoked from a thread other than the main thread.
      +            os._exit(-1)
      +    else:
      +        per_test_output.close()
      +        LOGGER.info("Finished test(%s): %s (%is)", pyspark_python, test_name, duration)
      +
      +
      +def get_default_python_executables():
      +    python_execs = [x for x in ["python2.6", "python3.4", "pypy"] if which(x)]
      +    if "python2.6" not in python_execs:
      +        LOGGER.warning("Not testing against `python2.6` because it could not be found; falling"
      +                       " back to `python` instead")
      +        python_execs.insert(0, "python")
      +    return python_execs
      +
      +
      +def parse_opts():
      +    parser = OptionParser(
      +        prog="run-tests"
      +    )
      +    parser.add_option(
      +        "--python-executables", type="string", default=','.join(get_default_python_executables()),
      +        help="A comma-separated list of Python executables to test against (default: %default)"
      +    )
      +    parser.add_option(
      +        "--modules", type="string",
      +        default=",".join(sorted(python_modules.keys())),
      +        help="A comma-separated list of Python modules to test (default: %default)"
      +    )
      +    parser.add_option(
      +        "-p", "--parallelism", type="int", default=4,
      +        help="The number of suites to test in parallel (default %default)"
      +    )
      +    parser.add_option(
      +        "--verbose", action="store_true",
      +        help="Enable additional debug logging"
      +    )
      +
      +    (opts, args) = parser.parse_args()
      +    if args:
      +        parser.error("Unsupported arguments: %s" % ' '.join(args))
      +    if opts.parallelism < 1:
      +        parser.error("Parallelism cannot be less than 1")
      +    return opts
      +
      +
      +def main():
      +    opts = parse_opts()
      +    if (opts.verbose):
      +        log_level = logging.DEBUG
      +    else:
      +        log_level = logging.INFO
      +    logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s")
      +    LOGGER.info("Running PySpark tests. Output is in %s", LOG_FILE)
      +    if os.path.exists(LOG_FILE):
      +        os.remove(LOG_FILE)
      +    python_execs = opts.python_executables.split(',')
      +    modules_to_test = []
      +    for module_name in opts.modules.split(','):
      +        if module_name in python_modules:
      +            modules_to_test.append(python_modules[module_name])
      +        else:
      +            print("Error: unrecognized module %s" % module_name)
      +            sys.exit(-1)
      +    LOGGER.info("Will test against the following Python executables: %s", python_execs)
      +    LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test])
      +
      +    task_queue = Queue.Queue()
      +    for python_exec in python_execs:
      +        python_implementation = subprocess_check_output(
      +            [python_exec, "-c", "import platform; print(platform.python_implementation())"],
      +            universal_newlines=True).strip()
      +        LOGGER.debug("%s python_implementation is %s", python_exec, python_implementation)
      +        LOGGER.debug("%s version is: %s", python_exec, subprocess_check_output(
      +            [python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip())
      +        for module in modules_to_test:
      +            if python_implementation not in module.blacklisted_python_implementations:
      +                for test_goal in module.python_test_goals:
      +                    task_queue.put((python_exec, test_goal))
      +
      +    def process_queue(task_queue):
      +        while True:
      +            try:
      +                (python_exec, test_goal) = task_queue.get_nowait()
      +            except Queue.Empty:
      +                break
      +            try:
      +                run_individual_python_test(test_goal, python_exec)
      +            finally:
      +                task_queue.task_done()
      +
      +    start_time = time.time()
      +    for _ in range(opts.parallelism):
      +        worker = Thread(target=process_queue, args=(task_queue,))
      +        worker.daemon = True
      +        worker.start()
      +    try:
      +        task_queue.join()
      +    except (KeyboardInterrupt, SystemExit):
      +        print_red("Exiting due to interrupt")
      +        sys.exit(-1)
      +    total_duration = time.time() - start_time
      +    LOGGER.info("Tests passed in %i seconds", total_duration)
      +
      +
      +if __name__ == "__main__":
      +    main()
      diff --git a/python/test_support/sql/orc_partitioned/._SUCCESS.crc b/python/test_support/sql/orc_partitioned/._SUCCESS.crc
      new file mode 100644
      index 000000000000..3b7b044936a8
      Binary files /dev/null and b/python/test_support/sql/orc_partitioned/._SUCCESS.crc differ
      diff --git a/sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964 b/python/test_support/sql/orc_partitioned/_SUCCESS
      old mode 100644
      new mode 100755
      similarity index 100%
      rename from sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964
      rename to python/test_support/sql/orc_partitioned/_SUCCESS
      diff --git a/python/test_support/sql/orc_partitioned/b=0/c=0/.part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc.crc b/python/test_support/sql/orc_partitioned/b=0/c=0/.part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc.crc
      new file mode 100644
      index 000000000000..834cf0b7f227
      Binary files /dev/null and b/python/test_support/sql/orc_partitioned/b=0/c=0/.part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc.crc differ
      diff --git a/python/test_support/sql/orc_partitioned/b=0/c=0/part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc b/python/test_support/sql/orc_partitioned/b=0/c=0/part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc
      new file mode 100755
      index 000000000000..494380187335
      Binary files /dev/null and b/python/test_support/sql/orc_partitioned/b=0/c=0/part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc differ
      diff --git a/python/test_support/sql/orc_partitioned/b=1/c=1/.part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc.crc b/python/test_support/sql/orc_partitioned/b=1/c=1/.part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc.crc
      new file mode 100644
      index 000000000000..693dceeee3ef
      Binary files /dev/null and b/python/test_support/sql/orc_partitioned/b=1/c=1/.part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc.crc differ
      diff --git a/python/test_support/sql/orc_partitioned/b=1/c=1/part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc b/python/test_support/sql/orc_partitioned/b=1/c=1/part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc
      new file mode 100755
      index 000000000000..4cbb95ae0242
      Binary files /dev/null and b/python/test_support/sql/orc_partitioned/b=1/c=1/part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc differ
      diff --git a/repl/pom.xml b/repl/pom.xml
      index 85f7bc8ac102..5cf416a4a544 100644
      --- a/repl/pom.xml
      +++ b/repl/pom.xml
      @@ -21,7 +21,7 @@
         
           org.apache.spark
           spark-parent_2.10
      -    1.5.0-SNAPSHOT
      +    1.6.0-SNAPSHOT
           ../pom.xml
         
       
      @@ -38,11 +38,6 @@
         
       
         
      -    
      -      ${jline.groupid}
      -      jline
      -      ${jline.version}
      -    
           
             org.apache.spark
             spark-core_${scala.binary.version}
      @@ -93,7 +88,7 @@
           
           
             org.mockito
      -      mockito-all
      +      mockito-core
             test
           
       
      @@ -138,7 +133,6 @@
                   
                   
                     
      -                src/main/scala
                       ${extra.source.dir}
                     
                   
      @@ -151,7 +145,6 @@
                   
                   
                     
      -                src/test/scala
                       ${extra.testsource.dir}
                     
                   
      @@ -161,6 +154,20 @@
           
         
         
      +    
      +      scala-2.10
      +      
      +        !scala-2.11
      +      
      +      
      +        
      +          ${jline.groupid}
      +          jline
      +          ${jline.version}
      +        
      +      
      +    
      +
           
             scala-2.11
             
      diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
      index 6480e2d24e04..24fbbc12c08d 100644
      --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
      +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
      @@ -39,6 +39,8 @@ class SparkCommandLine(args: List[String], override val settings: Settings)
         }
       
         def this(args: List[String]) {
      +    // scalastyle:off println
           this(args, str => Console.println("Error: " + str))
      +    // scalastyle:on println
         }
       }
      diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
      index 2b235525250c..304b1e8cdbed 100644
      --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
      +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
      @@ -981,7 +981,7 @@ class SparkILoop(
           // which spins off a separate thread, then print the prompt and try
           // our best to look ready.  The interlocking lazy vals tend to
           // inter-deadlock, so we break the cycle with a single asynchronous
      -    // message to an actor.
      +    // message to an rpcEndpoint.
           if (isAsync) {
             intp initialize initializedCallback()
             createAsyncListener() // listens for signal to run postInitialization
      @@ -1008,9 +1008,9 @@ class SparkILoop(
           val jars = SparkILoop.getAddedJars
           val conf = new SparkConf()
             .setMaster(getMaster())
      -      .setAppName("Spark shell")
             .setJars(jars)
             .set("spark.repl.class.uri", intp.classServerUri)
      +      .setIfMissing("spark.app.name", "Spark shell")
           if (execUri != null) {
             conf.set("spark.executor.uri", execUri)
           }
      @@ -1101,7 +1101,9 @@ object SparkILoop extends Logging {
                   val s = super.readLine()
                   // helping out by printing the line being interpreted.
                   if (s != null)
      +              // scalastyle:off println
                     output.println(s)
      +              // scalastyle:on println
                   s
                 }
               }
      diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
      index 05faef8786d2..bd3314d94eed 100644
      --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
      +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
      @@ -80,11 +80,13 @@ private[repl] trait SparkILoopInit {
           if (!initIsComplete)
             withLock { while (!initIsComplete) initLoopCondition.await() }
           if (initError != null) {
      +      // scalastyle:off println
             println("""
               |Failed to initialize the REPL due to an unexpected error.
               |This is a bug, please, report it along with the error diagnostics printed below.
               |%s.""".stripMargin.format(initError)
             )
      +      // scalastyle:on println
             false
           } else true
         }
      diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
      index 35fb62564502..4ee605fd7f11 100644
      --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
      +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
      @@ -1079,8 +1079,10 @@ import org.apache.spark.annotation.DeveloperApi
             throw new EvalException("Failed to load '" + path + "': " + ex.getMessage, ex)
       
           private def load(path: String): Class[_] = {
      +      // scalastyle:off classforname
             try Class.forName(path, true, classLoader)
             catch { case ex: Throwable => evalError(path, unwrap(ex)) }
      +      // scalastyle:on classforname
           }
       
           lazy val evalClass = load(evalPath)
      @@ -1761,7 +1763,9 @@ object SparkIMain {
               if (intp.totalSilence) ()
               else super.printMessage(msg)
             }
      +      // scalastyle:off println
             else Console.println(msg)
      +      // scalastyle:on println
           }
         }
       }
      diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
      index f150fec7db94..5674dcd669be 100644
      --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
      +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
      @@ -211,7 +211,7 @@ class ReplSuite extends SparkFunSuite {
         }
       
         test("local-cluster mode") {
      -    val output = runInterpreter("local-cluster[1,1,512]",
      +    val output = runInterpreter("local-cluster[1,1,1024]",
             """
               |var v = 7
               |def getV() = v
      @@ -233,7 +233,7 @@ class ReplSuite extends SparkFunSuite {
         }
       
         test("SPARK-1199 two instances of same class don't type check.") {
      -    val output = runInterpreter("local-cluster[1,1,512]",
      +    val output = runInterpreter("local-cluster[1,1,1024]",
             """
               |case class Sum(exp: String, exp2: String)
               |val a = Sum("A", "B")
      @@ -256,7 +256,7 @@ class ReplSuite extends SparkFunSuite {
       
         test("SPARK-2576 importing SQLContext.implicits._") {
           // We need to use local-cluster to test this case.
      -    val output = runInterpreter("local-cluster[1,1,512]",
      +    val output = runInterpreter("local-cluster[1,1,1024]",
             """
               |val sqlContext = new org.apache.spark.sql.SQLContext(sc)
               |import sqlContext.implicits._
      @@ -325,9 +325,9 @@ class ReplSuite extends SparkFunSuite {
           assertDoesNotContain("Exception", output)
           assertContains("ret: Array[Foo] = Array(Foo(1),", output)
         }
      -  
      +
         test("collecting objects of class defined in repl - shuffling") {
      -    val output = runInterpreter("local-cluster[1,1,512]",
      +    val output = runInterpreter("local-cluster[1,1,1024]",
             """
               |case class Foo(i: Int)
               |val list = List((1, Foo(1)), (1, Foo(2)))
      diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
      index f4f4b626988e..627148df80c1 100644
      --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
      +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
      @@ -17,13 +17,14 @@
       
       package org.apache.spark.repl
       
      +import java.io.File
      +
      +import scala.tools.nsc.Settings
      +
       import org.apache.spark.util.Utils
       import org.apache.spark._
       import org.apache.spark.sql.SQLContext
       
      -import scala.tools.nsc.Settings
      -import scala.tools.nsc.interpreter.SparkILoop
      -
       object Main extends Logging {
       
         val conf = new SparkConf()
      @@ -32,8 +33,10 @@ object Main extends Logging {
         val outputDir = Utils.createTempDir(rootDir)
         val s = new Settings()
         s.processArguments(List("-Yrepl-class-based",
      -    "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", "-Yrepl-sync"), true)
      -  val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf))
      +    "-Yrepl-outdir", s"${outputDir.getAbsolutePath}",
      +    "-classpath", getAddedJars.mkString(File.pathSeparator)), true)
      +  // the creation of SecurityManager has to be lazy so SPARK_YARN_MODE is set if needed
      +  lazy val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf))
         var sparkContext: SparkContext = _
         var sqlContext: SQLContext = _
         var interp = new SparkILoop // this is a public var because tests reset it.
      @@ -48,7 +51,6 @@ object Main extends Logging {
           Option(sparkContext).map(_.stop)
         }
       
      -
         def getAddedJars: Array[String] = {
           val envJars = sys.env.get("ADD_JARS")
           if (envJars.isDefined) {
      @@ -64,9 +66,9 @@ object Main extends Logging {
           val jars = getAddedJars
           val conf = new SparkConf()
             .setMaster(getMaster)
      -      .setAppName("Spark shell")
             .setJars(jars)
             .set("spark.repl.class.uri", classServer.uri)
      +      .setIfMissing("spark.app.name", "Spark shell")
           logInfo("Spark class server started at " + classServer.uri)
           if (execUri != null) {
             conf.set("spark.executor.uri", execUri)
      @@ -84,10 +86,9 @@ object Main extends Logging {
           val loader = Utils.getContextOrSparkClassLoader
           try {
             sqlContext = loader.loadClass(name).getConstructor(classOf[SparkContext])
      -        .newInstance(sparkContext).asInstanceOf[SQLContext] 
      +        .newInstance(sparkContext).asInstanceOf[SQLContext]
             logInfo("Created sql context (with Hive support)..")
      -    }
      -    catch {
      +    } catch {
             case _: java.lang.ClassNotFoundException | _: java.lang.NoClassDefFoundError =>
               sqlContext = new SQLContext(sparkContext)
               logInfo("Created sql context..")
      diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala
      deleted file mode 100644
      index 8e519fa67f64..000000000000
      --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala
      +++ /dev/null
      @@ -1,86 +0,0 @@
      -/* NSC -- new Scala compiler
      - * Copyright 2005-2013 LAMP/EPFL
      - * @author  Paul Phillips
      - */
      -
      -package scala.tools.nsc
      -package interpreter
      -
      -import scala.tools.nsc.ast.parser.Tokens.EOF
      -
      -trait SparkExprTyper {
      -  val repl: SparkIMain
      -
      -  import repl._
      -  import global.{ reporter => _, Import => _, _ }
      -  import naming.freshInternalVarName
      -
      -  def symbolOfLine(code: String): Symbol = {
      -    def asExpr(): Symbol = {
      -      val name  = freshInternalVarName()
      -      // Typing it with a lazy val would give us the right type, but runs
      -      // into compiler bugs with things like existentials, so we compile it
      -      // behind a def and strip the NullaryMethodType which wraps the expr.
      -      val line = "def " + name + " = " + code
      -
      -      interpretSynthetic(line) match {
      -        case IR.Success =>
      -          val sym0 = symbolOfTerm(name)
      -          // drop NullaryMethodType
      -          sym0.cloneSymbol setInfo exitingTyper(sym0.tpe_*.finalResultType)
      -        case _          => NoSymbol
      -      }
      -    }
      -    def asDefn(): Symbol = {
      -      val old = repl.definedSymbolList.toSet
      -
      -      interpretSynthetic(code) match {
      -        case IR.Success =>
      -          repl.definedSymbolList filterNot old match {
      -            case Nil        => NoSymbol
      -            case sym :: Nil => sym
      -            case syms       => NoSymbol.newOverloaded(NoPrefix, syms)
      -          }
      -        case _ => NoSymbol
      -      }
      -    }
      -    def asError(): Symbol = {
      -      interpretSynthetic(code)
      -      NoSymbol
      -    }
      -    beSilentDuring(asExpr()) orElse beSilentDuring(asDefn()) orElse asError()
      -  }
      -
      -  private var typeOfExpressionDepth = 0
      -  def typeOfExpression(expr: String, silent: Boolean = true): Type = {
      -    if (typeOfExpressionDepth > 2) {
      -      repldbg("Terminating typeOfExpression recursion for expression: " + expr)
      -      return NoType
      -    }
      -    typeOfExpressionDepth += 1
      -    // Don't presently have a good way to suppress undesirable success output
      -    // while letting errors through, so it is first trying it silently: if there
      -    // is an error, and errors are desired, then it re-evaluates non-silently
      -    // to induce the error message.
      -    try beSilentDuring(symbolOfLine(expr).tpe) match {
      -      case NoType if !silent => symbolOfLine(expr).tpe // generate error
      -      case tpe               => tpe
      -    }
      -    finally typeOfExpressionDepth -= 1
      -  }
      -
      -  // This only works for proper types.
      -  def typeOfTypeString(typeString: String): Type = {
      -    def asProperType(): Option[Type] = {
      -      val name = freshInternalVarName()
      -      val line = "def %s: %s = ???" format (name, typeString)
      -      interpretSynthetic(line) match {
      -        case IR.Success =>
      -          val sym0 = symbolOfTerm(name)
      -          Some(sym0.asMethod.returnType)
      -        case _          => None
      -      }
      -    }
      -    beSilentDuring(asProperType()) getOrElse NoType
      -  }
      -}
      diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
      index 7a5e94da5cbf..33d262558b1f 100644
      --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
      +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
      @@ -1,88 +1,64 @@
      -/* NSC -- new Scala compiler
      - * Copyright 2005-2013 LAMP/EPFL
      - * @author Alexander Spoon
      +/*
      + * Licensed to the Apache Software Foundation (ASF) under one or more
      + * contributor license agreements.  See the NOTICE file distributed with
      + * this work for additional information regarding copyright ownership.
      + * The ASF licenses this file to You under the Apache License, Version 2.0
      + * (the "License"); you may not use this file except in compliance with
      + * the License.  You may obtain a copy of the License at
      + *
      + *    http://www.apache.org/licenses/LICENSE-2.0
      + *
      + * Unless required by applicable law or agreed to in writing, software
      + * distributed under the License is distributed on an "AS IS" BASIS,
      + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      + * See the License for the specific language governing permissions and
      + * limitations under the License.
        */
       
      -package scala
      -package tools.nsc
      -package interpreter
      +package org.apache.spark.repl
       
      -import scala.language.{ implicitConversions, existentials }
      -import scala.annotation.tailrec
      -import Predef.{ println => _, _ }
      -import interpreter.session._
      -import StdReplTags._
      -import scala.reflect.api.{Mirror, Universe, TypeCreator}
      -import scala.util.Properties.{ jdkHome, javaVersion, versionString, javaVmName }
      -import scala.tools.nsc.util.{ ClassPath, Exceptional, stringFromWriter, stringFromStream }
      -import scala.reflect.{ClassTag, classTag}
      -import scala.reflect.internal.util.{ BatchSourceFile, ScalaClassLoader }
      -import ScalaClassLoader._
      -import scala.reflect.io.{ File, Directory }
      -import scala.tools.util._
      -import scala.collection.generic.Clearable
      -import scala.concurrent.{ ExecutionContext, Await, Future, future }
      -import ExecutionContext.Implicits._
      -import java.io.{ BufferedReader, FileReader }
      +import java.io.{BufferedReader, FileReader}
       
      -/** The Scala interactive shell.  It provides a read-eval-print loop
      -  *  around the Interpreter class.
      -  *  After instantiation, clients should call the main() method.
      -  *
      -  *  If no in0 is specified, then input will come from the console, and
      -  *  the class will attempt to provide input editing feature such as
      -  *  input history.
      -  *
      -  *  @author Moez A. Abdel-Gawad
      -  *  @author  Lex Spoon
      -  *  @version 1.2
      -  */
      -class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter)
      -  extends AnyRef
      -  with LoopCommands
      -{
      -  def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out)
      -  def this() = this(None, new JPrintWriter(Console.out, true))
      -//
      -//  @deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp
      -//  @deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: Interpreter): Unit = intp = i
      -
      -  var in: InteractiveReader = _   // the input stream from which commands come
      -  var settings: Settings = _
      -  var intp: SparkIMain = _
      +import Predef.{println => _, _}
      +import scala.util.Properties.{jdkHome, javaVersion, versionString, javaVmName}
       
      -  var globalFuture: Future[Boolean] = _
      +import scala.tools.nsc.interpreter.{JPrintWriter, ILoop}
      +import scala.tools.nsc.Settings
      +import scala.tools.nsc.util.stringFromStream
       
      -  protected def asyncMessage(msg: String) {
      -    if (isReplInfo || isReplPower)
      -      echoAndRefresh(msg)
      -  }
      +/**
      + *  A Spark-specific interactive shell.
      + */
      +class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter)
      +    extends ILoop(in0, out) {
      +  def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out)
      +  def this() = this(None, new JPrintWriter(Console.out, true))
       
         def initializeSpark() {
           intp.beQuietDuring {
      -      command( """
      +      processLine("""
                @transient val sc = {
                  val _sc = org.apache.spark.repl.Main.createSparkContext()
                  println("Spark context available as sc.")
                  _sc
                }
               """)
      -      command( """
      +      processLine("""
                @transient val sqlContext = {
                  val _sqlContext = org.apache.spark.repl.Main.createSQLContext()
                  println("SQL context available as sqlContext.")
                  _sqlContext
                }
               """)
      -      command("import org.apache.spark.SparkContext._")
      -      command("import sqlContext.implicits._")
      -      command("import sqlContext.sql")
      -      command("import org.apache.spark.sql.functions._")
      +      processLine("import org.apache.spark.SparkContext._")
      +      processLine("import sqlContext.implicits._")
      +      processLine("import sqlContext.sql")
      +      processLine("import org.apache.spark.sql.functions._")
           }
         }
       
         /** Print a welcome message */
      -  def printWelcome() {
      +  override def printWelcome() {
           import org.apache.spark.SPARK_VERSION
           echo("""Welcome to
             ____              __
      @@ -98,875 +74,42 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter)
           echo("Type :help for more information.")
         }
       
      -  override def echoCommandMessage(msg: String) {
      -    intp.reporter printUntruncatedMessage msg
      -  }
      -
      -  // lazy val power = new Power(intp, new StdReplVals(this))(tagOfStdReplVals, classTag[StdReplVals])
      -  def history = in.history
      -
      -  // classpath entries added via :cp
      -  var addedClasspath: String = ""
      -
      -  /** A reverse list of commands to replay if the user requests a :replay */
      -  var replayCommandStack: List[String] = Nil
      -
      -  /** A list of commands to replay if the user requests a :replay */
      -  def replayCommands = replayCommandStack.reverse
      -
      -  /** Record a command for replay should the user request a :replay */
      -  def addReplay(cmd: String) = replayCommandStack ::= cmd
      -
      -  def savingReplayStack[T](body: => T): T = {
      -    val saved = replayCommandStack
      -    try body
      -    finally replayCommandStack = saved
      -  }
      -  def savingReader[T](body: => T): T = {
      -    val saved = in
      -    try body
      -    finally in = saved
      -  }
      -
      -  /** Close the interpreter and set the var to null. */
      -  def closeInterpreter() {
      -    if (intp ne null) {
      -      intp.close()
      -      intp = null
      -    }
      -  }
      -
      -  class SparkILoopInterpreter extends SparkIMain(settings, out) {
      -    outer =>
      -
      -    override lazy val formatting = new Formatting {
      -      def prompt = SparkILoop.this.prompt
      -    }
      -    override protected def parentClassLoader =
      -      settings.explicitParentLoader.getOrElse( classOf[SparkILoop].getClassLoader )
      -  }
      -
      -  /** Create a new interpreter. */
      -  def createInterpreter() {
      -    if (addedClasspath != "")
      -      settings.classpath append addedClasspath
      -
      -    intp = new SparkILoopInterpreter
      -  }
      -
      -  /** print a friendly help message */
      -  def helpCommand(line: String): Result = {
      -    if (line == "") helpSummary()
      -    else uniqueCommand(line) match {
      -      case Some(lc) => echo("\n" + lc.help)
      -      case _        => ambiguousError(line)
      -    }
      -  }
      -  private def helpSummary() = {
      -    val usageWidth  = commands map (_.usageMsg.length) max
      -    val formatStr   = "%-" + usageWidth + "s %s"
      -
      -    echo("All commands can be abbreviated, e.g. :he instead of :help.")
      -
      -    commands foreach { cmd =>
      -      echo(formatStr.format(cmd.usageMsg, cmd.help))
      -    }
      -  }
      -  private def ambiguousError(cmd: String): Result = {
      -    matchingCommands(cmd) match {
      -      case Nil  => echo(cmd + ": no such command.  Type :help for help.")
      -      case xs   => echo(cmd + " is ambiguous: did you mean " + xs.map(":" + _.name).mkString(" or ") + "?")
      -    }
      -    Result(keepRunning = true, None)
      -  }
      -  private def matchingCommands(cmd: String) = commands filter (_.name startsWith cmd)
      -  private def uniqueCommand(cmd: String): Option[LoopCommand] = {
      -    // this lets us add commands willy-nilly and only requires enough command to disambiguate
      -    matchingCommands(cmd) match {
      -      case List(x)  => Some(x)
      -      // exact match OK even if otherwise appears ambiguous
      -      case xs       => xs find (_.name == cmd)
      -    }
      -  }
      -
      -  /** Show the history */
      -  lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") {
      -    override def usage = "[num]"
      -    def defaultLines = 20
      -
      -    def apply(line: String): Result = {
      -      if (history eq NoHistory)
      -        return "No history available."
      -
      -      val xs      = words(line)
      -      val current = history.index
      -      val count   = try xs.head.toInt catch { case _: Exception => defaultLines }
      -      val lines   = history.asStrings takeRight count
      -      val offset  = current - lines.size + 1
      -
      -      for ((line, index) <- lines.zipWithIndex)
      -        echo("%3d  %s".format(index + offset, line))
      -    }
      -  }
      -
      -  // When you know you are most likely breaking into the middle
      -  // of a line being typed.  This softens the blow.
      -  protected def echoAndRefresh(msg: String) = {
      -    echo("\n" + msg)
      -    in.redrawLine()
      -  }
      -  protected def echo(msg: String) = {
      -    out println msg
      -    out.flush()
      -  }
      -
      -  /** Search the history */
      -  def searchHistory(_cmdline: String) {
      -    val cmdline = _cmdline.toLowerCase
      -    val offset  = history.index - history.size + 1
      -
      -    for ((line, index) <- history.asStrings.zipWithIndex ; if line.toLowerCase contains cmdline)
      -      echo("%d %s".format(index + offset, line))
      -  }
      -
      -  private val currentPrompt = Properties.shellPromptString
      -
      -  /** Prompt to print when awaiting input */
      -  def prompt = currentPrompt
      -
         import LoopCommand.{ cmd, nullary }
       
      -  /** Standard commands **/
      -  lazy val standardCommands = List(
      -    cmd("cp", "", "add a jar or directory to the classpath", addClasspath),
      -    cmd("edit", "|", "edit history", editCommand),
      -    cmd("help", "[command]", "print this summary or command-specific help", helpCommand),
      -    historyCommand,
      -    cmd("h?", "", "search the history", searchHistory),
      -    cmd("imports", "[name name ...]", "show import history, identifying sources of names", importsCommand),
      -    //cmd("implicits", "[-v]", "show the implicits in scope", intp.implicitsCommand),
      -    cmd("javap", "", "disassemble a file or class name", javapCommand),
      -    cmd("line", "|", "place line(s) at the end of history", lineCommand),
      -    cmd("load", "", "interpret lines in a file", loadCommand),
      -    cmd("paste", "[-raw] [path]", "enter paste mode or paste a file", pasteCommand),
      -    // nullary("power", "enable power user mode", powerCmd),
      -    nullary("quit", "exit the interpreter", () => Result(keepRunning = false, None)),
      -    nullary("replay", "reset execution and replay all previous commands", replay),
      -    nullary("reset", "reset the repl to its initial state, forgetting all session entries", resetCommand),
      -    cmd("save", "", "save replayable session to a file", saveCommand),
      -    shCommand,
      -    cmd("settings", "[+|-]", "+enable/-disable flags, set compiler options", changeSettings),
      -    nullary("silent", "disable/enable automatic printing of results", verbosity),
      -//    cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand),
      -//    cmd("kind", "[-v] ", "display the kind of expression's type", kindCommand),
      -    nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand)
      -  )
      -
      -  /** Power user commands */
      -//  lazy val powerCommands: List[LoopCommand] = List(
      -//    cmd("phase", "", "set the implicit phase for power commands", phaseCommand)
      -//  )
      -
      -  private def importsCommand(line: String): Result = {
      -    val tokens    = words(line)
      -    val handlers  = intp.languageWildcardHandlers ++ intp.importHandlers
      -
      -    handlers.filterNot(_.importedSymbols.isEmpty).zipWithIndex foreach {
      -      case (handler, idx) =>
      -        val (types, terms) = handler.importedSymbols partition (_.name.isTypeName)
      -        val imps           = handler.implicitSymbols
      -        val found          = tokens filter (handler importsSymbolNamed _)
      -        val typeMsg        = if (types.isEmpty) "" else types.size + " types"
      -        val termMsg        = if (terms.isEmpty) "" else terms.size + " terms"
      -        val implicitMsg    = if (imps.isEmpty) "" else imps.size + " are implicit"
      -        val foundMsg       = if (found.isEmpty) "" else found.mkString(" // imports: ", ", ", "")
      -        val statsMsg       = List(typeMsg, termMsg, implicitMsg) filterNot (_ == "") mkString ("(", ", ", ")")
      -
      -        intp.reporter.printMessage("%2d) %-30s %s%s".format(
      -          idx + 1,
      -          handler.importString,
      -          statsMsg,
      -          foundMsg
      -        ))
      -    }
      -  }
      -
      -  private def findToolsJar() = PathResolver.SupplementalLocations.platformTools
      +  private val blockedCommands = Set("implicits", "javap", "power", "type", "kind")
       
      -  private def addToolsJarToLoader() = {
      -    val cl = findToolsJar() match {
      -      case Some(tools) => ScalaClassLoader.fromURLs(Seq(tools.toURL), intp.classLoader)
      -      case _           => intp.classLoader
      -    }
      -    if (Javap.isAvailable(cl)) {
      -      repldbg(":javap available.")
      -      cl
      -    }
      -    else {
      -      repldbg(":javap unavailable: no tools.jar at " + jdkHome)
      -      intp.classLoader
      -    }
      -  }
      -//
      -//  protected def newJavap() =
      -//    JavapClass(addToolsJarToLoader(), new IMain.ReplStrippingWriter(intp), Some(intp))
      -//
      -//  private lazy val javap = substituteAndLog[Javap]("javap", NoJavap)(newJavap())
      -
      -  // Still todo: modules.
      -//  private def typeCommand(line0: String): Result = {
      -//    line0.trim match {
      -//      case "" => ":type [-v] "
      -//      case s  => intp.typeCommandInternal(s stripPrefix "-v " trim, verbose = s startsWith "-v ")
      -//    }
      -//  }
      -
      -//  private def kindCommand(expr: String): Result = {
      -//    expr.trim match {
      -//      case "" => ":kind [-v] "
      -//      case s  => intp.kindCommandInternal(s stripPrefix "-v " trim, verbose = s startsWith "-v ")
      -//    }
      -//  }
      -
      -  private def warningsCommand(): Result = {
      -    if (intp.lastWarnings.isEmpty)
      -      "Can't find any cached warnings."
      -    else
      -      intp.lastWarnings foreach { case (pos, msg) => intp.reporter.warning(pos, msg) }
      -  }
      -
      -  private def changeSettings(args: String): Result = {
      -    def showSettings() = {
      -      for (s <- settings.userSetSettings.toSeq.sorted) echo(s.toString)
      -    }
      -    def updateSettings() = {
      -      // put aside +flag options
      -      val (pluses, rest) = (args split "\\s+").toList partition (_.startsWith("+"))
      -      val tmps = new Settings
      -      val (ok, leftover) = tmps.processArguments(rest, processAll = true)
      -      if (!ok) echo("Bad settings request.")
      -      else if (leftover.nonEmpty) echo("Unprocessed settings.")
      -      else {
      -        // boolean flags set-by-user on tmp copy should be off, not on
      -        val offs = tmps.userSetSettings filter (_.isInstanceOf[Settings#BooleanSetting])
      -        val (minuses, nonbools) = rest partition (arg => offs exists (_ respondsTo arg))
      -        // update non-flags
      -        settings.processArguments(nonbools, processAll = true)
      -        // also snag multi-value options for clearing, e.g. -Ylog: and -language:
      -        for {
      -          s <- settings.userSetSettings
      -          if s.isInstanceOf[Settings#MultiStringSetting] || s.isInstanceOf[Settings#PhasesSetting]
      -          if nonbools exists (arg => arg.head == '-' && arg.last == ':' && (s respondsTo arg.init))
      -        } s match {
      -          case c: Clearable => c.clear()
      -          case _ =>
      -        }
      -        def update(bs: Seq[String], name: String=>String, setter: Settings#Setting=>Unit) = {
      -          for (b <- bs)
      -            settings.lookupSetting(name(b)) match {
      -              case Some(s) =>
      -                if (s.isInstanceOf[Settings#BooleanSetting]) setter(s)
      -                else echo(s"Not a boolean flag: $b")
      -              case _ =>
      -                echo(s"Not an option: $b")
      -            }
      -        }
      -        update(minuses, identity, _.tryToSetFromPropertyValue("false"))  // turn off
      -        update(pluses, "-" + _.drop(1), _.tryToSet(Nil))                 // turn on
      -      }
      -    }
      -    if (args.isEmpty) showSettings() else updateSettings()
      -  }
      -
      -  private def javapCommand(line: String): Result = {
      -//    if (javap == null)
      -//      ":javap unavailable, no tools.jar at %s.  Set JDK_HOME.".format(jdkHome)
      -//    else if (line == "")
      -//      ":javap [-lcsvp] [path1 path2 ...]"
      -//    else
      -//      javap(words(line)) foreach { res =>
      -//        if (res.isError) return "Failed: " + res.value
      -//        else res.show()
      -//      }
      -  }
      -
      -  private def pathToPhaseWrapper = intp.originalPath("$r") + ".phased.atCurrent"
      -
      -  private def phaseCommand(name: String): Result = {
      -//    val phased: Phased = power.phased
      -//    import phased.NoPhaseName
      -//
      -//    if (name == "clear") {
      -//      phased.set(NoPhaseName)
      -//      intp.clearExecutionWrapper()
      -//      "Cleared active phase."
      -//    }
      -//    else if (name == "") phased.get match {
      -//      case NoPhaseName => "Usage: :phase  (e.g. typer, erasure.next, erasure+3)"
      -//      case ph          => "Active phase is '%s'.  (To clear, :phase clear)".format(phased.get)
      -//    }
      -//    else {
      -//      val what = phased.parse(name)
      -//      if (what.isEmpty || !phased.set(what))
      -//        "'" + name + "' does not appear to represent a valid phase."
      -//      else {
      -//        intp.setExecutionWrapper(pathToPhaseWrapper)
      -//        val activeMessage =
      -//          if (what.toString.length == name.length) "" + what
      -//          else "%s (%s)".format(what, name)
      -//
      -//        "Active phase is now: " + activeMessage
      -//      }
      -//    }
      -  }
      +  /** Standard commands **/
      +  lazy val sparkStandardCommands: List[SparkILoop.this.LoopCommand] =
      +    standardCommands.filter(cmd => !blockedCommands(cmd.name))
       
         /** Available commands */
      -  def commands: List[LoopCommand] = standardCommands ++ (
      -    // if (isReplPower)
      -    //  powerCommands
      -    // else
      -      Nil
      -    )
      -
      -  val replayQuestionMessage =
      -    """|That entry seems to have slain the compiler.  Shall I replay
      -      |your session? I can re-run each line except the last one.
      -      |[y/n]
      -    """.trim.stripMargin
      -
      -  private val crashRecovery: PartialFunction[Throwable, Boolean] = {
      -    case ex: Throwable =>
      -      val (err, explain) = (
      -        if (intp.isInitializeComplete)
      -          (intp.global.throwableAsString(ex), "")
      -        else
      -          (ex.getMessage, "The compiler did not initialize.\n")
      -        )
      -      echo(err)
      -
      -      ex match {
      -        case _: NoSuchMethodError | _: NoClassDefFoundError =>
      -          echo("\nUnrecoverable error.")
      -          throw ex
      -        case _  =>
      -          def fn(): Boolean =
      -            try in.readYesOrNo(explain + replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() })
      -            catch { case _: RuntimeException => false }
      -
      -          if (fn()) replay()
      -          else echo("\nAbandoning crashed session.")
      -      }
      -      true
      -  }
      -
      -  // return false if repl should exit
      -  def processLine(line: String): Boolean = {
      -    import scala.concurrent.duration._
      -    Await.ready(globalFuture, 60.seconds)
      -
      -    (line ne null) && (command(line) match {
      -      case Result(false, _)      => false
      -      case Result(_, Some(line)) => addReplay(line) ; true
      -      case _                     => true
      -    })
      -  }
      -
      -  private def readOneLine() = {
      -    out.flush()
      -    in readLine prompt
      -  }
      -
      -  /** The main read-eval-print loop for the repl.  It calls
      -    *  command() for each line of input, and stops when
      -    *  command() returns false.
      -    */
      -  @tailrec final def loop() {
      -    if ( try processLine(readOneLine()) catch crashRecovery )
      -      loop()
      -  }
      -
      -  /** interpret all lines from a specified file */
      -  def interpretAllFrom(file: File) {
      -    savingReader {
      -      savingReplayStack {
      -        file applyReader { reader =>
      -          in = SimpleReader(reader, out, interactive = false)
      -          echo("Loading " + file + "...")
      -          loop()
      -        }
      -      }
      -    }
      -  }
      -
      -  /** create a new interpreter and replay the given commands */
      -  def replay() {
      -    reset()
      -    if (replayCommandStack.isEmpty)
      -      echo("Nothing to replay.")
      -    else for (cmd <- replayCommands) {
      -      echo("Replaying: " + cmd)  // flush because maybe cmd will have its own output
      -      command(cmd)
      -      echo("")
      -    }
      -  }
      -  def resetCommand() {
      -    echo("Resetting interpreter state.")
      -    if (replayCommandStack.nonEmpty) {
      -      echo("Forgetting this session history:\n")
      -      replayCommands foreach echo
      -      echo("")
      -      replayCommandStack = Nil
      -    }
      -    if (intp.namedDefinedTerms.nonEmpty)
      -      echo("Forgetting all expression results and named terms: " + intp.namedDefinedTerms.mkString(", "))
      -    if (intp.definedTypes.nonEmpty)
      -      echo("Forgetting defined types: " + intp.definedTypes.mkString(", "))
      -
      -    reset()
      -  }
      -  def reset() {
      -    intp.reset()
      -    unleashAndSetPhase()
      -  }
      -
      -  def lineCommand(what: String): Result = editCommand(what, None)
      -
      -  // :edit id or :edit line
      -  def editCommand(what: String): Result = editCommand(what, Properties.envOrNone("EDITOR"))
      -
      -  def editCommand(what: String, editor: Option[String]): Result = {
      -    def diagnose(code: String) = {
      -      echo("The edited code is incomplete!\n")
      -      val errless = intp compileSources new BatchSourceFile("", s"object pastel {\n$code\n}")
      -      if (errless) echo("The compiler reports no errors.")
      -    }
      -    def historicize(text: String) = history match {
      -      case jlh: JLineHistory => text.lines foreach jlh.add ; jlh.moveToEnd() ; true
      -      case _ => false
      -    }
      -    def edit(text: String): Result = editor match {
      -      case Some(ed) =>
      -        val tmp = File.makeTemp()
      -        tmp.writeAll(text)
      -        try {
      -          val pr = new ProcessResult(s"$ed ${tmp.path}")
      -          pr.exitCode match {
      -            case 0 =>
      -              tmp.safeSlurp() match {
      -                case Some(edited) if edited.trim.isEmpty => echo("Edited text is empty.")
      -                case Some(edited) =>
      -                  echo(edited.lines map ("+" + _) mkString "\n")
      -                  val res = intp interpret edited
      -                  if (res == IR.Incomplete) diagnose(edited)
      -                  else {
      -                    historicize(edited)
      -                    Result(lineToRecord = Some(edited), keepRunning = true)
      -                  }
      -                case None => echo("Can't read edited text. Did you delete it?")
      -              }
      -            case x => echo(s"Error exit from $ed ($x), ignoring")
      -          }
      -        } finally {
      -          tmp.delete()
      -        }
      -      case None =>
      -        if (historicize(text)) echo("Placing text in recent history.")
      -        else echo(f"No EDITOR defined and you can't change history, echoing your text:%n$text")
      -    }
      -
      -    // if what is a number, use it as a line number or range in history
      -    def isNum = what forall (c => c.isDigit || c == '-' || c == '+')
      -    // except that "-" means last value
      -    def isLast = (what == "-")
      -    if (isLast || !isNum) {
      -      val name = if (isLast) intp.mostRecentVar else what
      -      val sym = intp.symbolOfIdent(name)
      -      intp.prevRequestList collectFirst { case r if r.defines contains sym => r } match {
      -        case Some(req) => edit(req.line)
      -        case None      => echo(s"No symbol in scope: $what")
      -      }
      -    } else try {
      -      val s = what
      -      // line 123, 120+3, -3, 120-123, 120-, note -3 is not 0-3 but (cur-3,cur)
      -      val (start, len) =
      -        if ((s indexOf '+') > 0) {
      -          val (a,b) = s splitAt (s indexOf '+')
      -          (a.toInt, b.drop(1).toInt)
      -        } else {
      -          (s indexOf '-') match {
      -            case -1 => (s.toInt, 1)
      -            case 0  => val n = s.drop(1).toInt ; (history.index - n, n)
      -            case _ if s.last == '-' => val n = s.init.toInt ; (n, history.index - n)
      -            case i  => val n = s.take(i).toInt ; (n, s.drop(i+1).toInt - n)
      -          }
      -        }
      -      import scala.collection.JavaConverters._
      -      val index = (start - 1) max 0
      -      val text = history match {
      -        case jlh: JLineHistory => jlh.entries(index).asScala.take(len) map (_.value) mkString "\n"
      -        case _ => history.asStrings.slice(index, index + len) mkString "\n"
      -      }
      -      edit(text)
      -    } catch {
      -      case _: NumberFormatException => echo(s"Bad range '$what'")
      -        echo("Use line 123, 120+3, -3, 120-123, 120-, note -3 is not 0-3 but (cur-3,cur)")
      -    }
      -  }
      -
      -  /** fork a shell and run a command */
      -  lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") {
      -    override def usage = ""
      -    def apply(line: String): Result = line match {
      -      case ""   => showUsage()
      -      case _    =>
      -        val toRun = s"new ${classOf[ProcessResult].getName}(${string2codeQuoted(line)})"
      -        intp interpret toRun
      -        ()
      -    }
      -  }
      -
      -  def withFile[A](filename: String)(action: File => A): Option[A] = {
      -    val res = Some(File(filename)) filter (_.exists) map action
      -    if (res.isEmpty) echo("That file does not exist")  // courtesy side-effect
      -    res
      -  }
      -
      -  def loadCommand(arg: String) = {
      -    var shouldReplay: Option[String] = None
      -    withFile(arg)(f => {
      -      interpretAllFrom(f)
      -      shouldReplay = Some(":load " + arg)
      -    })
      -    Result(keepRunning = true, shouldReplay)
      -  }
      -
      -  def saveCommand(filename: String): Result = (
      -    if (filename.isEmpty) echo("File name is required.")
      -    else if (replayCommandStack.isEmpty) echo("No replay commands in session")
      -    else File(filename).printlnAll(replayCommands: _*)
      -    )
      -
      -  def addClasspath(arg: String): Unit = {
      -    val f = File(arg).normalize
      -    if (f.exists) {
      -      addedClasspath = ClassPath.join(addedClasspath, f.path)
      -      val totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath)
      -      echo("Added '%s'.  Your new classpath is:\n\"%s\"".format(f.path, totalClasspath))
      -      replay()
      -    }
      -    else echo("The path '" + f + "' doesn't seem to exist.")
      -  }
      -
      -  def powerCmd(): Result = {
      -    if (isReplPower) "Already in power mode."
      -    else enablePowerMode(isDuringInit = false)
      -  }
      -  def enablePowerMode(isDuringInit: Boolean) = {
      -    replProps.power setValue true
      -    unleashAndSetPhase()
      -    // asyncEcho(isDuringInit, power.banner)
      -  }
      -  private def unleashAndSetPhase() {
      -    if (isReplPower) {
      -    //  power.unleash()
      -      // Set the phase to "typer"
      -      // intp beSilentDuring phaseCommand("typer")
      -    }
      -  }
      -
      -  def asyncEcho(async: Boolean, msg: => String) {
      -    if (async) asyncMessage(msg)
      -    else echo(msg)
      -  }
      -
      -  def verbosity() = {
      -    val old = intp.printResults
      -    intp.printResults = !old
      -    echo("Switched " + (if (old) "off" else "on") + " result printing.")
      -  }
      -
      -  /** Run one command submitted by the user.  Two values are returned:
      -    * (1) whether to keep running, (2) the line to record for replay,
      -    * if any. */
      -  def command(line: String): Result = {
      -    if (line startsWith ":") {
      -      val cmd = line.tail takeWhile (x => !x.isWhitespace)
      -      uniqueCommand(cmd) match {
      -        case Some(lc) => lc(line.tail stripPrefix cmd dropWhile (_.isWhitespace))
      -        case _        => ambiguousError(cmd)
      -      }
      -    }
      -    else if (intp.global == null) Result(keepRunning = false, None)  // Notice failure to create compiler
      -    else Result(keepRunning = true, interpretStartingWith(line))
      -  }
      -
      -  private def readWhile(cond: String => Boolean) = {
      -    Iterator continually in.readLine("") takeWhile (x => x != null && cond(x))
      -  }
      -
      -  def pasteCommand(arg: String): Result = {
      -    var shouldReplay: Option[String] = None
      -    def result = Result(keepRunning = true, shouldReplay)
      -    val (raw, file) =
      -      if (arg.isEmpty) (false, None)
      -      else {
      -        val r = """(-raw)?(\s+)?([^\-]\S*)?""".r
      -        arg match {
      -          case r(flag, sep, name) =>
      -            if (flag != null && name != null && sep == null)
      -              echo(s"""I assume you mean "$flag $name"?""")
      -            (flag != null, Option(name))
      -          case _ =>
      -            echo("usage: :paste -raw file")
      -            return result
      -        }
      -      }
      -    val code = file match {
      -      case Some(name) =>
      -        withFile(name)(f => {
      -          shouldReplay = Some(s":paste $arg")
      -          val s = f.slurp.trim
      -          if (s.isEmpty) echo(s"File contains no code: $f")
      -          else echo(s"Pasting file $f...")
      -          s
      -        }) getOrElse ""
      -      case None =>
      -        echo("// Entering paste mode (ctrl-D to finish)\n")
      -        val text = (readWhile(_ => true) mkString "\n").trim
      -        if (text.isEmpty) echo("\n// Nothing pasted, nothing gained.\n")
      -        else echo("\n// Exiting paste mode, now interpreting.\n")
      -        text
      -    }
      -    def interpretCode() = {
      -      val res = intp interpret code
      -      // if input is incomplete, let the compiler try to say why
      -      if (res == IR.Incomplete) {
      -        echo("The pasted code is incomplete!\n")
      -        // Remembrance of Things Pasted in an object
      -        val errless = intp compileSources new BatchSourceFile("", s"object pastel {\n$code\n}")
      -        if (errless) echo("...but compilation found no error? Good luck with that.")
      -      }
      -    }
      -    def compileCode() = {
      -      val errless = intp compileSources new BatchSourceFile("", code)
      -      if (!errless) echo("There were compilation errors!")
      -    }
      -    if (code.nonEmpty) {
      -      if (raw) compileCode() else interpretCode()
      -    }
      -    result
      -  }
      -
      -  private object paste extends Pasted {
      -    val ContinueString = "     | "
      -    val PromptString   = "scala> "
      -
      -    def interpret(line: String): Unit = {
      -      echo(line.trim)
      -      intp interpret line
      -      echo("")
      -    }
      -
      -    def transcript(start: String) = {
      -      echo("\n// Detected repl transcript paste: ctrl-D to finish.\n")
      -      apply(Iterator(start) ++ readWhile(_.trim != PromptString.trim))
      -    }
      -  }
      -  import paste.{ ContinueString, PromptString }
      -
      -  /** Interpret expressions starting with the first line.
      -    * Read lines until a complete compilation unit is available
      -    * or until a syntax error has been seen.  If a full unit is
      -    * read, go ahead and interpret it.  Return the full string
      -    * to be recorded for replay, if any.
      -    */
      -  def interpretStartingWith(code: String): Option[String] = {
      -    // signal completion non-completion input has been received
      -    in.completion.resetVerbosity()
      -
      -    def reallyInterpret = {
      -      val reallyResult = intp.interpret(code)
      -      (reallyResult, reallyResult match {
      -        case IR.Error       => None
      -        case IR.Success     => Some(code)
      -        case IR.Incomplete  =>
      -          if (in.interactive && code.endsWith("\n\n")) {
      -            echo("You typed two blank lines.  Starting a new command.")
      -            None
      -          }
      -          else in.readLine(ContinueString) match {
      -            case null =>
      -              // we know compilation is going to fail since we're at EOF and the
      -              // parser thinks the input is still incomplete, but since this is
      -              // a file being read non-interactively we want to fail.  So we send
      -              // it straight to the compiler for the nice error message.
      -              intp.compileString(code)
      -              None
      -
      -            case line => interpretStartingWith(code + "\n" + line)
      -          }
      -      })
      -    }
      -
      -    /** Here we place ourselves between the user and the interpreter and examine
      -      *  the input they are ostensibly submitting.  We intervene in several cases:
      -      *
      -      *  1) If the line starts with "scala> " it is assumed to be an interpreter paste.
      -      *  2) If the line starts with "." (but not ".." or "./") it is treated as an invocation
      -      *     on the previous result.
      -      *  3) If the Completion object's execute returns Some(_), we inject that value
      -      *     and avoid the interpreter, as it's likely not valid scala code.
      -      */
      -    if (code == "") None
      -    else if (!paste.running && code.trim.startsWith(PromptString)) {
      -      paste.transcript(code)
      -      None
      -    }
      -    else if (Completion.looksLikeInvocation(code) && intp.mostRecentVar != "") {
      -      interpretStartingWith(intp.mostRecentVar + code)
      -    }
      -    else if (code.trim startsWith "//") {
      -      // line comment, do nothing
      -      None
      -    }
      -    else
      -      reallyInterpret._2
      -  }
      -
      -  // runs :load `file` on any files passed via -i
      -  def loadFiles(settings: Settings) = settings match {
      -    case settings: GenericRunnerSettings =>
      -      for (filename <- settings.loadfiles.value) {
      -        val cmd = ":load " + filename
      -        command(cmd)
      -        addReplay(cmd)
      -        echo("")
      -      }
      -    case _ =>
      -  }
      -
      -  /** Tries to create a JLineReader, falling back to SimpleReader:
      -    *  unless settings or properties are such that it should start
      -    *  with SimpleReader.
      -    */
      -  def chooseReader(settings: Settings): InteractiveReader = {
      -    if (settings.Xnojline || Properties.isEmacsShell)
      -      SimpleReader()
      -    else try new JLineReader(
      -      if (settings.noCompletion) NoCompletion
      -      else new SparkJLineCompletion(intp)
      -    )
      -    catch {
      -      case ex @ (_: Exception | _: NoClassDefFoundError) =>
      -        echo("Failed to created JLineReader: " + ex + "\nFalling back to SimpleReader.")
      -        SimpleReader()
      -    }
      -  }
      -  protected def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] =
      -    u.TypeTag[T](
      -      m,
      -      new TypeCreator {
      -        def apply[U <: Universe with Singleton](m: Mirror[U]): U # Type =
      -          m.staticClass(classTag[T].runtimeClass.getName).toTypeConstructor.asInstanceOf[U # Type]
      -      })
      -
      -  private def loopPostInit() {
      -    // Bind intp somewhere out of the regular namespace where
      -    // we can get at it in generated code.
      -    intp.quietBind(NamedParam[SparkIMain]("$intp", intp)(tagOfStaticClass[SparkIMain], classTag[SparkIMain]))
      -    // Auto-run code via some setting.
      -    ( replProps.replAutorunCode.option
      -      flatMap (f => io.File(f).safeSlurp())
      -      foreach (intp quietRun _)
      -      )
      -    // classloader and power mode setup
      -    intp.setContextClassLoader()
      -    if (isReplPower) {
      -     // replProps.power setValue true
      -     // unleashAndSetPhase()
      -     // asyncMessage(power.banner)
      -    }
      -    // SI-7418 Now, and only now, can we enable TAB completion.
      -    in match {
      -      case x: JLineReader => x.consoleReader.postInit
      -      case _              =>
      -    }
      -  }
      -  def process(settings: Settings): Boolean = savingContextLoader {
      -    this.settings = settings
      -    createInterpreter()
      -
      -    // sets in to some kind of reader depending on environmental cues
      -    in = in0.fold(chooseReader(settings))(r => SimpleReader(r, out, interactive = true))
      -    globalFuture = future {
      -      intp.initializeSynchronous()
      -      loopPostInit()
      -      !intp.reporter.hasErrors
      -    }
      -    import scala.concurrent.duration._
      -    Await.ready(globalFuture, 10 seconds)
      -    printWelcome()
      +  override def commands: List[LoopCommand] = sparkStandardCommands
      +
      +  /** 
      +   * We override `loadFiles` because we need to initialize Spark *before* the REPL
      +   * sees any files, so that the Spark context is visible in those files. This is a bit of a
      +   * hack, but there isn't another hook available to us at this point.
      +   */
      +  override def loadFiles(settings: Settings): Unit = {
           initializeSpark()
      -    loadFiles(settings)
      -
      -    try loop()
      -    catch AbstractOrMissingHandler()
      -    finally closeInterpreter()
      -
      -    true
      +    super.loadFiles(settings)
         }
      -
      -  @deprecated("Use `process` instead", "2.9.0")
      -  def main(settings: Settings): Unit = process(settings) //used by sbt
       }
       
       object SparkILoop {
      -  implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp
       
      -  // Designed primarily for use by test code: take a String with a
      -  // bunch of code, and prints out a transcript of what it would look
      -  // like if you'd just typed it into the repl.
      -  def runForTranscript(code: String, settings: Settings): String = {
      -    import java.io.{ BufferedReader, StringReader, OutputStreamWriter }
      -
      -    stringFromStream { ostream =>
      -      Console.withOut(ostream) {
      -        val output = new JPrintWriter(new OutputStreamWriter(ostream), true) {
      -          override def write(str: String) = {
      -            // completely skip continuation lines
      -            if (str forall (ch => ch.isWhitespace || ch == '|')) ()
      -            else super.write(str)
      -          }
      -        }
      -        val input = new BufferedReader(new StringReader(code.trim + "\n")) {
      -          override def readLine(): String = {
      -            val s = super.readLine()
      -            // helping out by printing the line being interpreted.
      -            if (s != null)
      -              output.println(s)
      -            s
      -          }
      -        }
      -        val repl = new SparkILoop(input, output)
      -        if (settings.classpath.isDefault)
      -          settings.classpath.value = sys.props("java.class.path")
      -
      -        repl process settings
      -      }
      -    }
      -  }
      -
      -  /** Creates an interpreter loop with default settings and feeds
      -    *  the given code to it as input.
      -    */
      +  /** 
      +   * Creates an interpreter loop with default settings and feeds
      +   * the given code to it as input.
      +   */
         def run(code: String, sets: Settings = new Settings): String = {
           import java.io.{ BufferedReader, StringReader, OutputStreamWriter }
       
           stringFromStream { ostream =>
             Console.withOut(ostream) {
      -        val input    = new BufferedReader(new StringReader(code))
      -        val output   = new JPrintWriter(new OutputStreamWriter(ostream), true)
      -        val repl     = new SparkILoop(input, output)
      +        val input = new BufferedReader(new StringReader(code))
      +        val output = new JPrintWriter(new OutputStreamWriter(ostream), true)
      +        val repl = new SparkILoop(input, output)
       
               if (sets.classpath.isDefault)
                 sets.classpath.value = sys.props("java.class.path")
      @@ -975,5 +118,5 @@ object SparkILoop {
             }
           }
         }
      -  def run(lines: List[String]): String = run(lines map (_ + "\n") mkString)
      +  def run(lines: List[String]): String = run(lines.map(_ + "\n").mkString)
       }
      diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala
      deleted file mode 100644
      index 1cb910f37606..000000000000
      --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala
      +++ /dev/null
      @@ -1,1319 +0,0 @@
      -/* NSC -- new Scala compiler
      - * Copyright 2005-2013 LAMP/EPFL
      - * @author  Martin Odersky
      - */
      -
      -package scala
      -package tools.nsc
      -package interpreter
      -
      -import PartialFunction.cond
      -import scala.language.implicitConversions
      -import scala.beans.BeanProperty
      -import scala.collection.mutable
      -import scala.concurrent.{ Future, ExecutionContext }
      -import scala.reflect.runtime.{ universe => ru }
      -import scala.reflect.{ ClassTag, classTag }
      -import scala.reflect.internal.util.{ BatchSourceFile, SourceFile }
      -import scala.tools.util.PathResolver
      -import scala.tools.nsc.io.AbstractFile
      -import scala.tools.nsc.typechecker.{ TypeStrings, StructuredTypeStrings }
      -import scala.tools.nsc.util.{ ScalaClassLoader, stringFromReader, stringFromWriter, StackTraceOps }
      -import scala.tools.nsc.util.Exceptional.unwrap
      -import javax.script.{AbstractScriptEngine, Bindings, ScriptContext, ScriptEngine, ScriptEngineFactory, ScriptException, CompiledScript, Compilable}
      -
      -/** An interpreter for Scala code.
      -  *
      -  *  The main public entry points are compile(), interpret(), and bind().
      -  *  The compile() method loads a complete Scala file.  The interpret() method
      -  *  executes one line of Scala code at the request of the user.  The bind()
      -  *  method binds an object to a variable that can then be used by later
      -  *  interpreted code.
      -  *
      -  *  The overall approach is based on compiling the requested code and then
      -  *  using a Java classloader and Java reflection to run the code
      -  *  and access its results.
      -  *
      -  *  In more detail, a single compiler instance is used
      -  *  to accumulate all successfully compiled or interpreted Scala code.  To
      -  *  "interpret" a line of code, the compiler generates a fresh object that
      -  *  includes the line of code and which has public member(s) to export
      -  *  all variables defined by that code.  To extract the result of an
      -  *  interpreted line to show the user, a second "result object" is created
      -  *  which imports the variables exported by the above object and then
      -  *  exports members called "$eval" and "$print". To accomodate user expressions
      -  *  that read from variables or methods defined in previous statements, "import"
      -  *  statements are used.
      -  *
      -  *  This interpreter shares the strengths and weaknesses of using the
      -  *  full compiler-to-Java.  The main strength is that interpreted code
      -  *  behaves exactly as does compiled code, including running at full speed.
      -  *  The main weakness is that redefining classes and methods is not handled
      -  *  properly, because rebinding at the Java level is technically difficult.
      -  *
      -  *  @author Moez A. Abdel-Gawad
      -  *  @author Lex Spoon
      -  */
      -class SparkIMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings: Settings,
      -  protected val out: JPrintWriter) extends AbstractScriptEngine with Compilable with SparkImports {
      -  imain =>
      -
      -  setBindings(createBindings, ScriptContext.ENGINE_SCOPE)
      -  object replOutput extends ReplOutput(settings.Yreploutdir) { }
      -
      -  @deprecated("Use replOutput.dir instead", "2.11.0")
      -  def virtualDirectory = replOutput.dir
      -  // Used in a test case.
      -  def showDirectory() = replOutput.show(out)
      -
      -  private[nsc] var printResults               = true      // whether to print result lines
      -  private[nsc] var totalSilence               = false     // whether to print anything
      -  private var _initializeComplete             = false     // compiler is initialized
      -  private var _isInitialized: Future[Boolean] = null      // set up initialization future
      -  private var bindExceptions                  = true      // whether to bind the lastException variable
      -  private var _executionWrapper               = ""        // code to be wrapped around all lines
      -
      -  /** We're going to go to some trouble to initialize the compiler asynchronously.
      -    *  It's critical that nothing call into it until it's been initialized or we will
      -    *  run into unrecoverable issues, but the perceived repl startup time goes
      -    *  through the roof if we wait for it.  So we initialize it with a future and
      -    *  use a lazy val to ensure that any attempt to use the compiler object waits
      -    *  on the future.
      -    */
      -  private var _classLoader: util.AbstractFileClassLoader = null                              // active classloader
      -  private val _compiler: ReplGlobal                 = newCompiler(settings, reporter)   // our private compiler
      -
      -  def compilerClasspath: Seq[java.net.URL] = (
      -    if (isInitializeComplete) global.classPath.asURLs
      -    else new PathResolver(settings).result.asURLs  // the compiler's classpath
      -    )
      -  def settings = initialSettings
      -  // Run the code body with the given boolean settings flipped to true.
      -  def withoutWarnings[T](body: => T): T = beQuietDuring {
      -    val saved = settings.nowarn.value
      -    if (!saved)
      -      settings.nowarn.value = true
      -
      -    try body
      -    finally if (!saved) settings.nowarn.value = false
      -  }
      -
      -  /** construct an interpreter that reports to Console */
      -  def this(settings: Settings, out: JPrintWriter) = this(null, settings, out)
      -  def this(factory: ScriptEngineFactory, settings: Settings) = this(factory, settings, new NewLinePrintWriter(new ConsoleWriter, true))
      -  def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true))
      -  def this(factory: ScriptEngineFactory) = this(factory, new Settings())
      -  def this() = this(new Settings())
      -
      -  lazy val formatting: Formatting = new Formatting {
      -    val prompt = Properties.shellPromptString
      -  }
      -  lazy val reporter: SparkReplReporter = new SparkReplReporter(this)
      -
      -  import formatting._
      -  import reporter.{ printMessage, printUntruncatedMessage }
      -
      -  // This exists mostly because using the reporter too early leads to deadlock.
      -  private def echo(msg: String) { Console println msg }
      -  private def _initSources = List(new BatchSourceFile("", "class $repl_$init { }"))
      -  private def _initialize() = {
      -    try {
      -      // if this crashes, REPL will hang its head in shame
      -      val run = new _compiler.Run()
      -      assert(run.typerPhase != NoPhase, "REPL requires a typer phase.")
      -      run compileSources _initSources
      -      _initializeComplete = true
      -      true
      -    }
      -    catch AbstractOrMissingHandler()
      -  }
      -  private def tquoted(s: String) = "\"\"\"" + s + "\"\"\""
      -  private val logScope = scala.sys.props contains "scala.repl.scope"
      -  private def scopelog(msg: String) = if (logScope) Console.err.println(msg)
      -
      -  // argument is a thunk to execute after init is done
      -  def initialize(postInitSignal: => Unit) {
      -    synchronized {
      -      if (_isInitialized == null) {
      -        _isInitialized =
      -          Future(try _initialize() finally postInitSignal)(ExecutionContext.global)
      -      }
      -    }
      -  }
      -  def initializeSynchronous(): Unit = {
      -    if (!isInitializeComplete) {
      -      _initialize()
      -      assert(global != null, global)
      -    }
      -  }
      -  def isInitializeComplete = _initializeComplete
      -
      -  lazy val global: Global = {
      -    if (!isInitializeComplete) _initialize()
      -    _compiler
      -  }
      -
      -  import global._
      -  import definitions.{ ObjectClass, termMember, dropNullaryMethod}
      -
      -  lazy val runtimeMirror = ru.runtimeMirror(classLoader)
      -
      -  private def noFatal(body: => Symbol): Symbol = try body catch { case _: FatalError => NoSymbol }
      -
      -  def getClassIfDefined(path: String)  = (
      -    noFatal(runtimeMirror staticClass path)
      -      orElse noFatal(rootMirror staticClass path)
      -    )
      -  def getModuleIfDefined(path: String) = (
      -    noFatal(runtimeMirror staticModule path)
      -      orElse noFatal(rootMirror staticModule path)
      -    )
      -
      -  implicit class ReplTypeOps(tp: Type) {
      -    def andAlso(fn: Type => Type): Type = if (tp eq NoType) tp else fn(tp)
      -  }
      -
      -  // TODO: If we try to make naming a lazy val, we run into big time
      -  // scalac unhappiness with what look like cycles.  It has not been easy to
      -  // reduce, but name resolution clearly takes different paths.
      -  object naming extends {
      -    val global: imain.global.type = imain.global
      -  } with Naming {
      -    // make sure we don't overwrite their unwisely named res3 etc.
      -    def freshUserTermName(): TermName = {
      -      val name = newTermName(freshUserVarName())
      -      if (replScope containsName name) freshUserTermName()
      -      else name
      -    }
      -    def isInternalTermName(name: Name) = isInternalVarName("" + name)
      -  }
      -  import naming._
      -
      -  object deconstruct extends {
      -    val global: imain.global.type = imain.global
      -  } with StructuredTypeStrings
      -
      -  lazy val memberHandlers = new {
      -    val intp: imain.type = imain
      -  } with SparkMemberHandlers
      -  import memberHandlers._
      -
      -  /** Temporarily be quiet */
      -  def beQuietDuring[T](body: => T): T = {
      -    val saved = printResults
      -    printResults = false
      -    try body
      -    finally printResults = saved
      -  }
      -  def beSilentDuring[T](operation: => T): T = {
      -    val saved = totalSilence
      -    totalSilence = true
      -    try operation
      -    finally totalSilence = saved
      -  }
      -
      -  def quietRun[T](code: String) = beQuietDuring(interpret(code))
      -
      -  /** takes AnyRef because it may be binding a Throwable or an Exceptional */
      -  private def withLastExceptionLock[T](body: => T, alt: => T): T = {
      -    assert(bindExceptions, "withLastExceptionLock called incorrectly.")
      -    bindExceptions = false
      -
      -    try     beQuietDuring(body)
      -    catch   logAndDiscard("withLastExceptionLock", alt)
      -    finally bindExceptions = true
      -  }
      -
      -  def executionWrapper = _executionWrapper
      -  def setExecutionWrapper(code: String) = _executionWrapper = code
      -  def clearExecutionWrapper() = _executionWrapper = ""
      -
      -  /** interpreter settings */
      -  lazy val isettings = new SparkISettings(this)
      -
      -  /** Instantiate a compiler.  Overridable. */
      -  protected def newCompiler(settings: Settings, reporter: reporters.Reporter): ReplGlobal = {
      -    settings.outputDirs setSingleOutput replOutput.dir
      -    settings.exposeEmptyPackage.value = true
      -    new Global(settings, reporter) with ReplGlobal { override def toString: String = "" }
      -  }
      -
      -  /** Parent classloader.  Overridable. */
      -  protected def parentClassLoader: ClassLoader =
      -    settings.explicitParentLoader.getOrElse( this.getClass.getClassLoader() )
      -
      -  /* A single class loader is used for all commands interpreted by this Interpreter.
      -     It would also be possible to create a new class loader for each command
      -     to interpret.  The advantages of the current approach are:
      -
      -       - Expressions are only evaluated one time.  This is especially
      -         significant for I/O, e.g. "val x = Console.readLine"
      -
      -     The main disadvantage is:
      -
      -       - Objects, classes, and methods cannot be rebound.  Instead, definitions
      -         shadow the old ones, and old code objects refer to the old
      -         definitions.
      -  */
      -  def resetClassLoader() = {
      -    repldbg("Setting new classloader: was " + _classLoader)
      -    _classLoader = null
      -    ensureClassLoader()
      -  }
      -  final def ensureClassLoader() {
      -    if (_classLoader == null)
      -      _classLoader = makeClassLoader()
      -  }
      -  def classLoader: util.AbstractFileClassLoader = {
      -    ensureClassLoader()
      -    _classLoader
      -  }
      -
      -  def backticked(s: String): String = (
      -    (s split '.').toList map {
      -      case "_"                               => "_"
      -      case s if nme.keywords(newTermName(s)) => s"`$s`"
      -      case s                                 => s
      -    } mkString "."
      -    )
      -  def readRootPath(readPath: String) = getModuleIfDefined(readPath)
      -
      -  abstract class PhaseDependentOps {
      -    def shift[T](op: => T): T
      -
      -    def path(name: => Name): String = shift(path(symbolOfName(name)))
      -    def path(sym: Symbol): String = backticked(shift(sym.fullName))
      -    def sig(sym: Symbol): String  = shift(sym.defString)
      -  }
      -  object typerOp extends PhaseDependentOps {
      -    def shift[T](op: => T): T = exitingTyper(op)
      -  }
      -  object flatOp extends PhaseDependentOps {
      -    def shift[T](op: => T): T = exitingFlatten(op)
      -  }
      -
      -  def originalPath(name: String): String = originalPath(name: TermName)
      -  def originalPath(name: Name): String   = typerOp path name
      -  def originalPath(sym: Symbol): String  = typerOp path sym
      -  def flatPath(sym: Symbol): String      = flatOp shift sym.javaClassName
      -  def translatePath(path: String) = {
      -    val sym = if (path endsWith "$") symbolOfTerm(path.init) else symbolOfIdent(path)
      -    sym.toOption map flatPath
      -  }
      -  def translateEnclosingClass(n: String) = symbolOfTerm(n).enclClass.toOption map flatPath
      -
      -  private class TranslatingClassLoader(parent: ClassLoader) extends util.AbstractFileClassLoader(replOutput.dir, parent) {
      -    /** Overridden here to try translating a simple name to the generated
      -      *  class name if the original attempt fails.  This method is used by
      -      *  getResourceAsStream as well as findClass.
      -      */
      -    override protected def findAbstractFile(name: String): AbstractFile =
      -      super.findAbstractFile(name) match {
      -        case null if _initializeComplete => translatePath(name) map (super.findAbstractFile(_)) orNull
      -        case file => file
      -      }
      -  }
      -  private def makeClassLoader(): util.AbstractFileClassLoader =
      -    new TranslatingClassLoader(parentClassLoader match {
      -      case null   => ScalaClassLoader fromURLs compilerClasspath
      -      case p      => new ScalaClassLoader.URLClassLoader(compilerClasspath, p)
      -    })
      -
      -  // Set the current Java "context" class loader to this interpreter's class loader
      -  def setContextClassLoader() = classLoader.setAsContext()
      -
      -  def allDefinedNames: List[Name]  = exitingTyper(replScope.toList.map(_.name).sorted)
      -  def unqualifiedIds: List[String] = allDefinedNames map (_.decode) sorted
      -
      -  /** Most recent tree handled which wasn't wholly synthetic. */
      -  private def mostRecentlyHandledTree: Option[Tree] = {
      -    prevRequests.reverse foreach { req =>
      -      req.handlers.reverse foreach {
      -        case x: MemberDefHandler if x.definesValue && !isInternalTermName(x.name) => return Some(x.member)
      -        case _ => ()
      -      }
      -    }
      -    None
      -  }
      -
      -  private def updateReplScope(sym: Symbol, isDefined: Boolean) {
      -    def log(what: String) {
      -      val mark = if (sym.isType) "t " else "v "
      -      val name = exitingTyper(sym.nameString)
      -      val info = cleanTypeAfterTyper(sym)
      -      val defn = sym defStringSeenAs info
      -
      -      scopelog(f"[$mark$what%6s] $name%-25s $defn%s")
      -    }
      -    if (ObjectClass isSubClass sym.owner) return
      -    // unlink previous
      -    replScope lookupAll sym.name foreach { sym =>
      -      log("unlink")
      -      replScope unlink sym
      -    }
      -    val what = if (isDefined) "define" else "import"
      -    log(what)
      -    replScope enter sym
      -  }
      -
      -  def recordRequest(req: Request) {
      -    if (req == null)
      -      return
      -
      -    prevRequests += req
      -
      -    // warning about serially defining companions.  It'd be easy
      -    // enough to just redefine them together but that may not always
      -    // be what people want so I'm waiting until I can do it better.
      -    exitingTyper {
      -      req.defines filterNot (s => req.defines contains s.companionSymbol) foreach { newSym =>
      -        val oldSym = replScope lookup newSym.name.companionName
      -        if (Seq(oldSym, newSym).permutations exists { case Seq(s1, s2) => s1.isClass && s2.isModule }) {
      -          replwarn(s"warning: previously defined $oldSym is not a companion to $newSym.")
      -          replwarn("Companions must be defined together; you may wish to use :paste mode for this.")
      -        }
      -      }
      -    }
      -    exitingTyper {
      -      req.imports foreach (sym => updateReplScope(sym, isDefined = false))
      -      req.defines foreach (sym => updateReplScope(sym, isDefined = true))
      -    }
      -  }
      -
      -  private[nsc] def replwarn(msg: => String) {
      -    if (!settings.nowarnings)
      -      printMessage(msg)
      -  }
      -
      -  def compileSourcesKeepingRun(sources: SourceFile*) = {
      -    val run = new Run()
      -    assert(run.typerPhase != NoPhase, "REPL requires a typer phase.")
      -    reporter.reset()
      -    run compileSources sources.toList
      -    (!reporter.hasErrors, run)
      -  }
      -
      -  /** Compile an nsc SourceFile.  Returns true if there are
      -    *  no compilation errors, or false otherwise.
      -    */
      -  def compileSources(sources: SourceFile*): Boolean =
      -    compileSourcesKeepingRun(sources: _*)._1
      -
      -  /** Compile a string.  Returns true if there are no
      -    *  compilation errors, or false otherwise.
      -    */
      -  def compileString(code: String): Boolean =
      -    compileSources(new BatchSourceFile("
      +    
      +    
      +    
      +    // scalastyle:on
      +  }
      +
      +  private def planVisualization(metrics: Map[Long, Any], graph: SparkPlanGraph): Seq[Node] = {
      +    val metadata = graph.nodes.flatMap { node =>
      +      val nodeId = s"plan-meta-data-${node.id}"
      +      
      {node.desc}
      + } + +
      +
      + + {planVisualizationResources} + +
      + } + + private def jobURL(jobId: Long): String = + "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), jobId) + + private def physicalPlanDescription(physicalPlanDescription: String): Seq[Node] = { +
      + + + Details + +
      + + +
      + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala new file mode 100644 index 000000000000..5779c71f64e9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -0,0 +1,352 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import scala.collection.mutable + +import com.google.common.annotations.VisibleForTesting + +import org.apache.spark.{JobExecutionStatus, Logging} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler._ +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} + +private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener with Logging { + + private val retainedExecutions = + sqlContext.sparkContext.conf.getInt("spark.sql.ui.retainedExecutions", 1000) + + private val activeExecutions = mutable.HashMap[Long, SQLExecutionUIData]() + + // Old data in the following fields must be removed in "trimExecutionsIfNecessary". + // If adding new fields, make sure "trimExecutionsIfNecessary" can clean up old data + private val _executionIdToData = mutable.HashMap[Long, SQLExecutionUIData]() + + /** + * Maintain the relation between job id and execution id so that we can get the execution id in + * the "onJobEnd" method. + */ + private val _jobIdToExecutionId = mutable.HashMap[Long, Long]() + + private val _stageIdToStageMetrics = mutable.HashMap[Long, SQLStageMetrics]() + + private val failedExecutions = mutable.ListBuffer[SQLExecutionUIData]() + + private val completedExecutions = mutable.ListBuffer[SQLExecutionUIData]() + + def executionIdToData: Map[Long, SQLExecutionUIData] = synchronized { + _executionIdToData.toMap + } + + def jobIdToExecutionId: Map[Long, Long] = synchronized { + _jobIdToExecutionId.toMap + } + + def stageIdToStageMetrics: Map[Long, SQLStageMetrics] = synchronized { + _stageIdToStageMetrics.toMap + } + + private def trimExecutionsIfNecessary( + executions: mutable.ListBuffer[SQLExecutionUIData]): Unit = { + if (executions.size > retainedExecutions) { + val toRemove = math.max(retainedExecutions / 10, 1) + executions.take(toRemove).foreach { execution => + for (executionUIData <- _executionIdToData.remove(execution.executionId)) { + for (jobId <- executionUIData.jobs.keys) { + _jobIdToExecutionId.remove(jobId) + } + for (stageId <- executionUIData.stages) { + _stageIdToStageMetrics.remove(stageId) + } + } + } + executions.trimStart(toRemove) + } + } + + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + val executionIdString = jobStart.properties.getProperty(SQLExecution.EXECUTION_ID_KEY) + if (executionIdString == null) { + // This is not a job created by SQL + return + } + val executionId = executionIdString.toLong + val jobId = jobStart.jobId + val stageIds = jobStart.stageIds + + synchronized { + activeExecutions.get(executionId).foreach { executionUIData => + executionUIData.jobs(jobId) = JobExecutionStatus.RUNNING + executionUIData.stages ++= stageIds + stageIds.foreach(stageId => + _stageIdToStageMetrics(stageId) = new SQLStageMetrics(stageAttemptId = 0)) + _jobIdToExecutionId(jobId) = executionId + } + } + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized { + val jobId = jobEnd.jobId + for (executionId <- _jobIdToExecutionId.get(jobId); + executionUIData <- _executionIdToData.get(executionId)) { + jobEnd.jobResult match { + case JobSucceeded => executionUIData.jobs(jobId) = JobExecutionStatus.SUCCEEDED + case JobFailed(_) => executionUIData.jobs(jobId) = JobExecutionStatus.FAILED + } + if (executionUIData.completionTime.nonEmpty && !executionUIData.hasRunningJobs) { + // We are the last job of this execution, so mark the execution as finished. Note that + // `onExecutionEnd` also does this, but currently that can be called before `onJobEnd` + // since these are called on different threads. + markExecutionFinished(executionId) + } + } + } + + override def onExecutorMetricsUpdate( + executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { + for ((taskId, stageId, stageAttemptID, metrics) <- executorMetricsUpdate.taskMetrics) { + updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics, finishTask = false) + } + } + + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { + val stageId = stageSubmitted.stageInfo.stageId + val stageAttemptId = stageSubmitted.stageInfo.attemptId + // Always override metrics for old stage attempt + _stageIdToStageMetrics(stageId) = new SQLStageMetrics(stageAttemptId) + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { + updateTaskAccumulatorValues( + taskEnd.taskInfo.taskId, + taskEnd.stageId, + taskEnd.stageAttemptId, + taskEnd.taskMetrics, + finishTask = true) + } + + /** + * Update the accumulator values of a task with the latest metrics for this task. This is called + * every time we receive an executor heartbeat or when a task finishes. + */ + private def updateTaskAccumulatorValues( + taskId: Long, + stageId: Int, + stageAttemptID: Int, + metrics: TaskMetrics, + finishTask: Boolean): Unit = { + if (metrics == null) { + return + } + + _stageIdToStageMetrics.get(stageId) match { + case Some(stageMetrics) => + if (stageAttemptID < stageMetrics.stageAttemptId) { + // A task of an old stage attempt. Because a new stage is submitted, we can ignore it. + } else if (stageAttemptID > stageMetrics.stageAttemptId) { + logWarning(s"A task should not have a higher stageAttemptID ($stageAttemptID) then " + + s"what we have seen (${stageMetrics.stageAttemptId})") + } else { + // TODO We don't know the attemptId. Currently, what we can do is overriding the + // accumulator updates. However, if there are two same task are running, such as + // speculation, the accumulator updates will be overriding by different task attempts, + // the results will be weird. + stageMetrics.taskIdToMetricUpdates.get(taskId) match { + case Some(taskMetrics) => + if (finishTask) { + taskMetrics.finished = true + taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() + } else if (!taskMetrics.finished) { + taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() + } else { + // If a task is finished, we should not override with accumulator updates from + // heartbeat reports + } + case None => + // TODO Now just set attemptId to 0. Should fix here when we can get the attempt + // id from SparkListenerExecutorMetricsUpdate + stageMetrics.taskIdToMetricUpdates(taskId) = new SQLTaskMetrics( + attemptId = 0, finished = finishTask, metrics.accumulatorUpdates()) + } + } + case None => + // This execution and its stage have been dropped + } + } + + def onExecutionStart( + executionId: Long, + description: String, + details: String, + physicalPlanDescription: String, + physicalPlanGraph: SparkPlanGraph, + time: Long): Unit = { + val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => + node.metrics.map(metric => metric.accumulatorId -> metric) + } + + val executionUIData = new SQLExecutionUIData(executionId, description, details, + physicalPlanDescription, physicalPlanGraph, sqlPlanMetrics.toMap, time) + synchronized { + activeExecutions(executionId) = executionUIData + _executionIdToData(executionId) = executionUIData + } + } + + def onExecutionEnd(executionId: Long, time: Long): Unit = synchronized { + _executionIdToData.get(executionId).foreach { executionUIData => + executionUIData.completionTime = Some(time) + if (!executionUIData.hasRunningJobs) { + // onExecutionEnd happens after all "onJobEnd"s + // So we should update the execution lists. + markExecutionFinished(executionId) + } else { + // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. + // Then we don't if the execution is successful, so let the last onJobEnd updates the + // execution lists. + } + } + } + + private def markExecutionFinished(executionId: Long): Unit = { + activeExecutions.remove(executionId).foreach { executionUIData => + if (executionUIData.isFailed) { + failedExecutions += executionUIData + trimExecutionsIfNecessary(failedExecutions) + } else { + completedExecutions += executionUIData + trimExecutionsIfNecessary(completedExecutions) + } + } + } + + def getRunningExecutions: Seq[SQLExecutionUIData] = synchronized { + activeExecutions.values.toSeq + } + + def getFailedExecutions: Seq[SQLExecutionUIData] = synchronized { + failedExecutions + } + + def getCompletedExecutions: Seq[SQLExecutionUIData] = synchronized { + completedExecutions + } + + def getExecution(executionId: Long): Option[SQLExecutionUIData] = synchronized { + _executionIdToData.get(executionId) + } + + /** + * Get all accumulator updates from all tasks which belong to this execution and merge them. + */ + def getExecutionMetrics(executionId: Long): Map[Long, Any] = synchronized { + _executionIdToData.get(executionId) match { + case Some(executionUIData) => + val accumulatorUpdates = { + for (stageId <- executionUIData.stages; + stageMetrics <- _stageIdToStageMetrics.get(stageId).toIterable; + taskMetrics <- stageMetrics.taskIdToMetricUpdates.values; + accumulatorUpdate <- taskMetrics.accumulatorUpdates.toSeq) yield { + accumulatorUpdate + } + }.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) } + mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId => + executionUIData.accumulatorMetrics(accumulatorId).metricParam). + mapValues(_.asInstanceOf[SQLMetricValue[_]].value) + case None => + // This execution has been dropped + Map.empty + } + } + + private def mergeAccumulatorUpdates( + accumulatorUpdates: Seq[(Long, Any)], + paramFunc: Long => SQLMetricParam[SQLMetricValue[Any], Any]): Map[Long, Any] = { + accumulatorUpdates.groupBy(_._1).map { case (accumulatorId, values) => + val param = paramFunc(accumulatorId) + (accumulatorId, + values.map(_._2.asInstanceOf[SQLMetricValue[Any]]).foldLeft(param.zero)(param.addInPlace)) + } + } + +} + +/** + * Represent all necessary data for an execution that will be used in Web UI. + */ +private[ui] class SQLExecutionUIData( + val executionId: Long, + val description: String, + val details: String, + val physicalPlanDescription: String, + val physicalPlanGraph: SparkPlanGraph, + val accumulatorMetrics: Map[Long, SQLPlanMetric], + val submissionTime: Long, + var completionTime: Option[Long] = None, + val jobs: mutable.HashMap[Long, JobExecutionStatus] = mutable.HashMap.empty, + val stages: mutable.ArrayBuffer[Int] = mutable.ArrayBuffer()) { + + /** + * Return whether there are running jobs in this execution. + */ + def hasRunningJobs: Boolean = jobs.values.exists(_ == JobExecutionStatus.RUNNING) + + /** + * Return whether there are any failed jobs in this execution. + */ + def isFailed: Boolean = jobs.values.exists(_ == JobExecutionStatus.FAILED) + + def runningJobs: Seq[Long] = + jobs.filter { case (_, status) => status == JobExecutionStatus.RUNNING }.keys.toSeq + + def succeededJobs: Seq[Long] = + jobs.filter { case (_, status) => status == JobExecutionStatus.SUCCEEDED }.keys.toSeq + + def failedJobs: Seq[Long] = + jobs.filter { case (_, status) => status == JobExecutionStatus.FAILED }.keys.toSeq +} + +/** + * Represent a metric in a SQLPlan. + * + * Because we cannot revert our changes for an "Accumulator", we need to maintain accumulator + * updates for each task. So that if a task is retried, we can simply override the old updates with + * the new updates of the new attempt task. Since we cannot add them to accumulator, we need to use + * "AccumulatorParam" to get the aggregation value. + */ +private[ui] case class SQLPlanMetric( + name: String, + accumulatorId: Long, + metricParam: SQLMetricParam[SQLMetricValue[Any], Any]) + +/** + * Store all accumulatorUpdates for all tasks in a Spark stage. + */ +private[ui] class SQLStageMetrics( + val stageAttemptId: Long, + val taskIdToMetricUpdates: mutable.HashMap[Long, SQLTaskMetrics] = mutable.HashMap.empty) + +/** + * Store all accumulatorUpdates for a Spark task. + */ +private[ui] class SQLTaskMetrics( + val attemptId: Long, // TODO not used yet + var finished: Boolean, + var accumulatorUpdates: Map[Long, Any]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala new file mode 100644 index 000000000000..0b0867f67eb6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark.Logging +import org.apache.spark.sql.SQLContext +import org.apache.spark.ui.{SparkUI, SparkUITab} + +private[sql] class SQLTab(sqlContext: SQLContext, sparkUI: SparkUI) + extends SparkUITab(sparkUI, SQLTab.nextTabName) with Logging { + + val parent = sparkUI + val listener = sqlContext.listener + + attachPage(new AllExecutionsPage(this)) + attachPage(new ExecutionPage(this)) + parent.attachTab(this) + + parent.addStaticHandler(SQLTab.STATIC_RESOURCE_DIR, "/static/sql") +} + +private[sql] object SQLTab { + + private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/execution/ui/static" + + private val nextTabId = new AtomicInteger(0) + + private def nextTabName: String = { + val nextId = nextTabId.getAndIncrement() + if (nextId == 0) "SQL" else s"SQL$nextId" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala new file mode 100644 index 000000000000..ae3d752dde34 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import java.util.concurrent.atomic.AtomicLong + +import scala.collection.mutable + +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} + +/** + * A graph used for storing information of an executionPlan of DataFrame. + * + * Each graph is defined with a set of nodes and a set of edges. Each node represents a node in the + * SparkPlan tree, and each edge represents a parent-child relationship between two nodes. + */ +private[ui] case class SparkPlanGraph( + nodes: Seq[SparkPlanGraphNode], edges: Seq[SparkPlanGraphEdge]) { + + def makeDotFile(metrics: Map[Long, Any]): String = { + val dotFile = new StringBuilder + dotFile.append("digraph G {\n") + nodes.foreach(node => dotFile.append(node.makeDotNode(metrics) + "\n")) + edges.foreach(edge => dotFile.append(edge.makeDotEdge + "\n")) + dotFile.append("}") + dotFile.toString() + } +} + +private[sql] object SparkPlanGraph { + + /** + * Build a SparkPlanGraph from the root of a SparkPlan tree. + */ + def apply(plan: SparkPlan): SparkPlanGraph = { + val nodeIdGenerator = new AtomicLong(0) + val nodes = mutable.ArrayBuffer[SparkPlanGraphNode]() + val edges = mutable.ArrayBuffer[SparkPlanGraphEdge]() + buildSparkPlanGraphNode(plan, nodeIdGenerator, nodes, edges) + new SparkPlanGraph(nodes, edges) + } + + private def buildSparkPlanGraphNode( + plan: SparkPlan, + nodeIdGenerator: AtomicLong, + nodes: mutable.ArrayBuffer[SparkPlanGraphNode], + edges: mutable.ArrayBuffer[SparkPlanGraphEdge]): SparkPlanGraphNode = { + val metrics = plan.metrics.toSeq.map { case (key, metric) => + SQLPlanMetric(metric.name.getOrElse(key), metric.id, + metric.param.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]]) + } + val node = SparkPlanGraphNode( + nodeIdGenerator.getAndIncrement(), plan.nodeName, plan.simpleString, metrics) + nodes += node + val childrenNodes = plan.children.map( + child => buildSparkPlanGraphNode(child, nodeIdGenerator, nodes, edges)) + for (child <- childrenNodes) { + edges += SparkPlanGraphEdge(child.id, node.id) + } + node + } +} + +/** + * Represent a node in the SparkPlan tree, along with its metrics. + * + * @param id generated by "SparkPlanGraph". There is no duplicate id in a graph + * @param name the name of this SparkPlan node + * @param metrics metrics that this SparkPlan node will track + */ +private[ui] case class SparkPlanGraphNode( + id: Long, name: String, desc: String, metrics: Seq[SQLPlanMetric]) { + + def makeDotNode(metricsValue: Map[Long, Any]): String = { + val values = { + for (metric <- metrics; + value <- metricsValue.get(metric.accumulatorId)) yield { + metric.name + ": " + value + } + } + val label = if (values.isEmpty) { + name + } else { + // If there are metrics, display all metrics in a separate line. We should use an escaped + // "\n" here to follow the dot syntax. + // + // Note: whitespace between two "\n"s is to create an empty line between the name of + // SparkPlan and metrics. If removing it, it won't display the empty line in UI. + name + "\\n \\n" + values.mkString("\\n") + } + s""" $id [label="$label"];""" + } +} + +/** + * Represent an edge in the SparkPlan tree. `fromId` is the parent node id, and `toId` is the child + * node id. + */ +private[ui] case class SparkPlanGraphEdge(fromId: Long, toId: Long) { + + def makeDotEdge: String = s""" $fromId->$toId;\n""" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala new file mode 100644 index 000000000000..258afadc7695 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2} +import org.apache.spark.sql.execution.aggregate.ScalaUDAF +import org.apache.spark.sql.{Column, Row} +import org.apache.spark.sql.types._ +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * The base class for implementing user-defined aggregate functions (UDAF). + */ +@Experimental +abstract class UserDefinedAggregateFunction extends Serializable { + + /** + * A [[StructType]] represents data types of input arguments of this aggregate function. + * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments + * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like + * + * ``` + * new StructType() + * .add("doubleInput", DoubleType) + * .add("longInput", LongType) + * ``` + * + * The name of a field of this [[StructType]] is only used to identify the corresponding + * input argument. Users can choose names to identify the input arguments. + */ + def inputSchema: StructType + + /** + * A [[StructType]] represents data types of values in the aggregation buffer. + * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values + * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]], + * the returned [[StructType]] will look like + * + * ``` + * new StructType() + * .add("doubleInput", DoubleType) + * .add("longInput", LongType) + * ``` + * + * The name of a field of this [[StructType]] is only used to identify the corresponding + * buffer value. Users can choose names to identify the input arguments. + */ + def bufferSchema: StructType + + /** + * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]]. + */ + def dataType: DataType + + /** + * Returns true iff this function is deterministic, i.e. given the same input, + * always return the same output. + */ + def deterministic: Boolean + + /** + * Initializes the given aggregation buffer, i.e. the zero value of the aggregation buffer. + * + * The contract should be that applying the merge function on two initial buffers should just + * return the initial buffer itself, i.e. + * `merge(initialBuffer, initialBuffer)` should equal `initialBuffer`. + */ + def initialize(buffer: MutableAggregationBuffer): Unit + + /** + * Updates the given aggregation buffer `buffer` with new input data from `input`. + * + * This is called once per input row. + */ + def update(buffer: MutableAggregationBuffer, input: Row): Unit + + /** + * Merges two aggregation buffers and stores the updated buffer values back to `buffer1`. + * + * This is called when we merge two partially aggregated data together. + */ + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit + + /** + * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given + * aggregation buffer. + */ + def evaluate(buffer: Row): Any + + /** + * Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments. + */ + @scala.annotation.varargs + def apply(exprs: Column*): Column = { + val aggregateExpression = + AggregateExpression2( + ScalaUDAF(exprs.map(_.expr), this), + Complete, + isDistinct = false) + Column(aggregateExpression) + } + + /** + * Creates a [[Column]] for this UDAF using the distinct values of the given + * [[Column]]s as input arguments. + */ + @scala.annotation.varargs + def distinct(exprs: Column*): Column = { + val aggregateExpression = + AggregateExpression2( + ScalaUDAF(exprs.map(_.expr), this), + Complete, + isDistinct = true) + Column(aggregateExpression) + } +} + +/** + * :: Experimental :: + * A [[Row]] representing an mutable aggregation buffer. + * + * This is not meant to be extended outside of Spark. + */ +@Experimental +abstract class MutableAggregationBuffer extends Row { + + /** Update the ith value of this buffer. */ + def update(i: Int, value: Any): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 7e7a099a8318..60d9c509104d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} +import scala.util.Try import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -33,12 +35,14 @@ import org.apache.spark.util.Utils * * @groupname udf_funcs UDF functions * @groupname agg_funcs Aggregate functions + * @groupname datetime_funcs Date time functions * @groupname sort_funcs Sorting functions * @groupname normal_funcs Non-aggregate functions * @groupname math_funcs Math functions * @groupname misc_funcs Misc functions * @groupname window_funcs Window functions * @groupname string_funcs String functions + * @groupname collection_funcs Collection functions * @groupname Ungrouped Support functions for DataFrames. * @since 1.3.0 */ @@ -119,36 +123,54 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Aggregate function: returns the sum of all values in the expression. + * Aggregate function: returns the approximate number of distinct items in a group. * * @group agg_funcs * @since 1.3.0 */ - def sum(e: Column): Column = Sum(e.expr) + def approxCountDistinct(e: Column): Column = ApproxCountDistinct(e.expr) /** - * Aggregate function: returns the sum of all values in the given column. + * Aggregate function: returns the approximate number of distinct items in a group. * * @group agg_funcs * @since 1.3.0 */ - def sum(columnName: String): Column = sum(Column(columnName)) + def approxCountDistinct(columnName: String): Column = approxCountDistinct(column(columnName)) /** - * Aggregate function: returns the sum of distinct values in the expression. + * Aggregate function: returns the approximate number of distinct items in a group. * * @group agg_funcs * @since 1.3.0 */ - def sumDistinct(e: Column): Column = SumDistinct(e.expr) + def approxCountDistinct(e: Column, rsd: Double): Column = ApproxCountDistinct(e.expr, rsd) /** - * Aggregate function: returns the sum of distinct values in the expression. + * Aggregate function: returns the approximate number of distinct items in a group. * * @group agg_funcs * @since 1.3.0 */ - def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName)) + def approxCountDistinct(columnName: String, rsd: Double): Column = { + approxCountDistinct(Column(columnName), rsd) + } + + /** + * Aggregate function: returns the average of the values in a group. + * + * @group agg_funcs + * @since 1.3.0 + */ + def avg(e: Column): Column = Average(e.expr) + + /** + * Aggregate function: returns the average of the values in a group. + * + * @group agg_funcs + * @since 1.3.0 + */ + def avg(columnName: String): Column = avg(Column(columnName)) /** * Aggregate function: returns the number of items in a group. @@ -191,141 +213,186 @@ object functions { countDistinct(Column(columnName), columnNames.map(Column.apply) : _*) /** - * Aggregate function: returns the approximate number of distinct items in a group. + * Aggregate function: returns the first value in a group. * * @group agg_funcs * @since 1.3.0 */ - def approxCountDistinct(e: Column): Column = ApproxCountDistinct(e.expr) + def first(e: Column): Column = First(e.expr) /** - * Aggregate function: returns the approximate number of distinct items in a group. + * Aggregate function: returns the first value of a column in a group. * * @group agg_funcs * @since 1.3.0 */ - def approxCountDistinct(columnName: String): Column = approxCountDistinct(column(columnName)) + def first(columnName: String): Column = first(Column(columnName)) /** - * Aggregate function: returns the approximate number of distinct items in a group. + * Aggregate function: returns the last value in a group. * * @group agg_funcs * @since 1.3.0 */ - def approxCountDistinct(e: Column, rsd: Double): Column = ApproxCountDistinct(e.expr, rsd) + def last(e: Column): Column = Last(e.expr) /** - * Aggregate function: returns the approximate number of distinct items in a group. + * Aggregate function: returns the last value of the column in a group. * * @group agg_funcs * @since 1.3.0 */ - def approxCountDistinct(columnName: String, rsd: Double): Column = { - approxCountDistinct(Column(columnName), rsd) - } + def last(columnName: String): Column = last(Column(columnName)) /** - * Aggregate function: returns the average of the values in a group. + * Aggregate function: returns the maximum value of the expression in a group. * * @group agg_funcs * @since 1.3.0 */ - def avg(e: Column): Column = Average(e.expr) + def max(e: Column): Column = Max(e.expr) /** - * Aggregate function: returns the average of the values in a group. + * Aggregate function: returns the maximum value of the column in a group. * * @group agg_funcs * @since 1.3.0 */ - def avg(columnName: String): Column = avg(Column(columnName)) + def max(columnName: String): Column = max(Column(columnName)) /** - * Aggregate function: returns the first value in a group. + * Aggregate function: returns the average of the values in a group. + * Alias for avg. * * @group agg_funcs - * @since 1.3.0 + * @since 1.4.0 */ - def first(e: Column): Column = First(e.expr) + def mean(e: Column): Column = avg(e) /** - * Aggregate function: returns the first value of a column in a group. + * Aggregate function: returns the average of the values in a group. + * Alias for avg. * * @group agg_funcs - * @since 1.3.0 + * @since 1.4.0 */ - def first(columnName: String): Column = first(Column(columnName)) + def mean(columnName: String): Column = avg(columnName) /** - * Aggregate function: returns the last value in a group. + * Aggregate function: returns the minimum value of the expression in a group. * * @group agg_funcs * @since 1.3.0 */ - def last(e: Column): Column = Last(e.expr) + def min(e: Column): Column = Min(e.expr) /** - * Aggregate function: returns the last value of the column in a group. + * Aggregate function: returns the minimum value of the column in a group. * * @group agg_funcs * @since 1.3.0 */ - def last(columnName: String): Column = last(Column(columnName)) + def min(columnName: String): Column = min(Column(columnName)) /** - * Aggregate function: returns the average of the values in a group. - * Alias for avg. + * Aggregate function: returns the unbiased sample standard deviation + * of the expression in a group. * * @group agg_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def mean(e: Column): Column = avg(e) + def stddev(e: Column): Column = Stddev(e.expr) /** - * Aggregate function: returns the average of the values in a group. - * Alias for avg. + * Aggregate function: returns the population standard deviation of + * the expression in a group. * * @group agg_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def mean(columnName: String): Column = avg(columnName) + def stddev_pop(e: Column): Column = StddevPop(e.expr) /** - * Aggregate function: returns the minimum value of the expression in a group. + * Aggregate function: returns the unbiased sample standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_samp(e: Column): Column = StddevSamp(e.expr) + + /** + * Aggregate function: returns the sum of all values in the expression. * * @group agg_funcs * @since 1.3.0 */ - def min(e: Column): Column = Min(e.expr) + def sum(e: Column): Column = Sum(e.expr) /** - * Aggregate function: returns the minimum value of the column in a group. + * Aggregate function: returns the sum of all values in the given column. * * @group agg_funcs * @since 1.3.0 */ - def min(columnName: String): Column = min(Column(columnName)) + def sum(columnName: String): Column = sum(Column(columnName)) /** - * Aggregate function: returns the maximum value of the expression in a group. + * Aggregate function: returns the sum of distinct values in the expression. * * @group agg_funcs * @since 1.3.0 */ - def max(e: Column): Column = Max(e.expr) + def sumDistinct(e: Column): Column = SumDistinct(e.expr) /** - * Aggregate function: returns the maximum value of the column in a group. + * Aggregate function: returns the sum of distinct values in the expression. * * @group agg_funcs * @since 1.3.0 */ - def max(columnName: String): Column = max(Column(columnName)) + def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName)) ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Window function: returns the cumulative distribution of values within a window partition, + * i.e. the fraction of rows that are below the current row. + * + * {{{ + * N = total number of rows in the partition + * cumeDist(x) = number of values before (and including) x / N + * }}} + * + * + * This is equivalent to the CUME_DIST function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def cumeDist(): Column = { + UnresolvedWindowFunction("cume_dist", Nil) + } + + /** + * Window function: returns the rank of rows within a window partition, without any gaps. + * + * The difference between rank and denseRank is that denseRank leaves no gaps in ranking + * sequence when there are ties. That is, if you were ranking a competition using denseRank + * and had three people tie for second place, you would say that all three were in second + * place and that the next person came in third. + * + * This is equivalent to the DENSE_RANK function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def denseRank(): Column = { + UnresolvedWindowFunction("dense_rank", Nil) + } + /** * Window function: returns the value that is `offset` rows before the current row, and * `null` if there is less than `offset` rows before the current row. For example, @@ -453,32 +520,20 @@ object functions { } /** - * Window function: returns a sequential number starting at 1 within a window partition. - * - * This is equivalent to the ROW_NUMBER function in SQL. - * - * @group window_funcs - * @since 1.4.0 - */ - def rowNumber(): Column = { - UnresolvedWindowFunction("row_number", Nil) - } - - /** - * Window function: returns the rank of rows within a window partition, without any gaps. + * Window function: returns the relative rank (i.e. percentile) of rows within a window partition. * - * The difference between rank and denseRank is that denseRank leaves no gaps in ranking - * sequence when there are ties. That is, if you were ranking a competition using denseRank - * and had three people tie for second place, you would say that all three were in second - * place and that the next person came in third. + * This is computed by: + * {{{ + * (rank of row in its partition - 1) / (number of rows in the partition - 1) + * }}} * - * This is equivalent to the DENSE_RANK function in SQL. + * This is equivalent to the PERCENT_RANK function in SQL. * * @group window_funcs * @since 1.4.0 */ - def denseRank(): Column = { - UnresolvedWindowFunction("dense_rank", Nil) + def percentRank(): Column = { + UnresolvedWindowFunction("percent_rank", Nil) } /** @@ -499,39 +554,15 @@ object functions { } /** - * Window function: returns the cumulative distribution of values within a window partition, - * i.e. the fraction of rows that are below the current row. - * - * {{{ - * N = total number of rows in the partition - * cumeDist(x) = number of values before (and including) x / N - * }}} - * - * - * This is equivalent to the CUME_DIST function in SQL. - * - * @group window_funcs - * @since 1.4.0 - */ - def cumeDist(): Column = { - UnresolvedWindowFunction("cume_dist", Nil) - } - - /** - * Window function: returns the relative rank (i.e. percentile) of rows within a window partition. - * - * This is computed by: - * {{{ - * (rank of row in its partition - 1) / (number of rows in the partition - 1) - * }}} + * Window function: returns a sequential number starting at 1 within a window partition. * - * This is equivalent to the PERCENT_RANK function in SQL. + * This is equivalent to the ROW_NUMBER function in SQL. * * @group window_funcs * @since 1.4.0 */ - def percentRank(): Column = { - UnresolvedWindowFunction("percent_rank", Nil) + def rowNumber(): Column = { + UnresolvedWindowFunction("row_number", Nil) } ////////////////////////////////////////////////////////////////////////////////////////////// @@ -566,29 +597,47 @@ object functions { } /** - * Returns the first column that is not null. + * Marks a DataFrame as small enough for use in broadcast joins. + * + * The following example marks the right DataFrame for broadcast hash join using `joinKey`. * {{{ - * df.select(coalesce(df("a"), df("b"))) + * // left and right are DataFrames + * left.join(broadcast(right), "joinKey") * }}} * * @group normal_funcs + * @since 1.5.0 + */ + def broadcast(df: DataFrame): DataFrame = { + DataFrame(df.sqlContext, BroadcastHint(df.logicalPlan)) + } + + /** + * Returns the first column that is not null, or null if all inputs are null. + * + * For example, `coalesce(a, b, c)` will return a if a is not null, + * or b if a is null and b is not null, or c if both a and b are null but c is not null. + * + * @group normal_funcs * @since 1.3.0 */ @scala.annotation.varargs def coalesce(e: Column*): Column = Coalesce(e.map(_.expr)) /** - * Creates a new row for each element in the given array or map column. + * Creates a string column for the file name of the current Spark task. + * + * @group normal_funcs */ - def explode(e: Column): Column = Explode(e.expr) + def inputFileName(): Column = InputFileName() /** - * Converts a string exprsesion to lower case. + * Return true iff the column is NaN. * * @group normal_funcs - * @since 1.3.0 + * @since 1.5.0 */ - def lower(e: Column): Column = Lower(e.expr) + def isNaN(e: Column): Column = IsNaN(e.expr) /** * A column expression that generates monotonically increasing 64-bit integers. @@ -605,7 +654,17 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def monotonicallyIncreasingId(): Column = execution.expressions.MonotonicallyIncreasingID() + def monotonicallyIncreasingId(): Column = MonotonicallyIncreasingID() + + /** + * Returns col1 if it is not NaN, or col2 if col1 is NaN. + * + * Both inputs should be floating point columns (DoubleType or FloatType). + * + * @group normal_funcs + * @since 1.5.0 + */ + def nanvl(col1: Column, col2: Column): Column = NaNvl(col1.expr, col2.expr) /** * Unary minus, i.e. negate the expression. @@ -638,31 +697,6 @@ object functions { */ def not(e: Column): Column = !e - /** - * Evaluates a list of conditions and returns one of multiple possible result expressions. - * If otherwise is not defined at the end, null is returned for unmatched conditions. - * - * {{{ - * // Example: encoding gender string column into integer. - * - * // Scala: - * people.select(when(people("gender") === "male", 0) - * .when(people("gender") === "female", 1) - * .otherwise(2)) - * - * // Java: - * people.select(when(col("gender").equalTo("male"), 0) - * .when(col("gender").equalTo("female"), 1) - * .otherwise(2)) - * }}} - * - * @group normal_funcs - * @since 1.4.0 - */ - def when(condition: Column, value: Any): Column = { - CaseWhen(Seq(condition.expr, lit(value).expr)) - } - /** * Generate a random column with i.i.d. samples from U[0.0, 1.0]. * @@ -703,7 +737,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def sparkPartitionId(): Column = execution.expressions.SparkPartitionID + def sparkPartitionId(): Column = SparkPartitionID() /** * Computes the square root of the specified float value. @@ -722,17 +756,18 @@ object functions { def sqrt(colName: String): Column = sqrt(Column(colName)) /** - * Creates a new struct column. The input column must be a column in a [[DataFrame]], or - * a derived column expression that is named (i.e. aliased). + * Creates a new struct column. + * If the input column is a column in a [[DataFrame]], or a derived column expression + * that is named (i.e. aliased), its name would be remained as the StructField's name, + * otherwise, the newly generated StructField's name would be auto generated as col${index + 1}, + * i.e. col1, col2, col3, ... * * @group normal_funcs * @since 1.4.0 */ @scala.annotation.varargs def struct(cols: Column*): Column = { - require(cols.forall(_.expr.isInstanceOf[NamedExpression]), - s"struct input columns must all be named or aliased ($cols)") - CreateStruct(cols.map(_.expr.asInstanceOf[NamedExpression])) + CreateStruct(cols.map(_.expr)) } /** @@ -746,12 +781,29 @@ object functions { } /** - * Converts a string expression to upper case. + * Evaluates a list of conditions and returns one of multiple possible result expressions. + * If otherwise is not defined at the end, null is returned for unmatched conditions. + * + * {{{ + * // Example: encoding gender string column into integer. + * + * // Scala: + * people.select(when(people("gender") === "male", 0) + * .when(people("gender") === "female", 1) + * .otherwise(2)) + * + * // Java: + * people.select(when(col("gender").equalTo("male"), 0) + * .when(col("gender").equalTo("female"), 1) + * .otherwise(2)) + * }}} * * @group normal_funcs - * @since 1.3.0 + * @since 1.4.0 */ - def upper(e: Column): Column = Upper(e.expr) + def when(condition: Column, value: Any): Column = { + CaseWhen(Seq(condition.expr, lit(value).expr)) + } /** * Computes bitwise NOT. @@ -761,6 +813,18 @@ object functions { */ def bitwiseNOT(e: Column): Column = BitwiseNot(e.expr) + /** + * Parses the expression string into the column that it represents, similar to + * DataFrame.selectExpr + * {{{ + * // get the number of words of each length + * df.groupBy(expr("length(word)")).count() + * }}} + * + * @group normal_funcs + */ + def expr(expr: String): Column = Column(new SqlParser().parseExpression(expr)) + ////////////////////////////////////////////////////////////////////////////////////////////// // Math Functions ////////////////////////////////////////////////////////////////////////////////////////////// @@ -940,6 +1004,15 @@ object functions { */ def ceil(columnName: String): Column = ceil(Column(columnName)) + /** + * Convert a number in a string column from one base to another. + * + * @group math_funcs + * @since 1.5.0 + */ + def conv(num: Column, fromBase: Int, toBase: Int): Column = + Conv(num.expr, lit(fromBase).expr, lit(toBase).expr) + /** * Computes the cosine of the given value. * @@ -972,15 +1045,6 @@ object functions { */ def cosh(columnName: String): Column = cosh(Column(columnName)) - /** - * Returns the double value that is closer than any other to e, the base of the natural - * logarithms. - * - * @group math_funcs - * @since 1.5.0 - */ - def e(): Column = EulerNumber() - /** * Computes the exponential of the given value. * @@ -1013,6 +1077,14 @@ object functions { */ def expm1(columnName: String): Column = expm1(Column(columnName)) + /** + * Computes the factorial of the given value. + * + * @group math_funcs + * @since 1.5.0 + */ + def factorial(e: Column): Column = Factorial(e.expr) + /** * Computes the floor of the given value. * @@ -1029,6 +1101,48 @@ object functions { */ def floor(columnName: String): Column = floor(Column(columnName)) + /** + * Returns the greatest value of the list of values, skipping null values. + * This function takes at least 2 parameters. It will return null iff all parameters are null. + * + * @group normal_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def greatest(exprs: Column*): Column = { + require(exprs.length > 1, "greatest requires at least 2 arguments.") + Greatest(exprs.map(_.expr)) + } + + /** + * Returns the greatest value of the list of column names, skipping null values. + * This function takes at least 2 parameters. It will return null iff all parameters are null. + * + * @group normal_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def greatest(columnName: String, columnNames: String*): Column = { + greatest((columnName +: columnNames).map(Column.apply): _*) + } + + /** + * Computes hex value of the given column. + * + * @group math_funcs + * @since 1.5.0 + */ + def hex(column: Column): Column = Hex(column.expr) + + /** + * Inverse of hex. Interprets each pair of characters as a hexadecimal number + * and converts to the byte representation of number. + * + * @group math_funcs + * @since 1.5.0 + */ + def unhex(column: Column): Column = Unhex(column.expr) + /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. * @@ -1094,6 +1208,31 @@ object functions { */ def hypot(l: Double, rightName: String): Column = hypot(l, Column(rightName)) + /** + * Returns the least value of the list of values, skipping null values. + * This function takes at least 2 parameters. It will return null iff all parameters are null. + * + * @group normal_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def least(exprs: Column*): Column = { + require(exprs.length > 1, "least requires at least 2 arguments.") + Least(exprs.map(_.expr)) + } + + /** + * Returns the least value of the list of column names, skipping null values. + * This function takes at least 2 parameters. It will return null iff all parameters are null. + * + * @group normal_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def least(columnName: String, columnNames: String*): Column = { + least((columnName +: columnNames).map(Column.apply): _*) + } + /** * Computes the natural logarithm of the given value. * @@ -1158,15 +1297,6 @@ object functions { */ def log1p(columnName: String): Column = log1p(Column(columnName)) - /** - * Returns the double value that is closer than any other to pi, the ratio of the circumference - * of a circle to its diameter. - * - * @group math_funcs - * @since 1.5.0 - */ - def pi(): Column = Pi() - /** * Computes the logarithm of the given column in base 2. * @@ -1247,6 +1377,14 @@ object functions { */ def pow(l: Double, rightName: String): Column = pow(l, Column(rightName)) + /** + * Returns the positive value of dividend mod divisor. + * + * @group math_funcs + * @since 1.5.0 + */ + def pmod(dividend: Column, divisor: Column): Column = Pmod(dividend.expr, divisor.expr) + /** * Returns the double value that is closest in value to the argument and * is equal to a mathematical integer. @@ -1265,6 +1403,51 @@ object functions { */ def rint(columnName: String): Column = rint(Column(columnName)) + /** + * Returns the value of the column `e` rounded to 0 decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(e: Column): Column = round(e.expr, 0) + + /** + * Round the value of `e` to `scale` decimal places if `scale` >= 0 + * or at integral part when `scale` < 0. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale)) + + /** + * Shift the the given value numBits left. If the given value is a long value, this function + * will return a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftLeft(e: Column, numBits: Int): Column = ShiftLeft(e.expr, lit(numBits).expr) + + /** + * Shift the the given value numBits right. If the given value is a long value, it will return + * a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr) + + /** + * Unsigned shift the the given value numBits right. If the given value is a long value, + * it will return a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftRightUnsigned(e: Column, numBits: Int): Column = + ShiftRightUnsigned(e.expr, lit(numBits).expr) + /** * Computes the signum of the given value. * @@ -1382,7 +1565,8 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Calculates the MD5 digest and returns the value as a 32 character hex string. + * Calculates the MD5 digest of a binary column and returns the value + * as a 32 character hex string. * * @group misc_funcs * @since 1.5.0 @@ -1390,30 +1574,632 @@ object functions { def md5(e: Column): Column = Md5(e.expr) /** - * Calculates the MD5 digest and returns the value as a 32 character hex string. + * Calculates the SHA-1 digest of a binary column and returns the value + * as a 40 character hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha1(e: Column): Column = Sha1(e.expr) + + /** + * Calculates the SHA-2 family of hash functions of a binary column and + * returns the value as a hex string. + * + * @param e column to compute SHA-2 on. + * @param numBits one of 224, 256, 384, or 512. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha2(e: Column, numBits: Int): Column = { + require(Seq(0, 224, 256, 384, 512).contains(numBits), + s"numBits $numBits is not in the permitted values (0, 224, 256, 384, 512)") + Sha2(e.expr, lit(numBits).expr) + } + + /** + * Calculates the cyclic redundancy check value (CRC32) of a binary column and + * returns the value as a bigint. * * @group misc_funcs * @since 1.5.0 */ - def md5(columnName: String): Column = md5(Column(columnName)) + def crc32(e: Column): Column = Crc32(e.expr) ////////////////////////////////////////////////////////////////////////////////////////////// // String functions ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Computes the length of a given string value + * Computes the numeric value of the first character of the string column, and returns the + * result as a int column. + * + * @group string_funcs + * @since 1.5.0 + */ + def ascii(e: Column): Column = Ascii(e.expr) + + /** + * Computes the BASE64 encoding of a binary column and returns it as a string column. + * This is the reverse of unbase64. + * + * @group string_funcs + * @since 1.5.0 + */ + def base64(e: Column): Column = Base64(e.expr) + + /** + * Concatenates multiple input string columns together into a single string column. + * + * @group string_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def concat(exprs: Column*): Column = Concat(exprs.map(_.expr)) + + /** + * Concatenates multiple input string columns together into a single string column, + * using the given separator. + * + * @group string_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def concat_ws(sep: String, exprs: Column*): Column = { + ConcatWs(Literal.create(sep, StringType) +: exprs.map(_.expr)) + } + + /** + * Computes the first argument into a string from a binary using the provided character set + * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + * If either argument is null, the result will also be null. + * + * @group string_funcs + * @since 1.5.0 + */ + def decode(value: Column, charset: String): Column = Decode(value.expr, lit(charset).expr) + + /** + * Computes the first argument into a binary from a string using the provided character set + * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + * If either argument is null, the result will also be null. + * + * @group string_funcs + * @since 1.5.0 + */ + def encode(value: Column, charset: String): Column = Encode(value.expr, lit(charset).expr) + + /** + * Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places, + * and returns the result as a string column. + * + * If d is 0, the result has no decimal point or fractional part. + * If d < 0, the result will be null. + * + * @group string_funcs + * @since 1.5.0 + */ + def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr) + + /** + * Formats the arguments in printf-style and returns the result as a string column. + * + * @group string_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def format_string(format: String, arguments: Column*): Column = { + FormatString((lit(format) +: arguments).map(_.expr): _*) + } + + /** + * Returns a new string column by converting the first letter of each word to uppercase. + * Words are delimited by whitespace. + * + * For example, "hello world" will become "Hello World". + * + * @group string_funcs + * @since 1.5.0 + */ + def initcap(e: Column): Column = InitCap(e.expr) + + /** + * Locate the position of the first occurrence of substr column in the given string. + * Returns null if either of the arguments are null. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def instr(str: Column, substring: String): Column = StringInstr(str.expr, lit(substring).expr) + + /** + * Computes the length of a given string or binary column. + * + * @group string_funcs + * @since 1.5.0 + */ + def length(e: Column): Column = Length(e.expr) + + /** + * Converts a string column to lower case. + * + * @group string_funcs + * @since 1.3.0 + */ + def lower(e: Column): Column = Lower(e.expr) + + /** + * Computes the Levenshtein distance of the two given string columns. + * @group string_funcs + * @since 1.5.0 + */ + def levenshtein(l: Column, r: Column): Column = Levenshtein(l.expr, r.expr) + + /** + * Locate the position of the first occurrence of substr. + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: String, str: Column): Column = { + new StringLocate(lit(substr).expr, str.expr) + } + + /** + * Locate the position of the first occurrence of substr in a string column, after position pos. + * + * NOTE: The position is not zero based, but 1 based index. returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: String, str: Column, pos: Int): Column = { + StringLocate(lit(substr).expr, str.expr, lit(pos).expr) + } + + /** + * Left-pad the string column with + * + * @group string_funcs + * @since 1.5.0 + */ + def lpad(str: Column, len: Int, pad: String): Column = { + StringLPad(str.expr, lit(len).expr, lit(pad).expr) + } + + /** + * Trim the spaces from left end for the specified string value. + * + * @group string_funcs + * @since 1.5.0 + */ + def ltrim(e: Column): Column = StringTrimLeft(e.expr) + + /** + * Extract a specific(idx) group identified by a java regex, from the specified string column. + * + * @group string_funcs + * @since 1.5.0 + */ + def regexp_extract(e: Column, exp: String, groupIdx: Int): Column = { + RegExpExtract(e.expr, lit(exp).expr, lit(groupIdx).expr) + } + + /** + * Replace all substrings of the specified string value that match regexp with rep. + * + * @group string_funcs + * @since 1.5.0 + */ + def regexp_replace(e: Column, pattern: String, replacement: String): Column = { + RegExpReplace(e.expr, lit(pattern).expr, lit(replacement).expr) + } + + /** + * Decodes a BASE64 encoded string column and returns it as a binary column. + * This is the reverse of base64. + * + * @group string_funcs + * @since 1.5.0 + */ + def unbase64(e: Column): Column = UnBase64(e.expr) + + /** + * Right-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def rpad(str: Column, len: Int, pad: String): Column = { + StringRPad(str.expr, lit(len).expr, lit(pad).expr) + } + + /** + * Repeats a string column n times, and returns it as a new string column. + * + * @group string_funcs + * @since 1.5.0 + */ + def repeat(str: Column, n: Int): Column = { + StringRepeat(str.expr, lit(n).expr) + } + + /** + * Reverses the string column and returns it as a new string column. + * + * @group string_funcs + * @since 1.5.0 + */ + def reverse(str: Column): Column = { + StringReverse(str.expr) + } + + /** + * Trim the spaces from right end for the specified string value. + * + * @group string_funcs + * @since 1.5.0 + */ + def rtrim(e: Column): Column = StringTrimRight(e.expr) + + /** + * * Return the soundex code for the specified expression. + * + * @group string_funcs + * @since 1.5.0 + */ + def soundex(e: Column): Column = SoundEx(e.expr) + + /** + * Splits str around pattern (pattern is a regular expression). + * NOTE: pattern is a string represent the regular expression. + * + * @group string_funcs + * @since 1.5.0 + */ + def split(str: Column, pattern: String): Column = { + StringSplit(str.expr, lit(pattern).expr) + } + + /** + * Substring starts at `pos` and is of length `len` when str is String type or + * returns the slice of byte array that starts at `pos` in byte and is of length `len` + * when str is Binary type + * + * @group string_funcs + * @since 1.5.0 + */ + def substring(str: Column, pos: Int, len: Int): Column = + Substring(str.expr, lit(pos).expr, lit(len).expr) + + /** + * Returns the substring from string str before count occurrences of the delimiter delim. + * If count is positive, everything the left of the final delimiter (counting from left) is + * returned. If count is negative, every to the right of the final delimiter (counting from the + * right) is returned. substring_index performs a case-sensitive match when searching for delim. + * + * @group string_funcs + */ + def substring_index(str: Column, delim: String, count: Int): Column = + SubstringIndex(str.expr, lit(delim).expr, lit(count).expr) + + /** + * Translate any character in the src by a character in replaceString. + * The characters in replaceString is corresponding to the characters in matchingString. + * The translate will happen when any character in the string matching with the character + * in the matchingString. + * + * @group string_funcs + * @since 1.5.0 + */ + def translate(src: Column, matchingString: String, replaceString: String): Column = + StringTranslate(src.expr, lit(matchingString).expr, lit(replaceString).expr) + + /** + * Trim the spaces from both ends for the specified string column. + * * @group string_funcs * @since 1.5.0 */ - def strlen(e: Column): Column = StringLength(e.expr) + def trim(e: Column): Column = StringTrim(e.expr) /** - * Computes the length of a given string column + * Converts a string column to upper case. + * * @group string_funcs + * @since 1.3.0 + */ + def upper(e: Column): Column = Upper(e.expr) + + ////////////////////////////////////////////////////////////////////////////////////////////// + // DateTime functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Returns the date that is numMonths after startDate. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def add_months(startDate: Column, numMonths: Int): Column = + AddMonths(startDate.expr, Literal(numMonths)) + + /** + * Returns the current date as a date column. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def current_date(): Column = CurrentDate() + + /** + * Returns the current timestamp as a timestamp column. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def current_timestamp(): Column = CurrentTimestamp() + + /** + * Converts a date/timestamp/string to a value of string in the format specified by the date + * format given by the second argument. + * + * A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All + * pattern letters of [[java.text.SimpleDateFormat]] can be used. + * + * NOTE: Use when ever possible specialized functions like [[year]]. These benefit from a + * specialized implementation. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def date_format(dateExpr: Column, format: String): Column = + DateFormatClass(dateExpr.expr, Literal(format)) + + /** + * Returns the date that is `days` days after `start` + * @group datetime_funcs + * @since 1.5.0 + */ + def date_add(start: Column, days: Int): Column = DateAdd(start.expr, Literal(days)) + + /** + * Returns the date that is `days` days before `start` + * @group datetime_funcs + * @since 1.5.0 + */ + def date_sub(start: Column, days: Int): Column = DateSub(start.expr, Literal(days)) + + /** + * Returns the number of days from `start` to `end`. + * @group datetime_funcs + * @since 1.5.0 + */ + def datediff(end: Column, start: Column): Column = DateDiff(end.expr, start.expr) + + /** + * Extracts the year as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def year(e: Column): Column = Year(e.expr) + + /** + * Extracts the quarter as an integer from a given date/timestamp/string. + * @group datetime_funcs * @since 1.5.0 */ - def strlen(columnName: String): Column = strlen(Column(columnName)) + def quarter(e: Column): Column = Quarter(e.expr) + + /** + * Extracts the month as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def month(e: Column): Column = Month(e.expr) + + /** + * Extracts the day of the month as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def dayofmonth(e: Column): Column = DayOfMonth(e.expr) + + /** + * Extracts the day of the year as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def dayofyear(e: Column): Column = DayOfYear(e.expr) + + /** + * Extracts the hours as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def hour(e: Column): Column = Hour(e.expr) + + /** + * Given a date column, returns the last day of the month which the given date belongs to. + * For example, input "2015-07-27" returns "2015-07-31" since July 31 is the last day of the + * month in July 2015. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def last_day(e: Column): Column = LastDay(e.expr) + + /** + * Extracts the minutes as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def minute(e: Column): Column = Minute(e.expr) + + /* + * Returns number of months between dates `date1` and `date2`. + * @group datetime_funcs + * @since 1.5.0 + */ + def months_between(date1: Column, date2: Column): Column = MonthsBetween(date1.expr, date2.expr) + + /** + * Given a date column, returns the first date which is later than the value of the date column + * that is on the specified day of the week. + * + * For example, `next_day('2015-07-27', "Sunday")` returns 2015-08-02 because that is the first + * Sunday after 2015-07-27. + * + * Day of the week parameter is case insensitive, and accepts: + * "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". + * + * @group datetime_funcs + * @since 1.5.0 + */ + def next_day(date: Column, dayOfWeek: String): Column = NextDay(date.expr, lit(dayOfWeek).expr) + + /** + * Extracts the seconds as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def second(e: Column): Column = Second(e.expr) + + /** + * Extracts the week number as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 1.5.0 + */ + def weekofyear(e: Column): Column = WeekOfYear(e.expr) + + /** + * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string + * representing the timestamp of that moment in the current system time zone in the given + * format. + * @group datetime_funcs + * @since 1.5.0 + */ + def from_unixtime(ut: Column): Column = FromUnixTime(ut.expr, Literal("yyyy-MM-dd HH:mm:ss")) + + /** + * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string + * representing the timestamp of that moment in the current system time zone in the given + * format. + * @group datetime_funcs + * @since 1.5.0 + */ + def from_unixtime(ut: Column, f: String): Column = FromUnixTime(ut.expr, Literal(f)) + + /** + * Gets current Unix timestamp in seconds. + * @group datetime_funcs + * @since 1.5.0 + */ + def unix_timestamp(): Column = UnixTimestamp(CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")) + + /** + * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), + * using the default timezone and the default locale, return null if fail. + * @group datetime_funcs + * @since 1.5.0 + */ + def unix_timestamp(s: Column): Column = UnixTimestamp(s.expr, Literal("yyyy-MM-dd HH:mm:ss")) + + /** + * Convert time string with given pattern + * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) + * to Unix time stamp (in seconds), return null if fail. + * @group datetime_funcs + * @since 1.5.0 + */ + def unix_timestamp(s: Column, p: String): Column = UnixTimestamp(s.expr, Literal(p)) + + /** + * Converts the column into DateType. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def to_date(e: Column): Column = ToDate(e.expr) + + /** + * Returns date truncated to the unit specified by the format. + * + * @param format: 'year', 'yyyy', 'yy' for truncate by year, + * or 'month', 'mon', 'mm' for truncate by month + * + * @group datetime_funcs + * @since 1.5.0 + */ + def trunc(date: Column, format: String): Column = TruncDate(date.expr, Literal(format)) + + /** + * Assumes given timestamp is UTC and converts to given timezone. + * @group datetime_funcs + * @since 1.5.0 + */ + def from_utc_timestamp(ts: Column, tz: String): Column = + FromUTCTimestamp(ts.expr, Literal(tz).expr) + + /** + * Assumes given timestamp is in given timezone and converts to UTC. + * @group datetime_funcs + * @since 1.5.0 + */ + def to_utc_timestamp(ts: Column, tz: String): Column = ToUTCTimestamp(ts.expr, Literal(tz).expr) + + ////////////////////////////////////////////////////////////////////////////////////////////// + // Collection functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Returns true if the array contain the value + * @group collection_funcs + * @since 1.5.0 + */ + def array_contains(column: Column, value: Any): Column = + ArrayContains(column.expr, Literal(value)) + + /** + * Creates a new row for each element in the given array or map column. + * + * @group collection_funcs + * @since 1.3.0 + */ + def explode(e: Column): Column = Explode(e.expr) + + /** + * Returns length of array or map. + * + * @group collection_funcs + * @since 1.5.0 + */ + def size(e: Column): Column = Size(e.expr) + + /** + * Sorts the input array for the given column in ascending order, + * according to the natural ordering of the array elements. + * + * @group collection_funcs + * @since 1.5.0 + */ + def sort_array(e: Column): Column = sort_array(e, asc = true) + + /** + * Sorts the input array for the given column in ascending / descending order, + * according to the natural ordering of the array elements. + * + * @group collection_funcs + * @since 1.5.0 + */ + def sort_array(e: Column, asc: Boolean): Column = SortArray(e.expr, lit(asc).expr) ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// @@ -1424,6 +2210,7 @@ object functions { (0 to 10).map { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) + val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]).dataType :: $s"}) println(s""" /** * Defines a user-defined function of ${x} arguments as user-defined function (UDF). @@ -1433,14 +2220,15 @@ object functions { * @since 1.3.0 */ def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try($inputTypes).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) }""") } (0 to 10).map { x => val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") val fTypes = Seq.fill(x + 1)("_").mkString(", ") - val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ") + val argsInUDF = (1 to x).map(i => s"arg$i.expr").mkString(", ") println(s""" /** * Call a Scala function of ${x} arguments as user-defined function (UDF). This requires @@ -1448,9 +2236,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = { - ScalaUdf(f, returnType, Seq($argsInUdf)) + ScalaUDF(f, returnType, Seq($argsInUDF)) }""") } } @@ -1463,7 +2253,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1474,7 +2265,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1485,7 +2277,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1496,7 +2289,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1507,7 +2301,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1518,7 +2313,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1529,7 +2325,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1540,7 +2337,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1551,7 +2349,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1562,7 +2361,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1573,7 +2373,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: ScalaReflection.schemaFor(typeTag[A10]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } ////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1584,9 +2385,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function0[_], returnType: DataType): Column = { - ScalaUdf(f, returnType, Seq()) + ScalaUDF(f, returnType, Seq()) } /** @@ -1595,9 +2398,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr)) } /** @@ -1606,9 +2411,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr)) } /** @@ -1617,9 +2424,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) } /** @@ -1628,9 +2437,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) } /** @@ -1639,9 +2450,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) } /** @@ -1650,9 +2463,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) } /** @@ -1661,9 +2476,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) } /** @@ -1672,9 +2489,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) } /** @@ -1683,9 +2502,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) } /** @@ -1694,9 +2515,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) } // scalastyle:on @@ -1709,15 +2532,44 @@ object functions { * * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") * val sqlContext = df.sqlContext - * sqlContext.udf.register("simpleUdf", (v: Int) => v * v) - * df.select($"id", callUdf("simpleUdf", $"value")) + * sqlContext.udf.register("simpleUDF", (v: Int) => v * v) + * df.select($"id", callUDF("simpleUDF", $"value")) + * }}} + * + * @group udf_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def callUDF(udfName: String, cols: Column*): Column = { + UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false) + } + + /** + * Call an user-defined function. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + * val sqlContext = df.sqlContext + * sqlContext.udf.register("simpleUDF", (v: Int) => v * v) + * df.select($"id", callUdf("simpleUDF", $"value")) * }}} * * @group udf_funcs * @since 1.4.0 + * @deprecated As of 1.5.0, since it was not coherent to have two functions callUdf and callUDF */ + @deprecated("Use callUDF", "1.5.0") def callUdf(udfName: String, cols: Column*): Column = { - UnresolvedFunction(udfName, cols.map(_.expr)) + // Note: we avoid using closures here because on file systems that are case-insensitive, the + // compiled class file for the closure here will conflict with the one in callUDF (upper case). + val exprs = new Array[Expression](cols.size) + var i = 0 + while (i < cols.size) { + exprs(i) = cols(i).expr + i += 1 + } + UnresolvedFunction(udfName, exprs, isDistinct = false) } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 8849fc2f1f0e..68ebaaca6c53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -88,6 +88,17 @@ abstract class JdbcDialect { def quoteIdentifier(colName: String): String = { s""""$colName"""" } + + /** + * Get the SQL query that should be used to find if the given table exists. Dialects can + * override this method to return a query that works best in a particular database. + * @param table The name of the table. + * @return The SQL query to use for checking the table. + */ + def getTableExistsQuery(table: String): String = { + s"SELECT * FROM $table WHERE 1=0" + } + } /** @@ -125,6 +136,7 @@ object JdbcDialects { registerDialect(MySQLDialect) registerDialect(PostgresDialect) + registerDialect(DB2Dialect) /** * Fetch the JdbcDialect class corresponding to a given database url. @@ -197,6 +209,11 @@ case object PostgresDialect extends JdbcDialect { case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) case _ => None } + + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } + } /** @@ -221,4 +238,25 @@ case object MySQLDialect extends JdbcDialect { override def quoteIdentifier(colName: String): String = { s"`$colName`" } + + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } +} + +/** + * :: DeveloperApi :: + * Default DB2 dialect, mapping string/boolean on write to valid DB2 types. + * By default string, and boolean gets mapped to db2 invalid types TEXT, and BIT(1). + */ +@DeveloperApi +case object DB2Dialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:db2") + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Some(JdbcType("CLOB", java.sql.Types.CLOB)) + case BooleanType => Some(JdbcType("CHAR(1)", java.sql.Types.CHAR)) + case _ => None + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala deleted file mode 100644 index cc918c237192..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc - -import java.sql.{Connection, DriverManager} -import java.util.Properties - -import scala.util.Try - -/** - * Util functions for JDBC tables. - */ -private[sql] object JdbcUtils { - - /** - * Establishes a JDBC connection. - */ - def createConnection(url: String, connectionProperties: Properties): Connection = { - DriverManager.getConnection(url, connectionProperties) - } - - /** - * Returns true if the table already exists in the JDBC database. - */ - def tableExists(conn: Connection, table: String): Boolean = { - // Somewhat hacky, but there isn't a good way to identify whether a table exists for all - // SQL database systems, considering "table" could also include the database name. - Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess - } - - /** - * Drops a table from the JDBC database. - */ - def dropTable(conn: Connection, table: String): Unit = { - conn.prepareStatement(s"DROP TABLE $table").executeUpdate() - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala deleted file mode 100644 index dd8aaf647489..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ /dev/null @@ -1,250 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.sql.{Connection, Driver, DriverManager, DriverPropertyInfo, PreparedStatement, SQLFeatureNotSupportedException} -import java.util.Properties - -import scala.collection.mutable - -import org.apache.spark.Logging -import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils - -package object jdbc { - private[sql] object JDBCWriteDetails extends Logging { - /** - * Returns a PreparedStatement that inserts a row into table via conn. - */ - def insertStatement(conn: Connection, table: String, rddSchema: StructType): - PreparedStatement = { - val sql = new StringBuilder(s"INSERT INTO $table VALUES (") - var fieldsLeft = rddSchema.fields.length - while (fieldsLeft > 0) { - sql.append("?") - if (fieldsLeft > 1) sql.append(", ") else sql.append(")") - fieldsLeft = fieldsLeft - 1 - } - conn.prepareStatement(sql.toString) - } - - /** - * Saves a partition of a DataFrame to the JDBC database. This is done in - * a single database transaction in order to avoid repeatedly inserting - * data as much as possible. - * - * It is still theoretically possible for rows in a DataFrame to be - * inserted into the database more than once if a stage somehow fails after - * the commit occurs but before the stage can return successfully. - * - * This is not a closure inside saveTable() because apparently cosmetic - * implementation changes elsewhere might easily render such a closure - * non-Serializable. Instead, we explicitly close over all variables that - * are used. - */ - def savePartition( - url: String, - table: String, - iterator: Iterator[Row], - rddSchema: StructType, - nullTypes: Array[Int], - properties: Properties): Iterator[Byte] = { - val conn = DriverManager.getConnection(url, properties) - var committed = false - try { - conn.setAutoCommit(false) // Everything in the same db transaction. - val stmt = insertStatement(conn, table, rddSchema) - try { - while (iterator.hasNext) { - val row = iterator.next() - val numFields = rddSchema.fields.length - var i = 0 - while (i < numFields) { - if (row.isNullAt(i)) { - stmt.setNull(i + 1, nullTypes(i)) - } else { - rddSchema.fields(i).dataType match { - case IntegerType => stmt.setInt(i + 1, row.getInt(i)) - case LongType => stmt.setLong(i + 1, row.getLong(i)) - case DoubleType => stmt.setDouble(i + 1, row.getDouble(i)) - case FloatType => stmt.setFloat(i + 1, row.getFloat(i)) - case ShortType => stmt.setInt(i + 1, row.getShort(i)) - case ByteType => stmt.setInt(i + 1, row.getByte(i)) - case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i)) - case StringType => stmt.setString(i + 1, row.getString(i)) - case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i)) - case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) - case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) - case DecimalType.Unlimited => stmt.setBigDecimal(i + 1, - row.getAs[java.math.BigDecimal](i)) - case _ => throw new IllegalArgumentException( - s"Can't translate non-null value for field $i") - } - } - i = i + 1 - } - stmt.executeUpdate() - } - } finally { - stmt.close() - } - conn.commit() - committed = true - } finally { - if (!committed) { - // The stage must fail. We got here through an exception path, so - // let the exception through unless rollback() or close() want to - // tell the user about another problem. - conn.rollback() - conn.close() - } else { - // The stage must succeed. We cannot propagate any exception close() might throw. - try { - conn.close() - } catch { - case e: Exception => logWarning("Transaction succeeded, but closing failed", e) - } - } - } - Array[Byte]().iterator - } - - /** - * Compute the schema string for this RDD. - */ - def schemaString(df: DataFrame, url: String): String = { - val sb = new StringBuilder() - val dialect = JdbcDialects.get(url) - df.schema.fields foreach { field => { - val name = field.name - val typ: String = - dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse( - field.dataType match { - case IntegerType => "INTEGER" - case LongType => "BIGINT" - case DoubleType => "DOUBLE PRECISION" - case FloatType => "REAL" - case ShortType => "INTEGER" - case ByteType => "BYTE" - case BooleanType => "BIT(1)" - case StringType => "TEXT" - case BinaryType => "BLOB" - case TimestampType => "TIMESTAMP" - case DateType => "DATE" - case DecimalType.Unlimited => "DECIMAL(40,20)" - case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") - }) - val nullable = if (field.nullable) "" else "NOT NULL" - sb.append(s", $name $typ $nullable") - }} - if (sb.length < 2) "" else sb.substring(2) - } - - /** - * Saves the RDD to the database in a single transaction. - */ - def saveTable( - df: DataFrame, - url: String, - table: String, - properties: Properties = new Properties()) { - val dialect = JdbcDialects.get(url) - val nullTypes: Array[Int] = df.schema.fields.map { field => - dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse( - field.dataType match { - case IntegerType => java.sql.Types.INTEGER - case LongType => java.sql.Types.BIGINT - case DoubleType => java.sql.Types.DOUBLE - case FloatType => java.sql.Types.REAL - case ShortType => java.sql.Types.INTEGER - case ByteType => java.sql.Types.INTEGER - case BooleanType => java.sql.Types.BIT - case StringType => java.sql.Types.CLOB - case BinaryType => java.sql.Types.BLOB - case TimestampType => java.sql.Types.TIMESTAMP - case DateType => java.sql.Types.DATE - case DecimalType.Unlimited => java.sql.Types.DECIMAL - case _ => throw new IllegalArgumentException( - s"Can't translate null value for field $field") - }) - } - - val rddSchema = df.schema - df.foreachPartition { iterator => - JDBCWriteDetails.savePartition(url, table, iterator, rddSchema, nullTypes, properties) - } - } - - } - - private [sql] class DriverWrapper(val wrapped: Driver) extends Driver { - override def acceptsURL(url: String): Boolean = wrapped.acceptsURL(url) - - override def jdbcCompliant(): Boolean = wrapped.jdbcCompliant() - - override def getPropertyInfo(url: String, info: Properties): Array[DriverPropertyInfo] = { - wrapped.getPropertyInfo(url, info) - } - - override def getMinorVersion: Int = wrapped.getMinorVersion - - def getParentLogger: java.util.logging.Logger = - throw new SQLFeatureNotSupportedException( - s"${this.getClass().getName}.getParentLogger is not yet implemented.") - - override def connect(url: String, info: Properties): Connection = wrapped.connect(url, info) - - override def getMajorVersion: Int = wrapped.getMajorVersion - } - - /** - * java.sql.DriverManager is always loaded by bootstrap classloader, - * so it can't load JDBC drivers accessible by Spark ClassLoader. - * - * To solve the problem, drivers from user-supplied jars are wrapped - * into thin wrapper. - */ - private [sql] object DriverRegistry extends Logging { - - private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty - - def register(className: String): Unit = { - val cls = Utils.getContextOrSparkClassLoader.loadClass(className) - if (cls.getClassLoader == null) { - logTrace(s"$className has been loaded with bootstrap ClassLoader, wrapper is not required") - } else if (wrapperMap.get(className).isDefined) { - logTrace(s"Wrapper for $className already exists") - } else { - synchronized { - if (wrapperMap.get(className).isEmpty) { - val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver]) - DriverManager.registerDriver(wrapper) - wrapperMap(className) = wrapper - logTrace(s"Wrapper for $className registered") - } - } - } - } - - def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { - case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName - case driver => driver.getClass.getCanonicalName - } - } - -} // package object jdbc diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala deleted file mode 100644 index 69bf13e1e5a6..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ /dev/null @@ -1,223 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.json - -import java.io.IOException - -import org.apache.hadoop.fs.Path - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} - - -private[sql] class DefaultSource - extends RelationProvider - with SchemaRelationProvider - with CreatableRelationProvider { - - private def checkPath(parameters: Map[String, String]): String = { - parameters.getOrElse("path", sys.error("'path' must be specified for json data.")) - } - - /** Returns a new base relation with the parameters. */ - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { - val path = checkPath(parameters) - val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) - - new JSONRelation(path, samplingRatio, None, sqlContext) - } - - /** Returns a new base relation with the given schema and parameters. */ - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String], - schema: StructType): BaseRelation = { - val path = checkPath(parameters) - val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) - - new JSONRelation(path, samplingRatio, Some(schema), sqlContext) - } - - override def createRelation( - sqlContext: SQLContext, - mode: SaveMode, - parameters: Map[String, String], - data: DataFrame): BaseRelation = { - val path = checkPath(parameters) - val filesystemPath = new Path(path) - val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - val doSave = if (fs.exists(filesystemPath)) { - mode match { - case SaveMode.Append => - sys.error(s"Append mode is not supported by ${this.getClass.getCanonicalName}") - case SaveMode.Overwrite => { - var success: Boolean = false - try { - success = fs.delete(filesystemPath, true) - } catch { - case e: IOException => - throw new IOException( - s"Unable to clear output directory ${filesystemPath.toString} prior" - + s" to writing to JSON table:\n${e.toString}") - } - if (!success) { - throw new IOException( - s"Unable to clear output directory ${filesystemPath.toString} prior" - + s" to writing to JSON table.") - } - true - } - case SaveMode.ErrorIfExists => - sys.error(s"path $path already exists.") - case SaveMode.Ignore => false - } - } else { - true - } - if (doSave) { - // Only save data when the save mode is not ignore. - data.toJSON.saveAsTextFile(path) - } - - createRelation(sqlContext, parameters, data.schema) - } -} - -private[sql] class JSONRelation( - // baseRDD is not immutable with respect to INSERT OVERWRITE - // and so it must be recreated at least as often as the - // underlying inputs are modified. To be safe, a function is - // used instead of a regular RDD value to ensure a fresh RDD is - // recreated for each and every operation. - baseRDD: () => RDD[String], - val path: Option[String], - val samplingRatio: Double, - userSpecifiedSchema: Option[StructType])( - @transient val sqlContext: SQLContext) - extends BaseRelation - with TableScan - with InsertableRelation - with CatalystScan { - - def this( - path: String, - samplingRatio: Double, - userSpecifiedSchema: Option[StructType], - sqlContext: SQLContext) = - this( - () => sqlContext.sparkContext.textFile(path), - Some(path), - samplingRatio, - userSpecifiedSchema)(sqlContext) - - private val useJacksonStreamingAPI: Boolean = sqlContext.conf.useJacksonStreamingAPI - - override val needConversion: Boolean = false - - override lazy val schema = userSpecifiedSchema.getOrElse { - if (useJacksonStreamingAPI) { - InferSchema( - baseRDD(), - samplingRatio, - sqlContext.conf.columnNameOfCorruptRecord) - } else { - JsonRDD.nullTypeToStringType( - JsonRDD.inferSchema( - baseRDD(), - samplingRatio, - sqlContext.conf.columnNameOfCorruptRecord)) - } - } - - override def buildScan(): RDD[Row] = { - if (useJacksonStreamingAPI) { - JacksonParser( - baseRDD(), - schema, - sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) - } else { - JsonRDD.jsonStringToRow( - baseRDD(), - schema, - sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) - } - } - - override def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] = { - if (useJacksonStreamingAPI) { - JacksonParser( - baseRDD(), - StructType.fromAttributes(requiredColumns), - sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) - } else { - JsonRDD.jsonStringToRow( - baseRDD(), - StructType.fromAttributes(requiredColumns), - sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) - } - } - - override def insert(data: DataFrame, overwrite: Boolean): Unit = { - val filesystemPath = path match { - case Some(p) => new Path(p) - case None => - throw new IOException(s"Cannot INSERT into table with no path defined") - } - - val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - - if (overwrite) { - if (fs.exists(filesystemPath)) { - var success: Boolean = false - try { - success = fs.delete(filesystemPath, true) - } catch { - case e: IOException => - throw new IOException( - s"Unable to clear output directory ${filesystemPath.toString} prior" - + s" to writing to JSON table:\n${e.toString}") - } - if (!success) { - throw new IOException( - s"Unable to clear output directory ${filesystemPath.toString} prior" - + s" to writing to JSON table.") - } - } - // Write the data. - data.toJSON.saveAsTextFile(filesystemPath.toString) - // Right now, we assume that the schema is not changed. We will not update the schema. - // schema = data.schema - } else { - // TODO: Support INSERT INTO - sys.error("JSON table only support INSERT OVERWRITE for now.") - } - } - - override def hashCode(): Int = 41 * (41 + path.hashCode) + schema.hashCode() - - override def equals(other: Any): Boolean = other match { - case that: JSONRelation => - (this.path == that.path) && this.schema.sameType(that.schema) - case _ => false - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala deleted file mode 100644 index 44594c5080ff..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ /dev/null @@ -1,448 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.json - -import scala.collection.Map -import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} - -import com.fasterxml.jackson.core.JsonProcessingException -import com.fasterxml.jackson.databind.ObjectMapper - -import org.apache.spark.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - - -private[sql] object JsonRDD extends Logging { - - private[sql] def jsonStringToRow( - json: RDD[String], - schema: StructType, - columnNameOfCorruptRecords: String): RDD[InternalRow] = { - parseJson(json, columnNameOfCorruptRecords).map(parsed => asRow(parsed, schema)) - } - - private[sql] def inferSchema( - json: RDD[String], - samplingRatio: Double = 1.0, - columnNameOfCorruptRecords: String): StructType = { - require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") - val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1) - val allKeys = - if (schemaData.isEmpty()) { - Set.empty[(String, DataType)] - } else { - parseJson(schemaData, columnNameOfCorruptRecords).map(allKeysWithValueTypes).reduce(_ ++ _) - } - createSchema(allKeys) - } - - private def createSchema(allKeys: Set[(String, DataType)]): StructType = { - // Resolve type conflicts - val resolved = allKeys.groupBy { - case (key, dataType) => key - }.map { - // Now, keys and types are organized in the format of - // key -> Set(type1, type2, ...). - case (key, typeSet) => { - val fieldName = key.substring(1, key.length - 1).split("`.`").toSeq - val dataType = typeSet.map { - case (_, dataType) => dataType - }.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2)) - - (fieldName, dataType) - } - } - - def makeStruct(values: Seq[Seq[String]], prefix: Seq[String]): StructType = { - val (topLevel, structLike) = values.partition(_.size == 1) - - val topLevelFields = topLevel.filter { - name => resolved.get(prefix ++ name).get match { - case ArrayType(elementType, _) => { - def hasInnerStruct(t: DataType): Boolean = t match { - case s: StructType => true - case ArrayType(t1, _) => hasInnerStruct(t1) - case o => false - } - - // Check if this array has inner struct. - !hasInnerStruct(elementType) - } - case struct: StructType => false - case _ => true - } - }.map { - a => StructField(a.head, resolved.get(prefix ++ a).get, nullable = true) - } - val topLevelFieldNameSet = topLevelFields.map(_.name) - - val structFields: Seq[StructField] = structLike.groupBy(_(0)).filter { - case (name, _) => !topLevelFieldNameSet.contains(name) - }.map { - case (name, fields) => { - val nestedFields = fields.map(_.tail) - val structType = makeStruct(nestedFields, prefix :+ name) - val dataType = resolved.get(prefix :+ name).get - dataType match { - case array: ArrayType => - // The pattern of this array is ArrayType(...(ArrayType(StructType))). - // Since the inner struct of array is a placeholder (StructType(Nil)), - // we need to replace this placeholder with the actual StructType (structType). - def getActualArrayType( - innerStruct: StructType, - currentArray: ArrayType): ArrayType = currentArray match { - case ArrayType(s: StructType, containsNull) => - ArrayType(innerStruct, containsNull) - case ArrayType(a: ArrayType, containsNull) => - ArrayType(getActualArrayType(innerStruct, a), containsNull) - } - Some(StructField(name, getActualArrayType(structType, array), nullable = true)) - case struct: StructType => Some(StructField(name, structType, nullable = true)) - // dataType is StringType means that we have resolved type conflicts involving - // primitive types and complex types. So, the type of name has been relaxed to - // StringType. Also, this field should have already been put in topLevelFields. - case StringType => None - } - } - }.flatMap(field => field).toSeq - - StructType((topLevelFields ++ structFields).sortBy(_.name)) - } - - makeStruct(resolved.keySet.toSeq, Nil) - } - - private[sql] def nullTypeToStringType(struct: StructType): StructType = { - val fields = struct.fields.map { - case StructField(fieldName, dataType, nullable, _) => { - val newType = dataType match { - case NullType => StringType - case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) - case ArrayType(struct: StructType, containsNull) => - ArrayType(nullTypeToStringType(struct), containsNull) - case struct: StructType => nullTypeToStringType(struct) - case other: DataType => other - } - StructField(fieldName, newType, nullable) - } - } - - StructType(fields) - } - - /** - * Returns the most general data type for two given data types. - */ - private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { - HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2) match { - case Some(commonType) => commonType - case None => - // t1 or t2 is a StructType, ArrayType, or an unexpected type. - (t1, t2) match { - case (other: DataType, NullType) => other - case (NullType, other: DataType) => other - case (StructType(fields1), StructType(fields2)) => { - val newFields = (fields1 ++ fields2).groupBy(field => field.name).map { - case (name, fieldTypes) => { - val dataType = fieldTypes.map(field => field.dataType).reduce( - (type1: DataType, type2: DataType) => compatibleType(type1, type2)) - StructField(name, dataType, true) - } - } - StructType(newFields.toSeq.sortBy(_.name)) - } - case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => - ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) - // TODO: We should use JsonObjectStringType to mark that values of field will be - // strings and every string is a Json object. - case (_, _) => StringType - } - } - } - - private def typeOfPrimitiveValue: PartialFunction[Any, DataType] = { - // For Integer values, use LongType by default. - val useLongType: PartialFunction[Any, DataType] = { - case value: IntegerType.InternalType => LongType - } - - useLongType orElse ScalaReflection.typeOfObject orElse { - // Since we do not have a data type backed by BigInteger, - // when we see a Java BigInteger, we use DecimalType. - case value: java.math.BigInteger => DecimalType.Unlimited - // DecimalType's JVMType is scala BigDecimal. - case value: java.math.BigDecimal => DecimalType.Unlimited - // Unexpected data type. - case _ => StringType - } - } - - /** - * Returns the element type of an JSON array. We go through all elements of this array - * to detect any possible type conflict. We use [[compatibleType]] to resolve - * type conflicts. - */ - private def typeOfArray(l: Seq[Any]): ArrayType = { - val elements = l.flatMap(v => Option(v)) - if (elements.isEmpty) { - // If this JSON array is empty, we use NullType as a placeholder. - // If this array is not empty in other JSON objects, we can resolve - // the type after we have passed through all JSON objects. - ArrayType(NullType, containsNull = true) - } else { - val elementType = elements.map { - e => e match { - case map: Map[_, _] => StructType(Nil) - // We have an array of arrays. If those element arrays do not have the same - // element types, we will return ArrayType[StringType]. - case seq: Seq[_] => typeOfArray(seq) - case value => typeOfPrimitiveValue(value) - } - }.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2)) - - ArrayType(elementType, containsNull = true) - } - } - - /** - * Figures out all key names and data types of values from a parsed JSON object - * (in the format of Map[Stirng, Any]). When the value of a key is an JSON object, we - * only use a placeholder (StructType(Nil)) to mark that it should be a struct - * instead of getting all fields of this struct because a field does not appear - * in this JSON object can appear in other JSON objects. - */ - private def allKeysWithValueTypes(m: Map[String, Any]): Set[(String, DataType)] = { - val keyValuePairs = m.map { - // Quote the key with backticks to handle cases which have dots - // in the field name. - case (key, value) => (s"`$key`", value) - }.toSet - keyValuePairs.flatMap { - case (key: String, struct: Map[_, _]) => { - // The value associated with the key is an JSON object. - allKeysWithValueTypes(struct.asInstanceOf[Map[String, Any]]).map { - case (k, dataType) => (s"$key.$k", dataType) - } ++ Set((key, StructType(Nil))) - } - case (key: String, array: Seq[_]) => { - // The value associated with the key is an array. - // Handle inner structs of an array. - def buildKeyPathForInnerStructs(v: Any, t: DataType): Seq[(String, DataType)] = t match { - case ArrayType(e: StructType, _) => { - // The elements of this arrays are structs. - v.asInstanceOf[Seq[Map[String, Any]]].flatMap(Option(_)).flatMap { - element => allKeysWithValueTypes(element) - }.map { - case (k, t) => (s"$key.$k", t) - } - } - case ArrayType(t1, _) => - v.asInstanceOf[Seq[Any]].flatMap(Option(_)).flatMap { - element => buildKeyPathForInnerStructs(element, t1) - } - case other => Nil - } - val elementType = typeOfArray(array) - buildKeyPathForInnerStructs(array, elementType) :+ (key, elementType) - } - // we couldn't tell what the type is if the value is null or empty string - case (key: String, value) if value == "" || value == null => (key, NullType) :: Nil - case (key: String, value) => (key, typeOfPrimitiveValue(value)) :: Nil - } - } - - /** - * Converts a Java Map/List to a Scala Map/Seq. - * We do not use Jackson's scala module at here because - * DefaultScalaModule in jackson-module-scala will make - * the parsing very slow. - */ - private def scalafy(obj: Any): Any = obj match { - case map: java.util.Map[_, _] => - // .map(identity) is used as a workaround of non-serializable Map - // generated by .mapValues. - // This issue is documented at https://issues.scala-lang.org/browse/SI-7005 - JMapWrapper(map).mapValues(scalafy).map(identity) - case list: java.util.List[_] => - JListWrapper(list).map(scalafy) - case atom => atom - } - - private def parseJson( - json: RDD[String], - columnNameOfCorruptRecords: String): RDD[Map[String, Any]] = { - // According to [Jackson-72: https://jira.codehaus.org/browse/JACKSON-72], - // ObjectMapper will not return BigDecimal when - // "DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS" is disabled - // (see NumberDeserializer.deserialize for the logic). - // But, we do not want to enable this feature because it will use BigDecimal - // for every float number, which will be slow. - // So, right now, we will have Infinity for those BigDecimal number. - // TODO: Support BigDecimal. - json.mapPartitions(iter => { - // When there is a key appearing multiple times (a duplicate key), - // the ObjectMapper will take the last value associated with this duplicate key. - // For example: for {"key": 1, "key":2}, we will get "key"->2. - val mapper = new ObjectMapper() - iter.flatMap { record => - try { - val parsed = mapper.readValue(record, classOf[Object]) match { - case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil - case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]] - case _ => - sys.error( - s"Failed to parse record $record. Please make sure that each line of the file " + - "(or each string in the RDD) is a valid JSON object or an array of JSON objects.") - } - - parsed - } catch { - case e: JsonProcessingException => - Map(columnNameOfCorruptRecords -> UTF8String.fromString(record)) :: Nil - } - } - }) - } - - private def toLong(value: Any): Long = { - value match { - case value: java.lang.Integer => value.asInstanceOf[Int].toLong - case value: java.lang.Long => value.asInstanceOf[Long] - } - } - - private def toDouble(value: Any): Double = { - value match { - case value: java.lang.Integer => value.asInstanceOf[Int].toDouble - case value: java.lang.Long => value.asInstanceOf[Long].toDouble - case value: java.lang.Double => value.asInstanceOf[Double] - } - } - - private def toDecimal(value: Any): Decimal = { - value match { - case value: java.lang.Integer => Decimal(value) - case value: java.lang.Long => Decimal(value) - case value: java.math.BigInteger => Decimal(new java.math.BigDecimal(value)) - case value: java.lang.Double => Decimal(value) - case value: java.math.BigDecimal => Decimal(value) - } - } - - private def toJsonArrayString(seq: Seq[Any]): String = { - val builder = new StringBuilder - builder.append("[") - var count = 0 - seq.foreach { - element => - if (count > 0) builder.append(",") - count += 1 - builder.append(toString(element)) - } - builder.append("]") - - builder.toString() - } - - private def toJsonObjectString(map: Map[String, Any]): String = { - val builder = new StringBuilder - builder.append("{") - var count = 0 - map.foreach { - case (key, value) => - if (count > 0) builder.append(",") - count += 1 - val stringValue = if (value.isInstanceOf[String]) s"""\"$value\"""" else toString(value) - builder.append(s"""\"${key}\":${stringValue}""") - } - builder.append("}") - - builder.toString() - } - - private def toString(value: Any): String = { - value match { - case value: Map[_, _] => toJsonObjectString(value.asInstanceOf[Map[String, Any]]) - case value: Seq[_] => toJsonArrayString(value) - case value => Option(value).map(_.toString).orNull - } - } - - private def toDate(value: Any): Int = { - value match { - // only support string as date - case value: java.lang.String => - DateUtils.millisToDays(DateUtils.stringToTime(value).getTime) - case value: java.sql.Date => DateUtils.fromJavaDate(value) - } - } - - private def toTimestamp(value: Any): Long = { - value match { - case value: java.lang.Integer => value.asInstanceOf[Int].toLong * 10000L - case value: java.lang.Long => value * 10000L - case value: java.lang.String => DateUtils.stringToTime(value).getTime * 10000L - } - } - - private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any = { - if (value == null) { - null - } else { - desiredType match { - case StringType => UTF8String.fromString(toString(value)) - case _ if value == null || value == "" => null // guard the non string type - case IntegerType => value.asInstanceOf[IntegerType.InternalType] - case LongType => toLong(value) - case DoubleType => toDouble(value) - case DecimalType() => toDecimal(value) - case BooleanType => value.asInstanceOf[BooleanType.InternalType] - case NullType => null - case ArrayType(elementType, _) => - value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) - case MapType(StringType, valueType, _) => - val map = value.asInstanceOf[Map[String, Any]] - map.map { - case (k, v) => - (UTF8String.fromString(k), enforceCorrectType(v, valueType)) - }.map(identity) - case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct) - case DateType => toDate(value) - case TimestampType => toTimestamp(value) - } - } - } - - private def asRow(json: Map[String, Any], schema: StructType): InternalRow = { - // TODO: Reuse the row instead of creating a new one for every record. - val row = new GenericMutableRow(schema.fields.length) - schema.fields.zipWithIndex.foreach { - case (StructField(name, dataType, _, _), i) => - row.update(i, json.get(name).flatMap(v => Option(v)).map( - enforceCorrectType(_, dataType)).orNull) - } - - row - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 4e94fd07a877..a9c600b139b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -46,6 +46,6 @@ package object sql { * Type alias for [[DataFrame]]. Kept here for backward source compatibility for Scala. * @deprecated As of 1.3.0, replaced by `DataFrame`. */ - @deprecated("1.3.0", "use DataFrame") + @deprecated("use DataFrame", "1.3.0") type SchemaRDD = DataFrame } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala deleted file mode 100644 index 4da5e96b82e3..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ /dev/null @@ -1,929 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.sql.Timestamp -import java.util.{TimeZone, Calendar} - -import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap} - -import jodd.datetime.JDateTime -import org.apache.parquet.column.Dictionary -import org.apache.parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} -import org.apache.parquet.schema.MessageType - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.parquet.CatalystConverter.FieldType -import org.apache.spark.sql.types._ -import org.apache.spark.sql.parquet.timestamp.NanoTime -import org.apache.spark.unsafe.types.UTF8String - -/** - * Collection of converters of Parquet types (group and primitive types) that - * model arrays and maps. The conversions are partly based on the AvroParquet - * converters that are part of Parquet in order to be able to process these - * types. - * - * There are several types of converters: - *
        - *
      • [[org.apache.spark.sql.parquet.CatalystPrimitiveConverter]] for primitive - * (numeric, boolean and String) types
      • - *
      • [[org.apache.spark.sql.parquet.CatalystNativeArrayConverter]] for arrays - * of native JVM element types; note: currently null values are not supported!
      • - *
      • [[org.apache.spark.sql.parquet.CatalystArrayConverter]] for arrays of - * arbitrary element types (including nested element types); note: currently - * null values are not supported!
      • - *
      • [[org.apache.spark.sql.parquet.CatalystStructConverter]] for structs
      • - *
      • [[org.apache.spark.sql.parquet.CatalystMapConverter]] for maps; note: - * currently null values are not supported!
      • - *
      • [[org.apache.spark.sql.parquet.CatalystPrimitiveRowConverter]] for rows - * of only primitive element types
      • - *
      • [[org.apache.spark.sql.parquet.CatalystGroupConverter]] for other nested - * records, including the top-level row record
      • - *
      - */ - -private[sql] object CatalystConverter { - // The type internally used for fields - type FieldType = StructField - - // This is mostly Parquet convention (see, e.g., `ConversionPatterns`). - // Note that "array" for the array elements is chosen by ParquetAvro. - // Using a different value will result in Parquet silently dropping columns. - val ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME = "bag" - val ARRAY_ELEMENTS_SCHEMA_NAME = "array" - // SPARK-4520: Thrift generated parquet files have different array element - // schema names than avro. Thrift parquet uses array_schema_name + "_tuple" - // as opposed to "array" used by default. For more information, check - // TestThriftSchemaConverter.java in parquet.thrift. - val THRIFT_ARRAY_ELEMENTS_SCHEMA_NAME_SUFFIX = "_tuple" - val MAP_KEY_SCHEMA_NAME = "key" - val MAP_VALUE_SCHEMA_NAME = "value" - val MAP_SCHEMA_NAME = "map" - - // TODO: consider using Array[T] for arrays to avoid boxing of primitive types - type ArrayScalaType[T] = Seq[T] - type StructScalaType[T] = InternalRow - type MapScalaType[K, V] = Map[K, V] - - protected[parquet] def createConverter( - field: FieldType, - fieldIndex: Int, - parent: CatalystConverter): Converter = { - val fieldType: DataType = field.dataType - fieldType match { - case udt: UserDefinedType[_] => { - createConverter(field.copy(dataType = udt.sqlType), fieldIndex, parent) - } - // For native JVM types we use a converter with native arrays - case ArrayType(elementType: AtomicType, false) => { - new CatalystNativeArrayConverter(elementType, fieldIndex, parent) - } - // This is for other types of arrays, including those with nested fields - case ArrayType(elementType: DataType, false) => { - new CatalystArrayConverter(elementType, fieldIndex, parent) - } - case ArrayType(elementType: DataType, true) => { - new CatalystArrayContainsNullConverter(elementType, fieldIndex, parent) - } - case StructType(fields: Array[StructField]) => { - new CatalystStructConverter(fields, fieldIndex, parent) - } - case MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) => { - new CatalystMapConverter( - Array( - new FieldType(MAP_KEY_SCHEMA_NAME, keyType, false), - new FieldType(MAP_VALUE_SCHEMA_NAME, valueType, valueContainsNull)), - fieldIndex, - parent) - } - // Strings, Shorts and Bytes do not have a corresponding type in Parquet - // so we need to treat them separately - case StringType => - new CatalystPrimitiveStringConverter(parent, fieldIndex) - case ShortType => { - new CatalystPrimitiveConverter(parent, fieldIndex) { - override def addInt(value: Int): Unit = - parent.updateShort(fieldIndex, value.asInstanceOf[ShortType.InternalType]) - } - } - case ByteType => { - new CatalystPrimitiveConverter(parent, fieldIndex) { - override def addInt(value: Int): Unit = - parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.InternalType]) - } - } - case DateType => { - new CatalystPrimitiveConverter(parent, fieldIndex) { - override def addInt(value: Int): Unit = - parent.updateDate(fieldIndex, value.asInstanceOf[DateType.InternalType]) - } - } - case d: DecimalType => { - new CatalystPrimitiveConverter(parent, fieldIndex) { - override def addBinary(value: Binary): Unit = - parent.updateDecimal(fieldIndex, value, d) - } - } - case TimestampType => { - new CatalystPrimitiveConverter(parent, fieldIndex) { - override def addBinary(value: Binary): Unit = - parent.updateTimestamp(fieldIndex, value) - } - } - // All other primitive types use the default converter - case ctype: DataType if ParquetTypesConverter.isPrimitiveType(ctype) => { - // note: need the type tag here! - new CatalystPrimitiveConverter(parent, fieldIndex) - } - case _ => throw new RuntimeException( - s"unable to convert datatype ${field.dataType.toString} in CatalystConverter") - } - } - - protected[parquet] def createRootConverter( - parquetSchema: MessageType, - attributes: Seq[Attribute]): CatalystConverter = { - // For non-nested types we use the optimized Row converter - if (attributes.forall(a => ParquetTypesConverter.isPrimitiveType(a.dataType))) { - new CatalystPrimitiveRowConverter(attributes.toArray) - } else { - new CatalystGroupConverter(attributes.toArray) - } - } -} - -private[parquet] abstract class CatalystConverter extends GroupConverter { - /** - * The number of fields this group has - */ - protected[parquet] val size: Int - - /** - * The index of this converter in the parent - */ - protected[parquet] val index: Int - - /** - * The parent converter - */ - protected[parquet] val parent: CatalystConverter - - /** - * Called by child converters to update their value in its parent (this). - * Note that if possible the more specific update methods below should be used - * to avoid auto-boxing of native JVM types. - * - * @param fieldIndex - * @param value - */ - protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit - - protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateDate(fieldIndex: Int, value: Int): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateShort(fieldIndex: Int, value: Short): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateByte(fieldIndex: Int, value: Byte): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateDouble(fieldIndex: Int, value: Double): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateFloat(fieldIndex: Int, value: Float): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = - updateField(fieldIndex, value.getBytes) - - protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = - updateField(fieldIndex, UTF8String.fromBytes(value)) - - protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = - updateField(fieldIndex, readTimestamp(value)) - - protected[parquet] def updateDecimal(fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = - updateField(fieldIndex, readDecimal(new Decimal(), value, ctype)) - - protected[parquet] def isRootConverter: Boolean = parent == null - - protected[parquet] def clearBuffer(): Unit - - /** - * Should only be called in the root (group) converter! - * - * @return - */ - def getCurrentRecord: InternalRow = throw new UnsupportedOperationException - - /** - * Read a decimal value from a Parquet Binary into "dest". Only supports decimals that fit in - * a long (i.e. precision <= 18) - * - * Returned value is needed by CatalystConverter, which doesn't reuse the Decimal object. - */ - protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Decimal = { - val precision = ctype.precisionInfo.get.precision - val scale = ctype.precisionInfo.get.scale - val bytes = value.getBytes - require(bytes.length <= 16, "Decimal field too large to read") - var unscaled = 0L - var i = 0 - while (i < bytes.length) { - unscaled = (unscaled << 8) | (bytes(i) & 0xFF) - i += 1 - } - // Make sure unscaled has the right sign, by sign-extending the first bit - val numBits = 8 * bytes.length - unscaled = (unscaled << (64 - numBits)) >> (64 - numBits) - dest.set(unscaled, precision, scale) - } - - /** - * Read a Timestamp value from a Parquet Int96Value - */ - protected[parquet] def readTimestamp(value: Binary): Long = { - DateUtils.fromJavaTimestamp(CatalystTimestampConverter.convertToTimestamp(value)) - } -} - -/** - * A `parquet.io.api.GroupConverter` that is able to convert a Parquet record - * to a [[org.apache.spark.sql.catalyst.expressions.InternalRow]] object. - * - * @param schema The corresponding Catalyst schema in the form of a list of attributes. - */ -private[parquet] class CatalystGroupConverter( - protected[parquet] val schema: Array[FieldType], - protected[parquet] val index: Int, - protected[parquet] val parent: CatalystConverter, - protected[parquet] var current: ArrayBuffer[Any], - protected[parquet] var buffer: ArrayBuffer[InternalRow]) - extends CatalystConverter { - - def this(schema: Array[FieldType], index: Int, parent: CatalystConverter) = - this( - schema, - index, - parent, - current = null, - buffer = new ArrayBuffer[InternalRow]( - CatalystArrayConverter.INITIAL_ARRAY_SIZE)) - - /** - * This constructor is used for the root converter only! - */ - def this(attributes: Array[Attribute]) = - this(attributes.map(a => new FieldType(a.name, a.dataType, a.nullable)), 0, null) - - protected [parquet] val converters: Array[Converter] = - schema.zipWithIndex.map { - case (field, idx) => CatalystConverter.createConverter(field, idx, this) - }.toArray - - override val size = schema.size - - override def getCurrentRecord: InternalRow = { - assert(isRootConverter, "getCurrentRecord should only be called in root group converter!") - // TODO: use iterators if possible - // Note: this will ever only be called in the root converter when the record has been - // fully processed. Therefore it will be difficult to use mutable rows instead, since - // any non-root converter never would be sure when it would be safe to re-use the buffer. - new GenericRow(current.toArray) - } - - override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) - - // for child converters to update upstream values - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { - current.update(fieldIndex, value) - } - - override protected[parquet] def clearBuffer(): Unit = buffer.clear() - - override def start(): Unit = { - current = ArrayBuffer.fill(size)(null) - converters.foreach { converter => - if (!converter.isPrimitive) { - converter.asInstanceOf[CatalystConverter].clearBuffer() - } - } - } - - override def end(): Unit = { - if (!isRootConverter) { - assert(current != null) // there should be no empty groups - buffer.append(new GenericRow(current.toArray)) - parent.updateField(index, new GenericRow(buffer.toArray.asInstanceOf[Array[Any]])) - } - } -} - -/** - * A `parquet.io.api.GroupConverter` that is able to convert a Parquet record - * to a [[org.apache.spark.sql.catalyst.expressions.InternalRow]] object. Note that his - * converter is optimized for rows of primitive types (non-nested records). - */ -private[parquet] class CatalystPrimitiveRowConverter( - protected[parquet] val schema: Array[FieldType], - protected[parquet] var current: MutableRow) - extends CatalystConverter { - - // This constructor is used for the root converter only - def this(attributes: Array[Attribute]) = - this( - attributes.map(a => new FieldType(a.name, a.dataType, a.nullable)), - new SpecificMutableRow(attributes.map(_.dataType))) - - protected [parquet] val converters: Array[Converter] = - schema.zipWithIndex.map { - case (field, idx) => CatalystConverter.createConverter(field, idx, this) - }.toArray - - override val size = schema.size - - override val index = 0 - - override val parent = null - - // Should be only called in root group converter! - override def getCurrentRecord: InternalRow = current - - override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) - - // for child converters to update upstream values - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { - throw new UnsupportedOperationException // child converters should use the - // specific update methods below - } - - override protected[parquet] def clearBuffer(): Unit = {} - - override def start(): Unit = { - var i = 0 - while (i < size) { - current.setNullAt(i) - i = i + 1 - } - } - - override def end(): Unit = {} - - // Overridden here to avoid auto-boxing for primitive types - override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = - current.setBoolean(fieldIndex, value) - - override protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = - current.setInt(fieldIndex, value) - - override protected[parquet] def updateDate(fieldIndex: Int, value: Int): Unit = - current.setInt(fieldIndex, value) - - override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = - current.setLong(fieldIndex, value) - - override protected[parquet] def updateShort(fieldIndex: Int, value: Short): Unit = - current.setShort(fieldIndex, value) - - override protected[parquet] def updateByte(fieldIndex: Int, value: Byte): Unit = - current.setByte(fieldIndex, value) - - override protected[parquet] def updateDouble(fieldIndex: Int, value: Double): Unit = - current.setDouble(fieldIndex, value) - - override protected[parquet] def updateFloat(fieldIndex: Int, value: Float): Unit = - current.setFloat(fieldIndex, value) - - override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = - current.update(fieldIndex, value.getBytes) - - override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = - current.update(fieldIndex, UTF8String.fromBytes(value)) - - override protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = - current.setLong(fieldIndex, readTimestamp(value)) - - override protected[parquet] def updateDecimal( - fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { - var decimal = current(fieldIndex).asInstanceOf[Decimal] - if (decimal == null) { - decimal = new Decimal - current(fieldIndex) = decimal - } - readDecimal(decimal, value, ctype) - } -} - -/** - * A `parquet.io.api.PrimitiveConverter` that converts Parquet types to Catalyst types. - * - * @param parent The parent group converter. - * @param fieldIndex The index inside the record. - */ -private[parquet] class CatalystPrimitiveConverter( - parent: CatalystConverter, - fieldIndex: Int) extends PrimitiveConverter { - override def addBinary(value: Binary): Unit = - parent.updateBinary(fieldIndex, value) - - override def addBoolean(value: Boolean): Unit = - parent.updateBoolean(fieldIndex, value) - - override def addDouble(value: Double): Unit = - parent.updateDouble(fieldIndex, value) - - override def addFloat(value: Float): Unit = - parent.updateFloat(fieldIndex, value) - - override def addInt(value: Int): Unit = - parent.updateInt(fieldIndex, value) - - override def addLong(value: Long): Unit = - parent.updateLong(fieldIndex, value) -} - -/** - * A `parquet.io.api.PrimitiveConverter` that converts Parquet Binary to Catalyst String. - * Supports dictionaries to reduce Binary to String conversion overhead. - * - * Follows pattern in Parquet of using dictionaries, where supported, for String conversion. - * - * @param parent The parent group converter. - * @param fieldIndex The index inside the record. - */ -private[parquet] class CatalystPrimitiveStringConverter(parent: CatalystConverter, fieldIndex: Int) - extends CatalystPrimitiveConverter(parent, fieldIndex) { - - private[this] var dict: Array[Array[Byte]] = null - - override def hasDictionarySupport: Boolean = true - - override def setDictionary(dictionary: Dictionary): Unit = - dict = Array.tabulate(dictionary.getMaxId + 1) { dictionary.decodeToBinary(_).getBytes } - - override def addValueFromDictionary(dictionaryId: Int): Unit = - parent.updateString(fieldIndex, dict(dictionaryId)) - - override def addBinary(value: Binary): Unit = - parent.updateString(fieldIndex, value.getBytes) -} - -private[parquet] object CatalystArrayConverter { - val INITIAL_ARRAY_SIZE = 20 -} - -private[parquet] object CatalystTimestampConverter { - // TODO most part of this comes from Hive-0.14 - // Hive code might have some issues, so we need to keep an eye on it. - // Also we use NanoTime and Int96Values from parquet-examples. - // We utilize jodd to convert between NanoTime and Timestamp - val parquetTsCalendar = new ThreadLocal[Calendar] - def getCalendar: Calendar = { - // this is a cache for the calendar instance. - if (parquetTsCalendar.get == null) { - parquetTsCalendar.set(Calendar.getInstance(TimeZone.getTimeZone("GMT"))) - } - parquetTsCalendar.get - } - val NANOS_PER_SECOND: Long = 1000000000 - val SECONDS_PER_MINUTE: Long = 60 - val MINUTES_PER_HOUR: Long = 60 - val NANOS_PER_MILLI: Long = 1000000 - - def convertToTimestamp(value: Binary): Timestamp = { - val nt = NanoTime.fromBinary(value) - val timeOfDayNanos = nt.getTimeOfDayNanos - val julianDay = nt.getJulianDay - val jDateTime = new JDateTime(julianDay.toDouble) - val calendar = getCalendar - calendar.set(Calendar.YEAR, jDateTime.getYear) - calendar.set(Calendar.MONTH, jDateTime.getMonth - 1) - calendar.set(Calendar.DAY_OF_MONTH, jDateTime.getDay) - - // written in command style - var remainder = timeOfDayNanos - calendar.set( - Calendar.HOUR_OF_DAY, - (remainder / (NANOS_PER_SECOND * SECONDS_PER_MINUTE * MINUTES_PER_HOUR)).toInt) - remainder = remainder % (NANOS_PER_SECOND * SECONDS_PER_MINUTE * MINUTES_PER_HOUR) - calendar.set( - Calendar.MINUTE, (remainder / (NANOS_PER_SECOND * SECONDS_PER_MINUTE)).toInt) - remainder = remainder % (NANOS_PER_SECOND * SECONDS_PER_MINUTE) - calendar.set(Calendar.SECOND, (remainder / NANOS_PER_SECOND).toInt) - val nanos = remainder % NANOS_PER_SECOND - val ts = new Timestamp(calendar.getTimeInMillis) - ts.setNanos(nanos.toInt) - ts - } - - def convertFromTimestamp(ts: Timestamp): Binary = { - val calendar = getCalendar - calendar.setTime(ts) - val jDateTime = new JDateTime(calendar.get(Calendar.YEAR), - calendar.get(Calendar.MONTH) + 1, calendar.get(Calendar.DAY_OF_MONTH)) - // Hive-0.14 didn't set hour before get day number, while the day number should - // has something to do with hour, since julian day number grows at 12h GMT - // here we just follow what hive does. - val julianDay = jDateTime.getJulianDayNumber - - val hour = calendar.get(Calendar.HOUR_OF_DAY) - val minute = calendar.get(Calendar.MINUTE) - val second = calendar.get(Calendar.SECOND) - val nanos = ts.getNanos - // Hive-0.14 would use hours directly, that might be wrong, since the day starts - // from 12h in Julian. here we just follow what hive does. - val nanosOfDay = nanos + second * NANOS_PER_SECOND + - minute * NANOS_PER_SECOND * SECONDS_PER_MINUTE + - hour * NANOS_PER_SECOND * SECONDS_PER_MINUTE * MINUTES_PER_HOUR - NanoTime(julianDay, nanosOfDay).toBinary - } -} - -/** - * A `parquet.io.api.GroupConverter` that converts a single-element groups that - * match the characteristics of an array (see - * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an - * [[org.apache.spark.sql.types.ArrayType]]. - * - * @param elementType The type of the array elements (complex or primitive) - * @param index The position of this (array) field inside its parent converter - * @param parent The parent converter - * @param buffer A data buffer - */ -private[parquet] class CatalystArrayConverter( - val elementType: DataType, - val index: Int, - protected[parquet] val parent: CatalystConverter, - protected[parquet] var buffer: Buffer[Any]) - extends CatalystConverter { - - def this(elementType: DataType, index: Int, parent: CatalystConverter) = - this( - elementType, - index, - parent, - new ArrayBuffer[Any](CatalystArrayConverter.INITIAL_ARRAY_SIZE)) - - protected[parquet] val converter: Converter = CatalystConverter.createConverter( - new CatalystConverter.FieldType( - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, - elementType, - false), - fieldIndex = 0, - parent = this) - - override def getConverter(fieldIndex: Int): Converter = converter - - // arrays have only one (repeated) field, which is its elements - override val size = 1 - - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { - // fieldIndex is ignored (assumed to be zero but not checked) - if (value == null) { - throw new IllegalArgumentException("Null values inside Parquet arrays are not supported!") - } - buffer += value - } - - override protected[parquet] def clearBuffer(): Unit = { - buffer.clear() - } - - override def start(): Unit = { - if (!converter.isPrimitive) { - converter.asInstanceOf[CatalystConverter].clearBuffer() - } - } - - override def end(): Unit = { - assert(parent != null) - // here we need to make sure to use ArrayScalaType - parent.updateField(index, buffer.toArray.toSeq) - clearBuffer() - } -} - -/** - * A `parquet.io.api.GroupConverter` that converts a single-element groups that - * match the characteristics of an array (see - * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an - * [[org.apache.spark.sql.types.ArrayType]]. - * - * @param elementType The type of the array elements (native) - * @param index The position of this (array) field inside its parent converter - * @param parent The parent converter - * @param capacity The (initial) capacity of the buffer - */ -private[parquet] class CatalystNativeArrayConverter( - val elementType: AtomicType, - val index: Int, - protected[parquet] val parent: CatalystConverter, - protected[parquet] var capacity: Int = CatalystArrayConverter.INITIAL_ARRAY_SIZE) - extends CatalystConverter { - - type NativeType = elementType.InternalType - - private var buffer: Array[NativeType] = elementType.classTag.newArray(capacity) - - private var elements: Int = 0 - - protected[parquet] val converter: Converter = CatalystConverter.createConverter( - new CatalystConverter.FieldType( - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, - elementType, - false), - fieldIndex = 0, - parent = this) - - override def getConverter(fieldIndex: Int): Converter = converter - - // arrays have only one (repeated) field, which is its elements - override val size = 1 - - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = - throw new UnsupportedOperationException - - // Overridden here to avoid auto-boxing for primitive types - override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateShort(fieldIndex: Int, value: Short): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateByte(fieldIndex: Int, value: Byte): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateDouble(fieldIndex: Int, value: Double): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateFloat(fieldIndex: Int, value: Float): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = { - checkGrowBuffer() - buffer(elements) = value.getBytes.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = { - checkGrowBuffer() - buffer(elements) = UTF8String.fromBytes(value).asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def clearBuffer(): Unit = { - elements = 0 - } - - override def start(): Unit = {} - - override def end(): Unit = { - assert(parent != null) - // here we need to make sure to use ArrayScalaType - parent.updateField( - index, - buffer.slice(0, elements).toSeq) - clearBuffer() - } - - private def checkGrowBuffer(): Unit = { - if (elements >= capacity) { - val newCapacity = 2 * capacity - val tmp: Array[NativeType] = elementType.classTag.newArray(newCapacity) - Array.copy(buffer, 0, tmp, 0, capacity) - buffer = tmp - capacity = newCapacity - } - } -} - -/** - * A `parquet.io.api.GroupConverter` that converts a single-element groups that - * match the characteristics of an array contains null (see - * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an - * [[org.apache.spark.sql.types.ArrayType]]. - * - * @param elementType The type of the array elements (complex or primitive) - * @param index The position of this (array) field inside its parent converter - * @param parent The parent converter - * @param buffer A data buffer - */ -private[parquet] class CatalystArrayContainsNullConverter( - val elementType: DataType, - val index: Int, - protected[parquet] val parent: CatalystConverter, - protected[parquet] var buffer: Buffer[Any]) - extends CatalystConverter { - - def this(elementType: DataType, index: Int, parent: CatalystConverter) = - this( - elementType, - index, - parent, - new ArrayBuffer[Any](CatalystArrayConverter.INITIAL_ARRAY_SIZE)) - - protected[parquet] val converter: Converter = new CatalystConverter { - - private var current: Any = null - - val converter = CatalystConverter.createConverter( - new CatalystConverter.FieldType( - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, - elementType, - false), - fieldIndex = 0, - parent = this) - - override def getConverter(fieldIndex: Int): Converter = converter - - override def end(): Unit = parent.updateField(index, current) - - override def start(): Unit = { - current = null - } - - override protected[parquet] val size: Int = 1 - override protected[parquet] val index: Int = 0 - override protected[parquet] val parent = CatalystArrayContainsNullConverter.this - - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { - current = value - } - - override protected[parquet] def clearBuffer(): Unit = {} - } - - override def getConverter(fieldIndex: Int): Converter = converter - - // arrays have only one (repeated) field, which is its elements - override val size = 1 - - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { - buffer += value - } - - override protected[parquet] def clearBuffer(): Unit = { - buffer.clear() - } - - override def start(): Unit = {} - - override def end(): Unit = { - assert(parent != null) - // here we need to make sure to use ArrayScalaType - parent.updateField(index, buffer.toArray.toSeq) - clearBuffer() - } -} - -/** - * This converter is for multi-element groups of primitive or complex types - * that have repetition level optional or required (so struct fields). - * - * @param schema The corresponding Catalyst schema in the form of a list of - * attributes. - * @param index - * @param parent - */ -private[parquet] class CatalystStructConverter( - override protected[parquet] val schema: Array[FieldType], - override protected[parquet] val index: Int, - override protected[parquet] val parent: CatalystConverter) - extends CatalystGroupConverter(schema, index, parent) { - - override protected[parquet] def clearBuffer(): Unit = {} - - // TODO: think about reusing the buffer - override def end(): Unit = { - assert(!isRootConverter) - // here we need to make sure to use StructScalaType - // Note: we need to actually make a copy of the array since we - // may be in a nested field - parent.updateField(index, new GenericRow(current.toArray)) - } -} - -/** - * A `parquet.io.api.GroupConverter` that converts two-element groups that - * match the characteristics of a map (see - * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an - * [[org.apache.spark.sql.types.MapType]]. - * - * @param schema - * @param index - * @param parent - */ -private[parquet] class CatalystMapConverter( - protected[parquet] val schema: Array[FieldType], - override protected[parquet] val index: Int, - override protected[parquet] val parent: CatalystConverter) - extends CatalystConverter { - - private val map = new HashMap[Any, Any]() - - private val keyValueConverter = new CatalystConverter { - private var currentKey: Any = null - private var currentValue: Any = null - val keyConverter = CatalystConverter.createConverter(schema(0), 0, this) - val valueConverter = CatalystConverter.createConverter(schema(1), 1, this) - - override def getConverter(fieldIndex: Int): Converter = { - if (fieldIndex == 0) keyConverter else valueConverter - } - - override def end(): Unit = CatalystMapConverter.this.map += currentKey -> currentValue - - override def start(): Unit = { - currentKey = null - currentValue = null - } - - override protected[parquet] val size: Int = 2 - override protected[parquet] val index: Int = 0 - override protected[parquet] val parent: CatalystConverter = CatalystMapConverter.this - - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { - fieldIndex match { - case 0 => - currentKey = value - case 1 => - currentValue = value - case _ => - new RuntimePermission(s"trying to update Map with fieldIndex $fieldIndex") - } - } - - override protected[parquet] def clearBuffer(): Unit = {} - } - - override protected[parquet] val size: Int = 1 - - override protected[parquet] def clearBuffer(): Unit = {} - - override def start(): Unit = { - map.clear() - } - - override def end(): Unit = { - // here we need to make sure to use MapScalaType - parent.updateField(index, map.toMap) - } - - override def getConverter(fieldIndex: Int): Converter = keyValueConverter - - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = - throw new UnsupportedOperationException -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala deleted file mode 100644 index 704cf56f3826..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ /dev/null @@ -1,222 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.io.IOException -import java.util.logging.{Level, Logger => JLogger} - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.fs.permission.FsAction -import org.apache.parquet.hadoop.metadata.CompressionCodecName -import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat, ParquetRecordReader} -import org.apache.parquet.schema.MessageType -import org.apache.parquet.{Log => ParquetLog} - -import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, SQLContext} - -/** - * Relation that consists of data stored in a Parquet columnar format. - * - * Users should interact with parquet files though a [[DataFrame]], created by a [[SQLContext]] - * instead of using this class directly. - * - * {{{ - * val parquetRDD = sqlContext.parquetFile("path/to/parquet.file") - * }}} - * - * @param path The path to the Parquet file. - */ -private[sql] case class ParquetRelation( - path: String, - @transient conf: Option[Configuration], - @transient sqlContext: SQLContext, - partitioningAttributes: Seq[Attribute] = Nil) - extends LeafNode with MultiInstanceRelation { - - self: Product => - - /** Schema derived from ParquetFile */ - def parquetSchema: MessageType = - ParquetTypesConverter - .readMetaData(new Path(path), conf) - .getFileMetaData - .getSchema - - /** Attributes */ - override val output = - partitioningAttributes ++ - ParquetTypesConverter.readSchemaFromFile( - new Path(path.split(",").head), - conf, - sqlContext.conf.isParquetBinaryAsString, - sqlContext.conf.isParquetINT96AsTimestamp) - lazy val attributeMap = AttributeMap(output.map(o => o -> o)) - - override def newInstance(): this.type = { - ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] - } - - // Equals must also take into account the output attributes so that we can distinguish between - // different instances of the same relation, - override def equals(other: Any): Boolean = other match { - case p: ParquetRelation => - p.path == path && p.output == output - case _ => false - } - - override def hashCode: Int = { - com.google.common.base.Objects.hashCode(path, output) - } - - // TODO: Use data from the footers. - override lazy val statistics = Statistics(sizeInBytes = sqlContext.conf.defaultSizeInBytes) -} - -private[sql] object ParquetRelation { - - def enableLogForwarding() { - // Note: the org.apache.parquet.Log class has a static initializer that - // sets the java.util.logging Logger for "org.apache.parquet". This - // checks first to see if there's any handlers already set - // and if not it creates them. If this method executes prior - // to that class being loaded then: - // 1) there's no handlers installed so there's none to - // remove. But when it IS finally loaded the desired affect - // of removing them is circumvented. - // 2) The parquet.Log static initializer calls setUseParentHandlers(false) - // undoing the attempt to override the logging here. - // - // Therefore we need to force the class to be loaded. - // This should really be resolved by Parquet. - Class.forName(classOf[ParquetLog].getName) - - // Note: Logger.getLogger("parquet") has a default logger - // that appends to Console which needs to be cleared. - val parquetLogger = JLogger.getLogger(classOf[ParquetLog].getPackage.getName) - parquetLogger.getHandlers.foreach(parquetLogger.removeHandler) - parquetLogger.setUseParentHandlers(true) - - // Disables a WARN log message in ParquetOutputCommitter. We first ensure that - // ParquetOutputCommitter is loaded and the static LOG field gets initialized. - // See https://issues.apache.org/jira/browse/SPARK-5968 for details - Class.forName(classOf[ParquetOutputCommitter].getName) - JLogger.getLogger(classOf[ParquetOutputCommitter].getName).setLevel(Level.OFF) - - // Similar as above, disables a unnecessary WARN log message in ParquetRecordReader. - // See https://issues.apache.org/jira/browse/PARQUET-220 for details - Class.forName(classOf[ParquetRecordReader[_]].getName) - JLogger.getLogger(classOf[ParquetRecordReader[_]].getName).setLevel(Level.OFF) - } - - // The element type for the RDDs that this relation maps to. - type RowType = org.apache.spark.sql.catalyst.expressions.GenericMutableRow - - // The compression type - type CompressionType = org.apache.parquet.hadoop.metadata.CompressionCodecName - - // The parquet compression short names - val shortParquetCompressionCodecNames = Map( - "NONE" -> CompressionCodecName.UNCOMPRESSED, - "UNCOMPRESSED" -> CompressionCodecName.UNCOMPRESSED, - "SNAPPY" -> CompressionCodecName.SNAPPY, - "GZIP" -> CompressionCodecName.GZIP, - "LZO" -> CompressionCodecName.LZO) - - /** - * Creates a new ParquetRelation and underlying Parquetfile for the given LogicalPlan. Note that - * this is used inside [[org.apache.spark.sql.execution.SparkStrategies SparkStrategies]] to - * create a resolved relation as a data sink for writing to a Parquetfile. The relation is empty - * but is initialized with ParquetMetadata and can be inserted into. - * - * @param pathString The directory the Parquetfile will be stored in. - * @param child The child node that will be used for extracting the schema. - * @param conf A configuration to be used. - * @return An empty ParquetRelation with inferred metadata. - */ - def create(pathString: String, - child: LogicalPlan, - conf: Configuration, - sqlContext: SQLContext): ParquetRelation = { - if (!child.resolved) { - throw new UnresolvedException[LogicalPlan]( - child, - "Attempt to create Parquet table from unresolved child (when schema is not available)") - } - createEmpty(pathString, child.output, false, conf, sqlContext) - } - - /** - * Creates an empty ParquetRelation and underlying Parquetfile that only - * consists of the Metadata for the given schema. - * - * @param pathString The directory the Parquetfile will be stored in. - * @param attributes The schema of the relation. - * @param conf A configuration to be used. - * @return An empty ParquetRelation. - */ - def createEmpty(pathString: String, - attributes: Seq[Attribute], - allowExisting: Boolean, - conf: Configuration, - sqlContext: SQLContext): ParquetRelation = { - val path = checkPath(pathString, allowExisting, conf) - conf.set(ParquetOutputFormat.COMPRESSION, shortParquetCompressionCodecNames.getOrElse( - sqlContext.conf.parquetCompressionCodec.toUpperCase, CompressionCodecName.UNCOMPRESSED) - .name()) - ParquetRelation.enableLogForwarding() - // This is a hack. We always set nullable/containsNull/valueContainsNull to true - // for the schema of a parquet data. - val schema = StructType.fromAttributes(attributes).asNullable - val newAttributes = schema.toAttributes - ParquetTypesConverter.writeMetaData(newAttributes, path, conf) - new ParquetRelation(path.toString, Some(conf), sqlContext) { - override val output = newAttributes - } - } - - private def checkPath(pathStr: String, allowExisting: Boolean, conf: Configuration): Path = { - if (pathStr == null) { - throw new IllegalArgumentException("Unable to create ParquetRelation: path is null") - } - val origPath = new Path(pathStr) - val fs = origPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException( - s"Unable to create ParquetRelation: incorrectly formatted path $pathStr") - } - val path = origPath.makeQualified(fs) - if (!allowExisting && fs.exists(path)) { - sys.error(s"File $pathStr already exists.") - } - - if (fs.exists(path) && - !fs.getFileStatus(path) - .getPermission - .getUserAction - .implies(FsAction.READ_WRITE)) { - throw new IOException( - s"Unable to create ParquetRelation: path $path not read-writable") - } - path - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala deleted file mode 100644 index b30fc171c0af..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ /dev/null @@ -1,508 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.io.IOException -import java.lang.{Long => JLong} -import java.text.{NumberFormat, SimpleDateFormat} -import java.util.concurrent.{Callable, TimeUnit} -import java.util.{Date, List => JList} - -import scala.collection.JavaConversions._ -import scala.collection.mutable -import scala.util.Try - -import com.google.common.cache.CacheBuilder -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path} -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} -import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat => NewFileOutputFormat} -import org.apache.parquet.hadoop._ -import org.apache.parquet.hadoop.api.ReadSupport.ReadContext -import org.apache.parquet.hadoop.api.{InitContext, ReadSupport} -import org.apache.parquet.hadoop.metadata.GlobalMetaData -import org.apache.parquet.hadoop.util.ContextUtil -import org.apache.parquet.io.ParquetDecodingException -import org.apache.parquet.schema.MessageType - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, InternalRow, _} -import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} -import org.apache.spark.sql.types.StructType -import org.apache.spark.{Logging, TaskContext} -import org.apache.spark.util.SerializableConfiguration - -/** - * :: DeveloperApi :: - * Parquet table scan operator. Imports the file that backs the given - * [[org.apache.spark.sql.parquet.ParquetRelation]] as a ``RDD[InternalRow]``. - */ -private[sql] case class ParquetTableScan( - attributes: Seq[Attribute], - relation: ParquetRelation, - columnPruningPred: Seq[Expression]) - extends LeafNode { - - // The resolution of Parquet attributes is case sensitive, so we resolve the original attributes - // by exprId. note: output cannot be transient, see - // https://issues.apache.org/jira/browse/SPARK-1367 - val output = attributes.map(relation.attributeMap) - - // A mapping of ordinals partitionRow -> finalOutput. - val requestedPartitionOrdinals = { - val partitionAttributeOrdinals = AttributeMap(relation.partitioningAttributes.zipWithIndex) - - attributes.zipWithIndex.flatMap { - case (attribute, finalOrdinal) => - partitionAttributeOrdinals.get(attribute).map(_ -> finalOrdinal) - } - }.toArray - - protected override def doExecute(): RDD[InternalRow] = { - import org.apache.parquet.filter2.compat.FilterCompat.FilterPredicateCompat - - val sc = sqlContext.sparkContext - val job = new Job(sc.hadoopConfiguration) - ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) - - val conf: Configuration = ContextUtil.getConfiguration(job) - - relation.path.split(",").foreach { curPath => - val qualifiedPath = { - val path = new Path(curPath) - path.getFileSystem(conf).makeQualified(path) - } - NewFileInputFormat.addInputPath(job, qualifiedPath) - } - - // Store both requested and original schema in `Configuration` - conf.set( - RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - ParquetTypesConverter.convertToString(output)) - conf.set( - RowWriteSupport.SPARK_ROW_SCHEMA, - ParquetTypesConverter.convertToString(relation.output)) - - // Store record filtering predicate in `Configuration` - // Note 1: the input format ignores all predicates that cannot be expressed - // as simple column predicate filters in Parquet. Here we just record - // the whole pruning predicate. - ParquetFilters - .createRecordFilter(columnPruningPred) - .map(_.asInstanceOf[FilterPredicateCompat].getFilterPredicate) - // Set this in configuration of ParquetInputFormat, needed for RowGroupFiltering - .foreach(ParquetInputFormat.setFilterPredicate(conf, _)) - - // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata - conf.setBoolean( - SQLConf.PARQUET_CACHE_METADATA.key, - sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, true)) - - // Use task side metadata in parquet - conf.setBoolean(ParquetInputFormat.TASK_SIDE_METADATA, true) - - val baseRDD = - new org.apache.spark.rdd.NewHadoopRDD( - sc, - classOf[FilteringParquetRowInputFormat], - classOf[Void], - classOf[InternalRow], - conf) - - if (requestedPartitionOrdinals.nonEmpty) { - // This check is based on CatalystConverter.createRootConverter. - val primitiveRow = output.forall(a => ParquetTypesConverter.isPrimitiveType(a.dataType)) - - // Uses temporary variable to avoid the whole `ParquetTableScan` object being captured into - // the `mapPartitionsWithInputSplit` closure below. - val outputSize = output.size - - baseRDD.mapPartitionsWithInputSplit { case (split, iter) => - val partValue = "([^=]+)=([^=]+)".r - val partValues = - split.asInstanceOf[org.apache.parquet.hadoop.ParquetInputSplit] - .getPath - .toString - .split("/") - .flatMap { - case partValue(key, value) => Some(key -> value) - case _ => None - }.toMap - - // Convert the partitioning attributes into the correct types - val partitionRowValues = - relation.partitioningAttributes - .map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow)) - - if (primitiveRow) { - new Iterator[InternalRow] { - def hasNext: Boolean = iter.hasNext - def next(): InternalRow = { - // We are using CatalystPrimitiveRowConverter and it returns a SpecificMutableRow. - val row = iter.next()._2.asInstanceOf[SpecificMutableRow] - - // Parquet will leave partitioning columns empty, so we fill them in here. - var i = 0 - while (i < requestedPartitionOrdinals.size) { - row(requestedPartitionOrdinals(i)._2) = - partitionRowValues(requestedPartitionOrdinals(i)._1) - i += 1 - } - row - } - } - } else { - // Create a mutable row since we need to fill in values from partition columns. - val mutableRow = new GenericMutableRow(outputSize) - new Iterator[InternalRow] { - def hasNext: Boolean = iter.hasNext - def next(): InternalRow = { - // We are using CatalystGroupConverter and it returns a GenericRow. - // Since GenericRow is not mutable, we just cast it to a Row. - val row = iter.next()._2.asInstanceOf[InternalRow] - - var i = 0 - while (i < row.size) { - mutableRow(i) = row(i) - i += 1 - } - // Parquet will leave partitioning columns empty, so we fill them in here. - i = 0 - while (i < requestedPartitionOrdinals.size) { - mutableRow(requestedPartitionOrdinals(i)._2) = - partitionRowValues(requestedPartitionOrdinals(i)._1) - i += 1 - } - mutableRow - } - } - } - } - } else { - baseRDD.map(_._2) - } - } - - /** - * Applies a (candidate) projection. - * - * @param prunedAttributes The list of attributes to be used in the projection. - * @return Pruned TableScan. - */ - def pruneColumns(prunedAttributes: Seq[Attribute]): ParquetTableScan = { - val success = validateProjection(prunedAttributes) - if (success) { - ParquetTableScan(prunedAttributes, relation, columnPruningPred) - } else { - sys.error("Warning: Could not validate Parquet schema projection in pruneColumns") - } - } - - /** - * Evaluates a candidate projection by checking whether the candidate is a subtype - * of the original type. - * - * @param projection The candidate projection. - * @return True if the projection is valid, false otherwise. - */ - private def validateProjection(projection: Seq[Attribute]): Boolean = { - val original: MessageType = relation.parquetSchema - val candidate: MessageType = ParquetTypesConverter.convertFromAttributes(projection) - Try(original.checkContains(candidate)).isSuccess - } -} - -/** - * :: DeveloperApi :: - * Operator that acts as a sink for queries on RDDs and can be used to - * store the output inside a directory of Parquet files. This operator - * is similar to Hive's INSERT INTO TABLE operation in the sense that - * one can choose to either overwrite or append to a directory. Note - * that consecutive insertions to the same table must have compatible - * (source) schemas. - * - * WARNING: EXPERIMENTAL! InsertIntoParquetTable with overwrite=false may - * cause data corruption in the case that multiple users try to append to - * the same table simultaneously. Inserting into a table that was - * previously generated by other means (e.g., by creating an HDFS - * directory and importing Parquet files generated by other tools) may - * cause unpredicted behaviour and therefore results in a RuntimeException - * (only detected via filename pattern so will not catch all cases). - */ -@DeveloperApi -private[sql] case class InsertIntoParquetTable( - relation: ParquetRelation, - child: SparkPlan, - overwrite: Boolean = false) - extends UnaryNode with SparkHadoopMapReduceUtil { - - /** - * Inserts all rows into the Parquet file. - */ - protected override def doExecute(): RDD[InternalRow] = { - // TODO: currently we do not check whether the "schema"s are compatible - // That means if one first creates a table and then INSERTs data with - // and incompatible schema the execution will fail. It would be nice - // to catch this early one, maybe having the planner validate the schema - // before calling execute(). - - val childRdd = child.execute() - assert(childRdd != null) - - val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - - val writeSupport = - if (child.output.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) { - log.debug("Initializing MutableRowWriteSupport") - classOf[org.apache.spark.sql.parquet.MutableRowWriteSupport] - } else { - classOf[org.apache.spark.sql.parquet.RowWriteSupport] - } - - ParquetOutputFormat.setWriteSupportClass(job, writeSupport) - - val conf = ContextUtil.getConfiguration(job) - // This is a hack. We always set nullable/containsNull/valueContainsNull to true - // for the schema of a parquet data. - val schema = StructType.fromAttributes(relation.output).asNullable - RowWriteSupport.setSchema(schema.toAttributes, conf) - - val fspath = new Path(relation.path) - val fs = fspath.getFileSystem(conf) - - if (overwrite) { - try { - fs.delete(fspath, true) - } catch { - case e: IOException => - throw new IOException( - s"Unable to clear output directory ${fspath.toString} prior" - + s" to InsertIntoParquetTable:\n${e.toString}") - } - } - saveAsHadoopFile(childRdd, relation.path.toString, conf) - - // We return the child RDD to allow chaining (alternatively, one could return nothing). - childRdd - } - - override def output: Seq[Attribute] = child.output - - /** - * Stores the given Row RDD as a Hadoop file. - * - * Note: We cannot use ``saveAsNewAPIHadoopFile`` from [[org.apache.spark.rdd.PairRDDFunctions]] - * together with [[org.apache.spark.util.MutablePair]] because ``PairRDDFunctions`` uses - * ``Tuple2`` and not ``Product2``. Also, we want to allow appending files to an existing - * directory and need to determine which was the largest written file index before starting to - * write. - * - * @param rdd The [[org.apache.spark.rdd.RDD]] to writer - * @param path The directory to write to. - * @param conf A [[org.apache.hadoop.conf.Configuration]]. - */ - private def saveAsHadoopFile( - rdd: RDD[InternalRow], - path: String, - conf: Configuration) { - val job = new Job(conf) - val keyType = classOf[Void] - job.setOutputKeyClass(keyType) - job.setOutputValueClass(classOf[InternalRow]) - NewFileOutputFormat.setOutputPath(job, new Path(path)) - val wrappedConf = new SerializableConfiguration(job.getConfiguration) - val formatter = new SimpleDateFormat("yyyyMMddHHmm") - val jobtrackerID = formatter.format(new Date()) - val stageId = sqlContext.sparkContext.newRddId() - - val taskIdOffset = - if (overwrite) { - 1 - } else { - FileSystemHelper - .findMaxTaskId(NewFileOutputFormat.getOutputPath(job).toString, job.getConfiguration) + 1 - } - - def writeShard(context: TaskContext, iter: Iterator[InternalRow]): Int = { - /* "reduce task" */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, - context.attemptNumber) - val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) - val format = new AppendingParquetOutputFormat(taskIdOffset) - val committer = format.getOutputCommitter(hadoopContext) - committer.setupTask(hadoopContext) - val writer = format.getRecordWriter(hadoopContext) - try { - while (iter.hasNext) { - val row = iter.next() - writer.write(null, row) - } - } finally { - writer.close(hadoopContext) - } - SparkHadoopMapRedUtil.commitTask(committer, hadoopContext, context) - 1 - } - val jobFormat = new AppendingParquetOutputFormat(taskIdOffset) - /* apparently we need a TaskAttemptID to construct an OutputCommitter; - * however we're only going to use this local OutputCommitter for - * setupJob/commitJob, so we just use a dummy "map" task. - */ - val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0) - val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) - val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) - jobCommitter.setupJob(jobTaskContext) - sqlContext.sparkContext.runJob(rdd, writeShard _) - jobCommitter.commitJob(jobTaskContext) - } -} - -/** - * TODO: this will be able to append to directories it created itself, not necessarily - * to imported ones. - */ -private[parquet] class AppendingParquetOutputFormat(offset: Int) - extends org.apache.parquet.hadoop.ParquetOutputFormat[InternalRow] { - // override to accept existing directories as valid output directory - override def checkOutputSpecs(job: JobContext): Unit = {} - var committer: OutputCommitter = null - - // override to choose output filename so not overwrite existing ones - override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val numfmt = NumberFormat.getInstance() - numfmt.setMinimumIntegerDigits(5) - numfmt.setGroupingUsed(false) - - val taskId: TaskID = getTaskAttemptID(context).getTaskID - val partition: Int = taskId.getId - val filename = "part-r-" + numfmt.format(partition + offset) + ".parquet" - val committer: FileOutputCommitter = - getOutputCommitter(context).asInstanceOf[FileOutputCommitter] - new Path(committer.getWorkPath, filename) - } - - // The TaskAttemptContext is a class in hadoop-1 but is an interface in hadoop-2. - // The signatures of the method TaskAttemptContext.getTaskAttemptID for the both versions - // are the same, so the method calls are source-compatible but NOT binary-compatible because - // the opcode of method call for class is INVOKEVIRTUAL and for interface is INVOKEINTERFACE. - private def getTaskAttemptID(context: TaskAttemptContext): TaskAttemptID = { - context.getClass.getMethod("getTaskAttemptID").invoke(context).asInstanceOf[TaskAttemptID] - } - - // override to create output committer from configuration - override def getOutputCommitter(context: TaskAttemptContext): OutputCommitter = { - if (committer == null) { - val output = getOutputPath(context) - val cls = context.getConfiguration.getClass("spark.sql.parquet.output.committer.class", - classOf[ParquetOutputCommitter], classOf[ParquetOutputCommitter]) - val ctor = cls.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) - committer = ctor.newInstance(output, context).asInstanceOf[ParquetOutputCommitter] - } - committer - } - - // FileOutputFormat.getOutputPath takes JobConf in hadoop-1 but JobContext in hadoop-2 - private def getOutputPath(context: TaskAttemptContext): Path = { - context.getConfiguration().get("mapred.output.dir") match { - case null => null - case name => new Path(name) - } - } -} - -/** - * We extend ParquetInputFormat in order to have more control over which - * RecordFilter we want to use. - */ -private[parquet] class FilteringParquetRowInputFormat - extends org.apache.parquet.hadoop.ParquetInputFormat[InternalRow] with Logging { - - private var fileStatuses = Map.empty[Path, FileStatus] - - override def createRecordReader( - inputSplit: InputSplit, - taskAttemptContext: TaskAttemptContext): RecordReader[Void, InternalRow] = { - - import org.apache.parquet.filter2.compat.FilterCompat.NoOpFilter - - val readSupport: ReadSupport[InternalRow] = new RowReadSupport() - - val filter = ParquetInputFormat.getFilter(ContextUtil.getConfiguration(taskAttemptContext)) - if (!filter.isInstanceOf[NoOpFilter]) { - new ParquetRecordReader[InternalRow]( - readSupport, - filter) - } else { - new ParquetRecordReader[InternalRow](readSupport) - } - } - -} - -private[parquet] object FilteringParquetRowInputFormat { - private val footerCache = CacheBuilder.newBuilder() - .maximumSize(20000) - .build[FileStatus, Footer]() - - private val blockLocationCache = CacheBuilder.newBuilder() - .maximumSize(20000) - .expireAfterWrite(15, TimeUnit.MINUTES) // Expire locations since HDFS files might move - .build[FileStatus, Array[BlockLocation]]() -} - -private[parquet] object FileSystemHelper { - def listFiles(pathStr: String, conf: Configuration): Seq[Path] = { - val origPath = new Path(pathStr) - val fs = origPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException( - s"ParquetTableOperations: Path $origPath is incorrectly formatted") - } - val path = origPath.makeQualified(fs) - if (!fs.exists(path) || !fs.getFileStatus(path).isDir) { - throw new IllegalArgumentException( - s"ParquetTableOperations: path $path does not exist or is not a directory") - } - fs.globStatus(path) - .flatMap { status => if (status.isDir) fs.listStatus(status.getPath) else List(status) } - .map(_.getPath) - } - - /** - * Finds the maximum taskid in the output file names at the given path. - */ - def findMaxTaskId(pathStr: String, conf: Configuration): Int = { - val files = FileSystemHelper.listFiles(pathStr, conf) - // filename pattern is part-r-.parquet - val nameP = new scala.util.matching.Regex("""part-.-(\d{1,}).*""", "taskid") - val hiddenFileP = new scala.util.matching.Regex("_.*") - files.map(_.getName).map { - case nameP(taskid) => taskid.toInt - case hiddenFileP() => 0 - case other: String => - sys.error("ERROR: attempting to append to set of Parquet files and found file" + - s"that does not match name pattern: $other") - case _ => 0 - }.reduceOption(_ max _).getOrElse(0) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala deleted file mode 100644 index ba2a35b74ef8..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ /dev/null @@ -1,543 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.io.IOException - -import scala.collection.JavaConversions._ -import scala.util.Try - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.mapreduce.Job -import org.apache.parquet.format.converter.ParquetMetadataConverter -import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} -import org.apache.parquet.hadoop.util.ContextUtil -import org.apache.parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} -import org.apache.parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} -import org.apache.parquet.schema.Type.Repetition -import org.apache.parquet.schema.{ConversionPatterns, DecimalMetadata, GroupType => ParquetGroupType, MessageType, OriginalType => ParquetOriginalType, PrimitiveType => ParquetPrimitiveType, Type => ParquetType, Types => ParquetTypes} - -import org.apache.spark.Logging -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.types._ - - -/** A class representing Parquet info fields we care about, for passing back to Parquet */ -private[parquet] case class ParquetTypeInfo( - primitiveType: ParquetPrimitiveTypeName, - originalType: Option[ParquetOriginalType] = None, - decimalMetadata: Option[DecimalMetadata] = None, - length: Option[Int] = None) - -private[parquet] object ParquetTypesConverter extends Logging { - def isPrimitiveType(ctype: DataType): Boolean = ctype match { - case _: NumericType | BooleanType | StringType | BinaryType => true - case _: DataType => false - } - - def toPrimitiveDataType( - parquetType: ParquetPrimitiveType, - binaryAsString: Boolean, - int96AsTimestamp: Boolean): DataType = { - val originalType = parquetType.getOriginalType - val decimalInfo = parquetType.getDecimalMetadata - parquetType.getPrimitiveTypeName match { - case ParquetPrimitiveTypeName.BINARY - if (originalType == ParquetOriginalType.UTF8 || binaryAsString) => StringType - case ParquetPrimitiveTypeName.BINARY => BinaryType - case ParquetPrimitiveTypeName.BOOLEAN => BooleanType - case ParquetPrimitiveTypeName.DOUBLE => DoubleType - case ParquetPrimitiveTypeName.FLOAT => FloatType - case ParquetPrimitiveTypeName.INT32 - if originalType == ParquetOriginalType.DATE => DateType - case ParquetPrimitiveTypeName.INT32 => IntegerType - case ParquetPrimitiveTypeName.INT64 => LongType - case ParquetPrimitiveTypeName.INT96 if int96AsTimestamp => TimestampType - case ParquetPrimitiveTypeName.INT96 => - // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? - throw new AnalysisException("Potential loss of precision: cannot convert INT96") - case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY - if (originalType == ParquetOriginalType.DECIMAL && decimalInfo.getPrecision <= 18) => - // TODO: for now, our reader only supports decimals that fit in a Long - DecimalType(decimalInfo.getPrecision, decimalInfo.getScale) - case _ => throw new AnalysisException(s"Unsupported parquet datatype $parquetType") - } - } - - /** - * Converts a given Parquet `Type` into the corresponding - * [[org.apache.spark.sql.types.DataType]]. - * - * We apply the following conversion rules: - *
        - *
      • Primitive types are converter to the corresponding primitive type.
      • - *
      • Group types that have a single field that is itself a group, which has repetition - * level `REPEATED`, are treated as follows:
          - *
        • If the nested group has name `values`, the surrounding group is converted - * into an [[ArrayType]] with the corresponding field type (primitive or - * complex) as element type.
        • - *
        • If the nested group has name `map` and two fields (named `key` and `value`), - * the surrounding group is converted into a [[MapType]] - * with the corresponding key and value (value possibly complex) types. - * Note that we currently assume map values are not nullable.
        • - *
        • Other group types are converted into a [[StructType]] with the corresponding - * field types.
      • - *
      - * Note that fields are determined to be `nullable` if and only if their Parquet repetition - * level is not `REQUIRED`. - * - * @param parquetType The type to convert. - * @return The corresponding Catalyst type. - */ - def toDataType(parquetType: ParquetType, - isBinaryAsString: Boolean, - isInt96AsTimestamp: Boolean): DataType = { - def correspondsToMap(groupType: ParquetGroupType): Boolean = { - if (groupType.getFieldCount != 1 || groupType.getFields.apply(0).isPrimitive) { - false - } else { - // This mostly follows the convention in ``parquet.schema.ConversionPatterns`` - val keyValueGroup = groupType.getFields.apply(0).asGroupType() - keyValueGroup.getRepetition == Repetition.REPEATED && - keyValueGroup.getName == CatalystConverter.MAP_SCHEMA_NAME && - keyValueGroup.getFieldCount == 2 && - keyValueGroup.getFields.apply(0).getName == CatalystConverter.MAP_KEY_SCHEMA_NAME && - keyValueGroup.getFields.apply(1).getName == CatalystConverter.MAP_VALUE_SCHEMA_NAME - } - } - - def correspondsToArray(groupType: ParquetGroupType): Boolean = { - groupType.getFieldCount == 1 && - groupType.getFieldName(0) == CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME && - groupType.getFields.apply(0).getRepetition == Repetition.REPEATED - } - - if (parquetType.isPrimitive) { - toPrimitiveDataType(parquetType.asPrimitiveType, isBinaryAsString, isInt96AsTimestamp) - } else { - val groupType = parquetType.asGroupType() - parquetType.getOriginalType match { - // if the schema was constructed programmatically there may be hints how to convert - // it inside the metadata via the OriginalType field - case ParquetOriginalType.LIST => { // TODO: check enums! - assert(groupType.getFieldCount == 1) - val field = groupType.getFields.apply(0) - if (field.getName == CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME) { - val bag = field.asGroupType() - assert(bag.getFieldCount == 1) - ArrayType( - toDataType(bag.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp), - containsNull = true) - } else { - ArrayType( - toDataType(field, isBinaryAsString, isInt96AsTimestamp), containsNull = false) - } - } - case ParquetOriginalType.MAP => { - assert( - !groupType.getFields.apply(0).isPrimitive, - "Parquet Map type malformatted: expected nested group for map!") - val keyValueGroup = groupType.getFields.apply(0).asGroupType() - assert( - keyValueGroup.getFieldCount == 2, - "Parquet Map type malformatted: nested group should have 2 (key, value) fields!") - assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - - val keyType = - toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp) - val valueType = - toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString, isInt96AsTimestamp) - MapType(keyType, valueType, - keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED) - } - case _ => { - // Note: the order of these checks is important! - if (correspondsToMap(groupType)) { // MapType - val keyValueGroup = groupType.getFields.apply(0).asGroupType() - assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - - val keyType = - toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp) - val valueType = - toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString, isInt96AsTimestamp) - MapType(keyType, valueType, - keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED) - } else if (correspondsToArray(groupType)) { // ArrayType - val field = groupType.getFields.apply(0) - if (field.getName == CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME) { - val bag = field.asGroupType() - assert(bag.getFieldCount == 1) - ArrayType( - toDataType(bag.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp), - containsNull = true) - } else { - ArrayType( - toDataType(field, isBinaryAsString, isInt96AsTimestamp), containsNull = false) - } - } else { // everything else: StructType - val fields = groupType - .getFields - .map(ptype => new StructField( - ptype.getName, - toDataType(ptype, isBinaryAsString, isInt96AsTimestamp), - ptype.getRepetition != Repetition.REQUIRED)) - StructType(fields) - } - } - } - } - } - - /** - * For a given Catalyst [[org.apache.spark.sql.types.DataType]] return - * the name of the corresponding Parquet primitive type or None if the given type - * is not primitive. - * - * @param ctype The type to convert - * @return The name of the corresponding Parquet type properties - */ - def fromPrimitiveDataType(ctype: DataType): Option[ParquetTypeInfo] = ctype match { - case StringType => Some(ParquetTypeInfo( - ParquetPrimitiveTypeName.BINARY, Some(ParquetOriginalType.UTF8))) - case BinaryType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BINARY)) - case BooleanType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BOOLEAN)) - case DoubleType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.DOUBLE)) - case FloatType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FLOAT)) - case IntegerType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) - // There is no type for Byte or Short so we promote them to INT32. - case ShortType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) - case ByteType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) - case DateType => Some(ParquetTypeInfo( - ParquetPrimitiveTypeName.INT32, Some(ParquetOriginalType.DATE))) - case LongType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT64)) - case TimestampType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT96)) - case DecimalType.Fixed(precision, scale) if precision <= 18 => - // TODO: for now, our writer only supports decimals that fit in a Long - Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, - Some(ParquetOriginalType.DECIMAL), - Some(new DecimalMetadata(precision, scale)), - Some(BYTES_FOR_PRECISION(precision)))) - case _ => None - } - - /** - * Compute the FIXED_LEN_BYTE_ARRAY length needed to represent a given DECIMAL precision. - */ - private[parquet] val BYTES_FOR_PRECISION = Array.tabulate[Int](38) { precision => - var length = 1 - while (math.pow(2.0, 8 * length - 1) < math.pow(10.0, precision)) { - length += 1 - } - length - } - - /** - * Converts a given Catalyst [[org.apache.spark.sql.types.DataType]] into - * the corresponding Parquet `Type`. - * - * The conversion follows the rules below: - *
        - *
      • Primitive types are converted into Parquet's primitive types.
      • - *
      • [[org.apache.spark.sql.types.StructType]]s are converted - * into Parquet's `GroupType` with the corresponding field types.
      • - *
      • [[org.apache.spark.sql.types.ArrayType]]s are converted - * into a 2-level nested group, where the outer group has the inner - * group as sole field. The inner group has name `values` and - * repetition level `REPEATED` and has the element type of - * the array as schema. We use Parquet's `ConversionPatterns` for this - * purpose.
      • - *
      • [[org.apache.spark.sql.types.MapType]]s are converted - * into a nested (2-level) Parquet `GroupType` with two fields: a key - * type and a value type. The nested group has repetition level - * `REPEATED` and name `map`. We use Parquet's `ConversionPatterns` - * for this purpose
      • - *
      - * Parquet's repetition level is generally set according to the following rule: - *
        - *
      • If the call to `fromDataType` is recursive inside an enclosing `ArrayType` or - * `MapType`, then the repetition level is set to `REPEATED`.
      • - *
      • Otherwise, if the attribute whose type is converted is `nullable`, the Parquet - * type gets repetition level `OPTIONAL` and otherwise `REQUIRED`.
      • - *
      - * - *@param ctype The type to convert - * @param name The name of the [[org.apache.spark.sql.catalyst.expressions.Attribute]] - * whose type is converted - * @param nullable When true indicates that the attribute is nullable - * @param inArray When true indicates that this is a nested attribute inside an array. - * @return The corresponding Parquet type. - */ - def fromDataType( - ctype: DataType, - name: String, - nullable: Boolean = true, - inArray: Boolean = false, - toThriftSchemaNames: Boolean = false): ParquetType = { - val repetition = - if (inArray) { - Repetition.REPEATED - } else { - if (nullable) Repetition.OPTIONAL else Repetition.REQUIRED - } - val arraySchemaName = if (toThriftSchemaNames) { - name + CatalystConverter.THRIFT_ARRAY_ELEMENTS_SCHEMA_NAME_SUFFIX - } else { - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME - } - val typeInfo = fromPrimitiveDataType(ctype) - typeInfo.map { - case ParquetTypeInfo(primitiveType, originalType, decimalMetadata, length) => - val builder = ParquetTypes.primitive(primitiveType, repetition).as(originalType.orNull) - for (len <- length) { - builder.length(len) - } - for (metadata <- decimalMetadata) { - builder.precision(metadata.getPrecision).scale(metadata.getScale) - } - builder.named(name) - }.getOrElse { - ctype match { - case udt: UserDefinedType[_] => { - fromDataType(udt.sqlType, name, nullable, inArray, toThriftSchemaNames) - } - case ArrayType(elementType, false) => { - val parquetElementType = fromDataType( - elementType, - arraySchemaName, - nullable = false, - inArray = true, - toThriftSchemaNames) - ConversionPatterns.listType(repetition, name, parquetElementType) - } - case ArrayType(elementType, true) => { - val parquetElementType = fromDataType( - elementType, - arraySchemaName, - nullable = true, - inArray = false, - toThriftSchemaNames) - ConversionPatterns.listType( - repetition, - name, - new ParquetGroupType( - Repetition.REPEATED, - CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, - parquetElementType)) - } - case StructType(structFields) => { - val fields = structFields.map { - field => fromDataType(field.dataType, field.name, field.nullable, - inArray = false, toThriftSchemaNames) - } - new ParquetGroupType(repetition, name, fields.toSeq) - } - case MapType(keyType, valueType, valueContainsNull) => { - val parquetKeyType = - fromDataType( - keyType, - CatalystConverter.MAP_KEY_SCHEMA_NAME, - nullable = false, - inArray = false, - toThriftSchemaNames) - val parquetValueType = - fromDataType( - valueType, - CatalystConverter.MAP_VALUE_SCHEMA_NAME, - nullable = valueContainsNull, - inArray = false, - toThriftSchemaNames) - ConversionPatterns.mapType( - repetition, - name, - parquetKeyType, - parquetValueType) - } - case _ => throw new AnalysisException(s"Unsupported datatype $ctype") - } - } - } - - def convertToAttributes(parquetSchema: ParquetType, - isBinaryAsString: Boolean, - isInt96AsTimestamp: Boolean): Seq[Attribute] = { - parquetSchema - .asGroupType() - .getFields - .map( - field => - new AttributeReference( - field.getName, - toDataType(field, isBinaryAsString, isInt96AsTimestamp), - field.getRepetition != Repetition.REQUIRED)()) - } - - def convertFromAttributes(attributes: Seq[Attribute], - toThriftSchemaNames: Boolean = false): MessageType = { - checkSpecialCharacters(attributes) - val fields = attributes.map( - attribute => - fromDataType(attribute.dataType, attribute.name, attribute.nullable, - toThriftSchemaNames = toThriftSchemaNames)) - new MessageType("root", fields) - } - - def convertFromString(string: String): Seq[Attribute] = { - Try(DataType.fromJson(string)).getOrElse(DataType.fromCaseClassString(string)) match { - case s: StructType => s.toAttributes - case other => throw new AnalysisException(s"Can convert $string to row") - } - } - - private def checkSpecialCharacters(schema: Seq[Attribute]) = { - // ,;{}()\n\t= and space character are special characters in Parquet schema - schema.map(_.name).foreach { name => - if (name.matches(".*[ ,;{}()\n\t=].*")) { - throw new AnalysisException( - s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=". - |Please use alias to rename it. - """.stripMargin.split("\n").mkString(" ")) - } - } - } - - def convertToString(schema: Seq[Attribute]): String = { - checkSpecialCharacters(schema) - StructType.fromAttributes(schema).json - } - - def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration): Unit = { - if (origPath == null) { - throw new IllegalArgumentException("Unable to write Parquet metadata: path is null") - } - val fs = origPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException( - s"Unable to write Parquet metadata: path $origPath is incorrectly formatted") - } - val path = origPath.makeQualified(fs) - if (fs.exists(path) && !fs.getFileStatus(path).isDir) { - throw new IllegalArgumentException(s"Expected to write to directory $path but found file") - } - val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE) - if (fs.exists(metadataPath)) { - try { - fs.delete(metadataPath, true) - } catch { - case e: IOException => - throw new IOException(s"Unable to delete previous PARQUET_METADATA_FILE at $metadataPath") - } - } - val extraMetadata = new java.util.HashMap[String, String]() - extraMetadata.put( - RowReadSupport.SPARK_METADATA_KEY, - ParquetTypesConverter.convertToString(attributes)) - // TODO: add extra data, e.g., table name, date, etc.? - - val parquetSchema: MessageType = - ParquetTypesConverter.convertFromAttributes(attributes) - val metaData: FileMetaData = new FileMetaData( - parquetSchema, - extraMetadata, - "Spark") - - ParquetRelation.enableLogForwarding() - ParquetFileWriter.writeMetadataFile( - conf, - path, - new Footer(path, new ParquetMetadata(metaData, Nil)) :: Nil) - } - - /** - * Try to read Parquet metadata at the given Path. We first see if there is a summary file - * in the parent directory. If so, this is used. Else we read the actual footer at the given - * location. - * @param origPath The path at which we expect one (or more) Parquet files. - * @param configuration The Hadoop configuration to use. - * @return The `ParquetMetadata` containing among other things the schema. - */ - def readMetaData(origPath: Path, configuration: Option[Configuration]): ParquetMetadata = { - if (origPath == null) { - throw new IllegalArgumentException("Unable to read Parquet metadata: path is null") - } - val job = new Job() - val conf = configuration.getOrElse(ContextUtil.getConfiguration(job)) - val fs: FileSystem = origPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException(s"Incorrectly formatted Parquet metadata path $origPath") - } - val path = origPath.makeQualified(fs) - - val children = - fs - .globStatus(path) - .flatMap { status => if (status.isDir) fs.listStatus(status.getPath) else List(status) } - .filterNot { status => - val name = status.getPath.getName - (name(0) == '.' || name(0) == '_') && name != ParquetFileWriter.PARQUET_METADATA_FILE - } - - ParquetRelation.enableLogForwarding() - - // NOTE (lian): Parquet "_metadata" file can be very slow if the file consists of lots of row - // groups. Since Parquet schema is replicated among all row groups, we only need to touch a - // single row group to read schema related metadata. Notice that we are making assumptions that - // all data in a single Parquet file have the same schema, which is normally true. - children - // Try any non-"_metadata" file first... - .find(_.getPath.getName != ParquetFileWriter.PARQUET_METADATA_FILE) - // ... and fallback to "_metadata" if no such file exists (which implies the Parquet file is - // empty, thus normally the "_metadata" file is expected to be fairly small). - .orElse(children.find(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE)) - .map(ParquetFileReader.readFooter(conf, _, ParquetMetadataConverter.NO_FILTER)) - .getOrElse( - throw new IllegalArgumentException(s"Could not find Parquet metadata at path $path")) - } - - /** - * Reads in Parquet Metadata from the given path and tries to extract the schema - * (Catalyst attributes) from the application-specific key-value map. If this - * is empty it falls back to converting from the Parquet file schema which - * may lead to an upcast of types (e.g., {byte, short} to int). - * - * @param origPath The path at which we expect one (or more) Parquet files. - * @param conf The Hadoop configuration to use. - * @return A list of attributes that make up the schema. - */ - def readSchemaFromFile( - origPath: Path, - conf: Option[Configuration], - isBinaryAsString: Boolean, - isInt96AsTimestamp: Boolean): Seq[Attribute] = { - val keyValueMetadata: java.util.Map[String, String] = - readMetaData(origPath, conf) - .getFileMetaData - .getKeyValueMetaData - if (keyValueMetadata.get(RowReadSupport.SPARK_METADATA_KEY) != null) { - convertFromString(keyValueMetadata.get(RowReadSupport.SPARK_METADATA_KEY)) - } else { - val attributes = convertToAttributes( - readMetaData(origPath, conf).getFileMetaData.getSchema, - isBinaryAsString, - isInt96AsTimestamp) - log.info(s"Falling back to schema conversion from Parquet types; result: $attributes") - attributes - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala deleted file mode 100644 index c9de45e0ddfb..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ /dev/null @@ -1,622 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.net.URI -import java.util.{List => JList} - -import scala.collection.JavaConversions._ -import scala.util.Try - -import com.google.common.base.Objects -import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat -import org.apache.parquet.filter2.predicate.FilterApi -import org.apache.parquet.hadoop._ -import org.apache.parquet.hadoop.metadata.CompressionCodecName -import org.apache.parquet.hadoop.util.ContextUtil - -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.RDD._ -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.{Logging, SparkException, Partition => SparkPartition} - -private[sql] class DefaultSource extends HadoopFsRelationProvider { - override def createRelation( - sqlContext: SQLContext, - paths: Array[String], - schema: Option[StructType], - partitionColumns: Option[StructType], - parameters: Map[String, String]): HadoopFsRelation = { - new ParquetRelation2(paths, schema, None, partitionColumns, parameters)(sqlContext) - } -} - -// NOTE: This class is instantiated and used on executor side only, no need to be serializable. -private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext) - extends OutputWriter { - - private val recordWriter: RecordWriter[Void, InternalRow] = { - val conf = context.getConfiguration - val outputFormat = { - // When appending new Parquet files to an existing Parquet file directory, to avoid - // overwriting existing data files, we need to find out the max task ID encoded in these data - // file names. - // TODO Make this snippet a utility function for other data source developers - val maxExistingTaskId = { - // Note that `path` may point to a temporary location. Here we retrieve the real - // destination path from the configuration - val outputPath = new Path(conf.get("spark.sql.sources.output.path")) - val fs = outputPath.getFileSystem(conf) - - if (fs.exists(outputPath)) { - // Pattern used to match task ID in part file names, e.g.: - // - // part-r-00001.gz.parquet - // ^~~~~ - val partFilePattern = """part-.-(\d{1,}).*""".r - - fs.listStatus(outputPath).map(_.getPath.getName).map { - case partFilePattern(id) => id.toInt - case name if name.startsWith("_") => 0 - case name if name.startsWith(".") => 0 - case name => throw new AnalysisException( - s"Trying to write Parquet files to directory $outputPath, " + - s"but found items with illegal name '$name'.") - }.reduceOption(_ max _).getOrElse(0) - } else { - 0 - } - } - - new ParquetOutputFormat[InternalRow]() { - // Here we override `getDefaultWorkFile` for two reasons: - // - // 1. To allow appending. We need to generate output file name based on the max available - // task ID computed above. - // - // 2. To allow dynamic partitioning. Default `getDefaultWorkFile` uses - // `FileOutputCommitter.getWorkPath()`, which points to the base directory of all - // partitions in the case of dynamic partitioning. - override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val split = context.getTaskAttemptID.getTaskID.getId + maxExistingTaskId + 1 - new Path(path, f"part-r-$split%05d$extension") - } - } - } - - outputFormat.getRecordWriter(context) - } - - override def write(row: Row): Unit = recordWriter.write(null, row.asInstanceOf[InternalRow]) - - override def close(): Unit = recordWriter.close(context) -} - -private[sql] class ParquetRelation2( - override val paths: Array[String], - private val maybeDataSchema: Option[StructType], - // This is for metastore conversion. - private val maybePartitionSpec: Option[PartitionSpec], - override val userDefinedPartitionColumns: Option[StructType], - parameters: Map[String, String])( - val sqlContext: SQLContext) - extends HadoopFsRelation(maybePartitionSpec) - with Logging { - - private[sql] def this( - paths: Array[String], - maybeDataSchema: Option[StructType], - maybePartitionSpec: Option[PartitionSpec], - parameters: Map[String, String])( - sqlContext: SQLContext) = { - this( - paths, - maybeDataSchema, - maybePartitionSpec, - maybePartitionSpec.map(_.partitionColumns), - parameters)(sqlContext) - } - - // Should we merge schemas from all Parquet part-files? - private val shouldMergeSchemas = - parameters.getOrElse(ParquetRelation2.MERGE_SCHEMA, "true").toBoolean - - private val maybeMetastoreSchema = parameters - .get(ParquetRelation2.METASTORE_SCHEMA) - .map(DataType.fromJson(_).asInstanceOf[StructType]) - - private lazy val metadataCache: MetadataCache = { - val meta = new MetadataCache - meta.refresh() - meta - } - - override def equals(other: Any): Boolean = other match { - case that: ParquetRelation2 => - val schemaEquality = if (shouldMergeSchemas) { - this.shouldMergeSchemas == that.shouldMergeSchemas - } else { - this.dataSchema == that.dataSchema && - this.schema == that.schema - } - - this.paths.toSet == that.paths.toSet && - schemaEquality && - this.maybeDataSchema == that.maybeDataSchema && - this.partitionColumns == that.partitionColumns - - case _ => false - } - - override def hashCode(): Int = { - if (shouldMergeSchemas) { - Objects.hashCode( - Boolean.box(shouldMergeSchemas), - paths.toSet, - maybeDataSchema, - partitionColumns) - } else { - Objects.hashCode( - Boolean.box(shouldMergeSchemas), - paths.toSet, - dataSchema, - schema, - maybeDataSchema, - partitionColumns) - } - } - - override def dataSchema: StructType = maybeDataSchema.getOrElse(metadataCache.dataSchema) - - override private[sql] def refresh(): Unit = { - super.refresh() - metadataCache.refresh() - } - - // Parquet data source always uses Catalyst internal representations. - override val needConversion: Boolean = false - - override def sizeInBytes: Long = metadataCache.dataStatuses.map(_.getLen).sum - - override def prepareJobForWrite(job: Job): OutputWriterFactory = { - val conf = ContextUtil.getConfiguration(job) - - val committerClass = - conf.getClass( - "spark.sql.parquet.output.committer.class", - classOf[ParquetOutputCommitter], - classOf[ParquetOutputCommitter]) - - if (conf.get("spark.sql.parquet.output.committer.class") == null) { - logInfo("Using default output committer for Parquet: " + - classOf[ParquetOutputCommitter].getCanonicalName) - } else { - logInfo("Using user defined output committer for Parquet: " + committerClass.getCanonicalName) - } - - conf.setClass( - SQLConf.OUTPUT_COMMITTER_CLASS.key, - committerClass, - classOf[ParquetOutputCommitter]) - - // TODO There's no need to use two kinds of WriteSupport - // We should unify them. `SpecificMutableRow` can process both atomic (primitive) types and - // complex types. - val writeSupportClass = - if (dataSchema.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) { - classOf[MutableRowWriteSupport] - } else { - classOf[RowWriteSupport] - } - - ParquetOutputFormat.setWriteSupportClass(job, writeSupportClass) - RowWriteSupport.setSchema(dataSchema.toAttributes, conf) - - // Sets compression scheme - conf.set( - ParquetOutputFormat.COMPRESSION, - ParquetRelation - .shortParquetCompressionCodecNames - .getOrElse( - sqlContext.conf.parquetCompressionCodec.toUpperCase, - CompressionCodecName.UNCOMPRESSED).name()) - - new OutputWriterFactory { - override def newInstance( - path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new ParquetOutputWriter(path, context) - } - } - } - - override def buildScan( - requiredColumns: Array[String], - filters: Array[Filter], - inputFiles: Array[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { - val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA) - val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown - // Create the function to set variable Parquet confs at both driver and executor side. - val initLocalJobFuncOpt = - ParquetRelation2.initializeLocalJobFunc( - requiredColumns, - filters, - dataSchema, - useMetadataCache, - parquetFilterPushDown) _ - // Create the function to set input paths at the driver side. - val setInputPaths = ParquetRelation2.initializeDriverSideJobFunc(inputFiles) _ - - val footers = inputFiles.map(f => metadataCache.footers(f.getPath)) - - Utils.withDummyCallSite(sqlContext.sparkContext) { - // TODO Stop using `FilteringParquetRowInputFormat` and overriding `getPartition`. - // After upgrading to Parquet 1.6.0, we should be able to stop caching `FileStatus` objects - // and footers. Especially when a global arbitrative schema (either from metastore or data - // source DDL) is available. - new SqlNewHadoopRDD( - sc = sqlContext.sparkContext, - broadcastedConf = broadcastedConf, - initDriverSideJobFuncOpt = Some(setInputPaths), - initLocalJobFuncOpt = Some(initLocalJobFuncOpt), - inputFormatClass = classOf[FilteringParquetRowInputFormat], - keyClass = classOf[Void], - valueClass = classOf[InternalRow]) { - - val cacheMetadata = useMetadataCache - - @transient val cachedStatuses = inputFiles.map { f => - // In order to encode the authority of a Path containing special characters such as '/' - // (which does happen in some S3N credentials), we need to use the string returned by the - // URI of the path to create a new Path. - val pathWithEscapedAuthority = escapePathUserInfo(f.getPath) - new FileStatus( - f.getLen, f.isDir, f.getReplication, f.getBlockSize, f.getModificationTime, - f.getAccessTime, f.getPermission, f.getOwner, f.getGroup, pathWithEscapedAuthority) - }.toSeq - - @transient val cachedFooters = footers.map { f => - // In order to encode the authority of a Path containing special characters such as /, - // we need to use the string returned by the URI of the path to create a new Path. - new Footer(escapePathUserInfo(f.getFile), f.getParquetMetadata) - }.toSeq - - private def escapePathUserInfo(path: Path): Path = { - val uri = path.toUri - new Path(new URI( - uri.getScheme, uri.getRawUserInfo, uri.getHost, uri.getPort, uri.getPath, - uri.getQuery, uri.getFragment)) - } - - // Overridden so we can inject our own cached files statuses. - override def getPartitions: Array[SparkPartition] = { - val inputFormat = if (cacheMetadata) { - new FilteringParquetRowInputFormat { - override def listStatus(jobContext: JobContext): JList[FileStatus] = cachedStatuses - override def getFooters(jobContext: JobContext): JList[Footer] = cachedFooters - } - } else { - new FilteringParquetRowInputFormat - } - - val jobContext = newJobContext(getConf(isDriverSide = true), jobId) - val rawSplits = inputFormat.getSplits(jobContext) - - Array.tabulate[SparkPartition](rawSplits.size) { i => - new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) - } - } - }.values.map(_.asInstanceOf[Row]) - } - } - - private class MetadataCache { - // `FileStatus` objects of all "_metadata" files. - private var metadataStatuses: Array[FileStatus] = _ - - // `FileStatus` objects of all "_common_metadata" files. - private var commonMetadataStatuses: Array[FileStatus] = _ - - // Parquet footer cache. - var footers: Map[Path, Footer] = _ - - // `FileStatus` objects of all data files (Parquet part-files). - var dataStatuses: Array[FileStatus] = _ - - // Schema of the actual Parquet files, without partition columns discovered from partition - // directory paths. - var dataSchema: StructType = null - - // Schema of the whole table, including partition columns. - var schema: StructType = _ - - /** - * Refreshes `FileStatus`es, footers, partition spec, and table schema. - */ - def refresh(): Unit = { - // Lists `FileStatus`es of all leaf nodes (files) under all base directories. - val leaves = cachedLeafStatuses().filter { f => - isSummaryFile(f.getPath) || - !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) - }.toArray - - dataStatuses = leaves.filterNot(f => isSummaryFile(f.getPath)) - metadataStatuses = leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE) - commonMetadataStatuses = - leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE) - - footers = { - val conf = SparkHadoopUtil.get.conf - val taskSideMetaData = conf.getBoolean(ParquetInputFormat.TASK_SIDE_METADATA, true) - val rawFooters = if (shouldMergeSchemas) { - ParquetFileReader.readAllFootersInParallel( - conf, seqAsJavaList(leaves), taskSideMetaData) - } else { - ParquetFileReader.readAllFootersInParallelUsingSummaryFiles( - conf, seqAsJavaList(leaves), taskSideMetaData) - } - - rawFooters.map(footer => footer.getFile -> footer).toMap - } - - // If we already get the schema, don't need to re-compute it since the schema merging is - // time-consuming. - if (dataSchema == null) { - dataSchema = { - val dataSchema0 = maybeDataSchema - .orElse(readSchema()) - .orElse(maybeMetastoreSchema) - .getOrElse(throw new AnalysisException( - s"Failed to discover schema of Parquet file(s) in the following location(s):\n" + - paths.mkString("\n\t"))) - - // If this Parquet relation is converted from a Hive Metastore table, must reconcile case - // case insensitivity issue and possible schema mismatch (probably caused by schema - // evolution). - maybeMetastoreSchema - .map(ParquetRelation2.mergeMetastoreParquetSchema(_, dataSchema0)) - .getOrElse(dataSchema0) - } - } - } - - private def isSummaryFile(file: Path): Boolean = { - file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || - file.getName == ParquetFileWriter.PARQUET_METADATA_FILE - } - - private def readSchema(): Option[StructType] = { - // Sees which file(s) we need to touch in order to figure out the schema. - // - // Always tries the summary files first if users don't require a merged schema. In this case, - // "_common_metadata" is more preferable than "_metadata" because it doesn't contain row - // groups information, and could be much smaller for large Parquet files with lots of row - // groups. - // - // NOTE: Metadata stored in the summary files are merged from all part-files. However, for - // user defined key-value metadata (in which we store Spark SQL schema), Parquet doesn't know - // how to merge them correctly if some key is associated with different values in different - // part-files. When this happens, Parquet simply gives up generating the summary file. This - // implies that if a summary file presents, then: - // - // 1. Either all part-files have exactly the same Spark SQL schema, or - // 2. Some part-files don't contain Spark SQL schema in the key-value metadata at all (thus - // their schemas may differ from each other). - // - // Here we tend to be pessimistic and take the second case into account. Basically this means - // we can't trust the summary files if users require a merged schema, and must touch all part- - // files to do the merge. - val filesToTouch = - if (shouldMergeSchemas) { - // Also includes summary files, 'cause there might be empty partition directories. - (metadataStatuses ++ commonMetadataStatuses ++ dataStatuses).toSeq - } else { - // Tries any "_common_metadata" first. Parquet files written by old versions or Parquet - // don't have this. - commonMetadataStatuses.headOption - // Falls back to "_metadata" - .orElse(metadataStatuses.headOption) - // Summary file(s) not found, the Parquet file is either corrupted, or different part- - // files contain conflicting user defined metadata (two or more values are associated - // with a same key in different files). In either case, we fall back to any of the - // first part-file, and just assume all schemas are consistent. - .orElse(dataStatuses.headOption) - .toSeq - } - - assert( - filesToTouch.nonEmpty || maybeDataSchema.isDefined || maybeMetastoreSchema.isDefined, - "No schema defined, " + - s"and no Parquet data file or summary file found under ${paths.mkString(", ")}.") - - ParquetRelation2.readSchema(filesToTouch.map(f => footers.apply(f.getPath)), sqlContext) - } - } -} - -private[sql] object ParquetRelation2 extends Logging { - // Whether we should merge schemas collected from all Parquet part-files. - private[sql] val MERGE_SCHEMA = "mergeSchema" - - // Hive Metastore schema, used when converting Metastore Parquet tables. This option is only used - // internally. - private[sql] val METASTORE_SCHEMA = "metastoreSchema" - - /** This closure sets various Parquet configurations at both driver side and executor side. */ - private[parquet] def initializeLocalJobFunc( - requiredColumns: Array[String], - filters: Array[Filter], - dataSchema: StructType, - useMetadataCache: Boolean, - parquetFilterPushDown: Boolean)(job: Job): Unit = { - val conf = job.getConfiguration - conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[RowReadSupport].getName()) - - // Try to push down filters when filter push-down is enabled. - if (parquetFilterPushDown) { - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can be - // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` - // is used here. - .flatMap(ParquetFilters.createFilter(dataSchema, _)) - .reduceOption(FilterApi.and) - .foreach(ParquetInputFormat.setFilterPredicate(conf, _)) - } - - conf.set(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, { - val requestedSchema = StructType(requiredColumns.map(dataSchema(_))) - ParquetTypesConverter.convertToString(requestedSchema.toAttributes) - }) - - conf.set( - RowWriteSupport.SPARK_ROW_SCHEMA, - ParquetTypesConverter.convertToString(dataSchema.toAttributes)) - - // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata - conf.setBoolean(SQLConf.PARQUET_CACHE_METADATA.key, useMetadataCache) - } - - /** This closure sets input paths at the driver side. */ - private[parquet] def initializeDriverSideJobFunc( - inputFiles: Array[FileStatus])(job: Job): Unit = { - // We side the input paths at the driver side. - if (inputFiles.nonEmpty) { - FileInputFormat.setInputPaths(job, inputFiles.map(_.getPath): _*) - } - } - - private[parquet] def readSchema( - footers: Seq[Footer], sqlContext: SQLContext): Option[StructType] = { - footers.map { footer => - val metadata = footer.getParquetMetadata.getFileMetaData - val parquetSchema = metadata.getSchema - val maybeSparkSchema = metadata - .getKeyValueMetaData - .toMap - .get(RowReadSupport.SPARK_METADATA_KEY) - .flatMap { serializedSchema => - // Don't throw even if we failed to parse the serialized Spark schema. Just fallback to - // whatever is available. - Try(DataType.fromJson(serializedSchema)) - .recover { case _: Throwable => - logInfo( - s"Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + - "falling back to the deprecated DataType.fromCaseClassString parser.") - DataType.fromCaseClassString(serializedSchema) - } - .recover { case cause: Throwable => - logWarning( - s"""Failed to parse serialized Spark schema in Parquet key-value metadata: - |\t$serializedSchema - """.stripMargin, - cause) - } - .map(_.asInstanceOf[StructType]) - .toOption - } - - maybeSparkSchema.getOrElse { - // Falls back to Parquet schema if Spark SQL schema is absent. - StructType.fromAttributes( - // TODO Really no need to use `Attribute` here, we only need to know the data type. - ParquetTypesConverter.convertToAttributes( - parquetSchema, - sqlContext.conf.isParquetBinaryAsString, - sqlContext.conf.isParquetINT96AsTimestamp)) - } - }.reduceOption { (left, right) => - try left.merge(right) catch { case e: Throwable => - throw new SparkException(s"Failed to merge incompatible schemas $left and $right", e) - } - } - } - - /** - * Reconciles Hive Metastore case insensitivity issue and data type conflicts between Metastore - * schema and Parquet schema. - * - * Hive doesn't retain case information, while Parquet is case sensitive. On the other hand, the - * schema read from Parquet files may be incomplete (e.g. older versions of Parquet doesn't - * distinguish binary and string). This method generates a correct schema by merging Metastore - * schema data types and Parquet schema field names. - */ - private[parquet] def mergeMetastoreParquetSchema( - metastoreSchema: StructType, - parquetSchema: StructType): StructType = { - def schemaConflictMessage: String = - s"""Converting Hive Metastore Parquet, but detected conflicting schemas. Metastore schema: - |${metastoreSchema.prettyJson} - | - |Parquet schema: - |${parquetSchema.prettyJson} - """.stripMargin - - val mergedParquetSchema = mergeMissingNullableFields(metastoreSchema, parquetSchema) - - assert(metastoreSchema.size <= mergedParquetSchema.size, schemaConflictMessage) - - val ordinalMap = metastoreSchema.zipWithIndex.map { - case (field, index) => field.name.toLowerCase -> index - }.toMap - - val reorderedParquetSchema = mergedParquetSchema.sortBy(f => - ordinalMap.getOrElse(f.name.toLowerCase, metastoreSchema.size + 1)) - - StructType(metastoreSchema.zip(reorderedParquetSchema).map { - // Uses Parquet field names but retains Metastore data types. - case (mSchema, pSchema) if mSchema.name.toLowerCase == pSchema.name.toLowerCase => - mSchema.copy(name = pSchema.name) - case _ => - throw new SparkException(schemaConflictMessage) - }) - } - - /** - * Returns the original schema from the Parquet file with any missing nullable fields from the - * Hive Metastore schema merged in. - * - * When constructing a DataFrame from a collection of structured data, the resulting object has - * a schema corresponding to the union of the fields present in each element of the collection. - * Spark SQL simply assigns a null value to any field that isn't present for a particular row. - * In some cases, it is possible that a given table partition stored as a Parquet file doesn't - * contain a particular nullable field in its schema despite that field being present in the - * table schema obtained from the Hive Metastore. This method returns a schema representing the - * Parquet file schema along with any additional nullable fields from the Metastore schema - * merged in. - */ - private[parquet] def mergeMissingNullableFields( - metastoreSchema: StructType, - parquetSchema: StructType): StructType = { - val fieldMap = metastoreSchema.map(f => f.name.toLowerCase -> f).toMap - val missingFields = metastoreSchema - .map(_.name.toLowerCase) - .diff(parquetSchema.map(_.name.toLowerCase)) - .map(fieldMap(_)) - .filter(_.nullable) - StructType(parquetSchema ++ missingFields) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala deleted file mode 100644 index 4d5ed211ad0c..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet.timestamp - -import java.nio.{ByteBuffer, ByteOrder} - -import org.apache.parquet.Preconditions -import org.apache.parquet.io.api.{Binary, RecordConsumer} - -private[parquet] class NanoTime extends Serializable { - private var julianDay = 0 - private var timeOfDayNanos = 0L - - def set(julianDay: Int, timeOfDayNanos: Long): this.type = { - this.julianDay = julianDay - this.timeOfDayNanos = timeOfDayNanos - this - } - - def getJulianDay: Int = julianDay - - def getTimeOfDayNanos: Long = timeOfDayNanos - - def toBinary: Binary = { - val buf = ByteBuffer.allocate(12) - buf.order(ByteOrder.LITTLE_ENDIAN) - buf.putLong(timeOfDayNanos) - buf.putInt(julianDay) - buf.flip() - Binary.fromByteBuffer(buf) - } - - def writeValue(recordConsumer: RecordConsumer): Unit = { - recordConsumer.addBinary(toBinary) - } - - override def toString: String = - "NanoTime{julianDay=" + julianDay + ", timeOfDayNanos=" + timeOfDayNanos + "}" -} - -private[sql] object NanoTime { - def fromBinary(bytes: Binary): NanoTime = { - Preconditions.checkArgument(bytes.length() == 12, "Must be 12 bytes") - val buf = bytes.toByteBuffer - buf.order(ByteOrder.LITTLE_ENDIAN) - val timeOfDayNanos = buf.getLong - val julianDay = buf.getInt - new NanoTime().set(julianDay, timeOfDayNanos) - } - - def apply(julianDay: Int, timeOfDayNanos: Long): NanoTime = { - new NanoTime().set(julianDay, timeOfDayNanos) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala deleted file mode 100644 index c16bd9ae52c8..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ /dev/null @@ -1,491 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources - -import java.util.Date - -import scala.collection.mutable - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, FileOutputCommitter => MapReduceFileOutputCommitter} -import org.apache.parquet.hadoop.util.ContextUtil - -import org.apache.spark._ -import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} -import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext, SaveMode} -import org.apache.spark.util.SerializableConfiguration - -private[sql] case class InsertIntoDataSource( - logicalRelation: LogicalRelation, - query: LogicalPlan, - overwrite: Boolean) - extends RunnableCommand { - - override def run(sqlContext: SQLContext): Seq[InternalRow] = { - val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] - val data = DataFrame(sqlContext, query) - // Apply the schema of the existing table to the new data. - val df = sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) - relation.insert(df, overwrite) - - // Invalidate the cache. - sqlContext.cacheManager.invalidateCache(logicalRelation) - - Seq.empty[InternalRow] - } -} - -private[sql] case class InsertIntoHadoopFsRelation( - @transient relation: HadoopFsRelation, - @transient query: LogicalPlan, - mode: SaveMode) - extends RunnableCommand { - - override def run(sqlContext: SQLContext): Seq[InternalRow] = { - require( - relation.paths.length == 1, - s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}") - - val hadoopConf = sqlContext.sparkContext.hadoopConfiguration - val outputPath = new Path(relation.paths.head) - val fs = outputPath.getFileSystem(hadoopConf) - val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - - val doInsertion = (mode, fs.exists(qualifiedOutputPath)) match { - case (SaveMode.ErrorIfExists, true) => - sys.error(s"path $qualifiedOutputPath already exists.") - case (SaveMode.Overwrite, true) => - fs.delete(qualifiedOutputPath, true) - true - case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => - true - case (SaveMode.Ignore, exists) => - !exists - } - - if (doInsertion) { - val job = new Job(hadoopConf) - job.setOutputKeyClass(classOf[Void]) - job.setOutputValueClass(classOf[InternalRow]) - FileOutputFormat.setOutputPath(job, qualifiedOutputPath) - - // We create a DataFrame by applying the schema of relation to the data to make sure. - // We are writing data based on the expected schema, - val df = { - // For partitioned relation r, r.schema's column ordering can be different from the column - // ordering of data.logicalPlan (partition columns are all moved after data column). We - // need a Project to adjust the ordering, so that inside InsertIntoHadoopFsRelation, we can - // safely apply the schema of r.schema to the data. - val project = Project( - relation.schema.map(field => new UnresolvedAttribute(Seq(field.name))), query) - - sqlContext.internalCreateDataFrame( - DataFrame(sqlContext, project).queryExecution.toRdd, relation.schema) - } - - val partitionColumns = relation.partitionColumns.fieldNames - if (partitionColumns.isEmpty) { - insert(new DefaultWriterContainer(relation, job), df) - } else { - val writerContainer = new DynamicPartitionWriterContainer( - relation, job, partitionColumns, PartitioningUtils.DEFAULT_PARTITION_NAME) - insertWithDynamicPartitions(sqlContext, writerContainer, df, partitionColumns) - } - } - - Seq.empty[InternalRow] - } - - private def insert(writerContainer: BaseWriterContainer, df: DataFrame): Unit = { - // Uses local vals for serialization - val needsConversion = relation.needConversion - val dataSchema = relation.dataSchema - - // This call shouldn't be put into the `try` block below because it only initializes and - // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. - writerContainer.driverSideSetup() - - try { - df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _) - writerContainer.commitJob() - relation.refresh() - } catch { case cause: Throwable => - logError("Aborting job.", cause) - writerContainer.abortJob() - throw new SparkException("Job aborted.", cause) - } - - def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { - // If anything below fails, we should abort the task. - try { - writerContainer.executorSideSetup(taskContext) - - val converter = if (needsConversion) { - CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] - } else { - r: InternalRow => r.asInstanceOf[Row] - } - while (iterator.hasNext) { - val row = converter(iterator.next()) - writerContainer.outputWriterForRow(row).write(row) - } - - writerContainer.commitTask() - } catch { case cause: Throwable => - logError("Aborting task.", cause) - writerContainer.abortTask() - throw new SparkException("Task failed while writing rows.", cause) - } - } - } - - private def insertWithDynamicPartitions( - sqlContext: SQLContext, - writerContainer: BaseWriterContainer, - df: DataFrame, - partitionColumns: Array[String]): Unit = { - // Uses a local val for serialization - val needsConversion = relation.needConversion - val dataSchema = relation.dataSchema - - require( - df.schema == relation.schema, - s"""DataFrame must have the same schema as the relation to which is inserted. - |DataFrame schema: ${df.schema} - |Relation schema: ${relation.schema} - """.stripMargin) - - val partitionColumnsInSpec = relation.partitionColumns.fieldNames - require( - partitionColumnsInSpec.sameElements(partitionColumns), - s"""Partition columns mismatch. - |Expected: ${partitionColumnsInSpec.mkString(", ")} - |Actual: ${partitionColumns.mkString(", ")} - """.stripMargin) - - val output = df.queryExecution.executedPlan.output - val (partitionOutput, dataOutput) = output.partition(a => partitionColumns.contains(a.name)) - val codegenEnabled = df.sqlContext.conf.codegenEnabled - - // This call shouldn't be put into the `try` block below because it only initializes and - // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. - writerContainer.driverSideSetup() - - try { - df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _) - writerContainer.commitJob() - relation.refresh() - } catch { case cause: Throwable => - logError("Aborting job.", cause) - writerContainer.abortJob() - throw new SparkException("Job aborted.", cause) - } - - def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { - // If anything below fails, we should abort the task. - try { - writerContainer.executorSideSetup(taskContext) - - val partitionProj = newProjection(codegenEnabled, partitionOutput, output) - val dataProj = newProjection(codegenEnabled, dataOutput, output) - - val dataConverter: InternalRow => Row = if (needsConversion) { - CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] - } else { - r: InternalRow => r.asInstanceOf[Row] - } - val partitionSchema = StructType.fromAttributes(partitionOutput) - val partConverter: InternalRow => Row = - CatalystTypeConverters.createToScalaConverter(partitionSchema) - .asInstanceOf[InternalRow => Row] - - while (iterator.hasNext) { - val row = iterator.next() - val partitionPart = partConverter(partitionProj(row)) - val dataPart = dataConverter(dataProj(row)) - writerContainer.outputWriterForRow(partitionPart).write(dataPart) - } - - writerContainer.commitTask() - } catch { case cause: Throwable => - logError("Aborting task.", cause) - writerContainer.abortTask() - throw new SparkException("Task failed while writing rows.", cause) - } - } - } - - // This is copied from SparkPlan, probably should move this to a more general place. - private def newProjection( - codegenEnabled: Boolean, - expressions: Seq[Expression], - inputSchema: Seq[Attribute]): Projection = { - log.debug( - s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if (codegenEnabled && expressions.forall(_.isThreadSafe)) { - GenerateProjection.generate(expressions, inputSchema) - } else { - new InterpretedProjection(expressions, inputSchema) - } - } -} - -private[sql] abstract class BaseWriterContainer( - @transient val relation: HadoopFsRelation, - @transient job: Job) - extends SparkHadoopMapReduceUtil - with Logging - with Serializable { - - protected val serializableConf = new SerializableConfiguration(ContextUtil.getConfiguration(job)) - - // This is only used on driver side. - @transient private val jobContext: JobContext = job - - // The following fields are initialized and used on both driver and executor side. - @transient protected var outputCommitter: OutputCommitter = _ - @transient private var jobId: JobID = _ - @transient private var taskId: TaskID = _ - @transient private var taskAttemptId: TaskAttemptID = _ - @transient protected var taskAttemptContext: TaskAttemptContext = _ - - protected val outputPath: String = { - assert( - relation.paths.length == 1, - s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}") - relation.paths.head - } - - protected val dataSchema = relation.dataSchema - - protected var outputWriterFactory: OutputWriterFactory = _ - - private var outputFormatClass: Class[_ <: OutputFormat[_, _]] = _ - - def driverSideSetup(): Unit = { - setupIDs(0, 0, 0) - setupConf() - - // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor - // clones the Configuration object passed in. If we initialize the TaskAttemptContext first, - // configurations made in prepareJobForWrite(job) are not populated into the TaskAttemptContext. - // - // Also, the `prepareJobForWrite` call must happen before initializing output format and output - // committer, since their initialization involve the job configuration, which can be potentially - // decorated in `prepareJobForWrite`. - outputWriterFactory = relation.prepareJobForWrite(job) - taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) - - outputFormatClass = job.getOutputFormatClass - outputCommitter = newOutputCommitter(taskAttemptContext) - outputCommitter.setupJob(jobContext) - } - - def executorSideSetup(taskContext: TaskContext): Unit = { - setupIDs(taskContext.stageId(), taskContext.partitionId(), taskContext.attemptNumber()) - setupConf() - taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) - outputCommitter = newOutputCommitter(taskAttemptContext) - outputCommitter.setupTask(taskAttemptContext) - initWriters() - } - - protected def getWorkPath: String = { - outputCommitter match { - // FileOutputCommitter writes to a temporary location returned by `getWorkPath`. - case f: MapReduceFileOutputCommitter => f.getWorkPath.toString - case _ => outputPath - } - } - - private def newOutputCommitter(context: TaskAttemptContext): OutputCommitter = { - val committerClass = context.getConfiguration.getClass( - SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) - - Option(committerClass).map { clazz => - logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}") - - // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat - // has an associated output committer. To override this output committer, - // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS. - // If a data source needs to override the output committer, it needs to set the - // output committer in prepareForWrite method. - if (classOf[MapReduceFileOutputCommitter].isAssignableFrom(clazz)) { - // The specified output committer is a FileOutputCommitter. - // So, we will use the FileOutputCommitter-specified constructor. - val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) - ctor.newInstance(new Path(outputPath), context) - } else { - // The specified output committer is just a OutputCommitter. - // So, we will use the no-argument constructor. - val ctor = clazz.getDeclaredConstructor() - ctor.newInstance() - } - }.getOrElse { - // If output committer class is not set, we will use the one associated with the - // file output format. - val outputCommitter = outputFormatClass.newInstance().getOutputCommitter(context) - logInfo(s"Using output committer class ${outputCommitter.getClass.getCanonicalName}") - outputCommitter - } - } - - private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = { - this.jobId = SparkHadoopWriter.createJobID(new Date, jobId) - this.taskId = new TaskID(this.jobId, true, splitId) - this.taskAttemptId = new TaskAttemptID(taskId, attemptId) - } - - private def setupConf(): Unit = { - serializableConf.value.set("mapred.job.id", jobId.toString) - serializableConf.value.set("mapred.tip.id", taskAttemptId.getTaskID.toString) - serializableConf.value.set("mapred.task.id", taskAttemptId.toString) - serializableConf.value.setBoolean("mapred.task.is.map", true) - serializableConf.value.setInt("mapred.task.partition", 0) - } - - // Called on executor side when writing rows - def outputWriterForRow(row: Row): OutputWriter - - protected def initWriters(): Unit - - def commitTask(): Unit = { - SparkHadoopMapRedUtil.commitTask( - outputCommitter, taskAttemptContext, jobId.getId, taskId.getId, taskAttemptId.getId) - } - - def abortTask(): Unit = { - if (outputCommitter != null) { - outputCommitter.abortTask(taskAttemptContext) - } - logError(s"Task attempt $taskAttemptId aborted.") - } - - def commitJob(): Unit = { - outputCommitter.commitJob(jobContext) - logInfo(s"Job $jobId committed.") - } - - def abortJob(): Unit = { - if (outputCommitter != null) { - outputCommitter.abortJob(jobContext, JobStatus.State.FAILED) - } - logError(s"Job $jobId aborted.") - } -} - -private[sql] class DefaultWriterContainer( - @transient relation: HadoopFsRelation, - @transient job: Job) - extends BaseWriterContainer(relation, job) { - - @transient private var writer: OutputWriter = _ - - override protected def initWriters(): Unit = { - taskAttemptContext.getConfiguration.set("spark.sql.sources.output.path", outputPath) - writer = outputWriterFactory.newInstance(getWorkPath, dataSchema, taskAttemptContext) - } - - override def outputWriterForRow(row: Row): OutputWriter = writer - - override def commitTask(): Unit = { - try { - assert(writer != null, "OutputWriter instance should have been initialized") - writer.close() - super.commitTask() - } catch { - case cause: Throwable => - super.abortTask() - throw new RuntimeException("Failed to commit task", cause) - } - } - - override def abortTask(): Unit = { - try { - if (writer != null) { - writer.close() - } - } finally { - super.abortTask() - } - } -} - -private[sql] class DynamicPartitionWriterContainer( - @transient relation: HadoopFsRelation, - @transient job: Job, - partitionColumns: Array[String], - defaultPartitionName: String) - extends BaseWriterContainer(relation, job) { - - // All output writers are created on executor side. - @transient protected var outputWriters: mutable.Map[String, OutputWriter] = _ - - override protected def initWriters(): Unit = { - outputWriters = mutable.Map.empty[String, OutputWriter] - } - - override def outputWriterForRow(row: Row): OutputWriter = { - val partitionPath = partitionColumns.zip(row.toSeq).map { case (col, rawValue) => - val string = if (rawValue == null) null else String.valueOf(rawValue) - val valueString = if (string == null || string.isEmpty) { - defaultPartitionName - } else { - PartitioningUtils.escapePathName(string) - } - s"/$col=$valueString" - }.mkString.stripPrefix(Path.SEPARATOR) - - outputWriters.getOrElseUpdate(partitionPath, { - val path = new Path(getWorkPath, partitionPath) - taskAttemptContext.getConfiguration.set( - "spark.sql.sources.output.path", - new Path(outputPath, partitionPath).toString) - outputWriterFactory.newInstance(path.toString, dataSchema, taskAttemptContext) - }) - } - - override def commitTask(): Unit = { - try { - outputWriters.values.foreach(_.close()) - outputWriters.clear() - super.commitTask() - } catch { case cause: Throwable => - super.abortTask() - throw new RuntimeException("Failed to commit task", cause) - } - } - - override def abortTask(): Unit = { - try { - outputWriters.values.foreach(_.close()) - outputWriters.clear() - } finally { - super.abortTask() - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala deleted file mode 100644 index b7095c8ead79..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ /dev/null @@ -1,485 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources - -import scala.language.{existentials, implicitConversions} -import scala.util.matching.Regex - -import org.apache.hadoop.fs.Path - -import org.apache.spark.Logging -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.catalyst.AbstractSparkSQLParser -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InternalRow} -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext, SaveMode} -import org.apache.spark.util.Utils - -/** - * A parser for foreign DDL commands. - */ -private[sql] class DDLParser( - parseQuery: String => LogicalPlan) - extends AbstractSparkSQLParser with DataTypeParser with Logging { - - def parse(input: String, exceptionOnError: Boolean): LogicalPlan = { - try { - parse(input) - } catch { - case ddlException: DDLException => throw ddlException - case _ if !exceptionOnError => parseQuery(input) - case x: Throwable => throw x - } - } - - // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` - // properties via reflection the class in runtime for constructing the SqlLexical object - protected val CREATE = Keyword("CREATE") - protected val TEMPORARY = Keyword("TEMPORARY") - protected val TABLE = Keyword("TABLE") - protected val IF = Keyword("IF") - protected val NOT = Keyword("NOT") - protected val EXISTS = Keyword("EXISTS") - protected val USING = Keyword("USING") - protected val OPTIONS = Keyword("OPTIONS") - protected val DESCRIBE = Keyword("DESCRIBE") - protected val EXTENDED = Keyword("EXTENDED") - protected val AS = Keyword("AS") - protected val COMMENT = Keyword("COMMENT") - protected val REFRESH = Keyword("REFRESH") - - protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable | refreshTable - - protected def start: Parser[LogicalPlan] = ddl - - /** - * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * or - * `CREATE [TEMPORARY] TABLE avroTable(intField int, stringField string...) [IF NOT EXISTS] - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * or - * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * AS SELECT ... - */ - protected lazy val createTable: Parser[LogicalPlan] = - // TODO: Support database.table. - (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ ident ~ - tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ { - case temp ~ allowExisting ~ tableName ~ columns ~ provider ~ opts ~ query => - if (temp.isDefined && allowExisting.isDefined) { - throw new DDLException( - "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") - } - - val options = opts.getOrElse(Map.empty[String, String]) - if (query.isDefined) { - if (columns.isDefined) { - throw new DDLException( - "a CREATE TABLE AS SELECT statement does not allow column definitions.") - } - // When IF NOT EXISTS clause appears in the query, the save mode will be ignore. - val mode = if (allowExisting.isDefined) { - SaveMode.Ignore - } else if (temp.isDefined) { - SaveMode.Overwrite - } else { - SaveMode.ErrorIfExists - } - - val queryPlan = parseQuery(query.get) - CreateTableUsingAsSelect(tableName, - provider, - temp.isDefined, - Array.empty[String], - mode, - options, - queryPlan) - } else { - val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) - CreateTableUsing( - tableName, - userSpecifiedSchema, - provider, - temp.isDefined, - options, - allowExisting.isDefined, - managedIfNoPath = false) - } - } - - protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" - - /* - * describe [extended] table avroTable - * This will display all columns of table `avroTable` includes column_name,column_type,comment - */ - protected lazy val describeTable: Parser[LogicalPlan] = - (DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ { - case e ~ db ~ tbl => - val tblIdentifier = db match { - case Some(dbName) => - Seq(dbName, tbl) - case None => - Seq(tbl) - } - DescribeCommand(UnresolvedRelation(tblIdentifier, None), e.isDefined) - } - - protected lazy val refreshTable: Parser[LogicalPlan] = - REFRESH ~> TABLE ~> (ident <~ ".").? ~ ident ^^ { - case maybeDatabaseName ~ tableName => - RefreshTable(maybeDatabaseName.getOrElse("default"), tableName) - } - - protected lazy val options: Parser[Map[String, String]] = - "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } - - protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} - - override implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( - s"identifier matching regex $regex", { - case lexical.Identifier(str) if regex.unapplySeq(str).isDefined => str - case lexical.Keyword(str) if regex.unapplySeq(str).isDefined => str - } - ) - - protected lazy val optionPart: Parser[String] = "[_a-zA-Z][_a-zA-Z0-9]*".r ^^ { - case name => name - } - - protected lazy val optionName: Parser[String] = repsep(optionPart, ".") ^^ { - case parts => parts.mkString(".") - } - - protected lazy val pair: Parser[(String, String)] = - optionName ~ stringLit ^^ { case k ~ v => (k, v) } - - protected lazy val column: Parser[StructField] = - ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => - val meta = cm match { - case Some(comment) => - new MetadataBuilder().putString(COMMENT.str.toLowerCase, comment).build() - case None => Metadata.empty - } - - StructField(columnName, typ, nullable = true, meta) - } -} - -private[sql] object ResolvedDataSource { - - private val builtinSources = Map( - "jdbc" -> "org.apache.spark.sql.jdbc.DefaultSource", - "json" -> "org.apache.spark.sql.json.DefaultSource", - "parquet" -> "org.apache.spark.sql.parquet.DefaultSource", - "orc" -> "org.apache.spark.sql.hive.orc.DefaultSource" - ) - - /** Given a provider name, look up the data source class definition. */ - def lookupDataSource(provider: String): Class[_] = { - val loader = Utils.getContextOrSparkClassLoader - - if (builtinSources.contains(provider)) { - return loader.loadClass(builtinSources(provider)) - } - - try { - loader.loadClass(provider) - } catch { - case cnf: java.lang.ClassNotFoundException => - try { - loader.loadClass(provider + ".DefaultSource") - } catch { - case cnf: java.lang.ClassNotFoundException => - if (provider.startsWith("org.apache.spark.sql.hive.orc")) { - sys.error("The ORC data source must be used with Hive support enabled.") - } else { - sys.error(s"Failed to load class for data source: $provider") - } - } - } - } - - /** Create a [[ResolvedDataSource]] for reading data in. */ - def apply( - sqlContext: SQLContext, - userSpecifiedSchema: Option[StructType], - partitionColumns: Array[String], - provider: String, - options: Map[String, String]): ResolvedDataSource = { - val clazz: Class[_] = lookupDataSource(provider) - def className: String = clazz.getCanonicalName - val relation = userSpecifiedSchema match { - case Some(schema: StructType) => clazz.newInstance() match { - case dataSource: SchemaRelationProvider => - dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema) - case dataSource: HadoopFsRelationProvider => - val maybePartitionsSchema = if (partitionColumns.isEmpty) { - None - } else { - Some(partitionColumnsSchema(schema, partitionColumns)) - } - - val caseInsensitiveOptions = new CaseInsensitiveMap(options) - val paths = { - val patternPath = new Path(caseInsensitiveOptions("path")) - SparkHadoopUtil.get.globPath(patternPath).map(_.toString).toArray - } - - val dataSchema = - StructType(schema.filterNot(f => partitionColumns.contains(f.name))).asNullable - - dataSource.createRelation( - sqlContext, - paths, - Some(dataSchema), - maybePartitionsSchema, - caseInsensitiveOptions) - case dataSource: org.apache.spark.sql.sources.RelationProvider => - throw new AnalysisException(s"$className does not allow user-specified schemas.") - case _ => - throw new AnalysisException(s"$className is not a RelationProvider.") - } - - case None => clazz.newInstance() match { - case dataSource: RelationProvider => - dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) - case dataSource: HadoopFsRelationProvider => - val caseInsensitiveOptions = new CaseInsensitiveMap(options) - val paths = { - val patternPath = new Path(caseInsensitiveOptions("path")) - SparkHadoopUtil.get.globPath(patternPath).map(_.toString).toArray - } - dataSource.createRelation(sqlContext, paths, None, None, caseInsensitiveOptions) - case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => - throw new AnalysisException( - s"A schema needs to be specified when using $className.") - case _ => - throw new AnalysisException( - s"$className is neither a RelationProvider nor a FSBasedRelationProvider.") - } - } - new ResolvedDataSource(clazz, relation) - } - - private def partitionColumnsSchema( - schema: StructType, - partitionColumns: Array[String]): StructType = { - StructType(partitionColumns.map { col => - schema.find(_.name == col).getOrElse { - throw new RuntimeException(s"Partition column $col not found in schema $schema") - } - }).asNullable - } - - /** Create a [[ResolvedDataSource]] for saving the content of the given [[DataFrame]]. */ - def apply( - sqlContext: SQLContext, - provider: String, - partitionColumns: Array[String], - mode: SaveMode, - options: Map[String, String], - data: DataFrame): ResolvedDataSource = { - val clazz: Class[_] = lookupDataSource(provider) - val relation = clazz.newInstance() match { - case dataSource: CreatableRelationProvider => - dataSource.createRelation(sqlContext, mode, options, data) - case dataSource: HadoopFsRelationProvider => - // Don't glob path for the write path. The contracts here are: - // 1. Only one output path can be specified on the write path; - // 2. Output path must be a legal HDFS style file system path; - // 3. It's OK that the output path doesn't exist yet; - val caseInsensitiveOptions = new CaseInsensitiveMap(options) - val outputPath = { - val path = new Path(caseInsensitiveOptions("path")) - val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - path.makeQualified(fs.getUri, fs.getWorkingDirectory) - } - val dataSchema = StructType(data.schema.filterNot(f => partitionColumns.contains(f.name))) - val r = dataSource.createRelation( - sqlContext, - Array(outputPath.toString), - Some(dataSchema.asNullable), - Some(partitionColumnsSchema(data.schema, partitionColumns)), - caseInsensitiveOptions) - - // For partitioned relation r, r.schema's column ordering can be different from the column - // ordering of data.logicalPlan (partition columns are all moved after data column). This - // will be adjusted within InsertIntoHadoopFsRelation. - sqlContext.executePlan( - InsertIntoHadoopFsRelation( - r, - data.logicalPlan, - mode)).toRdd - r - case _ => - sys.error(s"${clazz.getCanonicalName} does not allow create table as select.") - } - new ResolvedDataSource(clazz, relation) - } -} - -private[sql] case class ResolvedDataSource(provider: Class[_], relation: BaseRelation) - -/** - * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command. - * @param table The table to be described. - * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false. - * It is effective only when the table is a Hive table. - */ -private[sql] case class DescribeCommand( - table: LogicalPlan, - isExtended: Boolean) extends LogicalPlan with Command { - - override def children: Seq[LogicalPlan] = Seq.empty - override val output: Seq[Attribute] = Seq( - // Column names are based on Hive. - AttributeReference("col_name", StringType, nullable = false, - new MetadataBuilder().putString("comment", "name of the column").build())(), - AttributeReference("data_type", StringType, nullable = false, - new MetadataBuilder().putString("comment", "data type of the column").build())(), - AttributeReference("comment", StringType, nullable = false, - new MetadataBuilder().putString("comment", "comment of the column").build())()) -} - -/** - * Used to represent the operation of create table using a data source. - * @param allowExisting If it is true, we will do nothing when the table already exists. - * If it is false, an exception will be thrown - */ -private[sql] case class CreateTableUsing( - tableName: String, - userSpecifiedSchema: Option[StructType], - provider: String, - temporary: Boolean, - options: Map[String, String], - allowExisting: Boolean, - managedIfNoPath: Boolean) extends LogicalPlan with Command { - - override def output: Seq[Attribute] = Seq.empty - override def children: Seq[LogicalPlan] = Seq.empty -} - -/** - * A node used to support CTAS statements and saveAsTable for the data source API. - * This node is a [[UnaryNode]] instead of a [[Command]] because we want the analyzer - * can analyze the logical plan that will be used to populate the table. - * So, [[PreWriteCheck]] can detect cases that are not allowed. - */ -private[sql] case class CreateTableUsingAsSelect( - tableName: String, - provider: String, - temporary: Boolean, - partitionColumns: Array[String], - mode: SaveMode, - options: Map[String, String], - child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = Seq.empty[Attribute] - // TODO: Override resolved after we support databaseName. - // override lazy val resolved = databaseName != None && childrenResolved -} - -private[sql] case class CreateTempTableUsing( - tableName: String, - userSpecifiedSchema: Option[StructType], - provider: String, - options: Map[String, String]) extends RunnableCommand { - - def run(sqlContext: SQLContext): Seq[InternalRow] = { - val resolved = ResolvedDataSource( - sqlContext, userSpecifiedSchema, Array.empty[String], provider, options) - sqlContext.registerDataFrameAsTable( - DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) - Seq.empty - } -} - -private[sql] case class CreateTempTableUsingAsSelect( - tableName: String, - provider: String, - partitionColumns: Array[String], - mode: SaveMode, - options: Map[String, String], - query: LogicalPlan) extends RunnableCommand { - - override def run(sqlContext: SQLContext): Seq[InternalRow] = { - val df = DataFrame(sqlContext, query) - val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df) - sqlContext.registerDataFrameAsTable( - DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) - - Seq.empty - } -} - -private[sql] case class RefreshTable(databaseName: String, tableName: String) - extends RunnableCommand { - - override def run(sqlContext: SQLContext): Seq[InternalRow] = { - // Refresh the given table's metadata first. - sqlContext.catalog.refreshTable(databaseName, tableName) - - // If this table is cached as a InMemoryColumnarRelation, drop the original - // cached version and make the new version cached lazily. - val logicalPlan = sqlContext.catalog.lookupRelation(Seq(databaseName, tableName)) - // Use lookupCachedData directly since RefreshTable also takes databaseName. - val isCached = sqlContext.cacheManager.lookupCachedData(logicalPlan).nonEmpty - if (isCached) { - // Create a data frame to represent the table. - // TODO: Use uncacheTable once it supports database name. - val df = DataFrame(sqlContext, logicalPlan) - // Uncache the logicalPlan. - sqlContext.cacheManager.tryUncacheQuery(df, blocking = true) - // Cache it again. - sqlContext.cacheManager.cacheQuery(df, Some(tableName)) - } - - Seq.empty[InternalRow] - } -} - -/** - * Builds a map in which keys are case insensitive - */ -protected[sql] class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] - with Serializable { - - val baseMap = map.map(kv => kv.copy(_1 = kv._1.toLowerCase)) - - override def get(k: String): Option[String] = baseMap.get(k.toLowerCase) - - override def + [B1 >: String](kv: (String, B1)): Map[String, B1] = - baseMap + kv.copy(_1 = kv._1.toLowerCase) - - override def iterator: Iterator[(String, String)] = baseMap.iterator - - override def -(key: String): Map[String, String] = baseMap - key.toLowerCase -} - -/** - * The exception thrown from the DDL parser. - */ -protected[sql] class DDLException(message: String) extends Exception(message) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 24e86ca415c5..3780cbbcc963 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.sources +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines all the filters that we can push down to the data sources. +//////////////////////////////////////////////////////////////////////////////////////////////////// + /** * A filter predicate for data sources. * @@ -32,6 +36,15 @@ abstract class Filter */ case class EqualTo(attribute: String, value: Any) extends Filter +/** + * Performs equality comparison, similar to [[EqualTo]]. However, this differs from [[EqualTo]] + * in that it returns `true` (rather than NULL) if both inputs are NULL, and `false` + * (rather than NULL) if one of the input is NULL and the other is not NULL. + * + * @since 1.5.0 + */ +case class EqualNullSafe(attribute: String, value: Any) extends Filter + /** * A filter that evaluates to `true` iff the attribute evaluates to a value * greater than `value`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 7005c7079af9..7b030b7d73bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -24,16 +24,45 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.execution.RDDConversions -import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.execution.{FileRelation, RDDConversions} +import org.apache.spark.sql.execution.datasources.{PartitioningUtils, PartitionSpec, Partition} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql._ import org.apache.spark.util.SerializableConfiguration +/** + * ::DeveloperApi:: + * Data sources should implement this trait so that they can register an alias to their data source. + * This allows users to give the data source alias as the format type over the fully qualified + * class name. + * + * A new instance of this class with be instantiated each time a DDL call is made. + * + * @since 1.5.0 + */ +@DeveloperApi +trait DataSourceRegister { + + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source. For example: + * + * {{{ + * override def format(): String = "parquet" + * }}} + * + * @since 1.5.0 + */ + def shortName(): String +} + /** * ::DeveloperApi:: * Implemented by objects that produce relations for a specific kind of data source. When @@ -74,7 +103,7 @@ trait RelationProvider { * A new instance of this class with be instantiated each time a DDL call is made. * * The difference between a [[RelationProvider]] and a [[SchemaRelationProvider]] is that - * users need to provide a schema when using a SchemaRelationProvider. + * users need to provide a schema when using a [[SchemaRelationProvider]]. * A relation provider can inherits both [[RelationProvider]] and [[SchemaRelationProvider]] * if it can support both schema inference and user-specified schemas. * @@ -110,7 +139,7 @@ trait SchemaRelationProvider { * * The difference between a [[RelationProvider]] and a [[HadoopFsRelationProvider]] is * that users need to provide a schema and a (possibly empty) list of partition columns when - * using a SchemaRelationProvider. A relation provider can inherits both [[RelationProvider]], + * using a [[HadoopFsRelationProvider]]. A relation provider can inherits both [[RelationProvider]], * and [[HadoopFsRelationProvider]] if it can support schema inference, user-specified * schemas, and accessing partitioned relations. * @@ -339,6 +368,17 @@ abstract class OutputWriter { * @since 1.4.0 */ def close(): Unit + + private var converter: InternalRow => Row = _ + + protected[sql] def initConverter(dataSchema: StructType) = { + converter = + CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] + } + + protected[sql] def writeInternal(row: InternalRow): Unit = { + write(converter(row)) + } } /** @@ -366,7 +406,9 @@ abstract class OutputWriter { */ @Experimental abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[PartitionSpec]) - extends BaseRelation { + extends BaseRelation with FileRelation with Logging { + + override def toString: String = getClass.getSimpleName + paths.mkString("[", ",", "]") def this() = this(None) @@ -381,36 +423,40 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]] - def refresh(): Unit = { - // We don't filter files/directories whose name start with "_" except "_temporary" here, as - // specific data sources may take advantages over them (e.g. Parquet _metadata and - // _common_metadata files). "_temporary" directories are explicitly ignored since failed - // tasks/jobs may leave partial/corrupted data files there. - def listLeafFilesAndDirs(fs: FileSystem, status: FileStatus): Set[FileStatus] = { - if (status.getPath.getName.toLowerCase == "_temporary") { - Set.empty + private def listLeafFiles(paths: Array[String]): Set[FileStatus] = { + if (paths.length >= sqlContext.conf.parallelPartitionDiscoveryThreshold) { + HadoopFsRelation.listLeafFilesInParallel(paths, hadoopConf, sqlContext.sparkContext) + } else { + val statuses = paths.flatMap { path => + val hdfsPath = new Path(path) + val fs = hdfsPath.getFileSystem(hadoopConf) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + + logInfo(s"Listing $qualified on driver") + Try(fs.listStatus(qualified)).getOrElse(Array.empty) + }.filterNot { status => + val name = status.getPath.getName + name.toLowerCase == "_temporary" || name.startsWith(".") + } + + val (dirs, files) = statuses.partition(_.isDir) + + if (dirs.isEmpty) { + files.toSet } else { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) - val leafDirs = if (dirs.isEmpty) Set(status) else Set.empty[FileStatus] - files.toSet ++ leafDirs ++ dirs.flatMap(dir => listLeafFilesAndDirs(fs, dir)) + files.toSet ++ listLeafFiles(dirs.map(_.getPath.toString)) } } + } - leafFiles.clear() + def refresh(): Unit = { + val files = listLeafFiles(paths) - val statuses = paths.flatMap { path => - val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(hadoopConf) - val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - Try(fs.getFileStatus(qualified)).toOption.toArray.flatMap(listLeafFilesAndDirs(fs, _)) - }.filterNot { status => - // SPARK-8037: Ignores files like ".DS_Store" and other hidden files/directories - status.getPath.getName.startsWith(".") - } + leafFiles.clear() + leafDirToChildrenFiles.clear() - val files = statuses.filterNot(_.isDir) leafFiles ++= files.map(f => f.getPath -> f).toMap - leafDirToChildrenFiles ++= files.groupBy(_.getPath.getParent) + leafDirToChildrenFiles ++= files.toArray.groupBy(_.getPath.getParent) } } @@ -440,8 +486,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio val spec = discoverPartitions() val partitionColumnTypes = spec.partitionColumns.map(_.dataType) val castedPartitions = spec.partitions.map { case p @ Partition(values, path) => - val literals = values.toSeq.zip(partitionColumnTypes).map { - case (value, dataType) => Literal.create(value, dataType) + val literals = partitionColumnTypes.zipWithIndex.map { case (dt, i) => + Literal.create(values.get(i, dt), dt) } val castedValues = partitionSchema.zip(literals).map { case (field, literal) => Cast(literal, field.dataType).eval() @@ -470,6 +516,10 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio */ def paths: Array[String] + override def inputFiles: Array[String] = cachedLeafStatuses().map(_.getPath.toString).toArray + + override def sizeInBytes: Long = cachedLeafStatuses().map(_.getLen).sum + /** * Partition columns. Can be either defined by [[userDefinedPartitionColumns]] or automatically * discovered. Note that they should always be nullable. @@ -514,7 +564,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio }) } - private[sources] final def buildScan( + final private[sql] def buildScan( requiredColumns: Array[String], filters: Array[Filter], inputPaths: Array[String], @@ -572,6 +622,11 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio * * @since 1.4.0 */ + // TODO Tries to eliminate the extra Catalyst-to-Scala conversion when `needConversion` is true + // + // PR #7626 separated `Row` and `InternalRow` completely. One of the consequences is that we can + // no longer treat an `InternalRow` containing Catalyst values as a `Row`. Thus we have to + // introduce another row value conversion for data sources whose `needConversion` is true. def buildScan(requiredColumns: Array[String], inputFiles: Array[FileStatus]): RDD[Row] = { // Yeah, to workaround serialization... val dataSchema = this.dataSchema @@ -583,22 +638,34 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio BoundReference(dataSchema.fieldIndex(col), field.dataType, field.nullable) }.toSeq - val rdd = buildScan(inputFiles) - val converted = + val rdd: RDD[Row] = buildScan(inputFiles) + val converted: RDD[InternalRow] = if (needConversion) { RDDConversions.rowToRowRdd(rdd, dataSchema.fields.map(_.dataType)) } else { - rdd.map(_.asInstanceOf[InternalRow]) + rdd.asInstanceOf[RDD[InternalRow]] } + converted.mapPartitions { rows => - val buildProjection = if (codegenEnabled && requiredOutput.forall(_.isThreadSafe)) { + val buildProjection = if (codegenEnabled) { GenerateMutableProjection.generate(requiredOutput, dataSchema.toAttributes) } else { () => new InterpretedMutableProjection(requiredOutput, dataSchema.toAttributes) } - val mutableProjection = buildProjection() - rows.map(r => mutableProjection(r).asInstanceOf[Row]) - } + + val projectedRows = { + val mutableProjection = buildProjection() + rows.map(r => mutableProjection(r)) + } + + if (needConversion) { + val requiredSchema = StructType(requiredColumns.map(dataSchema(_))) + val toScala = CatalystTypeConverters.createToScalaConverter(requiredSchema) + projectedRows.map(toScala(_).asInstanceOf[Row]) + } else { + projectedRows + } + }.asInstanceOf[RDD[Row]] } /** @@ -665,3 +732,63 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio */ def prepareJobForWrite(job: Job): OutputWriterFactory } + +private[sql] object HadoopFsRelation extends Logging { + // We don't filter files/directories whose name start with "_" except "_temporary" here, as + // specific data sources may take advantages over them (e.g. Parquet _metadata and + // _common_metadata files). "_temporary" directories are explicitly ignored since failed + // tasks/jobs may leave partial/corrupted data files there. Files and directories whose name + // start with "." are also ignored. + def listLeafFiles(fs: FileSystem, status: FileStatus): Array[FileStatus] = { + logInfo(s"Listing ${status.getPath}") + val name = status.getPath.getName.toLowerCase + if (name == "_temporary" || name.startsWith(".")) { + Array.empty + } else { + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } + } + + // `FileStatus` is Writable but not serializable. What make it worse, somehow it doesn't play + // well with `SerializableWritable`. So there seems to be no way to serialize a `FileStatus`. + // Here we use `FakeFileStatus` to extract key components of a `FileStatus` to serialize it from + // executor side and reconstruct it on driver side. + case class FakeFileStatus( + path: String, + length: Long, + isDir: Boolean, + blockReplication: Short, + blockSize: Long, + modificationTime: Long, + accessTime: Long) + + def listLeafFilesInParallel( + paths: Array[String], + hadoopConf: Configuration, + sparkContext: SparkContext): Set[FileStatus] = { + logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}") + + val serializableConfiguration = new SerializableConfiguration(hadoopConf) + val fakeStatuses = sparkContext.parallelize(paths).flatMap { path => + val hdfsPath = new Path(path) + val fs = hdfsPath.getFileSystem(serializableConfiguration.value) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + Try(listLeafFiles(fs, fs.getFileStatus(qualified))).getOrElse(Array.empty) + }.map { status => + FakeFileStatus( + status.getPath.toString, + status.getLen, + status.isDir, + status.getReplication, + status.getBlockSize, + status.getModificationTime, + status.getAccessTime) + }.collect() + + fakeStatuses.map { f => + new FileStatus( + f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime, new Path(f.path)) + }.toSet + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala deleted file mode 100644 index 9fa394525d65..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.test - -import scala.language.implicitConversions - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan - -/** A SQLContext that can be used for local testing. */ -class LocalSQLContext - extends SQLContext( - new SparkContext( - "local[2]", - "TestSQLContext", - new SparkConf().set("spark.sql.testkey", "true"))) { - - override protected[sql] def createSession(): SQLSession = { - new this.SQLSession() - } - - protected[sql] class SQLSession extends super.SQLSession { - protected[sql] override lazy val conf: SQLConf = new SQLConf { - /** Fewer partitions to speed up testing. */ - override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5) - } - } - - /** - * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier way to - * construct [[DataFrame]] directly out of local data without relying on implicits. - */ - protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { - DataFrame(this, plan) - } - -} - -object TestSQLContext extends LocalSQLContext - diff --git a/sql/core/src/test/README.md b/sql/core/src/test/README.md new file mode 100644 index 000000000000..421c2ea4f7ae --- /dev/null +++ b/sql/core/src/test/README.md @@ -0,0 +1,29 @@ +# Notes for Parquet compatibility tests + +The following directories and files are used for Parquet compatibility tests: + +``` +. +├── README.md # This file +├── avro +│   ├── *.avdl # Testing Avro IDL(s) +│   └── *.avpr # !! NO TOUCH !! Protocol files generated from Avro IDL(s) +├── gen-java # !! NO TOUCH !! Generated Java code +├── scripts +│   ├── gen-avro.sh # Script used to generate Java code for Avro +│   └── gen-thrift.sh # Script used to generate Java code for Thrift +└── thrift + └── *.thrift # Testing Thrift schema(s) +``` + +To avoid code generation during build time, Java code generated from testing Thrift schema and Avro IDL are also checked in. + +When updating the testing Thrift schema and Avro IDL, please run `gen-avro.sh` and `gen-thrift.sh` accordingly to update generated Java code. + +## Prerequisites + +Please ensure `avro-tools` and `thrift` are installed. You may install these two on Mac OS X via: + +```bash +$ brew install thrift avro-tools +``` diff --git a/sql/core/src/test/avro/parquet-compat.avdl b/sql/core/src/test/avro/parquet-compat.avdl new file mode 100644 index 000000000000..c5eb5b5164cf --- /dev/null +++ b/sql/core/src/test/avro/parquet-compat.avdl @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This is a test protocol for testing parquet-avro compatibility. +@namespace("org.apache.spark.sql.execution.datasources.parquet.test.avro") +protocol CompatibilityTest { + enum Suit { + SPADES, + HEARTS, + DIAMONDS, + CLUBS + } + + record ParquetEnum { + Suit suit; + } + + record Nested { + array nested_ints_column; + string nested_string_column; + } + + record AvroPrimitives { + boolean bool_column; + int int_column; + long long_column; + float float_column; + double double_column; + bytes binary_column; + string string_column; + } + + record AvroOptionalPrimitives { + union { null, boolean } maybe_bool_column; + union { null, int } maybe_int_column; + union { null, long } maybe_long_column; + union { null, float } maybe_float_column; + union { null, double } maybe_double_column; + union { null, bytes } maybe_binary_column; + union { null, string } maybe_string_column; + } + + record AvroNonNullableArrays { + array strings_column; + union { null, array } maybe_ints_column; + } + + record AvroArrayOfArray { + array> int_arrays_column; + } + + record AvroMapOfArray { + map> string_to_ints_column; + } + + record ParquetAvroCompat { + array strings_column; + map string_to_int_column; + map> complex_column; + } +} diff --git a/sql/core/src/test/avro/parquet-compat.avpr b/sql/core/src/test/avro/parquet-compat.avpr new file mode 100644 index 000000000000..9ad315b74fb4 --- /dev/null +++ b/sql/core/src/test/avro/parquet-compat.avpr @@ -0,0 +1,147 @@ +{ + "protocol" : "CompatibilityTest", + "namespace" : "org.apache.spark.sql.execution.datasources.parquet.test.avro", + "types" : [ { + "type" : "enum", + "name" : "Suit", + "symbols" : [ "SPADES", "HEARTS", "DIAMONDS", "CLUBS" ] + }, { + "type" : "record", + "name" : "ParquetEnum", + "fields" : [ { + "name" : "suit", + "type" : "Suit" + } ] + }, { + "type" : "record", + "name" : "Nested", + "fields" : [ { + "name" : "nested_ints_column", + "type" : { + "type" : "array", + "items" : "int" + } + }, { + "name" : "nested_string_column", + "type" : "string" + } ] + }, { + "type" : "record", + "name" : "AvroPrimitives", + "fields" : [ { + "name" : "bool_column", + "type" : "boolean" + }, { + "name" : "int_column", + "type" : "int" + }, { + "name" : "long_column", + "type" : "long" + }, { + "name" : "float_column", + "type" : "float" + }, { + "name" : "double_column", + "type" : "double" + }, { + "name" : "binary_column", + "type" : "bytes" + }, { + "name" : "string_column", + "type" : "string" + } ] + }, { + "type" : "record", + "name" : "AvroOptionalPrimitives", + "fields" : [ { + "name" : "maybe_bool_column", + "type" : [ "null", "boolean" ] + }, { + "name" : "maybe_int_column", + "type" : [ "null", "int" ] + }, { + "name" : "maybe_long_column", + "type" : [ "null", "long" ] + }, { + "name" : "maybe_float_column", + "type" : [ "null", "float" ] + }, { + "name" : "maybe_double_column", + "type" : [ "null", "double" ] + }, { + "name" : "maybe_binary_column", + "type" : [ "null", "bytes" ] + }, { + "name" : "maybe_string_column", + "type" : [ "null", "string" ] + } ] + }, { + "type" : "record", + "name" : "AvroNonNullableArrays", + "fields" : [ { + "name" : "strings_column", + "type" : { + "type" : "array", + "items" : "string" + } + }, { + "name" : "maybe_ints_column", + "type" : [ "null", { + "type" : "array", + "items" : "int" + } ] + } ] + }, { + "type" : "record", + "name" : "AvroArrayOfArray", + "fields" : [ { + "name" : "int_arrays_column", + "type" : { + "type" : "array", + "items" : { + "type" : "array", + "items" : "int" + } + } + } ] + }, { + "type" : "record", + "name" : "AvroMapOfArray", + "fields" : [ { + "name" : "string_to_ints_column", + "type" : { + "type" : "map", + "values" : { + "type" : "array", + "items" : "int" + } + } + } ] + }, { + "type" : "record", + "name" : "ParquetAvroCompat", + "fields" : [ { + "name" : "strings_column", + "type" : { + "type" : "array", + "items" : "string" + } + }, { + "name" : "string_to_int_column", + "type" : { + "type" : "map", + "values" : "int" + } + }, { + "name" : "complex_column", + "type" : { + "type" : "map", + "values" : { + "type" : "array", + "items" : "Nested" + } + } + } ] + } ], + "messages" : { } +} \ No newline at end of file diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroArrayOfArray.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroArrayOfArray.java new file mode 100644 index 000000000000..ee327827903e --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroArrayOfArray.java @@ -0,0 +1,142 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroArrayOfArray extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroArrayOfArray\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"int_arrays_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"array\",\"items\":\"int\"}}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.List> int_arrays_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroArrayOfArray() {} + + /** + * All-args constructor. + */ + public AvroArrayOfArray(java.util.List> int_arrays_column) { + this.int_arrays_column = int_arrays_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return int_arrays_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: int_arrays_column = (java.util.List>)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'int_arrays_column' field. + */ + public java.util.List> getIntArraysColumn() { + return int_arrays_column; + } + + /** + * Sets the value of the 'int_arrays_column' field. + * @param value the value to set. + */ + public void setIntArraysColumn(java.util.List> value) { + this.int_arrays_column = value; + } + + /** Creates a new AvroArrayOfArray RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder(); + } + + /** Creates a new AvroArrayOfArray RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder(other); + } + + /** Creates a new AvroArrayOfArray RecordBuilder by copying an existing AvroArrayOfArray instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder(other); + } + + /** + * RecordBuilder for AvroArrayOfArray instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.List> int_arrays_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder other) { + super(other); + if (isValidValue(fields()[0], other.int_arrays_column)) { + this.int_arrays_column = data().deepCopy(fields()[0].schema(), other.int_arrays_column); + fieldSetFlags()[0] = true; + } + } + + /** Creates a Builder by copying an existing AvroArrayOfArray instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.SCHEMA$); + if (isValidValue(fields()[0], other.int_arrays_column)) { + this.int_arrays_column = data().deepCopy(fields()[0].schema(), other.int_arrays_column); + fieldSetFlags()[0] = true; + } + } + + /** Gets the value of the 'int_arrays_column' field */ + public java.util.List> getIntArraysColumn() { + return int_arrays_column; + } + + /** Sets the value of the 'int_arrays_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder setIntArraysColumn(java.util.List> value) { + validate(fields()[0], value); + this.int_arrays_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'int_arrays_column' field has been set */ + public boolean hasIntArraysColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'int_arrays_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder clearIntArraysColumn() { + int_arrays_column = null; + fieldSetFlags()[0] = false; + return this; + } + + @Override + public AvroArrayOfArray build() { + try { + AvroArrayOfArray record = new AvroArrayOfArray(); + record.int_arrays_column = fieldSetFlags()[0] ? this.int_arrays_column : (java.util.List>) defaultValue(fields()[0]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroMapOfArray.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroMapOfArray.java new file mode 100644 index 000000000000..727f6a7bf733 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroMapOfArray.java @@ -0,0 +1,142 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroMapOfArray extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroMapOfArray\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"string_to_ints_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"int\"},\"avro.java.string\":\"String\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.Map> string_to_ints_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroMapOfArray() {} + + /** + * All-args constructor. + */ + public AvroMapOfArray(java.util.Map> string_to_ints_column) { + this.string_to_ints_column = string_to_ints_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return string_to_ints_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: string_to_ints_column = (java.util.Map>)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'string_to_ints_column' field. + */ + public java.util.Map> getStringToIntsColumn() { + return string_to_ints_column; + } + + /** + * Sets the value of the 'string_to_ints_column' field. + * @param value the value to set. + */ + public void setStringToIntsColumn(java.util.Map> value) { + this.string_to_ints_column = value; + } + + /** Creates a new AvroMapOfArray RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder(); + } + + /** Creates a new AvroMapOfArray RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder(other); + } + + /** Creates a new AvroMapOfArray RecordBuilder by copying an existing AvroMapOfArray instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder(other); + } + + /** + * RecordBuilder for AvroMapOfArray instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.Map> string_to_ints_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder other) { + super(other); + if (isValidValue(fields()[0], other.string_to_ints_column)) { + this.string_to_ints_column = data().deepCopy(fields()[0].schema(), other.string_to_ints_column); + fieldSetFlags()[0] = true; + } + } + + /** Creates a Builder by copying an existing AvroMapOfArray instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.SCHEMA$); + if (isValidValue(fields()[0], other.string_to_ints_column)) { + this.string_to_ints_column = data().deepCopy(fields()[0].schema(), other.string_to_ints_column); + fieldSetFlags()[0] = true; + } + } + + /** Gets the value of the 'string_to_ints_column' field */ + public java.util.Map> getStringToIntsColumn() { + return string_to_ints_column; + } + + /** Sets the value of the 'string_to_ints_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder setStringToIntsColumn(java.util.Map> value) { + validate(fields()[0], value); + this.string_to_ints_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'string_to_ints_column' field has been set */ + public boolean hasStringToIntsColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'string_to_ints_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder clearStringToIntsColumn() { + string_to_ints_column = null; + fieldSetFlags()[0] = false; + return this; + } + + @Override + public AvroMapOfArray build() { + try { + AvroMapOfArray record = new AvroMapOfArray(); + record.string_to_ints_column = fieldSetFlags()[0] ? this.string_to_ints_column : (java.util.Map>) defaultValue(fields()[0]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroNonNullableArrays.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroNonNullableArrays.java new file mode 100644 index 000000000000..934793f42f9c --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroNonNullableArrays.java @@ -0,0 +1,196 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroNonNullableArrays extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroNonNullableArrays\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"maybe_ints_column\",\"type\":[\"null\",{\"type\":\"array\",\"items\":\"int\"}]}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.List strings_column; + @Deprecated public java.util.List maybe_ints_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroNonNullableArrays() {} + + /** + * All-args constructor. + */ + public AvroNonNullableArrays(java.util.List strings_column, java.util.List maybe_ints_column) { + this.strings_column = strings_column; + this.maybe_ints_column = maybe_ints_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return strings_column; + case 1: return maybe_ints_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: strings_column = (java.util.List)value$; break; + case 1: maybe_ints_column = (java.util.List)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'strings_column' field. + */ + public java.util.List getStringsColumn() { + return strings_column; + } + + /** + * Sets the value of the 'strings_column' field. + * @param value the value to set. + */ + public void setStringsColumn(java.util.List value) { + this.strings_column = value; + } + + /** + * Gets the value of the 'maybe_ints_column' field. + */ + public java.util.List getMaybeIntsColumn() { + return maybe_ints_column; + } + + /** + * Sets the value of the 'maybe_ints_column' field. + * @param value the value to set. + */ + public void setMaybeIntsColumn(java.util.List value) { + this.maybe_ints_column = value; + } + + /** Creates a new AvroNonNullableArrays RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder(); + } + + /** Creates a new AvroNonNullableArrays RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder(other); + } + + /** Creates a new AvroNonNullableArrays RecordBuilder by copying an existing AvroNonNullableArrays instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder(other); + } + + /** + * RecordBuilder for AvroNonNullableArrays instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.List strings_column; + private java.util.List maybe_ints_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder other) { + super(other); + if (isValidValue(fields()[0], other.strings_column)) { + this.strings_column = data().deepCopy(fields()[0].schema(), other.strings_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.maybe_ints_column)) { + this.maybe_ints_column = data().deepCopy(fields()[1].schema(), other.maybe_ints_column); + fieldSetFlags()[1] = true; + } + } + + /** Creates a Builder by copying an existing AvroNonNullableArrays instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.SCHEMA$); + if (isValidValue(fields()[0], other.strings_column)) { + this.strings_column = data().deepCopy(fields()[0].schema(), other.strings_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.maybe_ints_column)) { + this.maybe_ints_column = data().deepCopy(fields()[1].schema(), other.maybe_ints_column); + fieldSetFlags()[1] = true; + } + } + + /** Gets the value of the 'strings_column' field */ + public java.util.List getStringsColumn() { + return strings_column; + } + + /** Sets the value of the 'strings_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder setStringsColumn(java.util.List value) { + validate(fields()[0], value); + this.strings_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'strings_column' field has been set */ + public boolean hasStringsColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'strings_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder clearStringsColumn() { + strings_column = null; + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'maybe_ints_column' field */ + public java.util.List getMaybeIntsColumn() { + return maybe_ints_column; + } + + /** Sets the value of the 'maybe_ints_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder setMaybeIntsColumn(java.util.List value) { + validate(fields()[1], value); + this.maybe_ints_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'maybe_ints_column' field has been set */ + public boolean hasMaybeIntsColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'maybe_ints_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder clearMaybeIntsColumn() { + maybe_ints_column = null; + fieldSetFlags()[1] = false; + return this; + } + + @Override + public AvroNonNullableArrays build() { + try { + AvroNonNullableArrays record = new AvroNonNullableArrays(); + record.strings_column = fieldSetFlags()[0] ? this.strings_column : (java.util.List) defaultValue(fields()[0]); + record.maybe_ints_column = fieldSetFlags()[1] ? this.maybe_ints_column : (java.util.List) defaultValue(fields()[1]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroOptionalPrimitives.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroOptionalPrimitives.java new file mode 100644 index 000000000000..e4d1ead8dd15 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroOptionalPrimitives.java @@ -0,0 +1,466 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroOptionalPrimitives extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroOptionalPrimitives\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.lang.Boolean maybe_bool_column; + @Deprecated public java.lang.Integer maybe_int_column; + @Deprecated public java.lang.Long maybe_long_column; + @Deprecated public java.lang.Float maybe_float_column; + @Deprecated public java.lang.Double maybe_double_column; + @Deprecated public java.nio.ByteBuffer maybe_binary_column; + @Deprecated public java.lang.String maybe_string_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroOptionalPrimitives() {} + + /** + * All-args constructor. + */ + public AvroOptionalPrimitives(java.lang.Boolean maybe_bool_column, java.lang.Integer maybe_int_column, java.lang.Long maybe_long_column, java.lang.Float maybe_float_column, java.lang.Double maybe_double_column, java.nio.ByteBuffer maybe_binary_column, java.lang.String maybe_string_column) { + this.maybe_bool_column = maybe_bool_column; + this.maybe_int_column = maybe_int_column; + this.maybe_long_column = maybe_long_column; + this.maybe_float_column = maybe_float_column; + this.maybe_double_column = maybe_double_column; + this.maybe_binary_column = maybe_binary_column; + this.maybe_string_column = maybe_string_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return maybe_bool_column; + case 1: return maybe_int_column; + case 2: return maybe_long_column; + case 3: return maybe_float_column; + case 4: return maybe_double_column; + case 5: return maybe_binary_column; + case 6: return maybe_string_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: maybe_bool_column = (java.lang.Boolean)value$; break; + case 1: maybe_int_column = (java.lang.Integer)value$; break; + case 2: maybe_long_column = (java.lang.Long)value$; break; + case 3: maybe_float_column = (java.lang.Float)value$; break; + case 4: maybe_double_column = (java.lang.Double)value$; break; + case 5: maybe_binary_column = (java.nio.ByteBuffer)value$; break; + case 6: maybe_string_column = (java.lang.String)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'maybe_bool_column' field. + */ + public java.lang.Boolean getMaybeBoolColumn() { + return maybe_bool_column; + } + + /** + * Sets the value of the 'maybe_bool_column' field. + * @param value the value to set. + */ + public void setMaybeBoolColumn(java.lang.Boolean value) { + this.maybe_bool_column = value; + } + + /** + * Gets the value of the 'maybe_int_column' field. + */ + public java.lang.Integer getMaybeIntColumn() { + return maybe_int_column; + } + + /** + * Sets the value of the 'maybe_int_column' field. + * @param value the value to set. + */ + public void setMaybeIntColumn(java.lang.Integer value) { + this.maybe_int_column = value; + } + + /** + * Gets the value of the 'maybe_long_column' field. + */ + public java.lang.Long getMaybeLongColumn() { + return maybe_long_column; + } + + /** + * Sets the value of the 'maybe_long_column' field. + * @param value the value to set. + */ + public void setMaybeLongColumn(java.lang.Long value) { + this.maybe_long_column = value; + } + + /** + * Gets the value of the 'maybe_float_column' field. + */ + public java.lang.Float getMaybeFloatColumn() { + return maybe_float_column; + } + + /** + * Sets the value of the 'maybe_float_column' field. + * @param value the value to set. + */ + public void setMaybeFloatColumn(java.lang.Float value) { + this.maybe_float_column = value; + } + + /** + * Gets the value of the 'maybe_double_column' field. + */ + public java.lang.Double getMaybeDoubleColumn() { + return maybe_double_column; + } + + /** + * Sets the value of the 'maybe_double_column' field. + * @param value the value to set. + */ + public void setMaybeDoubleColumn(java.lang.Double value) { + this.maybe_double_column = value; + } + + /** + * Gets the value of the 'maybe_binary_column' field. + */ + public java.nio.ByteBuffer getMaybeBinaryColumn() { + return maybe_binary_column; + } + + /** + * Sets the value of the 'maybe_binary_column' field. + * @param value the value to set. + */ + public void setMaybeBinaryColumn(java.nio.ByteBuffer value) { + this.maybe_binary_column = value; + } + + /** + * Gets the value of the 'maybe_string_column' field. + */ + public java.lang.String getMaybeStringColumn() { + return maybe_string_column; + } + + /** + * Sets the value of the 'maybe_string_column' field. + * @param value the value to set. + */ + public void setMaybeStringColumn(java.lang.String value) { + this.maybe_string_column = value; + } + + /** Creates a new AvroOptionalPrimitives RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder(); + } + + /** Creates a new AvroOptionalPrimitives RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder(other); + } + + /** Creates a new AvroOptionalPrimitives RecordBuilder by copying an existing AvroOptionalPrimitives instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder(other); + } + + /** + * RecordBuilder for AvroOptionalPrimitives instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.lang.Boolean maybe_bool_column; + private java.lang.Integer maybe_int_column; + private java.lang.Long maybe_long_column; + private java.lang.Float maybe_float_column; + private java.lang.Double maybe_double_column; + private java.nio.ByteBuffer maybe_binary_column; + private java.lang.String maybe_string_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder other) { + super(other); + if (isValidValue(fields()[0], other.maybe_bool_column)) { + this.maybe_bool_column = data().deepCopy(fields()[0].schema(), other.maybe_bool_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.maybe_int_column)) { + this.maybe_int_column = data().deepCopy(fields()[1].schema(), other.maybe_int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.maybe_long_column)) { + this.maybe_long_column = data().deepCopy(fields()[2].schema(), other.maybe_long_column); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.maybe_float_column)) { + this.maybe_float_column = data().deepCopy(fields()[3].schema(), other.maybe_float_column); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.maybe_double_column)) { + this.maybe_double_column = data().deepCopy(fields()[4].schema(), other.maybe_double_column); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.maybe_binary_column)) { + this.maybe_binary_column = data().deepCopy(fields()[5].schema(), other.maybe_binary_column); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.maybe_string_column)) { + this.maybe_string_column = data().deepCopy(fields()[6].schema(), other.maybe_string_column); + fieldSetFlags()[6] = true; + } + } + + /** Creates a Builder by copying an existing AvroOptionalPrimitives instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.SCHEMA$); + if (isValidValue(fields()[0], other.maybe_bool_column)) { + this.maybe_bool_column = data().deepCopy(fields()[0].schema(), other.maybe_bool_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.maybe_int_column)) { + this.maybe_int_column = data().deepCopy(fields()[1].schema(), other.maybe_int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.maybe_long_column)) { + this.maybe_long_column = data().deepCopy(fields()[2].schema(), other.maybe_long_column); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.maybe_float_column)) { + this.maybe_float_column = data().deepCopy(fields()[3].schema(), other.maybe_float_column); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.maybe_double_column)) { + this.maybe_double_column = data().deepCopy(fields()[4].schema(), other.maybe_double_column); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.maybe_binary_column)) { + this.maybe_binary_column = data().deepCopy(fields()[5].schema(), other.maybe_binary_column); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.maybe_string_column)) { + this.maybe_string_column = data().deepCopy(fields()[6].schema(), other.maybe_string_column); + fieldSetFlags()[6] = true; + } + } + + /** Gets the value of the 'maybe_bool_column' field */ + public java.lang.Boolean getMaybeBoolColumn() { + return maybe_bool_column; + } + + /** Sets the value of the 'maybe_bool_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeBoolColumn(java.lang.Boolean value) { + validate(fields()[0], value); + this.maybe_bool_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'maybe_bool_column' field has been set */ + public boolean hasMaybeBoolColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'maybe_bool_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeBoolColumn() { + maybe_bool_column = null; + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'maybe_int_column' field */ + public java.lang.Integer getMaybeIntColumn() { + return maybe_int_column; + } + + /** Sets the value of the 'maybe_int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeIntColumn(java.lang.Integer value) { + validate(fields()[1], value); + this.maybe_int_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'maybe_int_column' field has been set */ + public boolean hasMaybeIntColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'maybe_int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeIntColumn() { + maybe_int_column = null; + fieldSetFlags()[1] = false; + return this; + } + + /** Gets the value of the 'maybe_long_column' field */ + public java.lang.Long getMaybeLongColumn() { + return maybe_long_column; + } + + /** Sets the value of the 'maybe_long_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeLongColumn(java.lang.Long value) { + validate(fields()[2], value); + this.maybe_long_column = value; + fieldSetFlags()[2] = true; + return this; + } + + /** Checks whether the 'maybe_long_column' field has been set */ + public boolean hasMaybeLongColumn() { + return fieldSetFlags()[2]; + } + + /** Clears the value of the 'maybe_long_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeLongColumn() { + maybe_long_column = null; + fieldSetFlags()[2] = false; + return this; + } + + /** Gets the value of the 'maybe_float_column' field */ + public java.lang.Float getMaybeFloatColumn() { + return maybe_float_column; + } + + /** Sets the value of the 'maybe_float_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeFloatColumn(java.lang.Float value) { + validate(fields()[3], value); + this.maybe_float_column = value; + fieldSetFlags()[3] = true; + return this; + } + + /** Checks whether the 'maybe_float_column' field has been set */ + public boolean hasMaybeFloatColumn() { + return fieldSetFlags()[3]; + } + + /** Clears the value of the 'maybe_float_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeFloatColumn() { + maybe_float_column = null; + fieldSetFlags()[3] = false; + return this; + } + + /** Gets the value of the 'maybe_double_column' field */ + public java.lang.Double getMaybeDoubleColumn() { + return maybe_double_column; + } + + /** Sets the value of the 'maybe_double_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeDoubleColumn(java.lang.Double value) { + validate(fields()[4], value); + this.maybe_double_column = value; + fieldSetFlags()[4] = true; + return this; + } + + /** Checks whether the 'maybe_double_column' field has been set */ + public boolean hasMaybeDoubleColumn() { + return fieldSetFlags()[4]; + } + + /** Clears the value of the 'maybe_double_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeDoubleColumn() { + maybe_double_column = null; + fieldSetFlags()[4] = false; + return this; + } + + /** Gets the value of the 'maybe_binary_column' field */ + public java.nio.ByteBuffer getMaybeBinaryColumn() { + return maybe_binary_column; + } + + /** Sets the value of the 'maybe_binary_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeBinaryColumn(java.nio.ByteBuffer value) { + validate(fields()[5], value); + this.maybe_binary_column = value; + fieldSetFlags()[5] = true; + return this; + } + + /** Checks whether the 'maybe_binary_column' field has been set */ + public boolean hasMaybeBinaryColumn() { + return fieldSetFlags()[5]; + } + + /** Clears the value of the 'maybe_binary_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeBinaryColumn() { + maybe_binary_column = null; + fieldSetFlags()[5] = false; + return this; + } + + /** Gets the value of the 'maybe_string_column' field */ + public java.lang.String getMaybeStringColumn() { + return maybe_string_column; + } + + /** Sets the value of the 'maybe_string_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeStringColumn(java.lang.String value) { + validate(fields()[6], value); + this.maybe_string_column = value; + fieldSetFlags()[6] = true; + return this; + } + + /** Checks whether the 'maybe_string_column' field has been set */ + public boolean hasMaybeStringColumn() { + return fieldSetFlags()[6]; + } + + /** Clears the value of the 'maybe_string_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeStringColumn() { + maybe_string_column = null; + fieldSetFlags()[6] = false; + return this; + } + + @Override + public AvroOptionalPrimitives build() { + try { + AvroOptionalPrimitives record = new AvroOptionalPrimitives(); + record.maybe_bool_column = fieldSetFlags()[0] ? this.maybe_bool_column : (java.lang.Boolean) defaultValue(fields()[0]); + record.maybe_int_column = fieldSetFlags()[1] ? this.maybe_int_column : (java.lang.Integer) defaultValue(fields()[1]); + record.maybe_long_column = fieldSetFlags()[2] ? this.maybe_long_column : (java.lang.Long) defaultValue(fields()[2]); + record.maybe_float_column = fieldSetFlags()[3] ? this.maybe_float_column : (java.lang.Float) defaultValue(fields()[3]); + record.maybe_double_column = fieldSetFlags()[4] ? this.maybe_double_column : (java.lang.Double) defaultValue(fields()[4]); + record.maybe_binary_column = fieldSetFlags()[5] ? this.maybe_binary_column : (java.nio.ByteBuffer) defaultValue(fields()[5]); + record.maybe_string_column = fieldSetFlags()[6] ? this.maybe_string_column : (java.lang.String) defaultValue(fields()[6]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroPrimitives.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroPrimitives.java new file mode 100644 index 000000000000..1c2afed16781 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroPrimitives.java @@ -0,0 +1,461 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroPrimitives extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroPrimitives\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public boolean bool_column; + @Deprecated public int int_column; + @Deprecated public long long_column; + @Deprecated public float float_column; + @Deprecated public double double_column; + @Deprecated public java.nio.ByteBuffer binary_column; + @Deprecated public java.lang.String string_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroPrimitives() {} + + /** + * All-args constructor. + */ + public AvroPrimitives(java.lang.Boolean bool_column, java.lang.Integer int_column, java.lang.Long long_column, java.lang.Float float_column, java.lang.Double double_column, java.nio.ByteBuffer binary_column, java.lang.String string_column) { + this.bool_column = bool_column; + this.int_column = int_column; + this.long_column = long_column; + this.float_column = float_column; + this.double_column = double_column; + this.binary_column = binary_column; + this.string_column = string_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return bool_column; + case 1: return int_column; + case 2: return long_column; + case 3: return float_column; + case 4: return double_column; + case 5: return binary_column; + case 6: return string_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: bool_column = (java.lang.Boolean)value$; break; + case 1: int_column = (java.lang.Integer)value$; break; + case 2: long_column = (java.lang.Long)value$; break; + case 3: float_column = (java.lang.Float)value$; break; + case 4: double_column = (java.lang.Double)value$; break; + case 5: binary_column = (java.nio.ByteBuffer)value$; break; + case 6: string_column = (java.lang.String)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'bool_column' field. + */ + public java.lang.Boolean getBoolColumn() { + return bool_column; + } + + /** + * Sets the value of the 'bool_column' field. + * @param value the value to set. + */ + public void setBoolColumn(java.lang.Boolean value) { + this.bool_column = value; + } + + /** + * Gets the value of the 'int_column' field. + */ + public java.lang.Integer getIntColumn() { + return int_column; + } + + /** + * Sets the value of the 'int_column' field. + * @param value the value to set. + */ + public void setIntColumn(java.lang.Integer value) { + this.int_column = value; + } + + /** + * Gets the value of the 'long_column' field. + */ + public java.lang.Long getLongColumn() { + return long_column; + } + + /** + * Sets the value of the 'long_column' field. + * @param value the value to set. + */ + public void setLongColumn(java.lang.Long value) { + this.long_column = value; + } + + /** + * Gets the value of the 'float_column' field. + */ + public java.lang.Float getFloatColumn() { + return float_column; + } + + /** + * Sets the value of the 'float_column' field. + * @param value the value to set. + */ + public void setFloatColumn(java.lang.Float value) { + this.float_column = value; + } + + /** + * Gets the value of the 'double_column' field. + */ + public java.lang.Double getDoubleColumn() { + return double_column; + } + + /** + * Sets the value of the 'double_column' field. + * @param value the value to set. + */ + public void setDoubleColumn(java.lang.Double value) { + this.double_column = value; + } + + /** + * Gets the value of the 'binary_column' field. + */ + public java.nio.ByteBuffer getBinaryColumn() { + return binary_column; + } + + /** + * Sets the value of the 'binary_column' field. + * @param value the value to set. + */ + public void setBinaryColumn(java.nio.ByteBuffer value) { + this.binary_column = value; + } + + /** + * Gets the value of the 'string_column' field. + */ + public java.lang.String getStringColumn() { + return string_column; + } + + /** + * Sets the value of the 'string_column' field. + * @param value the value to set. + */ + public void setStringColumn(java.lang.String value) { + this.string_column = value; + } + + /** Creates a new AvroPrimitives RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder(); + } + + /** Creates a new AvroPrimitives RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder(other); + } + + /** Creates a new AvroPrimitives RecordBuilder by copying an existing AvroPrimitives instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder(other); + } + + /** + * RecordBuilder for AvroPrimitives instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private boolean bool_column; + private int int_column; + private long long_column; + private float float_column; + private double double_column; + private java.nio.ByteBuffer binary_column; + private java.lang.String string_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder other) { + super(other); + if (isValidValue(fields()[0], other.bool_column)) { + this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.int_column)) { + this.int_column = data().deepCopy(fields()[1].schema(), other.int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.long_column)) { + this.long_column = data().deepCopy(fields()[2].schema(), other.long_column); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.float_column)) { + this.float_column = data().deepCopy(fields()[3].schema(), other.float_column); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.double_column)) { + this.double_column = data().deepCopy(fields()[4].schema(), other.double_column); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.binary_column)) { + this.binary_column = data().deepCopy(fields()[5].schema(), other.binary_column); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.string_column)) { + this.string_column = data().deepCopy(fields()[6].schema(), other.string_column); + fieldSetFlags()[6] = true; + } + } + + /** Creates a Builder by copying an existing AvroPrimitives instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.SCHEMA$); + if (isValidValue(fields()[0], other.bool_column)) { + this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.int_column)) { + this.int_column = data().deepCopy(fields()[1].schema(), other.int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.long_column)) { + this.long_column = data().deepCopy(fields()[2].schema(), other.long_column); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.float_column)) { + this.float_column = data().deepCopy(fields()[3].schema(), other.float_column); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.double_column)) { + this.double_column = data().deepCopy(fields()[4].schema(), other.double_column); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.binary_column)) { + this.binary_column = data().deepCopy(fields()[5].schema(), other.binary_column); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.string_column)) { + this.string_column = data().deepCopy(fields()[6].schema(), other.string_column); + fieldSetFlags()[6] = true; + } + } + + /** Gets the value of the 'bool_column' field */ + public java.lang.Boolean getBoolColumn() { + return bool_column; + } + + /** Sets the value of the 'bool_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setBoolColumn(boolean value) { + validate(fields()[0], value); + this.bool_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'bool_column' field has been set */ + public boolean hasBoolColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'bool_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearBoolColumn() { + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'int_column' field */ + public java.lang.Integer getIntColumn() { + return int_column; + } + + /** Sets the value of the 'int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setIntColumn(int value) { + validate(fields()[1], value); + this.int_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'int_column' field has been set */ + public boolean hasIntColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearIntColumn() { + fieldSetFlags()[1] = false; + return this; + } + + /** Gets the value of the 'long_column' field */ + public java.lang.Long getLongColumn() { + return long_column; + } + + /** Sets the value of the 'long_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setLongColumn(long value) { + validate(fields()[2], value); + this.long_column = value; + fieldSetFlags()[2] = true; + return this; + } + + /** Checks whether the 'long_column' field has been set */ + public boolean hasLongColumn() { + return fieldSetFlags()[2]; + } + + /** Clears the value of the 'long_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearLongColumn() { + fieldSetFlags()[2] = false; + return this; + } + + /** Gets the value of the 'float_column' field */ + public java.lang.Float getFloatColumn() { + return float_column; + } + + /** Sets the value of the 'float_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setFloatColumn(float value) { + validate(fields()[3], value); + this.float_column = value; + fieldSetFlags()[3] = true; + return this; + } + + /** Checks whether the 'float_column' field has been set */ + public boolean hasFloatColumn() { + return fieldSetFlags()[3]; + } + + /** Clears the value of the 'float_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearFloatColumn() { + fieldSetFlags()[3] = false; + return this; + } + + /** Gets the value of the 'double_column' field */ + public java.lang.Double getDoubleColumn() { + return double_column; + } + + /** Sets the value of the 'double_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setDoubleColumn(double value) { + validate(fields()[4], value); + this.double_column = value; + fieldSetFlags()[4] = true; + return this; + } + + /** Checks whether the 'double_column' field has been set */ + public boolean hasDoubleColumn() { + return fieldSetFlags()[4]; + } + + /** Clears the value of the 'double_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearDoubleColumn() { + fieldSetFlags()[4] = false; + return this; + } + + /** Gets the value of the 'binary_column' field */ + public java.nio.ByteBuffer getBinaryColumn() { + return binary_column; + } + + /** Sets the value of the 'binary_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setBinaryColumn(java.nio.ByteBuffer value) { + validate(fields()[5], value); + this.binary_column = value; + fieldSetFlags()[5] = true; + return this; + } + + /** Checks whether the 'binary_column' field has been set */ + public boolean hasBinaryColumn() { + return fieldSetFlags()[5]; + } + + /** Clears the value of the 'binary_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearBinaryColumn() { + binary_column = null; + fieldSetFlags()[5] = false; + return this; + } + + /** Gets the value of the 'string_column' field */ + public java.lang.String getStringColumn() { + return string_column; + } + + /** Sets the value of the 'string_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setStringColumn(java.lang.String value) { + validate(fields()[6], value); + this.string_column = value; + fieldSetFlags()[6] = true; + return this; + } + + /** Checks whether the 'string_column' field has been set */ + public boolean hasStringColumn() { + return fieldSetFlags()[6]; + } + + /** Clears the value of the 'string_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearStringColumn() { + string_column = null; + fieldSetFlags()[6] = false; + return this; + } + + @Override + public AvroPrimitives build() { + try { + AvroPrimitives record = new AvroPrimitives(); + record.bool_column = fieldSetFlags()[0] ? this.bool_column : (java.lang.Boolean) defaultValue(fields()[0]); + record.int_column = fieldSetFlags()[1] ? this.int_column : (java.lang.Integer) defaultValue(fields()[1]); + record.long_column = fieldSetFlags()[2] ? this.long_column : (java.lang.Long) defaultValue(fields()[2]); + record.float_column = fieldSetFlags()[3] ? this.float_column : (java.lang.Float) defaultValue(fields()[3]); + record.double_column = fieldSetFlags()[4] ? this.double_column : (java.lang.Double) defaultValue(fields()[4]); + record.binary_column = fieldSetFlags()[5] ? this.binary_column : (java.nio.ByteBuffer) defaultValue(fields()[5]); + record.string_column = fieldSetFlags()[6] ? this.string_column : (java.lang.String) defaultValue(fields()[6]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java new file mode 100644 index 000000000000..28fdc1dfb911 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java @@ -0,0 +1,17 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; + +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public interface CompatibilityTest { + public static final org.apache.avro.Protocol PROTOCOL = org.apache.avro.Protocol.parse("{\"protocol\":\"CompatibilityTest\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"types\":[{\"type\":\"enum\",\"name\":\"Suit\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]},{\"type\":\"record\",\"name\":\"ParquetEnum\",\"fields\":[{\"name\":\"suit\",\"type\":\"Suit\"}]},{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"AvroPrimitives\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"AvroOptionalPrimitives\",\"fields\":[{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]}]},{\"type\":\"record\",\"name\":\"AvroNonNullableArrays\",\"fields\":[{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"maybe_ints_column\",\"type\":[\"null\",{\"type\":\"array\",\"items\":\"int\"}]}]},{\"type\":\"record\",\"name\":\"AvroArrayOfArray\",\"fields\":[{\"name\":\"int_arrays_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"array\",\"items\":\"int\"}}}]},{\"type\":\"record\",\"name\":\"AvroMapOfArray\",\"fields\":[{\"name\":\"string_to_ints_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"int\"},\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"fields\":[{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"Nested\"},\"avro.java.string\":\"String\"}}]}],\"messages\":{}}"); + + @SuppressWarnings("all") + public interface Callback extends CompatibilityTest { + public static final org.apache.avro.Protocol PROTOCOL = org.apache.spark.sql.execution.datasources.parquet.test.avro.CompatibilityTest.PROTOCOL; + } +} \ No newline at end of file diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java new file mode 100644 index 000000000000..a7bf4841919c --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java @@ -0,0 +1,196 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class Nested extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"Nested\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.List nested_ints_column; + @Deprecated public java.lang.String nested_string_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public Nested() {} + + /** + * All-args constructor. + */ + public Nested(java.util.List nested_ints_column, java.lang.String nested_string_column) { + this.nested_ints_column = nested_ints_column; + this.nested_string_column = nested_string_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return nested_ints_column; + case 1: return nested_string_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: nested_ints_column = (java.util.List)value$; break; + case 1: nested_string_column = (java.lang.String)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'nested_ints_column' field. + */ + public java.util.List getNestedIntsColumn() { + return nested_ints_column; + } + + /** + * Sets the value of the 'nested_ints_column' field. + * @param value the value to set. + */ + public void setNestedIntsColumn(java.util.List value) { + this.nested_ints_column = value; + } + + /** + * Gets the value of the 'nested_string_column' field. + */ + public java.lang.String getNestedStringColumn() { + return nested_string_column; + } + + /** + * Sets the value of the 'nested_string_column' field. + * @param value the value to set. + */ + public void setNestedStringColumn(java.lang.String value) { + this.nested_string_column = value; + } + + /** Creates a new Nested RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder(); + } + + /** Creates a new Nested RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder(other); + } + + /** Creates a new Nested RecordBuilder by copying an existing Nested instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder(other); + } + + /** + * RecordBuilder for Nested instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.List nested_ints_column; + private java.lang.String nested_string_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder other) { + super(other); + if (isValidValue(fields()[0], other.nested_ints_column)) { + this.nested_ints_column = data().deepCopy(fields()[0].schema(), other.nested_ints_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.nested_string_column)) { + this.nested_string_column = data().deepCopy(fields()[1].schema(), other.nested_string_column); + fieldSetFlags()[1] = true; + } + } + + /** Creates a Builder by copying an existing Nested instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.SCHEMA$); + if (isValidValue(fields()[0], other.nested_ints_column)) { + this.nested_ints_column = data().deepCopy(fields()[0].schema(), other.nested_ints_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.nested_string_column)) { + this.nested_string_column = data().deepCopy(fields()[1].schema(), other.nested_string_column); + fieldSetFlags()[1] = true; + } + } + + /** Gets the value of the 'nested_ints_column' field */ + public java.util.List getNestedIntsColumn() { + return nested_ints_column; + } + + /** Sets the value of the 'nested_ints_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder setNestedIntsColumn(java.util.List value) { + validate(fields()[0], value); + this.nested_ints_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'nested_ints_column' field has been set */ + public boolean hasNestedIntsColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'nested_ints_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder clearNestedIntsColumn() { + nested_ints_column = null; + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'nested_string_column' field */ + public java.lang.String getNestedStringColumn() { + return nested_string_column; + } + + /** Sets the value of the 'nested_string_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder setNestedStringColumn(java.lang.String value) { + validate(fields()[1], value); + this.nested_string_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'nested_string_column' field has been set */ + public boolean hasNestedStringColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'nested_string_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder clearNestedStringColumn() { + nested_string_column = null; + fieldSetFlags()[1] = false; + return this; + } + + @Override + public Nested build() { + try { + Nested record = new Nested(); + record.nested_ints_column = fieldSetFlags()[0] ? this.nested_ints_column : (java.util.List) defaultValue(fields()[0]); + record.nested_string_column = fieldSetFlags()[1] ? this.nested_string_column : (java.lang.String) defaultValue(fields()[1]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java new file mode 100644 index 000000000000..ef12d193f916 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java @@ -0,0 +1,250 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class ParquetAvroCompat extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}},\"avro.java.string\":\"String\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.List strings_column; + @Deprecated public java.util.Map string_to_int_column; + @Deprecated public java.util.Map> complex_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public ParquetAvroCompat() {} + + /** + * All-args constructor. + */ + public ParquetAvroCompat(java.util.List strings_column, java.util.Map string_to_int_column, java.util.Map> complex_column) { + this.strings_column = strings_column; + this.string_to_int_column = string_to_int_column; + this.complex_column = complex_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return strings_column; + case 1: return string_to_int_column; + case 2: return complex_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: strings_column = (java.util.List)value$; break; + case 1: string_to_int_column = (java.util.Map)value$; break; + case 2: complex_column = (java.util.Map>)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'strings_column' field. + */ + public java.util.List getStringsColumn() { + return strings_column; + } + + /** + * Sets the value of the 'strings_column' field. + * @param value the value to set. + */ + public void setStringsColumn(java.util.List value) { + this.strings_column = value; + } + + /** + * Gets the value of the 'string_to_int_column' field. + */ + public java.util.Map getStringToIntColumn() { + return string_to_int_column; + } + + /** + * Sets the value of the 'string_to_int_column' field. + * @param value the value to set. + */ + public void setStringToIntColumn(java.util.Map value) { + this.string_to_int_column = value; + } + + /** + * Gets the value of the 'complex_column' field. + */ + public java.util.Map> getComplexColumn() { + return complex_column; + } + + /** + * Sets the value of the 'complex_column' field. + * @param value the value to set. + */ + public void setComplexColumn(java.util.Map> value) { + this.complex_column = value; + } + + /** Creates a new ParquetAvroCompat RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder(); + } + + /** Creates a new ParquetAvroCompat RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder(other); + } + + /** Creates a new ParquetAvroCompat RecordBuilder by copying an existing ParquetAvroCompat instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder(other); + } + + /** + * RecordBuilder for ParquetAvroCompat instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.List strings_column; + private java.util.Map string_to_int_column; + private java.util.Map> complex_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder other) { + super(other); + if (isValidValue(fields()[0], other.strings_column)) { + this.strings_column = data().deepCopy(fields()[0].schema(), other.strings_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.string_to_int_column)) { + this.string_to_int_column = data().deepCopy(fields()[1].schema(), other.string_to_int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.complex_column)) { + this.complex_column = data().deepCopy(fields()[2].schema(), other.complex_column); + fieldSetFlags()[2] = true; + } + } + + /** Creates a Builder by copying an existing ParquetAvroCompat instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.SCHEMA$); + if (isValidValue(fields()[0], other.strings_column)) { + this.strings_column = data().deepCopy(fields()[0].schema(), other.strings_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.string_to_int_column)) { + this.string_to_int_column = data().deepCopy(fields()[1].schema(), other.string_to_int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.complex_column)) { + this.complex_column = data().deepCopy(fields()[2].schema(), other.complex_column); + fieldSetFlags()[2] = true; + } + } + + /** Gets the value of the 'strings_column' field */ + public java.util.List getStringsColumn() { + return strings_column; + } + + /** Sets the value of the 'strings_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setStringsColumn(java.util.List value) { + validate(fields()[0], value); + this.strings_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'strings_column' field has been set */ + public boolean hasStringsColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'strings_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearStringsColumn() { + strings_column = null; + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'string_to_int_column' field */ + public java.util.Map getStringToIntColumn() { + return string_to_int_column; + } + + /** Sets the value of the 'string_to_int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setStringToIntColumn(java.util.Map value) { + validate(fields()[1], value); + this.string_to_int_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'string_to_int_column' field has been set */ + public boolean hasStringToIntColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'string_to_int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearStringToIntColumn() { + string_to_int_column = null; + fieldSetFlags()[1] = false; + return this; + } + + /** Gets the value of the 'complex_column' field */ + public java.util.Map> getComplexColumn() { + return complex_column; + } + + /** Sets the value of the 'complex_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setComplexColumn(java.util.Map> value) { + validate(fields()[2], value); + this.complex_column = value; + fieldSetFlags()[2] = true; + return this; + } + + /** Checks whether the 'complex_column' field has been set */ + public boolean hasComplexColumn() { + return fieldSetFlags()[2]; + } + + /** Clears the value of the 'complex_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearComplexColumn() { + complex_column = null; + fieldSetFlags()[2] = false; + return this; + } + + @Override + public ParquetAvroCompat build() { + try { + ParquetAvroCompat record = new ParquetAvroCompat(); + record.strings_column = fieldSetFlags()[0] ? this.strings_column : (java.util.List) defaultValue(fields()[0]); + record.string_to_int_column = fieldSetFlags()[1] ? this.string_to_int_column : (java.util.Map) defaultValue(fields()[1]); + record.complex_column = fieldSetFlags()[2] ? this.complex_column : (java.util.Map>) defaultValue(fields()[2]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetEnum.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetEnum.java new file mode 100644 index 000000000000..05fefe4cee75 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetEnum.java @@ -0,0 +1,142 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class ParquetEnum extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ParquetEnum\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"suit\",\"type\":{\"type\":\"enum\",\"name\":\"Suit\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit suit; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public ParquetEnum() {} + + /** + * All-args constructor. + */ + public ParquetEnum(org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit suit) { + this.suit = suit; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return suit; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: suit = (org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'suit' field. + */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit getSuit() { + return suit; + } + + /** + * Sets the value of the 'suit' field. + * @param value the value to set. + */ + public void setSuit(org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit value) { + this.suit = value; + } + + /** Creates a new ParquetEnum RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder(); + } + + /** Creates a new ParquetEnum RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder(other); + } + + /** Creates a new ParquetEnum RecordBuilder by copying an existing ParquetEnum instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder(other); + } + + /** + * RecordBuilder for ParquetEnum instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit suit; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder other) { + super(other); + if (isValidValue(fields()[0], other.suit)) { + this.suit = data().deepCopy(fields()[0].schema(), other.suit); + fieldSetFlags()[0] = true; + } + } + + /** Creates a Builder by copying an existing ParquetEnum instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.SCHEMA$); + if (isValidValue(fields()[0], other.suit)) { + this.suit = data().deepCopy(fields()[0].schema(), other.suit); + fieldSetFlags()[0] = true; + } + } + + /** Gets the value of the 'suit' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit getSuit() { + return suit; + } + + /** Sets the value of the 'suit' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder setSuit(org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit value) { + validate(fields()[0], value); + this.suit = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'suit' field has been set */ + public boolean hasSuit() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'suit' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder clearSuit() { + suit = null; + fieldSetFlags()[0] = false; + return this; + } + + @Override + public ParquetEnum build() { + try { + ParquetEnum record = new ParquetEnum(); + record.suit = fieldSetFlags()[0] ? this.suit : (org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit) defaultValue(fields()[0]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Suit.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Suit.java new file mode 100644 index 000000000000..00711a0c2a26 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Suit.java @@ -0,0 +1,13 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public enum Suit { + SPADES, HEARTS, DIAMONDS, CLUBS ; + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"enum\",\"name\":\"Suit\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index fcb8f5499cf8..7b50aad4ad49 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -18,21 +18,27 @@ package test.org.apache.spark.sql; import java.io.Serializable; +import java.math.BigDecimal; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import org.apache.spark.sql.test.TestSQLContext$; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; -import org.apache.spark.sql.*; -import org.apache.spark.sql.types.*; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -43,14 +49,16 @@ public class JavaApplySchemaSuite implements Serializable { @Before public void setUp() { - sqlContext = TestSQLContext$.MODULE$; - javaCtx = new JavaSparkContext(sqlContext.sparkContext()); + SparkContext context = new SparkContext("local[*]", "testing"); + javaCtx = new JavaSparkContext(context); + sqlContext = new SQLContext(context); } @After public void tearDown() { - javaCtx = null; + sqlContext.sparkContext().stop(); sqlContext = null; + javaCtx = null; } public static class Person implements Serializable { @@ -76,7 +84,7 @@ public void setAge(int age) { @Test public void applySchema() { - List personList = new ArrayList(2); + List personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); @@ -88,12 +96,13 @@ public void applySchema() { JavaRDD rowRDD = javaCtx.parallelize(personList).map( new Function() { + @Override public Row call(Person person) throws Exception { return RowFactory.create(person.getName(), person.getAge()); } }); - List fields = new ArrayList(2); + List fields = new ArrayList<>(2); fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); @@ -111,7 +120,7 @@ public Row call(Person person) throws Exception { @Test public void dataFrameRDDOperations() { - List personList = new ArrayList(2); + List personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); @@ -122,27 +131,28 @@ public void dataFrameRDDOperations() { personList.add(person2); JavaRDD rowRDD = javaCtx.parallelize(personList).map( - new Function() { - public Row call(Person person) throws Exception { - return RowFactory.create(person.getName(), person.getAge()); - } - }); - - List fields = new ArrayList(2); - fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); + new Function() { + @Override + public Row call(Person person) { + return RowFactory.create(person.getName(), person.getAge()); + } + }); + + List fields = new ArrayList<>(2); + fields.add(DataTypes.createStructField("", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); DataFrame df = sqlContext.applySchema(rowRDD, schema); df.registerTempTable("people"); List actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function() { - + @Override public String call(Row row) { - return row.getString(0) + "_" + row.get(1).toString(); + return row.getString(0) + "_" + row.get(1); } }).collect(); - List expected = new ArrayList(2); + List expected = new ArrayList<>(2); expected.add("Michael_29"); expected.add("Yin_28"); @@ -158,8 +168,9 @@ public void applySchemaToJSON() { "{\"string\":\"this is another simple string.\", \"integer\":11, \"long\":21474836469, " + "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + "\"boolean\":false, \"null\":null}")); - List fields = new ArrayList(7); - fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(), true)); + List fields = new ArrayList<>(7); + fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(20, 0), + true)); fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true)); fields.add(DataTypes.createStructField("double", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("integer", DataTypes.LongType, true)); @@ -167,10 +178,10 @@ public void applySchemaToJSON() { fields.add(DataTypes.createStructField("null", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("string", DataTypes.StringType, true)); StructType expectedSchema = DataTypes.createStructType(fields); - List expectedResult = new ArrayList(2); + List expectedResult = new ArrayList<>(2); expectedResult.add( RowFactory.create( - new java.math.BigDecimal("92233720368547758070"), + new BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, 10, @@ -179,7 +190,7 @@ public void applySchemaToJSON() { "this is a simple string.")); expectedResult.add( RowFactory.create( - new java.math.BigDecimal("92233720368547758069"), + new BigDecimal("92233720368547758069"), false, 1.7976931348623157E305, 11, diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 72c42f4fe376..5f9abd4999ce 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -17,51 +17,51 @@ package test.org.apache.spark.sql; +import java.io.Serializable; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Map; + +import scala.collection.JavaConverters; +import scala.collection.Seq; + import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Ints; +import org.junit.*; +import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; +import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.test.TestSQLContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.*; -import org.junit.*; - -import scala.collection.JavaConversions; -import scala.collection.Seq; -import scala.collection.mutable.Buffer; - -import java.io.Serializable; -import java.util.Arrays; -import java.util.Comparator; -import java.util.List; -import java.util.Map; - -import static org.apache.spark.sql.functions.*; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; - private transient SQLContext context; + private transient TestSQLContext context; @Before public void setUp() { // Trigger static initializer of TestData - TestData$.MODULE$.testData(); - jsc = new JavaSparkContext(TestSQLContext.sparkContext()); - context = TestSQLContext$.MODULE$; + SparkContext sc = new SparkContext("local[*]", "testing"); + jsc = new JavaSparkContext(sc); + context = new TestSQLContext(sc); + context.loadTestData(); } @After public void tearDown() { - jsc = null; + context.sparkContext().stop(); context = null; + jsc = null; } @Test public void testExecution() { DataFrame df = context.table("testData").filter("key = 1"); - Assert.assertEquals(df.select("key").collect()[0].get(0), 1); + Assert.assertEquals(1, df.select("key").collect()[0].get(0)); } /** @@ -90,13 +90,14 @@ public void testVarargMethods() { df.groupBy().mean("key"); df.groupBy().max("key"); df.groupBy().min("key"); + df.groupBy().stddev("key"); df.groupBy().sum("key"); // Varargs in column expressions df.groupBy().agg(countDistinct("key", "value")); df.groupBy().agg(countDistinct(col("key"), col("value"))); df.select(coalesce(col("key"))); - + // Varargs with mathfunctions DataFrame df2 = context.table("testData2"); df2.select(exp("a"), exp("b")); @@ -119,7 +120,7 @@ public void testShow() { public static class Bean implements Serializable { private double a = 0.0; - private Integer[] b = new Integer[]{0, 1}; + private Integer[] b = { 0, 1 }; private Map c = ImmutableMap.of("hello", new int[] { 1, 2 }); private List d = Arrays.asList("floppy", "disk"); @@ -161,17 +162,18 @@ public void testCreateDataFrameFromJavaBeans() { schema.apply("d")); Row first = df.select("a", "b", "c", "d").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); - // Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below, + // Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below, // verify that it has the expected length, and contains expected elements. Seq result = first.getAs(1); Assert.assertEquals(bean.getB().length, result.length()); for (int i = 0; i < result.length(); i++) { Assert.assertEquals(bean.getB()[i], result.apply(i)); } - Buffer outputBuffer = (Buffer) first.getJavaMap(2).get("hello"); + @SuppressWarnings("unchecked") + Seq outputBuffer = (Seq) first.getJavaMap(2).get("hello"); Assert.assertArrayEquals( bean.getC().get("hello"), - Ints.toArray(JavaConversions.bufferAsJavaList(outputBuffer))); + Ints.toArray(JavaConverters.seqAsJavaListConverter(outputBuffer).asJava())); Seq d = first.getAs(3); Assert.assertEquals(bean.getD().size(), d.length()); for (int i = 0; i < d.length(); i++) { @@ -179,7 +181,8 @@ public void testCreateDataFrameFromJavaBeans() { } } - private static Comparator CrosstabRowComparator = new Comparator() { + private static final Comparator crosstabRowComparator = new Comparator() { + @Override public int compare(Row row1, Row row2) { String item1 = row1.getString(0); String item2 = row2.getString(0); @@ -192,24 +195,24 @@ public void testCrosstab() { DataFrame df = context.table("testData2"); DataFrame crosstab = df.stat().crosstab("a", "b"); String[] columnNames = crosstab.schema().fieldNames(); - Assert.assertEquals(columnNames[0], "a_b"); - Assert.assertEquals(columnNames[1], "1"); - Assert.assertEquals(columnNames[2], "2"); + Assert.assertEquals("a_b", columnNames[0]); + Assert.assertEquals("1", columnNames[1]); + Assert.assertEquals("2", columnNames[2]); Row[] rows = crosstab.collect(); - Arrays.sort(rows, CrosstabRowComparator); + Arrays.sort(rows, crosstabRowComparator); Integer count = 1; for (Row row : rows) { Assert.assertEquals(row.get(0).toString(), count.toString()); - Assert.assertEquals(row.getLong(1), 1L); - Assert.assertEquals(row.getLong(2), 1L); + Assert.assertEquals(1L, row.getLong(1)); + Assert.assertEquals(1L, row.getLong(2)); count++; } } - + @Test public void testFrequentItems() { DataFrame df = context.table("testData2"); - String[] cols = new String[]{"a"}; + String[] cols = {"a"}; DataFrame results = df.stat().freqItems(cols, 0.2); Assert.assertTrue(results.collect()[0].getSeq(0).contains(1)); } @@ -218,13 +221,22 @@ public void testFrequentItems() { public void testCorrelation() { DataFrame df = context.table("testData2"); Double pearsonCorr = df.stat().corr("a", "b", "pearson"); - Assert.assertTrue(Math.abs(pearsonCorr) < 1e-6); + Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6); } @Test public void testCovariance() { DataFrame df = context.table("testData2"); Double result = df.stat().cov("a", "b"); - Assert.assertTrue(Math.abs(result) < 1e-6); + Assert.assertTrue(Math.abs(result) < 1.0e-6); + } + + @Test + public void testSampleBy() { + DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); + DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); + Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); + Row[] expected = {RowFactory.create(0, 5), RowFactory.create(1, 8)}; + Assert.assertArrayEquals(expected, actual); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java index 4ce1d1dddb26..3ab4db2a035d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java @@ -18,6 +18,7 @@ package test.org.apache.spark.sql; import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; import java.sql.Date; import java.sql.Timestamp; import java.util.Arrays; @@ -52,12 +53,12 @@ public void setUp() { shortValue = (short)32767; intValue = 2147483647; longValue = 9223372036854775807L; - floatValue = (float)3.4028235E38; + floatValue = 3.4028235E38f; doubleValue = 1.7976931348623157E308; decimalValue = new BigDecimal("1.7976931348623157E328"); booleanValue = true; stringValue = "this is a string"; - binaryValue = stringValue.getBytes(); + binaryValue = stringValue.getBytes(StandardCharsets.UTF_8); dateValue = Date.valueOf("2014-06-30"); timestampValue = Timestamp.valueOf("2014-06-30 09:20:00.0"); } @@ -123,8 +124,8 @@ public void constructSimpleRow() { Assert.assertEquals(binaryValue, simpleRow.get(16)); Assert.assertEquals(dateValue, simpleRow.get(17)); Assert.assertEquals(timestampValue, simpleRow.get(18)); - Assert.assertEquals(true, simpleRow.isNullAt(19)); - Assert.assertEquals(null, simpleRow.get(19)); + Assert.assertTrue(simpleRow.isNullAt(19)); + Assert.assertNull(simpleRow.get(19)); } @Test @@ -134,7 +135,7 @@ public void constructComplexRow() { stringValue + " (1)", stringValue + " (2)", stringValue + "(3)"); // Simple map - Map simpleMap = new HashMap(); + Map simpleMap = new HashMap<>(); simpleMap.put(stringValue + " (1)", longValue); simpleMap.put(stringValue + " (2)", longValue - 1); simpleMap.put(stringValue + " (3)", longValue - 2); @@ -149,7 +150,7 @@ public void constructComplexRow() { List arrayOfRows = Arrays.asList(simpleStruct); // Complex map - Map, Row> complexMap = new HashMap, Row>(); + Map, Row> complexMap = new HashMap<>(); complexMap.put(arrayOfRows, simpleStruct); // Complex struct @@ -167,7 +168,7 @@ public void constructComplexRow() { Assert.assertEquals(arrayOfMaps, complexStruct.get(3)); Assert.assertEquals(arrayOfRows, complexStruct.get(4)); Assert.assertEquals(complexMap, complexStruct.get(5)); - Assert.assertEquals(null, complexStruct.get(6)); + Assert.assertNull(complexStruct.get(6)); // A very complex row Row complexRow = RowFactory.create(arrayOfMaps, arrayOfRows, complexMap, complexStruct); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 79d92734ff37..4a78dca7fea6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -20,15 +20,16 @@ import java.io.Serializable; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.apache.spark.SparkContext; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.api.java.UDF1; import org.apache.spark.sql.api.java.UDF2; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.DataTypes; // The test suite itself is Serializable so that anonymous Function implementations can be @@ -40,12 +41,16 @@ public class JavaUDFSuite implements Serializable { @Before public void setUp() { - sqlContext = TestSQLContext$.MODULE$; - sc = new JavaSparkContext(sqlContext.sparkContext()); + SparkContext _sc = new SparkContext("local[*]", "testing"); + sqlContext = new SQLContext(_sc); + sc = new JavaSparkContext(_sc); } @After public void tearDown() { + sqlContext.sparkContext().stop(); + sqlContext = null; + sc = null; } @SuppressWarnings("unchecked") @@ -57,13 +62,13 @@ public void udf1Test() { sqlContext.udf().register("stringLengthTest", new UDF1() { @Override - public Integer call(String str) throws Exception { + public Integer call(String str) { return str.length(); } }, DataTypes.IntegerType); Row result = sqlContext.sql("SELECT stringLengthTest('test')").head(); - assert(result.getInt(0) == 4); + Assert.assertEquals(4, result.getInt(0)); } @SuppressWarnings("unchecked") @@ -77,12 +82,12 @@ public void udf2Test() { sqlContext.udf().register("stringLengthTest", new UDF2() { @Override - public Integer call(String str1, String str2) throws Exception { + public Integer call(String str1, String str2) { return str1.length() + str2.length(); } }, DataTypes.IntegerType); Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head(); - assert(result.getInt(0) == 9); + Assert.assertEquals(9, result.getInt(0)); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 2706e01bd28a..9e241f20987c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -21,13 +21,14 @@ import java.io.IOException; import java.util.*; +import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.*; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; @@ -43,7 +44,7 @@ public class JavaSaveLoadSuite { File path; DataFrame df; - private void checkAnswer(DataFrame actual, List expected) { + private static void checkAnswer(DataFrame actual, List expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -52,8 +53,9 @@ private void checkAnswer(DataFrame actual, List expected) { @Before public void setUp() throws IOException { - sqlContext = TestSQLContext$.MODULE$; - sc = new JavaSparkContext(sqlContext.sparkContext()); + SparkContext _sc = new SparkContext("local[*]", "testing"); + sqlContext = new SQLContext(_sc); + sc = new JavaSparkContext(_sc); originalDefaultSource = sqlContext.conf().defaultDataSourceName(); path = @@ -62,7 +64,7 @@ public void setUp() throws IOException { path.delete(); } - List jsonObjects = new ArrayList(10); + List jsonObjects = new ArrayList<>(10); for (int i = 0; i < 10; i++) { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } @@ -71,9 +73,16 @@ public void setUp() throws IOException { df.registerTempTable("jsonTable"); } + @After + public void tearDown() { + sqlContext.sparkContext().stop(); + sqlContext = null; + sc = null; + } + @Test public void saveAndLoad() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save(); DataFrame loadedDF = sqlContext.read().format("json").options(options).load(); @@ -82,11 +91,11 @@ public void saveAndLoad() { @Test public void saveAndLoadWithSchema() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write().format("json").mode(SaveMode.ErrorIfExists).options(options).save(); - List fields = new ArrayList(); + List fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); DataFrame loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 000000000000..cfd7889b4ac2 --- /dev/null +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1,3 @@ +org.apache.spark.sql.sources.FakeSourceOne +org.apache.spark.sql.sources.FakeSourceTwo +org.apache.spark.sql.sources.FakeSourceThree diff --git a/sql/core/src/test/resources/nested-array-struct.parquet b/sql/core/src/test/resources/nested-array-struct.parquet new file mode 100644 index 000000000000..41a43fa35d39 Binary files /dev/null and b/sql/core/src/test/resources/nested-array-struct.parquet differ diff --git a/sql/core/src/test/resources/old-repeated-int.parquet b/sql/core/src/test/resources/old-repeated-int.parquet new file mode 100644 index 000000000000..520922f73ebb Binary files /dev/null and b/sql/core/src/test/resources/old-repeated-int.parquet differ diff --git a/sql/core/src/test/resources/old-repeated-message.parquet b/sql/core/src/test/resources/old-repeated-message.parquet new file mode 100644 index 000000000000..548db9916277 Binary files /dev/null and b/sql/core/src/test/resources/old-repeated-message.parquet differ diff --git a/sql/core/src/test/resources/old-repeated.parquet b/sql/core/src/test/resources/old-repeated.parquet new file mode 100644 index 000000000000..213f1a90291b Binary files /dev/null and b/sql/core/src/test/resources/old-repeated.parquet differ diff --git a/sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet b/sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet new file mode 100644 index 000000000000..837e4876eea6 Binary files /dev/null and b/sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet differ diff --git a/sql/core/src/test/resources/proto-repeated-string.parquet b/sql/core/src/test/resources/proto-repeated-string.parquet new file mode 100644 index 000000000000..8a7eea601d01 Binary files /dev/null and b/sql/core/src/test/resources/proto-repeated-string.parquet differ diff --git a/sql/core/src/test/resources/proto-repeated-struct.parquet b/sql/core/src/test/resources/proto-repeated-struct.parquet new file mode 100644 index 000000000000..c29eee35c350 Binary files /dev/null and b/sql/core/src/test/resources/proto-repeated-struct.parquet differ diff --git a/sql/core/src/test/resources/proto-struct-with-array-many.parquet b/sql/core/src/test/resources/proto-struct-with-array-many.parquet new file mode 100644 index 000000000000..ff9809675fc0 Binary files /dev/null and b/sql/core/src/test/resources/proto-struct-with-array-many.parquet differ diff --git a/sql/core/src/test/resources/proto-struct-with-array.parquet b/sql/core/src/test/resources/proto-struct-with-array.parquet new file mode 100644 index 000000000000..325a8370ad20 Binary files /dev/null and b/sql/core/src/test/resources/proto-struct-with-array.parquet differ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index eb3e91332206..356d4ff3fa83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -17,27 +17,26 @@ package org.apache.spark.sql +import org.apache.spark.sql.execution.PhysicalRDD + import scala.concurrent.duration._ -import scala.language.{implicitConversions, postfixOps} +import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.storage.{StorageLevel, RDDBlockId} -case class BigData(s: String) - -class CachedTableSuite extends QueryTest { - TestData // Load test tables. +private case class BigData(s: String) - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.sql +class CachedTableSuite extends QueryTest with SharedSQLContext { + import testImplicits._ def rddIdOf(tableName: String): Int = { - val executedPlan = ctx.table(tableName).queryExecution.executedPlan + val executedPlan = sqlContext.table(tableName).queryExecution.executedPlan executedPlan.collect { case InMemoryColumnarTableScan(_, _, relation) => relation.cachedColumnBuffers.id @@ -47,47 +46,66 @@ class CachedTableSuite extends QueryTest { } def isMaterialized(rddId: Int): Boolean = { - ctx.sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty + sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty + } + + test("withColumn doesn't invalidate cached dataframe") { + var evalCount = 0 + val myUDF = udf((x: String) => { evalCount += 1; "result" }) + val df = Seq(("test", 1)).toDF("s", "i").select(myUDF($"s")) + df.cache() + + df.collect() + assert(evalCount === 1) + + df.collect() + assert(evalCount === 1) + + val df2 = df.withColumn("newColumn", lit(1)) + df2.collect() + + // We should not reevaluate the cached dataframe + assert(evalCount === 1) } test("cache temp table") { testData.select('key).registerTempTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0) - ctx.cacheTable("tempTable") + sqlContext.cacheTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - ctx.uncacheTable("tempTable") + sqlContext.uncacheTable("tempTable") } test("unpersist an uncached table will not raise exception") { - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) testData.persist() - assert(None != ctx.cacheManager.lookupCachedData(testData)) + assert(None != sqlContext.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) } test("cache table as select") { sql("CACHE TABLE tempTable AS SELECT key FROM testData") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - ctx.uncacheTable("tempTable") + sqlContext.uncacheTable("tempTable") } test("uncaching temp table") { testData.select('key).registerTempTable("tempTable1") testData.select('key).registerTempTable("tempTable2") - ctx.cacheTable("tempTable1") + sqlContext.cacheTable("tempTable1") assertCached(sql("SELECT COUNT(*) FROM tempTable1")) assertCached(sql("SELECT COUNT(*) FROM tempTable2")) // Is this valid? - ctx.uncacheTable("tempTable2") + sqlContext.uncacheTable("tempTable2") // Should this be cached? assertCached(sql("SELECT COUNT(*) FROM tempTable1"), 0) @@ -95,103 +113,103 @@ class CachedTableSuite extends QueryTest { test("too big for memory") { val data = "*" * 1000 - ctx.sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() + sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() .registerTempTable("bigData") - ctx.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) - assert(ctx.table("bigData").count() === 200000L) - ctx.table("bigData").unpersist(blocking = true) + sqlContext.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) + assert(sqlContext.table("bigData").count() === 200000L) + sqlContext.table("bigData").unpersist(blocking = true) } test("calling .cache() should use in-memory columnar caching") { - ctx.table("testData").cache() - assertCached(ctx.table("testData")) - ctx.table("testData").unpersist(blocking = true) + sqlContext.table("testData").cache() + assertCached(sqlContext.table("testData")) + sqlContext.table("testData").unpersist(blocking = true) } test("calling .unpersist() should drop in-memory columnar cache") { - ctx.table("testData").cache() - ctx.table("testData").count() - ctx.table("testData").unpersist(blocking = true) - assertCached(ctx.table("testData"), 0) + sqlContext.table("testData").cache() + sqlContext.table("testData").count() + sqlContext.table("testData").unpersist(blocking = true) + assertCached(sqlContext.table("testData"), 0) } test("isCached") { - ctx.cacheTable("testData") + sqlContext.cacheTable("testData") - assertCached(ctx.table("testData")) - assert(ctx.table("testData").queryExecution.withCachedData match { + assertCached(sqlContext.table("testData")) + assert(sqlContext.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => true case _ => false }) - ctx.uncacheTable("testData") - assert(!ctx.isCached("testData")) - assert(ctx.table("testData").queryExecution.withCachedData match { + sqlContext.uncacheTable("testData") + assert(!sqlContext.isCached("testData")) + assert(sqlContext.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => false case _ => true }) } test("SPARK-1669: cacheTable should be idempotent") { - assume(!ctx.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) + assume(!sqlContext.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) - ctx.cacheTable("testData") - assertCached(ctx.table("testData")) + sqlContext.cacheTable("testData") + assertCached(sqlContext.table("testData")) assertResult(1, "InMemoryRelation not found, testData should have been cached") { - ctx.table("testData").queryExecution.withCachedData.collect { + sqlContext.table("testData").queryExecution.withCachedData.collect { case r: InMemoryRelation => r }.size } - ctx.cacheTable("testData") + sqlContext.cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { - ctx.table("testData").queryExecution.withCachedData.collect { + sqlContext.table("testData").queryExecution.withCachedData.collect { case r @ InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan, _) => r }.size } - ctx.uncacheTable("testData") + sqlContext.uncacheTable("testData") } test("read from cached table and uncache") { - ctx.cacheTable("testData") - checkAnswer(ctx.table("testData"), testData.collect().toSeq) - assertCached(ctx.table("testData")) + sqlContext.cacheTable("testData") + checkAnswer(sqlContext.table("testData"), testData.collect().toSeq) + assertCached(sqlContext.table("testData")) - ctx.uncacheTable("testData") - checkAnswer(ctx.table("testData"), testData.collect().toSeq) - assertCached(ctx.table("testData"), 0) + sqlContext.uncacheTable("testData") + checkAnswer(sqlContext.table("testData"), testData.collect().toSeq) + assertCached(sqlContext.table("testData"), 0) } test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - ctx.uncacheTable("testData") + sqlContext.uncacheTable("testData") } } test("SELECT star from cached table") { sql("SELECT * FROM testData").registerTempTable("selectStar") - ctx.cacheTable("selectStar") + sqlContext.cacheTable("selectStar") checkAnswer( sql("SELECT * FROM selectStar WHERE key = 1"), Seq(Row(1, "1"))) - ctx.uncacheTable("selectStar") + sqlContext.uncacheTable("selectStar") } test("Self-join cached") { val unCachedAnswer = sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() - ctx.cacheTable("testData") + sqlContext.cacheTable("testData") checkAnswer( sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), unCachedAnswer.toSeq) - ctx.uncacheTable("testData") + sqlContext.uncacheTable("testData") } test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") { sql("CACHE TABLE testData") - assertCached(ctx.table("testData")) + assertCached(sqlContext.table("testData")) val rddId = rddIdOf("testData") assert( @@ -199,7 +217,7 @@ class CachedTableSuite extends QueryTest { "Eagerly cached in-memory table should have already been materialized") sql("UNCACHE TABLE testData") - assert(!ctx.isCached("testData"), "Table 'testData' should not be cached") + assert(!sqlContext.isCached("testData"), "Table 'testData' should not be cached") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") @@ -208,14 +226,14 @@ class CachedTableSuite extends QueryTest { test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - assertCached(ctx.table("testCacheTable")) + assertCached(sqlContext.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - ctx.uncacheTable("testCacheTable") + sqlContext.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -223,14 +241,14 @@ class CachedTableSuite extends QueryTest { test("CACHE TABLE tableName AS SELECT ...") { sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10") - assertCached(ctx.table("testCacheTable")) + assertCached(sqlContext.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - ctx.uncacheTable("testCacheTable") + sqlContext.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -238,7 +256,7 @@ class CachedTableSuite extends QueryTest { test("CACHE LAZY TABLE tableName") { sql("CACHE LAZY TABLE testData") - assertCached(ctx.table("testData")) + assertCached(sqlContext.table("testData")) val rddId = rddIdOf("testData") assert( @@ -250,7 +268,7 @@ class CachedTableSuite extends QueryTest { isMaterialized(rddId), "Lazily cached in-memory table should have been materialized") - ctx.uncacheTable("testData") + sqlContext.uncacheTable("testData") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -258,7 +276,7 @@ class CachedTableSuite extends QueryTest { test("InMemoryRelation statistics") { sql("CACHE TABLE testData") - ctx.table("testData").queryExecution.withCachedData.collect { + sqlContext.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum assert(cached.statistics.sizeInBytes === actualSizeInBytes) @@ -267,50 +285,48 @@ class CachedTableSuite extends QueryTest { test("Drops temporary table") { testData.select('key).registerTempTable("t1") - ctx.table("t1") - ctx.dropTempTable("t1") - assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found")) + sqlContext.table("t1") + sqlContext.dropTempTable("t1") + assert( + intercept[RuntimeException](sqlContext.table("t1")).getMessage.startsWith("Table Not Found")) } test("Drops cached temporary table") { testData.select('key).registerTempTable("t1") testData.select('key).registerTempTable("t2") - ctx.cacheTable("t1") + sqlContext.cacheTable("t1") - assert(ctx.isCached("t1")) - assert(ctx.isCached("t2")) + assert(sqlContext.isCached("t1")) + assert(sqlContext.isCached("t2")) - ctx.dropTempTable("t1") - assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found")) - assert(!ctx.isCached("t2")) + sqlContext.dropTempTable("t1") + assert( + intercept[RuntimeException](sqlContext.table("t1")).getMessage.startsWith("Table Not Found")) + assert(!sqlContext.isCached("t2")) } test("Clear all cache") { sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - ctx.cacheTable("t1") - ctx.cacheTable("t2") - ctx.clearCache() - assert(ctx.cacheManager.isEmpty) + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + sqlContext.clearCache() + assert(sqlContext.cacheManager.isEmpty) sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - ctx.cacheTable("t1") - ctx.cacheTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") sql("Clear CACHE") - assert(ctx.cacheManager.isEmpty) + assert(sqlContext.cacheManager.isEmpty) } test("Clear accumulators when uncacheTable to prevent memory leaking") { sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - Accumulators.synchronized { - val accsSize = Accumulators.originals.size - ctx.cacheTable("t1") - ctx.cacheTable("t2") - assert((accsSize + 2) == Accumulators.originals.size) - } + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") sql("SELECT * FROM t1").count() sql("SELECT * FROM t2").count() @@ -319,9 +335,23 @@ class CachedTableSuite extends QueryTest { Accumulators.synchronized { val accsSize = Accumulators.originals.size - ctx.uncacheTable("t1") - ctx.uncacheTable("t2") + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") assert((accsSize - 2) == Accumulators.originals.size) } } + + test("SPARK-10327 Cache Table is not working while subquery has alias in its project list") { + sparkContext.parallelize((1, 1) :: (2, 2) :: Nil) + .toDF("key", "value").selectExpr("key", "value", "key+1").registerTempTable("abc") + sqlContext.cacheTable("abc") + + val sparkPlan = sql( + """select a.key, b.key, c.key from + |abc a join abc b on a.key=b.key + |join abc c on a.key=c.key""".stripMargin).queryExecution.sparkPlan + + assert(sparkPlan.collect { case e: InMemoryColumnarTableScan => e }.size === 3) + assert(sparkPlan.collect { case e: PhysicalRDD => e }.size === 0) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 88bb743ab0bc..4e988f074b11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -17,17 +17,93 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.scalatest.Matchers._ -import org.apache.spark.sql.execution.Project +import org.apache.spark.sql.execution.{Project, TungstenProject} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -class ColumnExpressionSuite extends QueryTest { - import org.apache.spark.sql.TestData._ +class ColumnExpressionSuite extends QueryTest with SharedSQLContext { + import testImplicits._ - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ + private lazy val booleanData = { + sqlContext.createDataFrame(sparkContext.parallelize( + Row(false, false) :: + Row(false, true) :: + Row(true, false) :: + Row(true, true) :: Nil), + StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) + } + + test("column names with space") { + val df = Seq((1, "a")).toDF("name with space", "name.with.dot") + + checkAnswer( + df.select(df("name with space")), + Row(1) :: Nil) + + checkAnswer( + df.select($"name with space"), + Row(1) :: Nil) + + checkAnswer( + df.select(col("name with space")), + Row(1) :: Nil) + + checkAnswer( + df.select("name with space"), + Row(1) :: Nil) + + checkAnswer( + df.select(expr("`name with space`")), + Row(1) :: Nil) + } + + test("column names with dot") { + val df = Seq((1, "a")).toDF("name with space", "name.with.dot").as("a") + + checkAnswer( + df.select(df("`name.with.dot`")), + Row("a") :: Nil) + + checkAnswer( + df.select($"`name.with.dot`"), + Row("a") :: Nil) + + checkAnswer( + df.select(col("`name.with.dot`")), + Row("a") :: Nil) + + checkAnswer( + df.select("`name.with.dot`"), + Row("a") :: Nil) + + checkAnswer( + df.select(expr("`name.with.dot`")), + Row("a") :: Nil) + + checkAnswer( + df.select(df("a.`name.with.dot`")), + Row("a") :: Nil) + + checkAnswer( + df.select($"a.`name.with.dot`"), + Row("a") :: Nil) + + checkAnswer( + df.select(col("a.`name.with.dot`")), + Row("a") :: Nil) + + checkAnswer( + df.select("a.`name.with.dot`"), + Row("a") :: Nil) + + checkAnswer( + df.select(expr("a.`name.with.dot`")), + Row("a") :: Nil) + } test("alias") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") @@ -35,6 +111,14 @@ class ColumnExpressionSuite extends QueryTest { assert(df.select(df("a").alias("b")).columns.head === "b") } + test("as propagates metadata") { + val metadata = new MetadataBuilder + metadata.putString("key", "value") + val origCol = $"a".as("b", metadata.build()) + val newCol = origCol.as("c") + assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value") + } + test("single explode") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") checkAnswer( @@ -187,7 +271,7 @@ class ColumnExpressionSuite extends QueryTest { nullStrings.collect().toSeq.filter(r => r.getString(1) eq null)) checkAnswer( - ctx.sql("select isnull(null), isnull(1)"), + sql("select isnull(null), isnull(1)"), Row(true, false)) } @@ -197,10 +281,54 @@ class ColumnExpressionSuite extends QueryTest { nullStrings.collect().toSeq.filter(r => r.getString(1) ne null)) checkAnswer( - ctx.sql("select isnotnull(null), isnotnull('a')"), + sql("select isnotnull(null), isnotnull('a')"), Row(false, true)) } + test("isNaN") { + val testData = sqlContext.createDataFrame(sparkContext.parallelize( + Row(Double.NaN, Float.NaN) :: + Row(math.log(-1), math.log(-3).toFloat) :: + Row(null, null) :: + Row(Double.MaxValue, Float.MinValue):: Nil), + StructType(Seq(StructField("a", DoubleType), StructField("b", FloatType)))) + + checkAnswer( + testData.select($"a".isNaN, $"b".isNaN), + Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil) + + checkAnswer( + testData.select(isNaN($"a"), isNaN($"b")), + Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil) + + checkAnswer( + sql("select isnan(15), isnan('invalid')"), + Row(false, false)) + } + + test("nanvl") { + val testData = sqlContext.createDataFrame(sparkContext.parallelize( + Row(null, 3.0, Double.NaN, Double.PositiveInfinity, 1.0f, 4) :: Nil), + StructType(Seq(StructField("a", DoubleType), StructField("b", DoubleType), + StructField("c", DoubleType), StructField("d", DoubleType), + StructField("e", FloatType), StructField("f", IntegerType)))) + + checkAnswer( + testData.select( + nanvl($"a", lit(5)), nanvl($"b", lit(10)), nanvl(lit(10), $"b"), + nanvl($"c", lit(null).cast(DoubleType)), nanvl($"d", lit(10)), + nanvl($"b", $"e"), nanvl($"e", $"f")), + Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0) + ) + testData.registerTempTable("t") + checkAnswer( + sql( + "select nanvl(a, 5), nanvl(b, 10), nanvl(10, b), nanvl(c, null), nanvl(d, 10), " + + " nanvl(b, e), nanvl(e, f) from t"), + Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0) + ) + } + test("===") { checkAnswer( testData2.filter($"a" === 1), @@ -222,7 +350,7 @@ class ColumnExpressionSuite extends QueryTest { } test("!==") { - val nullData = ctx.createDataFrame(ctx.sparkContext.parallelize( + val nullData = sqlContext.createDataFrame(sparkContext.parallelize( Row(1, 1) :: Row(1, 2) :: Row(1, null) :: @@ -283,7 +411,7 @@ class ColumnExpressionSuite extends QueryTest { } test("between") { - val testData = ctx.sparkContext.parallelize( + val testData = sparkContext.parallelize( (0, 1, 2) :: (1, 2, 3) :: (2, 1, 0) :: @@ -298,26 +426,25 @@ class ColumnExpressionSuite extends QueryTest { test("in") { val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") - checkAnswer(df.filter($"a".in(1, 2)), + checkAnswer(df.filter($"a".isin(1, 2)), df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".in(3, 2)), + checkAnswer(df.filter($"a".isin(3, 2)), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".in(3, 1)), + checkAnswer(df.filter($"a".isin(3, 1)), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - checkAnswer(df.filter($"b".in("y", "x")), + checkAnswer(df.filter($"b".isin("y", "x")), df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x")) - checkAnswer(df.filter($"b".in("z", "x")), + checkAnswer(df.filter($"b".isin("z", "x")), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x")) - checkAnswer(df.filter($"b".in("z", "y")), + checkAnswer(df.filter($"b".isin("z", "y")), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) - } - val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize( - Row(false, false) :: - Row(false, true) :: - Row(true, false) :: - Row(true, true) :: Nil), - StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") + + intercept[AnalysisException] { + df2.filter($"a".isin($"b")) + } + } test("&&") { checkAnswer( @@ -402,7 +529,7 @@ class ColumnExpressionSuite extends QueryTest { ) checkAnswer( - ctx.sql("SELECT upper('aB'), ucase('cDe')"), + sql("SELECT upper('aB'), ucase('cDe')"), Row("AB", "CDE")) } @@ -423,13 +550,13 @@ class ColumnExpressionSuite extends QueryTest { ) checkAnswer( - ctx.sql("SELECT lower('aB'), lcase('cDe')"), + sql("SELECT lower('aB'), lcase('cDe')"), Row("ab", "cde")) } test("monotonicallyIncreasingId") { // Make sure we have 2 partitions, each with 2 records. - val df = ctx.sparkContext.parallelize(1 to 2, 2).mapPartitions { iter => + val df = sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => Iterator(Tuple1(1), Tuple1(2)) }.toDF("a") checkAnswer( @@ -439,13 +566,28 @@ class ColumnExpressionSuite extends QueryTest { } test("sparkPartitionId") { - val df = ctx.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b") + // Make sure we have 2 partitions, each with 2 records. + val df = sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => + Iterator(Tuple1(1), Tuple1(2)) + }.toDF("a") checkAnswer( df.select(sparkPartitionId()), - Row(0) + Row(0) :: Row(0) :: Row(1) :: Row(1) :: Nil ) } + test("InputFileName") { + withTempPath { dir => + val data = sparkContext.parallelize(0 to 10).toDF("id") + data.write.parquet(dir.getCanonicalPath) + val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(inputFileName()) + .head.getString(0) + assert(answer.contains(dir.getCanonicalPath)) + + checkAnswer(data.select(inputFileName()).limit(1), Row("")) + } + } + test("lift alias out of cast") { compareExpressions( col("1234").as("name").cast("int").expr, @@ -480,6 +622,7 @@ class ColumnExpressionSuite extends QueryTest { def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = { val projects = df.queryExecution.executedPlan.collect { case project: Project => project + case tungstenProject: TungstenProject => tungstenProject } assert(projects.size === expectedNumProjects) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index b26d3ab253a1..f5ef9ffd7f4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.sql -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.DecimalType -class DataFrameAggregateSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("groupBy") { checkAnswer( @@ -68,12 +66,12 @@ class DataFrameAggregateSuite extends QueryTest { Seq(Row(1, 3), Row(2, 3), Row(3, 3)) ) - ctx.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false) + sqlContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false) checkAnswer( testData2.groupBy("a").agg(sum($"b")), Seq(Row(3), Row(3), Row(3)) ) - ctx.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, true) + sqlContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, true) } test("agg without groups") { @@ -177,6 +175,39 @@ class DataFrameAggregateSuite extends QueryTest { Row(0, null)) } + test("stddev") { + val testData2ADev = math.sqrt(4/5.0) + + checkAnswer( + testData2.agg(stddev('a)), + Row(testData2ADev)) + + checkAnswer( + testData2.agg(stddev_pop('a)), + Row(math.sqrt(4/6.0))) + + checkAnswer( + testData2.agg(stddev_samp('a)), + Row(testData2ADev)) + } + + test("zero stddev") { + val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") + assert(emptyTableData.count() == 0) + + checkAnswer( + emptyTableData.agg(stddev('a)), + Row(null)) + + checkAnswer( + emptyTableData.agg(stddev_pop('a)), + Row(null)) + + checkAnswer( + emptyTableData.agg(stddev_samp('a)), + Row(null)) + } + test("zero sum") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( @@ -190,5 +221,4 @@ class DataFrameAggregateSuite extends QueryTest { emptyTableData.agg(sumDistinct('a)), Row(null)) } - } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala new file mode 100644 index 000000000000..09f7b507670c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +/** + * A test suite to test DataFrame/SQL functionalities with complex types (i.e. array, struct, map). + */ +class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("UDF on struct") { + val f = udf((a: String) => a) + val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + df.select(struct($"a").as("s")).select(f($"s.a")).collect() + } + + test("UDF on named_struct") { + val f = udf((a: String) => a) + val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + df.selectExpr("named_struct('a', a) s").select(f($"s.a")).collect() + } + + test("UDF on array") { + val f = udf((a: String) => a) + val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + df.select(array($"a").as("s")).select(f(expr("s[0]"))).collect() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameDateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameDateSuite.scala deleted file mode 100644 index a4719a38de1d..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameDateSuite.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.sql.{Date, Timestamp} - -class DataFrameDateTimeSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - - test("timestamp comparison with date strings") { - val df = Seq( - (1, Timestamp.valueOf("2015-01-01 00:00:00")), - (2, Timestamp.valueOf("2014-01-01 00:00:00"))).toDF("i", "t") - - checkAnswer( - df.select("t").filter($"t" <= "2014-06-01"), - Row(Timestamp.valueOf("2014-01-01 00:00:00")) :: Nil) - - - checkAnswer( - df.select("t").filter($"t" >= "2014-06-01"), - Row(Timestamp.valueOf("2015-01-01 00:00:00")) :: Nil) - } - - test("date comparison with date strings") { - val df = Seq( - (1, Date.valueOf("2015-01-01")), - (2, Date.valueOf("2014-01-01"))).toDF("i", "t") - - checkAnswer( - df.select("t").filter($"t" <= "2014-06-01"), - Row(Date.valueOf("2014-01-01")) :: Nil) - - - checkAnswer( - df.select("t").filter($"t" >= "2015"), - Row(Date.valueOf("2015-01-01")) :: Nil) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 8b53b384a22f..3a3f19af1473 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.sql -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ /** * Test suite for functions in [[org.apache.spark.sql.functions]]. */ -class DataFrameFunctionsSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("array with column name") { val df = Seq((0, 1)).toDF("a", "b") @@ -79,27 +77,51 @@ class DataFrameFunctionsSuite extends QueryTest { assert(row.getAs[Row](0) === Row(2, "str")) } - test("struct: must use named column expression") { - intercept[IllegalArgumentException] { - struct(col("a") * 2) - } + test("struct with column expression to be automatically named") { + val df = Seq((1, "str")).toDF("a", "b") + val result = df.select(struct((col("a") * 2), col("b"))) + + val expectedType = StructType(Seq( + StructField("col1", IntegerType, nullable = false), + StructField("b", StringType) + )) + assert(result.first.schema(0).dataType === expectedType) + checkAnswer(result, Row(Row(2, "str"))) + } + + test("struct with literal columns") { + val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") + val result = df.select(struct((col("a") * 2), lit(5.0))) + + val expectedType = StructType(Seq( + StructField("col1", IntegerType, nullable = false), + StructField("col2", DoubleType, nullable = false) + )) + + assert(result.first.schema(0).dataType === expectedType) + checkAnswer(result, Seq(Row(Row(2, 5.0)), Row(Row(4, 5.0)))) + } + + test("struct with all literal columns") { + val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") + val result = df.select(struct(lit("v"), lit(5.0))) + + val expectedType = StructType(Seq( + StructField("col1", StringType, nullable = false), + StructField("col2", DoubleType, nullable = false) + )) + + assert(result.first.schema(0).dataType === expectedType) + checkAnswer(result, Seq(Row(Row("v", 5.0)), Row(Row("v", 5.0)))) } test("constant functions") { checkAnswer( - testData2.select(e()).limit(1), + sql("SELECT E()"), Row(scala.math.E) ) checkAnswer( - testData2.select(pi()).limit(1), - Row(scala.math.Pi) - ) - checkAnswer( - ctx.sql("SELECT E()"), - Row(scala.math.E) - ) - checkAnswer( - ctx.sql("SELECT PI()"), + sql("SELECT PI()"), Row(scala.math.Pi) ) } @@ -129,14 +151,14 @@ class DataFrameFunctionsSuite extends QueryTest { test("nvl function") { checkAnswer( - ctx.sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"), + sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"), Row("x", "y", null)) } test("misc md5 function") { val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") checkAnswer( - df.select(md5($"a"), md5("b")), + df.select(md5($"a"), md5($"b")), Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c")) checkAnswer( @@ -144,21 +166,225 @@ class DataFrameFunctionsSuite extends QueryTest { Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c")) } - test("string length function") { + test("misc sha1 function") { + val df = Seq(("ABC", "ABC".getBytes)).toDF("a", "b") + checkAnswer( + df.select(sha1($"a"), sha1($"b")), + Row("3c01bdbb26f358bab27f267924aa2c9a03fcfdb8", "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")) + + val dfEmpty = Seq(("", "".getBytes)).toDF("a", "b") + checkAnswer( + dfEmpty.selectExpr("sha1(a)", "sha1(b)"), + Row("da39a3ee5e6b4b0d3255bfef95601890afd80709", "da39a3ee5e6b4b0d3255bfef95601890afd80709")) + } + + test("misc sha2 function") { + val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") + checkAnswer( + df.select(sha2($"a", 256), sha2($"b", 256)), + Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78", + "7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89")) + + checkAnswer( + df.selectExpr("sha2(a, 256)", "sha2(b, 256)"), + Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78", + "7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89")) + + intercept[IllegalArgumentException] { + df.select(sha2($"a", 1024)) + } + } + + test("misc crc32 function") { + val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") + checkAnswer( + df.select(crc32($"a"), crc32($"b")), + Row(2743272264L, 2180413220L)) + + checkAnswer( + df.selectExpr("crc32(a)", "crc32(b)"), + Row(2743272264L, 2180413220L)) + } + + test("string function find_in_set") { + val df = Seq(("abc,b,ab,c,def", "abc,b,ab,c,def")).toDF("a", "b") + + checkAnswer( + df.selectExpr("find_in_set('ab', a)", "find_in_set('x', b)"), + Row(3, 0)) + } + + test("conditional function: least") { + checkAnswer( + testData2.select(least(lit(-1), lit(0), col("a"), col("b"))).limit(1), + Row(-1) + ) + checkAnswer( + sql("SELECT least(a, 2) as l from testData2 order by l"), + Seq(Row(1), Row(1), Row(2), Row(2), Row(2), Row(2)) + ) + } + + test("conditional function: greatest") { + checkAnswer( + testData2.select(greatest(lit(2), lit(3), col("a"), col("b"))).limit(1), + Row(3) + ) + checkAnswer( + sql("SELECT greatest(a, 2) as g from testData2 order by g"), + Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3)) + ) + } + + test("pmod") { + val intData = Seq((7, 3), (-7, 3)).toDF("a", "b") + checkAnswer( + intData.select(pmod('a, 'b)), + Seq(Row(1), Row(2)) + ) + checkAnswer( + intData.select(pmod('a, lit(3))), + Seq(Row(1), Row(2)) + ) + checkAnswer( + intData.select(pmod(lit(-7), 'b)), + Seq(Row(2), Row(2)) + ) + checkAnswer( + intData.selectExpr("pmod(a, b)"), + Seq(Row(1), Row(2)) + ) + checkAnswer( + intData.selectExpr("pmod(a, 3)"), + Seq(Row(1), Row(2)) + ) + checkAnswer( + intData.selectExpr("pmod(-7, b)"), + Seq(Row(2), Row(2)) + ) + val doubleData = Seq((7.2, 4.1)).toDF("a", "b") checkAnswer( - nullStrings.select(strlen($"s"), strlen("s")), - nullStrings.collect().toSeq.map { r => - val v = r.getString(1) - val l = if (v == null) null else v.length - Row(l, l) - }) + doubleData.select(pmod('a, 'b)), + Seq(Row(3.1000000000000005)) // same as hive + ) + checkAnswer( + doubleData.select(pmod(lit(2), lit(Int.MaxValue))), + Seq(Row(2)) + ) + } + test("sort_array function") { + val df = Seq( + (Array[Int](2, 1, 3), Array("b", "c", "a")), + (Array[Int](), Array[String]()), + (null, null) + ).toDF("a", "b") + checkAnswer( + df.select(sort_array($"a"), sort_array($"b")), + Seq( + Row(Seq(1, 2, 3), Seq("a", "b", "c")), + Row(Seq[Int](), Seq[String]()), + Row(null, null)) + ) + checkAnswer( + df.select(sort_array($"a", false), sort_array($"b", false)), + Seq( + Row(Seq(3, 2, 1), Seq("c", "b", "a")), + Row(Seq[Int](), Seq[String]()), + Row(null, null)) + ) checkAnswer( - nullStrings.selectExpr("length(s)"), - nullStrings.collect().toSeq.map { r => - val v = r.getString(1) - val l = if (v == null) null else v.length - Row(l) - }) + df.selectExpr("sort_array(a)", "sort_array(b)"), + Seq( + Row(Seq(1, 2, 3), Seq("a", "b", "c")), + Row(Seq[Int](), Seq[String]()), + Row(null, null)) + ) + checkAnswer( + df.selectExpr("sort_array(a, true)", "sort_array(b, false)"), + Seq( + Row(Seq(1, 2, 3), Seq("c", "b", "a")), + Row(Seq[Int](), Seq[String]()), + Row(null, null)) + ) + + val df2 = Seq((Array[Array[Int]](Array(2)), "x")).toDF("a", "b") + assert(intercept[AnalysisException] { + df2.selectExpr("sort_array(a)").collect() + }.getMessage().contains("does not support sorting array of type array")) + + val df3 = Seq(("xxx", "x")).toDF("a", "b") + assert(intercept[AnalysisException] { + df3.selectExpr("sort_array(a)").collect() + }.getMessage().contains("only supports array input")) + } + + test("array size function") { + val df = Seq( + (Seq[Int](1, 2), "x"), + (Seq[Int](), "y"), + (Seq[Int](1, 2, 3), "z") + ).toDF("a", "b") + checkAnswer( + df.select(size($"a")), + Seq(Row(2), Row(0), Row(3)) + ) + checkAnswer( + df.selectExpr("size(a)"), + Seq(Row(2), Row(0), Row(3)) + ) + } + + test("map size function") { + val df = Seq( + (Map[Int, Int](1 -> 1, 2 -> 2), "x"), + (Map[Int, Int](), "y"), + (Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3), "z") + ).toDF("a", "b") + checkAnswer( + df.select(size($"a")), + Seq(Row(2), Row(0), Row(3)) + ) + checkAnswer( + df.selectExpr("size(a)"), + Seq(Row(2), Row(0), Row(3)) + ) + } + + test("array contains function") { + val df = Seq( + (Seq[Int](1, 2), "x"), + (Seq[Int](), "x") + ).toDF("a", "b") + + // Simple test cases + checkAnswer( + df.select(array_contains(df("a"), 1)), + Seq(Row(true), Row(false)) + ) + checkAnswer( + df.selectExpr("array_contains(a, 1)"), + Seq(Row(true), Row(false)) + ) + + // In hive, this errors because null has no type information + intercept[AnalysisException] { + df.select(array_contains(df("a"), null)) + } + intercept[AnalysisException] { + df.selectExpr("array_contains(a, null)") + } + intercept[AnalysisException] { + df.selectExpr("array_contains(null, 1)") + } + + checkAnswer( + df.selectExpr("array_contains(array(array(1), null)[0], 1)"), + Seq(Row(true), Row(true)) + ) + checkAnswer( + df.selectExpr("array_contains(array(1, null), array(1, null)[0])"), + Seq(Row(true), Row(true)) + ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index fbb30706a494..094efbaeadcd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql -class DataFrameImplicitsSuite extends QueryTest { +import org.apache.spark.sql.test.SharedSQLContext - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("RDD of tuples") { checkAnswer( - ctx.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), + sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), (1 to 10).map(i => Row(i, i.toString))) } @@ -36,19 +36,19 @@ class DataFrameImplicitsSuite extends QueryTest { test("RDD[Int]") { checkAnswer( - ctx.sparkContext.parallelize(1 to 10).toDF("intCol"), + sparkContext.parallelize(1 to 10).toDF("intCol"), (1 to 10).map(i => Row(i))) } test("RDD[Long]") { checkAnswer( - ctx.sparkContext.parallelize(1L to 10L).toDF("longCol"), + sparkContext.parallelize(1L to 10L).toDF("longCol"), (1L to 10L).map(i => Row(i))) } test("RDD[String]") { checkAnswer( - ctx.sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"), + sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"), (1 to 10).map(i => Row(i.toString))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 6165764632c2..e2716d7841d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql -import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameJoinSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameJoinSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("join - join using") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") @@ -58,7 +57,7 @@ class DataFrameJoinSuite extends QueryTest { checkAnswer( df1.join(df2, $"df1.key" === $"df2.key"), - ctx.sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") + sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") .collect().toSeq) } @@ -93,4 +92,20 @@ class DataFrameJoinSuite extends QueryTest { left.join(right, left("key") === right("key")), Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil) } + + test("broadcast join hint") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") + + // equijoin - should be converted into broadcast join + val plan1 = df1.join(broadcast(df2), "key").queryExecution.executedPlan + assert(plan1.collect { case p: BroadcastHashJoin => p }.size === 1) + + // no join key -- should not be a broadcast join + val plan2 = df1.join(broadcast(df2)).queryExecution.executedPlan + assert(plan2.collect { case p: BroadcastHashJoin => p }.size === 0) + + // planner should not crash without a join + broadcast(df1).queryExecution.executedPlan + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 495701d4f616..329ffb66083b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -17,21 +17,23 @@ package org.apache.spark.sql -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameNaFunctionsSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ def createDF(): DataFrame = { Seq[(String, java.lang.Integer, java.lang.Double)]( ("Bob", 16, 176.5), ("Alice", null, 164.3), ("David", 60, null), + ("Nina", 25, Double.NaN), ("Amy", null, null), - (null, null, null)).toDF("name", "age", "height") + (null, null, null) + ).toDF("name", "age", "height") } test("drop") { @@ -39,12 +41,12 @@ class DataFrameNaFunctionsSuite extends QueryTest { val rows = input.collect() checkAnswer( - input.na.drop("name" :: Nil), - rows(0) :: rows(1) :: rows(2) :: rows(3) :: Nil) + input.na.drop("name" :: Nil).select("name"), + Row("Bob") :: Row("Alice") :: Row("David") :: Row("Nina") :: Row("Amy") :: Nil) checkAnswer( - input.na.drop("age" :: Nil), - rows(0) :: rows(2) :: Nil) + input.na.drop("age" :: Nil).select("name"), + Row("Bob") :: Row("David") :: Row("Nina") :: Nil) checkAnswer( input.na.drop("age" :: "height" :: Nil), @@ -67,8 +69,8 @@ class DataFrameNaFunctionsSuite extends QueryTest { val rows = input.collect() checkAnswer( - input.na.drop("all"), - rows(0) :: rows(1) :: rows(2) :: rows(3) :: Nil) + input.na.drop("all").select("name"), + Row("Bob") :: Row("Alice") :: Row("David") :: Row("Nina") :: Row("Amy") :: Nil) checkAnswer( input.na.drop("any"), @@ -79,8 +81,8 @@ class DataFrameNaFunctionsSuite extends QueryTest { rows(0) :: Nil) checkAnswer( - input.na.drop("all", Seq("age", "height")), - rows(0) :: rows(1) :: rows(2) :: Nil) + input.na.drop("all", Seq("age", "height")).select("name"), + Row("Bob") :: Row("Alice") :: Row("David") :: Row("Nina") :: Nil) } test("drop with threshold") { @@ -108,6 +110,7 @@ class DataFrameNaFunctionsSuite extends QueryTest { Row("Bob", 16, 176.5) :: Row("Alice", 50, 164.3) :: Row("David", 60, 50.6) :: + Row("Nina", 25, 50.6) :: Row("Amy", 50, 50.6) :: Row(null, 50, 50.6) :: Nil) @@ -117,17 +120,19 @@ class DataFrameNaFunctionsSuite extends QueryTest { // string checkAnswer( input.na.fill("unknown").select("name"), - Row("Bob") :: Row("Alice") :: Row("David") :: Row("Amy") :: Row("unknown") :: Nil) + Row("Bob") :: Row("Alice") :: Row("David") :: + Row("Nina") :: Row("Amy") :: Row("unknown") :: Nil) assert(input.na.fill("unknown").columns.toSeq === input.columns.toSeq) // fill double with subset columns checkAnswer( - input.na.fill(50.6, "age" :: Nil), - Row("Bob", 16, 176.5) :: - Row("Alice", 50, 164.3) :: - Row("David", 60, null) :: - Row("Amy", 50, null) :: - Row(null, 50, null) :: Nil) + input.na.fill(50.6, "age" :: Nil).select("name", "age"), + Row("Bob", 16) :: + Row("Alice", 50) :: + Row("David", 60) :: + Row("Nina", 25) :: + Row("Amy", 50) :: + Row(null, 50) :: Nil) // fill string with subset columns checkAnswer( @@ -148,11 +153,11 @@ class DataFrameNaFunctionsSuite extends QueryTest { // Test Java version checkAnswer( - df.na.fill(mapAsJavaMap(Map( + df.na.fill(Map( "a" -> "test", "c" -> 1, "d" -> 2.2 - ))), + ).asJava), Row("test", null, 1, 2.2)) } @@ -164,29 +169,27 @@ class DataFrameNaFunctionsSuite extends QueryTest { 16 -> 61, 60 -> 6, 164.3 -> 461.3 // Alice is really tall - )) + )).collect() - checkAnswer( - out, - Row("Bob", 61, 176.5) :: - Row("Alice", null, 461.3) :: - Row("David", 6, null) :: - Row("Amy", null, null) :: - Row(null, null, null) :: Nil) + assert(out(0) === Row("Bob", 61, 176.5)) + assert(out(1) === Row("Alice", null, 461.3)) + assert(out(2) === Row("David", 6, null)) + assert(out(3).get(2).asInstanceOf[Double].isNaN) + assert(out(4) === Row("Amy", null, null)) + assert(out(5) === Row(null, null, null)) // Replace only the age column val out1 = input.na.replace("age", Map( 16 -> 61, 60 -> 6, 164.3 -> 461.3 // Alice is really tall - )) - - checkAnswer( - out1, - Row("Bob", 61, 176.5) :: - Row("Alice", null, 164.3) :: - Row("David", 6, null) :: - Row("Amy", null, null) :: - Row(null, null, null) :: Nil) + )).collect() + + assert(out1(0) === Row("Bob", 61, 176.5)) + assert(out1(1) === Row("Alice", null, 164.3)) + assert(out1(2) === Row("David", 6, null)) + assert(out1(3).get(2).asInstanceOf[Double].isNaN) + assert(out1(4) === Row("Amy", null, null)) + assert(out1(5) === Row(null, null, null)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 0d3ff899dad7..6524abcf5e97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -17,17 +17,51 @@ package org.apache.spark.sql -import org.scalatest.Matchers._ +import java.util.Random -import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameStatSuite extends SparkFunSuite { - - private val sqlCtx = org.apache.spark.sql.test.TestSQLContext - import sqlCtx.implicits._ +class DataFrameStatSuite extends QueryTest with SharedSQLContext { + import testImplicits._ private def toLetter(i: Int): String = (i + 97).toChar.toString + test("sample with replacement") { + val n = 100 + val data = sparkContext.parallelize(1 to n, 2).toDF("id") + checkAnswer( + data.sample(withReplacement = true, 0.05, seed = 13), + Seq(5, 10, 52, 73).map(Row(_)) + ) + } + + test("sample without replacement") { + val n = 100 + val data = sparkContext.parallelize(1 to n, 2).toDF("id") + checkAnswer( + data.sample(withReplacement = false, 0.05, seed = 13), + Seq(16, 23, 88, 100).map(Row(_)) + ) + } + + test("randomSplit") { + val n = 600 + val data = sparkContext.parallelize(1 to n, 2).toDF("id") + for (seed <- 1 to 5) { + val splits = data.randomSplit(Array[Double](1, 2, 3), seed) + assert(splits.length == 3, "wrong number of splits") + + assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList == + data.collect().toList, "incomplete or wrong split") + + val s = splits.map(_.count()) + assert(math.abs(s(0) - 100) < 50) // std = 9.13 + assert(math.abs(s(1) - 200) < 50) // std = 11.55 + assert(math.abs(s(2) - 300) < 50) // std = 12.25 + } + } + test("pearson correlation") { val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c") val corr1 = df.stat.corr("a", "b", "pearson") @@ -65,22 +99,52 @@ class DataFrameStatSuite extends SparkFunSuite { } test("crosstab") { - val df = Seq((0, 0), (2, 1), (1, 0), (2, 0), (0, 0), (2, 0)).toDF("a", "b") + val rng = new Random() + val data = Seq.tabulate(25)(i => (rng.nextInt(5), rng.nextInt(10))) + val df = data.toDF("a", "b") val crosstab = df.stat.crosstab("a", "b") val columnNames = crosstab.schema.fieldNames assert(columnNames(0) === "a_b") - assert(columnNames(1) === "0") - assert(columnNames(2) === "1") - val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0)) - assert(rows(0).get(0).toString === "0") - assert(rows(0).getLong(1) === 2L) - assert(rows(0).get(2) === 0L) - assert(rows(1).get(0).toString === "1") - assert(rows(1).getLong(1) === 1L) - assert(rows(1).get(2) === 0L) - assert(rows(2).get(0).toString === "2") - assert(rows(2).getLong(1) === 2L) - assert(rows(2).getLong(2) === 1L) + // reduce by key + val expected = data.map(t => (t, 1)).groupBy(_._1).mapValues(_.length) + val rows = crosstab.collect() + rows.foreach { row => + val i = row.getString(0).toInt + for (col <- 1 until columnNames.length) { + val j = columnNames(col).toInt + assert(row.getLong(col) === expected.getOrElse((i, j), 0).toLong) + } + } + } + + test("special crosstab elements (., '', null, ``)") { + val data = Seq( + ("a", Double.NaN, "ho"), + (null, 2.0, "ho"), + ("a.b", Double.NegativeInfinity, ""), + ("b", Double.PositiveInfinity, "`ha`"), + ("a", 1.0, null) + ) + val df = data.toDF("1", "2", "3") + val ct1 = df.stat.crosstab("1", "2") + // column fields should be 1 + distinct elements of second column + assert(ct1.schema.fields.length === 6) + assert(ct1.collect().length === 4) + val ct2 = df.stat.crosstab("1", "3") + assert(ct2.schema.fields.length === 5) + assert(ct2.schema.fieldNames.contains("ha")) + assert(ct2.collect().length === 4) + val ct3 = df.stat.crosstab("3", "2") + assert(ct3.schema.fields.length === 6) + assert(ct3.schema.fieldNames.contains("NaN")) + assert(ct3.schema.fieldNames.contains("Infinity")) + assert(ct3.schema.fieldNames.contains("-Infinity")) + assert(ct3.collect().length === 4) + val ct4 = df.stat.crosstab("3", "1") + assert(ct4.schema.fields.length === 5) + assert(ct4.schema.fieldNames.contains("null")) + assert(ct4.schema.fieldNames.contains("a.b")) + assert(ct4.collect().length === 4) } test("Frequent Items") { @@ -91,11 +155,37 @@ class DataFrameStatSuite extends SparkFunSuite { val results = df.stat.freqItems(Array("numbers", "letters"), 0.1) val items = results.collect().head - items.getSeq[Int](0) should contain (1) - items.getSeq[String](1) should contain (toLetter(1)) + assert(items.getSeq[Int](0).contains(1)) + assert(items.getSeq[String](1).contains(toLetter(1))) val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1) val items2 = singleColResults.collect().head - items2.getSeq[Double](0) should contain (-1.0) + assert(items2.getSeq[Double](0).contains(-1.0)) + } + + test("Frequent Items 2") { + val rows = sparkContext.parallelize(Seq.empty[Int], 4) + // this is a regression test, where when merging partitions, we omitted values with higher + // counts than those that existed in the map when the map was full. This test should also fail + // if anything like SPARK-9614 is observed once again + val df = rows.mapPartitionsWithIndex { (idx, iter) => + if (idx == 3) { // must come from one of the later merges, therefore higher partition index + Iterator("3", "3", "3", "3", "3") + } else { + Iterator("0", "1", "2", "3", "4") + } + }.toDF("a") + val results = df.stat.freqItems(Array("a"), 0.25) + val items = results.collect().head.getSeq[String](0) + assert(items.contains("3")) + assert(items.length === 1) + } + + test("sampleBy") { + val df = sqlContext.range(0, 100).select((col("id") % 3).as("key")) + val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) + checkAnswer( + sampled.groupBy("key").count().orderBy("key"), + Seq(Row(0, 5), Row(1, 8))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ba1d020f22f1..c167999af580 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -17,41 +17,40 @@ package org.apache.spark.sql +import java.io.File + import scala.language.postfixOps +import scala.util.Random + +import org.scalatest.Matchers._ +import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint} - +import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SharedSQLContext} -class DataFrameSuite extends QueryTest { - import org.apache.spark.sql.TestData._ - - lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("analysis error should be eagerly reported") { - val oldSetting = ctx.conf.dataFrameEagerAnalysis // Eager analysis. - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true) - - intercept[Exception] { testData.select('nonExistentName) } - intercept[Exception] { - testData.groupBy('key).agg(Map("nonExistentName" -> "sum")) - } - intercept[Exception] { - testData.groupBy("nonExistentName").agg(Map("key" -> "sum")) - } - intercept[Exception] { - testData.groupBy($"abcd").agg(Map("key" -> "sum")) + withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "true") { + intercept[Exception] { testData.select('nonExistentName) } + intercept[Exception] { + testData.groupBy('key).agg(Map("nonExistentName" -> "sum")) + } + intercept[Exception] { + testData.groupBy("nonExistentName").agg(Map("key" -> "sum")) + } + intercept[Exception] { + testData.groupBy($"abcd").agg(Map("key" -> "sum")) + } } // No more eager analysis once the flag is turned off - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false) - testData.select('nonExistentName) - - // Set the flag back to original value before this test. - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting) + withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "false") { + testData.select('nonExistentName) + } } test("dataframe toString") { @@ -69,21 +68,18 @@ class DataFrameSuite extends QueryTest { } test("invalid plan toString, debug mode") { - val oldSetting = ctx.conf.dataFrameEagerAnalysis - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true) - // Turn on debug mode so we can see invalid query plans. import org.apache.spark.sql.execution.debug._ - ctx.debug() - val badPlan = testData.select('badColumn) + withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "true") { + sqlContext.debug() - assert(badPlan.toString contains badPlan.queryExecution.toString, - "toString on bad query plans should include the query execution but was:\n" + - badPlan.toString) + val badPlan = testData.select('badColumn) - // Set the flag back to original value before this test. - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting) + assert(badPlan.toString contains badPlan.queryExecution.toString, + "toString on bad query plans should include the query execution but was:\n" + + badPlan.toString) + } } test("access complex data") { @@ -99,8 +95,8 @@ class DataFrameSuite extends QueryTest { } test("empty data frame") { - assert(ctx.emptyDataFrame.columns.toSeq === Seq.empty[String]) - assert(ctx.emptyDataFrame.count() === 0) + assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) + assert(sqlContext.emptyDataFrame.count() === 0) } test("head and take") { @@ -134,6 +130,21 @@ class DataFrameSuite extends QueryTest { ) } + test("SPARK-8930: explode should fail with a meaningful message if it takes a star") { + val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv") + val e = intercept[AnalysisException] { + df.explode($"*") { case Row(prefix: String, csv: String) => + csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq + }.queryExecution.assertAnalyzed() + } + assert(e.getMessage.contains( + "Cannot explode *, explode can only be applied on a specific column.")) + + df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) => + csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq + }.queryExecution.assertAnalyzed() + } + test("explode alias and star") { val df = Seq((Array("a"), 1)).toDF("a", "b") @@ -160,6 +171,12 @@ class DataFrameSuite extends QueryTest { testData.collect().filter(_.getInt(0) > 90).toSeq) } + test("filterExpr using where") { + checkAnswer( + testData.where("key > 50"), + testData.collect().filter(_.getInt(0) > 50).toSeq) + } + test("repartition") { checkAnswer( testData.select('key).repartition(10).select('key), @@ -301,7 +318,7 @@ class DataFrameSuite extends QueryTest { ) } - test("call udf in SQLContext") { + test("deprecated callUdf in SQLContext") { val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") val sqlctx = df.sqlContext sqlctx.udf.register("simpleUdf", (v: Int) => v * v) @@ -310,6 +327,15 @@ class DataFrameSuite extends QueryTest { Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil) } + test("callUDF in SQLContext") { + val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + val sqlctx = df.sqlContext + sqlctx.udf.register("simpleUDF", (v: Int) => v * v) + checkAnswer( + df.select($"id", callUDF("simpleUDF", $"value")), + Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil) + } + test("withColumn") { val df = testData.toDF().withColumn("newCol", col("key") + 1) checkAnswer( @@ -321,7 +347,7 @@ class DataFrameSuite extends QueryTest { } test("replace column using withColumn") { - val df2 = ctx.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") + val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1) checkAnswer( df3.select("x"), @@ -400,23 +426,6 @@ class DataFrameSuite extends QueryTest { assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol")) } - test("randomSplit") { - val n = 600 - val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") - for (seed <- 1 to 5) { - val splits = data.randomSplit(Array[Double](1, 2, 3), seed) - assert(splits.length == 3, "wrong number of splits") - - assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList == - data.collect().toList, "incomplete or wrong split") - - val s = splits.map(_.count()) - assert(math.abs(s(0) - 100) < 50) // std = 9.13 - assert(math.abs(s(1) - 200) < 50) // std = 11.55 - assert(math.abs(s(2) - 300) < 50) // std = 12.25 - } - } - test("describe") { val describeTestData = Seq( ("Bob", 16, 176), @@ -427,7 +436,7 @@ class DataFrameSuite extends QueryTest { val describeResult = Seq( Row("count", "4", "4"), Row("mean", "33.0", "178.0"), - Row("stddev", "16.583123951777", "10.0"), + Row("stddev", "19.148542155126762", "11.547005383792516"), Row("min", "16", "164"), Row("max", "60", "192")) @@ -471,12 +480,53 @@ class DataFrameSuite extends QueryTest { checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq) } + test("inputFiles") { + withTempDir { dir => + val df = Seq((1, 22)).toDF("a", "b") + + val parquetDir = new File(dir, "parquet").getCanonicalPath + df.write.parquet(parquetDir) + val parquetDF = sqlContext.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) + + val jsonDir = new File(dir, "json").getCanonicalPath + df.write.json(jsonDir) + val jsonDF = sqlContext.read.json(jsonDir) + assert(parquetDF.inputFiles.nonEmpty) + + val unioned = jsonDF.unionAll(parquetDF).inputFiles.sorted + val allFiles = (jsonDF.inputFiles ++ parquetDF.inputFiles).toSet.toArray.sorted + assert(unioned === allFiles) + } + } + ignore("show") { // This test case is intended ignored, but to make sure it compiles correctly testData.select($"*").show() testData.select($"*").show(1000) } + test("showString: truncate = [true, false]") { + val longString = Array.fill(21)("1").mkString + val df = sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = """+---------------------+ + ||_1 | + |+---------------------+ + ||1 | + ||111111111111111111111| + |+---------------------+ + |""".stripMargin + assert(df.showString(10, false) === expectedAnswerForFalse) + val expectedAnswerForTrue = """+--------------------+ + || _1| + |+--------------------+ + || 1| + ||11111111111111111...| + |+--------------------+ + |""".stripMargin + assert(df.showString(10, true) === expectedAnswerForTrue) + } + test("showString(negative)") { val expectedAnswer = """+---+-----+ ||key|value| @@ -548,21 +598,17 @@ class DataFrameSuite extends QueryTest { } test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { - val rowRDD = ctx.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) + val rowRDD = sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) - val df = ctx.createDataFrame(rowRDD, schema) + val df = sqlContext.createDataFrame(rowRDD, schema) df.rdd.collect() } - test("SPARK-6899") { - val originalValue = ctx.conf.codegenEnabled - ctx.setConf(SQLConf.CODEGEN_ENABLED, true) - try{ + test("SPARK-6899: type should match when using codegen") { + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { checkAnswer( decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2.0))) - } finally { - ctx.setConf(SQLConf.CODEGEN_ENABLED, originalValue) } } @@ -571,17 +617,18 @@ class DataFrameSuite extends QueryTest { assert(complexData.filter(complexData("m")("1") === 1).count() == 1) assert(complexData.filter(complexData("s")("key") === 1).count() == 1) assert(complexData.filter(complexData("m")(complexData("s")("value")) === 1).count() == 1) + assert(complexData.filter(complexData("a")(complexData("s")("key")) === 1).count() == 1) } test("SPARK-7551: support backticks for DataFrame attribute resolution") { - val df = ctx.read.json(ctx.sparkContext.makeRDD( + val df = sqlContext.read.json(sparkContext.makeRDD( """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) checkAnswer( df.select(df("`a.b`.c.`d..e`.`f`")), Row(1) ) - val df2 = ctx.read.json(ctx.sparkContext.makeRDD( + val df2 = sqlContext.read.json(sparkContext.makeRDD( """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) checkAnswer( df2.select(df2("`a b`.c.d e.f")), @@ -601,7 +648,7 @@ class DataFrameSuite extends QueryTest { } test("SPARK-7324 dropDuplicates") { - val testData = ctx.sparkContext.parallelize( + val testData = sparkContext.parallelize( (2, 1, 2) :: (1, 1, 1) :: (1, 2, 1) :: (2, 1, 2) :: (2, 2, 2) :: (2, 2, 1) :: @@ -635,63 +682,229 @@ class DataFrameSuite extends QueryTest { Seq(Row(2, 1, 2), Row(1, 1, 1))) } - test("SPARK-7276: Project collapse for continuous select") { - var df = testData - for (i <- 1 to 5) { - df = df.select($"*") - } - - import org.apache.spark.sql.catalyst.plans.logical.Project - // make sure df have at most two Projects - val p = df.logicalPlan.asInstanceOf[Project].child.asInstanceOf[Project] - assert(!p.child.isInstanceOf[Project]) - } - test("SPARK-7150 range api") { // numSlice is greater than length - val res1 = ctx.range(0, 10, 1, 15).select("id") + val res1 = sqlContext.range(0, 10, 1, 15).select("id") assert(res1.count == 10) assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res2 = ctx.range(3, 15, 3, 2).select("id") + val res2 = sqlContext.range(3, 15, 3, 2).select("id") assert(res2.count == 4) assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) - val res3 = ctx.range(1, -2).select("id") + val res3 = sqlContext.range(1, -2).select("id") assert(res3.count == 0) // start is positive, end is negative, step is negative - val res4 = ctx.range(1, -2, -2, 6).select("id") + val res4 = sqlContext.range(1, -2, -2, 6).select("id") assert(res4.count == 2) assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0))) // start, end, step are negative - val res5 = ctx.range(-3, -8, -2, 1).select("id") + val res5 = sqlContext.range(-3, -8, -2, 1).select("id") assert(res5.count == 3) assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15))) // start, end are negative, step is positive - val res6 = ctx.range(-8, -4, 2, 1).select("id") + val res6 = sqlContext.range(-8, -4, 2, 1).select("id") assert(res6.count == 2) assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14))) - val res7 = ctx.range(-10, -9, -20, 1).select("id") + val res7 = sqlContext.range(-10, -9, -20, 1).select("id") assert(res7.count == 0) - val res8 = ctx.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") + val res8 = sqlContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") assert(res8.count == 3) assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) - val res9 = ctx.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") + val res9 = sqlContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") assert(res9.count == 2) assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) // only end provided as argument - val res10 = ctx.range(10).select("id") + val res10 = sqlContext.range(10).select("id") assert(res10.count == 10) assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res11 = ctx.range(-1).select("id") + val res11 = sqlContext.range(-1).select("id") assert(res11.count == 0) } + + test("SPARK-8621: support empty string column name") { + val df = Seq(Tuple1(1)).toDF("").as("t") + // We should allow empty string as column name + df.col("") + df.col("t.``") + } + + test("SPARK-8797: sort by float column containing NaN should not crash") { + val inputData = Seq.fill(10)(Tuple1(Float.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toFloat)) + val df = Random.shuffle(inputData).toDF("a") + df.orderBy("a").collect() + } + + test("SPARK-8797: sort by double column containing NaN should not crash") { + val inputData = Seq.fill(10)(Tuple1(Double.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toDouble)) + val df = Random.shuffle(inputData).toDF("a") + df.orderBy("a").collect() + } + + test("NaN is greater than all other non-NaN numeric values") { + val maxDouble = Seq(Double.NaN, Double.PositiveInfinity, Double.MaxValue) + .map(Tuple1.apply).toDF("a").selectExpr("max(a)").first() + assert(java.lang.Double.isNaN(maxDouble.getDouble(0))) + val maxFloat = Seq(Float.NaN, Float.PositiveInfinity, Float.MaxValue) + .map(Tuple1.apply).toDF("a").selectExpr("max(a)").first() + assert(java.lang.Float.isNaN(maxFloat.getFloat(0))) + } + + test("SPARK-8072: Better Exception for Duplicate Columns") { + // only one duplicate column present + val e = intercept[org.apache.spark.sql.AnalysisException] { + Seq((1, 2, 3), (2, 3, 4), (3, 4, 5)).toDF("column1", "column2", "column1") + .write.format("parquet").save("temp") + } + assert(e.getMessage.contains("Duplicate column(s)")) + assert(e.getMessage.contains("parquet")) + assert(e.getMessage.contains("column1")) + assert(!e.getMessage.contains("column2")) + + // multiple duplicate columns present + val f = intercept[org.apache.spark.sql.AnalysisException] { + Seq((1, 2, 3, 4, 5), (2, 3, 4, 5, 6), (3, 4, 5, 6, 7)) + .toDF("column1", "column2", "column3", "column1", "column3") + .write.format("json").save("temp") + } + assert(f.getMessage.contains("Duplicate column(s)")) + assert(f.getMessage.contains("JSON")) + assert(f.getMessage.contains("column1")) + assert(f.getMessage.contains("column3")) + assert(!f.getMessage.contains("column2")) + } + + test("SPARK-6941: Better error message for inserting into RDD-based Table") { + withTempDir { dir => + + val tempParquetFile = new File(dir, "tmp_parquet") + val tempJsonFile = new File(dir, "tmp_json") + + val df = Seq(Tuple1(1)).toDF() + val insertion = Seq(Tuple1(2)).toDF("col") + + // pass case: parquet table (HadoopFsRelation) + df.write.mode(SaveMode.Overwrite).parquet(tempParquetFile.getCanonicalPath) + val pdf = sqlContext.read.parquet(tempParquetFile.getCanonicalPath) + pdf.registerTempTable("parquet_base") + insertion.write.insertInto("parquet_base") + + // pass case: json table (InsertableRelation) + df.write.mode(SaveMode.Overwrite).json(tempJsonFile.getCanonicalPath) + val jdf = sqlContext.read.json(tempJsonFile.getCanonicalPath) + jdf.registerTempTable("json_base") + insertion.write.mode(SaveMode.Overwrite).insertInto("json_base") + + // error cases: insert into an RDD + df.registerTempTable("rdd_base") + val e1 = intercept[AnalysisException] { + insertion.write.insertInto("rdd_base") + } + assert(e1.getMessage.contains("Inserting into an RDD-based table is not allowed.")) + + // error case: insert into a logical plan that is not a LeafNode + val indirectDS = pdf.select("_1").filter($"_1" > 5) + indirectDS.registerTempTable("indirect_ds") + val e2 = intercept[AnalysisException] { + insertion.write.insertInto("indirect_ds") + } + assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed.")) + + // error case: insert into an OneRowRelation + new DataFrame(sqlContext, OneRowRelation).registerTempTable("one_row") + val e3 = intercept[AnalysisException] { + insertion.write.insertInto("one_row") + } + assert(e3.getMessage.contains("Inserting into an RDD-based table is not allowed.")) + } + } + + test("SPARK-8608: call `show` on local DataFrame with random columns should return same value") { + // Make sure we can pass this test for both codegen mode and interpreted mode. + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + val df = testData.select(rand(33)) + assert(df.showString(5) == df.showString(5)) + } + + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { + val df = testData.select(rand(33)) + assert(df.showString(5) == df.showString(5)) + } + + // We will reuse the same Expression object for LocalRelation. + val df = (1 to 10).map(Tuple1.apply).toDF().select(rand(33)) + assert(df.showString(5) == df.showString(5)) + } + + test("SPARK-8609: local DataFrame with random columns should return same value after sort") { + // Make sure we can pass this test for both codegen mode and interpreted mode. + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + checkAnswer(testData.sort(rand(33)), testData.sort(rand(33))) + } + + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { + checkAnswer(testData.sort(rand(33)), testData.sort(rand(33))) + } + + // We will reuse the same Expression object for LocalRelation. + val df = (1 to 10).map(Tuple1.apply).toDF() + checkAnswer(df.sort(rand(33)), df.sort(rand(33))) + } + + test("SPARK-9083: sort with non-deterministic expressions") { + import org.apache.spark.util.random.XORShiftRandom + + val seed = 33 + val df = (1 to 100).map(Tuple1.apply).toDF("i") + val random = new XORShiftRandom(seed) + val expected = (1 to 100).map(_ -> random.nextDouble()).sortBy(_._2).map(_._1) + val actual = df.sort(rand(seed)).collect().map(_.getInt(0)) + assert(expected === actual) + } + + test("SPARK-9323: DataFrame.orderBy should support nested column name") { + val df = sqlContext.read.json(sparkContext.makeRDD( + """{"a": {"b": 1}}""" :: Nil)) + checkAnswer(df.orderBy("a.b"), Row(Row(1))) + } + + test("SPARK-9950: correctly analyze grouping/aggregating on struct fields") { + val df = Seq(("x", (1, 1)), ("y", (2, 2))).toDF("a", "b") + checkAnswer(df.groupBy("b._1").agg(sum("b._2")), Row(1, 1) :: Row(2, 2) :: Nil) + } + + test("SPARK-10093: Avoid transformations on executors") { + val df = Seq((1, 1)).toDF("a", "b") + df.where($"a" === 1) + .select($"a", $"b", struct($"b")) + .orderBy("a") + .select(struct($"b")) + .collect() + } + + test("SPARK-10034: Sort on Aggregate with aggregation expression named 'aggOrdering'") { + val df = Seq(1 -> 2).toDF("i", "j") + val query = df.groupBy('i) + .agg(max('j).as("aggOrdering")) + .orderBy(sum('j)) + checkAnswer(query, Row(1, 2)) + } + + test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") { + val input = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + (1 to 10).map(i => s"""{"id": $i}"""))) + + val df = input.select($"id", rand(0).as('r)) + df.as("a").join(df.filter($"r" < 0.5).as("b"), $"a.id" === $"b.id").collect().foreach { row => + assert(row.getDouble(1) - row.getDouble(3) === 0.0 +- 0.001) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala new file mode 100644 index 000000000000..7ae12a7895f7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ + +/** + * An end-to-end test suite specifically for testing Tungsten (Unsafe/CodeGen) mode. + * + * This is here for now so I can make sure Tungsten project is tested without refactoring existing + * end-to-end test infra. In the long run this should just go away. + */ +class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("test simple types") { + withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { + val df = sparkContext.parallelize(Seq((1, 2))).toDF("a", "b") + assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2)) + } + } + + test("test struct type") { + withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { + val struct = Row(1, 2L, 3.0F, 3.0) + val data = sparkContext.parallelize(Seq(Row(1, struct))) + + val schema = new StructType() + .add("a", IntegerType) + .add("b", + new StructType() + .add("b1", IntegerType) + .add("b2", LongType) + .add("b3", FloatType) + .add("b4", DoubleType)) + + val df = sqlContext.createDataFrame(data, schema) + assert(df.select("b").first() === Row(struct)) + } + } + + test("test nested struct type") { + withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { + val innerStruct = Row(1, "abcd") + val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg") + val data = sparkContext.parallelize(Seq(Row(1, outerStruct))) + + val schema = new StructType() + .add("a", IntegerType) + .add("b", + new StructType() + .add("b1", IntegerType) + .add("b2", LongType) + .add("b3", FloatType) + .add("b4", DoubleType) + .add("b5", new StructType() + .add("b5a", IntegerType) + .add("b5b", StringType)) + .add("b6", StringType)) + + val df = sqlContext.createDataFrame(data, schema) + assert(df.select("b").first() === Row(outerStruct)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala new file mode 100644 index 000000000000..9080c53c491a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -0,0 +1,494 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.sql.{Timestamp, Date} +import java.text.SimpleDateFormat + +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.unsafe.types.CalendarInterval + +class DateFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("function current_date") { + val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") + val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val d1 = DateTimeUtils.fromJavaDate(df1.select(current_date()).collect().head.getDate(0)) + val d2 = DateTimeUtils.fromJavaDate( + sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0)) + val d3 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1) + } + + // This is a bad test. SPARK-9196 will fix it and re-enable it. + ignore("function current_timestamp") { + val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") + checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1)) + // Execution in one query should return the same value + checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), + Row(true)) + assert(math.abs(sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp( + 0).getTime - System.currentTimeMillis()) < 5000) + } + + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val sdfDate = new SimpleDateFormat("yyyy-MM-dd") + val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime) + val ts = new Timestamp(sdf.parse("2013-04-08 13:10:15").getTime) + + test("timestamp comparison with date strings") { + val df = Seq( + (1, Timestamp.valueOf("2015-01-01 00:00:00")), + (2, Timestamp.valueOf("2014-01-01 00:00:00"))).toDF("i", "t") + + checkAnswer( + df.select("t").filter($"t" <= "2014-06-01"), + Row(Timestamp.valueOf("2014-01-01 00:00:00")) :: Nil) + + + checkAnswer( + df.select("t").filter($"t" >= "2014-06-01"), + Row(Timestamp.valueOf("2015-01-01 00:00:00")) :: Nil) + } + + test("date comparison with date strings") { + val df = Seq( + (1, Date.valueOf("2015-01-01")), + (2, Date.valueOf("2014-01-01"))).toDF("i", "t") + + checkAnswer( + df.select("t").filter($"t" <= "2014-06-01"), + Row(Date.valueOf("2014-01-01")) :: Nil) + + + checkAnswer( + df.select("t").filter($"t" >= "2015"), + Row(Date.valueOf("2015-01-01")) :: Nil) + } + + test("date format") { + val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(date_format($"a", "y"), date_format($"b", "y"), date_format($"c", "y")), + Row("2015", "2015", "2013")) + + checkAnswer( + df.selectExpr("date_format(a, 'y')", "date_format(b, 'y')", "date_format(c, 'y')"), + Row("2015", "2015", "2013")) + } + + test("year") { + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(year($"a"), year($"b"), year($"c")), + Row(2015, 2015, 2013)) + + checkAnswer( + df.selectExpr("year(a)", "year(b)", "year(c)"), + Row(2015, 2015, 2013)) + } + + test("quarter") { + val ts = new Timestamp(sdf.parse("2013-11-08 13:10:15").getTime) + + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(quarter($"a"), quarter($"b"), quarter($"c")), + Row(2, 2, 4)) + + checkAnswer( + df.selectExpr("quarter(a)", "quarter(b)", "quarter(c)"), + Row(2, 2, 4)) + } + + test("month") { + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(month($"a"), month($"b"), month($"c")), + Row(4, 4, 4)) + + checkAnswer( + df.selectExpr("month(a)", "month(b)", "month(c)"), + Row(4, 4, 4)) + } + + test("dayofmonth") { + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(dayofmonth($"a"), dayofmonth($"b"), dayofmonth($"c")), + Row(8, 8, 8)) + + checkAnswer( + df.selectExpr("day(a)", "day(b)", "dayofmonth(c)"), + Row(8, 8, 8)) + } + + test("dayofyear") { + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(dayofyear($"a"), dayofyear($"b"), dayofyear($"c")), + Row(98, 98, 98)) + + checkAnswer( + df.selectExpr("dayofyear(a)", "dayofyear(b)", "dayofyear(c)"), + Row(98, 98, 98)) + } + + test("hour") { + val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(hour($"a"), hour($"b"), hour($"c")), + Row(0, 13, 13)) + + checkAnswer( + df.selectExpr("hour(a)", "hour(b)", "hour(c)"), + Row(0, 13, 13)) + } + + test("minute") { + val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(minute($"a"), minute($"b"), minute($"c")), + Row(0, 10, 10)) + + checkAnswer( + df.selectExpr("minute(a)", "minute(b)", "minute(c)"), + Row(0, 10, 10)) + } + + test("second") { + val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(second($"a"), second($"b"), second($"c")), + Row(0, 15, 15)) + + checkAnswer( + df.selectExpr("second(a)", "second(b)", "second(c)"), + Row(0, 15, 15)) + } + + test("weekofyear") { + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(weekofyear($"a"), weekofyear($"b"), weekofyear($"c")), + Row(15, 15, 15)) + + checkAnswer( + df.selectExpr("weekofyear(a)", "weekofyear(b)", "weekofyear(c)"), + Row(15, 15, 15)) + } + + test("function date_add") { + val st1 = "2015-06-01 12:34:56" + val st2 = "2015-06-02 12:34:56" + val t1 = Timestamp.valueOf(st1) + val t2 = Timestamp.valueOf(st2) + val s1 = "2015-06-01" + val s2 = "2015-06-02" + val d1 = Date.valueOf(s1) + val d2 = Date.valueOf(s2) + val df = Seq((t1, d1, s1, st1), (t2, d2, s2, st2)).toDF("t", "d", "s", "ss") + checkAnswer( + df.select(date_add(col("d"), 1)), + Seq(Row(Date.valueOf("2015-06-02")), Row(Date.valueOf("2015-06-03")))) + checkAnswer( + df.select(date_add(col("t"), 3)), + Seq(Row(Date.valueOf("2015-06-04")), Row(Date.valueOf("2015-06-05")))) + checkAnswer( + df.select(date_add(col("s"), 5)), + Seq(Row(Date.valueOf("2015-06-06")), Row(Date.valueOf("2015-06-07")))) + checkAnswer( + df.select(date_add(col("ss"), 7)), + Seq(Row(Date.valueOf("2015-06-08")), Row(Date.valueOf("2015-06-09")))) + + checkAnswer(df.selectExpr("DATE_ADD(null, 1)"), Seq(Row(null), Row(null))) + checkAnswer( + df.selectExpr("""DATE_ADD(d, 1)"""), + Seq(Row(Date.valueOf("2015-06-02")), Row(Date.valueOf("2015-06-03")))) + } + + test("function date_sub") { + val st1 = "2015-06-01 12:34:56" + val st2 = "2015-06-02 12:34:56" + val t1 = Timestamp.valueOf(st1) + val t2 = Timestamp.valueOf(st2) + val s1 = "2015-06-01" + val s2 = "2015-06-02" + val d1 = Date.valueOf(s1) + val d2 = Date.valueOf(s2) + val df = Seq((t1, d1, s1, st1), (t2, d2, s2, st2)).toDF("t", "d", "s", "ss") + checkAnswer( + df.select(date_sub(col("d"), 1)), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + checkAnswer( + df.select(date_sub(col("t"), 1)), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + checkAnswer( + df.select(date_sub(col("s"), 1)), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + checkAnswer( + df.select(date_sub(col("ss"), 1)), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + checkAnswer( + df.select(date_sub(lit(null), 1)).limit(1), Row(null)) + + checkAnswer(df.selectExpr("""DATE_SUB(d, null)"""), Seq(Row(null), Row(null))) + checkAnswer( + df.selectExpr("""DATE_SUB(d, 1)"""), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + } + + test("time_add") { + val t1 = Timestamp.valueOf("2015-07-31 23:59:59") + val t2 = Timestamp.valueOf("2015-12-31 00:00:00") + val d1 = Date.valueOf("2015-07-31") + val d2 = Date.valueOf("2015-12-31") + val i = new CalendarInterval(2, 2000000L) + val df = Seq((1, t1, d1), (3, t2, d2)).toDF("n", "t", "d") + checkAnswer( + df.selectExpr(s"d + $i"), + Seq(Row(Date.valueOf("2015-09-30")), Row(Date.valueOf("2016-02-29")))) + checkAnswer( + df.selectExpr(s"t + $i"), + Seq(Row(Timestamp.valueOf("2015-10-01 00:00:01")), + Row(Timestamp.valueOf("2016-02-29 00:00:02")))) + } + + test("time_sub") { + val t1 = Timestamp.valueOf("2015-10-01 00:00:01") + val t2 = Timestamp.valueOf("2016-02-29 00:00:02") + val d1 = Date.valueOf("2015-09-30") + val d2 = Date.valueOf("2016-02-29") + val i = new CalendarInterval(2, 2000000L) + val df = Seq((1, t1, d1), (3, t2, d2)).toDF("n", "t", "d") + checkAnswer( + df.selectExpr(s"d - $i"), + Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-12-30")))) + checkAnswer( + df.selectExpr(s"t - $i"), + Seq(Row(Timestamp.valueOf("2015-07-31 23:59:59")), + Row(Timestamp.valueOf("2015-12-31 00:00:00")))) + } + + test("function add_months") { + val d1 = Date.valueOf("2015-08-31") + val d2 = Date.valueOf("2015-02-28") + val df = Seq((1, d1), (2, d2)).toDF("n", "d") + checkAnswer( + df.select(add_months(col("d"), 1)), + Seq(Row(Date.valueOf("2015-09-30")), Row(Date.valueOf("2015-03-31")))) + checkAnswer( + df.selectExpr("add_months(d, -1)"), + Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-01-31")))) + } + + test("function months_between") { + val d1 = Date.valueOf("2015-07-31") + val d2 = Date.valueOf("2015-02-16") + val t1 = Timestamp.valueOf("2014-09-30 23:30:00") + val t2 = Timestamp.valueOf("2015-09-16 12:00:00") + val s1 = "2014-09-15 11:30:00" + val s2 = "2015-10-01 00:00:00" + val df = Seq((t1, d1, s1), (t2, d2, s2)).toDF("t", "d", "s") + checkAnswer(df.select(months_between(col("t"), col("d"))), Seq(Row(-10.0), Row(7.0))) + checkAnswer(df.selectExpr("months_between(t, s)"), Seq(Row(0.5), Row(-0.5))) + } + + test("function last_day") { + val df1 = Seq((1, "2015-07-23"), (2, "2015-07-24")).toDF("i", "d") + val df2 = Seq((1, "2015-07-23 00:11:22"), (2, "2015-07-24 11:22:33")).toDF("i", "t") + checkAnswer( + df1.select(last_day(col("d"))), + Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-07-31")))) + checkAnswer( + df2.select(last_day(col("t"))), + Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-07-31")))) + } + + test("function next_day") { + val df1 = Seq(("mon", "2015-07-23"), ("tuesday", "2015-07-20")).toDF("dow", "d") + val df2 = Seq(("th", "2015-07-23 00:11:22"), ("xx", "2015-07-24 11:22:33")).toDF("dow", "t") + checkAnswer( + df1.select(next_day(col("d"), "MONDAY")), + Seq(Row(Date.valueOf("2015-07-27")), Row(Date.valueOf("2015-07-27")))) + checkAnswer( + df2.select(next_day(col("t"), "th")), + Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-07-30")))) + } + + test("function to_date") { + val d1 = Date.valueOf("2015-07-22") + val d2 = Date.valueOf("2015-07-01") + val t1 = Timestamp.valueOf("2015-07-22 10:00:00") + val t2 = Timestamp.valueOf("2014-12-31 23:59:59") + val s1 = "2015-07-22 10:00:00" + val s2 = "2014-12-31" + val df = Seq((d1, t1, s1), (d2, t2, s2)).toDF("d", "t", "s") + + checkAnswer( + df.select(to_date(col("t"))), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")))) + checkAnswer( + df.select(to_date(col("d"))), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2015-07-01")))) + checkAnswer( + df.select(to_date(col("s"))), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")))) + + checkAnswer( + df.selectExpr("to_date(t)"), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")))) + checkAnswer( + df.selectExpr("to_date(d)"), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2015-07-01")))) + checkAnswer( + df.selectExpr("to_date(s)"), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")))) + } + + test("function trunc") { + val df = Seq( + (1, Timestamp.valueOf("2015-07-22 10:00:00")), + (2, Timestamp.valueOf("2014-12-31 00:00:00"))).toDF("i", "t") + + checkAnswer( + df.select(trunc(col("t"), "YY")), + Seq(Row(Date.valueOf("2015-01-01")), Row(Date.valueOf("2014-01-01")))) + + checkAnswer( + df.selectExpr("trunc(t, 'Month')"), + Seq(Row(Date.valueOf("2015-07-01")), Row(Date.valueOf("2014-12-01")))) + } + + test("from_unixtime") { + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2) + val fmt3 = "yy-MM-dd HH-mm-ss" + val sdf3 = new SimpleDateFormat(fmt3) + val df = Seq((1000, "yyyy-MM-dd HH:mm:ss.SSS"), (-1000, "yy-MM-dd HH-mm-ss")).toDF("a", "b") + checkAnswer( + df.select(from_unixtime(col("a"))), + Seq(Row(sdf1.format(new Timestamp(1000000))), Row(sdf1.format(new Timestamp(-1000000))))) + checkAnswer( + df.select(from_unixtime(col("a"), fmt2)), + Seq(Row(sdf2.format(new Timestamp(1000000))), Row(sdf2.format(new Timestamp(-1000000))))) + checkAnswer( + df.select(from_unixtime(col("a"), fmt3)), + Seq(Row(sdf3.format(new Timestamp(1000000))), Row(sdf3.format(new Timestamp(-1000000))))) + checkAnswer( + df.selectExpr("from_unixtime(a)"), + Seq(Row(sdf1.format(new Timestamp(1000000))), Row(sdf1.format(new Timestamp(-1000000))))) + checkAnswer( + df.selectExpr(s"from_unixtime(a, '$fmt2')"), + Seq(Row(sdf2.format(new Timestamp(1000000))), Row(sdf2.format(new Timestamp(-1000000))))) + checkAnswer( + df.selectExpr(s"from_unixtime(a, '$fmt3')"), + Seq(Row(sdf3.format(new Timestamp(1000000))), Row(sdf3.format(new Timestamp(-1000000))))) + } + + test("unix_timestamp") { + val date1 = Date.valueOf("2015-07-24") + val date2 = Date.valueOf("2015-07-25") + val ts1 = Timestamp.valueOf("2015-07-24 10:00:00.3") + val ts2 = Timestamp.valueOf("2015-07-25 02:02:02.2") + val s1 = "2015/07/24 10:00:00.5" + val s2 = "2015/07/25 02:02:02.6" + val ss1 = "2015-07-24 10:00:00" + val ss2 = "2015-07-25 02:02:02" + val fmt = "yyyy/MM/dd HH:mm:ss.S" + val df = Seq((date1, ts1, s1, ss1), (date2, ts2, s2, ss2)).toDF("d", "ts", "s", "ss") + checkAnswer(df.select(unix_timestamp(col("ts"))), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.select(unix_timestamp(col("ss"))), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.select(unix_timestamp(col("d"), fmt)), Seq( + Row(date1.getTime / 1000L), Row(date2.getTime / 1000L))) + checkAnswer(df.select(unix_timestamp(col("s"), fmt)), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.selectExpr("unix_timestamp(ts)"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.selectExpr("unix_timestamp(ss)"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.selectExpr(s"unix_timestamp(d, '$fmt')"), Seq( + Row(date1.getTime / 1000L), Row(date2.getTime / 1000L))) + checkAnswer(df.selectExpr(s"unix_timestamp(s, '$fmt')"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + } + + test("datediff") { + val df = Seq( + (Date.valueOf("2015-07-24"), Timestamp.valueOf("2015-07-24 01:00:00"), + "2015-07-23", "2015-07-23 03:00:00"), + (Date.valueOf("2015-07-25"), Timestamp.valueOf("2015-07-25 02:00:00"), + "2015-07-24", "2015-07-24 04:00:00") + ).toDF("a", "b", "c", "d") + checkAnswer(df.select(datediff(col("a"), col("b"))), Seq(Row(0), Row(0))) + checkAnswer(df.select(datediff(col("a"), col("c"))), Seq(Row(1), Row(1))) + checkAnswer(df.select(datediff(col("d"), col("b"))), Seq(Row(-1), Row(-1))) + checkAnswer(df.selectExpr("datediff(a, d)"), Seq(Row(1), Row(1))) + } + + test("from_utc_timestamp") { + val df = Seq( + (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00"), + (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00") + ).toDF("a", "b") + checkAnswer( + df.select(from_utc_timestamp(col("a"), "PST")), + Seq( + Row(Timestamp.valueOf("2015-07-23 17:00:00")), + Row(Timestamp.valueOf("2015-07-24 17:00:00")))) + checkAnswer( + df.select(from_utc_timestamp(col("b"), "PST")), + Seq( + Row(Timestamp.valueOf("2015-07-23 17:00:00")), + Row(Timestamp.valueOf("2015-07-24 17:00:00")))) + } + + test("to_utc_timestamp") { + val df = Seq( + (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00"), + (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00") + ).toDF("a", "b") + checkAnswer( + df.select(to_utc_timestamp(col("a"), "PST")), + Seq( + Row(Timestamp.valueOf("2015-07-24 07:00:00")), + Row(Timestamp.valueOf("2015-07-25 07:00:00")))) + checkAnswer( + df.select(to_utc_timestamp(col("b"), "PST")), + Seq( + Row(Timestamp.valueOf("2015-07-24 07:00:00")), + Row(Timestamp.valueOf("2015-07-25 07:00:00")))) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala new file mode 100644 index 000000000000..78a98798eff6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Literal, GenericInternalRow, Attribute} +import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.{Row, Strategy, QueryTest} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.unsafe.types.UTF8String + +case class FastOperator(output: Seq[Attribute]) extends SparkPlan { + + override protected def doExecute(): RDD[InternalRow] = { + val str = Literal("so fast").value + val row = new GenericInternalRow(Array[Any](str)) + sparkContext.parallelize(Seq(row)) + } + + override def children: Seq[SparkPlan] = Nil +} + +object TestStrategy extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case Project(Seq(attr), _) if attr.name == "a" => + FastOperator(attr.toAttribute :: Nil) :: Nil + case _ => Nil + } +} + +class ExtraStrategiesSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("insert an extraStrategy") { + try { + sqlContext.experimental.extraStrategies = TestStrategy :: Nil + + val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b") + checkAnswer( + df.select("a"), + Row("so fast")) + + checkAnswer( + df.select("a", "b"), + Row("so slow", 1)) + } finally { + sqlContext.experimental.extraStrategies = Nil + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 20390a554430..7a027e13089e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -17,42 +17,39 @@ package org.apache.spark.sql -import org.scalatest.BeforeAndAfterEach - -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.test.SharedSQLContext -class JoinSuite extends QueryTest with BeforeAndAfterEach { - // Ensures tables are loaded. - TestData +class JoinSuite extends QueryTest with SharedSQLContext { + import testImplicits._ - lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.logicalPlanToSparkQuery + setupTestData() test("equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = ctx.planner.HashJoin(join) + val planned = sqlContext.planner.EquiJoinSelection(join) assert(planned.size === 1) } def assertJoin(sqlString: String, c: Class[_]): Any = { - val df = ctx.sql(sqlString) + val df = sql(sqlString) val physical = df.queryExecution.sparkPlan val operators = physical.collect { case j: ShuffledHashJoin => j - case j: HashOuterJoin => j + case j: ShuffledHashOuterJoin => j case j: LeftSemiJoinHash => j case j: BroadcastHashJoin => j + case j: BroadcastHashOuterJoin => j case j: LeftSemiJoinBNL => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j case j: BroadcastLeftSemiJoinHash => j case j: SortMergeJoin => j + case j: SortMergeOuterJoin => j } assert(operators.size === 1) @@ -62,9 +59,8 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("join operator selection") { - ctx.cacheManager.clearCache() + sqlContext.cacheManager.clearCache() - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), @@ -78,15 +74,16 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]), ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]), - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[HashOuterJoin]), + classOf[SortMergeOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[HashOuterJoin]), - ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin]), + classOf[SortMergeOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", + classOf[SortMergeOuterJoin]), ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", @@ -94,50 +91,83 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + Seq( + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", + classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", + classOf[ShuffledHashJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", + classOf[ShuffledHashOuterJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } + } + + test("SortMergeJoin shouldn't work on unsortable columns") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { Seq( - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) + ("SELECT * FROM arrayData JOIN complexData ON data = a", classOf[ShuffledHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } } test("broadcasted hash join operator selection") { - ctx.cacheManager.clearCache() - ctx.sql("CACHE TABLE testData") + sqlContext.cacheManager.clearCache() + sql("CACHE TABLE testData") + for (sortMergeJoinEnabled <- Seq(true, false)) { + withClue(s"sortMergeJoinEnabled=$sortMergeJoinEnabled") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> s"$sortMergeJoinEnabled") { + Seq( + ("SELECT * FROM testData join testData2 ON key = a", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key = 2", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key = 2", + classOf[BroadcastHashJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } + } + } + sql("UNCACHE TABLE testData") + } - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled - Seq( - ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key = 2", - classOf[BroadcastHashJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + test("broadcasted hash outer join operator selection") { + sqlContext.cacheManager.clearCache() + sql("CACHE TABLE testData") + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { Seq( - ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a and key = 2", - classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key = 2", - classOf[BroadcastHashJoin]) + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", + classOf[SortMergeOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[BroadcastHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[BroadcastHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } - - ctx.sql("UNCACHE TABLE testData") + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + Seq( + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[BroadcastHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[BroadcastHashOuterJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } + sql("UNCACHE TABLE testData") } test("multiple-key equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = ctx.planner.HashJoin(join) + val planned = sqlContext.planner.EquiJoinSelection(join) assert(planned.size === 1) } @@ -242,7 +272,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are choosing left.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - ctx.sql( + sql( """ |SELECT l.N, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -256,7 +286,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(6, 1) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT r.a, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -302,7 +332,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are choosing right.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - ctx.sql( + sql( """ |SELECT l.a, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -311,7 +341,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 6)) checkAnswer( - ctx.sql( + sql( """ |SELECT r.N, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -363,7 +393,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator. checkAnswer( - ctx.sql( + sql( """ |SELECT l.a, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -372,7 +402,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 10)) checkAnswer( - ctx.sql( + sql( """ |SELECT r.N, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -387,7 +417,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 4) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT l.N, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) @@ -402,7 +432,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 4) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT r.a, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) @@ -412,32 +442,31 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("broadcasted left semi join operator selection") { - ctx.cacheManager.clearCache() - ctx.sql("CACHE TABLE testData") - val tmp = ctx.conf.autoBroadcastJoinThreshold + sqlContext.cacheManager.clearCache() + sql("CACHE TABLE testData") - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=1000000000") - Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", - classOf[BroadcastLeftSemiJoinHash]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", + classOf[BroadcastLeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } } - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") - - Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } } - ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp) - ctx.sql("UNCACHE TABLE testData") + sql("UNCACHE TABLE testData") } test("left semi join") { - val df = ctx.sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") + val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") checkAnswer(df, Row(1, 1) :: Row(1, 2) :: @@ -445,6 +474,5 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala new file mode 100644 index 000000000000..045fea82e4c8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext + +class JsonFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("function get_json_object") { + val df: DataFrame = Seq(("""{"name": "alice", "age": 5}""", "")).toDF("a", "b") + checkAnswer( + df.selectExpr("get_json_object(a, '$.name')", "get_json_object(a, '$.age')"), + Row("alice", "5")) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 2089660c52bf..eab0fbb196eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} -class ListTablesSuite extends QueryTest with BeforeAndAfter { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext { + import testImplicits._ private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value") @@ -33,33 +32,33 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter { } after { - ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + sqlContext.catalog.unregisterTable(Seq("ListTablesSuiteTable")) } test("get all tables") { checkAnswer( - ctx.tables().filter("tableName = 'ListTablesSuiteTable'"), + sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( - ctx.sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), + sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) - assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + sqlContext.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } test("getting all Tables with a database name has no impact on returned table names") { checkAnswer( - ctx.tables("DB").filter("tableName = 'ListTablesSuiteTable'"), + sqlContext.tables("DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( - ctx.sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), + sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) - assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + sqlContext.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } test("query the returned DataFrame of tables") { @@ -67,20 +66,20 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter { StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: Nil) - Seq(ctx.tables(), ctx.sql("SHOW TABLes")).foreach { + Seq(sqlContext.tables(), sql("SHOW TABLes")).foreach { case tableDF => assert(expectedSchema === tableDF.schema) tableDF.registerTempTable("tables") checkAnswer( - ctx.sql( + sql( "SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"), Row(true, "ListTablesSuiteTable") ) checkAnswer( - ctx.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), + sqlContext.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), Row("tables", true)) - ctx.dropTempTable("tables") + sqlContext.dropTempTable("tables") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 2768d7dfc803..30289c3c1d09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -19,19 +19,16 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions.{log => logarithm} - +import org.apache.spark.sql.test.SharedSQLContext private object MathExpressionsTestData { case class DoubleData(a: java.lang.Double, b: java.lang.Double) case class NullDoubles(a: java.lang.Double) } -class MathExpressionsSuite extends QueryTest { - +class MathExpressionsSuite extends QueryTest with SharedSQLContext { import MathExpressionsTestData._ - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ + import testImplicits._ private lazy val doubleData = (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1)).toDF() @@ -69,12 +66,7 @@ class MathExpressionsSuite extends QueryTest { if (f(-1) === math.log1p(-1)) { checkAnswer( nnDoubleData.select(c('b)), - (1 to 9).map(n => Row(f(n * -0.1))) :+ Row(Double.NegativeInfinity) - ) - } else { - checkAnswer( - nnDoubleData.select(c('b)), - (1 to 10).map(n => Row(null)) + (1 to 9).map(n => Row(f(n * -0.1))) :+ Row(null) ) } @@ -155,7 +147,7 @@ class MathExpressionsSuite extends QueryTest { test("toDegrees") { testOneToOneMathFunction(toDegrees, math.toDegrees) checkAnswer( - ctx.sql("SELECT degrees(0), degrees(1), degrees(1.5)"), + sql("SELECT degrees(0), degrees(1), degrees(1.5)"), Seq((1, 2)).toDF().select(toDegrees(lit(0)), toDegrees(lit(1)), toDegrees(lit(1.5))) ) } @@ -163,7 +155,7 @@ class MathExpressionsSuite extends QueryTest { test("toRadians") { testOneToOneMathFunction(toRadians, math.toRadians) checkAnswer( - ctx.sql("SELECT radians(0), radians(1), radians(1.5)"), + sql("SELECT radians(0), radians(1), radians(1.5)"), Seq((1, 2)).toDF().select(toRadians(lit(0)), toRadians(lit(1)), toRadians(lit(1.5))) ) } @@ -175,18 +167,58 @@ class MathExpressionsSuite extends QueryTest { test("ceil and ceiling") { testOneToOneMathFunction(ceil, math.ceil) checkAnswer( - ctx.sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"), + sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"), Row(0.0, 1.0, 2.0)) } + test("conv") { + val df = Seq(("333", 10, 2)).toDF("num", "fromBase", "toBase") + checkAnswer(df.select(conv('num, 10, 16)), Row("14D")) + checkAnswer(df.select(conv(lit(100), 2, 16)), Row("4")) + checkAnswer(df.select(conv(lit(3122234455L), 10, 16)), Row("BA198457")) + checkAnswer(df.selectExpr("conv(num, fromBase, toBase)"), Row("101001101")) + checkAnswer(df.selectExpr("""conv("100", 2, 10)"""), Row("4")) + checkAnswer(df.selectExpr("""conv("-10", 16, -10)"""), Row("-16")) + checkAnswer( + df.selectExpr("""conv("9223372036854775807", 36, -16)"""), Row("-1")) // for overflow + } + test("floor") { testOneToOneMathFunction(floor, math.floor) } + test("factorial") { + val df = (0 to 5).map(i => (i, i)).toDF("a", "b") + checkAnswer( + df.select(factorial('a)), + Seq(Row(1), Row(1), Row(2), Row(6), Row(24), Row(120)) + ) + checkAnswer( + df.selectExpr("factorial(a)"), + Seq(Row(1), Row(1), Row(2), Row(6), Row(24), Row(120)) + ) + } + test("rint") { testOneToOneMathFunction(rint, math.rint) } + test("round") { + val df = Seq(5, 55, 555).map(Tuple1(_)).toDF("a") + checkAnswer( + df.select(round('a), round('a, -1), round('a, -2)), + Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600)) + ) + + val pi = 3.1415 + checkAnswer( + sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + + s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"), + Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), + BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) + ) + } + test("exp") { testOneToOneMathFunction(exp, math.exp) } @@ -199,7 +231,7 @@ class MathExpressionsSuite extends QueryTest { testOneToOneMathFunction[Double](signum, math.signum) checkAnswer( - ctx.sql("SELECT sign(10), signum(-11)"), + sql("SELECT sign(10), signum(-11)"), Row(1, -1)) } @@ -207,11 +239,34 @@ class MathExpressionsSuite extends QueryTest { testTwoToOneMathFunction(pow, pow, math.pow) checkAnswer( - ctx.sql("SELECT pow(1, 2), power(2, 1)"), + sql("SELECT pow(1, 2), power(2, 1)"), Seq((1, 2)).toDF().select(pow(lit(1), lit(2)), pow(lit(2), lit(1))) ) } + test("hex") { + val data = Seq((28, -28, 100800200404L, "hello")).toDF("a", "b", "c", "d") + checkAnswer(data.select(hex('a)), Seq(Row("1C"))) + checkAnswer(data.select(hex('b)), Seq(Row("FFFFFFFFFFFFFFE4"))) + checkAnswer(data.select(hex('c)), Seq(Row("177828FED4"))) + checkAnswer(data.select(hex('d)), Seq(Row("68656C6C6F"))) + checkAnswer(data.selectExpr("hex(a)"), Seq(Row("1C"))) + checkAnswer(data.selectExpr("hex(b)"), Seq(Row("FFFFFFFFFFFFFFE4"))) + checkAnswer(data.selectExpr("hex(c)"), Seq(Row("177828FED4"))) + checkAnswer(data.selectExpr("hex(d)"), Seq(Row("68656C6C6F"))) + checkAnswer(data.selectExpr("hex(cast(d as binary))"), Seq(Row("68656C6C6F"))) + } + + test("unhex") { + val data = Seq(("1C", "737472696E67")).toDF("a", "b") + checkAnswer(data.select(unhex('a)), Row(Array[Byte](28.toByte))) + checkAnswer(data.select(unhex('b)), Row("string".getBytes)) + checkAnswer(data.selectExpr("unhex(a)"), Row(Array[Byte](28.toByte))) + checkAnswer(data.selectExpr("unhex(b)"), Row("string".getBytes)) + checkAnswer(data.selectExpr("""unhex("##")"""), Row(null)) + checkAnswer(data.selectExpr("""unhex("G123")"""), Row(null)) + } + test("hypot") { testTwoToOneMathFunction(hypot, hypot, math.hypot) } @@ -223,7 +278,7 @@ class MathExpressionsSuite extends QueryTest { test("log / ln") { testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log) checkAnswer( - ctx.sql("SELECT ln(0), ln(1), ln(1.5)"), + sql("SELECT ln(0), ln(1), ln(1.5)"), Seq((1, 2)).toDF().select(logarithm(lit(0)), logarithm(lit(1)), logarithm(lit(1.5))) ) } @@ -236,6 +291,57 @@ class MathExpressionsSuite extends QueryTest { testOneToOneNonNegativeMathFunction(log1p, math.log1p) } + test("shift left") { + val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((21, 21, 21, 21, 21, null)) + .toDF("a", "b", "c", "d", "e", "f") + + checkAnswer( + df.select( + shiftLeft('a, 1), shiftLeft('b, 1), shiftLeft('c, 1), shiftLeft('d, 1), + shiftLeft('f, 1)), + Row(42.toLong, 42, 42.toShort, 42.toByte, null)) + + checkAnswer( + df.selectExpr( + "shiftLeft(a, 1)", "shiftLeft(b, 1)", "shiftLeft(b, 1)", "shiftLeft(d, 1)", + "shiftLeft(f, 1)"), + Row(42.toLong, 42, 42.toShort, 42.toByte, null)) + } + + test("shift right") { + val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((42, 42, 42, 42, 42, null)) + .toDF("a", "b", "c", "d", "e", "f") + + checkAnswer( + df.select( + shiftRight('a, 1), shiftRight('b, 1), shiftRight('c, 1), shiftRight('d, 1), + shiftRight('f, 1)), + Row(21.toLong, 21, 21.toShort, 21.toByte, null)) + + checkAnswer( + df.selectExpr( + "shiftRight(a, 1)", "shiftRight(b, 1)", "shiftRight(c, 1)", "shiftRight(d, 1)", + "shiftRight(f, 1)"), + Row(21.toLong, 21, 21.toShort, 21.toByte, null)) + } + + test("shift right unsigned") { + val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((-42, 42, 42, 42, 42, null)) + .toDF("a", "b", "c", "d", "e", "f") + + checkAnswer( + df.select( + shiftRightUnsigned('a, 1), shiftRightUnsigned('b, 1), shiftRightUnsigned('c, 1), + shiftRightUnsigned('d, 1), shiftRightUnsigned('f, 1)), + Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null)) + + checkAnswer( + df.selectExpr( + "shiftRightUnsigned(a, 1)", "shiftRightUnsigned(b, 1)", "shiftRightUnsigned(c, 1)", + "shiftRightUnsigned(d, 1)", "shiftRightUnsigned(f, 1)"), + Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null)) + } + test("binary log") { val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b") checkAnswer( @@ -267,7 +373,7 @@ class MathExpressionsSuite extends QueryTest { df.select(log2("b") + log2("a")), Row(1)) - checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null)) + checkAnswer(sql("SELECT LOG2(8), LOG2(null)"), Row(3, null)) } test("sqrt") { @@ -276,13 +382,13 @@ class MathExpressionsSuite extends QueryTest { df.select(sqrt("a"), sqrt("b")), Row(1.0, 2.0)) - checkAnswer(ctx.sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null)) + checkAnswer(sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null)) checkAnswer(df.selectExpr("sqrt(a)", "sqrt(b)", "sqrt(null)"), Row(1.0, 2.0, null)) } test("negative") { checkAnswer( - ctx.sql("SELECT negative(1), negative(0), negative(-1)"), + sql("SELECT negative(1), negative(0), negative(-1)"), Row(-1, 0, 1)) } @@ -290,6 +396,5 @@ class MathExpressionsSuite extends QueryTest { val df = Seq((1, -1, "abc")).toDF("a", "b", "c") checkAnswer(df.selectExpr("positive(a)"), Row(1)) checkAnswer(df.selectExpr("positive(b)"), Row(-1)) - checkAnswer(df.selectExpr("positive(c)"), Row("abc")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 98ba3c99283a..cada03e9ac6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql import java.util.{Locale, TimeZone} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.columnar.InMemoryRelation -class QueryTest extends PlanTest { +abstract class QueryTest extends PlanTest { + + protected def sqlContext: SQLContext // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) @@ -56,27 +58,36 @@ class QueryTest extends PlanTest { * @param df the [[DataFrame]] to be executed * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. */ - protected def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Unit = { - QueryTest.checkAnswer(df, expectedAnswer) match { + protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + val analyzedDF = try df catch { + case ae: AnalysisException => + val currentValue = sqlContext.conf.dataFrameEagerAnalysis + sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false) + val partiallyAnalzyedPlan = df.queryExecution.analyzed + sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, currentValue) + fail( + s""" + |Failed to analyze query: $ae + |$partiallyAnalzyedPlan + | + |${stackTraceToString(ae)} + |""".stripMargin) + } + + QueryTest.checkAnswer(analyzedDF, expectedAnswer) match { case Some(errorMessage) => fail(errorMessage) case None => } } - protected def checkAnswer(df: DataFrame, expectedAnswer: Row): Unit = { + protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = { checkAnswer(df, Seq(expectedAnswer)) } - protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = { + protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = { checkAnswer(df, expectedAnswer.collect()) } - def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext) { - test(sqlString) { - checkAnswer(sqlContext.sql(sqlString), expectedAnswer) - } - } - /** * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. */ @@ -151,7 +162,7 @@ object QueryTest { } def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): String = { - checkAnswer(df, expectedAnswer.toSeq) match { + checkAnswer(df, expectedAnswer.asScala) match { case Some(errorMessage) => errorMessage case None => null } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index d84b57af9c88..3ba14d7602a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -19,34 +19,34 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer - import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String -class RowSuite extends SparkFunSuite { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class RowSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ test("create row") { val expected = new GenericMutableRow(4) - expected.update(0, 2147483647) - expected.setString(1, "this is a string") - expected.update(2, false) - expected.update(3, null) + expected.setInt(0, 2147483647) + expected.update(1, UTF8String.fromString("this is a string")) + expected.setBoolean(2, false) + expected.setNullAt(3) + val actual1 = Row(2147483647, "this is a string", false, null) - assert(expected.size === actual1.size) + assert(expected.numFields === actual1.size) assert(expected.getInt(0) === actual1.getInt(0)) assert(expected.getString(1) === actual1.getString(1)) assert(expected.getBoolean(2) === actual1.getBoolean(2)) - assert(expected(3) === actual1(3)) + assert(expected.isNullAt(3) === actual1.isNullAt(3)) val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null)) - assert(expected.size === actual2.size) + assert(expected.numFields === actual2.size) assert(expected.getInt(0) === actual2.getInt(0)) assert(expected.getString(1) === actual2.getString(1)) assert(expected.getBoolean(2) === actual2.getBoolean(2)) - assert(expected(3) === actual2(3)) + assert(expected.isNullAt(3) === actual2.isNullAt(3)) } test("SpecificMutableRow.update with null") { @@ -57,7 +57,7 @@ class RowSuite extends SparkFunSuite { test("serialize w/ kryo") { val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first() - val serializer = new SparkSqlSerializer(ctx.sparkContext.getConf) + val serializer = new SparkSqlSerializer(sparkContext.getConf) val instance = serializer.newInstance() val ser = instance.serialize(row) val de = instance.deserialize(ser).asInstanceOf[Row] @@ -73,4 +73,25 @@ class RowSuite extends SparkFunSuite { row.getAs[Int]("c") } } + + test("float NaN == NaN") { + val r1 = Row(Float.NaN) + val r2 = Row(Float.NaN) + assert(r1 === r2) + } + + test("double NaN == NaN") { + val r1 = Row(Double.NaN) + val r2 = Row(Double.NaN) + assert(r1 === r2) + } + + test("equals and hashCode") { + val r1 = Row("Hello") + val r2 = Row("Hello") + assert(r1 === r2) + assert(r1.hashCode() === r2.hashCode()) + val r3 = Row("World") + assert(r3.hashCode() != r1.hashCode()) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 75791e9d53c2..3d2bd236ceea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,71 +17,79 @@ package org.apache.spark.sql +import org.apache.spark.sql.test.{TestSQLContext, SharedSQLContext} -class SQLConfSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class SQLConfSuite extends QueryTest with SharedSQLContext { private val testKey = "test.key.0" private val testVal = "test.val.0" test("propagate from spark conf") { // We create a new context here to avoid order dependence with other tests that might call // clear(). - val newContext = new SQLContext(ctx.sparkContext) + val newContext = new SQLContext(sparkContext) assert(newContext.getConf("spark.sql.testkey", "false") === "true") } test("programmatic ways of basic setting and getting") { - ctx.conf.clear() - assert(ctx.getAllConfs.size === 0) - - ctx.setConf(testKey, testVal) - assert(ctx.getConf(testKey) === testVal) - assert(ctx.getConf(testKey, testVal + "_") === testVal) - assert(ctx.getAllConfs.contains(testKey)) + // Set a conf first. + sqlContext.setConf(testKey, testVal) + // Clear the conf. + sqlContext.conf.clear() + // After clear, only overrideConfs used by unit test should be in the SQLConf. + assert(sqlContext.getAllConfs === TestSQLContext.overrideConfs) + + sqlContext.setConf(testKey, testVal) + assert(sqlContext.getConf(testKey) === testVal) + assert(sqlContext.getConf(testKey, testVal + "_") === testVal) + assert(sqlContext.getAllConfs.contains(testKey)) // Tests SQLConf as accessed from a SQLContext is mutable after // the latter is initialized, unlike SparkConf inside a SparkContext. - assert(ctx.getConf(testKey) == testVal) - assert(ctx.getConf(testKey, testVal + "_") === testVal) - assert(ctx.getAllConfs.contains(testKey)) + assert(sqlContext.getConf(testKey) === testVal) + assert(sqlContext.getConf(testKey, testVal + "_") === testVal) + assert(sqlContext.getAllConfs.contains(testKey)) - ctx.conf.clear() + sqlContext.conf.clear() } test("parse SQL set commands") { - ctx.conf.clear() - ctx.sql(s"set $testKey=$testVal") - assert(ctx.getConf(testKey, testVal + "_") === testVal) - assert(ctx.getConf(testKey, testVal + "_") === testVal) + sqlContext.conf.clear() + sql(s"set $testKey=$testVal") + assert(sqlContext.getConf(testKey, testVal + "_") === testVal) + assert(sqlContext.getConf(testKey, testVal + "_") === testVal) - ctx.sql("set some.property=20") - assert(ctx.getConf("some.property", "0") === "20") - ctx.sql("set some.property = 40") - assert(ctx.getConf("some.property", "0") === "40") + sql("set some.property=20") + assert(sqlContext.getConf("some.property", "0") === "20") + sql("set some.property = 40") + assert(sqlContext.getConf("some.property", "0") === "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" - ctx.sql(s"set $key=$vs") - assert(ctx.getConf(key, "0") === vs) + sql(s"set $key=$vs") + assert(sqlContext.getConf(key, "0") === vs) - ctx.sql(s"set $key=") - assert(ctx.getConf(key, "0") === "") + sql(s"set $key=") + assert(sqlContext.getConf(key, "0") === "") - ctx.conf.clear() + sqlContext.conf.clear() } test("deprecated property") { - ctx.conf.clear() - ctx.sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") - assert(ctx.conf.numShufflePartitions === 10) + sqlContext.conf.clear() + val original = sqlContext.conf.numShufflePartitions + try{ + sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") + assert(sqlContext.conf.numShufflePartitions === 10) + } finally { + sql(s"set ${SQLConf.SHUFFLE_PARTITIONS}=$original") + } } test("invalid conf value") { - ctx.conf.clear() + sqlContext.conf.clear() val e = intercept[IllegalArgumentException] { - ctx.sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") + sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") } assert(e.getMessage === s"${SQLConf.CASE_SENSITIVE.key} should be boolean, but was 10") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index c8d8796568a4..dd88ae3700ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -17,32 +17,33 @@ package org.apache.spark.sql -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSQLContext -class SQLContextSuite extends SparkFunSuite with BeforeAndAfterAll { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class SQLContextSuite extends SparkFunSuite with SharedSQLContext { override def afterAll(): Unit = { - SQLContext.setLastInstantiatedContext(ctx) + try { + SQLContext.setLastInstantiatedContext(sqlContext) + } finally { + super.afterAll() + } } test("getOrCreate instantiates SQLContext") { SQLContext.clearLastInstantiatedContext() - val sqlContext = SQLContext.getOrCreate(ctx.sparkContext) + val sqlContext = SQLContext.getOrCreate(sparkContext) assert(sqlContext != null, "SQLContext.getOrCreate returned null") - assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), + assert(SQLContext.getOrCreate(sparkContext).eq(sqlContext), "SQLContext created by SQLContext.getOrCreate not returned by SQLContext.getOrCreate") } test("getOrCreate gets last explicitly instantiated SQLContext") { SQLContext.clearLastInstantiatedContext() - val sqlContext = new SQLContext(ctx.sparkContext) - assert(SQLContext.getOrCreate(ctx.sparkContext) != null, + val sqlContext = new SQLContext(sparkContext) + assert(SQLContext.getOrCreate(sparkContext) != null, "SQLContext.getOrCreate after explicitly created SQLContext returned null") - assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), + assert(SQLContext.getOrCreate(sparkContext).eq(sqlContext), "SQLContext.getOrCreate after explicitly created SQLContext did not return the context") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 4441afd6bd81..f9981356f364 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,28 +17,26 @@ package org.apache.spark.sql -import org.scalatest.BeforeAndAfterAll - +import java.math.MathContext import java.sql.Timestamp +import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.catalyst.DefaultParserDialect +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.errors.DialectException -import org.apache.spark.sql.execution.GeneratedAggregate +import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.functions._ -import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SQLTestData._ +import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ class MyDialect extends DefaultParserDialect -class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { - // Make sure the tables are loaded. - TestData +class SQLQuerySuite extends QueryTest with SharedSQLContext { + import testImplicits._ - val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ - import sqlContext.sql + setupTestData() test("having clause") { Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("hav") @@ -57,6 +55,32 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { checkAnswer(queryCoalesce, Row("1") :: Nil) } + test("show functions") { + checkAnswer(sql("SHOW functions"), + FunctionRegistry.builtin.listFunction().sorted.map(Row(_))) + } + + test("describe functions") { + checkExistence(sql("describe function extended upper"), true, + "Function: upper", + "Class: org.apache.spark.sql.catalyst.expressions.Upper", + "Usage: upper(str) - Returns str with all characters changed to uppercase", + "Extended Usage:", + "> SELECT upper('SparkSql');", + "'SPARKSQL'") + + checkExistence(sql("describe functioN Upper"), true, + "Function: upper", + "Class: org.apache.spark.sql.catalyst.expressions.Upper", + "Usage: upper(str) - Returns str with all characters changed to uppercase") + + checkExistence(sql("describe functioN Upper"), false, + "Extended Usage") + + checkExistence(sql("describe functioN abcadf"), true, + "Function: abcadf is not found.") + } + test("SPARK-6743: no columns from cache") { Seq( (83, 0, 38), @@ -111,15 +135,26 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) } + test("SPARK-8668 expr function") { + checkAnswer(Seq((1, "Bobby G.")) + .toDF("id", "name") + .select(expr("length(name)"), expr("abs(id)")), Row(8, 1)) + + checkAnswer(Seq((1, "building burrito tunnels"), (1, "major projects")) + .toDF("id", "saying") + .groupBy(expr("length(saying)")) + .count(), Row(24, 1) :: Row(14, 1) :: Nil) + } + test("SQL Dialect Switching to a new SQL parser") { - val newContext = new SQLContext(sqlContext.sparkContext) + val newContext = new SQLContext(sparkContext) newContext.setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) assert(newContext.getSQLDialect().getClass === classOf[MyDialect]) assert(newContext.sql("SELECT 1").collect() === Array(Row(1))) } test("SQL Dialect Switch to an invalid parser with alias") { - val newContext = new SQLContext(sqlContext.sparkContext) + val newContext = new SQLContext(sparkContext) newContext.sql("SET spark.sql.dialect=MyTestClass") intercept[DialectException] { newContext.sql("SELECT 1") @@ -137,13 +172,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("SPARK-7158 collect and take return different results") { import java.util.UUID - import org.apache.spark.sql.types._ val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index") // we except the id is materialized once - def id: () => String = () => { UUID.randomUUID().toString() } + val idUDF = org.apache.spark.sql.functions.udf(() => UUID.randomUUID().toString) - val dfWithId = df.withColumn("id", callUDF(id, StringType)) + val dfWithId = df.withColumn("id", idUDF()) // Make a new DataFrame (actually the same reference to the old one) val cached = dfWithId.cache() // Trigger the cache @@ -162,7 +196,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("grouping on nested fields") { - sqlContext.read.json(sqlContext.sparkContext.parallelize( + sqlContext.read.json(sparkContext.parallelize( """{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) .registerTempTable("rows") @@ -181,7 +215,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("SPARK-6201 IN type conversion") { sqlContext.read.json( - sqlContext.sparkContext.parallelize( + sparkContext.parallelize( Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) .registerTempTable("d") @@ -190,6 +224,54 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Seq(Row("1"), Row("2"))) } + test("SPARK-8828 sum should return null if all input values are null") { + withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "true") { + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + } + withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + } + } + + private def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { + val df = sql(sqlText) + // First, check if we have GeneratedAggregate. + val hasGeneratedAgg = df.queryExecution.executedPlan + .collect { case _: aggregate.TungstenAggregate => true } + .nonEmpty + if (!hasGeneratedAgg) { + fail( + s""" + |Codegen is enabled, but query $sqlText does not have TungstenAggregate in the plan. + |${df.queryExecution.simpleString} + """.stripMargin) + } + // Then, check results. + checkAnswer(df, expectedResults) + } + test("aggregation with codegen") { val originalValue = sqlContext.conf.codegenEnabled sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) @@ -199,25 +281,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { .unionAll(sqlContext.table("testData")) .registerTempTable("testData3x") - def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { - val df = sql(sqlText) - // First, check if we have GeneratedAggregate. - var hasGeneratedAgg = false - df.queryExecution.executedPlan.foreach { - case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true - case _ => - } - if (!hasGeneratedAgg) { - fail( - s""" - |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan. - |${df.queryExecution.simpleString} - """.stripMargin) - } - // Then, check results. - checkAnswer(df, expectedResults) - } - try { // Just to group rows. testCodeGen( @@ -265,6 +328,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { testCodeGen( "SELECT min(key) FROM testData3x", Row(1) :: Nil) + // STDDEV + testCodeGen( + "SELECT a, stddev(b), stddev_pop(b) FROM testData2 GROUP BY a", + (1 to 3).map(i => Row(i, math.sqrt(0.5), math.sqrt(0.25)))) + testCodeGen( + "SELECT stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2", + Row(math.sqrt(1.5 / 5), math.sqrt(1.5 / 6), math.sqrt(1.5 / 5)) :: Nil) // Some combinations. testCodeGen( """ @@ -285,8 +355,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(100, 1, 50.5, 300, 100) :: Nil) // Aggregate with Code generation handling all null values testCodeGen( - "SELECT sum('a'), avg('a'), count(null) FROM testData", - Row(0, null, 0) :: Nil) + "SELECT sum('a'), avg('a'), stddev('a'), count(null) FROM testData", + Row(null, null, null, 0) :: Nil) } finally { sqlContext.dropTempTable("testData3x") sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) @@ -299,7 +369,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(1)) checkAnswer( sql("SELECT COALESCE(null, 1, 1.5)"), - Row(1.toDouble)) + Row(BigDecimal(1))) checkAnswer( sql("SELECT COALESCE(null, null, null)"), Row(null)) @@ -396,6 +466,18 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { ) } + test("left semi greater than predicate and equal operator") { + checkAnswer( + sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.b and x.a >= y.a + 2"), + Seq(Row(3, 1), Row(3, 2)) + ) + + checkAnswer( + sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.a and x.a >= y.b + 1"), + Seq(Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2)) + ) + } + test("index into array of arrays") { checkAnswer( sql( @@ -413,18 +495,35 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("literal in agg grouping expressions") { - checkAnswer( - sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) - checkAnswer( - sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + def literalInAggTest(): Unit = { + checkAnswer( + sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + checkAnswer( + sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + checkAnswer( + sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), + sql("SELECT 1, 2, sum(b) FROM testData2")) + } + + literalInAggTest() + withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { + literalInAggTest() + } } test("aggregates with nulls") { checkAnswer( - sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"), - Row(1, 3, 2, 6, 3) + sql("SELECT MIN(a), MAX(a), AVG(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), + Row(1, 3, 2, 1, 6, 3) ) } @@ -483,42 +582,28 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("sorting") { - val before = sqlContext.conf.externalSortEnabled - sqlContext.setConf(SQLConf.EXTERNAL_SORT, false) - sortTest() - sqlContext.setConf(SQLConf.EXTERNAL_SORT, before) + withSQLConf(SQLConf.EXTERNAL_SORT.key -> "false") { + sortTest() + } } test("external sorting") { - val before = sqlContext.conf.externalSortEnabled - sqlContext.setConf(SQLConf.EXTERNAL_SORT, true) - sortTest() - sqlContext.setConf(SQLConf.EXTERNAL_SORT, before) + withSQLConf(SQLConf.EXTERNAL_SORT.key -> "true") { + sortTest() + } } test("SPARK-6927 sorting with codegen on") { - val externalbefore = sqlContext.conf.externalSortEnabled - val codegenbefore = sqlContext.conf.codegenEnabled - sqlContext.setConf(SQLConf.EXTERNAL_SORT, false) - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) - try{ + withSQLConf(SQLConf.EXTERNAL_SORT.key -> "false", + SQLConf.CODEGEN_ENABLED.key -> "true") { sortTest() - } finally { - sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore) - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore) } } test("SPARK-6927 external sorting with codegen on") { - val externalbefore = sqlContext.conf.externalSortEnabled - val codegenbefore = sqlContext.conf.codegenEnabled - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) - sqlContext.setConf(SQLConf.EXTERNAL_SORT, true) - try { + withSQLConf(SQLConf.EXTERNAL_SORT.key -> "true", + SQLConf.CODEGEN_ENABLED.key -> "true") { sortTest() - } finally { - sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore) - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore) } } @@ -631,12 +716,46 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { checkAnswer( sql( - """ - |SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3 - """.stripMargin), + "SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"), Row(2, 1, 2, 2, 1)) } + test("count of empty table") { + withTempTable("t") { + Seq.empty[(Int, Int)].toDF("a", "b").registerTempTable("t") + checkAnswer( + sql("select count(a) from t"), + Row(0)) + } + } + + test("stddev") { + checkAnswer( + sql("SELECT STDDEV(a) FROM testData2"), + Row(math.sqrt(4/5.0)) + ) + } + + test("stddev_pop") { + checkAnswer( + sql("SELECT STDDEV_POP(a) FROM testData2"), + Row(math.sqrt(4/6.0)) + ) + } + + test("stddev_samp") { + checkAnswer( + sql("SELECT STDDEV_SAMP(a) FROM testData2"), + Row(math.sqrt(4/5.0)) + ) + } + + test("stddev agg") { + checkAnswer( + sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"), + (1 to 3).map(i => Row(i, math.sqrt(1/2.0), math.sqrt(1/4.0), math.sqrt(1/2.0)))) + } + test("inner join where, one match per row") { checkAnswer( sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"), @@ -906,21 +1025,30 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { val nonexistentKey = "nonexistent" // "set" itself returns all config variables currently specified in SQLConf. - assert(sql("SET").collect().size == 0) + assert(sql("SET").collect().size === TestSQLContext.overrideConfs.size) + sql("SET").collect().foreach { row => + val key = row.getString(0) + val value = row.getString(1) + assert( + TestSQLContext.overrideConfs.contains(key), + s"$key should exist in SQLConf.") + assert( + TestSQLContext.overrideConfs(key) === value, + s"The value of $key should be ${TestSQLContext.overrideConfs(key)} instead of $value.") + } + val overrideConfs = sql("SET").collect() // "set key=val" sql(s"SET $testKey=$testVal") checkAnswer( sql("SET"), - Row(testKey, testVal) + overrideConfs ++ Seq(Row(testKey, testVal)) ) sql(s"SET ${testKey + testKey}=${testVal + testVal}") checkAnswer( sql("set"), - Seq( - Row(testKey, testVal), - Row(testKey + testKey, testVal + testVal)) + overrideConfs ++ Seq(Row(testKey, testVal), Row(testKey + testKey, testVal + testVal)) ) // "set key" @@ -1071,7 +1199,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { validateMetadata(sql("SELECT * FROM personWithMeta")) validateMetadata(sql("SELECT id, name FROM personWithMeta")) validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId")) - validateMetadata(sql("SELECT name, salary FROM personWithMeta JOIN salary ON id = personId")) + validateMetadata(sql( + "SELECT name, salary FROM personWithMeta JOIN salary ON id = personId")) } test("SPARK-3371 Renaming a function expression with group by gives error") { @@ -1127,19 +1256,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("Floating point number format") { checkAnswer( - sql("SELECT 0.3"), Row(0.3) + sql("SELECT 0.3"), Row(BigDecimal(0.3).underlying()) ) checkAnswer( - sql("SELECT -0.8"), Row(-0.8) + sql("SELECT -0.8"), Row(BigDecimal(-0.8).underlying()) ) checkAnswer( - sql("SELECT .5"), Row(0.5) + sql("SELECT .5"), Row(BigDecimal(0.5)) ) checkAnswer( - sql("SELECT -.18"), Row(-0.18) + sql("SELECT -.18"), Row(BigDecimal(-0.18)) ) } @@ -1172,11 +1301,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { ) checkAnswer( - sql("SELECT -5.2"), Row(-5.2) + sql("SELECT -5.2"), Row(BigDecimal(-5.2)) ) checkAnswer( - sql("SELECT +6.8"), Row(6.8) + sql("SELECT +6.8"), Row(BigDecimal(6.8)) ) checkAnswer( @@ -1256,7 +1385,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-3483 Special chars in column names") { - val data = sqlContext.sparkContext.parallelize( + val data = sparkContext.parallelize( Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) sqlContext.read.json(data).registerTempTable("records") sql("SELECT `key?number1`, `key.number2` FROM records") @@ -1299,13 +1428,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-4322 Grouping field with struct field as sub expression") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) + sqlContext.read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) .registerTempTable("data") checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) sqlContext.dropTempTable("data") sqlContext.read.json( - sqlContext.sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") + sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) sqlContext.dropTempTable("data") } @@ -1326,10 +1455,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("Supporting relational operator '<=>' in Spark SQL") { val nullCheckData1 = TestData(1, "1") :: TestData(2, null) :: Nil - val rdd1 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) + val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) rdd1.toDF().registerTempTable("nulldata1") val nullCheckData2 = TestData(1, "1") :: TestData(2, null) :: Nil - val rdd2 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) + val rdd2 = sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) rdd2.toDF().registerTempTable("nulldata2") checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " + "nulldata2 on nulldata1.value <=> nulldata2.value"), @@ -1338,7 +1467,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("Multi-column COUNT(DISTINCT ...)") { val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil - val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) + val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("distinctData") checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) } @@ -1346,14 +1475,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("SPARK-4699 case sensitivity SQL query") { sqlContext.setConf(SQLConf.CASE_SENSITIVE, false) val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil - val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) + val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("testTable1") checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) sqlContext.setConf(SQLConf.CASE_SENSITIVE, true) } test("SPARK-6145: ORDER BY test for nested fields") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD( + sqlContext.read.json(sparkContext.makeRDD( """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) .registerTempTable("nestedOrder") @@ -1366,14 +1495,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-6145: special cases") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD( - """{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t") - checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1)) - checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1)) + sqlContext.read.json(sparkContext.makeRDD( + """{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t") + checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1)) + checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1)) } test("SPARK-6898: complete support for special chars in column names") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD( + sqlContext.read.json(sparkContext.makeRDD( """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) .registerTempTable("t") @@ -1404,6 +1533,16 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { """.stripMargin), Row(3) :: Row(7) :: Row(11) :: Row(15) :: Nil) + checkAnswer( + sql( + """ + |SELECT sum(b) + |FROM orderByData + |GROUP BY a + |ORDER BY sum(b), max(b) + """.stripMargin), + Row(3) :: Row(7) :: Row(11) :: Row(15) :: Nil) + checkAnswer( sql( """ @@ -1423,6 +1562,26 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { |ORDER BY sum(b) + 1 """.stripMargin), Row("4", 3) :: Row("1", 7) :: Row("3", 11) :: Row("2", 15) :: Nil) + + checkAnswer( + sql( + """ + |SELECT count(*) + |FROM orderByData + |GROUP BY a + |ORDER BY count(*) + """.stripMargin), + Row(2) :: Row(2) :: Row(2) :: Row(2) :: Nil) + + checkAnswer( + sql( + """ + |SELECT a + |FROM orderByData + |GROUP BY a + |ORDER BY a, count(*), sum(b) + """.stripMargin), + Row("1") :: Row("2") :: Row("3") :: Row("4") :: Nil) } test("SPARK-7952: fix the equality check between boolean and numeric types") { @@ -1447,9 +1606,197 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("SPARK-7067: order by queries for complex ExtractValue chain") { withTempTable("t") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD( + sqlContext.read.json(sparkContext.makeRDD( """{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""" :: Nil)).registerTempTable("t") checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1)))) } } + + test("SPARK-8782: ORDER BY NULL") { + withTempTable("t") { + Seq((1, 2), (1, 2)).toDF("a", "b").registerTempTable("t") + checkAnswer(sql("SELECT * FROM t ORDER BY NULL"), Seq(Row(1, 2), Row(1, 2))) + } + } + + test("SPARK-8837: use keyword in column name") { + withTempTable("t") { + val df = Seq(1 -> "a").toDF("count", "sort") + checkAnswer(df.filter("count > 0"), Row(1, "a")) + df.registerTempTable("t") + checkAnswer(sql("select count, sort from t"), Row(1, "a")) + } + } + + test("SPARK-8753: add interval type") { + import org.apache.spark.unsafe.types.CalendarInterval + + val df = sql("select interval 3 years -3 month 7 week 123 microseconds") + checkAnswer(df, Row(new CalendarInterval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 ))) + withTempPath(f => { + // Currently we don't yet support saving out values of interval data type. + val e = intercept[AnalysisException] { + df.write.json(f.getCanonicalPath) + } + e.message.contains("Cannot save interval data type into external storage") + }) + + def checkIntervalParseError(s: String): Unit = { + val e = intercept[AnalysisException] { + sql(s) + } + e.message.contains("at least one time unit should be given for interval literal") + } + + checkIntervalParseError("select interval") + // Currently we don't yet support nanosecond + checkIntervalParseError("select interval 23 nanosecond") + } + + test("SPARK-8945: add and subtract expressions for interval type") { + import org.apache.spark.unsafe.types.CalendarInterval + import org.apache.spark.unsafe.types.CalendarInterval.MICROS_PER_WEEK + + val df = sql("select interval 3 years -3 month 7 week 123 microseconds as i") + checkAnswer(df, Row(new CalendarInterval(12 * 3 - 3, 7L * MICROS_PER_WEEK + 123))) + + checkAnswer(df.select(df("i") + new CalendarInterval(2, 123)), + Row(new CalendarInterval(12 * 3 - 3 + 2, 7L * MICROS_PER_WEEK + 123 + 123))) + + checkAnswer(df.select(df("i") - new CalendarInterval(2, 123)), + Row(new CalendarInterval(12 * 3 - 3 - 2, 7L * MICROS_PER_WEEK + 123 - 123))) + + // unary minus + checkAnswer(df.select(-df("i")), + Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123)))) + } + + test("aggregation with codegen updates peak execution memory") { + withSQLConf((SQLConf.CODEGEN_ENABLED.key, "true")) { + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "aggregation with codegen") { + testCodeGen( + "SELECT key, count(value) FROM testData GROUP BY key", + (1 to 100).map(i => Row(i, 1))) + } + } + } + + test("decimal precision with multiply/division") { + checkAnswer(sql("select 10.3 * 3.0"), Row(BigDecimal("30.90"))) + checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000"))) + checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000"))) + checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"), + Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38)))) + checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"), + Row(null)) + + checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333"))) + checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333"))) + checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333"))) + checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"), + Row(BigDecimal("3.433333333333333333333333333", new MathContext(38)))) + checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"), + Row(BigDecimal("3.4333333333333333333333333333", new MathContext(38)))) + } + + test("SPARK-10215 Div of Decimal returns null") { + val d = Decimal(1.12321) + val df = Seq((d, 1)).toDF("a", "b") + + checkAnswer( + df.selectExpr("b * a / b"), + Seq(Row(d.toBigDecimal))) + checkAnswer( + df.selectExpr("b * a / b / b"), + Seq(Row(d.toBigDecimal))) + checkAnswer( + df.selectExpr("b * a + b"), + Seq(Row(BigDecimal(2.12321)))) + checkAnswer( + df.selectExpr("b * a - b"), + Seq(Row(BigDecimal(0.12321)))) + checkAnswer( + df.selectExpr("b * a * b"), + Seq(Row(d.toBigDecimal))) + } + + test("precision smaller than scale") { + checkAnswer(sql("select 10.00"), Row(BigDecimal("10.00"))) + checkAnswer(sql("select 1.00"), Row(BigDecimal("1.00"))) + checkAnswer(sql("select 0.10"), Row(BigDecimal("0.10"))) + checkAnswer(sql("select 0.01"), Row(BigDecimal("0.01"))) + checkAnswer(sql("select 0.001"), Row(BigDecimal("0.001"))) + checkAnswer(sql("select -0.01"), Row(BigDecimal("-0.01"))) + checkAnswer(sql("select -0.001"), Row(BigDecimal("-0.001"))) + } + + test("external sorting updates peak execution memory") { + withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) { + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") { + sortTest() + } + } + } + + test("SPARK-9511: error with table starting with number") { + withTempTable("1one") { + sparkContext.parallelize(1 to 10).map(i => (i, i.toString)) + .toDF("num", "str") + .registerTempTable("1one") + checkAnswer(sql("select count(num) from 1one"), Row(10)) + } + } + + test("specifying database name for a temporary table is not allowed") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = + sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + df + .write + .format("parquet") + .save(path) + + val message = intercept[AnalysisException] { + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE db.t + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + assert(message.contains("Specifying database name or other qualifiers are not allowed")) + + // If you use backticks to quote the name of a temporary table having dot in it. + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE `db.t` + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + checkAnswer(sqlContext.table("`db.t`"), df) + } + } + + test("SPARK-10130 type coercion for IF should have children resolved first") { + withTempTable("src") { + Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src") + checkAnswer( + sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0))) + } + } + + test("SPARK-10389: order by non-attribute grouping expression on Aggregate") { + withTempTable("src") { + Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src") + checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY key + 1"), + Seq(Row(1), Row(1))) + checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY (key + 1) * 2"), + Seq(Row(1), Row(1))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index ece3d6fdf2af..295f02f9a7b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.test.SharedSQLContext case class ReflectData( stringField: String, @@ -72,17 +72,15 @@ case class ComplexReflectData( mapFieldContainsNull: Map[Int, Option[Long]], dataField: Data) -class ScalaReflectionRelationSuite extends SparkFunSuite { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, - new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1, 2, 3)) + new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3)) Seq(data).toDF().registerTempTable("reflectData") - assert(ctx.sql("SELECT * FROM reflectData").collect().head === + assert(sql("SELECT * FROM reflectData").collect().head === Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3))) @@ -92,7 +90,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { val data = NullReflectData(null, null, null, null, null, null, null) Seq(data).toDF().registerTempTable("reflectNullData") - assert(ctx.sql("SELECT * FROM reflectNullData").collect().head === + assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } @@ -100,7 +98,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { val data = OptionalReflectData(None, None, None, None, None, None, None) Seq(data).toDF().registerTempTable("reflectOptionalData") - assert(ctx.sql("SELECT * FROM reflectOptionalData").collect().head === + assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } @@ -108,7 +106,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { test("query binary data") { Seq(ReflectBinary(Array[Byte](1))).toDF().registerTempTable("reflectBinary") - val result = ctx.sql("SELECT data FROM reflectBinary") + val result = sql("SELECT data FROM reflectBinary") .collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) } @@ -127,17 +125,17 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { Nested(None, "abc"))) Seq(data).toDF().registerTempTable("reflectComplexData") - assert(ctx.sql("SELECT * FROM reflectComplexData").collect().head === - new GenericRow(Array[Any]( + assert(sql("SELECT * FROM reflectComplexData").collect().head === + Row( Seq(1, 2, 3), Seq(1, 2, null), Map(1 -> 10L, 2 -> 20L), Map(1 -> 10L, 2 -> 20L, 3 -> null), - new GenericRow(Array[Any]( + Row( Seq(10, 20, 30), Seq(10, 20, null), Map(10 -> 100L, 20 -> 200L), Map(10 -> 100L, 20 -> 200L, 30 -> null), - new GenericRow(Array[Any](null, "abc"))))))) + Row(null, "abc")))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index e55c9e460b79..ddab91862964 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.sql.test.SharedSQLContext -class SerializationSuite extends SparkFunSuite { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class SerializationSuite extends SparkFunSuite with SharedSQLContext { test("[SPARK-5235] SQLContext should be serializable") { - val sqlContext = new SQLContext(ctx.sparkContext) - new JavaSerializer(new SparkConf()).newInstance().serialize(sqlContext) + val _sqlContext = new SQLContext(sparkContext) + new JavaSerializer(new SparkConf()).newInstance().serialize(_sqlContext) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala new file mode 100644 index 000000000000..e12e6bea3026 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.Decimal + + +class StringFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("string concat") { + val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c") + + checkAnswer( + df.select(concat($"a", $"b"), concat($"a", $"b", $"c")), + Row("ab", null)) + + checkAnswer( + df.selectExpr("concat(a, b)", "concat(a, b, c)"), + Row("ab", null)) + } + + test("string concat_ws") { + val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c") + + checkAnswer( + df.select(concat_ws("||", $"a", $"b", $"c")), + Row("a||b")) + + checkAnswer( + df.selectExpr("concat_ws('||', a, b, c)"), + Row("a||b")) + } + + test("string Levenshtein distance") { + val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r") + checkAnswer(df.select(levenshtein($"l", $"r")), Seq(Row(3), Row(1))) + checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1))) + } + + test("string regex_replace / regex_extract") { + val df = Seq( + ("100-200", "(\\d+)-(\\d+)", "300"), + ("100-200", "(\\d+)-(\\d+)", "400"), + ("100-200", "(\\d+)", "400")).toDF("a", "b", "c") + + checkAnswer( + df.select( + regexp_replace($"a", "(\\d+)", "num"), + regexp_extract($"a", "(\\d+)-(\\d+)", 1)), + Row("num-num", "100") :: Row("num-num", "100") :: Row("num-num", "100") :: Nil) + + // for testing the mutable state of the expression in code gen. + // This is a hack way to enable the codegen, thus the codegen is enable by default, + // it will still use the interpretProjection if projection followed by a LocalRelation, + // hence we add a filter operator. + // See the optimizer rule `ConvertToLocalRelation` + checkAnswer( + df.filter("isnotnull(a)").selectExpr( + "regexp_replace(a, b, c)", + "regexp_extract(a, b, 1)"), + Row("300", "100") :: Row("400", "100") :: Row("400-400", "100") :: Nil) + } + + test("string ascii function") { + val df = Seq(("abc", "")).toDF("a", "b") + checkAnswer( + df.select(ascii($"a"), ascii($"b")), + Row(97, 0)) + + checkAnswer( + df.selectExpr("ascii(a)", "ascii(b)"), + Row(97, 0)) + } + + test("string base64/unbase64 function") { + val bytes = Array[Byte](1, 2, 3, 4) + val df = Seq((bytes, "AQIDBA==")).toDF("a", "b") + checkAnswer( + df.select(base64($"a"), unbase64($"b")), + Row("AQIDBA==", bytes)) + + checkAnswer( + df.selectExpr("base64(a)", "unbase64(b)"), + Row("AQIDBA==", bytes)) + } + + test("string / binary substring function") { + // scalastyle:off + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + val df = Seq(("1世3", Array[Byte](1, 2, 3, 4))).toDF("a", "b") + checkAnswer(df.select(substring($"a", 1, 2)), Row("1世")) + checkAnswer(df.select(substring($"b", 2, 2)), Row(Array[Byte](2,3))) + checkAnswer(df.selectExpr("substring(a, 1, 2)"), Row("1世")) + // scalastyle:on + } + + test("string encode/decode function") { + val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116) + // scalastyle:off + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c") + checkAnswer( + df.select(encode($"a", "utf-8"), decode($"c", "utf-8")), + Row(bytes, "大千世界")) + + checkAnswer( + df.selectExpr("encode(a, 'utf-8')", "decode(c, 'utf-8')"), + Row(bytes, "大千世界")) + // scalastyle:on + } + + test("string translate") { + val df = Seq(("translate", "")).toDF("a", "b") + checkAnswer(df.select(translate($"a", "rnlt", "123")), Row("1a2s3ae")) + checkAnswer(df.selectExpr("""translate(a, "rnlt", "")"""), Row("asae")) + } + + test("string trim functions") { + val df = Seq((" example ", "")).toDF("a", "b") + + checkAnswer( + df.select(ltrim($"a"), rtrim($"a"), trim($"a")), + Row("example ", " example", "example")) + + checkAnswer( + df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"), + Row("example ", " example", "example")) + } + + test("string formatString function") { + val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c") + + checkAnswer( + df.select(format_string("aa%d%s", $"b", $"c")), + Row("aa123cc")) + + checkAnswer( + df.selectExpr("printf(a, b, c)"), + Row("aa123cc")) + } + + test("soundex function") { + val df = Seq(("MARY", "SU")).toDF("l", "r") + checkAnswer( + df.select(soundex($"l"), soundex($"r")), Row("M600", "S000")) + + checkAnswer( + df.selectExpr("SoundEx(l)", "SoundEx(r)"), Row("M600", "S000")) + } + + test("string instr function") { + val df = Seq(("aaads", "aa", "zz")).toDF("a", "b", "c") + + checkAnswer( + df.select(instr($"a", "aa")), + Row(1)) + + checkAnswer( + df.selectExpr("instr(a, b)"), + Row(1)) + } + + test("string substring_index function") { + val df = Seq(("www.apache.org", ".", "zz")).toDF("a", "b", "c") + checkAnswer( + df.select(substring_index($"a", ".", 2)), + Row("www.apache")) + checkAnswer( + df.selectExpr("substring_index(a, '.', 2)"), + Row("www.apache") + ) + } + + test("string locate function") { + val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d") + + checkAnswer( + df.select(locate("aa", $"a"), locate("aa", $"a", 1)), + Row(1, 2)) + + checkAnswer( + df.selectExpr("locate(b, a)", "locate(b, a, d)"), + Row(1, 2)) + } + + test("string padding functions") { + val df = Seq(("hi", 5, "??")).toDF("a", "b", "c") + + checkAnswer( + df.select(lpad($"a", 1, "c"), lpad($"a", 5, "??"), rpad($"a", 1, "c"), rpad($"a", 5, "??")), + Row("h", "???hi", "h", "hi???")) + + checkAnswer( + df.selectExpr("lpad(a, b, c)", "rpad(a, b, c)", "lpad(a, 1, c)", "rpad(a, 1, c)"), + Row("???hi", "hi???", "h", "h")) + } + + test("string repeat function") { + val df = Seq(("hi", 2)).toDF("a", "b") + + checkAnswer( + df.select(repeat($"a", 2)), + Row("hihi")) + + checkAnswer( + df.selectExpr("repeat(a, 2)", "repeat(a, b)"), + Row("hihi", "hihi")) + } + + test("string reverse function") { + val df = Seq(("hi", "hhhi")).toDF("a", "b") + + checkAnswer( + df.select(reverse($"a"), reverse($"b")), + Row("ih", "ihhh")) + + checkAnswer( + df.selectExpr("reverse(b)"), + Row("ihhh")) + } + + test("string space function") { + val df = Seq((2, 3)).toDF("a", "b") + + checkAnswer( + df.selectExpr("space(b)"), + Row(" ")) + } + + test("string split function") { + val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b") + + checkAnswer( + df.select(split($"a", "[1-9]+")), + Row(Seq("aa", "bb", "cc"))) + + checkAnswer( + df.selectExpr("split(a, '[1-9]+')"), + Row(Seq("aa", "bb", "cc"))) + } + + test("string / binary length function") { + val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c") + checkAnswer( + df.select(length($"a"), length($"b")), + Row(3, 4)) + + checkAnswer( + df.selectExpr("length(a)", "length(b)"), + Row(3, 4)) + + intercept[AnalysisException] { + df.selectExpr("length(c)") // int type of the argument is unacceptable + } + } + + test("initcap function") { + val df = Seq(("ab", "a B")).toDF("l", "r") + checkAnswer( + df.select(initcap($"l"), initcap($"r")), Row("Ab", "A B")) + + checkAnswer( + df.selectExpr("InitCap(l)", "InitCap(r)"), Row("Ab", "A B")) + } + + test("number format function") { + val df = sqlContext.range(1) + + checkAnswer( + df.select(format_number(lit(5L), 4)), + Row("5.0000")) + + checkAnswer( + df.select(format_number(lit(1.toByte), 4)), // convert the 1st argument to integer + Row("1.0000")) + + checkAnswer( + df.select(format_number(lit(2.toShort), 4)), // convert the 1st argument to integer + Row("2.0000")) + + checkAnswer( + df.select(format_number(lit(3.1322.toFloat), 4)), // convert the 1st argument to double + Row("3.1322")) + + checkAnswer( + df.select(format_number(lit(4), 4)), // not convert anything + Row("4.0000")) + + checkAnswer( + df.select(format_number(lit(5L), 4)), // not convert anything + Row("5.0000")) + + checkAnswer( + df.select(format_number(lit(6.48173), 4)), // not convert anything + Row("6.4817")) + + checkAnswer( + df.select(format_number(lit(BigDecimal(7.128381)), 4)), // not convert anything + Row("7.1284")) + + intercept[AnalysisException] { + df.select(format_number(lit("aa"), 4)) // string type of the 1st argument is unacceptable + } + + intercept[AnalysisException] { + df.selectExpr("format_number(4, 6.48173)") // non-integral type 2nd argument is unacceptable + } + + // for testing the mutable state of the expression in code gen. + // This is a hack way to enable the codegen, thus the codegen is enable by default, + // it will still use the interpretProjection if projection follows by a LocalRelation, + // hence we add a filter operator. + // See the optimizer rule `ConvertToLocalRelation` + val df2 = Seq((5L, 4), (4L, 3), (4L, 3), (4L, 3), (3L, 2)).toDF("a", "b") + checkAnswer( + df2.filter("b>0").selectExpr("format_number(a, b)"), + Row("5.0000") :: Row("4.000") :: Row("4.000") :: Row("4.000") :: Row("3.00") :: Nil) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala deleted file mode 100644 index 520a862ea083..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ /dev/null @@ -1,200 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.sql.Timestamp - -import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.sql.test._ - - -case class TestData(key: Int, value: String) - -object TestData { - val testData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))).toDF() - testData.registerTempTable("testData") - - val negativeData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() - negativeData.registerTempTable("negativeData") - - case class LargeAndSmallInts(a: Int, b: Int) - val largeAndSmallInts = - TestSQLContext.sparkContext.parallelize( - LargeAndSmallInts(2147483644, 1) :: - LargeAndSmallInts(1, 2) :: - LargeAndSmallInts(2147483645, 1) :: - LargeAndSmallInts(2, 2) :: - LargeAndSmallInts(2147483646, 1) :: - LargeAndSmallInts(3, 2) :: Nil).toDF() - largeAndSmallInts.registerTempTable("largeAndSmallInts") - - case class TestData2(a: Int, b: Int) - val testData2 = - TestSQLContext.sparkContext.parallelize( - TestData2(1, 1) :: - TestData2(1, 2) :: - TestData2(2, 1) :: - TestData2(2, 2) :: - TestData2(3, 1) :: - TestData2(3, 2) :: Nil, 2).toDF() - testData2.registerTempTable("testData2") - - case class DecimalData(a: BigDecimal, b: BigDecimal) - - val decimalData = - TestSQLContext.sparkContext.parallelize( - DecimalData(1, 1) :: - DecimalData(1, 2) :: - DecimalData(2, 1) :: - DecimalData(2, 2) :: - DecimalData(3, 1) :: - DecimalData(3, 2) :: Nil).toDF() - decimalData.registerTempTable("decimalData") - - case class BinaryData(a: Array[Byte], b: Int) - val binaryData = - TestSQLContext.sparkContext.parallelize( - BinaryData("12".getBytes(), 1) :: - BinaryData("22".getBytes(), 5) :: - BinaryData("122".getBytes(), 3) :: - BinaryData("121".getBytes(), 2) :: - BinaryData("123".getBytes(), 4) :: Nil).toDF() - binaryData.registerTempTable("binaryData") - - case class TestData3(a: Int, b: Option[Int]) - val testData3 = - TestSQLContext.sparkContext.parallelize( - TestData3(1, None) :: - TestData3(2, Some(2)) :: Nil).toDF() - testData3.registerTempTable("testData3") - - case class UpperCaseData(N: Int, L: String) - val upperCaseData = - TestSQLContext.sparkContext.parallelize( - UpperCaseData(1, "A") :: - UpperCaseData(2, "B") :: - UpperCaseData(3, "C") :: - UpperCaseData(4, "D") :: - UpperCaseData(5, "E") :: - UpperCaseData(6, "F") :: Nil).toDF() - upperCaseData.registerTempTable("upperCaseData") - - case class LowerCaseData(n: Int, l: String) - val lowerCaseData = - TestSQLContext.sparkContext.parallelize( - LowerCaseData(1, "a") :: - LowerCaseData(2, "b") :: - LowerCaseData(3, "c") :: - LowerCaseData(4, "d") :: Nil).toDF() - lowerCaseData.registerTempTable("lowerCaseData") - - case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) - val arrayData = - TestSQLContext.sparkContext.parallelize( - ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: - ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) - arrayData.toDF().registerTempTable("arrayData") - - case class MapData(data: scala.collection.Map[Int, String]) - val mapData = - TestSQLContext.sparkContext.parallelize( - MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: - MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: - MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: - MapData(Map(1 -> "a4", 2 -> "b4")) :: - MapData(Map(1 -> "a5")) :: Nil) - mapData.toDF().registerTempTable("mapData") - - case class StringData(s: String) - val repeatedData = - TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) - repeatedData.toDF().registerTempTable("repeatedData") - - val nullableRepeatedData = - TestSQLContext.sparkContext.parallelize( - List.fill(2)(StringData(null)) ++ - List.fill(2)(StringData("test"))) - nullableRepeatedData.toDF().registerTempTable("nullableRepeatedData") - - case class NullInts(a: Integer) - val nullInts = - TestSQLContext.sparkContext.parallelize( - NullInts(1) :: - NullInts(2) :: - NullInts(3) :: - NullInts(null) :: Nil - ).toDF() - nullInts.registerTempTable("nullInts") - - val allNulls = - TestSQLContext.sparkContext.parallelize( - NullInts(null) :: - NullInts(null) :: - NullInts(null) :: - NullInts(null) :: Nil).toDF() - allNulls.registerTempTable("allNulls") - - case class NullStrings(n: Int, s: String) - val nullStrings = - TestSQLContext.sparkContext.parallelize( - NullStrings(1, "abc") :: - NullStrings(2, "ABC") :: - NullStrings(3, null) :: Nil).toDF() - nullStrings.registerTempTable("nullStrings") - - case class TableName(tableName: String) - TestSQLContext - .sparkContext - .parallelize(TableName("test") :: Nil) - .toDF() - .registerTempTable("tableName") - - val unparsedStrings = - TestSQLContext.sparkContext.parallelize( - "1, A1, true, null" :: - "2, B2, false, null" :: - "3, C3, true, null" :: - "4, D4, true, 2147483644" :: Nil) - - case class IntField(i: Int) - // An RDD with 4 elements and 8 partitions - val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8) - withEmptyParts.toDF().registerTempTable("withEmptyParts") - - case class Person(id: Int, name: String, age: Int) - case class Salary(personId: Int, salary: Double) - val person = TestSQLContext.sparkContext.parallelize( - Person(0, "mike", 30) :: - Person(1, "jim", 20) :: Nil).toDF() - person.registerTempTable("person") - val salary = TestSQLContext.sparkContext.parallelize( - Salary(0, 2000.0) :: - Salary(1, 1000.0) :: Nil).toDF() - salary.registerTempTable("salary") - - case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) - val complexData = - TestSQLContext.sparkContext.parallelize( - ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1), true) - :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2), false) - :: Nil).toDF() - complexData.registerTempTable("complexData") -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 703a34c47ec2..e0435a0dba6a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,16 +17,16 @@ package org.apache.spark.sql +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData._ -case class FunctionResult(f1: String, f2: String) +private case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class UDFSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("built-in fixed arity expressions") { - val df = ctx.emptyDataFrame + val df = sqlContext.emptyDataFrame df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)") } @@ -51,8 +51,27 @@ class UDFSuite extends QueryTest { df.selectExpr("count(distinct a)") } + test("SPARK-8003 spark_partition_id") { + val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying") + df.registerTempTable("tmp_table") + checkAnswer(sql("select spark_partition_id() from tmp_table").toDF(), Row(0)) + sqlContext.dropTempTable("tmp_table") + } + + test("SPARK-8005 input_file_name") { + withTempPath { dir => + val data = sparkContext.parallelize(0 to 10, 2).toDF("id") + data.write.parquet(dir.getCanonicalPath) + sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") + val answer = sql("select input_file_name() from test_table").head().getString(0) + assert(answer.contains(dir.getCanonicalPath)) + assert(sql("select input_file_name() from test_table").distinct().collect().length >= 2) + sqlContext.dropTempTable("test_table") + } + } + test("error reporting for incorrect number of arguments") { - val df = ctx.emptyDataFrame + val df = sqlContext.emptyDataFrame val e = intercept[AnalysisException] { df.selectExpr("substr('abcd', 2, 3, 4)") } @@ -60,7 +79,7 @@ class UDFSuite extends QueryTest { } test("error reporting for undefined functions") { - val df = ctx.emptyDataFrame + val df = sqlContext.emptyDataFrame val e = intercept[AnalysisException] { df.selectExpr("a_function_that_does_not_exist()") } @@ -68,32 +87,108 @@ class UDFSuite extends QueryTest { } test("Simple UDF") { - ctx.udf.register("strLenScala", (_: String).length) - assert(ctx.sql("SELECT strLenScala('test')").head().getInt(0) === 4) + sqlContext.udf.register("strLenScala", (_: String).length) + assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { - ctx.udf.register("random0", () => { Math.random()}) - assert(ctx.sql("SELECT random0()").head().getDouble(0) >= 0.0) + sqlContext.udf.register("random0", () => { Math.random()}) + assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { - ctx.udf.register("strLenScala", (_: String).length + (_: Int)) - assert(ctx.sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) + sqlContext.udf.register("strLenScala", (_: String).length + (_: Int)) + assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) + } + + test("UDF in a WHERE") { + sqlContext.udf.register("oneArgFilter", (n: Int) => { n > 80 }) + + val df = sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toDF() + df.registerTempTable("integerData") + + val result = + sql("SELECT * FROM integerData WHERE oneArgFilter(key)") + assert(result.count() === 20) + } + + test("UDF in a HAVING") { + sqlContext.udf.register("havingFilter", (n: Long) => { n > 5 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.registerTempTable("groupData") + + val result = + sql( + """ + | SELECT g, SUM(v) as s + | FROM groupData + | GROUP BY g + | HAVING havingFilter(s) + """.stripMargin) + + assert(result.count() === 2) + } + + test("UDF in a GROUP BY") { + sqlContext.udf.register("groupFunction", (n: Int) => { n > 10 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.registerTempTable("groupData") + + val result = + sql( + """ + | SELECT SUM(v) + | FROM groupData + | GROUP BY groupFunction(v) + """.stripMargin) + assert(result.count() === 2) + } + + test("UDFs everywhere") { + sqlContext.udf.register("groupFunction", (n: Int) => { n > 10 }) + sqlContext.udf.register("havingFilter", (n: Long) => { n > 2000 }) + sqlContext.udf.register("whereFilter", (n: Int) => { n < 150 }) + sqlContext.udf.register("timesHundred", (n: Long) => { n * 100 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.registerTempTable("groupData") + + val result = + sql( + """ + | SELECT timesHundred(SUM(v)) as v100 + | FROM groupData + | WHERE whereFilter(v) + | GROUP BY groupFunction(v) + | HAVING havingFilter(v100) + """.stripMargin) + assert(result.count() === 1) } test("struct UDF") { - ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) + sqlContext.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) val result = - ctx.sql("SELECT returnStruct('test', 'test2') as ret") + sql("SELECT returnStruct('test', 'test2') as ret") .select($"ret.f1").head().getString(0) assert(result === "test") } test("udf that is transformed") { - ctx.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) + sqlContext.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) // 1 + 1 is constant folded causing a transformation. - assert(ctx.sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) + assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) + } + + test("type coercion for udf inputs") { + sqlContext.udf.register("intExpected", (x: Int) => x) + // pass a decimal to intExpected. + assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala new file mode 100644 index 000000000000..2476b10e3cf9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.ByteArrayOutputStream + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.memory.MemoryAllocator +import org.apache.spark.unsafe.types.UTF8String + +class UnsafeRowSuite extends SparkFunSuite { + + test("bitset width calculation") { + assert(UnsafeRow.calculateBitSetWidthInBytes(0) === 0) + assert(UnsafeRow.calculateBitSetWidthInBytes(1) === 8) + assert(UnsafeRow.calculateBitSetWidthInBytes(32) === 8) + assert(UnsafeRow.calculateBitSetWidthInBytes(64) === 8) + assert(UnsafeRow.calculateBitSetWidthInBytes(65) === 16) + assert(UnsafeRow.calculateBitSetWidthInBytes(128) === 16) + } + + test("writeToStream") { + val row = InternalRow.apply(UTF8String.fromString("hello"), UTF8String.fromString("world"), 123) + val arrayBackedUnsafeRow: UnsafeRow = + UnsafeProjection.create(Array[DataType](StringType, StringType, IntegerType)).apply(row) + assert(arrayBackedUnsafeRow.getBaseObject.isInstanceOf[Array[Byte]]) + val (bytesFromArrayBackedRow, field0StringFromArrayBackedRow): (Array[Byte], String) = { + val baos = new ByteArrayOutputStream() + arrayBackedUnsafeRow.writeToStream(baos, null) + (baos.toByteArray, arrayBackedUnsafeRow.getString(0)) + } + val (bytesFromOffheapRow, field0StringFromOffheapRow): (Array[Byte], String) = { + val offheapRowPage = MemoryAllocator.UNSAFE.allocate(arrayBackedUnsafeRow.getSizeInBytes) + try { + Platform.copyMemory( + arrayBackedUnsafeRow.getBaseObject, + arrayBackedUnsafeRow.getBaseOffset, + offheapRowPage.getBaseObject, + offheapRowPage.getBaseOffset, + arrayBackedUnsafeRow.getSizeInBytes + ) + val offheapUnsafeRow: UnsafeRow = new UnsafeRow() + offheapUnsafeRow.pointTo( + offheapRowPage.getBaseObject, + offheapRowPage.getBaseOffset, + 3, // num fields + arrayBackedUnsafeRow.getSizeInBytes + ) + assert(offheapUnsafeRow.getBaseObject === null) + val baos = new ByteArrayOutputStream() + val writeBuffer = new Array[Byte](1024) + offheapUnsafeRow.writeToStream(baos, writeBuffer) + (baos.toByteArray, offheapUnsafeRow.getString(0)) + } finally { + MemoryAllocator.UNSAFE.free(offheapRowPage) + } + } + + assert(bytesFromArrayBackedRow === bytesFromOffheapRow) + assert(field0StringFromArrayBackedRow === field0StringFromOffheapRow) + } + + test("calling getDouble() and getFloat() on null columns") { + val row = InternalRow.apply(null, null) + val unsafeRow = UnsafeProjection.create(Array[DataType](FloatType, DoubleType)).apply(row) + assert(unsafeRow.getFloat(0) === row.getFloat(0)) + assert(unsafeRow.getDouble(1) === row.getDouble(1)) + } + + test("calling get(ordinal, datatype) on null columns") { + val row = InternalRow.apply(null) + val unsafeRow = UnsafeProjection.create(Array[DataType](NullType)).apply(row) + for (dataType <- DataTypeTestUtils.atomicTypes) { + assert(unsafeRow.get(0, dataType) === null) + } + } + + test("createFromByteArray and copyFrom") { + val row = InternalRow(1, UTF8String.fromString("abc")) + val converter = UnsafeProjection.create(Array[DataType](IntegerType, StringType)) + val unsafeRow = converter.apply(row) + + val emptyRow = UnsafeRow.createFromByteArray(64, 2) + val buffer = emptyRow.getBaseObject + + emptyRow.copyFrom(unsafeRow) + assert(emptyRow.getSizeInBytes() === unsafeRow.getSizeInBytes) + assert(emptyRow.getInt(0) === unsafeRow.getInt(0)) + assert(emptyRow.getUTF8String(1) === unsafeRow.getUTF8String(1)) + // make sure we reuse the buffer. + assert(emptyRow.getBaseObject === buffer) + + // make sure we really copied the input row. + unsafeRow.setInt(0, 2) + assert(emptyRow.getInt(0) === 1) + + val longString = UTF8String.fromString((1 to 100).map(_ => "abc").reduce(_ + _)) + val row2 = InternalRow(3, longString) + val unsafeRow2 = converter.apply(row2) + + // make sure we can resize. + emptyRow.copyFrom(unsafeRow2) + assert(emptyRow.getSizeInBytes() === unsafeRow2.getSizeInBytes) + assert(emptyRow.getInt(0) === 3) + assert(emptyRow.getUTF8String(1) === longString) + // make sure we really resized. + assert(emptyRow.getBaseObject != buffer) + + // make sure we can still handle small rows after resize. + emptyRow.copyFrom(unsafeRow) + assert(emptyRow.getSizeInBytes() === unsafeRow.getSizeInBytes) + assert(emptyRow.getInt(0) === unsafeRow.getInt(0)) + assert(emptyRow.getUTF8String(1) === unsafeRow.getUTF8String(1)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 45c9f06941c1..46d87843dfa4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -24,6 +24,7 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashSet @@ -47,17 +48,17 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) - override def serialize(obj: Any): Seq[Double] = { + override def serialize(obj: Any): ArrayData = { obj match { case features: MyDenseVector => - features.data.toSeq + new GenericArrayData(features.data.map(_.asInstanceOf[Any])) } } override def deserialize(datum: Any): MyDenseVector = { datum match { - case data: Seq[_] => - new MyDenseVector(data.asInstanceOf[Seq[Double]].toArray) + case data: ArrayData => + new MyDenseVector(data.toDoubleArray()) } } @@ -66,10 +67,8 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { private[spark] override def asNullable: MyDenseVectorUDT = this } -class UserDefinedTypeSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { + import testImplicits._ private lazy val pointsRDD = Seq( MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), @@ -91,10 +90,10 @@ class UserDefinedTypeSuite extends QueryTest { } test("UDTs and UDFs") { - ctx.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + sqlContext.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) pointsRDD.registerTempTable("points") checkAnswer( - ctx.sql("SELECT testType(features) from points"), + sql("SELECT testType(features) from points"), Seq(Row(true), Row(true))) } @@ -138,4 +137,30 @@ class UserDefinedTypeSuite extends QueryTest { val actual = openHashSetUDT.deserialize(openHashSetUDT.serialize(set)) assert(actual.iterator.toSet === set.iterator.toSet) } + + test("UDTs with JSON") { + val data = Seq( + "{\"id\":1,\"vec\":[1.1,2.2,3.3,4.4]}", + "{\"id\":2,\"vec\":[2.25,4.5,8.75]}" + ) + val schema = StructType(Seq( + StructField("id", IntegerType, false), + StructField("vec", new MyDenseVectorUDT, false) + )) + + val stringRDD = sparkContext.parallelize(data) + val jsonRDD = sqlContext.read.schema(schema).json(stringRDD) + checkAnswer( + jsonRDD, + Row(1, new MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))) :: + Row(2, new MyDenseVector(Array(2.25, 4.5, 8.75))) :: + Nil + ) + } + + test("SPARK-10472 UserDefinedType.typeName") { + assert(IntegerType.typeName === "integer") + assert(new MyDenseVectorUDT().typeName === "mydensevector") + assert(new OpenHashSetUDT(IntegerType).typeName === "openhashset") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 1f37455dd0bc..d0430d2a60e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -18,54 +18,93 @@ package org.apache.spark.sql.columnar import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.InternalRow +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.types._ class ColumnStatsSuite extends SparkFunSuite { - testColumnStats(classOf[ByteColumnStats], BYTE, InternalRow(Byte.MaxValue, Byte.MinValue, 0)) - testColumnStats(classOf[ShortColumnStats], SHORT, InternalRow(Short.MaxValue, Short.MinValue, 0)) - testColumnStats(classOf[IntColumnStats], INT, InternalRow(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[LongColumnStats], LONG, InternalRow(Long.MaxValue, Long.MinValue, 0)) - testColumnStats(classOf[FloatColumnStats], FLOAT, InternalRow(Float.MaxValue, Float.MinValue, 0)) - testColumnStats(classOf[DoubleColumnStats], DOUBLE, - InternalRow(Double.MaxValue, Double.MinValue, 0)) - testColumnStats(classOf[FixedDecimalColumnStats], - FIXED_DECIMAL(15, 10), InternalRow(null, null, 0)) - testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0)) - testColumnStats(classOf[DateColumnStats], DATE, InternalRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[BooleanColumnStats], BOOLEAN, createRow(true, false, 0)) + testColumnStats(classOf[ByteColumnStats], BYTE, createRow(Byte.MaxValue, Byte.MinValue, 0)) + testColumnStats(classOf[ShortColumnStats], SHORT, createRow(Short.MaxValue, Short.MinValue, 0)) + testColumnStats(classOf[IntColumnStats], INT, createRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[DateColumnStats], DATE, createRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[LongColumnStats], LONG, createRow(Long.MaxValue, Long.MinValue, 0)) testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, - InternalRow(Long.MaxValue, Long.MinValue, 0)) + createRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[FloatColumnStats], FLOAT, createRow(Float.MaxValue, Float.MinValue, 0)) + testColumnStats(classOf[DoubleColumnStats], DOUBLE, + createRow(Double.MaxValue, Double.MinValue, 0)) + testColumnStats(classOf[StringColumnStats], STRING, createRow(null, null, 0)) + testDecimalColumnStats(createRow(null, null, 0)) + + def createRow(values: Any*): GenericInternalRow = new GenericInternalRow(values.toArray) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], columnType: NativeColumnType[T], - initialStatistics: InternalRow): Unit = { + initialStatistics: GenericInternalRow): Unit = { val columnStatsName = columnStatsClass.getSimpleName test(s"$columnStatsName: empty") { val columnStats = columnStatsClass.newInstance() - columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { + columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { case (actual, expected) => assert(actual === expected) } } test(s"$columnStatsName: non-empty") { - import ColumnarTestUtils._ + import org.apache.spark.sql.columnar.ColumnarTestUtils._ val columnStats = columnStatsClass.newInstance() val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) - val values = rows.take(10).map(_(0).asInstanceOf[T#InternalType]) + val values = rows.take(10).map(_.get(0, columnType.dataType).asInstanceOf[T#InternalType]) + val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] + val stats = columnStats.collectedStatistics + + assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) + assertResult(10, "Wrong null count")(stats.values(2)) + assertResult(20, "Wrong row count")(stats.values(3)) + assertResult(stats.values(4), "Wrong size in bytes") { + rows.map { row => + if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) + }.sum + } + } + } + + def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats]( + initialStatistics: GenericInternalRow): Unit = { + + val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName + val columnType = FIXED_DECIMAL(15, 10) + + test(s"$columnStatsName: empty") { + val columnStats = new FixedDecimalColumnStats(15, 10) + columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { + case (actual, expected) => assert(actual === expected) + } + } + + test(s"$columnStatsName: non-empty") { + import org.apache.spark.sql.columnar.ColumnarTestUtils._ + + val columnStats = new FixedDecimalColumnStats(15, 10) + val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) + rows.foreach(columnStats.gatherStats(_, 0)) + + val values = rows.take(10).map(_.get(0, columnType.dataType).asInstanceOf[T#InternalType]) val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) - assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) - assertResult(10, "Wrong null count")(stats(2)) - assertResult(20, "Wrong row count")(stats(3)) - assertResult(stats(4), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) + assertResult(10, "Wrong null count")(stats.values(2)) + assertResult(20, "Wrong row count")(stats.values(3)) + assertResult(stats.values(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 6daddfb2c480..8f024690efd0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -32,13 +32,15 @@ import org.apache.spark.unsafe.types.UTF8String class ColumnTypeSuite extends SparkFunSuite with Logging { - val DEFAULT_BUFFER_SIZE = 512 + private val DEFAULT_BUFFER_SIZE = 512 + private val MAP_GENERIC = GENERIC(MapType(IntegerType, StringType)) test("defaultSize") { val checks = Map( - INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, - FIXED_DECIMAL(15, 10) -> 8, BOOLEAN -> 1, STRING -> 8, DATE -> 4, TIMESTAMP -> 8, - BINARY -> 16, GENERIC -> 16) + BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, DATE -> 4, + LONG -> 8, TIMESTAMP -> 8, FLOAT -> 4, DOUBLE -> 8, + STRING -> 8, BINARY -> 16, FIXED_DECIMAL(15, 10) -> 8, + MAP_GENERIC -> 16) checks.foreach { case (columnType, expectedSize) => assertResult(expectedSize, s"Wrong defaultSize for $columnType") { @@ -48,8 +50,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { } test("actualSize") { - def checkActualSize[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType], + def checkActualSize[JvmType]( + columnType: ColumnType[JvmType], value: JvmType, expected: Int): Unit = { @@ -60,27 +62,24 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { } } - checkActualSize(INT, Int.MaxValue, 4) + checkActualSize(BOOLEAN, true, 1) + checkActualSize(BYTE, Byte.MaxValue, 1) checkActualSize(SHORT, Short.MaxValue, 2) + checkActualSize(INT, Int.MaxValue, 4) + checkActualSize(DATE, Int.MaxValue, 4) checkActualSize(LONG, Long.MaxValue, 8) - checkActualSize(BYTE, Byte.MaxValue, 1) - checkActualSize(DOUBLE, Double.MaxValue, 8) + checkActualSize(TIMESTAMP, Long.MaxValue, 8) checkActualSize(FLOAT, Float.MaxValue, 4) - checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8) - checkActualSize(BOOLEAN, true, 1) + checkActualSize(DOUBLE, Double.MaxValue, 8) checkActualSize(STRING, UTF8String.fromString("hello"), 4 + "hello".getBytes("utf-8").length) - checkActualSize(DATE, 0, 4) - checkActualSize(TIMESTAMP, 0L, 8) - - val binary = Array.fill[Byte](4)(0: Byte) - checkActualSize(BINARY, binary, 4 + 4) + checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4) + checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8) val generic = Map(1 -> "a") - checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) + checkActualSize(MAP_GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) } - testNativeColumnType[BooleanType.type]( - BOOLEAN, + testNativeColumnType(BOOLEAN)( (buffer: ByteBuffer, v: Boolean) => { buffer.put((if (v) 1 else 0).toByte) }, @@ -88,18 +87,23 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { buffer.get() == 1 }) - testNativeColumnType[IntegerType.type](INT, _.putInt(_), _.getInt) + testNativeColumnType(BYTE)(_.put(_), _.get) + + testNativeColumnType(SHORT)(_.putShort(_), _.getShort) + + testNativeColumnType(INT)(_.putInt(_), _.getInt) + + testNativeColumnType(DATE)(_.putInt(_), _.getInt) - testNativeColumnType[ShortType.type](SHORT, _.putShort(_), _.getShort) + testNativeColumnType(LONG)(_.putLong(_), _.getLong) - testNativeColumnType[LongType.type](LONG, _.putLong(_), _.getLong) + testNativeColumnType(TIMESTAMP)(_.putLong(_), _.getLong) - testNativeColumnType[ByteType.type](BYTE, _.put(_), _.get) + testNativeColumnType(FLOAT)(_.putFloat(_), _.getFloat) - testNativeColumnType[DoubleType.type](DOUBLE, _.putDouble(_), _.getDouble) + testNativeColumnType(DOUBLE)(_.putDouble(_), _.getDouble) - testNativeColumnType[DecimalType]( - FIXED_DECIMAL(15, 10), + testNativeColumnType(FIXED_DECIMAL(15, 10))( (buffer: ByteBuffer, decimal: Decimal) => { buffer.putLong(decimal.toUnscaledLong) }, @@ -107,10 +111,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { Decimal(buffer.getLong(), 15, 10) }) - testNativeColumnType[FloatType.type](FLOAT, _.putFloat(_), _.getFloat) - testNativeColumnType[StringType.type]( - STRING, + testNativeColumnType(STRING)( (buffer: ByteBuffer, string: UTF8String) => { val bytes = string.getBytes buffer.putInt(bytes.length) @@ -123,7 +125,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { UTF8String.fromBytes(bytes) }) - testColumnType[BinaryType.type, Array[Byte]]( + testColumnType[Array[Byte]]( BINARY, (buffer: ByteBuffer, bytes: Array[Byte]) => { buffer.putInt(bytes.length).put(bytes) @@ -140,7 +142,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { val obj = Map(1 -> "spark", 2 -> "sql") val serializedObj = SparkSqlSerializer.serialize(obj) - GENERIC.append(SparkSqlSerializer.serialize(obj), buffer) + MAP_GENERIC.append(SparkSqlSerializer.serialize(obj), buffer) buffer.rewind() val length = buffer.getInt() @@ -157,7 +159,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { assertResult(obj, "Deserialized object didn't equal to the original object") { buffer.rewind() - SparkSqlSerializer.deserialize(GENERIC.extract(buffer)) + SparkSqlSerializer.deserialize(MAP_GENERIC.extract(buffer)) } } @@ -170,7 +172,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { val obj = CustomClass(Int.MaxValue, Long.MaxValue) val serializedObj = serializer.serialize(obj).array() - GENERIC.append(serializer.serialize(obj).array(), buffer) + MAP_GENERIC.append(serializer.serialize(obj).array(), buffer) buffer.rewind() val length = buffer.getInt @@ -192,20 +194,20 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { assertResult(obj, "Custom deserialized object didn't equal the original object") { buffer.rewind() - serializer.deserialize(ByteBuffer.wrap(GENERIC.extract(buffer))) + serializer.deserialize(ByteBuffer.wrap(MAP_GENERIC.extract(buffer))) } } def testNativeColumnType[T <: AtomicType]( - columnType: NativeColumnType[T], - putter: (ByteBuffer, T#InternalType) => Unit, + columnType: NativeColumnType[T]) + (putter: (ByteBuffer, T#InternalType) => Unit, getter: (ByteBuffer) => T#InternalType): Unit = { - testColumnType[T, T#InternalType](columnType, putter, getter) + testColumnType[T#InternalType](columnType, putter, getter) } - def testColumnType[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType], + def testColumnType[JvmType]( + columnType: ColumnType[JvmType], putter: (ByteBuffer, JvmType) => Unit, getter: (ByteBuffer) => JvmType): Unit = { @@ -262,7 +264,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { } } - assertResult(GENERIC) { + assertResult(GENERIC(DecimalType(19, 0))) { ColumnType(DecimalType(19, 0)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index 7c86eae3f77f..79bb7d072feb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -31,7 +31,7 @@ object ColumnarTestUtils { row } - def makeRandomValue[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]): JvmType = { + def makeRandomValue[JvmType](columnType: ColumnType[JvmType]): JvmType = { def randomBytes(length: Int) = { val bytes = new Array[Byte](length) Random.nextBytes(bytes) @@ -39,18 +39,18 @@ object ColumnarTestUtils { } (columnType match { + case BOOLEAN => Random.nextBoolean() case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort case INT => Random.nextInt() + case DATE => Random.nextInt() case LONG => Random.nextLong() + case TIMESTAMP => Random.nextLong() case FLOAT => Random.nextFloat() case DOUBLE => Random.nextDouble() - case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32))) - case BOOLEAN => Random.nextBoolean() case BINARY => randomBytes(Random.nextInt(32)) - case DATE => Random.nextInt() - case TIMESTAMP => Random.nextLong() + case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) case _ => // Using a random one-element map instead of an arbitrary object Map(Random.nextInt() -> Random.nextString(Random.nextInt(32))) @@ -58,15 +58,15 @@ object ColumnarTestUtils { } def makeRandomValues( - head: ColumnType[_ <: DataType, _], - tail: ColumnType[_ <: DataType, _]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail) + head: ColumnType[_], + tail: ColumnType[_]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail) - def makeRandomValues(columnTypes: Seq[ColumnType[_ <: DataType, _]]): Seq[Any] = { + def makeRandomValues(columnTypes: Seq[ColumnType[_]]): Seq[Any] = { columnTypes.map(makeRandomValue(_)) } - def makeUniqueRandomValues[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType], + def makeUniqueRandomValues[JvmType]( + columnType: ColumnType[JvmType], count: Int): Seq[JvmType] = { Iterator.iterate(HashSet.empty[JvmType]) { set => @@ -75,10 +75,10 @@ object ColumnarTestUtils { } def makeRandomRow( - head: ColumnType[_ <: DataType, _], - tail: ColumnType[_ <: DataType, _]*): InternalRow = makeRandomRow(Seq(head) ++ tail) + head: ColumnType[_], + tail: ColumnType[_]*): InternalRow = makeRandomRow(Seq(head) ++ tail) - def makeRandomRow(columnTypes: Seq[ColumnType[_ <: DataType, _]]): InternalRow = { + def makeRandomRow(columnTypes: Seq[ColumnType[_]]): InternalRow = { val row = new GenericMutableRow(columnTypes.length) makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) => row(index) = value diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 01bc23277fa8..cd3644eb9c09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -19,21 +19,19 @@ package org.apache.spark.sql.columnar import java.sql.{Date, Timestamp} -import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{QueryTest, Row, TestData} import org.apache.spark.storage.StorageLevel.MEMORY_ONLY -class InMemoryColumnarQuerySuite extends QueryTest { - // Make sure the tables are loaded. - TestData +class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { + import testImplicits._ - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.{logicalPlanToSparkQuery, sql} + setupTestData() test("simple columnar query") { - val plan = ctx.executePlan(testData.logicalPlan).executedPlan + val plan = sqlContext.executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -41,16 +39,16 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("default size avoids broadcast") { // TODO: Improve this test when we have better statistics - ctx.sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) + sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) .toDF().registerTempTable("sizeTst") - ctx.cacheTable("sizeTst") + sqlContext.cacheTable("sizeTst") assert( - ctx.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > - ctx.conf.autoBroadcastJoinThreshold) + sqlContext.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > + sqlContext.conf.autoBroadcastJoinThreshold) } test("projection") { - val plan = ctx.executePlan(testData.select('value, 'key).logicalPlan).executedPlan + val plan = sqlContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().map { @@ -59,7 +57,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { - val plan = ctx.executePlan(testData.logicalPlan).executedPlan + val plan = sqlContext.executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -71,7 +69,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM repeatedData"), repeatedData.collect().toSeq.map(Row.fromTuple)) - ctx.cacheTable("repeatedData") + sqlContext.cacheTable("repeatedData") checkAnswer( sql("SELECT * FROM repeatedData"), @@ -83,7 +81,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) - ctx.cacheTable("nullableRepeatedData") + sqlContext.cacheTable("nullableRepeatedData") checkAnswer( sql("SELECT * FROM nullableRepeatedData"), @@ -98,7 +96,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT time FROM timestamps"), timestamps.collect().toSeq) - ctx.cacheTable("timestamps") + sqlContext.cacheTable("timestamps") checkAnswer( sql("SELECT time FROM timestamps"), @@ -110,7 +108,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq.map(Row.fromTuple)) - ctx.cacheTable("withEmptyParts") + sqlContext.cacheTable("withEmptyParts") checkAnswer( sql("SELECT * FROM withEmptyParts"), @@ -148,7 +146,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { val dataTypes = Seq(StringType, BinaryType, NullType, BooleanType, ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5), + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), DateType, TimestampType, ArrayType(IntegerType), MapType(StringType, LongType), struct) val fields = dataTypes.zipWithIndex.map { case (dataType, index) => @@ -159,7 +157,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { // Create a RDD for the schema val rdd = - ctx.sparkContext.parallelize((1 to 100), 10).map { i => + sparkContext.parallelize((1 to 100), 10).map { i => Row( s"str${i}: test cache.", s"binary${i}: test cache.".getBytes("UTF-8"), @@ -179,18 +177,39 @@ class InMemoryColumnarQuerySuite extends QueryTest { (0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, Row((i - 0.25).toFloat, Seq(true, false, null))) } - ctx.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") + sqlContext.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") // Cache the table. sql("cache table InMemoryCache_different_data_types") // Make sure the table is indeed cached. - val tableScan = ctx.table("InMemoryCache_different_data_types").queryExecution.executedPlan + sqlContext.table("InMemoryCache_different_data_types").queryExecution.executedPlan assert( - ctx.isCached("InMemoryCache_different_data_types"), + sqlContext.isCached("InMemoryCache_different_data_types"), "InMemoryCache_different_data_types should be cached.") // Issue a query and check the results. checkAnswer( sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"), - ctx.table("InMemoryCache_different_data_types").collect()) - ctx.dropTempTable("InMemoryCache_different_data_types") + sqlContext.table("InMemoryCache_different_data_types").collect()) + sqlContext.dropTempTable("InMemoryCache_different_data_types") + } + + test("SPARK-10422: String column in InMemoryColumnarCache needs to override clone method") { + val df = sqlContext.range(1, 100).selectExpr("id % 10 as id") + .rdd.map(id => Tuple1(s"str_$id")).toDF("i") + val cached = df.cache() + // count triggers the caching action. It should not throw. + cached.count() + + // Make sure, the DataFrame is indeed cached. + assert(sqlContext.cacheManager.lookupCachedData(cached).nonEmpty) + + // Check result. + checkAnswer( + cached, + sqlContext.range(1, 100).selectExpr("id % 10 as id") + .rdd.map(id => Tuple1(s"str_$id")).toDF("i") + ) + + // Drop the cache. + cached.unpersist() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 2a6e0c376551..f4f6c7649bfa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -21,17 +21,17 @@ import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{StringType, ArrayType, DataType} -class TestNullableColumnAccessor[T <: DataType, JvmType]( +class TestNullableColumnAccessor[JvmType]( buffer: ByteBuffer, - columnType: ColumnType[T, JvmType]) + columnType: ColumnType[JvmType]) extends BasicColumnAccessor(buffer, columnType) with NullableColumnAccessor object TestNullableColumnAccessor { - def apply[T <: DataType, JvmType](buffer: ByteBuffer, columnType: ColumnType[T, JvmType]) - : TestNullableColumnAccessor[T, JvmType] = { + def apply[JvmType](buffer: ByteBuffer, columnType: ColumnType[JvmType]) + : TestNullableColumnAccessor[JvmType] = { // Skips the column type ID buffer.getInt() new TestNullableColumnAccessor(buffer, columnType) @@ -42,14 +42,14 @@ class NullableColumnAccessorSuite extends SparkFunSuite { import ColumnarTestUtils._ Seq( - INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, FIXED_DECIMAL(15, 10), BINARY, GENERIC, - DATE, TIMESTAMP - ).foreach { + BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, + STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType))) + .foreach { testNullableColumnAccessor(_) } - def testNullableColumnAccessor[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType]): Unit = { + def testNullableColumnAccessor[JvmType]( + columnType: ColumnType[JvmType]): Unit = { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") val nullRow = makeNullRow(1) @@ -75,7 +75,7 @@ class NullableColumnAccessorSuite extends SparkFunSuite { (0 until 4).foreach { _ => assert(accessor.hasNext) accessor.extractTo(row, 0) - assert(row(0) === randomRow(0)) + assert(row.get(0, columnType.dataType) === randomRow.get(0, columnType.dataType)) assert(accessor.hasNext) accessor.extractTo(row, 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index cb4e9f1eb7f4..241d09ea205e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -21,13 +21,13 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ -class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) - extends BasicColumnBuilder[T, JvmType](new NoopColumnStats, columnType) +class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType]) + extends BasicColumnBuilder[JvmType](new NoopColumnStats, columnType) with NullableColumnBuilder object TestNullableColumnBuilder { - def apply[T <: DataType, JvmType](columnType: ColumnType[T, JvmType], initialSize: Int = 0) - : TestNullableColumnBuilder[T, JvmType] = { + def apply[JvmType](columnType: ColumnType[JvmType], initialSize: Int = 0) + : TestNullableColumnBuilder[JvmType] = { val builder = new TestNullableColumnBuilder(columnType) builder.initialize(initialSize) builder @@ -38,14 +38,14 @@ class NullableColumnBuilderSuite extends SparkFunSuite { import ColumnarTestUtils._ Seq( - INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, FIXED_DECIMAL(15, 10), BINARY, GENERIC, - DATE, TIMESTAMP - ).foreach { + BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, + STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType))) + .foreach { testNullableColumnBuilder(_) } - def testNullableColumnBuilder[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType]): Unit = { + def testNullableColumnBuilder[JvmType]( + columnType: ColumnType[JvmType]): Unit = { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") @@ -92,13 +92,14 @@ class NullableColumnBuilderSuite extends SparkFunSuite { // For non-null values (0 until 4).foreach { _ => - val actual = if (columnType == GENERIC) { - SparkSqlSerializer.deserialize[Any](GENERIC.extract(buffer)) + val actual = if (columnType.isInstanceOf[GENERIC]) { + SparkSqlSerializer.deserialize[Any](columnType.extract(buffer).asInstanceOf[Array[Byte]]) } else { columnType.extract(buffer) } - assert(actual === randomRow(0), "Extracted value didn't equal to the original one") + assert(actual === randomRow.get(0, columnType.dataType), + "Extracted value didn't equal to the original one") } assert(!buffer.hasRemaining) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 2c0879927a12..6b7401464f46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -17,46 +17,43 @@ package org.apache.spark.sql.columnar -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} - import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData._ -class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ - private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize - private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning + private lazy val originalColumnBatchSize = sqlContext.conf.columnBatchSize + private lazy val originalInMemoryPartitionPruning = sqlContext.conf.inMemoryPartitionPruning override protected def beforeAll(): Unit = { + super.beforeAll() // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch - ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) + sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) - val pruningData = ctx.sparkContext.makeRDD((1 to 100).map { key => + val pruningData = sparkContext.makeRDD((1 to 100).map { key => val string = if (((key - 1) / 10) % 2 == 0) null else key.toString TestData(key, string) }, 5).toDF() pruningData.registerTempTable("pruningData") // Enable in-memory partition pruning - ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) + sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Enable in-memory table scan accumulators - ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") + sqlContext.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") + sqlContext.cacheTable("pruningData") } override protected def afterAll(): Unit = { - ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) - ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) - } - - before { - ctx.cacheTable("pruningData") - } - - after { - ctx.uncacheTable("pruningData") + try { + sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) + sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + sqlContext.uncacheTable("pruningData") + } finally { + super.afterAll() + } } // Comparisons @@ -110,7 +107,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi expectedQueryResult: => Seq[Int]): Unit = { test(query) { - val df = ctx.sql(query) + val df = sql(query) val queryExecution = df.queryExecution assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala index f606e2133bed..9a2948c59ba4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala @@ -33,7 +33,7 @@ class BooleanBitSetSuite extends SparkFunSuite { val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet) val rows = Seq.fill[InternalRow](count)(makeRandomRow(BOOLEAN)) - val values = rows.map(_(0)) + val values = rows.map(_.getBoolean(0)) rows.foreach(builder.appendFrom(_, 0)) val buffer = builder.build() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala new file mode 100644 index 000000000000..911d12e93e50 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.plans.physical.SinglePartition +import org.apache.spark.sql.test.SharedSQLContext + +class ExchangeSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.localSeqToDataFrameHolder + + test("shuffling UnsafeRows in exchange") { + val input = (1 to 1000).map(Tuple1.apply) + checkAnswer( + input.toDF(), + plan => ConvertToSafe(Exchange(SinglePartition, ConvertToUnsafe(plan))), + input.map(Row.fromTuple) + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 5854ab48db55..cafa1d515478 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -17,20 +17,43 @@ package org.apache.spark.sql.execution -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.TestData._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{execution, Row, SQLConf} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.sql.test.TestSQLContext.planner._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Row, SQLConf, execution} -class PlannerSuite extends SparkFunSuite { +class PlannerSuite extends SharedSQLContext { + import testImplicits._ + + setupTestData() + + private def testPartialAggregationPlan(query: LogicalPlan): Unit = { + val planner = sqlContext.planner + import planner._ + val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption) + val planned = + plannedOption.getOrElse( + fail(s"Could query play aggregation query $query. Is it an aggregation query?")) + val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } + + // For the new aggregation code path, there will be three aggregate operator for + // distinct aggregations. + assert( + aggregations.size == 2 || aggregations.size == 3, + s"The plan of query $query does not have partial aggregations.") + } + test("unions are collapsed") { + val planner = sqlContext.planner + import planner._ val query = testData.unionAll(testData).unionAll(testData).logicalPlan val planned = BasicOperators(query).head val logicalUnions = query collect { case u: logical.Union => u } @@ -42,52 +65,44 @@ class PlannerSuite extends SparkFunSuite { test("count is partially aggregated") { val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed - val planned = HashAggregation(query).head - val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } - - assert(aggregations.size === 2) + testPartialAggregationPlan(query) } test("count distinct is partially aggregated") { val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed - val planned = HashAggregation(query) - assert(planned.nonEmpty) + testPartialAggregationPlan(query) } test("mixed aggregates are partially aggregated") { val query = testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed - val planned = HashAggregation(query) - assert(planned.nonEmpty) + testPartialAggregationPlan(query) } test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { - def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = { - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold) - val fields = fieldTypes.zipWithIndex.map { - case (dataType, index) => StructField(s"c${index}", dataType, true) - } :+ StructField("key", IntegerType, true) - val schema = StructType(fields) - val row = Row.fromSeq(Seq.fill(fields.size)(null)) - val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil) - createDataFrame(rowRDD, schema).registerTempTable("testLimit") - - val planned = sql( - """ - |SELECT l.a, l.b - |FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key) - """.stripMargin).queryExecution.executedPlan - - val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } - val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } - - assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") - assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") - - dropTempTable("testLimit") - } + def checkPlan(fieldTypes: Seq[DataType]): Unit = { + withTempTable("testLimit") { + val fields = fieldTypes.zipWithIndex.map { + case (dataType, index) => StructField(s"c${index}", dataType, true) + } :+ StructField("key", IntegerType, true) + val schema = StructType(fields) + val row = Row.fromSeq(Seq.fill(fields.size)(null)) + val rowRDD = sparkContext.parallelize(row :: Nil) + sqlContext.createDataFrame(rowRDD, schema).registerTempTable("testLimit") + + val planned = sql( + """ + |SELECT l.a, l.b + |FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key) + """.stripMargin).queryExecution.executedPlan - val origThreshold = conf.autoBroadcastJoinThreshold + val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } + val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } + + assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") + assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") + } + } val simpleTypes = NullType :: @@ -99,13 +114,15 @@ class PlannerSuite extends SparkFunSuite { FloatType :: DoubleType :: DecimalType(10, 5) :: - DecimalType.Unlimited :: + DecimalType.SYSTEM_DEFAULT :: DateType :: TimestampType :: StringType :: BinaryType :: Nil - checkPlan(simpleTypes, newThreshold = 16434) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "16434") { + checkPlan(simpleTypes) + } val complexTypes = ArrayType(DoubleType, true) :: @@ -117,28 +134,237 @@ class PlannerSuite extends SparkFunSuite { StructField("b", ArrayType(DoubleType), nullable = false), StructField("c", DoubleType, nullable = false))) :: Nil - checkPlan(complexTypes, newThreshold = 901617) - - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "901617") { + checkPlan(complexTypes) + } } test("InMemoryRelation statistics propagation") { - val origThreshold = conf.autoBroadcastJoinThreshold - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "81920") { + withTempTable("tiny") { + testData.limit(3).registerTempTable("tiny") + sql("CACHE TABLE tiny") + + val a = testData.as("a") + val b = sqlContext.table("tiny").as("b") + val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan - testData.limit(3).registerTempTable("tiny") - sql("CACHE TABLE tiny") + val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } + val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } - val a = testData.as("a") - val b = table("tiny").as("b") - val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan + assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") + assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") - val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } - val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } + sqlContext.clearCache() + } + } + } - assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") - assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") + test("efficient limit -> project -> sort") { + { + val query = + testData.select('key, 'value).sort('key).limit(2).logicalPlan + val planned = sqlContext.planner.TakeOrderedAndProject(query) + assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) + assert(planned.head.output === testData.select('key, 'value).logicalPlan.output) + } - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) + { + // We need to make sure TakeOrderedAndProject's output is correct when we push a project + // into it. + val query = + testData.select('key, 'value).sort('key).select('value, 'key).limit(2).logicalPlan + val planned = sqlContext.planner.TakeOrderedAndProject(query) + assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) + assert(planned.head.output === testData.select('value, 'key).logicalPlan.output) + } } + + test("PartitioningCollection") { + withTempTable("normal", "small", "tiny") { + testData.registerTempTable("normal") + testData.limit(10).registerTempTable("small") + testData.limit(3).registerTempTable("tiny") + + // Disable broadcast join + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + { + val numExchanges = sql( + """ + |SELECT * + |FROM + | normal JOIN small ON (normal.key = small.key) + | JOIN tiny ON (small.key = tiny.key) + """.stripMargin + ).queryExecution.executedPlan.collect { + case exchange: Exchange => exchange + }.length + assert(numExchanges === 3) + } + + { + // This second query joins on different keys: + val numExchanges = sql( + """ + |SELECT * + |FROM + | normal JOIN small ON (normal.key = small.key) + | JOIN tiny ON (normal.key = tiny.key) + """.stripMargin + ).queryExecution.executedPlan.collect { + case exchange: Exchange => exchange + }.length + assert(numExchanges === 3) + } + + } + } + } + + // --- Unit tests of EnsureRequirements --------------------------------------------------------- + + // When it comes to testing whether EnsureRequirements properly ensures distribution requirements, + // there two dimensions that need to be considered: are the child partitionings compatible and + // do they satisfy the distribution requirements? As a result, we need at least four test cases. + + private def assertDistributionRequirementsAreSatisfied(outputPlan: SparkPlan): Unit = { + if (outputPlan.children.length > 1 + && outputPlan.requiredChildDistribution.toSet != Set(UnspecifiedDistribution)) { + val childPartitionings = outputPlan.children.map(_.outputPartitioning) + if (!Partitioning.allCompatible(childPartitionings)) { + fail(s"Partitionings are not compatible: $childPartitionings") + } + } + outputPlan.children.zip(outputPlan.requiredChildDistribution).foreach { + case (child, requiredDist) => + assert(child.outputPartitioning.satisfies(requiredDist), + s"$child output partitioning does not satisfy $requiredDist:\n$outputPlan") + } + } + + test("EnsureRequirements with incompatible child partitionings which satisfy distribution") { + // Consider an operator that requires inputs that are clustered by two expressions (e.g. + // sort merge join where there are multiple columns in the equi-join condition) + val clusteringA = Literal(1) :: Nil + val clusteringB = Literal(2) :: Nil + val distribution = ClusteredDistribution(clusteringA ++ clusteringB) + // Say that the left and right inputs are each partitioned by _one_ of the two join columns: + val leftPartitioning = HashPartitioning(clusteringA, 1) + val rightPartitioning = HashPartitioning(clusteringB, 1) + // Individually, each input's partitioning satisfies the clustering distribution: + assert(leftPartitioning.satisfies(distribution)) + assert(rightPartitioning.satisfies(distribution)) + // However, these partitionings are not compatible with each other, so we still need to + // repartition both inputs prior to performing the join: + assert(!leftPartitioning.compatibleWith(rightPartitioning)) + assert(!rightPartitioning.compatibleWith(leftPartitioning)) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = leftPartitioning), + DummySparkPlan(outputPartitioning = rightPartitioning) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(Seq.empty, Seq.empty) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case Exchange(_, _) => true }.isEmpty) { + fail(s"Exchange should have been added:\n$outputPlan") + } + } + + test("EnsureRequirements with child partitionings with different numbers of output partitions") { + // This is similar to the previous test, except it checks that partitionings are not compatible + // unless they produce the same number of partitions. + val clustering = Literal(1) :: Nil + val distribution = ClusteredDistribution(clustering) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = HashPartitioning(clustering, 1)), + DummySparkPlan(outputPartitioning = HashPartitioning(clustering, 2)) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(Seq.empty, Seq.empty) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + } + + test("EnsureRequirements with compatible child partitionings that do not satisfy distribution") { + val distribution = ClusteredDistribution(Literal(1) :: Nil) + // The left and right inputs have compatible partitionings but they do not satisfy the + // distribution because they are clustered on different columns. Thus, we need to shuffle. + val childPartitioning = HashPartitioning(Literal(2) :: Nil, 1) + assert(!childPartitioning.satisfies(distribution)) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = childPartitioning), + DummySparkPlan(outputPartitioning = childPartitioning) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(Seq.empty, Seq.empty) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case Exchange(_, _) => true }.isEmpty) { + fail(s"Exchange should have been added:\n$outputPlan") + } + } + + test("EnsureRequirements with compatible child partitionings that satisfy distribution") { + // In this case, all requirements are satisfied and no exchange should be added. + val distribution = ClusteredDistribution(Literal(1) :: Nil) + val childPartitioning = HashPartitioning(Literal(1) :: Nil, 5) + assert(childPartitioning.satisfies(distribution)) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = childPartitioning), + DummySparkPlan(outputPartitioning = childPartitioning) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(Seq.empty, Seq.empty) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case Exchange(_, _) => true }.nonEmpty) { + fail(s"Exchange should not have been added:\n$outputPlan") + } + } + + // This is a regression test for SPARK-9703 + test("EnsureRequirements should not repartition if only ordering requirement is unsatisfied") { + // Consider an operator that imposes both output distribution and ordering requirements on its + // children, such as sort sort merge join. If the distribution requirements are satisfied but + // the output ordering requirements are unsatisfied, then the planner should only add sorts and + // should not need to add additional shuffles / exchanges. + val outputOrdering = Seq(SortOrder(Literal(1), Ascending)) + val distribution = ClusteredDistribution(Literal(1) :: Nil) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = SinglePartition), + DummySparkPlan(outputPartitioning = SinglePartition) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(outputOrdering, outputOrdering) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case Exchange(_, _) => true }.nonEmpty) { + fail(s"No Exchanges should have been added:\n$outputPlan") + } + } + + // --------------------------------------------------------------------------------------------- +} + +// Used for unit-testing EnsureRequirements +private case class DummySparkPlan( + override val children: Seq[SparkPlan] = Nil, + override val outputOrdering: Seq[SortOrder] = Nil, + override val outputPartitioning: Partitioning = UnknownPartitioning(0), + override val requiredChildDistribution: Seq[Distribution] = Nil, + override val requiredChildOrdering: Seq[Seq[SortOrder]] = Nil + ) extends SparkPlan { + override protected def doExecute(): RDD[InternalRow] = throw new NotImplementedError + override def output: Seq[Attribute] = Seq.empty } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala new file mode 100644 index 000000000000..4492e37ad01f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { + + private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect { + case c: ConvertToUnsafe => c + case c: ConvertToSafe => c + } + + private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) + assert(!outputsSafe.outputsUnsafeRows) + private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) + assert(outputsUnsafe.outputsUnsafeRows) + + test("planner should insert unsafe->safe conversions when required") { + val plan = Limit(10, outputsUnsafe) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe]) + } + + test("filter can process unsafe rows") { + val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).size === 1) + assert(preparedPlan.outputsUnsafeRows) + } + + test("filter can process safe rows") { + val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).isEmpty) + assert(!preparedPlan.outputsUnsafeRows) + } + + test("execute() fails an assertion if inputs rows are of different formats") { + val e = intercept[AssertionError] { + Union(Seq(outputsSafe, outputsUnsafe)).execute() + } + assert(e.getMessage.contains("format")) + } + + test("union requires all of its input rows' formats to agree") { + val plan = Union(Seq(outputsSafe, outputsUnsafe)) + assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(preparedPlan.outputsUnsafeRows) + } + + test("union can process safe rows") { + val plan = Union(Seq(outputsSafe, outputsSafe)) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(!preparedPlan.outputsUnsafeRows) + } + + test("union can process unsafe rows") { + val plan = Union(Seq(outputsUnsafe, outputsUnsafe)) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(preparedPlan.outputsUnsafeRows) + } + + test("round trip with ConvertToUnsafe and ConvertToSafe") { + val input = Seq(("hello", 1), ("world", 2)) + checkAnswer( + sqlContext.createDataFrame(input), + plan => ConvertToSafe(ConvertToUnsafe(plan)), + input.map(Row.fromTuple) + ) + } + + test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") { + SparkPlan.currentContext.set(sqlContext) + val schema = ArrayType(StringType) + val rows = (1 to 100).map { i => + InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString)))) + } + val relation = LocalTableScan(Seq(AttributeReference("t", schema)()), rows) + + val plan = + DummyPlan( + ConvertToSafe( + ConvertToUnsafe(relation))) + assert(plan.execute().collect().map(_.getUTF8String(0).toString) === (1 to 100).map(_.toString)) + } +} + +case class DummyPlan(child: SparkPlan) extends UnaryNode { + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + // This `DummyPlan` is in safe mode, so we don't need to do copy even we hold some + // values gotten from the incoming rows. + // we cache all strings here to make sure we have deep copied UTF8String inside incoming + // safe InternalRow. + val strings = new scala.collection.mutable.ArrayBuffer[UTF8String] + iter.foreach { row => + strings += row.getArray(0).getUTF8String(0) + } + strings.map(InternalRow(_)).iterator + } + } + + override def output: Seq[Attribute] = Seq(AttributeReference("a", StringType)()) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala new file mode 100644 index 000000000000..63639681ef80 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.Properties + +import scala.collection.parallel.CompositeThrowable + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql.SQLContext + +class SQLExecutionSuite extends SparkFunSuite { + + test("concurrent query execution (SPARK-10548)") { + // Try to reproduce the issue with the old SparkContext + val conf = new SparkConf() + .setMaster("local[*]") + .setAppName("test") + val badSparkContext = new BadSparkContext(conf) + try { + testConcurrentQueryExecution(badSparkContext) + fail("unable to reproduce SPARK-10548") + } catch { + case e: IllegalArgumentException => + assert(e.getMessage.contains(SQLExecution.EXECUTION_ID_KEY)) + } finally { + badSparkContext.stop() + } + + // Verify that the issue is fixed with the latest SparkContext + val goodSparkContext = new SparkContext(conf) + try { + testConcurrentQueryExecution(goodSparkContext) + } finally { + goodSparkContext.stop() + } + } + + /** + * Trigger SPARK-10548 by mocking a parent and its child thread executing queries concurrently. + */ + private def testConcurrentQueryExecution(sc: SparkContext): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Initialize local properties. This is necessary for the test to pass. + sc.getLocalProperties + + // Set up a thread that runs executes a simple SQL query. + // Before starting the thread, mutate the execution ID in the parent. + // The child thread should not see the effect of this change. + var throwable: Option[Throwable] = None + val child = new Thread { + override def run(): Unit = { + try { + sc.parallelize(1 to 100).map { i => (i, i) }.toDF("a", "b").collect() + } catch { + case t: Throwable => + throwable = Some(t) + } + + } + } + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, "anything") + child.start() + child.join() + + // The throwable is thrown from the child thread so it doesn't have a helpful stack trace + throwable.foreach { t => + t.setStackTrace(t.getStackTrace ++ Thread.currentThread.getStackTrace) + throw t + } + } + +} + +/** + * A bad [[SparkContext]] that does not clone the inheritable thread local properties + * when passing them to children threads. + */ +private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) { + protected[spark] override val localProperties = new InheritableThreadLocal[Properties] { + override protected def childValue(parent: Properties): Properties = new Properties(parent) + override protected def initialValue(): Properties = new Properties() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index a1e3ca11b1ad..3073d492e613 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.test.SharedSQLContext -class SortSuite extends SparkPlanTest { +class SortSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.localSeqToDataFrameHolder // This test was originally added as an example of how to use [[SparkPlanTest]]; // it's not designed to be a comprehensive test of ExternalSort. @@ -33,12 +36,14 @@ class SortSuite extends SparkPlanTest { checkAnswer( input.toDF("a", "b", "c"), - ExternalSort('a.asc :: 'b.asc :: Nil, global = false, _: SparkPlan), - input.sorted) + ExternalSort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan), + input.sortBy(t => (t._1, t._2)).map(Row.fromTuple), + sortAnswers = false) checkAnswer( input.toDF("a", "b", "c"), - ExternalSort('b.asc :: 'a.asc :: Nil, global = false, _: SparkPlan), - input.sortBy(t => (t._2, t._1))) + ExternalSort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan), + input.sortBy(t => (t._2, t._1)).map(Row.fromTuple), + sortAnswers = false) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 13f3be8ca28d..3d218f01c9ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -22,26 +22,16 @@ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal import org.apache.spark.SparkFunSuite - +import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.BoundReference -import org.apache.spark.sql.catalyst.util._ - -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{DataFrameHolder, Row, DataFrame} +import org.apache.spark.sql.test.SQLTestUtils /** * Base class for writing tests for individual physical operators. For an example of how this * class's test helper methods can be used, see [[SortSuite]]. */ -class SparkPlanTest extends SparkFunSuite { - - /** - * Creates a DataFrame from a local Seq of Product. - */ - implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = { - TestSQLContext.implicits.localSeqToDataFrameHolder(data) - } +private[sql] abstract class SparkPlanTest extends SparkFunSuite { + protected def sqlContext: SQLContext /** * Runs the plan and makes sure the answer matches the expected result. @@ -49,30 +39,83 @@ class SparkPlanTest extends SparkFunSuite { * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate * the physical operator that's being tested. * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. */ protected def checkAnswer( input: DataFrame, planFunction: SparkPlan => SparkPlan, - expectedAnswer: Seq[Row]): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer) match { + expectedAnswer: Seq[Row], + sortAnswers: Boolean = true): Unit = { + doCheckAnswer( + input :: Nil, + (plans: Seq[SparkPlan]) => planFunction(plans.head), + expectedAnswer, + sortAnswers) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param left the left input data to be used. + * @param right the right input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. + */ + protected def checkAnswer2( + left: DataFrame, + right: DataFrame, + planFunction: (SparkPlan, SparkPlan) => SparkPlan, + expectedAnswer: Seq[Row], + sortAnswers: Boolean = true): Unit = { + doCheckAnswer( + left :: right :: Nil, + (plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)), + expectedAnswer, + sortAnswers) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts a sequence of input SparkPlans and uses them to + * instantiate the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. + */ + protected def doCheckAnswer( + input: Seq[DataFrame], + planFunction: Seq[SparkPlan] => SparkPlan, + expectedAnswer: Seq[Row], + sortAnswers: Boolean = true): Unit = { + SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } } /** - * Runs the plan and makes sure the answer matches the expected result. + * Runs the plan and makes sure the answer matches the result produced by a reference plan. * @param input the input data to be used. * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate * the physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. + * @param expectedPlanFunction a function which accepts the input SparkPlan and uses it to + * instantiate a reference implementation of the physical operator + * that's being tested. The result of executing this plan will be + * treated as the source-of-truth for the test. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. */ - protected def checkAnswer[A <: Product : TypeTag]( + protected def checkThatPlansAgree( input: DataFrame, planFunction: SparkPlan => SparkPlan, - expectedAnswer: Seq[A]): Unit = { - val expectedRows = expectedAnswer.map(Row.fromTuple) - SparkPlanTest.checkAnswer(input, planFunction, expectedRows) match { + expectedPlanFunction: SparkPlan => SparkPlan, + sortAnswers: Boolean = true): Unit = { + SparkPlanTest.checkAnswer( + input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -85,54 +128,87 @@ class SparkPlanTest extends SparkFunSuite { object SparkPlanTest { /** - * Runs the plan and makes sure the answer matches the expected result. + * Runs the plan and makes sure the answer matches the result produced by a reference plan. * @param input the input data to be used. * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate * the physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param expectedPlanFunction a function which accepts the input SparkPlan and uses it to + * instantiate a reference implementation of the physical operator + * that's being tested. The result of executing this plan will be + * treated as the source-of-truth for the test. */ def checkAnswer( input: DataFrame, planFunction: SparkPlan => SparkPlan, - expectedAnswer: Seq[Row]): Option[String] = { + expectedPlanFunction: SparkPlan => SparkPlan, + sortAnswers: Boolean, + sqlContext: SQLContext): Option[String] = { val outputPlan = planFunction(input.queryExecution.sparkPlan) + val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan) - // A very simple resolver to make writing tests easier. In contrast to the real resolver - // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = outputPlan transform { - case plan: SparkPlan => - val inputMap = plan.children.flatMap(_.output).zipWithIndex.map { - case (a, i) => - (a.name, BoundReference(i, a.dataType, a.nullable)) - }.toMap - - plan.transformExpressions { - case UnresolvedAttribute(Seq(u)) => - inputMap.getOrElse(u, - sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) - } + val expectedAnswer: Seq[Row] = try { + executePlan(expectedOutputPlan, sqlContext) + } catch { + case NonFatal(e) => + val errorMessage = + s""" + | Exception thrown while executing Spark plan to calculate expected answer: + | $expectedOutputPlan + | == Exception == + | $e + | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) } - def prepareAnswer(answer: Seq[Row]): Seq[Row] = { - // Converts data to types that we can do equality comparison using Scala collections. - // For BigDecimal type, the Scala type has a better definition of equality test (similar to - // Java's java.math.BigDecimal.compareTo). - // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for - // equality test. - // This function is copied from Catalyst's QueryTest - val converted: Seq[Row] = answer.map { s => - Row.fromSeq(s.toSeq.map { - case d: java.math.BigDecimal => BigDecimal(d) - case b: Array[Byte] => b.toSeq - case o => o - }) - } - converted.sortBy(_.toString()) + val actualAnswer: Seq[Row] = try { + executePlan(outputPlan, sqlContext) + } catch { + case NonFatal(e) => + val errorMessage = + s""" + | Exception thrown while executing Spark plan: + | $outputPlan + | == Exception == + | $e + | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + + SQLTestUtils.compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage => + s""" + | Results do not match. + | Actual result Spark plan: + | $outputPlan + | Expected result Spark plan: + | $expectedOutputPlan + | $errorMessage + """.stripMargin } + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. + */ + def checkAnswer( + input: Seq[DataFrame], + planFunction: Seq[SparkPlan] => SparkPlan, + expectedAnswer: Seq[Row], + sortAnswers: Boolean, + sqlContext: SQLContext): Option[String] = { + + val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan)) val sparkAnswer: Seq[Row] = try { - resolvedPlan.executeCollect().toSeq + executePlan(outputPlan, sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -146,22 +222,30 @@ object SparkPlanTest { return Some(errorMessage) } - if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { - val errorMessage = - s""" - | Results do not match for Spark plan: - | $outputPlan - | == Results == - | ${sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer).map(_.toString()), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} - """.stripMargin - return Some(errorMessage) + SQLTestUtils.compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage => + s""" + | Results do not match for Spark plan: + | $outputPlan + | $errorMessage + """.stripMargin } + } - None + private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { + // A very simple resolver to make writing tests easier. In contrast to the real resolver + // this is always case sensitive and does not try to handle scoping or complex type resolution. + val resolvedPlan = sqlContext.prepareForExecution.execute( + outputPlan transform { + case plan: SparkPlan => + val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap + plan transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } + } + ) + resolvedPlan.executeCollect().toSeq } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala deleted file mode 100644 index 8631e247c6c0..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ /dev/null @@ -1,196 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.sql.{Timestamp, Date} - -import org.apache.spark.sql.test.TestSQLContext -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.rdd.ShuffledRDD -import org.apache.spark.serializer.Serializer -import org.apache.spark.{ShuffleDependency, SparkFunSuite} -import org.apache.spark.sql.types._ -import org.apache.spark.sql.Row -import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest} - -class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite { - // Make sure that we will not use serializer2 for unsupported data types. - def checkSupported(dataType: DataType, isSupported: Boolean): Unit = { - val testName = - s"${if (dataType == null) null else dataType.toString} is " + - s"${if (isSupported) "supported" else "unsupported"}" - - test(testName) { - assert(SparkSqlSerializer2.support(Array(dataType)) === isSupported) - } - } - - checkSupported(null, isSupported = true) - checkSupported(NullType, isSupported = true) - checkSupported(BooleanType, isSupported = true) - checkSupported(ByteType, isSupported = true) - checkSupported(ShortType, isSupported = true) - checkSupported(IntegerType, isSupported = true) - checkSupported(LongType, isSupported = true) - checkSupported(FloatType, isSupported = true) - checkSupported(DoubleType, isSupported = true) - checkSupported(DateType, isSupported = true) - checkSupported(TimestampType, isSupported = true) - checkSupported(StringType, isSupported = true) - checkSupported(BinaryType, isSupported = true) - checkSupported(DecimalType(10, 5), isSupported = true) - checkSupported(DecimalType.Unlimited, isSupported = true) - - // For now, ArrayType, MapType, and StructType are not supported. - checkSupported(ArrayType(DoubleType, true), isSupported = false) - checkSupported(ArrayType(StringType, false), isSupported = false) - checkSupported(MapType(IntegerType, StringType, true), isSupported = false) - checkSupported(MapType(IntegerType, ArrayType(DoubleType), false), isSupported = false) - checkSupported(StructType(StructField("a", IntegerType, true) :: Nil), isSupported = false) - // UDTs are not supported right now. - checkSupported(new MyDenseVectorUDT, isSupported = false) -} - -abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll { - var allColumns: String = _ - val serializerClass: Class[Serializer] = - classOf[SparkSqlSerializer2].asInstanceOf[Class[Serializer]] - var numShufflePartitions: Int = _ - var useSerializer2: Boolean = _ - - protected lazy val ctx = TestSQLContext - - override def beforeAll(): Unit = { - numShufflePartitions = ctx.conf.numShufflePartitions - useSerializer2 = ctx.conf.useSqlSerializer2 - - ctx.sql("set spark.sql.useSerializer2=true") - - val supportedTypes = - Seq(StringType, BinaryType, NullType, BooleanType, - ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5), - DateType, TimestampType) - - val fields = supportedTypes.zipWithIndex.map { case (dataType, index) => - StructField(s"col$index", dataType, true) - } - allColumns = fields.map(_.name).mkString(",") - val schema = StructType(fields) - - // Create a RDD with all data types supported by SparkSqlSerializer2. - val rdd = - ctx.sparkContext.parallelize((1 to 1000), 10).map { i => - Row( - s"str${i}: test serializer2.", - s"binary${i}: test serializer2.".getBytes("UTF-8"), - null, - i % 2 == 0, - i.toByte, - i.toShort, - i, - Long.MaxValue - i.toLong, - (i + 0.25).toFloat, - (i + 0.75), - BigDecimal(Long.MaxValue.toString + ".12345"), - new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"), - new Date(i), - new Timestamp(i)) - } - - ctx.createDataFrame(rdd, schema).registerTempTable("shuffle") - - super.beforeAll() - } - - override def afterAll(): Unit = { - ctx.dropTempTable("shuffle") - ctx.sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions") - ctx.sql(s"set spark.sql.useSerializer2=$useSerializer2") - super.afterAll() - } - - def checkSerializer[T <: Serializer]( - executedPlan: SparkPlan, - expectedSerializerClass: Class[T]): Unit = { - executedPlan.foreach { - case exchange: Exchange => - val shuffledRDD = exchange.execute().firstParent.asInstanceOf[ShuffledRDD[_, _, _]] - val dependency = shuffledRDD.getDependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] - val serializerNotSetMessage = - s"Expected $expectedSerializerClass as the serializer of Exchange. " + - s"However, the serializer was not set." - val serializer = dependency.serializer.getOrElse(fail(serializerNotSetMessage)) - assert(serializer.getClass === expectedSerializerClass) - case _ => // Ignore other nodes. - } - } - - test("key schema and value schema are not nulls") { - val df = ctx.sql(s"SELECT DISTINCT ${allColumns} FROM shuffle") - checkSerializer(df.queryExecution.executedPlan, serializerClass) - checkAnswer( - df, - ctx.table("shuffle").collect()) - } - - test("key schema is null") { - val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",") - val df = ctx.sql(s"SELECT $aggregations FROM shuffle") - checkSerializer(df.queryExecution.executedPlan, serializerClass) - checkAnswer( - df, - Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000)) - } - - test("value schema is null") { - val df = ctx.sql(s"SELECT col0 FROM shuffle ORDER BY col0") - checkSerializer(df.queryExecution.executedPlan, serializerClass) - assert(df.map(r => r.getString(0)).collect().toSeq === - ctx.table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq) - } - - test("no map output field") { - val df = ctx.sql(s"SELECT 1 + 1 FROM shuffle") - checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer]) - } -} - -/** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */ -class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite { - override def beforeAll(): Unit = { - super.beforeAll() - // Sort merge will not be triggered. - val bypassMergeThreshold = - ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}") - } -} - -/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */ -class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite { - - override def beforeAll(): Unit = { - super.beforeAll() - // To trigger the sort merge. - val bypassMergeThreshold = - ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}") - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala new file mode 100644 index 000000000000..48c3938ff87b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.shuffle.ShuffleMemoryManager + +/** + * A [[ShuffleMemoryManager]] that can be controlled to run out of memory. + */ +class TestShuffleMemoryManager extends ShuffleMemoryManager(Long.MaxValue, 4 * 1024 * 1024) { + private var oom = false + + override def tryToAcquire(numBytes: Long): Long = { + if (oom) { + oom = false + 0 + } else { + // Uncomment the following to trace memory allocations. + // println(s"tryToAcquire $numBytes in " + + // Thread.currentThread().getStackTrace.mkString("", "\n -", "")) + val acquired = super.tryToAcquire(numBytes) + acquired + } + } + + override def release(numBytes: Long): Unit = { + // Uncomment the following to trace memory releases. + // println(s"release $numBytes in " + + // Thread.currentThread().getStackTrace.mkString("", "\n -", "")) + super.release(numBytes) + } + + def markAsOutOfMemory(): Unit = { + oom = true + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala new file mode 100644 index 000000000000..7a0f0dfd2b7f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.util.Random + +import org.apache.spark.AccumulatorSuite +import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ + +/** + * A test suite that generates randomized data to test the [[TungstenSort]] operator. + */ +class TungstenSortSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.localSeqToDataFrameHolder + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) + } + + override def afterAll(): Unit = { + try { + sqlContext.conf.unsetConf(SQLConf.CODEGEN_ENABLED) + } finally { + super.afterAll() + } + } + + test("sort followed by limit") { + checkThatPlansAgree( + (1 to 100).map(v => Tuple1(v)).toDF("a"), + (child: SparkPlan) => Limit(10, TungstenSort('a.asc :: Nil, true, child)), + (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), + sortAnswers = false + ) + } + + test("sorting does not crash for large inputs") { + val sortOrder = 'a.asc :: Nil + val stringLength = 1024 * 1024 * 2 + checkThatPlansAgree( + Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), + TungstenSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), + Sort(sortOrder, global = true, _: SparkPlan), + sortAnswers = false + ) + } + + test("sorting updates peak execution memory") { + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "unsafe external sort") { + checkThatPlansAgree( + (1 to 100).map(v => Tuple1(v)).toDF("a"), + (child: SparkPlan) => TungstenSort('a.asc :: Nil, true, child), + (child: SparkPlan) => Sort('a.asc :: Nil, global = true, child), + sortAnswers = false) + } + } + + // Test sorting on different data types + for ( + dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); + nullable <- Seq(true, false); + sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); + randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) + ) { + test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { + val inputData = Seq.fill(1000)(randomDataGenerator()) + val inputDf = sqlContext.createDataFrame( + sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), + StructType(StructField("a", dataType, nullable = true) :: Nil) + ) + assert(TungstenSort.supportsSchema(inputDf.schema)) + checkThatPlansAgree( + inputDf, + plan => ConvertToSafe( + TungstenSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), + Sort(sortOrder, global = true, _: SparkPlan), + sortAnswers = false + ) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala new file mode 100644 index 000000000000..d1f0b2b1fc52 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -0,0 +1,336 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.util.control.NonFatal +import scala.collection.mutable +import scala.util.{Try, Random} + +import org.scalatest.Matchers + +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} +import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} +import org.apache.spark.unsafe.types.UTF8String + +/** + * Test suite for [[UnsafeFixedWidthAggregationMap]]. + * + * Use [[testWithMemoryLeakDetection]] rather than [[test]] to construct test cases. + */ +class UnsafeFixedWidthAggregationMapSuite + extends SparkFunSuite + with Matchers + with SharedSQLContext { + + import UnsafeFixedWidthAggregationMap._ + + private val groupKeySchema = StructType(StructField("product", StringType) :: Nil) + private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil) + private def emptyAggregationBuffer: InternalRow = InternalRow(0) + private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes + + private var taskMemoryManager: TaskMemoryManager = null + private var shuffleMemoryManager: TestShuffleMemoryManager = null + + def testWithMemoryLeakDetection(name: String)(f: => Unit) { + def cleanup(): Unit = { + if (taskMemoryManager != null) { + val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask() + assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0) + assert(leakedShuffleMemory === 0) + taskMemoryManager = null + } + TaskContext.unset() + } + + test(name) { + taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + shuffleMemoryManager = new TestShuffleMemoryManager + + TaskContext.setTaskContext(new TaskContextImpl( + stageId = 0, + partitionId = 0, + taskAttemptId = Random.nextInt(10000), + attemptNumber = 0, + taskMemoryManager = taskMemoryManager, + metricsSystem = null, + internalAccumulators = Seq.empty)) + + try { + f + } catch { + case NonFatal(e) => + Try(cleanup()) + throw e + } + cleanup() + } + } + + private def randomStrings(n: Int): Seq[String] = { + val rand = new Random(42) + Seq.fill(512) { + Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString + }.distinct + } + + testWithMemoryLeakDetection("supported schemas") { + assert(supportsAggregationBufferSchema( + StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil))) + assert(supportsAggregationBufferSchema( + StructType(StructField("x", DecimalType.SYSTEM_DEFAULT) :: Nil))) + assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) + assert( + !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) + } + + testWithMemoryLeakDetection("empty map") { + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + taskMemoryManager, + shuffleMemoryManager, + 1024, // initial capacity, + PAGE_SIZE_BYTES, + false // disable perf metrics + ) + assert(!map.iterator().next()) + map.free() + } + + testWithMemoryLeakDetection("updating values for a single key") { + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + taskMemoryManager, + shuffleMemoryManager, + 1024, // initial capacity + PAGE_SIZE_BYTES, + false // disable perf metrics + ) + val groupKey = InternalRow(UTF8String.fromString("cats")) + + // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts) + assert(map.getAggregationBuffer(groupKey) != null) + val iter = map.iterator() + assert(iter.next()) + iter.getKey.getString(0) should be ("cats") + iter.getValue.getInt(0) should be (0) + assert(!iter.next()) + + // Modifications to rows retrieved from the map should update the values in the map + iter.getValue.setInt(0, 42) + map.getAggregationBuffer(groupKey).getInt(0) should be (42) + + map.free() + } + + testWithMemoryLeakDetection("inserting large random keys") { + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + taskMemoryManager, + shuffleMemoryManager, + 128, // initial capacity + PAGE_SIZE_BYTES, + false // disable perf metrics + ) + val rand = new Random(42) + val groupKeys: Set[String] = Seq.fill(512)(rand.nextString(1024)).toSet + groupKeys.foreach { keyString => + assert(map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString))) != null) + } + + val seenKeys = new mutable.HashSet[String] + val iter = map.iterator() + while (iter.next()) { + seenKeys += iter.getKey.getString(0) + } + assert(seenKeys.size === groupKeys.size) + assert(seenKeys === groupKeys) + map.free() + } + + testWithMemoryLeakDetection("test external sorting") { + // Memory consumption in the beginning of the task. + val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask() + + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + taskMemoryManager, + shuffleMemoryManager, + 128, // initial capacity + PAGE_SIZE_BYTES, + false // disable perf metrics + ) + + val keys = randomStrings(1024).take(512) + keys.foreach { keyString => + val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString))) + buf.setInt(0, keyString.length) + assert(buf != null) + } + + // Convert the map into a sorter + val sorter = map.destructAndCreateExternalSorter() + + withClue(s"destructAndCreateExternalSorter should release memory used by the map") { + // 4096 * 16 is the initial size allocated for the pointer/prefix array in the in-mem sorter. + assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() === + initialMemoryConsumption + 4096 * 16) + } + + // Add more keys to the sorter and make sure the results come out sorted. + val additionalKeys = randomStrings(1024) + val keyConverter = UnsafeProjection.create(groupKeySchema) + val valueConverter = UnsafeProjection.create(aggBufferSchema) + + additionalKeys.zipWithIndex.foreach { case (str, i) => + val k = InternalRow(UTF8String.fromString(str)) + val v = InternalRow(str.length) + sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) + + if ((i % 100) == 0) { + shuffleMemoryManager.markAsOutOfMemory() + sorter.closeCurrentPage() + } + } + + val out = new scala.collection.mutable.ArrayBuffer[String] + val iter = sorter.sortedIterator() + while (iter.next()) { + assert(iter.getKey.getString(0).length === iter.getValue.getInt(0)) + out += iter.getKey.getString(0) + } + + assert(out === (keys ++ additionalKeys).sorted) + + map.free() + } + + testWithMemoryLeakDetection("test external sorting with an empty map") { + + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + taskMemoryManager, + shuffleMemoryManager, + 128, // initial capacity + PAGE_SIZE_BYTES, + false // disable perf metrics + ) + + // Convert the map into a sorter + val sorter = map.destructAndCreateExternalSorter() + + // Add more keys to the sorter and make sure the results come out sorted. + val additionalKeys = randomStrings(1024) + val keyConverter = UnsafeProjection.create(groupKeySchema) + val valueConverter = UnsafeProjection.create(aggBufferSchema) + + additionalKeys.zipWithIndex.foreach { case (str, i) => + val k = InternalRow(UTF8String.fromString(str)) + val v = InternalRow(str.length) + sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) + + if ((i % 100) == 0) { + shuffleMemoryManager.markAsOutOfMemory() + sorter.closeCurrentPage() + } + } + + val out = new scala.collection.mutable.ArrayBuffer[String] + val iter = sorter.sortedIterator() + while (iter.next()) { + // At here, we also test if copy is correct. + val key = iter.getKey.copy() + val value = iter.getValue.copy() + assert(key.getString(0).length === value.getInt(0)) + out += key.getString(0) + } + + assert(out === (additionalKeys).sorted) + + map.free() + } + + testWithMemoryLeakDetection("test external sorting with empty records") { + + // Memory consumption in the beginning of the task. + val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask() + + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + StructType(Nil), + StructType(Nil), + taskMemoryManager, + shuffleMemoryManager, + 128, // initial capacity + PAGE_SIZE_BYTES, + false // disable perf metrics + ) + + (1 to 10).foreach { i => + val buf = map.getAggregationBuffer(UnsafeRow.createFromByteArray(0, 0)) + assert(buf != null) + } + + // Convert the map into a sorter. Right now, it contains one record. + val sorter = map.destructAndCreateExternalSorter() + + withClue(s"destructAndCreateExternalSorter should release memory used by the map") { + // 4096 * 16 is the initial size allocated for the pointer/prefix array in the in-mem sorter. + assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() === + initialMemoryConsumption + 4096 * 16) + } + + // Add more keys to the sorter and make sure the results come out sorted. + (1 to 4096).foreach { i => + sorter.insertKV(UnsafeRow.createFromByteArray(0, 0), UnsafeRow.createFromByteArray(0, 0)) + + if ((i % 100) == 0) { + shuffleMemoryManager.markAsOutOfMemory() + sorter.closeCurrentPage() + } + } + + var count = 0 + val iter = sorter.sortedIterator() + while (iter.next()) { + // At here, we also test if copy is correct. + iter.getKey.copy() + iter.getValue.copy() + count += 1; + } + + // 1 record was from the map and 4096 records were explicitly inserted. + assert(count === 4097) + + map.free() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala new file mode 100644 index 000000000000..d3be568a8758 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.util.Random + +import org.apache.spark._ +import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} + +/** + * Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data. + */ +class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { + private val keyTypes = Seq(IntegerType, FloatType, DoubleType, StringType) + private val valueTypes = Seq(IntegerType, FloatType, DoubleType, StringType) + + testKVSorter(new StructType, new StructType, spill = true) + testKVSorter(new StructType().add("c1", IntegerType), new StructType, spill = true) + testKVSorter(new StructType, new StructType().add("c1", IntegerType), spill = true) + + private val rand = new Random(42) + for (i <- 0 until 6) { + val keySchema = RandomDataGenerator.randomSchema(rand.nextInt(10) + 1, keyTypes) + val valueSchema = RandomDataGenerator.randomSchema(rand.nextInt(10) + 1, valueTypes) + testKVSorter(keySchema, valueSchema, spill = i > 3) + } + + + /** + * Create a test case using randomly generated data for the given key and value schema. + * + * The approach works as follows: + * + * - Create input by randomly generating data based on the given schema + * - Run [[UnsafeKVExternalSorter]] on the generated data + * - Collect the output from the sorter, and make sure the keys are sorted in ascending order + * - Sort the input by both key and value, and sort the sorter output also by both key and value. + * Compare the sorted input and sorted output together to make sure all the key/values match. + * + * If spill is set to true, the sorter will spill probabilistically roughly every 100 records. + */ + private def testKVSorter(keySchema: StructType, valueSchema: StructType, spill: Boolean): Unit = { + // Create the data converters + val kExternalConverter = CatalystTypeConverters.createToCatalystConverter(keySchema) + val vExternalConverter = CatalystTypeConverters.createToCatalystConverter(valueSchema) + val kConverter = UnsafeProjection.create(keySchema) + val vConverter = UnsafeProjection.create(valueSchema) + + val keyDataGen = RandomDataGenerator.forType(keySchema, nullable = false).get + val valueDataGen = RandomDataGenerator.forType(valueSchema, nullable = false).get + + val inputData = Seq.fill(1024) { + val k = kConverter(kExternalConverter.apply(keyDataGen.apply()).asInstanceOf[InternalRow]) + val v = vConverter(vExternalConverter.apply(valueDataGen.apply()).asInstanceOf[InternalRow]) + (k.asInstanceOf[InternalRow].copy(), v.asInstanceOf[InternalRow].copy()) + } + + val keySchemaStr = keySchema.map(_.dataType.simpleString).mkString("[", ",", "]") + val valueSchemaStr = valueSchema.map(_.dataType.simpleString).mkString("[", ",", "]") + + test(s"kv sorting key schema $keySchemaStr and value schema $valueSchemaStr") { + testKVSorter( + keySchema, + valueSchema, + inputData, + pageSize = 16 * 1024 * 1024, + spill + ) + } + } + + /** + * Create a test case using the given input data for the given key and value schema. + * + * The approach works as follows: + * + * - Create input by randomly generating data based on the given schema + * - Run [[UnsafeKVExternalSorter]] on the input data + * - Collect the output from the sorter, and make sure the keys are sorted in ascending order + * - Sort the input by both key and value, and sort the sorter output also by both key and value. + * Compare the sorted input and sorted output together to make sure all the key/values match. + * + * If spill is set to true, the sorter will spill probabilistically roughly every 100 records. + */ + private def testKVSorter( + keySchema: StructType, + valueSchema: StructType, + inputData: Seq[(InternalRow, InternalRow)], + pageSize: Long, + spill: Boolean): Unit = { + + val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + val shuffleMemMgr = new TestShuffleMemoryManager + TaskContext.setTaskContext(new TaskContextImpl( + stageId = 0, + partitionId = 0, + taskAttemptId = 98456, + attemptNumber = 0, + taskMemoryManager = taskMemMgr, + metricsSystem = null, + internalAccumulators = Seq.empty)) + + val sorter = new UnsafeKVExternalSorter( + keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr, pageSize) + + // Insert the keys and values into the sorter + inputData.foreach { case (k, v) => + sorter.insertKV(k.asInstanceOf[UnsafeRow], v.asInstanceOf[UnsafeRow]) + // 1% chance we will spill + if (rand.nextDouble() < 0.01 && spill) { + shuffleMemMgr.markAsOutOfMemory() + sorter.closeCurrentPage() + } + } + + // Collect the sorted output + val out = new scala.collection.mutable.ArrayBuffer[(InternalRow, InternalRow)] + val iter = sorter.sortedIterator() + while (iter.next()) { + out += Tuple2(iter.getKey.copy(), iter.getValue.copy()) + } + sorter.cleanupResources() + + val keyOrdering = InterpretedOrdering.forSchema(keySchema.map(_.dataType)) + val valueOrdering = InterpretedOrdering.forSchema(valueSchema.map(_.dataType)) + val kvOrdering = new Ordering[(InternalRow, InternalRow)] { + override def compare(x: (InternalRow, InternalRow), y: (InternalRow, InternalRow)): Int = { + keyOrdering.compare(x._1, y._1) match { + case 0 => valueOrdering.compare(x._2, y._2) + case cmp => cmp + } + } + } + + // Testing to make sure output from the sorter is sorted by key + var prevK: InternalRow = null + out.zipWithIndex.foreach { case ((k, v), i) => + if (prevK != null) { + assert(keyOrdering.compare(prevK, k) <= 0, + s""" + |key is not in sorted order: + |previous key: $prevK + |current key : $k + """.stripMargin) + } + prevK = k + } + + // Testing to make sure the key/value in output matches input + assert(out.sorted(kvOrdering) === inputData.sorted(kvOrdering)) + + // Make sure there is no memory leak + val leakedUnsafeMemory: Long = taskMemMgr.cleanUpAllAllocatedMemory + if (shuffleMemMgr != null) { + val leakedShuffleMemory: Long = shuffleMemMgr.getMemoryConsumptionForThisTask() + assert(0L === leakedShuffleMemory) + } + assert(0 === leakedUnsafeMemory) + TaskContext.unset() + } + + test("kv sorting with records that exceed page size") { + val pageSize = 128 + + val schema = StructType(StructField("b", BinaryType) :: Nil) + val externalConverter = CatalystTypeConverters.createToCatalystConverter(schema) + val converter = UnsafeProjection.create(schema) + + val rand = new Random() + val inputData = Seq.fill(1024) { + val kBytes = new Array[Byte](rand.nextInt(pageSize)) + val vBytes = new Array[Byte](rand.nextInt(pageSize)) + rand.nextBytes(kBytes) + rand.nextBytes(vBytes) + val k = converter(externalConverter.apply(Row(kBytes)).asInstanceOf[InternalRow]) + val v = converter(externalConverter.apply(Row(vBytes)).asInstanceOf[InternalRow]) + (k.asInstanceOf[InternalRow].copy(), v.asInstanceOf[InternalRow].copy()) + } + + testKVSorter( + schema, + schema, + inputData, + pageSize, + spill = true + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala new file mode 100644 index 000000000000..0113d052e338 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.io.{File, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream} + +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.util.Utils +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.types._ +import org.apache.spark._ + + +/** + * used to test close InputStream in UnsafeRowSerializer + */ +class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStream(buf) { + var closed: Boolean = false + override def close(): Unit = { + closed = true + super.close() + } +} + +class UnsafeRowSerializerSuite extends SparkFunSuite { + + private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { + val converter = unsafeRowConverter(schema) + converter(row) + } + + private def unsafeRowConverter(schema: Array[DataType]): Row => UnsafeRow = { + val converter = UnsafeProjection.create(schema) + (row: Row) => { + converter(CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow]) + } + } + + test("toUnsafeRow() test helper method") { + // This currently doesnt work because the generic getter throws an exception. + val row = Row("Hello", 123) + val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) + assert(row.getString(0) === unsafeRow.getUTF8String(0).toString) + assert(row.getInt(1) === unsafeRow.getInt(1)) + } + + test("basic row serialization") { + val rows = Seq(Row("Hello", 1), Row("World", 2)) + val unsafeRows = rows.map(row => toUnsafeRow(row, Array(StringType, IntegerType))) + val serializer = new UnsafeRowSerializer(numFields = 2).newInstance() + val baos = new ByteArrayOutputStream() + val serializerStream = serializer.serializeStream(baos) + for (unsafeRow <- unsafeRows) { + serializerStream.writeKey(0) + serializerStream.writeValue(unsafeRow) + } + serializerStream.close() + val input = new ClosableByteArrayInputStream(baos.toByteArray) + val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator + for (expectedRow <- unsafeRows) { + val actualRow = deserializerIter.next().asInstanceOf[(Integer, UnsafeRow)]._2 + assert(expectedRow.getSizeInBytes === actualRow.getSizeInBytes) + assert(expectedRow.getString(0) === actualRow.getString(0)) + assert(expectedRow.getInt(1) === actualRow.getInt(1)) + } + assert(!deserializerIter.hasNext) + assert(input.closed) + } + + test("close empty input stream") { + val baos = new ByteArrayOutputStream() + val dout = new DataOutputStream(baos) + dout.writeInt(-1) // EOF + dout.flush() + val input = new ClosableByteArrayInputStream(baos.toByteArray) + val serializer = new UnsafeRowSerializer(numFields = 2).newInstance() + val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator + assert(!deserializerIter.hasNext) + assert(input.closed) + } + + test("SPARK-10466: external sorter spilling with unsafe row serializer") { + var sc: SparkContext = null + var outputFile: File = null + val oldEnv = SparkEnv.get // save the old SparkEnv, as it will be overwritten + Utils.tryWithSafeFinally { + val conf = new SparkConf() + .set("spark.shuffle.spill.initialMemoryThreshold", "1024") + .set("spark.shuffle.sort.bypassMergeThreshold", "0") + .set("spark.shuffle.memoryFraction", "0.0001") + + sc = new SparkContext("local", "test", conf) + outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") + // prepare data + val converter = unsafeRowConverter(Array(IntegerType)) + val data = (1 to 1000).iterator.map { i => + (i, converter(Row(i))) + } + val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( + partitioner = Some(new HashPartitioner(10)), + serializer = Some(new UnsafeRowSerializer(numFields = 1))) + + // Ensure we spilled something and have to merge them later + assert(sorter.numSpills === 0) + sorter.insertAll(data) + assert(sorter.numSpills > 0) + + // Merging spilled files should not throw assertion error + val taskContext = + new TaskContextImpl(0, 0, 0, 0, null, null, InternalAccumulator.create(sc)) + taskContext.taskMetrics.shuffleWriteMetrics = Some(new ShuffleWriteMetrics) + sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), taskContext, outputFile) + } { + // Clean up + if (sc != null) { + sc.stop() + } + + // restore the spark env + SparkEnv.set(oldEnv) + + if (outputFile != null) { + outputFile.delete() + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala new file mode 100644 index 000000000000..afda0d29f6d9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.unsafe.memory.TaskMemoryManager + +class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLContext { + + test("memory acquired on construction") { + val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager) + val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty) + TaskContext.setTaskContext(taskContext) + + // Assert that a page is allocated before processing starts + var iter: TungstenAggregationIterator = null + try { + val newMutableProjection = (expr: Seq[Expression], schema: Seq[Attribute]) => { + () => new InterpretedMutableProjection(expr, schema) + } + val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy") + iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, 0, + Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum) + val numPages = iter.getHashMap.getNumDataPages + assert(numPages === 1) + } finally { + // Clean up + if (iter != null) { + iter.free() + } + TaskContext.unset() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala similarity index 76% rename from sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index c32d9f88dd6e..6a18cc6d2713 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -15,27 +15,25 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json -import java.io.StringWriter +import java.io.{File, StringWriter} import java.sql.{Date, Timestamp} import com.fasterxml.jackson.core.JsonFactory +import org.apache.spark.rdd.RDD import org.scalactic.Tolerance._ import org.apache.spark.sql.{QueryTest, Row, SQLConf} -import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.json.InferSchema.compatibleType -import org.apache.spark.sql.sources.LogicalRelation +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JsonSuite extends QueryTest with TestJsonData { - - protected lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.sql - import ctx.implicits._ +class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { + import testImplicits._ test("Type promotion") { def checkTypePromotion(expected: Any, actual: Any) { @@ -63,39 +61,39 @@ class JsonSuite extends QueryTest with TestJsonData { checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType)) checkTypePromotion(intNumber.toDouble, enforceCorrectType(intNumber, DoubleType)) checkTypePromotion( - Decimal(intNumber), enforceCorrectType(intNumber, DecimalType.Unlimited)) + Decimal(intNumber), enforceCorrectType(intNumber, DecimalType.SYSTEM_DEFAULT)) val longNumber: Long = 9223372036854775807L checkTypePromotion(longNumber, enforceCorrectType(longNumber, LongType)) checkTypePromotion(longNumber.toDouble, enforceCorrectType(longNumber, DoubleType)) checkTypePromotion( - Decimal(longNumber), enforceCorrectType(longNumber, DecimalType.Unlimited)) + Decimal(longNumber), enforceCorrectType(longNumber, DecimalType.SYSTEM_DEFAULT)) val doubleNumber: Double = 1.7976931348623157E308d checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType)) - checkTypePromotion( - Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.Unlimited)) - checkTypePromotion(DateUtils.fromJavaTimestamp(new Timestamp(intNumber)), + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber)), enforceCorrectType(intNumber, TimestampType)) - checkTypePromotion(DateUtils.fromJavaTimestamp(new Timestamp(intNumber.toLong)), + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber.toLong)), enforceCorrectType(intNumber.toLong, TimestampType)) val strTime = "2014-09-30 12:34:56" - checkTypePromotion(DateUtils.fromJavaTimestamp(Timestamp.valueOf(strTime)), + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(strTime)), enforceCorrectType(strTime, TimestampType)) val strDate = "2014-10-15" checkTypePromotion( - DateUtils.fromJavaDate(Date.valueOf(strDate)), enforceCorrectType(strDate, DateType)) + DateTimeUtils.fromJavaDate(Date.valueOf(strDate)), enforceCorrectType(strDate, DateType)) val ISO8601Time1 = "1970-01-01T01:00:01.0Z" - checkTypePromotion(DateUtils.fromJavaTimestamp(new Timestamp(3601000)), + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(3601000)), enforceCorrectType(ISO8601Time1, TimestampType)) - checkTypePromotion(DateUtils.millisToDays(3601000), enforceCorrectType(ISO8601Time1, DateType)) + checkTypePromotion(DateTimeUtils.millisToDays(3601000), + enforceCorrectType(ISO8601Time1, DateType)) val ISO8601Time2 = "1970-01-01T02:00:01-01:00" - checkTypePromotion(DateUtils.fromJavaTimestamp(new Timestamp(10801000)), + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(10801000)), enforceCorrectType(ISO8601Time2, TimestampType)) - checkTypePromotion(DateUtils.millisToDays(10801000), enforceCorrectType(ISO8601Time2, DateType)) + checkTypePromotion(DateTimeUtils.millisToDays(10801000), + enforceCorrectType(ISO8601Time2, DateType)) } test("Get compatible type") { @@ -113,7 +111,7 @@ class JsonSuite extends QueryTest with TestJsonData { checkDataType(NullType, IntegerType, IntegerType) checkDataType(NullType, LongType, LongType) checkDataType(NullType, DoubleType, DoubleType) - checkDataType(NullType, DecimalType.Unlimited, DecimalType.Unlimited) + checkDataType(NullType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT) checkDataType(NullType, StringType, StringType) checkDataType(NullType, ArrayType(IntegerType), ArrayType(IntegerType)) checkDataType(NullType, StructType(Nil), StructType(Nil)) @@ -124,7 +122,7 @@ class JsonSuite extends QueryTest with TestJsonData { checkDataType(BooleanType, IntegerType, StringType) checkDataType(BooleanType, LongType, StringType) checkDataType(BooleanType, DoubleType, StringType) - checkDataType(BooleanType, DecimalType.Unlimited, StringType) + checkDataType(BooleanType, DecimalType.SYSTEM_DEFAULT, StringType) checkDataType(BooleanType, StringType, StringType) checkDataType(BooleanType, ArrayType(IntegerType), StringType) checkDataType(BooleanType, StructType(Nil), StringType) @@ -133,7 +131,7 @@ class JsonSuite extends QueryTest with TestJsonData { checkDataType(IntegerType, IntegerType, IntegerType) checkDataType(IntegerType, LongType, LongType) checkDataType(IntegerType, DoubleType, DoubleType) - checkDataType(IntegerType, DecimalType.Unlimited, DecimalType.Unlimited) + checkDataType(IntegerType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT) checkDataType(IntegerType, StringType, StringType) checkDataType(IntegerType, ArrayType(IntegerType), StringType) checkDataType(IntegerType, StructType(Nil), StringType) @@ -141,23 +139,24 @@ class JsonSuite extends QueryTest with TestJsonData { // LongType checkDataType(LongType, LongType, LongType) checkDataType(LongType, DoubleType, DoubleType) - checkDataType(LongType, DecimalType.Unlimited, DecimalType.Unlimited) + checkDataType(LongType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT) checkDataType(LongType, StringType, StringType) checkDataType(LongType, ArrayType(IntegerType), StringType) checkDataType(LongType, StructType(Nil), StringType) // DoubleType checkDataType(DoubleType, DoubleType, DoubleType) - checkDataType(DoubleType, DecimalType.Unlimited, DecimalType.Unlimited) + checkDataType(DoubleType, DecimalType.SYSTEM_DEFAULT, DoubleType) checkDataType(DoubleType, StringType, StringType) checkDataType(DoubleType, ArrayType(IntegerType), StringType) checkDataType(DoubleType, StructType(Nil), StringType) - // DoubleType - checkDataType(DecimalType.Unlimited, DecimalType.Unlimited, DecimalType.Unlimited) - checkDataType(DecimalType.Unlimited, StringType, StringType) - checkDataType(DecimalType.Unlimited, ArrayType(IntegerType), StringType) - checkDataType(DecimalType.Unlimited, StructType(Nil), StringType) + // DecimalType + checkDataType(DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT, + DecimalType.SYSTEM_DEFAULT) + checkDataType(DecimalType.SYSTEM_DEFAULT, StringType, StringType) + checkDataType(DecimalType.SYSTEM_DEFAULT, ArrayType(IntegerType), StringType) + checkDataType(DecimalType.SYSTEM_DEFAULT, StructType(Nil), StringType) // StringType checkDataType(StringType, StringType, StringType) @@ -211,12 +210,12 @@ class JsonSuite extends QueryTest with TestJsonData { checkDataType( StructType( StructField("f1", IntegerType, true) :: Nil), - DecimalType.Unlimited, + DecimalType.SYSTEM_DEFAULT, StringType) } test("Complex field and type inferring with null in sampling") { - val jsonDF = ctx.read.json(jsonNullStruct) + val jsonDF = sqlContext.read.json(jsonNullStruct) val expectedSchema = StructType( StructField("headers", StructType( StructField("Charset", StringType, true) :: @@ -235,10 +234,10 @@ class JsonSuite extends QueryTest with TestJsonData { } test("Primitive field and type inferring") { - val jsonDF = ctx.read.json(primitiveFieldAndType) + val jsonDF = sqlContext.read.json(primitiveFieldAndType) val expectedSchema = StructType( - StructField("bigInteger", DecimalType.Unlimited, true) :: + StructField("bigInteger", DecimalType(20, 0), true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", LongType, true) :: @@ -263,12 +262,12 @@ class JsonSuite extends QueryTest with TestJsonData { } test("Complex field and type inferring") { - val jsonDF = ctx.read.json(complexFieldAndType1) + val jsonDF = sqlContext.read.json(complexFieldAndType1) val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, true), true), true) :: - StructField("arrayOfBigInteger", ArrayType(DecimalType.Unlimited, true), true) :: + StructField("arrayOfBigInteger", ArrayType(DecimalType(21, 0), true), true) :: StructField("arrayOfBoolean", ArrayType(BooleanType, true), true) :: StructField("arrayOfDouble", ArrayType(DoubleType, true), true) :: StructField("arrayOfInteger", ArrayType(LongType, true), true) :: @@ -282,7 +281,7 @@ class JsonSuite extends QueryTest with TestJsonData { StructField("field3", StringType, true) :: Nil), true), true) :: StructField("struct", StructType( StructField("field1", BooleanType, true) :: - StructField("field2", DecimalType.Unlimited, true) :: Nil), true) :: + StructField("field2", DecimalType(20, 0), true) :: Nil), true) :: StructField("structWithArrayFields", StructType( StructField("field1", ArrayType(LongType, true), true) :: StructField("field2", ArrayType(StringType, true), true) :: Nil), true) :: Nil) @@ -362,7 +361,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("GetField operation on complex data type") { - val jsonDF = ctx.read.json(complexFieldAndType1) + val jsonDF = sqlContext.read.json(complexFieldAndType1) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -378,12 +377,12 @@ class JsonSuite extends QueryTest with TestJsonData { } test("Type conflict in primitive field values") { - val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) + val jsonDF = sqlContext.read.json(primitiveFieldValueTypeConflict) val expectedSchema = StructType( StructField("num_bool", StringType, true) :: StructField("num_num_1", LongType, true) :: - StructField("num_num_2", DecimalType.Unlimited, true) :: + StructField("num_num_2", DoubleType, true) :: StructField("num_num_3", DoubleType, true) :: StructField("num_str", StringType, true) :: StructField("str_bool", StringType, true) :: Nil) @@ -395,11 +394,9 @@ class JsonSuite extends QueryTest with TestJsonData { checkAnswer( sql("select * from jsonTable"), Row("true", 11L, null, 1.1, "13.1", "str1") :: - Row("12", null, new java.math.BigDecimal("21474836470.9"), null, null, "true") :: - Row("false", 21474836470L, - new java.math.BigDecimal("92233720368547758070"), 100, "str1", "false") :: - Row(null, 21474836570L, - new java.math.BigDecimal("1.1"), 21474836470L, "92233720368547758070", null) :: Nil + Row("12", null, 21474836470.9, null, null, "true") :: + Row("false", 21474836470L, 92233720368547758070d, 100, "str1", "false") :: + Row(null, 21474836570L, 1.1, 21474836470L, "92233720368547758070", null) :: Nil ) // Number and Boolean conflict: resolve the type as number in this query. @@ -421,12 +418,12 @@ class JsonSuite extends QueryTest with TestJsonData { // Widening to DecimalType checkAnswer( - sql("select num_num_2 + 1.2 from jsonTable where num_num_2 > 1.1"), - Row(new java.math.BigDecimal("21474836472.1")) :: - Row(new java.math.BigDecimal("92233720368547758071.2")) :: Nil + sql("select num_num_2 + 1.3 from jsonTable where num_num_2 > 1.1"), + Row(21474836472.2) :: + Row(92233720368547758071.3) :: Nil ) - // Widening to DoubleType + // Widening to Double checkAnswer( sql("select num_num_3 + 1.2 from jsonTable where num_num_3 > 1.1"), Row(101.2) :: Row(21474836471.2) :: Nil @@ -435,13 +432,13 @@ class JsonSuite extends QueryTest with TestJsonData { // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str > 14"), - Row(92233720368547758071.2) + Row(BigDecimal("92233720368547758071.2")) ) // Number and String conflict: resolve the type as number in this query. checkAnswer( - sql("select num_str + 1.2 from jsonTable where num_str > 92233720368547758060"), - Row(new java.math.BigDecimal("92233720368547758061.2").doubleValue) + sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"), + Row(new java.math.BigDecimal("92233720368547758071.2")) ) // String and Boolean conflict: resolve the type as string. @@ -452,7 +449,7 @@ class JsonSuite extends QueryTest with TestJsonData { } ignore("Type conflict in primitive field values (Ignored)") { - val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) + val jsonDF = sqlContext.read.json(primitiveFieldValueTypeConflict) jsonDF.registerTempTable("jsonTable") // Right now, the analyzer does not promote strings in a boolean expression. @@ -487,9 +484,9 @@ class JsonSuite extends QueryTest with TestJsonData { // in the Project. checkAnswer( jsonDF. - where('num_str > BigDecimal("92233720368547758060")). + where('num_str >= BigDecimal("92233720368547758060")). select(('num_str + 1.2).as("num")), - Row(new java.math.BigDecimal("92233720368547758061.2")) + Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue()) ) // The following test will fail. The type of num_str is StringType. @@ -500,12 +497,12 @@ class JsonSuite extends QueryTest with TestJsonData { // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str > 13"), - Row(14.3) :: Row(92233720368547758071.2) :: Nil + Row(BigDecimal("14.3")) :: Row(BigDecimal("92233720368547758071.2")) :: Nil ) } test("Type conflict in complex field values") { - val jsonDF = ctx.read.json(complexFieldValueTypeConflict) + val jsonDF = sqlContext.read.json(complexFieldValueTypeConflict) val expectedSchema = StructType( StructField("array", ArrayType(LongType, true), true) :: @@ -529,7 +526,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("Type conflict in array elements") { - val jsonDF = ctx.read.json(arrayElementTypeConflict) + val jsonDF = sqlContext.read.json(arrayElementTypeConflict) val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: @@ -557,7 +554,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("Handling missing fields") { - val jsonDF = ctx.read.json(missingFields) + val jsonDF = sqlContext.read.json(missingFields) val expectedSchema = StructType( StructField("a", BooleanType, true) :: @@ -575,10 +572,10 @@ class JsonSuite extends QueryTest with TestJsonData { test("jsonFile should be based on JSONRelation") { val dir = Utils.createTempDir() dir.delete() - val path = dir.getCanonicalPath - ctx.sparkContext.parallelize(1 to 100) + val path = dir.getCanonicalFile.toURI.toString + sparkContext.parallelize(1 to 100) .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) - val jsonDF = ctx.read.option("samplingRatio", "0.49").json(path) + val jsonDF = sqlContext.read.option("samplingRatio", "0.49").json(path) val analyzed = jsonDF.queryExecution.analyzed assert( @@ -588,14 +585,15 @@ class JsonSuite extends QueryTest with TestJsonData { assert( relation.isInstanceOf[JSONRelation], "The DataFrame returned by jsonFile should be based on JSONRelation.") - assert(relation.asInstanceOf[JSONRelation].path === Some(path)) + assert(relation.asInstanceOf[JSONRelation].paths === Array(path)) assert(relation.asInstanceOf[JSONRelation].samplingRatio === (0.49 +- 0.001)) val schema = StructType(StructField("a", LongType, true) :: Nil) val logicalRelation = - ctx.read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation] + sqlContext.read.schema(schema).json(path) + .queryExecution.analyzed.asInstanceOf[LogicalRelation] val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] - assert(relationWithSchema.path === Some(path)) + assert(relationWithSchema.paths === Array(path)) assert(relationWithSchema.schema === schema) assert(relationWithSchema.samplingRatio > 0.99) } @@ -605,10 +603,10 @@ class JsonSuite extends QueryTest with TestJsonData { dir.delete() val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = ctx.read.json(path) + val jsonDF = sqlContext.read.json(path) val expectedSchema = StructType( - StructField("bigInteger", DecimalType.Unlimited, true) :: + StructField("bigInteger", DecimalType(20, 0), true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", LongType, true) :: @@ -666,7 +664,7 @@ class JsonSuite extends QueryTest with TestJsonData { primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) val schema = StructType( - StructField("bigInteger", DecimalType.Unlimited, true) :: + StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", IntegerType, true) :: @@ -674,7 +672,7 @@ class JsonSuite extends QueryTest with TestJsonData { StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - val jsonDF1 = ctx.read.schema(schema).json(path) + val jsonDF1 = sqlContext.read.schema(schema).json(path) assert(schema === jsonDF1.schema) @@ -691,7 +689,7 @@ class JsonSuite extends QueryTest with TestJsonData { "this is a simple string.") ) - val jsonDF2 = ctx.read.schema(schema).json(primitiveFieldAndType) + val jsonDF2 = sqlContext.read.schema(schema).json(primitiveFieldAndType) assert(schema === jsonDF2.schema) @@ -712,7 +710,7 @@ class JsonSuite extends QueryTest with TestJsonData { test("Applying schemas with MapType") { val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - val jsonWithSimpleMap = ctx.read.schema(schemaWithSimpleMap).json(mapType1) + val jsonWithSimpleMap = sqlContext.read.schema(schemaWithSimpleMap).json(mapType1) jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") @@ -740,7 +738,7 @@ class JsonSuite extends QueryTest with TestJsonData { val schemaWithComplexMap = StructType( StructField("map", MapType(StringType, innerStruct, true), false) :: Nil) - val jsonWithComplexMap = ctx.read.schema(schemaWithComplexMap).json(mapType2) + val jsonWithComplexMap = sqlContext.read.schema(schemaWithComplexMap).json(mapType2) jsonWithComplexMap.registerTempTable("jsonWithComplexMap") @@ -766,7 +764,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("SPARK-2096 Correctly parse dot notations") { - val jsonDF = ctx.read.json(complexFieldAndType2) + val jsonDF = sqlContext.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -784,7 +782,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("SPARK-3390 Complex arrays") { - val jsonDF = ctx.read.json(complexFieldAndType2) + val jsonDF = sqlContext.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -807,7 +805,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("SPARK-3308 Read top level JSON arrays") { - val jsonDF = ctx.read.json(jsonArray) + val jsonDF = sqlContext.read.json(jsonArray) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -825,64 +823,63 @@ class JsonSuite extends QueryTest with TestJsonData { test("Corrupt records") { // Test if we can query corrupt records. - val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") - - val jsonDF = ctx.read.json(corruptRecords) - jsonDF.registerTempTable("jsonTable") - - val schema = StructType( - StructField("_unparsed", StringType, true) :: - StructField("a", StringType, true) :: - StructField("b", StringType, true) :: - StructField("c", StringType, true) :: Nil) - - assert(schema === jsonDF.schema) - - // In HiveContext, backticks should be used to access columns starting with a underscore. - checkAnswer( - sql( - """ - |SELECT a, b, c, _unparsed - |FROM jsonTable - """.stripMargin), - Row(null, null, null, "{") :: - Row(null, null, null, "") :: - Row(null, null, null, """{"a":1, b:2}""") :: - Row(null, null, null, """{"a":{, b:3}""") :: - Row("str_a_4", "str_b_4", "str_c_4", null) :: - Row(null, null, null, "]") :: Nil - ) - - checkAnswer( - sql( - """ - |SELECT a, b, c - |FROM jsonTable - |WHERE _unparsed IS NULL - """.stripMargin), - Row("str_a_4", "str_b_4", "str_c_4") - ) - - checkAnswer( - sql( - """ - |SELECT _unparsed - |FROM jsonTable - |WHERE _unparsed IS NOT NULL - """.stripMargin), - Row("{") :: - Row("") :: - Row("""{"a":1, b:2}""") :: - Row("""{"a":{, b:3}""") :: - Row("]") :: Nil - ) - - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { + withTempTable("jsonTable") { + val jsonDF = sqlContext.read.json(corruptRecords) + jsonDF.registerTempTable("jsonTable") + + val schema = StructType( + StructField("_unparsed", StringType, true) :: + StructField("a", StringType, true) :: + StructField("b", StringType, true) :: + StructField("c", StringType, true) :: Nil) + + assert(schema === jsonDF.schema) + + // In HiveContext, backticks should be used to access columns starting with a underscore. + checkAnswer( + sql( + """ + |SELECT a, b, c, _unparsed + |FROM jsonTable + """.stripMargin), + Row(null, null, null, "{") :: + Row(null, null, null, "") :: + Row(null, null, null, """{"a":1, b:2}""") :: + Row(null, null, null, """{"a":{, b:3}""") :: + Row("str_a_4", "str_b_4", "str_c_4", null) :: + Row(null, null, null, "]") :: Nil + ) + + checkAnswer( + sql( + """ + |SELECT a, b, c + |FROM jsonTable + |WHERE _unparsed IS NULL + """.stripMargin), + Row("str_a_4", "str_b_4", "str_c_4") + ) + + checkAnswer( + sql( + """ + |SELECT _unparsed + |FROM jsonTable + |WHERE _unparsed IS NOT NULL + """.stripMargin), + Row("{") :: + Row("") :: + Row("""{"a":1, b:2}""") :: + Row("""{"a":{, b:3}""") :: + Row("]") :: Nil + ) + } + } } test("SPARK-4068: nulls in arrays") { - val jsonDF = ctx.read.json(nullsInArrays) + val jsonDF = sqlContext.read.json(nullsInArrays) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -928,7 +925,7 @@ class JsonSuite extends QueryTest with TestJsonData { Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5) } - val df1 = ctx.createDataFrame(rowRDD1, schema1) + val df1 = sqlContext.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") val df2 = df1.toDF val result = df2.toJSON.collect() @@ -951,7 +948,7 @@ class JsonSuite extends QueryTest with TestJsonData { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df3 = ctx.createDataFrame(rowRDD2, schema2) + val df3 = sqlContext.createDataFrame(rowRDD2, schema2) df3.registerTempTable("applySchema2") val df4 = df3.toDF val result2 = df4.toJSON.collect() @@ -959,8 +956,8 @@ class JsonSuite extends QueryTest with TestJsonData { assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") - val jsonDF = ctx.read.json(primitiveFieldAndType) - val primTable = ctx.read.json(jsonDF.toJSON) + val jsonDF = sqlContext.read.json(primitiveFieldAndType) + val primTable = sqlContext.read.json(jsonDF.toJSON) primTable.registerTempTable("primativeTable") checkAnswer( sql("select * from primativeTable"), @@ -972,8 +969,8 @@ class JsonSuite extends QueryTest with TestJsonData { "this is a simple string.") ) - val complexJsonDF = ctx.read.json(complexFieldAndType1) - val compTable = ctx.read.json(complexJsonDF.toJSON) + val complexJsonDF = sqlContext.read.json(complexFieldAndType1) + val compTable = sqlContext.read.json(complexJsonDF.toJSON) compTable.registerTempTable("complexTable") // Access elements of a primitive array. checkAnswer( @@ -1037,26 +1034,35 @@ class JsonSuite extends QueryTest with TestJsonData { } test("JSONRelation equality test") { - val context = org.apache.spark.sql.test.TestSQLContext + val relation0 = new JSONRelation( + Some(empty), + 1.0, + Some(StructType(StructField("a", IntegerType, true) :: Nil)), + None, None)(sqlContext) + val logicalRelation0 = LogicalRelation(relation0) val relation1 = new JSONRelation( - "path", + Some(singleRow), 1.0, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - context) + None, None)(sqlContext) val logicalRelation1 = LogicalRelation(relation1) val relation2 = new JSONRelation( - "path", + Some(singleRow), 0.5, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - context) + None, None)(sqlContext) val logicalRelation2 = LogicalRelation(relation2) val relation3 = new JSONRelation( - "path", + Some(singleRow), 1.0, - Some(StructType(StructField("b", StringType, true) :: Nil)), - context) + Some(StructType(StructField("b", IntegerType, true) :: Nil)), + None, None)(sqlContext) val logicalRelation3 = LogicalRelation(relation3) + assert(relation0 !== relation1) + assert(!logicalRelation0.sameResult(logicalRelation1), + s"$logicalRelation0 and $logicalRelation1 should be considered not having the same result.") + assert(relation1 === relation2) assert(logicalRelation1.sameResult(logicalRelation2), s"$logicalRelation1 and $logicalRelation2 should be considered having the same result.") @@ -1068,6 +1074,27 @@ class JsonSuite extends QueryTest with TestJsonData { assert(relation2 !== relation3) assert(!logicalRelation2.sameResult(logicalRelation3), s"$logicalRelation2 and $logicalRelation3 should be considered not having the same result.") + + withTempPath(dir => { + val path = dir.getCanonicalFile.toURI.toString + sparkContext.parallelize(1 to 100) + .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) + + val d1 = ResolvedDataSource( + sqlContext, + userSpecifiedSchema = None, + partitionColumns = Array.empty[String], + provider = classOf[DefaultSource].getCanonicalName, + options = Map("path" -> path)) + + val d2 = ResolvedDataSource( + sqlContext, + userSpecifiedSchema = None, + partitionColumns = Array.empty[String], + provider = classOf[DefaultSource].getCanonicalName, + options = Map("path" -> path)) + assert(d1 === d2) + }) } test("SPARK-6245 JsonRDD.inferSchema on empty RDD") { @@ -1077,29 +1104,21 @@ class JsonSuite extends QueryTest with TestJsonData { } test("SPARK-7565 MapType in JsonRDD") { - val useStreaming = ctx.conf.useJacksonStreamingAPI - val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") - - val schemaWithSimpleMap = StructType( - StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - try{ - for (useStreaming <- List(true, false)) { - ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) - val temp = Utils.createTempDir().getPath - - val df = ctx.read.schema(schemaWithSimpleMap).json(mapType1) - df.write.mode("overwrite").parquet(temp) + withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { + withTempDir { dir => + val schemaWithSimpleMap = StructType( + StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) + val df = sqlContext.read.schema(schemaWithSimpleMap).json(mapType1) + + val path = dir.getAbsolutePath + df.write.mode("overwrite").parquet(path) // order of MapType is not defined - assert(ctx.read.parquet(temp).count() == 5) + assert(sqlContext.read.parquet(path).count() == 5) - val df2 = ctx.read.json(corruptRecords) - df2.write.mode("overwrite").parquet(temp) - checkAnswer(ctx.read.parquet(temp), df2.collect()) + val df2 = sqlContext.read.json(corruptRecords) + df2.write.mode("overwrite").parquet(path) + checkAnswer(sqlContext.read.parquet(path), df2.collect()) } - } finally { - ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) } } @@ -1107,4 +1126,37 @@ class JsonSuite extends QueryTest with TestJsonData { val emptySchema = InferSchema(emptyRecords, 1.0, "") assert(StructType(Seq()) === emptySchema) } + + test("JSON with Partition") { + def makePartition(rdd: RDD[String], parent: File, partName: String, partValue: Any): File = { + val p = new File(parent, s"$partName=${partValue.toString}") + rdd.saveAsTextFile(p.getCanonicalPath) + p + } + + withTempPath(root => { + val d1 = new File(root, "d1=1") + // root/dt=1/col1=abc + val p1_col1 = makePartition( + sparkContext.parallelize(2 to 5).map(i => s"""{"a": 1, "b": "str$i"}"""), + d1, + "col1", + "abc") + + // root/dt=1/col1=abd + val p2 = makePartition( + sparkContext.parallelize(6 to 10).map(i => s"""{"a": 1, "b": "str$i"}"""), + d1, + "col1", + "abd") + + sqlContext.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part") + checkAnswer(sql( + "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4)) + checkAnswer(sql( + "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abd'"), Row(5)) + checkAnswer(sql( + "SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9)) + }) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala similarity index 88% rename from sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index eb62066ac643..713d1da1cb51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -15,17 +15,16 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext -trait TestJsonData { - - protected def ctx: SQLContext +private[json] trait TestJsonData { + protected def sqlContext: SQLContext def primitiveFieldAndType: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"string":"this is a simple string.", "integer":10, "long":21474836470, @@ -36,7 +35,7 @@ trait TestJsonData { }""" :: Nil) def primitiveFieldValueTypeConflict: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1, "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" :: """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null, @@ -47,14 +46,14 @@ trait TestJsonData { "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil) def jsonNullStruct: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":{}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":""}""" :: """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil) def complexFieldValueTypeConflict: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"num_struct":11, "str_array":[1, 2, 3], "array":[], "struct_array":[], "struct": {}}""" :: """{"num_struct":{"field":false}, "str_array":null, @@ -65,14 +64,14 @@ trait TestJsonData { "array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil) def arrayElementTypeConflict: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], "array2": [{"field":214748364700}, {"field":1}]}""" :: """{"array3": [{"field":"str"}, {"field":1}]}""" :: """{"array3": [1, 2, 3]}""" :: Nil) def missingFields: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"a":true}""" :: """{"b":21474836470}""" :: """{"c":[33, 44]}""" :: @@ -80,7 +79,7 @@ trait TestJsonData { """{"e":"str"}""" :: Nil) def complexFieldAndType1: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"struct":{"field1": true, "field2": 92233720368547758070}, "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, "arrayOfString":["str1", "str2"], @@ -96,7 +95,7 @@ trait TestJsonData { }""" :: Nil) def complexFieldAndType2: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], "complexArrayOfStruct": [ { @@ -150,7 +149,7 @@ trait TestJsonData { }""" :: Nil) def mapType1: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"map": {"a": 1}}""" :: """{"map": {"b": 2}}""" :: """{"map": {"c": 3}}""" :: @@ -158,7 +157,7 @@ trait TestJsonData { """{"map": {"e": null}}""" :: Nil) def mapType2: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" :: """{"map": {"b": {"field2": 2}}}""" :: """{"map": {"c": {"field1": [], "field2": 4}}}""" :: @@ -167,21 +166,21 @@ trait TestJsonData { """{"map": {"f": {"field1": null}}}""" :: Nil) def nullsInArrays: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"field1":[[null], [[["Test"]]]]}""" :: """{"field2":[null, [{"Test":1}]]}""" :: """{"field3":[[null], [{"Test":"2"}]]}""" :: """{"field4":[[null, [1,2,3]]]}""" :: Nil) def jsonArray: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """[{"a":"str_a_1"}]""" :: """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """[]""" :: Nil) def corruptRecords: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{""" :: """""" :: """{"a":1, b:2}""" :: @@ -190,7 +189,7 @@ trait TestJsonData { """]""" :: Nil) def emptyRecords: RDD[String] = - ctx.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{""" :: """""" :: """{"a": {}}""" :: @@ -198,5 +197,8 @@ trait TestJsonData { """{"b": [{"c": {}}]}""" :: """]""" :: Nil) - def empty: RDD[String] = ctx.sparkContext.parallelize(Seq[String]()) + + lazy val singleRow: RDD[String] = sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil) + + def empty: RDD[String] = sqlContext.sparkContext.parallelize(Seq[String]()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala new file mode 100644 index 000000000000..36b929ee1f40 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File +import java.nio.ByteBuffer +import java.util.{List => JList, Map => JMap} + +import scala.collection.JavaConverters._ + +import org.apache.avro.Schema +import org.apache.avro.generic.IndexedRecord +import org.apache.hadoop.fs.Path +import org.apache.parquet.avro.AvroParquetWriter + +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.datasources.parquet.test.avro._ +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { + private def withWriter[T <: IndexedRecord] + (path: String, schema: Schema) + (f: AvroParquetWriter[T] => Unit): Unit = { + logInfo( + s"""Writing Avro records with the following Avro schema into Parquet file: + | + |${schema.toString(true)} + """.stripMargin) + + val writer = new AvroParquetWriter[T](new Path(path), schema) + try f(writer) finally writer.close() + } + + test("required primitives") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[AvroPrimitives](path, AvroPrimitives.getClassSchema) { writer => + (0 until 10).foreach { i => + writer.write( + AvroPrimitives.newBuilder() + .setBoolColumn(i % 2 == 0) + .setIntColumn(i) + .setLongColumn(i.toLong * 10) + .setFloatColumn(i.toFloat + 0.1f) + .setDoubleColumn(i.toDouble + 0.2d) + .setBinaryColumn(ByteBuffer.wrap(s"val_$i".getBytes("UTF-8"))) + .setStringColumn(s"val_$i") + .build()) + } + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + Row( + i % 2 == 0, + i, + i.toLong * 10, + i.toFloat + 0.1f, + i.toDouble + 0.2d, + s"val_$i".getBytes("UTF-8"), + s"val_$i") + }) + } + } + + test("optional primitives") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[AvroOptionalPrimitives](path, AvroOptionalPrimitives.getClassSchema) { writer => + (0 until 10).foreach { i => + val record = if (i % 3 == 0) { + AvroOptionalPrimitives.newBuilder() + .setMaybeBoolColumn(null) + .setMaybeIntColumn(null) + .setMaybeLongColumn(null) + .setMaybeFloatColumn(null) + .setMaybeDoubleColumn(null) + .setMaybeBinaryColumn(null) + .setMaybeStringColumn(null) + .build() + } else { + AvroOptionalPrimitives.newBuilder() + .setMaybeBoolColumn(i % 2 == 0) + .setMaybeIntColumn(i) + .setMaybeLongColumn(i.toLong * 10) + .setMaybeFloatColumn(i.toFloat + 0.1f) + .setMaybeDoubleColumn(i.toDouble + 0.2d) + .setMaybeBinaryColumn(ByteBuffer.wrap(s"val_$i".getBytes("UTF-8"))) + .setMaybeStringColumn(s"val_$i") + .build() + } + + writer.write(record) + } + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + if (i % 3 == 0) { + Row.apply(Seq.fill(7)(null): _*) + } else { + Row( + i % 2 == 0, + i, + i.toLong * 10, + i.toFloat + 0.1f, + i.toDouble + 0.2d, + s"val_$i".getBytes("UTF-8"), + s"val_$i") + } + }) + } + } + + test("non-nullable arrays") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[AvroNonNullableArrays](path, AvroNonNullableArrays.getClassSchema) { writer => + (0 until 10).foreach { i => + val record = { + val builder = + AvroNonNullableArrays.newBuilder() + .setStringsColumn(Seq.tabulate(3)(i => s"val_$i").asJava) + + if (i % 3 == 0) { + builder.setMaybeIntsColumn(null).build() + } else { + builder.setMaybeIntsColumn(Seq.tabulate(3)(Int.box).asJava).build() + } + } + + writer.write(record) + } + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + Row( + Seq.tabulate(3)(i => s"val_$i"), + if (i % 3 == 0) null else Seq.tabulate(3)(identity)) + }) + } + } + + ignore("nullable arrays (parquet-avro 1.7.0 does not properly support this)") { + // TODO Complete this test case after upgrading to parquet-mr 1.8+ + } + + test("SPARK-10136 array of primitive array") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[AvroArrayOfArray](path, AvroArrayOfArray.getClassSchema) { writer => + (0 until 10).foreach { i => + writer.write(AvroArrayOfArray.newBuilder() + .setIntArraysColumn( + Seq.tabulate(3, 3)((i, j) => i * 3 + j: Integer).map(_.asJava).asJava) + .build()) + } + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + Row(Seq.tabulate(3, 3)((i, j) => i * 3 + j)) + }) + } + } + + test("map of primitive array") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[AvroMapOfArray](path, AvroMapOfArray.getClassSchema) { writer => + (0 until 10).foreach { i => + writer.write(AvroMapOfArray.newBuilder() + .setStringToIntsColumn( + Seq.tabulate(3) { i => + i.toString -> Seq.tabulate(3)(j => i + j: Integer).asJava + }.toMap.asJava) + .build()) + } + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + Row(Seq.tabulate(3)(i => i.toString -> Seq.tabulate(3)(j => i + j)).toMap) + }) + } + } + + test("various complex types") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[ParquetAvroCompat](path, ParquetAvroCompat.getClassSchema) { writer => + (0 until 10).foreach(i => writer.write(makeParquetAvroCompat(i))) + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + Row( + Seq.tabulate(3)(n => s"arr_${i + n}"), + Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap, + Seq.tabulate(3) { n => + (i + n).toString -> Seq.tabulate(3) { m => + Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}") + } + }.toMap) + }) + } + } + + def makeParquetAvroCompat(i: Int): ParquetAvroCompat = { + def makeComplexColumn(i: Int): JMap[String, JList[Nested]] = { + Seq.tabulate(3) { n => + (i + n).toString -> Seq.tabulate(3) { m => + Nested + .newBuilder() + .setNestedIntsColumn(Seq.tabulate(3)(j => i + j + m: Integer).asJava) + .setNestedStringColumn(s"val_${i + m}") + .build() + }.asJava + }.toMap.asJava + } + + ParquetAvroCompat + .newBuilder() + .setStringsColumn(Seq.tabulate(3)(n => s"arr_${i + n}").asJava) + .setStringToIntColumn(Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap.asJava) + .setComplexColumn(makeComplexColumn(i)) + .build() + } + + test("SPARK-9407 Push down predicates involving Parquet ENUM columns") { + import testImplicits._ + + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[ParquetEnum](path, ParquetEnum.getClassSchema) { writer => + (0 until 4).foreach { i => + writer.write(ParquetEnum.newBuilder().setSuit(Suit.values.apply(i)).build()) + } + } + + checkAnswer(sqlContext.read.parquet(path).filter('suit === "SPADES"), Row("SPADES")) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala new file mode 100644 index 000000000000..0835bd123049 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import scala.collection.JavaConverters.{collectionAsScalaIterableConverter, mapAsJavaMapConverter, seqAsJavaListConverter} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, PathFilter} +import org.apache.parquet.hadoop.api.WriteSupport +import org.apache.parquet.hadoop.api.WriteSupport.WriteContext +import org.apache.parquet.hadoop.{ParquetFileReader, ParquetWriter} +import org.apache.parquet.io.api.RecordConsumer +import org.apache.parquet.schema.{MessageType, MessageTypeParser} + +import org.apache.spark.sql.QueryTest + +/** + * Helper class for testing Parquet compatibility. + */ +private[sql] abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest { + protected def readParquetSchema(path: String): MessageType = { + readParquetSchema(path, { path => !path.getName.startsWith("_") }) + } + + protected def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = { + val fsPath = new Path(path) + val fs = fsPath.getFileSystem(hadoopConfiguration) + val parquetFiles = fs.listStatus(fsPath, new PathFilter { + override def accept(path: Path): Boolean = pathFilter(path) + }).toSeq.asJava + + val footers = + ParquetFileReader.readAllFootersInParallel(hadoopConfiguration, parquetFiles, true) + footers.asScala.head.getParquetMetadata.getFileMetaData.getSchema + } + + protected def logParquetSchema(path: String): Unit = { + logInfo( + s"""Schema of the Parquet file written by parquet-avro: + |${readParquetSchema(path)} + """.stripMargin) + } +} + +private[sql] object ParquetCompatibilityTest { + implicit class RecordConsumerDSL(consumer: RecordConsumer) { + def message(f: => Unit): Unit = { + consumer.startMessage() + f + consumer.endMessage() + } + + def group(f: => Unit): Unit = { + consumer.startGroup() + f + consumer.endGroup() + } + + def field(name: String, index: Int)(f: => Unit): Unit = { + consumer.startField(name, index) + f + consumer.endField(name, index) + } + } + + /** + * A testing Parquet [[WriteSupport]] implementation used to write manually constructed Parquet + * records with arbitrary structures. + */ + private class DirectWriteSupport(schema: MessageType, metadata: Map[String, String]) + extends WriteSupport[RecordConsumer => Unit] { + + private var recordConsumer: RecordConsumer = _ + + override def init(configuration: Configuration): WriteContext = { + new WriteContext(schema, metadata.asJava) + } + + override def write(recordWriter: RecordConsumer => Unit): Unit = { + recordWriter.apply(recordConsumer) + } + + override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { + this.recordConsumer = recordConsumer + } + } + + /** + * Writes arbitrary messages conforming to a given `schema` to a Parquet file located by `path`. + * Records are produced by `recordWriters`. + */ + def writeDirect(path: String, schema: String, recordWriters: (RecordConsumer => Unit)*): Unit = { + writeDirect(path, schema, Map.empty[String, String], recordWriters: _*) + } + + /** + * Writes arbitrary messages conforming to a given `schema` to a Parquet file located by `path` + * with given user-defined key-value `metadata`. Records are produced by `recordWriters`. + */ + def writeDirect( + path: String, + schema: String, + metadata: Map[String, String], + recordWriters: (RecordConsumer => Unit)*): Unit = { + val messageType = MessageTypeParser.parseMessageType(schema) + val writeSupport = new DirectWriteSupport(messageType, metadata) + val parquetWriter = new ParquetWriter[RecordConsumer => Unit](new Path(path), writeSupport) + try recordWriters.foreach(parquetWriter.write) finally parquetWriter.close() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala similarity index 72% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index a2763c78b645..f067112cfca9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -15,18 +15,17 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet -import org.scalatest.BeforeAndAfterAll import org.apache.parquet.filter2.predicate.Operators._ import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} +import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.sources.LogicalRelation -import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelation} +import org.apache.spark.sql.test.SharedSQLContext /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -40,8 +39,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} * 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred * data type is nullable. */ -class ParquetFilterSuiteBase extends QueryTest with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext +class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { private def checkFilterPredicate( df: DataFrame, @@ -56,28 +54,22 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { .select(output.map(e => Column(e)): _*) .where(Column(predicate)) - val maybeAnalyzedPredicate = { - val forParquetTableScan = query.queryExecution.executedPlan.collect { - case plan: ParquetTableScan => plan.columnPruningPred - }.flatten.reduceOption(_ && _) + val analyzedPredicate = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation)) => filters + }.flatten + assert(analyzedPredicate.nonEmpty) - val forParquetDataSource = query.queryExecution.optimizedPlan.collect { - case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation2)) => filters - }.flatten.reduceOption(_ && _) + val selectedFilters = DataSourceStrategy.selectFilters(analyzedPredicate) + assert(selectedFilters.nonEmpty) - forParquetTableScan.orElse(forParquetDataSource) - } - - assert(maybeAnalyzedPredicate.isDefined) - maybeAnalyzedPredicate.foreach { pred => - val maybeFilter = ParquetFilters.createFilter(pred) + selectedFilters.foreach { pred => + val maybeFilter = ParquetFilters.createFilter(df.schema, pred) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") maybeFilter.foreach { f => // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) assert(f.getClass === filterClass) } } - checker(query, expected) } } @@ -98,7 +90,7 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) (implicit df: DataFrame): Unit = { def checkBinaryAnswer(df: DataFrame, expected: Seq[Row]) = { - assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).toSeq.sorted) { + assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).sorted) { df.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted } } @@ -118,43 +110,18 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false))) checkFilterPredicate('_1 === true, classOf[Eq[_]], true) + checkFilterPredicate('_1 <=> true, classOf[Eq[_]], true) checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false) } } - test("filter pushdown - short") { - withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit df => - checkFilterPredicate(Cast('_1, IntegerType) === 1, classOf[Eq[_]], 1) - checkFilterPredicate( - Cast('_1, IntegerType) !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - - checkFilterPredicate(Cast('_1, IntegerType) < 2, classOf[Lt[_]], 1) - checkFilterPredicate(Cast('_1, IntegerType) > 3, classOf[Gt[_]], 4) - checkFilterPredicate(Cast('_1, IntegerType) <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate(Cast('_1, IntegerType) >= 4, classOf[GtEq[_]], 4) - - checkFilterPredicate(Literal(1) === Cast('_1, IntegerType), classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > Cast('_1, IntegerType), classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < Cast('_1, IntegerType), classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= Cast('_1, IntegerType), classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= Cast('_1, IntegerType), classOf[GtEq[_]], 4) - - checkFilterPredicate(!(Cast('_1, IntegerType) < 4), classOf[GtEq[_]], 4) - checkFilterPredicate( - Cast('_1, IntegerType) > 2 && Cast('_1, IntegerType) < 4, classOf[Operators.And], 3) - checkFilterPredicate( - Cast('_1, IntegerType) < 2 || Cast('_1, IntegerType) > 3, - classOf[Operators.Or], - Seq(Row(1), Row(4))) - } - } - test("filter pushdown - integer") { withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) @@ -163,13 +130,13 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } @@ -180,6 +147,7 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) @@ -188,13 +156,13 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } @@ -205,6 +173,7 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) @@ -213,13 +182,13 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } @@ -230,6 +199,7 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) @@ -238,13 +208,13 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } @@ -256,6 +226,7 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1") + checkFilterPredicate('_1 <=> "1", classOf[Eq[_]], "1") checkFilterPredicate( '_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) @@ -265,13 +236,13 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1") + checkFilterPredicate(Literal("1") <=> '_1, classOf[Eq[_]], "1") checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1") checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4") checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") - checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3") checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) } } @@ -283,6 +254,7 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { withParquetDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df => checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq[_]], 1.b) + checkBinaryFilterPredicate('_1 <=> 1.b, classOf[Eq[_]], 1.b) checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkBinaryFilterPredicate( @@ -297,32 +269,20 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b) checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b) + checkBinaryFilterPredicate(Literal(1.b) <=> '_1, classOf[Eq[_]], 1.b) checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b) checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b) checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) - checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b) checkBinaryFilterPredicate( '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b))) } } -} - -class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { - lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi - - override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) - } - - override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) - } test("SPARK-6554: don't push down predicates which reference partition columns") { - import sqlContext.implicits._ + import testImplicits._ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { dir => @@ -338,37 +298,3 @@ class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeA } } } - -class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { - lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi - - override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) - } - - override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) - } - - test("SPARK-6742: don't push down predicates which reference partition columns") { - import sqlContext.implicits._ - - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { - withTempPath { dir => - val path = s"${dir.getCanonicalPath}/part=1" - (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) - - // If the "part = 1" filter gets pushed down, this query will throw an exception since - // "part" is not a valid column in the actual Parquet file - val df = DataFrame(sqlContext, org.apache.spark.sql.parquet.ParquetRelation( - path, - Some(sqlContext.sparkContext.hadoopConfiguration), sqlContext, - Seq(AttributeReference("part", IntegerType, false)()) )) - - checkAnswer( - df.filter("a = 1 or part = 1"), - (1 to 3).map(i => Row(1, i, i.toString))) - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala similarity index 74% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 284d99d4938d..cd552e83372f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -15,9 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet -import scala.collection.JavaConversions._ +import java.util.Collections + +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -28,16 +30,16 @@ import org.apache.parquet.example.data.simple.SimpleGroup import org.apache.parquet.example.data.{Group, GroupWriter} import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.hadoop.api.WriteSupport.WriteContext -import org.apache.parquet.hadoop.metadata.{CompressionCodecName, FileMetaData, ParquetMetadata} +import org.apache.parquet.hadoop.metadata.{BlockMetaData, CompressionCodecName, FileMetaData, ParquetMetadata} import org.apache.parquet.hadoop.{Footer, ParquetFileWriter, ParquetOutputCommitter, ParquetWriter} import org.apache.parquet.io.api.RecordConsumer import org.apache.parquet.schema.{MessageType, MessageTypeParser} -import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ // Write support class for nested groups: ParquetWriter initializes GroupWriteSupport @@ -63,9 +65,8 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS /** * A test suite that tests basic Parquet I/O. */ -class ParquetIOSuiteBase extends QueryTest with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ +class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { + import testImplicits._ /** * Writes `data` to a Parquet file, reads it back and check file contents. @@ -99,45 +100,28 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } test("fixed-length decimals") { - def makeDecimalRDD(decimal: DecimalType): DataFrame = - sqlContext.sparkContext + sparkContext .parallelize(0 to 1000) .map(i => Tuple1(i / 100.0)) .toDF() // Parquet doesn't allow column names with spaces, have to add an alias here .select($"_1" cast decimal as "dec") - for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { + for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17), (19, 0), (38, 37))) { withTempPath { dir => val data = makeDecimalRDD(DecimalType(precision, scale)) data.write.parquet(dir.getCanonicalPath) checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) } } - - // Decimals with precision above 18 are not yet supported - intercept[Throwable] { - withTempPath { dir => - makeDecimalRDD(DecimalType(19, 10)).write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).collect() - } - } - - // Unlimited-length decimals are not yet supported - intercept[Throwable] { - withTempPath { dir => - makeDecimalRDD(DecimalType.Unlimited).write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).collect() - } - } } test("date type") { def makeDateRDD(): DataFrame = - sqlContext.sparkContext + sparkContext .parallelize(0 to 1000) - .map(i => Tuple1(DateUtils.toJavaDate(i))) + .map(i => Tuple1(DateTimeUtils.toJavaDate(i))) .toDF() .select($"_1") @@ -158,6 +142,11 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { checkParquetFile(data) } + test("array and double") { + val data = (1 to 4).map(i => (i.toDouble, Seq(i.toDouble, (i + 1).toDouble))) + checkParquetFile(data) + } + test("struct") { val data = (1 to 4).map(i => Tuple1((i, s"val_$i"))) withParquetDataFrame(data) { df => @@ -218,9 +207,8 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { test("compression codec") { def compressionCodecFor(path: String): String = { val codecs = ParquetTypesConverter - .readMetaData(new Path(path), Some(configuration)) - .getBlocks - .flatMap(_.getColumns) + .readMetaData(new Path(path), Some(hadoopConfiguration)).getBlocks.asScala + .flatMap(_.getColumns.asScala) .map(_.getCodec.name()) .distinct @@ -289,14 +277,14 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { test("write metadata") { withTempPath { file => val path = new Path(file.toURI.toString) - val fs = FileSystem.getLocal(configuration) + val fs = FileSystem.getLocal(hadoopConfiguration) val attributes = ScalaReflection.attributesFor[(Int, String)] - ParquetTypesConverter.writeMetaData(attributes, path, configuration) + ParquetTypesConverter.writeMetaData(attributes, path, hadoopConfiguration) assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_COMMON_METADATA_FILE))) assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) - val metaData = ParquetTypesConverter.readMetaData(path, Some(configuration)) + val metaData = ParquetTypesConverter.readMetaData(path, Some(hadoopConfiguration)) val actualSchema = metaData.getFileMetaData.getSchema val expectedSchema = ParquetTypesConverter.convertFromAttributes(attributes) @@ -361,14 +349,16 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { """.stripMargin) withTempPath { location => - val extraMetadata = Map(RowReadSupport.SPARK_METADATA_KEY -> sparkSchema.toString) + val extraMetadata = Collections.singletonMap( + CatalystReadSupport.SPARK_METADATA_KEY, sparkSchema.toString) val fileMetadata = new FileMetaData(parquetSchema, extraMetadata, "Spark") val path = new Path(location.getCanonicalPath) ParquetFileWriter.writeMetadataFile( - sqlContext.sparkContext.hadoopConfiguration, + sparkContext.hadoopConfiguration, path, - new Footer(path, new ParquetMetadata(fileMetadata, Nil)) :: Nil) + Collections.singletonList( + new Footer(path, new ParquetMetadata(fileMetadata, Collections.emptyList())))) assertResult(sqlContext.read.parquet(path.toString).schema) { StructType( @@ -380,12 +370,36 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } test("SPARK-6352 DirectParquetOutputCommitter") { - val clonedConf = new Configuration(configuration) + val clonedConf = new Configuration(hadoopConfiguration) // Write to a parquet file and let it fail. // _temporary should be missing if direct output committer works. try { - configuration.set("spark.sql.parquet.output.committer.class", + hadoopConfiguration.set("spark.sql.parquet.output.committer.class", + classOf[DirectParquetOutputCommitter].getCanonicalName) + sqlContext.udf.register("div0", (x: Int) => x / 0) + withTempPath { dir => + intercept[org.apache.spark.SparkException] { + sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) + } + val path = new Path(dir.getCanonicalPath, "_temporary") + val fs = path.getFileSystem(hadoopConfiguration) + assert(!fs.exists(path)) + } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + } + } + + test("SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible") { + val clonedConf = new Configuration(hadoopConfiguration) + + // Write to a parquet file and let it fail. + // _temporary should be missing if direct output committer works. + try { + hadoopConfiguration.set("spark.sql.parquet.output.committer.class", "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => @@ -393,26 +407,27 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") - val fs = path.getFileSystem(configuration) + val fs = path.getFileSystem(hadoopConfiguration) assert(!fs.exists(path)) } } finally { // Hadoop 1 doesn't have `Configuration.unset` - configuration.clear() - clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) } } - test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overriden") { + + test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overridden") { withTempPath { dir => - val clonedConf = new Configuration(configuration) + val clonedConf = new Configuration(hadoopConfiguration) - configuration.set( + hadoopConfiguration.set( SQLConf.OUTPUT_COMMITTER_CLASS.key, classOf[ParquetOutputCommitter].getCanonicalName) - configuration.set( + hadoopConfiguration.set( "spark.sql.parquet.output.committer.class", - classOf[BogusParquetOutputCommitter].getCanonicalName) + classOf[JobCommitFailureParquetOutputCommitter].getCanonicalName) try { val message = intercept[SparkException] { @@ -421,31 +436,11 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { assert(message === "Intentional exception for testing purposes") } finally { // Hadoop 1 doesn't have `Configuration.unset` - configuration.clear() - clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) } } } -} - -class BogusParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) - extends ParquetOutputCommitter(outputPath, context) { - - override def commitJob(jobContext: JobContext): Unit = { - sys.error("Intentional exception for testing purposes") - } -} - -class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { - private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi - - override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) - } - - override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API.key, originalConf.toString) - } test("SPARK-6330 regression test") { // In 1.3.0, save to fs other than file: without configuring core-site.xml would get: @@ -458,16 +453,54 @@ class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterA }.toString assert(errorMessage.contains("UnknownHostException")) } + + test("SPARK-7837 Do not close output writer twice when commitTask() fails") { + val clonedConf = new Configuration(hadoopConfiguration) + + // Using a output committer that always fail when committing a task, so that both + // `commitTask()` and `abortTask()` are invoked. + hadoopConfiguration.set( + "spark.sql.parquet.output.committer.class", + classOf[TaskCommitFailureParquetOutputCommitter].getCanonicalName) + + try { + // Before fixing SPARK-7837, the following code results in an NPE because both + // `commitTask()` and `abortTask()` try to close output writers. + + withTempPath { dir => + val m1 = intercept[SparkException] { + sqlContext.range(1).coalesce(1).write.parquet(dir.getCanonicalPath) + }.getCause.getMessage + assert(m1.contains("Intentional exception for testing purposes")) + } + + withTempPath { dir => + val m2 = intercept[SparkException] { + val df = sqlContext.range(1).select('id as 'a, 'id as 'b).coalesce(1) + df.write.partitionBy("a").parquet(dir.getCanonicalPath) + }.getCause.getMessage + assert(m2.contains("Intentional exception for testing purposes")) + } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + } + } } -class ParquetDataSourceOffIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { - private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi +class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { - override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) + override def commitJob(jobContext: JobContext): Unit = { + sys.error("Intentional exception for testing purposes") } +} + +class TaskCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { - override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) + override def commitTask(context: TaskAttemptContext): Unit = { + sys.error("Intentional exception for testing purposes") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala new file mode 100644 index 000000000000..83b65fb419ed --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import org.apache.spark.sql.Row +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedSQLContext { + test("parquet files with different physical schemas but share the same logical schema") { + import ParquetCompatibilityTest._ + + // This test case writes two Parquet files, both representing the following Catalyst schema + // + // StructType( + // StructField( + // "f", + // ArrayType(IntegerType, containsNull = false), + // nullable = false)) + // + // The first Parquet file comes with parquet-avro style 2-level LIST-annotated group, while the + // other one comes with parquet-protobuf style 1-level unannotated primitive field. + withTempDir { dir => + val avroStylePath = new File(dir, "avro-style").getCanonicalPath + val protobufStylePath = new File(dir, "protobuf-style").getCanonicalPath + + val avroStyleSchema = + """message avro_style { + | required group f (LIST) { + | repeated int32 array; + | } + |} + """.stripMargin + + writeDirect(avroStylePath, avroStyleSchema, { rc => + rc.message { + rc.field("f", 0) { + rc.group { + rc.field("array", 0) { + rc.addInteger(0) + rc.addInteger(1) + } + } + } + } + }) + + logParquetSchema(avroStylePath) + + val protobufStyleSchema = + """message protobuf_style { + | repeated int32 f; + |} + """.stripMargin + + writeDirect(protobufStylePath, protobufStyleSchema, { rc => + rc.message { + rc.field("f", 0) { + rc.addInteger(2) + rc.addInteger(3) + } + } + }) + + logParquetSchema(protobufStylePath) + + checkAnswer( + sqlContext.read.parquet(dir.getCanonicalPath), + Seq( + Row(Seq(0, 1)), + Row(Seq(2, 3)))) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala similarity index 85% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 01df189d1f3b..7bac8609e1b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.File import java.math.BigInteger @@ -26,12 +26,12 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.io.Files import org.apache.hadoop.fs.Path +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.sources.PartitioningUtils._ -import org.apache.spark.sql.sources.{LogicalRelation, Partition, PartitionSpec} +import org.apache.spark.sql.execution.datasources.{LogicalRelation, PartitionSpec, Partition, PartitioningUtils} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql._ import org.apache.spark.unsafe.types.UTF8String // The data where the partitioning key exists only in the directory structure. @@ -40,11 +40,9 @@ case class ParquetData(intField: Int, stringField: String) // The data that also includes the partitioning key case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) -class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { - - override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ - import sqlContext.sql +class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with SharedSQLContext { + import PartitioningUtils._ + import testImplicits._ val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__" @@ -447,7 +445,12 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { (1 to 10).map(i => (i, i.toString)).toDF("intField", "stringField"), makePartitionDir(base, defaultPartitionName, "pi" -> 2)) - sqlContext.read.format("parquet").load(base.getCanonicalPath).registerTempTable("t") + sqlContext + .read + .option("mergeSchema", "true") + .format("parquet") + .load(base.getCanonicalPath) + .registerTempTable("t") withTempTable("t") { checkAnswer( @@ -462,7 +465,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { (1 to 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(dir.getCanonicalPath) val queryExecution = sqlContext.read.parquet(dir.getCanonicalPath).queryExecution queryExecution.analyzed.collectFirst { - case LogicalRelation(relation: ParquetRelation2) => + case LogicalRelation(relation: ParquetRelation) => assert(relation.partitionSpec === PartitionSpec.emptySpec) }.getOrElse { fail(s"Expecting a ParquetRelation2, but got:\n$queryExecution") @@ -504,7 +507,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { FloatType, DoubleType, DecimalType(10, 5), - DecimalType.Unlimited, + DecimalType.SYSTEM_DEFAULT, DateType, TimestampType, StringType) @@ -514,7 +517,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { } val schema = StructType(partitionColumns :+ StructField(s"i", StringType)) - val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(row :: Nil), schema) + val df = sqlContext.createDataFrame(sparkContext.parallelize(row :: Nil), schema) withTempPath { dir => df.write.format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) @@ -538,4 +541,60 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { checkAnswer(sqlContext.read.format("parquet").load(dir.getCanonicalPath), df) } } + + test("listConflictingPartitionColumns") { + def makeExpectedMessage(colNameLists: Seq[String], paths: Seq[String]): String = { + val conflictingColNameLists = colNameLists.zipWithIndex.map { case (list, index) => + s"\tPartition column name list #$index: $list" + }.mkString("\n", "\n", "\n") + + // scalastyle:off + s"""Conflicting partition column names detected: + |$conflictingColNameLists + |For partitioned table directories, data files should only live in leaf directories. + |And directories at the same level should have the same partition column name. + |Please check the following directories for unexpected files or inconsistent partition column names: + |${paths.map("\t" + _).mkString("\n", "\n", "")} + """.stripMargin.trim + // scalastyle:on + } + + assert( + listConflictingPartitionColumns( + Seq( + (new Path("file:/tmp/foo/a=1"), PartitionValues(Seq("a"), Seq(Literal(1)))), + (new Path("file:/tmp/foo/b=1"), PartitionValues(Seq("b"), Seq(Literal(1)))))).trim === + makeExpectedMessage(Seq("a", "b"), Seq("file:/tmp/foo/a=1", "file:/tmp/foo/b=1"))) + + assert( + listConflictingPartitionColumns( + Seq( + (new Path("file:/tmp/foo/a=1/_temporary"), PartitionValues(Seq("a"), Seq(Literal(1)))), + (new Path("file:/tmp/foo/a=1"), PartitionValues(Seq("a"), Seq(Literal(1)))))).trim === + makeExpectedMessage( + Seq("a"), + Seq("file:/tmp/foo/a=1/_temporary", "file:/tmp/foo/a=1"))) + + assert( + listConflictingPartitionColumns( + Seq( + (new Path("file:/tmp/foo/a=1"), + PartitionValues(Seq("a"), Seq(Literal(1)))), + (new Path("file:/tmp/foo/a=1/b=foo"), + PartitionValues(Seq("a", "b"), Seq(Literal(1), Literal("foo")))))).trim === + makeExpectedMessage( + Seq("a", "a, b"), + Seq("file:/tmp/foo/a=1", "file:/tmp/foo/a=1/b=foo"))) + } + + test("Parallel partition discovery") { + withTempPath { dir => + withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "1") { + val path = dir.getCanonicalPath + val df = sqlContext.range(5).select('id as 'a, 'id as 'b, 'id as 'c).coalesce(1) + df.write.partitionBy("b", "c").parquet(path) + checkAnswer(sqlContext.read.parquet(path), df) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala new file mode 100644 index 000000000000..b290429c2a02 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { + + private def readParquetProtobufFile(name: String): DataFrame = { + val url = Thread.currentThread().getContextClassLoader.getResource(name) + sqlContext.read.parquet(url.toString) + } + + test("unannotated array of primitive type") { + checkAnswer(readParquetProtobufFile("old-repeated-int.parquet"), Row(Seq(1, 2, 3))) + } + + test("unannotated array of struct") { + checkAnswer( + readParquetProtobufFile("old-repeated-message.parquet"), + Row( + Seq( + Row("First inner", null, null), + Row(null, "Second inner", null), + Row(null, null, "Third inner")))) + + checkAnswer( + readParquetProtobufFile("proto-repeated-struct.parquet"), + Row( + Seq( + Row("0 - 1", "0 - 2", "0 - 3"), + Row("1 - 1", "1 - 2", "1 - 3")))) + + checkAnswer( + readParquetProtobufFile("proto-struct-with-array-many.parquet"), + Seq( + Row( + Seq( + Row("0 - 0 - 1", "0 - 0 - 2", "0 - 0 - 3"), + Row("0 - 1 - 1", "0 - 1 - 2", "0 - 1 - 3"))), + Row( + Seq( + Row("1 - 0 - 1", "1 - 0 - 2", "1 - 0 - 3"), + Row("1 - 1 - 1", "1 - 1 - 2", "1 - 1 - 3"))), + Row( + Seq( + Row("2 - 0 - 1", "2 - 0 - 2", "2 - 0 - 3"), + Row("2 - 1 - 1", "2 - 1 - 2", "2 - 1 - 3"))))) + } + + test("struct with unannotated array") { + checkAnswer( + readParquetProtobufFile("proto-struct-with-array.parquet"), + Row(10, 9, Seq.empty, null, Row(9), Seq(Row(9), Row(10)))) + } + + test("unannotated array of struct with unannotated array") { + checkAnswer( + readParquetProtobufFile("nested-array-struct.parquet"), + Seq( + Row(2, Seq(Row(1, Seq(Row(3))))), + Row(5, Seq(Row(4, Seq(Row(6))))), + Row(8, Seq(Row(7, Seq(Row(9))))))) + } + + test("unannotated array of string") { + checkAnswer( + readParquetProtobufFile("proto-repeated-string.parquet"), + Seq( + Row(Seq("hello", "world")), + Row(Seq("good", "bye")), + Row(Seq("one", "two", "three")))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala new file mode 100644 index 000000000000..1c1cfa34ad04 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -0,0 +1,552 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.execution.datasources.parquet.TestingUDT.{NestedStruct, NestedStructUDT} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +/** + * A test suite that tests various Parquet queries. + */ +class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext { + import testImplicits._ + + test("simple select queries") { + withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { + checkAnswer(sql("SELECT _1 FROM t where t._1 > 5"), (6 until 10).map(Row.apply(_))) + checkAnswer(sql("SELECT _1 FROM t as tmp where tmp._1 < 5"), (0 until 5).map(Row.apply(_))) + } + } + + test("appending") { + val data = (0 until 10).map(i => (i, i.toString)) + sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + withParquetTable(data, "t") { + sql("INSERT INTO TABLE t SELECT * FROM tmp") + checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) + } + sqlContext.catalog.unregisterTable(Seq("tmp")) + } + + test("overwriting") { + val data = (0 until 10).map(i => (i, i.toString)) + sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + withParquetTable(data, "t") { + sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") + checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple)) + } + sqlContext.catalog.unregisterTable(Seq("tmp")) + } + + test("self-join") { + // 4 rows, cells of column 1 of row 2 and row 4 are null + val data = (1 to 4).map { i => + val maybeInt = if (i % 2 == 0) None else Some(i) + (maybeInt, i.toString) + } + + withParquetTable(data, "t") { + val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") + val queryOutput = selfJoin.queryExecution.analyzed.output + + assertResult(4, "Field count mismatches")(queryOutput.size) + assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") { + queryOutput.filter(_.name == "_1").map(_.exprId).size + } + + checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) + } + } + + test("nested data - struct with array field") { + val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { + case Tuple1((_, Seq(string))) => Row(string) + }) + } + } + + test("nested data - array of struct") { + val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { + case Tuple1(Seq((_, string))) => Row(string) + }) + } + } + + test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { + withParquetTable((1 to 10).map(Tuple1.apply), "t") { + checkAnswer(sql("SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) + } + } + + test("SPARK-5309 strings stored using dictionary compression in parquet") { + withParquetTable((0 until 1000).map(i => ("same", "run_" + i /100, 1)), "t") { + + checkAnswer(sql("SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), + (0 until 10).map(i => Row("same", "run_" + i, 100))) + + checkAnswer(sql("SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), + List(Row("same", "run_5", 100))) + } + } + + test("SPARK-6917 DecimalType should work with non-native types") { + val data = (1 to 10).map(i => Row(Decimal(i, 18, 0), new java.sql.Timestamp(i))) + val schema = StructType(List(StructField("d", DecimalType(18, 0), false), + StructField("time", TimestampType, false)).toArray) + withTempPath { file => + val df = sqlContext.createDataFrame(sparkContext.parallelize(data), schema) + df.write.parquet(file.getCanonicalPath) + val df2 = sqlContext.read.parquet(file.getCanonicalPath) + checkAnswer(df2, df.collect().toSeq) + } + } + + test("Enabling/disabling merging partfiles when merging parquet schema") { + def testSchemaMerging(expectedColumnNumber: Int): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + // delete summary files, so if we don't merge part-files, one column will not be included. + Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata")) + Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata")) + assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) + } + } + + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", + SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "true") { + testSchemaMerging(2) + } + + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", + SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "false") { + testSchemaMerging(3) + } + } + + test("Enabling/disabling schema merging") { + def testSchemaMerging(expectedColumnNumber: Int): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) + } + } + + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true") { + testSchemaMerging(3) + } + + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") { + testSchemaMerging(2) + } + } + + test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") { + withTempPath { dir => + val basePath = dir.getCanonicalPath + sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString) + + // Disables the global SQL option for schema merging + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") { + assertResult(2) { + // Disables schema merging via data source option + sqlContext.read.option("mergeSchema", "false").parquet(basePath).columns.length + } + + assertResult(3) { + // Enables schema merging via data source option + sqlContext.read.option("mergeSchema", "true").parquet(basePath).columns.length + } + } + } + } + + test("SPARK-9119 Decimal should be correctly written into parquet") { + withTempPath { dir => + val basePath = dir.getCanonicalPath + val schema = StructType(Array(StructField("name", DecimalType(10, 5), false))) + val rowRDD = sparkContext.parallelize(Array(Row(Decimal("67123.45")))) + val df = sqlContext.createDataFrame(rowRDD, schema) + df.write.parquet(basePath) + + val decimal = sqlContext.read.parquet(basePath).first().getDecimal(0) + assert(Decimal("67123.45") === Decimal(decimal)) + } + } + + test("SPARK-10005 Schema merging for nested struct") { + withTempPath { dir => + val path = dir.getCanonicalPath + + def append(df: DataFrame): Unit = { + df.write.mode(SaveMode.Append).parquet(path) + } + + // Note that both the following two DataFrames contain a single struct column with multiple + // nested fields. + append((1 to 2).map(i => Tuple1((i, i))).toDF()) + append((1 to 2).map(i => Tuple1((i, i, i))).toDF()) + + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { + checkAnswer( + sqlContext.read.option("mergeSchema", "true").parquet(path), + Seq( + Row(Row(1, 1, null)), + Row(Row(2, 2, null)), + Row(Row(1, 1, 1)), + Row(Row(2, 2, 2)))) + } + } + } + + test("SPARK-10301 requested schema clipping - same schema") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(0L, 1L))) + } + } + + // This test case is ignored because of parquet-mr bug PARQUET-370 + ignore("SPARK-10301 requested schema clipping - schemas with disjoint sets of fields") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("c", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(null, null))) + } + } + + test("SPARK-10301 requested schema clipping - requested schema contains physical schema") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true) + .add("c", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(0L, 1L, null, null))) + } + + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'd', id + 3) AS s").coalesce(1) + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true) + .add("c", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(0L, null, null, 3L))) + } + } + + test("SPARK-10301 requested schema clipping - physical schema contains requested schema") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2, 'd', id + 3) AS s") + .coalesce(1) + + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(0L, 1L))) + } + + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2, 'd', id + 3) AS s") + .coalesce(1) + + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(0L, 3L))) + } + } + + test("SPARK-10301 requested schema clipping - schemas overlap but don't contain each other") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s") + .coalesce(1) + + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("b", LongType, nullable = true) + .add("c", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(1L, 2L, null))) + } + } + + test("SPARK-10301 requested schema clipping - deeply nested struct") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', ARRAY(NAMED_STRUCT('b', id, 'c', id))) AS s") + .coalesce(1) + + df.write.parquet(path) + + val userDefinedSchema = new StructType() + .add("s", + new StructType() + .add( + "a", + ArrayType( + new StructType() + .add("b", LongType, nullable = true) + .add("d", StringType, nullable = true), + containsNull = true), + nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(Seq(Row(0, null))))) + } + } + + test("SPARK-10301 requested schema clipping - out of order") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s") + .coalesce(1) + + val df2 = sqlContext + .range(1, 2) + .selectExpr("NAMED_STRUCT('c', id + 2, 'b', id + 1, 'd', id + 3) AS s") + .coalesce(1) + + df1.write.parquet(path) + df2.write.mode(SaveMode.Append).parquet(path) + + val userDefinedSchema = new StructType() + .add("s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Seq( + Row(Row(0, 1, null)), + Row(Row(null, 2, 4)))) + } + } + + test("SPARK-10301 requested schema clipping - schema merging") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', id, 'c', id + 2) AS s") + .coalesce(1) + + val df2 = sqlContext + .range(1, 2) + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s") + .coalesce(1) + + df1.write.mode(SaveMode.Append).parquet(path) + df2.write.mode(SaveMode.Append).parquet(path) + + checkAnswer( + sqlContext + .read + .option("mergeSchema", "true") + .parquet(path) + .selectExpr("s.a", "s.b", "s.c"), + Seq( + Row(0, null, 2), + Row(1, 2, 3))) + } + } + + test("SPARK-10301 requested schema clipping - UDT") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = sqlContext + .range(1) + .selectExpr( + """NAMED_STRUCT( + | 'f0', CAST(id AS STRING), + | 'f1', NAMED_STRUCT( + | 'a', CAST(id + 1 AS INT), + | 'b', CAST(id + 2 AS LONG), + | 'c', CAST(id + 3.5 AS DOUBLE) + | ) + |) AS s + """.stripMargin) + .coalesce(1) + + df.write.mode(SaveMode.Append).parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("f1", new NestedStructUDT, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(NestedStruct(1, 2L, 3.5D)))) + } + } +} + +object TestingUDT { + @SQLUserDefinedType(udt = classOf[NestedStructUDT]) + case class NestedStruct(a: Integer, b: Long, c: Double) + + class NestedStructUDT extends UserDefinedType[NestedStruct] { + override def sqlType: DataType = + new StructType() + .add("a", IntegerType, nullable = true) + .add("b", LongType, nullable = false) + .add("c", DoubleType, nullable = false) + + override def serialize(obj: Any): Any = { + val row = new SpecificMutableRow(sqlType.asInstanceOf[StructType].map(_.dataType)) + obj match { + case n: NestedStruct => + row.setInt(0, n.a) + row.setLong(1, n.b) + row.setDouble(2, n.c) + } + } + + override def userClass: Class[NestedStruct] = classOf[NestedStruct] + + override def deserialize(datum: Any): NestedStruct = { + datum match { + case row: InternalRow => + NestedStruct(row.getInt(0), row.getLong(1), row.getDouble(2)) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala new file mode 100644 index 000000000000..5a8f772c3228 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -0,0 +1,1472 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag + +import org.apache.parquet.schema.MessageTypeParser + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ + +abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { + + /** + * Checks whether the reflected Parquet message type for product type `T` conforms `messageType`. + */ + protected def testSchemaInference[T <: Product: ClassTag: TypeTag]( + testName: String, + messageType: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + testSchema( + testName, + StructType.fromAttributes(ScalaReflection.attributesFor[T]), + messageType, + binaryAsString, + int96AsTimestamp, + followParquetFormatSpec, + isThriftDerived) + } + + protected def testParquetToCatalyst( + testName: String, + sqlSchema: StructType, + parquetSchema: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + val converter = new CatalystSchemaConverter( + assumeBinaryIsString = binaryAsString, + assumeInt96IsTimestamp = int96AsTimestamp, + followParquetFormatSpec = followParquetFormatSpec) + + test(s"sql <= parquet: $testName") { + val actual = converter.convert(MessageTypeParser.parseMessageType(parquetSchema)) + val expected = sqlSchema + assert( + actual === expected, + s"""Schema mismatch. + |Expected schema: ${expected.json} + |Actual schema: ${actual.json} + """.stripMargin) + } + } + + protected def testCatalystToParquet( + testName: String, + sqlSchema: StructType, + parquetSchema: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + val converter = new CatalystSchemaConverter( + assumeBinaryIsString = binaryAsString, + assumeInt96IsTimestamp = int96AsTimestamp, + followParquetFormatSpec = followParquetFormatSpec) + + test(s"sql => parquet: $testName") { + val actual = converter.convert(sqlSchema) + val expected = MessageTypeParser.parseMessageType(parquetSchema) + actual.checkContains(expected) + expected.checkContains(actual) + } + } + + protected def testSchema( + testName: String, + sqlSchema: StructType, + parquetSchema: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + + testCatalystToParquet( + testName, + sqlSchema, + parquetSchema, + binaryAsString, + int96AsTimestamp, + followParquetFormatSpec, + isThriftDerived) + + testParquetToCatalyst( + testName, + sqlSchema, + parquetSchema, + binaryAsString, + int96AsTimestamp, + followParquetFormatSpec, + isThriftDerived) + } +} + +class ParquetSchemaInferenceSuite extends ParquetSchemaTest { + testSchemaInference[(Boolean, Int, Long, Float, Double, Array[Byte])]( + "basic types", + """ + |message root { + | required boolean _1; + | required int32 _2; + | required int64 _3; + | required float _4; + | required double _5; + | optional binary _6; + |} + """.stripMargin, + binaryAsString = false) + + testSchemaInference[(Byte, Short, Int, Long, java.sql.Date)]( + "logical integral types", + """ + |message root { + | required int32 _1 (INT_8); + | required int32 _2 (INT_16); + | required int32 _3 (INT_32); + | required int64 _4 (INT_64); + | optional int32 _5 (DATE); + |} + """.stripMargin) + + testSchemaInference[Tuple1[String]]( + "string", + """ + |message root { + | optional binary _1 (UTF8); + |} + """.stripMargin, + binaryAsString = true) + + testSchemaInference[Tuple1[String]]( + "binary enum as string", + """ + |message root { + | optional binary _1 (ENUM); + |} + """.stripMargin) + + testSchemaInference[Tuple1[Seq[Int]]]( + "non-nullable array - non-standard", + """ + |message root { + | optional group _1 (LIST) { + | repeated int32 array; + | } + |} + """.stripMargin) + + testSchemaInference[Tuple1[Seq[Int]]]( + "non-nullable array - standard", + """ + |message root { + | optional group _1 (LIST) { + | repeated group list { + | required int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[Tuple1[Seq[Integer]]]( + "nullable array - non-standard", + """ + |message root { + | optional group _1 (LIST) { + | repeated group bag { + | optional int32 array; + | } + | } + |} + """.stripMargin) + + testSchemaInference[Tuple1[Seq[Integer]]]( + "nullable array - standard", + """ + |message root { + | optional group _1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[Tuple1[Map[Int, String]]]( + "map - standard", + """ + |message root { + | optional group _1 (MAP) { + | repeated group key_value { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[Tuple1[Map[Int, String]]]( + "map - non-standard", + """ + |message root { + | optional group _1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testSchemaInference[Tuple1[Pair[Int, String]]]( + "struct", + """ + |message root { + | optional group _1 { + | required int32 _1; + | optional binary _2 (UTF8); + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( + "deeply nested type - non-standard", + """ + |message root { + | optional group _1 (MAP_KEY_VALUE) { + | repeated group map { + | required int32 key; + | optional group value { + | optional binary _1 (UTF8); + | optional group _2 (LIST) { + | repeated group bag { + | optional group array { + | required int32 _1; + | required double _2; + | } + | } + | } + | } + | } + | } + |} + """.stripMargin) + + testSchemaInference[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( + "deeply nested type - standard", + """ + |message root { + | optional group _1 (MAP) { + | repeated group key_value { + | required int32 key; + | optional group value { + | optional binary _1 (UTF8); + | optional group _2 (LIST) { + | repeated group list { + | optional group element { + | required int32 _1; + | required double _2; + | } + | } + | } + | } + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[(Option[Int], Map[Int, Option[Double]])]( + "optional types", + """ + |message root { + | optional int32 _1; + | optional group _2 (MAP) { + | repeated group key_value { + | required int32 key; + | optional double value; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + // Parquet files generated by parquet-thrift are already handled by the schema converter, but + // let's leave this test here until both read path and write path are all updated. + ignore("thrift generated parquet schema") { + // Test for SPARK-4520 -- ensure that thrift generated parquet schema is generated + // as expected from attributes + testSchemaInference[( + Array[Byte], Array[Byte], Array[Byte], Seq[Int], Map[Array[Byte], Seq[Int]])]( + "thrift generated parquet schema", + """ + |message root { + | optional binary _1 (UTF8); + | optional binary _2 (UTF8); + | optional binary _3 (UTF8); + | optional group _4 (LIST) { + | repeated int32 _4_tuple; + | } + | optional group _5 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required binary key (UTF8); + | optional group value (LIST) { + | repeated int32 value_tuple; + | } + | } + | } + |} + """.stripMargin, + isThriftDerived = true) + } +} + +class ParquetSchemaSuite extends ParquetSchemaTest { + test("DataType string parser compatibility") { + // This is the generated string from previous versions of the Spark SQL, using the following: + // val schema = StructType(List( + // StructField("c1", IntegerType, false), + // StructField("c2", BinaryType, true))) + val caseClassString = + "StructType(List(StructField(c1,IntegerType,false), StructField(c2,BinaryType,true)))" + + // scalastyle:off + val jsonString = """{"type":"struct","fields":[{"name":"c1","type":"integer","nullable":false,"metadata":{}},{"name":"c2","type":"binary","nullable":true,"metadata":{}}]}""" + // scalastyle:on + + val fromCaseClassString = ParquetTypesConverter.convertFromString(caseClassString) + val fromJson = ParquetTypesConverter.convertFromString(jsonString) + + (fromCaseClassString, fromJson).zipped.foreach { (a, b) => + assert(a.name == b.name) + assert(a.dataType === b.dataType) + assert(a.nullable === b.nullable) + } + } + + test("merge with metastore schema") { + // Field type conflict resolution + assertResult( + StructType(Seq( + StructField("lowerCase", StringType), + StructField("UPPERCase", DoubleType, nullable = false)))) { + + ParquetRelation.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("lowercase", StringType), + StructField("uppercase", DoubleType, nullable = false))), + + StructType(Seq( + StructField("lowerCase", BinaryType), + StructField("UPPERCase", IntegerType, nullable = true)))) + } + + // MetaStore schema is subset of parquet schema + assertResult( + StructType(Seq( + StructField("UPPERCase", DoubleType, nullable = false)))) { + + ParquetRelation.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("uppercase", DoubleType, nullable = false))), + + StructType(Seq( + StructField("lowerCase", BinaryType), + StructField("UPPERCase", IntegerType, nullable = true)))) + } + + // Metastore schema contains additional non-nullable fields. + assert(intercept[Throwable] { + ParquetRelation.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("uppercase", DoubleType, nullable = false), + StructField("lowerCase", BinaryType, nullable = false))), + + StructType(Seq( + StructField("UPPERCase", IntegerType, nullable = true)))) + }.getMessage.contains("detected conflicting schemas")) + + // Conflicting non-nullable field names + intercept[Throwable] { + ParquetRelation.mergeMetastoreParquetSchema( + StructType(Seq(StructField("lower", StringType, nullable = false))), + StructType(Seq(StructField("lowerCase", BinaryType)))) + } + } + + test("merge missing nullable fields from Metastore schema") { + // Standard case: Metastore schema contains additional nullable fields not present + // in the Parquet file schema. + assertResult( + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = true)))) { + ParquetRelation.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("firstfield", StringType, nullable = true), + StructField("secondfield", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = true))), + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true)))) + } + + // Merge should fail if the Metastore contains any additional fields that are not + // nullable. + assert(intercept[Throwable] { + ParquetRelation.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("firstfield", StringType, nullable = true), + StructField("secondfield", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = false))), + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true)))) + }.getMessage.contains("detected conflicting schemas")) + } + + // ======================================================= + // Tests for converting Parquet LIST to Catalyst ArrayType + // ======================================================= + + testParquetToCatalyst( + "Backwards-compatibility: LIST with nullable element type - 1 - standard", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with nullable element type - 2", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group element { + | optional int32 num; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 1 - standard", + StructType(Seq( + StructField("f1", ArrayType(IntegerType, containsNull = false), nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | required int32 element; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 2", + StructType(Seq( + StructField("f1", ArrayType(IntegerType, containsNull = false), nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group element { + | required int32 num; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 3", + StructType(Seq( + StructField("f1", ArrayType(IntegerType, containsNull = false), nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated int32 element; + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 4", + StructType(Seq( + StructField( + "f1", + ArrayType( + StructType(Seq( + StructField("str", StringType, nullable = false), + StructField("num", IntegerType, nullable = false))), + containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group element { + | required binary str (UTF8); + | required int32 num; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 5 - parquet-avro style", + StructType(Seq( + StructField( + "f1", + ArrayType( + StructType(Seq( + StructField("str", StringType, nullable = false))), + containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group array { + | required binary str (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 6 - parquet-thrift style", + StructType(Seq( + StructField( + "f1", + ArrayType( + StructType(Seq( + StructField("str", StringType, nullable = false))), + containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group f1_tuple { + | required binary str (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type 7 - " + + "parquet-protobuf primitive lists", + new StructType() + .add("f1", ArrayType(IntegerType, containsNull = false), nullable = false), + """message root { + | repeated int32 f1; + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type 8 - " + + "parquet-protobuf non-primitive lists", + { + val elementType = + new StructType() + .add("c1", StringType, nullable = true) + .add("c2", IntegerType, nullable = false) + + new StructType() + .add("f1", ArrayType(elementType, containsNull = false), nullable = false) + }, + """message root { + | repeated group f1 { + | optional binary c1 (UTF8); + | required int32 c2; + | } + |} + """.stripMargin) + + // ======================================================= + // Tests for converting Catalyst ArrayType to Parquet LIST + // ======================================================= + + testCatalystToParquet( + "Backwards-compatibility: LIST with nullable element type - 1 - standard", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: LIST with nullable element type - 2 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group bag { + | optional int32 array; + | } + | } + |} + """.stripMargin) + + testCatalystToParquet( + "Backwards-compatibility: LIST with non-nullable element type - 1 - standard", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | required int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: LIST with non-nullable element type - 2 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated int32 array; + | } + |} + """.stripMargin) + + // ==================================================== + // Tests for converting Parquet Map to Catalyst MapType + // ==================================================== + + testParquetToCatalyst( + "Backwards-compatibility: MAP with non-nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with non-nullable value type - 2", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP_KEY_VALUE) { + | repeated group map { + | required int32 num; + | required binary str (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with non-nullable value type - 3 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with nullable value type - 2", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP_KEY_VALUE) { + | repeated group map { + | required int32 num; + | optional binary str (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with nullable value type - 3 - parquet-avro style", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin) + + // ==================================================== + // Tests for converting Catalyst MapType to Parquet Map + // ==================================================== + + testCatalystToParquet( + "Backwards-compatibility: MAP with non-nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: MAP with non-nullable value type - 2 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testCatalystToParquet( + "Backwards-compatibility: MAP with nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: MAP with nullable value type - 3 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin) + + // ================================= + // Tests for conversion for decimals + // ================================= + + testSchema( + "DECIMAL(1, 0) - standard", + StructType(Seq(StructField("f1", DecimalType(1, 0)))), + """message root { + | optional int32 f1 (DECIMAL(1, 0)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(8, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(8, 3)))), + """message root { + | optional int32 f1 (DECIMAL(8, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(9, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(9, 3)))), + """message root { + | optional int32 f1 (DECIMAL(9, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(18, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(18, 3)))), + """message root { + | optional int64 f1 (DECIMAL(18, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(19, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(19, 3)))), + """message root { + | optional fixed_len_byte_array(9) f1 (DECIMAL(19, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(1, 0) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(1, 0)))), + """message root { + | optional fixed_len_byte_array(1) f1 (DECIMAL(1, 0)); + |} + """.stripMargin) + + testSchema( + "DECIMAL(8, 3) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(8, 3)))), + """message root { + | optional fixed_len_byte_array(4) f1 (DECIMAL(8, 3)); + |} + """.stripMargin) + + testSchema( + "DECIMAL(9, 3) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(9, 3)))), + """message root { + | optional fixed_len_byte_array(5) f1 (DECIMAL(9, 3)); + |} + """.stripMargin) + + testSchema( + "DECIMAL(18, 3) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(18, 3)))), + """message root { + | optional fixed_len_byte_array(8) f1 (DECIMAL(18, 3)); + |} + """.stripMargin) + + private def testSchemaClipping( + testName: String, + parquetSchema: String, + catalystSchema: StructType, + expectedSchema: String): Unit = { + test(s"Clipping - $testName") { + val expected = MessageTypeParser.parseMessageType(expectedSchema) + val actual = CatalystReadSupport.clipParquetSchema( + MessageTypeParser.parseMessageType(parquetSchema), catalystSchema) + + try { + expected.checkContains(actual) + actual.checkContains(expected) + } catch { case cause: Throwable => + fail( + s"""Expected clipped schema: + |$expected + |Actual clipped schema: + |$actual + """.stripMargin, + cause) + } + } + } + + testSchemaClipping( + "simple nested struct", + + parquetSchema = + """message root { + | required group f0 { + | optional int32 f00; + | optional int32 f01; + | } + |} + """.stripMargin, + + catalystSchema = { + val f0Type = new StructType().add("f00", IntegerType, nullable = true) + new StructType() + .add("f0", f0Type, nullable = false) + .add("f1", IntegerType, nullable = true) + }, + + expectedSchema = + """message root { + | required group f0 { + | optional int32 f00; + | } + | optional int32 f1; + |} + """.stripMargin) + + testSchemaClipping( + "parquet-protobuf style array", + + parquetSchema = + """message root { + | required group f0 { + | repeated binary f00 (UTF8); + | repeated group f01 { + | optional int32 f010; + | optional double f011; + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f00Type = ArrayType(StringType, containsNull = false) + val f01Type = ArrayType( + new StructType() + .add("f011", DoubleType, nullable = true), + containsNull = false) + + val f0Type = new StructType() + .add("f00", f00Type, nullable = false) + .add("f01", f01Type, nullable = false) + val f1Type = ArrayType(IntegerType, containsNull = true) + + new StructType() + .add("f0", f0Type, nullable = false) + .add("f1", f1Type, nullable = true) + }, + + expectedSchema = + """message root { + | required group f0 { + | repeated binary f00 (UTF8); + | repeated group f01 { + | optional double f011; + | } + | } + | + | optional group f1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "parquet-thrift style array", + + parquetSchema = + """message root { + | required group f0 { + | optional group f00 (LIST) { + | repeated binary f00_tuple (UTF8); + | } + | + | optional group f01 (LIST) { + | repeated group f01_tuple { + | optional int32 f010; + | optional double f011; + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f01ElementType = new StructType() + .add("f011", DoubleType, nullable = true) + .add("f012", LongType, nullable = true) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = false), nullable = true) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 { + | optional group f00 (LIST) { + | repeated binary f00_tuple (UTF8); + | } + | + | optional group f01 (LIST) { + | repeated group f01_tuple { + | optional double f011; + | optional int64 f012; + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "parquet-avro style array", + + parquetSchema = + """message root { + | required group f0 { + | optional group f00 (LIST) { + | repeated binary array (UTF8); + | } + | + | optional group f01 (LIST) { + | repeated group array { + | optional int32 f010; + | optional double f011; + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f01ElementType = new StructType() + .add("f011", DoubleType, nullable = true) + .add("f012", LongType, nullable = true) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = false), nullable = true) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 { + | optional group f00 (LIST) { + | repeated binary array (UTF8); + | } + | + | optional group f01 (LIST) { + | repeated group array { + | optional double f011; + | optional int64 f012; + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "parquet-hive style array", + + parquetSchema = + """message root { + | optional group f0 { + | optional group f00 (LIST) { + | repeated group bag { + | optional binary array_element; + | } + | } + | + | optional group f01 (LIST) { + | repeated group bag { + | optional group array_element { + | optional int32 f010; + | optional double f011; + | } + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f01ElementType = new StructType() + .add("f011", DoubleType, nullable = true) + .add("f012", LongType, nullable = true) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = true), nullable = true) + .add("f01", ArrayType(f01ElementType, containsNull = true), nullable = true) + + new StructType().add("f0", f0Type, nullable = true) + }, + + expectedSchema = + """message root { + | optional group f0 { + | optional group f00 (LIST) { + | repeated group bag { + | optional binary array_element; + | } + | } + | + | optional group f01 (LIST) { + | repeated group bag { + | optional group array_element { + | optional double f011; + | optional int64 f012; + | } + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "2-level list of required struct", + + parquetSchema = + s"""message root { + | required group f0 { + | required group f00 (LIST) { + | repeated group element { + | required int32 f000; + | optional int64 f001; + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f00ElementType = + new StructType() + .add("f001", LongType, nullable = true) + .add("f002", DoubleType, nullable = false) + + val f00Type = ArrayType(f00ElementType, containsNull = false) + val f0Type = new StructType().add("f00", f00Type, nullable = false) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + s"""message root { + | required group f0 { + | required group f00 (LIST) { + | repeated group element { + | optional int64 f001; + | required double f002; + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "standard array", + + parquetSchema = + """message root { + | required group f0 { + | optional group f00 (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) { + | repeated group list { + | required group element { + | optional int32 f010; + | optional double f011; + | } + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f01ElementType = new StructType() + .add("f011", DoubleType, nullable = true) + .add("f012", LongType, nullable = true) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = false), nullable = true) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 { + | optional group f00 (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) { + | repeated group list { + | required group element { + | optional double f011; + | optional int64 f012; + | } + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "empty requested schema", + + parquetSchema = + """message root { + | required group f0 { + | required int32 f00; + | required int64 f01; + | } + |} + """.stripMargin, + + catalystSchema = new StructType(), + + expectedSchema = "message root {}") + + testSchemaClipping( + "disjoint field sets", + + parquetSchema = + """message root { + | required group f0 { + | required int32 f00; + | required int64 f01; + | } + |} + """.stripMargin, + + catalystSchema = + new StructType() + .add( + "f0", + new StructType() + .add("f02", FloatType, nullable = true) + .add("f03", DoubleType, nullable = true), + nullable = true), + + expectedSchema = + """message root { + | required group f0 { + | optional float f02; + | optional double f03; + | } + |} + """.stripMargin) + + testSchemaClipping( + "parquet-avro style map", + + parquetSchema = + """message root { + | required group f0 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required group value { + | required int32 value_f0; + | required int64 value_f1; + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val valueType = + new StructType() + .add("value_f1", LongType, nullable = false) + .add("value_f2", DoubleType, nullable = false) + + val f0Type = MapType(IntegerType, valueType, valueContainsNull = false) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required group value { + | required int64 value_f1; + | required double value_f2; + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "standard map", + + parquetSchema = + """message root { + | required group f0 (MAP) { + | repeated group key_value { + | required int32 key; + | required group value { + | required int32 value_f0; + | required int64 value_f1; + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val valueType = + new StructType() + .add("value_f1", LongType, nullable = false) + .add("value_f2", DoubleType, nullable = false) + + val f0Type = MapType(IntegerType, valueType, valueContainsNull = false) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 (MAP) { + | repeated group key_value { + | required int32 key; + | required group value { + | required int64 value_f1; + | required double value_f2; + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "standard map with complex key", + + parquetSchema = + """message root { + | required group f0 (MAP) { + | repeated group key_value { + | required group key { + | required int32 value_f0; + | required int64 value_f1; + | } + | required int32 value; + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val keyType = + new StructType() + .add("value_f1", LongType, nullable = false) + .add("value_f2", DoubleType, nullable = false) + + val f0Type = MapType(keyType, IntegerType, valueContainsNull = false) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 (MAP) { + | repeated group key_value { + | required group key { + | required int64 value_f1; + | required double value_f2; + | } + | required int32 value; + | } + | } + |} + """.stripMargin) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index eb15a1609f1d..442fafb12f20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.File @@ -23,7 +23,7 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.{DataFrame, SaveMode} +import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} /** * A helper trait that provides convenient facilities for Parquet testing. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala new file mode 100644 index 000000000000..88a3d878f97f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.Row +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { + import ParquetCompatibilityTest._ + + private val parquetFilePath = + Thread.currentThread().getContextClassLoader.getResource("parquet-thrift-compat.snappy.parquet") + + test("Read Parquet file generated by parquet-thrift") { + logInfo( + s"""Schema of the Parquet file written by parquet-thrift: + |${readParquetSchema(parquetFilePath.toString)} + """.stripMargin) + + checkAnswer(sqlContext.read.parquet(parquetFilePath.toString), (0 until 10).map { i => + val suits = Array("SPADES", "HEARTS", "DIAMONDS", "CLUBS") + + val nonNullablePrimitiveValues = Seq( + i % 2 == 0, + i.toByte, + (i + 1).toShort, + i + 2, + i.toLong * 10, + i.toDouble + 0.2d, + // Thrift `BINARY` values are actually unencoded `STRING` values, and thus are always + // treated as `BINARY (UTF8)` in parquet-thrift, since parquet-thrift always assume + // Thrift `STRING`s are encoded using UTF-8. + s"val_$i", + s"val_$i", + // Thrift ENUM values are converted to Parquet binaries containing UTF-8 strings + suits(i % 4)) + + val nullablePrimitiveValues = if (i % 3 == 0) { + Seq.fill(nonNullablePrimitiveValues.length)(null) + } else { + nonNullablePrimitiveValues + } + + val complexValues = Seq( + Seq.tabulate(3)(n => s"arr_${i + n}"), + // Thrift `SET`s are converted to Parquet `LIST`s + Seq(i), + Seq.tabulate(3)(n => (i + n: Integer) -> s"val_${i + n}").toMap, + Seq.tabulate(3) { n => + (i + n) -> Seq.tabulate(3) { m => + Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}") + } + }.toMap) + + Row(nonNullablePrimitiveValues ++ nullablePrimitiveValues ++ complexValues: _*) + }) + } + + test("SPARK-10136 list of primitive list") { + withTempPath { dir => + val path = dir.getCanonicalPath + + // This Parquet schema is translated from the following Thrift schema: + // + // struct ListOfPrimitiveList { + // 1: list> f; + // } + val schema = + s"""message ListOfPrimitiveList { + | required group f (LIST) { + | repeated group f_tuple (LIST) { + | repeated int32 f_tuple_tuple; + | } + | } + |} + """.stripMargin + + writeDirect(path, schema, { rc => + rc.message { + rc.field("f", 0) { + rc.group { + rc.field("f_tuple", 0) { + rc.group { + rc.field("f_tuple_tuple", 0) { + rc.addInteger(0) + rc.addInteger(1) + } + } + + rc.group { + rc.field("f_tuple_tuple", 0) { + rc.addInteger(2) + rc.addInteger(3) + } + } + } + } + } + } + }, { rc => + rc.message { + rc.field("f", 0) { + rc.group { + rc.field("f_tuple", 0) { + rc.group { + rc.field("f_tuple_tuple", 0) { + rc.addInteger(4) + rc.addInteger(5) + } + } + + rc.group { + rc.field("f_tuple_tuple", 0) { + rc.addInteger(6) + rc.addInteger(7) + } + } + } + } + } + } + }) + + logParquetSchema(path) + + checkAnswer( + sqlContext.read.parquet(path), + Seq( + Row(Seq(Seq(0, 1), Seq(2, 3))), + Row(Seq(Seq(4, 5), Seq(6, 7))))) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 8ec3985e0036..22189477d277 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -18,15 +18,11 @@ package org.apache.spark.sql.execution.debug import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.SharedSQLContext + +class DebuggingSuite extends SparkFunSuite with SharedSQLContext { -class DebuggingSuite extends SparkFunSuite { test("DataFrame.debug()") { testData.debug() } - - test("DataFrame.typeCheck()") { - testData.typeCheck() - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala new file mode 100644 index 000000000000..dcbfdca71acb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -0,0 +1,85 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.joins + +import scala.reflect.ClassTag + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest} + +/** + * Test various broadcast join operators with unsafe enabled. + * + * Tests in this suite we need to run Spark in local-cluster mode. In particular, the use of + * unsafe map in [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered + * without serializing the hashed relation, which does not happen in local mode. + */ +class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { + protected var sqlContext: SQLContext = null + + /** + * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled. + */ + override def beforeAll(): Unit = { + super.beforeAll() + val conf = new SparkConf() + .setMaster("local-cluster[2,1,1024]") + .setAppName("testing") + val sc = new SparkContext(conf) + sqlContext = new SQLContext(sc) + sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) + } + + override def afterAll(): Unit = { + sqlContext.sparkContext.stop() + sqlContext = null + } + + /** + * Test whether the specified broadcast join updates the peak execution memory accumulator. + */ + private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = { + AccumulatorSuite.verifyPeakExecutionMemorySet(sqlContext.sparkContext, name) { + val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + // Comparison at the end is for broadcast left semi join + val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") + val df3 = df1.join(broadcast(df2), joinExpression, joinType) + val plan = df3.queryExecution.executedPlan + assert(plan.collect { case p: T => p }.size === 1) + plan.executeCollect() + } + } + + test("unsafe broadcast hash join updates peak execution memory") { + testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash join", "inner") + } + + test("unsafe broadcast hash outer join updates peak execution memory") { + testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer") + } + + test("unsafe broadcast left semi join updates peak execution memory") { + testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi") + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 71db6a215985..e5fd9e277fc6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -17,12 +17,18 @@ package org.apache.spark.sql.execution.joins +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} + import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{Projection, InternalRow} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.collection.CompactBuffer -class HashedRelationSuite extends SparkFunSuite { +class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { // Key is simply the record itself private val keyProjection = new Projection { @@ -31,32 +37,101 @@ class HashedRelationSuite extends SparkFunSuite { test("GeneralHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) - val hashed = HashedRelation(data.iterator, keyProjection) + val numDataRows = SQLMetrics.createLongMetric(sparkContext, "data") + val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) - assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0))) - assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1))) + assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) + assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1))) assert(hashed.get(InternalRow(10)) === null) val data2 = CompactBuffer[InternalRow](data(2)) data2 += data(2) - assert(hashed.get(data(2)) == data2) + assert(hashed.get(data(2)) === data2) + assert(numDataRows.value.value === data.length) } test("UniqueKeyHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2)) - val hashed = HashedRelation(data.iterator, keyProjection) + val numDataRows = SQLMetrics.createLongMetric(sparkContext, "data") + val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) - assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0))) - assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1))) - assert(hashed.get(data(2)) == CompactBuffer[InternalRow](data(2))) + assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) + assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1))) + assert(hashed.get(data(2)) === CompactBuffer[InternalRow](data(2))) assert(hashed.get(InternalRow(10)) === null) val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation] - assert(uniqHashed.getValue(data(0)) == data(0)) - assert(uniqHashed.getValue(data(1)) == data(1)) - assert(uniqHashed.getValue(data(2)) == data(2)) - assert(uniqHashed.getValue(InternalRow(10)) == null) + assert(uniqHashed.getValue(data(0)) === data(0)) + assert(uniqHashed.getValue(data(1)) === data(1)) + assert(uniqHashed.getValue(data(2)) === data(2)) + assert(uniqHashed.getValue(InternalRow(10)) === null) + assert(numDataRows.value.value === data.length) + } + + test("UnsafeHashedRelation") { + val schema = StructType(StructField("a", IntegerType, true) :: Nil) + val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) + val numDataRows = SQLMetrics.createLongMetric(sparkContext, "data") + val toUnsafe = UnsafeProjection.create(schema) + val unsafeData = data.map(toUnsafe(_).copy()).toArray + + val buildKey = Seq(BoundReference(0, IntegerType, false)) + val keyGenerator = UnsafeProjection.create(buildKey) + val hashed = UnsafeHashedRelation(unsafeData.iterator, numDataRows, keyGenerator, 1) + assert(hashed.isInstanceOf[UnsafeHashedRelation]) + + assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) + assert(hashed.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) + assert(hashed.get(toUnsafe(InternalRow(10))) === null) + + val data2 = CompactBuffer[InternalRow](unsafeData(2).copy()) + data2 += unsafeData(2).copy() + assert(hashed.get(unsafeData(2)) === data2) + + val os = new ByteArrayOutputStream() + val out = new ObjectOutputStream(os) + hashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out) + out.flush() + val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) + val hashed2 = new UnsafeHashedRelation() + hashed2.readExternal(in) + assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) + assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) + assert(hashed2.get(toUnsafe(InternalRow(10))) === null) + assert(hashed2.get(unsafeData(2)) === data2) + assert(numDataRows.value.value === data.length) + + val os2 = new ByteArrayOutputStream() + val out2 = new ObjectOutputStream(os2) + hashed2.asInstanceOf[UnsafeHashedRelation].writeExternal(out2) + out2.flush() + // This depends on that the order of items in BytesToBytesMap.iterator() is exactly the same + // as they are inserted + assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray)) + } + + test("test serialization empty hash map") { + val os = new ByteArrayOutputStream() + val out = new ObjectOutputStream(os) + val hashed = new UnsafeHashedRelation( + new java.util.HashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) + hashed.writeExternal(out) + out.flush() + val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) + val hashed2 = new UnsafeHashedRelation() + hashed2.readExternal(in) + + val schema = StructType(StructField("a", IntegerType, true) :: Nil) + val toUnsafe = UnsafeProjection.create(schema) + val row = toUnsafe(InternalRow(0)) + assert(hashed2.get(row) === null) + + val os2 = new ByteArrayOutputStream() + val out2 = new ObjectOutputStream(os2) + hashed2.writeExternal(out2) + out2.flush() + assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala new file mode 100644 index 000000000000..4174ee055021 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.{DataFrame, execution, Row, SQLConf} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} + +class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.localSeqToDataFrameHolder + + private lazy val myUpperCaseData = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( + Row(1, "A"), + Row(2, "B"), + Row(3, "C"), + Row(4, "D"), + Row(5, "E"), + Row(6, "F"), + Row(null, "G") + )), new StructType().add("N", IntegerType).add("L", StringType)) + + private lazy val myLowerCaseData = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( + Row(1, "a"), + Row(2, "b"), + Row(3, "c"), + Row(4, "d"), + Row(null, "e") + )), new StructType().add("n", IntegerType).add("l", StringType)) + + private lazy val myTestData = Seq( + (1, 1), + (1, 2), + (2, 1), + (2, 2), + (3, 1), + (3, 2) + ).toDF("a", "b") + + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable + private def testInnerJoin( + testName: String, + leftRows: => DataFrame, + rightRows: => DataFrame, + condition: () => Expression, + expectedAnswer: Seq[Product]): Unit = { + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition())) + ExtractEquiJoinKeys.unapply(join) + } + + def makeBroadcastHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan, + side: BuildSide) = { + val broadcastHashJoin = + execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan) + boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) + } + + def makeShuffledHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan, + side: BuildSide) = { + val shuffledHashJoin = + execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan) + val filteredJoin = + boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) + EnsureRequirements(sqlContext).apply(filteredJoin) + } + + def makeSortMergeJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan) = { + val sortMergeJoin = + execution.joins.SortMergeJoin(leftKeys, rightKeys, leftPlan, rightPlan) + val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin) + EnsureRequirements(sqlContext).apply(filteredJoin) + } + + test(s"$testName using BroadcastHashJoin (build=left)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeBroadcastHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using BroadcastHashJoin (build=right)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeBroadcastHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using ShuffledHashJoin (build=left)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeShuffledHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using ShuffledHashJoin (build=right)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeShuffledHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using SortMergeJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeSortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + } + + testInnerJoin( + "inner join, one match per row", + myUpperCaseData, + myLowerCaseData, + () => (myUpperCaseData.col("N") === myLowerCaseData.col("n")).expr, + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + ) + ) + + { + lazy val left = myTestData.where("a = 1") + lazy val right = myTestData.where("a = 1") + testInnerJoin( + "inner join, multiple matches", + left, + right, + () => (left.col("a") === right.col("a")).expr, + Seq( + (1, 1, 1, 1), + (1, 1, 1, 2), + (1, 2, 1, 1), + (1, 2, 1, 2) + ) + ) + } + + { + lazy val left = myTestData.where("a = 1") + lazy val right = myTestData.where("a = 2") + testInnerJoin( + "inner join, no matches", + left, + right, + () => (left.col("a") === right.col("a")).expr, + Seq.empty + ) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala new file mode 100644 index 000000000000..09e0237a7cc5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.{DataFrame, Row, SQLConf} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} +import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType} + +class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { + + private lazy val left = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(2, 100.0), + Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches + Row(2, 1.0), + Row(3, 3.0), + Row(5, 1.0), + Row(6, 6.0), + Row(null, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) + + private lazy val right = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( + Row(0, 0.0), + Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches + Row(2, -1.0), + Row(2, -1.0), + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(5, 3.0), + Row(7, 7.0), + Row(null, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + private lazy val condition = { + And((left.col("a") === right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) + } + + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable + private def testOuterJoin( + testName: String, + leftRows: => DataFrame, + rightRows: => DataFrame, + joinType: JoinType, + condition: => Expression, + expectedAnswer: Seq[Product]): Unit = { + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join) + } + + test(s"$testName using ShuffledHashOuterJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext).apply( + ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + if (joinType != FullOuter) { + test(s"$testName using BroadcastHashOuterJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + } + + test(s"$testName using SortMergeOuterJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext).apply( + SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + } + + // --- Basic outer joins ------------------------------------------------------------------------ + + testOuterJoin( + "basic left outer join", + left, + right, + LeftOuter, + condition, + Seq( + (null, null, null, null), + (1, 2.0, null, null), + (2, 100.0, null, null), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null), + (5, 1.0, 5, 3.0), + (6, 6.0, null, null) + ) + ) + + testOuterJoin( + "basic right outer join", + left, + right, + RightOuter, + condition, + Seq( + (null, null, null, null), + (null, null, 0, 0.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (null, null, 2, -1.0), + (null, null, 2, -1.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0), + (5, 1.0, 5, 3.0), + (null, null, 7, 7.0) + ) + ) + + testOuterJoin( + "basic full outer join", + left, + right, + FullOuter, + condition, + Seq( + (1, 2.0, null, null), + (null, null, 2, -1.0), + (null, null, 2, -1.0), + (2, 100.0, null, null), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null), + (5, 1.0, 5, 3.0), + (6, 6.0, null, null), + (null, null, 0, 0.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0), + (null, null, 7, 7.0), + (null, null, null, null), + (null, null, null, null) + ) + ) + + // --- Both inputs empty ------------------------------------------------------------------------ + + testOuterJoin( + "left outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + LeftOuter, + condition, + Seq.empty + ) + + testOuterJoin( + "right outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + RightOuter, + condition, + Seq.empty + ) + + testOuterJoin( + "full outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + FullOuter, + condition, + Seq.empty + ) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala new file mode 100644 index 000000000000..3afd762942bc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.{SQLConf, DataFrame, Row} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression} +import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} + +class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { + + private lazy val left = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(1, 2.0), + Row(2, 1.0), + Row(2, 1.0), + Row(3, 3.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) + + private lazy val right = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( + Row(2, 3.0), + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + private lazy val condition = { + And((left.col("a") === right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) + } + + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable + private def testLeftSemiJoin( + testName: String, + leftRows: => DataFrame, + rightRows: => DataFrame, + condition: => Expression, + expectedAnswer: Seq[Product]): Unit = { + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join) + } + + test(s"$testName using LeftSemiJoinHash") { + extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext).apply( + LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using BroadcastLeftSemiJoinHash") { + extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using LeftSemiJoinBNL") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + LeftSemiJoinBNL(left, right, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + testLeftSemiJoin( + "basic test", + left, + right, + condition, + Seq( + (2, 1.0), + (2, 1.0) + ) + ) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala new file mode 100644 index 000000000000..efc3227dd60d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala @@ -0,0 +1,68 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation + +/** + * A dummy [[LocalNode]] that just returns rows from a [[LocalRelation]]. + */ +private[local] case class DummyNode( + output: Seq[Attribute], + relation: LocalRelation, + conf: SQLConf) + extends LocalNode(conf) { + + import DummyNode._ + + private var index: Int = CLOSED + private val input: Seq[InternalRow] = relation.data + + def this(output: Seq[Attribute], data: Seq[Product], conf: SQLConf = new SQLConf) { + this(output, LocalRelation.fromProduct(output, data), conf) + } + + def isOpen: Boolean = index != CLOSED + + override def children: Seq[LocalNode] = Seq.empty + + override def open(): Unit = { + index = -1 + } + + override def next(): Boolean = { + index += 1 + index < input.size + } + + override def fetch(): InternalRow = { + assert(index >= 0 && index < input.size) + input(index) + } + + override def close(): Unit = { + index = CLOSED + } +} + +private object DummyNode { + val CLOSED: Int = Int.MinValue +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala new file mode 100644 index 000000000000..bbd94d8da2d1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala @@ -0,0 +1,49 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.catalyst.dsl.expressions._ + + +class ExpandNodeSuite extends LocalNodeTest { + + private def testExpand(inputData: Array[(Int, Int)] = Array.empty): Unit = { + val inputNode = new DummyNode(kvIntAttributes, inputData) + val projections = Seq(Seq('k + 'v, 'k - 'v), Seq('k * 'v, 'k / 'v)) + val expandNode = new ExpandNode(conf, projections, inputNode.output, inputNode) + val resolvedNode = resolveExpressions(expandNode) + val expectedOutput = { + val firstHalf = inputData.map { case (k, v) => (k + v, k - v) } + val secondHalf = inputData.map { case (k, v) => (k * v, k / v) } + firstHalf ++ secondHalf + } + val actualOutput = resolvedNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput.toSet === expectedOutput.toSet) + } + + test("empty") { + testExpand() + } + + test("basic") { + testExpand((1 to 100).map { i => (i, i * 1000) }.toArray) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala new file mode 100644 index 000000000000..4eadce646d37 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala @@ -0,0 +1,45 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.catalyst.dsl.expressions._ + + +class FilterNodeSuite extends LocalNodeTest { + + private def testFilter(inputData: Array[(Int, Int)] = Array.empty): Unit = { + val cond = 'k % 2 === 0 + val inputNode = new DummyNode(kvIntAttributes, inputData) + val filterNode = new FilterNode(conf, cond, inputNode) + val resolvedNode = resolveExpressions(filterNode) + val expectedOutput = inputData.filter { case (k, _) => k % 2 == 0 } + val actualOutput = resolvedNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) + } + + test("empty") { + testFilter() + } + + test("basic") { + testFilter((1 to 100).map { i => (i, i) }.toArray) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala new file mode 100644 index 000000000000..5c1bdb088eee --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -0,0 +1,97 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} + + +class HashJoinNodeSuite extends LocalNodeTest { + + // Test all combinations of the two dimensions: with/out unsafe and build sides + private val maybeUnsafeAndCodegen = Seq(false, true) + private val buildSides = Seq(BuildLeft, BuildRight) + maybeUnsafeAndCodegen.foreach { unsafeAndCodegen => + buildSides.foreach { buildSide => + testJoin(unsafeAndCodegen, buildSide) + } + } + + /** + * Test inner hash join with varying degrees of matches. + */ + private def testJoin( + unsafeAndCodegen: Boolean, + buildSide: BuildSide): Unit = { + val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe" + val testNamePrefix = s"$simpleOrUnsafe / $buildSide" + val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray + val conf = new SQLConf + conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen) + conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen) + + // Actual test body + def runTest(leftInput: Array[(Int, String)], rightInput: Array[(Int, String)]): Unit = { + val rightInputMap = rightInput.toMap + val leftNode = new DummyNode(joinNameAttributes, leftInput) + val rightNode = new DummyNode(joinNicknameAttributes, rightInput) + val makeNode = (node1: LocalNode, node2: LocalNode) => { + resolveExpressions(new HashJoinNode( + conf, Seq('id1), Seq('id2), buildSide, node1, node2)) + } + val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode + val hashJoinNode = makeUnsafeNode(leftNode, rightNode) + val expectedOutput = leftInput + .filter { case (k, _) => rightInputMap.contains(k) } + .map { case (k, v) => (k, v, k, rightInputMap(k)) } + val actualOutput = hashJoinNode.collect().map { row => + // (id, name, id, nickname) + (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) + } + assert(actualOutput === expectedOutput) + } + + test(s"$testNamePrefix: empty") { + runTest(Array.empty, Array.empty) + runTest(someData, Array.empty) + runTest(Array.empty, someData) + } + + test(s"$testNamePrefix: no matches") { + val someIrrelevantData = (10000 to 100100).map { i => (i, "piper" + i) }.toArray + runTest(someData, Array.empty) + runTest(Array.empty, someData) + runTest(someData, someIrrelevantData) + runTest(someIrrelevantData, someData) + } + + test(s"$testNamePrefix: partial matches") { + val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray + runTest(someData, someOtherData) + runTest(someOtherData, someData) + } + + test(s"$testNamePrefix: full matches") { + val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) }.toArray + runTest(someData, someSuperRelevantData) + runTest(someSuperRelevantData, someData) + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala new file mode 100644 index 000000000000..c0ad2021b204 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala @@ -0,0 +1,37 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + + +class IntersectNodeSuite extends LocalNodeTest { + + test("basic") { + val n = 100 + val leftData = (1 to n).filter { i => i % 2 == 0 }.map { i => (i, i) }.toArray + val rightData = (1 to n).filter { i => i % 3 == 0 }.map { i => (i, i) }.toArray + val leftNode = new DummyNode(kvIntAttributes, leftData) + val rightNode = new DummyNode(kvIntAttributes, rightData) + val intersectNode = new IntersectNode(conf, leftNode, rightNode) + val expectedOutput = leftData.intersect(rightData) + val actualOutput = intersectNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala new file mode 100644 index 000000000000..fb790636a368 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala @@ -0,0 +1,41 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + + +class LimitNodeSuite extends LocalNodeTest { + + private def testLimit(inputData: Array[(Int, Int)] = Array.empty, limit: Int = 10): Unit = { + val inputNode = new DummyNode(kvIntAttributes, inputData) + val limitNode = new LimitNode(conf, limit, inputNode) + val expectedOutput = inputData.take(limit) + val actualOutput = limitNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) + } + + test("empty") { + testLimit() + } + + test("basic") { + testLimit((1 to 100).map { i => (i, i) }.toArray, 20) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala new file mode 100644 index 000000000000..0d1ed99eec6c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala @@ -0,0 +1,73 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + + +class LocalNodeSuite extends LocalNodeTest { + private val data = (1 to 100).map { i => (i, i) }.toArray + + test("basic open, next, fetch, close") { + val node = new DummyNode(kvIntAttributes, data) + assert(!node.isOpen) + node.open() + assert(node.isOpen) + data.foreach { case (k, v) => + assert(node.next()) + // fetch should be idempotent + val fetched = node.fetch() + assert(node.fetch() === fetched) + assert(node.fetch() === fetched) + assert(node.fetch().numFields === 2) + assert(node.fetch().getInt(0) === k) + assert(node.fetch().getInt(1) === v) + } + assert(!node.next()) + node.close() + assert(!node.isOpen) + } + + test("asIterator") { + val node = new DummyNode(kvIntAttributes, data) + val iter = node.asIterator + node.open() + data.foreach { case (k, v) => + // hasNext should be idempotent + assert(iter.hasNext) + assert(iter.hasNext) + val item = iter.next() + assert(item.numFields === 2) + assert(item.getInt(0) === k) + assert(item.getInt(1) === v) + } + intercept[NoSuchElementException] { + iter.next() + } + node.close() + } + + test("collect") { + val node = new DummyNode(kvIntAttributes, data) + node.open() + val collected = node.collect() + assert(collected.size === data.size) + assert(collected.forall(_.size === 2)) + assert(collected.map { case row => (row.getInt(0), row.getInt(0)) } === data) + node.close() + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala new file mode 100644 index 000000000000..098050bcd223 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -0,0 +1,70 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.types.{IntegerType, StringType} + + +class LocalNodeTest extends SparkFunSuite { + + protected val conf: SQLConf = new SQLConf + protected val kvIntAttributes = Seq( + AttributeReference("k", IntegerType)(), + AttributeReference("v", IntegerType)()) + protected val joinNameAttributes = Seq( + AttributeReference("id1", IntegerType)(), + AttributeReference("name", StringType)()) + protected val joinNicknameAttributes = Seq( + AttributeReference("id2", IntegerType)(), + AttributeReference("nickname", StringType)()) + + /** + * Wrap a function processing two [[LocalNode]]s such that: + * (1) all input rows are automatically converted to unsafe rows + * (2) all output rows are automatically converted back to safe rows + */ + protected def wrapForUnsafe( + f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { + (left: LocalNode, right: LocalNode) => { + val _left = ConvertToUnsafeNode(conf, left) + val _right = ConvertToUnsafeNode(conf, right) + val r = f(_left, _right) + ConvertToSafeNode(conf, r) + } + } + + /** + * Recursively resolve all expressions in a [[LocalNode]] using the node's attributes. + */ + protected def resolveExpressions(outputNode: LocalNode): LocalNode = { + outputNode transform { + case node: LocalNode => + val inputMap = node.output.map { a => (a.name, a) }.toMap + node transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala new file mode 100644 index 000000000000..40299d9d5ee3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala @@ -0,0 +1,145 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} + + +class NestedLoopJoinNodeSuite extends LocalNodeTest { + + // Test all combinations of the three dimensions: with/out unsafe, build sides, and join types + private val maybeUnsafeAndCodegen = Seq(false, true) + private val buildSides = Seq(BuildLeft, BuildRight) + private val joinTypes = Seq(LeftOuter, RightOuter, FullOuter) + maybeUnsafeAndCodegen.foreach { unsafeAndCodegen => + buildSides.foreach { buildSide => + joinTypes.foreach { joinType => + testJoin(unsafeAndCodegen, buildSide, joinType) + } + } + } + + /** + * Test outer nested loop joins with varying degrees of matches. + */ + private def testJoin( + unsafeAndCodegen: Boolean, + buildSide: BuildSide, + joinType: JoinType): Unit = { + val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe" + val testNamePrefix = s"$simpleOrUnsafe / $buildSide / $joinType" + val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray + val conf = new SQLConf + conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen) + conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen) + + // Actual test body + def runTest( + joinType: JoinType, + leftInput: Array[(Int, String)], + rightInput: Array[(Int, String)]): Unit = { + val leftNode = new DummyNode(joinNameAttributes, leftInput) + val rightNode = new DummyNode(joinNicknameAttributes, rightInput) + val cond = 'id1 === 'id2 + val makeNode = (node1: LocalNode, node2: LocalNode) => { + resolveExpressions( + new NestedLoopJoinNode(conf, node1, node2, buildSide, joinType, Some(cond))) + } + val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode + val hashJoinNode = makeUnsafeNode(leftNode, rightNode) + val expectedOutput = generateExpectedOutput(leftInput, rightInput, joinType) + val actualOutput = hashJoinNode.collect().map { row => + // (id, name, id, nickname) + (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) + } + assert(actualOutput.toSet === expectedOutput.toSet) + } + + test(s"$testNamePrefix: empty") { + runTest(joinType, Array.empty, Array.empty) + } + + test(s"$testNamePrefix: no matches") { + val someIrrelevantData = (10000 to 10100).map { i => (i, "piper" + i) }.toArray + runTest(joinType, someData, Array.empty) + runTest(joinType, Array.empty, someData) + runTest(joinType, someData, someIrrelevantData) + runTest(joinType, someIrrelevantData, someData) + } + + test(s"$testNamePrefix: partial matches") { + val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray + runTest(joinType, someData, someOtherData) + runTest(joinType, someOtherData, someData) + } + + test(s"$testNamePrefix: full matches") { + val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) } + runTest(joinType, someData, someSuperRelevantData) + runTest(joinType, someSuperRelevantData, someData) + } + } + + /** + * Helper method to generate the expected output of a test based on the join type. + */ + private def generateExpectedOutput( + leftInput: Array[(Int, String)], + rightInput: Array[(Int, String)], + joinType: JoinType): Array[(Int, String, Int, String)] = { + joinType match { + case LeftOuter => + val rightInputMap = rightInput.toMap + leftInput.map { case (k, v) => + val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0) + val rightValue = rightInputMap.getOrElse(k, null) + (k, v, rightKey, rightValue) + } + + case RightOuter => + val leftInputMap = leftInput.toMap + rightInput.map { case (k, v) => + val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0) + val leftValue = leftInputMap.getOrElse(k, null) + (leftKey, leftValue, k, v) + } + + case FullOuter => + val leftInputMap = leftInput.toMap + val rightInputMap = rightInput.toMap + val leftOutput = leftInput.map { case (k, v) => + val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0) + val rightValue = rightInputMap.getOrElse(k, null) + (k, v, rightKey, rightValue) + } + val rightOutput = rightInput.map { case (k, v) => + val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0) + val leftValue = leftInputMap.getOrElse(k, null) + (leftKey, leftValue, k, v) + } + (leftOutput ++ rightOutput).distinct + + case other => + throw new IllegalArgumentException(s"Join type $other is not applicable") + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala new file mode 100644 index 000000000000..02ecb23d34b2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala @@ -0,0 +1,49 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NamedExpression} +import org.apache.spark.sql.types.{IntegerType, StringType} + + +class ProjectNodeSuite extends LocalNodeTest { + private val pieAttributes = Seq( + AttributeReference("id", IntegerType)(), + AttributeReference("age", IntegerType)(), + AttributeReference("name", StringType)()) + + private def testProject(inputData: Array[(Int, Int, String)] = Array.empty): Unit = { + val inputNode = new DummyNode(pieAttributes, inputData) + val columns = Seq[NamedExpression](inputNode.output(0), inputNode.output(2)) + val projectNode = new ProjectNode(conf, columns, inputNode) + val expectedOutput = inputData.map { case (id, age, name) => (id, name) } + val actualOutput = projectNode.collect().map { case row => + (row.getInt(0), row.getString(1)) + } + assert(actualOutput === expectedOutput) + } + + test("empty") { + testProject() + } + + test("basic") { + testProject((1 to 100).map { i => (i, i + 1, "pie" + i) }.toArray) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala new file mode 100644 index 000000000000..a3e83bbd5145 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} + + +class SampleNodeSuite extends LocalNodeTest { + + private def testSample(withReplacement: Boolean): Unit = { + val seed = 0L + val lowerb = 0.0 + val upperb = 0.3 + val maybeOut = if (withReplacement) "" else "out" + test(s"with$maybeOut replacement") { + val inputData = (1 to 1000).map { i => (i, i) }.toArray + val inputNode = new DummyNode(kvIntAttributes, inputData) + val sampleNode = new SampleNode(conf, lowerb, upperb, withReplacement, seed, inputNode) + val sampler = + if (withReplacement) { + new PoissonSampler[(Int, Int)](upperb - lowerb, useGapSamplingIfPossible = false) + } else { + new BernoulliCellSampler[(Int, Int)](lowerb, upperb) + } + sampler.setSeed(seed) + val expectedOutput = sampler.sample(inputData.iterator).toArray + val actualOutput = sampleNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) + } + } + + testSample(withReplacement = true) + testSample(withReplacement = false) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala new file mode 100644 index 000000000000..42ebc7bfcaad --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import scala.util.Random + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.SortOrder + + +class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { + + private def testTakeOrderedAndProject(desc: Boolean): Unit = { + val limit = 10 + val ascOrDesc = if (desc) "desc" else "asc" + test(ascOrDesc) { + val inputData = Random.shuffle((1 to 100).toList).map { i => (i, i) }.toArray + val inputNode = new DummyNode(kvIntAttributes, inputData) + val firstColumn = inputNode.output(0) + val sortDirection = if (desc) Descending else Ascending + val sortOrder = SortOrder(firstColumn, sortDirection) + val takeOrderAndProjectNode = new TakeOrderedAndProjectNode( + conf, limit, Seq(sortOrder), Some(Seq(firstColumn)), inputNode) + val expectedOutput = inputData + .map { case (k, _) => k } + .sortBy { k => k * (if (desc) -1 else 1) } + .take(limit) + val actualOutput = takeOrderAndProjectNode.collect().map { row => row.getInt(0) } + assert(actualOutput === expectedOutput) + } + } + + testTakeOrderedAndProject(desc = false) + testTakeOrderedAndProject(desc = true) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala new file mode 100644 index 000000000000..666b0235c061 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala @@ -0,0 +1,55 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + + +class UnionNodeSuite extends LocalNodeTest { + + private def testUnion(inputData: Seq[Array[(Int, Int)]]): Unit = { + val inputNodes = inputData.map { data => + new DummyNode(kvIntAttributes, data) + } + val unionNode = new UnionNode(conf, inputNodes) + val expectedOutput = inputData.flatten + val actualOutput = unionNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) + } + + test("empty") { + testUnion(Seq(Array.empty)) + testUnion(Seq(Array.empty, Array.empty)) + } + + test("self") { + val data = (1 to 100).map { i => (i, i) }.toArray + testUnion(Seq(data)) + testUnion(Seq(data, data)) + testUnion(Seq(data, data, data)) + } + + test("basic") { + val zero = Array.empty[(Int, Int)] + val one = (1 to 100).map { i => (i, i) }.toArray + val two = (50 to 150).map { i => (i, i) }.toArray + val three = (800 to 900).map { i => (i, i) }.toArray + testUnion(Seq(zero, one, two, three)) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala new file mode 100644 index 000000000000..6afffae161ef --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -0,0 +1,577 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.metric + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + +import scala.collection.mutable + +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm._ +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.ui.SparkPlanGraph +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils + + +class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ + + test("LongSQLMetric should not box Long") { + val l = SQLMetrics.createLongMetric(sparkContext, "long") + val f = () => { + l += 1L + l.add(1L) + } + BoxingFinder.getClassReader(f.getClass).foreach { cl => + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}") + } + } + + test("Normal accumulator should do boxing") { + // We need this test to make sure BoxingFinder works. + val l = sparkContext.accumulator(0L) + val f = () => { l += 1L } + BoxingFinder.getClassReader(f.getClass).foreach { cl => + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.nonEmpty, "Found find boxing in this test") + } + } + + /** + * Call `df.collect()` and verify if the collected metrics are same as "expectedMetrics". + * + * @param df `DataFrame` to run + * @param expectedNumOfJobs number of jobs that will run + * @param expectedMetrics the expected metrics. The format is + * `nodeId -> (operatorName, metric name -> metric value)`. + */ + private def testSparkPlanMetrics( + df: DataFrame, + expectedNumOfJobs: Int, + expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { + val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + df.collect() + sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = sqlContext.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change it to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= expectedNumOfJobs) + if (jobs.size == expectedNumOfJobs) { + // If we can track all jobs, check the metric values + val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node => + expectedMetrics.contains(node.id) + }.map { node => + val nodeMetrics = node.metrics.map { metric => + val metricValue = metricValues(metric.accumulatorId) + (metric.name, metricValue) + }.toMap + (node.id, node.name -> nodeMetrics) + }.toMap + assert(expectedMetrics === actualMetrics) + } else { + // TODO Remove this "else" once we fix the race condition that missing the JobStarted event. + // Since we cannot track all jobs, the metric values could be wrong and we should not check + // them. + logWarning("Due to a race condition, we miss some jobs and cannot verify the metric values") + } + } + + test("Project metrics") { + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "false", + SQLConf.CODEGEN_ENABLED.key -> "false", + SQLConf.TUNGSTEN_ENABLED.key -> "false") { + // Assume the execution plan is + // PhysicalRDD(nodeId = 1) -> Project(nodeId = 0) + val df = person.select('name) + testSparkPlanMetrics(df, 1, Map( + 0L ->("Project", Map( + "number of rows" -> 2L))) + ) + } + } + + test("TungstenProject metrics") { + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "true", + SQLConf.CODEGEN_ENABLED.key -> "true", + SQLConf.TUNGSTEN_ENABLED.key -> "true") { + // Assume the execution plan is + // PhysicalRDD(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = person.select('name) + testSparkPlanMetrics(df, 1, Map( + 0L ->("TungstenProject", Map( + "number of rows" -> 2L))) + ) + } + } + + test("Filter metrics") { + // Assume the execution plan is + // PhysicalRDD(nodeId = 1) -> Filter(nodeId = 0) + val df = person.filter('age < 25) + testSparkPlanMetrics(df, 1, Map( + 0L -> ("Filter", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + } + + test("Aggregate metrics") { + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "false", + SQLConf.CODEGEN_ENABLED.key -> "false", + SQLConf.TUNGSTEN_ENABLED.key -> "false") { + // Assume the execution plan is + // ... -> Aggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> Aggregate(nodeId = 0) + val df = testData2.groupBy().count() // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("Aggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 2L)), + 0L -> ("Aggregate", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + + // 2 partitions and each partition contains 2 keys + val df2 = testData2.groupBy('a).count() + testSparkPlanMetrics(df2, 1, Map( + 2L -> ("Aggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 4L)), + 0L -> ("Aggregate", Map( + "number of input rows" -> 4L, + "number of output rows" -> 3L))) + ) + } + } + + test("SortBasedAggregate metrics") { + // Because SortBasedAggregate may skip different rows if the number of partitions is different, + // this test should use the deterministic number of partitions. + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "false", + SQLConf.CODEGEN_ENABLED.key -> "true", + SQLConf.TUNGSTEN_ENABLED.key -> "true") { + // Assume the execution plan is + // ... -> SortBasedAggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> + // SortBasedAggregate(nodeId = 0) + val df = testData2.groupBy().count() // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("SortBasedAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 2L)), + 0L -> ("SortBasedAggregate", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + + // Assume the execution plan is + // ... -> SortBasedAggregate(nodeId = 3) -> TungstenExchange(nodeId = 2) + // -> ExternalSort(nodeId = 1)-> SortBasedAggregate(nodeId = 0) + // 2 partitions and each partition contains 2 keys + val df2 = testData2.groupBy('a).count() + testSparkPlanMetrics(df2, 1, Map( + 3L -> ("SortBasedAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 4L)), + 0L -> ("SortBasedAggregate", Map( + "number of input rows" -> 4L, + "number of output rows" -> 3L))) + ) + } + } + + test("TungstenAggregate metrics") { + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "true", + SQLConf.CODEGEN_ENABLED.key -> "true", + SQLConf.TUNGSTEN_ENABLED.key -> "true") { + // Assume the execution plan is + // ... -> TungstenAggregate(nodeId = 2) -> Exchange(nodeId = 1) + // -> TungstenAggregate(nodeId = 0) + val df = testData2.groupBy().count() // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("TungstenAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 2L)), + 0L -> ("TungstenAggregate", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + + // 2 partitions and each partition contains 2 keys + val df2 = testData2.groupBy('a).count() + testSparkPlanMetrics(df2, 1, Map( + 2L -> ("TungstenAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 4L)), + 0L -> ("TungstenAggregate", Map( + "number of input rows" -> 4L, + "number of output rows" -> 3L))) + ) + } + } + + test("SortMergeJoin metrics") { + // Because SortMergeJoin may skip different rows if the number of partitions is different, this + // test should use the deterministic number of partitions. + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("SortMergeJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of left rows" -> 4L, + "number of right rows" -> 2L, + "number of output rows" -> 4L))) + ) + } + } + } + + test("SortMergeOuterJoin metrics") { + // Because SortMergeOuterJoin may skip different rows if the number of partitions is different, + // this test should use the deterministic number of partitions. + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> SortMergeOuterJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 left JOIN testDataForJoin ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("SortMergeOuterJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of left rows" -> 6L, + "number of right rows" -> 2L, + "number of output rows" -> 8L))) + ) + + val df2 = sqlContext.sql( + "SELECT * FROM testDataForJoin right JOIN testData2 ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df2, 1, Map( + 1L -> ("SortMergeOuterJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of left rows" -> 2L, + "number of right rows" -> 6L, + "number of output rows" -> 8L))) + ) + } + } + } + + test("BroadcastHashJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key", "value") + // Assume the execution plan is + // ... -> BroadcastHashJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = df1.join(broadcast(df2), "key") + testSparkPlanMetrics(df, 2, Map( + 1L -> ("BroadcastHashJoin", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("ShuffledHashJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> ShuffledHashJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("ShuffledHashJoin", Map( + "number of left rows" -> 6L, + "number of right rows" -> 2L, + "number of output rows" -> 4L))) + ) + } + } + } + + test("ShuffledHashOuterJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") + val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") + // Assume the execution plan is + // ... -> ShuffledHashOuterJoin(nodeId = 0) + val df = df1.join(df2, $"key" === $"key2", "left_outer") + testSparkPlanMetrics(df, 1, Map( + 0L -> ("ShuffledHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 5L))) + ) + + val df3 = df1.join(df2, $"key" === $"key2", "right_outer") + testSparkPlanMetrics(df3, 1, Map( + 0L -> ("ShuffledHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 6L))) + ) + + val df4 = df1.join(df2, $"key" === $"key2", "outer") + testSparkPlanMetrics(df4, 1, Map( + 0L -> ("ShuffledHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 7L))) + ) + } + } + + test("BroadcastHashOuterJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") + val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") + // Assume the execution plan is + // ... -> BroadcastHashOuterJoin(nodeId = 0) + val df = df1.join(broadcast(df2), $"key" === $"key2", "left_outer") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("BroadcastHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 5L))) + ) + + val df3 = df1.join(broadcast(df2), $"key" === $"key2", "right_outer") + testSparkPlanMetrics(df3, 2, Map( + 0L -> ("BroadcastHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 6L))) + ) + } + } + + test("BroadcastNestedLoopJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 left JOIN testDataForJoin ON " + + "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a") + testSparkPlanMetrics(df, 3, Map( + 1L -> ("BroadcastNestedLoopJoin", Map( + "number of left rows" -> 12L, // left needs to be scanned twice + "number of right rows" -> 2L, + "number of output rows" -> 12L))) + ) + } + } + } + + test("BroadcastLeftSemiJoinHash metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + // Assume the execution plan is + // ... -> BroadcastLeftSemiJoinHash(nodeId = 0) + val df = df1.join(broadcast(df2), $"key" === $"key2", "leftsemi") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("BroadcastLeftSemiJoinHash", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("LeftSemiJoinHash metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + // Assume the execution plan is + // ... -> LeftSemiJoinHash(nodeId = 0) + val df = df1.join(df2, $"key" === $"key2", "leftsemi") + testSparkPlanMetrics(df, 1, Map( + 0L -> ("LeftSemiJoinHash", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("LeftSemiJoinBNL metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + // Assume the execution plan is + // ... -> LeftSemiJoinBNL(nodeId = 0) + val df = df1.join(df2, $"key" < $"key2", "leftsemi") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("LeftSemiJoinBNL", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("CartesianProduct metrics") { + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 JOIN testDataForJoin") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("CartesianProduct", Map( + "number of left rows" -> 12L, // left needs to be scanned twice + "number of right rows" -> 12L, // right is read 6 times + "number of output rows" -> 12L))) + ) + } + } + + test("save metrics") { + withTempPath { file => + val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + // Assume the execution plan is + // PhysicalRDD(nodeId = 0) + person.select('name).write.format("json").save(file.getAbsolutePath) + sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = sqlContext.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= 1) + val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + // Because "save" will create a new DataFrame internally, we cannot get the real metric id. + // However, we still can check the value. + assert(metricValues.values.toSeq === Seq(2L)) + } + } + +} + +private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String) + +/** + * If `method` is null, search all methods of this class recursively to find if they do some boxing. + * If `method` is specified, only search this method of the class to speed up the searching. + * + * This method will skip the methods in `visitedMethods` to avoid potential infinite cycles. + */ +private class BoxingFinder( + method: MethodIdentifier[_] = null, + val boxingInvokes: mutable.Set[String] = mutable.Set.empty, + visitedMethods: mutable.Set[MethodIdentifier[_]] = mutable.Set.empty) + extends ClassVisitor(ASM4) { + + private val primitiveBoxingClassName = + Set("java/lang/Long", + "java/lang/Double", + "java/lang/Integer", + "java/lang/Float", + "java/lang/Short", + "java/lang/Character", + "java/lang/Byte", + "java/lang/Boolean") + + override def visitMethod( + access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): + MethodVisitor = { + if (method != null && (method.name != name || method.desc != desc)) { + // If method is specified, skip other methods. + return new MethodVisitor(ASM4) {} + } + + new MethodVisitor(ASM4) { + override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { + if (op == INVOKESPECIAL && name == "" || op == INVOKESTATIC && name == "valueOf") { + if (primitiveBoxingClassName.contains(owner)) { + // Find boxing methods, e.g, new java.lang.Long(l) or java.lang.Long.valueOf(l) + boxingInvokes.add(s"$owner.$name") + } + } else { + // scalastyle:off classforname + val classOfMethodOwner = Class.forName(owner.replace('/', '.'), false, + Thread.currentThread.getContextClassLoader) + // scalastyle:on classforname + val m = MethodIdentifier(classOfMethodOwner, name, desc) + if (!visitedMethods.contains(m)) { + // Keep track of visited methods to avoid potential infinite cycles + visitedMethods += m + BoxingFinder.getClassReader(classOfMethodOwner).foreach { cl => + visitedMethods += m + cl.accept(new BoxingFinder(m, boxingInvokes, visitedMethods), 0) + } + } + } + } + } + } +} + +private object BoxingFinder { + + def getClassReader(cls: Class[_]): Option[ClassReader] = { + val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" + val resourceStream = cls.getResourceAsStream(className) + val baos = new ByteArrayOutputStream(128) + // Copy data over, before delegating to ClassReader - + // else we can run out of open file handles. + Utils.copyStream(resourceStream, baos, true) + // ASM4 doesn't support Java 8 classes, which requires ASM5. + // So if the class is ASM5 (E.g., java.lang.Long when using JDK8 runtime to run these codes), + // then ClassReader will throw IllegalArgumentException, + // However, since this is only for testing, it's safe to skip these classes. + try { + Some(new ClassReader(new ByteArrayInputStream(baos.toByteArray))) + } catch { + case _: IllegalArgumentException => None + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala new file mode 100644 index 000000000000..7a46c69a056b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import java.util.Properties + +import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.sql.execution.metric.LongSQLMetricValue +import org.apache.spark.scheduler._ +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.test.SharedSQLContext + +class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ + + private def createTestDataFrame: DataFrame = { + Seq( + (1, 1), + (2, 2) + ).toDF().filter("_1 > 1") + } + + private def createProperties(executionId: Long): Properties = { + val properties = new Properties() + properties.setProperty(SQLExecution.EXECUTION_ID_KEY, executionId.toString) + properties + } + + private def createStageInfo(stageId: Int, attemptId: Int): StageInfo = new StageInfo( + stageId = stageId, + attemptId = attemptId, + // The following fields are not used in tests + name = "", + numTasks = 0, + rddInfos = Nil, + parentIds = Nil, + details = "" + ) + + private def createTaskInfo(taskId: Int, attemptNumber: Int): TaskInfo = new TaskInfo( + taskId = taskId, + attemptNumber = attemptNumber, + // The following fields are not used in tests + index = 0, + launchTime = 0, + executorId = "", + host = "", + taskLocality = null, + speculative = false + ) + + private def createTaskMetrics(accumulatorUpdates: Map[Long, Long]): TaskMetrics = { + val metrics = new TaskMetrics + metrics.setAccumulatorsUpdater(() => accumulatorUpdates.mapValues(new LongSQLMetricValue(_))) + metrics.updateAccumulators() + metrics + } + + test("basic") { + val listener = new SQLListener(sqlContext) + val executionId = 0 + val df = createTestDataFrame + val accumulatorIds = + SparkPlanGraph(df.queryExecution.executedPlan).nodes.flatMap(_.metrics.map(_.accumulatorId)) + // Assume all accumulators are long + var accumulatorValue = 0L + val accumulatorUpdates = accumulatorIds.map { id => + accumulatorValue += 1L + (id, accumulatorValue) + }.toMap + + listener.onExecutionStart( + executionId, + "test", + "test", + df.queryExecution.toString, + SparkPlanGraph(df.queryExecution.executedPlan), + System.currentTimeMillis()) + + val executionUIData = listener.executionIdToData(0) + + listener.onJobStart(SparkListenerJobStart( + jobId = 0, + time = System.currentTimeMillis(), + stageInfos = Seq( + createStageInfo(0, 0), + createStageInfo(1, 0) + ), + createProperties(executionId))) + listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 0))) + + assert(listener.getExecutionMetrics(0).isEmpty) + + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( + // (task id, stage id, stage attempt, metrics) + (0L, 0, 0, createTaskMetrics(accumulatorUpdates)), + (1L, 0, 0, createTaskMetrics(accumulatorUpdates)) + ))) + + assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 2)) + + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( + // (task id, stage id, stage attempt, metrics) + (0L, 0, 0, createTaskMetrics(accumulatorUpdates)), + (1L, 0, 0, createTaskMetrics(accumulatorUpdates.mapValues(_ * 2))) + ))) + + assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 3)) + + // Retrying a stage should reset the metrics + listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 1))) + + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( + // (task id, stage id, stage attempt, metrics) + (0L, 0, 1, createTaskMetrics(accumulatorUpdates)), + (1L, 0, 1, createTaskMetrics(accumulatorUpdates)) + ))) + + assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 2)) + + // Ignore the task end for the first attempt + listener.onTaskEnd(SparkListenerTaskEnd( + stageId = 0, + stageAttemptId = 0, + taskType = "", + reason = null, + createTaskInfo(0, 0), + createTaskMetrics(accumulatorUpdates.mapValues(_ * 100)))) + + assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 2)) + + // Finish two tasks + listener.onTaskEnd(SparkListenerTaskEnd( + stageId = 0, + stageAttemptId = 1, + taskType = "", + reason = null, + createTaskInfo(0, 0), + createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)))) + listener.onTaskEnd(SparkListenerTaskEnd( + stageId = 0, + stageAttemptId = 1, + taskType = "", + reason = null, + createTaskInfo(1, 0), + createTaskMetrics(accumulatorUpdates.mapValues(_ * 3)))) + + assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 5)) + + // Summit a new stage + listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(1, 0))) + + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( + // (task id, stage id, stage attempt, metrics) + (0L, 1, 0, createTaskMetrics(accumulatorUpdates)), + (1L, 1, 0, createTaskMetrics(accumulatorUpdates)) + ))) + + assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 7)) + + // Finish two tasks + listener.onTaskEnd(SparkListenerTaskEnd( + stageId = 1, + stageAttemptId = 0, + taskType = "", + reason = null, + createTaskInfo(0, 0), + createTaskMetrics(accumulatorUpdates.mapValues(_ * 3)))) + listener.onTaskEnd(SparkListenerTaskEnd( + stageId = 1, + stageAttemptId = 0, + taskType = "", + reason = null, + createTaskInfo(1, 0), + createTaskMetrics(accumulatorUpdates.mapValues(_ * 3)))) + + assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 11)) + + assert(executionUIData.runningJobs === Seq(0)) + assert(executionUIData.succeededJobs.isEmpty) + assert(executionUIData.failedJobs.isEmpty) + + listener.onJobEnd(SparkListenerJobEnd( + jobId = 0, + time = System.currentTimeMillis(), + JobSucceeded + )) + listener.onExecutionEnd(executionId, System.currentTimeMillis()) + + assert(executionUIData.runningJobs.isEmpty) + assert(executionUIData.succeededJobs === Seq(0)) + assert(executionUIData.failedJobs.isEmpty) + + assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 11)) + } + + test("onExecutionEnd happens before onJobEnd(JobSucceeded)") { + val listener = new SQLListener(sqlContext) + val executionId = 0 + val df = createTestDataFrame + listener.onExecutionStart( + executionId, + "test", + "test", + df.queryExecution.toString, + SparkPlanGraph(df.queryExecution.executedPlan), + System.currentTimeMillis()) + listener.onJobStart(SparkListenerJobStart( + jobId = 0, + time = System.currentTimeMillis(), + stageInfos = Nil, + createProperties(executionId))) + listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onJobEnd(SparkListenerJobEnd( + jobId = 0, + time = System.currentTimeMillis(), + JobSucceeded + )) + + val executionUIData = listener.executionIdToData(0) + assert(executionUIData.runningJobs.isEmpty) + assert(executionUIData.succeededJobs === Seq(0)) + assert(executionUIData.failedJobs.isEmpty) + } + + test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") { + val listener = new SQLListener(sqlContext) + val executionId = 0 + val df = createTestDataFrame + listener.onExecutionStart( + executionId, + "test", + "test", + df.queryExecution.toString, + SparkPlanGraph(df.queryExecution.executedPlan), + System.currentTimeMillis()) + listener.onJobStart(SparkListenerJobStart( + jobId = 0, + time = System.currentTimeMillis(), + stageInfos = Nil, + createProperties(executionId))) + listener.onJobEnd(SparkListenerJobEnd( + jobId = 0, + time = System.currentTimeMillis(), + JobSucceeded + )) + + listener.onJobStart(SparkListenerJobStart( + jobId = 1, + time = System.currentTimeMillis(), + stageInfos = Nil, + createProperties(executionId))) + listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onJobEnd(SparkListenerJobEnd( + jobId = 1, + time = System.currentTimeMillis(), + JobSucceeded + )) + + val executionUIData = listener.executionIdToData(0) + assert(executionUIData.runningJobs.isEmpty) + assert(executionUIData.succeededJobs.sorted === Seq(0, 1)) + assert(executionUIData.failedJobs.isEmpty) + } + + test("onExecutionEnd happens before onJobEnd(JobFailed)") { + val listener = new SQLListener(sqlContext) + val executionId = 0 + val df = createTestDataFrame + listener.onExecutionStart( + executionId, + "test", + "test", + df.queryExecution.toString, + SparkPlanGraph(df.queryExecution.executedPlan), + System.currentTimeMillis()) + listener.onJobStart(SparkListenerJobStart( + jobId = 0, + time = System.currentTimeMillis(), + stageInfos = Seq.empty, + createProperties(executionId))) + listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onJobEnd(SparkListenerJobEnd( + jobId = 0, + time = System.currentTimeMillis(), + JobFailed(new RuntimeException("Oops")) + )) + + val executionUIData = listener.executionIdToData(0) + assert(executionUIData.runningJobs.isEmpty) + assert(executionUIData.succeededJobs.isEmpty) + assert(executionUIData.failedJobs === Seq(0)) + } + + ignore("no memory leak") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set("spark.task.maxFailures", "1") // Don't retry the tasks to run this test quickly + .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly + val sc = new SparkContext(conf) + try { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + // Run 100 successful executions and 100 failed executions. + // Each execution only has one job and one stage. + for (i <- 0 until 100) { + val df = Seq( + (1, 1), + (2, 2) + ).toDF() + df.collect() + try { + df.foreach(_ => throw new RuntimeException("Oops")) + } catch { + case e: SparkException => // This is expected for a failed job + } + } + sc.listenerBus.waitUntilEmpty(10000) + assert(sqlContext.listener.getCompletedExecutions.size <= 50) + assert(sqlContext.listener.getFailedExecutions.size <= 50) + // 50 for successful executions and 50 for failed executions + assert(sqlContext.listener.executionIdToData.size <= 100) + assert(sqlContext.listener.jobIdToExecutionId.size <= 100) + assert(sqlContext.listener.stageIdToStageMetrics.size <= 100) + } finally { + sc.stop() + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 69ab1c292d22..5ab9381de4d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -25,9 +25,13 @@ import org.h2.jdbc.JdbcSQLException import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { + import testImplicits._ -class JDBCSuite extends SparkFunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb0" val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass" var conn: java.sql.Connection = null @@ -41,12 +45,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { Some(StringType) } - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.sql - before { - Class.forName("org.h2.Driver") + Utils.classForName("org.h2.Driver") // Extra properties that will be specified for our database. We need these to test // usage of parameters from OPTIONS clause in queries. val properties = new Properties() @@ -133,7 +133,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { """.stripMargin.replaceAll("\n", " ")) - conn.prepareStatement("create table test.flttypes (a DOUBLE, b REAL, c DECIMAL(40, 20))" + conn.prepareStatement("create table test.flttypes (a DOUBLE, b REAL, c DECIMAL(38, 18))" ).executeUpdate() conn.prepareStatement("insert into test.flttypes values (" + "1.0000000000000002220446049250313080847263336181640625, " @@ -151,7 +151,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { s""" |create table test.nulltypes (a INT, b BOOLEAN, c TINYINT, d BINARY(20), e VARCHAR(20), |f VARCHAR_IGNORECASE(20), g CHAR(20), h BLOB, i CLOB, j TIME, k DATE, l TIMESTAMP, - |m DOUBLE, n REAL, o DECIMAL(40, 20)) + |m DOUBLE, n REAL, o DECIMAL(38, 18)) """.stripMargin.replaceAll("\n", " ")).executeUpdate() conn.prepareStatement("insert into test.nulltypes values (" + "null, null, null, null, null, null, null, null, null, " @@ -255,26 +255,26 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { } test("Basic API") { - assert(ctx.read.jdbc( + assert(sqlContext.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3) } test("Basic API with FetchSize") { val properties = new Properties properties.setProperty("fetchSize", "2") - assert(ctx.read.jdbc( + assert(sqlContext.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3) } test("Partitioning via JDBCPartitioningInfo API") { assert( - ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) .collect().length === 3) } test("Partitioning via list-of-where-clauses API") { val parts = Array[String]("THEID < 2", "THEID >= 2") - assert(ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) + assert(sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) .collect().length === 3) } @@ -326,13 +326,13 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { assert(cal.get(Calendar.HOUR) === 11) assert(cal.get(Calendar.MINUTE) === 22) assert(cal.get(Calendar.SECOND) === 33) - assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543500) + assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543000) } test("test DATE types") { - val rows = ctx.read.jdbc( + val rows = sqlContext.read.jdbc( urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - val cachedRows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val cachedRows = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(rows(1).getAs[java.sql.Date](1) === null) @@ -340,8 +340,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { } test("test DATE types in cache") { - val rows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val rows = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().registerTempTable("mycached_date") val cachedRows = sql("select * from mycached_date").collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) @@ -349,21 +349,21 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { } test("test types for null value") { - val rows = ctx.read.jdbc( + val rows = sqlContext.read.jdbc( urlWithUserAndPass, "TEST.NULLTYPES", new Properties).collect() assert((0 to 14).forall(i => rows(0).isNullAt(i))) } test("H2 floating-point types") { val rows = sql("SELECT * FROM flttypes").collect() - assert(rows(0).getDouble(0) === 1.00000000000000022) // Yes, I meant ==. - assert(rows(0).getDouble(1) === 1.00000011920928955) // Yes, I meant ==. - assert(rows(0).getAs[BigDecimal](2) - .equals(new BigDecimal("123456789012345.54321543215432100000"))) - assert(rows(0).schema.fields(2).dataType === DecimalType(40, 20)) - val compareDecimal = sql("SELECT C FROM flttypes where C > C - 1").collect() - assert(compareDecimal(0).getAs[BigDecimal](0) - .equals(new BigDecimal("123456789012345.54321543215432100000"))) + assert(rows(0).getDouble(0) === 1.00000000000000022) + assert(rows(0).getDouble(1) === 1.00000011920928955) + assert(rows(0).getAs[BigDecimal](2) === + new BigDecimal("123456789012345.543215432154321000")) + assert(rows(0).schema.fields(2).dataType === DecimalType(38, 18)) + val result = sql("SELECT C FROM flttypes where C > C - 1").collect() + assert(result(0).getAs[BigDecimal](0) === + new BigDecimal("123456789012345.543215432154321000")) } test("SQL query as table name") { @@ -396,7 +396,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { test("Remap types via JdbcDialects") { JdbcDialects.registerDialect(testH2Dialect) - val df = ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) + val df = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) assert(df.schema.filter(_.dataType != org.apache.spark.sql.types.StringType).isEmpty) val rows = df.collect() assert(rows(0).get(0).isInstanceOf[String]) @@ -407,6 +407,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { test("Default jdbc dialect registration") { assert(JdbcDialects.get("jdbc:mysql://127.0.0.1/db") == MySQLDialect) assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") == PostgresDialect) + assert(JdbcDialects.get("jdbc:db2://127.0.0.1/db") == DB2Dialect) assert(JdbcDialects.get("test.invalid") == NoopDialect) } @@ -444,4 +445,23 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) } + test("DB2Dialect type mapping") { + val db2Dialect = JdbcDialects.get("jdbc:db2://127.0.0.1/db") + assert(db2Dialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB") + assert(db2Dialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "CHAR(1)") + } + + test("table exists query by jdbc dialect") { + val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") + val h2 = JdbcDialects.get(url) + val table = "weblogs" + val defaultQuery = s"SELECT * FROM $table WHERE 1=0" + val limitQuery = s"SELECT 1 FROM $table LIMIT 1" + assert(MySQL.getTableExistsQuery(table) == limitQuery) + assert(Postgres.getTableExistsQuery(table) == limitQuery) + assert(db2.getTableExistsQuery(table) == defaultQuery) + assert(h2.getTableExistsQuery(table) == defaultQuery) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index d949ef42267e..e23ee6693133 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -22,11 +22,13 @@ import java.util.Properties import org.scalatest.BeforeAndAfter -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{SaveMode, Row} +import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { -class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb2" var conn: java.sql.Connection = null val url1 = "jdbc:h2:mem:testdb3" @@ -36,12 +38,8 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { properties.setProperty("password", "testPass") properties.setProperty("rowId", "false") - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.sql - before { - Class.forName("org.h2.Driver") + Utils.classForName("org.h2.Driver") conn = DriverManager.getConnection(url) conn.prepareStatement("create schema test").executeUpdate() @@ -57,14 +55,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { "create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() conn1.commit() - ctx.sql( + sql( s""" |CREATE TEMPORARY TABLE PEOPLE |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - ctx.sql( + sql( s""" |CREATE TEMPORARY TABLE PEOPLE1 |USING org.apache.spark.sql.jdbc @@ -77,8 +75,6 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { conn1.close() } - private lazy val sc = ctx.sparkContext - private lazy val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222)) private lazy val arr1x2 = Array[Row](Row.apply("fred", 3)) private lazy val schema2 = StructType( @@ -92,49 +88,50 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { StructField("seq", IntegerType) :: Nil) test("Basic CREATE") { - val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties) - assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) - assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) + assert(2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) + assert( + 2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) } test("CREATE with overwrite") { - val df = ctx.createDataFrame(sc.parallelize(arr2x3), schema3) - val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3) + val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.DROPTEST", properties) - assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(3 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(2 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(3 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.DROPTEST", properties) - assert(1 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(1 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(2 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) } test("CREATE then INSERT to append") { - val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url, "TEST.APPENDTEST", new Properties) df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties) - assert(3 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) - assert(2 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) + assert(3 === sqlContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) + assert(2 === sqlContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) } test("CREATE then INSERT to truncate") { - val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.TRUNCATETEST", properties) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.TRUNCATETEST", properties) - assert(1 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) - assert(2 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) + assert(1 === sqlContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) + assert(2 === sqlContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) } test("Incompatible INSERT to append") { - val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = ctx.createDataFrame(sc.parallelize(arr2x3), schema3) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3) df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties) intercept[org.apache.spark.SparkException] { @@ -143,15 +140,15 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { } test("INSERT to JDBC Datasource") { - ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } test("INSERT to JDBC Datasource with overwrite") { - ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - ctx.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala deleted file mode 100644 index fafad67dde3a..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.sql.types._ -import org.apache.spark.sql.{QueryTest, Row, SQLConf} - -/** - * A test suite that tests various Parquet queries. - */ -class ParquetQuerySuiteBase extends QueryTest with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.sql - - test("simple select queries") { - withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { - checkAnswer(sql("SELECT _1 FROM t where t._1 > 5"), (6 until 10).map(Row.apply(_))) - checkAnswer(sql("SELECT _1 FROM t as tmp where tmp._1 < 5"), (0 until 5).map(Row.apply(_))) - } - } - - test("appending") { - val data = (0 until 10).map(i => (i, i.toString)) - sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") - withParquetTable(data, "t") { - sql("INSERT INTO TABLE t SELECT * FROM tmp") - checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) - } - sqlContext.catalog.unregisterTable(Seq("tmp")) - } - - test("overwriting") { - val data = (0 until 10).map(i => (i, i.toString)) - sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") - withParquetTable(data, "t") { - sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") - checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple)) - } - sqlContext.catalog.unregisterTable(Seq("tmp")) - } - - test("self-join") { - // 4 rows, cells of column 1 of row 2 and row 4 are null - val data = (1 to 4).map { i => - val maybeInt = if (i % 2 == 0) None else Some(i) - (maybeInt, i.toString) - } - - withParquetTable(data, "t") { - val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") - val queryOutput = selfJoin.queryExecution.analyzed.output - - assertResult(4, "Field count mismatches")(queryOutput.size) - assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") { - queryOutput.filter(_.name == "_1").map(_.exprId).size - } - - checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) - } - } - - test("nested data - struct with array field") { - val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { - case Tuple1((_, Seq(string))) => Row(string) - }) - } - } - - test("nested data - array of struct") { - val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { - case Tuple1(Seq((_, string))) => Row(string) - }) - } - } - - test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { - withParquetTable((1 to 10).map(Tuple1.apply), "t") { - checkAnswer(sql("SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) - } - } - - test("SPARK-5309 strings stored using dictionary compression in parquet") { - withParquetTable((0 until 1000).map(i => ("same", "run_" + i /100, 1)), "t") { - - checkAnswer(sql("SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), - (0 until 10).map(i => Row("same", "run_" + i, 100))) - - checkAnswer(sql("SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), - List(Row("same", "run_5", 100))) - } - } - - test("SPARK-6917 DecimalType should work with non-native types") { - val data = (1 to 10).map(i => Row(Decimal(i, 18, 0), new java.sql.Timestamp(i))) - val schema = StructType(List(StructField("d", DecimalType(18, 0), false), - StructField("time", TimestampType, false)).toArray) - withTempPath { file => - val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data), schema) - df.write.parquet(file.getCanonicalPath) - val df2 = sqlContext.read.parquet(file.getCanonicalPath) - checkAnswer(df2, df.collect().toSeq) - } - } -} - -class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { - private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi - - override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) - } - - override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) - } -} - -class ParquetDataSourceOffQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { - private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi - - override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) - } - - override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala deleted file mode 100644 index 171a656f0e01..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ /dev/null @@ -1,280 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.TypeTag - -import org.apache.parquet.schema.MessageTypeParser - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.types._ - -class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext - - /** - * Checks whether the reflected Parquet message type for product type `T` conforms `messageType`. - */ - private def testSchema[T <: Product: ClassTag: TypeTag]( - testName: String, messageType: String, isThriftDerived: Boolean = false): Unit = { - test(testName) { - val actual = ParquetTypesConverter.convertFromAttributes( - ScalaReflection.attributesFor[T], isThriftDerived) - val expected = MessageTypeParser.parseMessageType(messageType) - actual.checkContains(expected) - expected.checkContains(actual) - } - } - - testSchema[(Boolean, Int, Long, Float, Double, Array[Byte])]( - "basic types", - """ - |message root { - | required boolean _1; - | required int32 _2; - | required int64 _3; - | required float _4; - | required double _5; - | optional binary _6; - |} - """.stripMargin) - - testSchema[(Byte, Short, Int, Long, java.sql.Date)]( - "logical integral types", - """ - |message root { - | required int32 _1 (INT_8); - | required int32 _2 (INT_16); - | required int32 _3 (INT_32); - | required int64 _4 (INT_64); - | optional int32 _5 (DATE); - |} - """.stripMargin) - - // Currently String is the only supported logical binary type. - testSchema[Tuple1[String]]( - "binary logical types", - """ - |message root { - | optional binary _1 (UTF8); - |} - """.stripMargin) - - testSchema[Tuple1[Seq[Int]]]( - "array", - """ - |message root { - | optional group _1 (LIST) { - | repeated int32 array; - | } - |} - """.stripMargin) - - testSchema[Tuple1[Map[Int, String]]]( - "map", - """ - |message root { - | optional group _1 (MAP) { - | repeated group map (MAP_KEY_VALUE) { - | required int32 key; - | optional binary value (UTF8); - | } - | } - |} - """.stripMargin) - - testSchema[Tuple1[Pair[Int, String]]]( - "struct", - """ - |message root { - | optional group _1 { - | required int32 _1; - | optional binary _2 (UTF8); - | } - |} - """.stripMargin) - - testSchema[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( - "deeply nested type", - """ - |message root { - | optional group _1 (MAP) { - | repeated group map (MAP_KEY_VALUE) { - | required int32 key; - | optional group value { - | optional binary _1 (UTF8); - | optional group _2 (LIST) { - | repeated group bag { - | optional group array { - | required int32 _1; - | required double _2; - | } - | } - | } - | } - | } - | } - |} - """.stripMargin) - - testSchema[(Option[Int], Map[Int, Option[Double]])]( - "optional types", - """ - |message root { - | optional int32 _1; - | optional group _2 (MAP) { - | repeated group map (MAP_KEY_VALUE) { - | required int32 key; - | optional double value; - | } - | } - |} - """.stripMargin) - - // Test for SPARK-4520 -- ensure that thrift generated parquet schema is generated - // as expected from attributes - testSchema[(Array[Byte], Array[Byte], Array[Byte], Seq[Int], Map[Array[Byte], Seq[Int]])]( - "thrift generated parquet schema", - """ - |message root { - | optional binary _1 (UTF8); - | optional binary _2 (UTF8); - | optional binary _3 (UTF8); - | optional group _4 (LIST) { - | repeated int32 _4_tuple; - | } - | optional group _5 (MAP) { - | repeated group map (MAP_KEY_VALUE) { - | required binary key (UTF8); - | optional group value (LIST) { - | repeated int32 value_tuple; - | } - | } - | } - |} - """.stripMargin, isThriftDerived = true) - - test("DataType string parser compatibility") { - // This is the generated string from previous versions of the Spark SQL, using the following: - // val schema = StructType(List( - // StructField("c1", IntegerType, false), - // StructField("c2", BinaryType, true))) - val caseClassString = - "StructType(List(StructField(c1,IntegerType,false), StructField(c2,BinaryType,true)))" - - // scalastyle:off - val jsonString = - """ - |{"type":"struct","fields":[{"name":"c1","type":"integer","nullable":false,"metadata":{}},{"name":"c2","type":"binary","nullable":true,"metadata":{}}]} - """.stripMargin - // scalastyle:on - - val fromCaseClassString = ParquetTypesConverter.convertFromString(caseClassString) - val fromJson = ParquetTypesConverter.convertFromString(jsonString) - - (fromCaseClassString, fromJson).zipped.foreach { (a, b) => - assert(a.name == b.name) - assert(a.dataType === b.dataType) - assert(a.nullable === b.nullable) - } - } - - test("merge with metastore schema") { - // Field type conflict resolution - assertResult( - StructType(Seq( - StructField("lowerCase", StringType), - StructField("UPPERCase", DoubleType, nullable = false)))) { - - ParquetRelation2.mergeMetastoreParquetSchema( - StructType(Seq( - StructField("lowercase", StringType), - StructField("uppercase", DoubleType, nullable = false))), - - StructType(Seq( - StructField("lowerCase", BinaryType), - StructField("UPPERCase", IntegerType, nullable = true)))) - } - - // MetaStore schema is subset of parquet schema - assertResult( - StructType(Seq( - StructField("UPPERCase", DoubleType, nullable = false)))) { - - ParquetRelation2.mergeMetastoreParquetSchema( - StructType(Seq( - StructField("uppercase", DoubleType, nullable = false))), - - StructType(Seq( - StructField("lowerCase", BinaryType), - StructField("UPPERCase", IntegerType, nullable = true)))) - } - - // Metastore schema contains additional non-nullable fields. - assert(intercept[Throwable] { - ParquetRelation2.mergeMetastoreParquetSchema( - StructType(Seq( - StructField("uppercase", DoubleType, nullable = false), - StructField("lowerCase", BinaryType, nullable = false))), - - StructType(Seq( - StructField("UPPERCase", IntegerType, nullable = true)))) - }.getMessage.contains("detected conflicting schemas")) - - // Conflicting non-nullable field names - intercept[Throwable] { - ParquetRelation2.mergeMetastoreParquetSchema( - StructType(Seq(StructField("lower", StringType, nullable = false))), - StructType(Seq(StructField("lowerCase", BinaryType)))) - } - } - - test("merge missing nullable fields from Metastore schema") { - // Standard case: Metastore schema contains additional nullable fields not present - // in the Parquet file schema. - assertResult( - StructType(Seq( - StructField("firstField", StringType, nullable = true), - StructField("secondField", StringType, nullable = true), - StructField("thirdfield", StringType, nullable = true)))) { - ParquetRelation2.mergeMetastoreParquetSchema( - StructType(Seq( - StructField("firstfield", StringType, nullable = true), - StructField("secondfield", StringType, nullable = true), - StructField("thirdfield", StringType, nullable = true))), - StructType(Seq( - StructField("firstField", StringType, nullable = true), - StructField("secondField", StringType, nullable = true)))) - } - - // Merge should fail if the Metastore contains any additional fields that are not - // nullable. - assert(intercept[Throwable] { - ParquetRelation2.mergeMetastoreParquetSchema( - StructType(Seq( - StructField("firstfield", StringType, nullable = true), - StructField("secondfield", StringType, nullable = true), - StructField("thirdfield", StringType, nullable = false))), - StructType(Seq( - StructField("firstField", StringType, nullable = true), - StructField("secondField", StringType, nullable = true)))) - }.getMessage.contains("detected conflicting schemas")) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index a71088430bfd..6fc9febe4970 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -19,27 +19,30 @@ package org.apache.spark.sql.sources import java.io.{File, IOException} -import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.datasources.DDLException +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { - - import caseInsensitiveContext.sql - - private lazy val sparkContext = caseInsensitiveContext.sparkContext - - var path: File = null +class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { + protected override lazy val sql = caseInsensitiveContext.sql _ + private var path: File = null override def beforeAll(): Unit = { + super.beforeAll() path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) caseInsensitiveContext.read.json(rdd).registerTempTable("jt") } override def afterAll(): Unit = { - caseInsensitiveContext.dropTempTable("jt") + try { + caseInsensitiveContext.dropTempTable("jt") + } finally { + super.afterAll() + } } after { @@ -50,7 +53,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -74,7 +77,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -91,7 +94,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -106,7 +109,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -121,7 +124,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -138,7 +141,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -157,7 +160,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -174,7 +177,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable (a int, b string) - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -187,7 +190,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -198,7 +201,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala new file mode 100644 index 000000000000..853707c036c9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala @@ -0,0 +1,89 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{StringType, StructField, StructType} + + +// please note that the META-INF/services had to be modified for the test directory for this to work +class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext { + + test("data sources with the same name") { + intercept[RuntimeException] { + caseInsensitiveContext.read.format("Fluet da Bomb").load() + } + } + + test("load data source from format alias") { + caseInsensitiveContext.read.format("gathering quorum").load().schema == + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } + + test("specify full classname with duplicate formats") { + caseInsensitiveContext.read.format("org.apache.spark.sql.sources.FakeSourceOne") + .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false))) + } + + test("should fail to load ORC without HiveContext") { + intercept[ClassNotFoundException] { + caseInsensitiveContext.read.format("orc").load() + } + } +} + + +class FakeSourceOne extends RelationProvider with DataSourceRegister { + + def shortName(): String = "Fluet da Bomb" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } +} + +class FakeSourceTwo extends RelationProvider with DataSourceRegister { + + def shortName(): String = "Fluet da Bomb" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } +} + +class FakeSourceThree extends RelationProvider with DataSourceRegister { + + def shortName(): String = "gathering quorum" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 5fc53f701299..5f8514e1a241 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.sources import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -44,7 +45,7 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo StructField("doubleType", DoubleType, nullable = false), StructField("bigintType", LongType, nullable = false), StructField("tinyintType", ByteType, nullable = false), - StructField("decimalType", DecimalType.Unlimited, nullable = false), + StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false), StructField("fixedDecimalType", DecimalType(5, 1), nullable = false), StructField("binaryType", BinaryType, nullable = false), StructField("booleanType", BooleanType, nullable = false), @@ -61,16 +62,19 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo override def needConversion: Boolean = false override def buildScan(): RDD[Row] = { + // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] sqlContext.sparkContext.parallelize(from to to).map { e => InternalRow(UTF8String.fromString(s"people$e"), e * 2) - } + }.asInstanceOf[RDD[Row]] } } -class DDLTestSuite extends DataSourceTest { +class DDLTestSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ - before { - caseInsensitiveContext.sql( + override def beforeAll(): Unit = { + super.beforeAll() + sql( """ |CREATE TEMPORARY TABLE ddlPeople |USING org.apache.spark.sql.sources.DDLScanSource @@ -104,7 +108,7 @@ class DDLTestSuite extends DataSourceTest { )) test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") { - val attributes = caseInsensitiveContext.sql("describe ddlPeople") + val attributes = sql("describe ddlPeople") .queryExecution.executedPlan.output assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment")) assert(attributes.map(_.dataType).toSet === Set(StringType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 00cc7d5ea580..af04079ec895 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -17,18 +17,21 @@ package org.apache.spark.sql.sources -import org.scalatest.BeforeAndAfter - import org.apache.spark.sql._ -import org.apache.spark.sql.test.TestSQLContext +private[sql] abstract class DataSourceTest extends QueryTest { -abstract class DataSourceTest extends QueryTest with BeforeAndAfter { // We want to test some edge cases. - protected implicit lazy val caseInsensitiveContext = { - val ctx = new SQLContext(TestSQLContext.sparkContext) + protected lazy val caseInsensitiveContext: SQLContext = { + val ctx = new SQLContext(sqlContext.sparkContext) ctx.setConf(SQLConf.CASE_SENSITIVE, false) ctx } + protected def sqlTest(sqlString: String, expectedAnswer: Seq[Row]) { + test(sqlString) { + checkAnswer(caseInsensitiveContext.sql(sqlString), expectedAnswer) + } + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 81b3a0f0c5b3..68ce37c00077 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.sources import scala.language.existentials import org.apache.spark.rdd.RDD +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -56,6 +58,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL // Predicate test on integer column def translateFilterOnA(filter: Filter): Int => Boolean = filter match { case EqualTo("a", v) => (a: Int) => a == v + case EqualNullSafe("a", v) => (a: Int) => a == v case LessThan("a", v: Int) => (a: Int) => a < v case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v case GreaterThan("a", v: Int) => (a: Int) => a > v @@ -76,6 +79,9 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL case StringStartsWith("c", v) => _.startsWith(v) case StringEndsWith("c", v) => _.endsWith(v) case StringContains("c", v) => _.contains(v) + case EqualTo("c", v: String) => _.equals(v) + case EqualTo("c", v: UTF8String) => sys.error("UTF8String should not appear in filters") + case In("c", values) => (s: String) => values.map(_.asInstanceOf[String]).toSet.contains(s) case _ => (c: String) => true } @@ -95,11 +101,11 @@ object FiltersPushed { var list: Seq[Filter] = Nil } -class FilteredScanSuite extends DataSourceTest { +class FilteredScanSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ - import caseInsensitiveContext.sql - - before { + override def beforeAll(): Unit = { + super.beforeAll() sql( """ |CREATE TEMPORARY TABLE oneToTenFiltered @@ -235,6 +241,9 @@ class FilteredScanSuite extends DataSourceTest { testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1) testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0) + testPushDown("SELECT c FROM oneToTenFiltered WHERE c = 'aaaaaAAAAA'", 1) + testPushDown("SELECT c FROM oneToTenFiltered WHERE c IN ('aaaaaAAAAA', 'foo')", 1) + def testPushDown(sqlString: String, expectedCount: Int): Unit = { test(s"PushDown Returns $expectedCount: $sqlString") { val queryExecution = sql(sqlString).queryExecution diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 0b7c46c482c8..5b70d258d6ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -19,22 +19,18 @@ package org.apache.spark.sql.sources import java.io.File -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.sql.{SaveMode, AnalysisException, Row} +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class InsertSuite extends DataSourceTest with BeforeAndAfterAll { - - import caseInsensitiveContext.sql - - private lazy val sparkContext = caseInsensitiveContext.sparkContext - - var path: File = null +class InsertSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ + private var path: File = null - override def beforeAll: Unit = { + override def beforeAll(): Unit = { + super.beforeAll() path = Utils.createTempDir() - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) caseInsensitiveContext.read.json(rdd).registerTempTable("jt") sql( s""" @@ -46,10 +42,14 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { """.stripMargin) } - override def afterAll: Unit = { - caseInsensitiveContext.dropTempTable("jsonTable") - caseInsensitiveContext.dropTempTable("jt") - Utils.deleteRecursively(path) + override def afterAll(): Unit = { + try { + caseInsensitiveContext.dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jt") + Utils.deleteRecursively(path) + } finally { + super.afterAll() + } } test("Simple INSERT OVERWRITE a JSONRelation") { @@ -110,7 +110,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { ) // Writing the table to less part files. - val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 5) + val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""), 5) caseInsensitiveContext.read.json(rdd1).registerTempTable("jt1") sql( s""" @@ -122,7 +122,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { ) // Writing the table to more part files. - val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 10) + val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""), 10) caseInsensitiveContext.read.json(rdd2).registerTempTable("jt2") sql( s""" @@ -146,27 +146,23 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { caseInsensitiveContext.dropTempTable("jt2") } - test("INSERT INTO not supported for JSONRelation for now") { - intercept[RuntimeException]{ - sql( - s""" - |INSERT INTO TABLE jsonTable SELECT a, b FROM jt - """.stripMargin) - } - } - - test("save directly to the path of a JSON table") { - caseInsensitiveContext.table("jt").selectExpr("a * 5 as a", "b") - .write.mode(SaveMode.Overwrite).json(path.toString) + test("INSERT INTO JSONRelation for now") { + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) checkAnswer( sql("SELECT a, b FROM jsonTable"), - (1 to 10).map(i => Row(i * 5, s"str$i")) + sql("SELECT a, b FROM jt").collect() ) - caseInsensitiveContext.table("jt").write.mode(SaveMode.Overwrite).json(path.toString) + sql( + s""" + |INSERT INTO TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) checkAnswer( sql("SELECT a, b FROM jsonTable"), - (1 to 10).map(i => Row(i, s"str$i")) + sql("SELECT a, b FROM jt UNION ALL SELECT a, b FROM jt").collect() ) } @@ -183,6 +179,11 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { } test("Caching") { + // write something to the jsonTable + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) // Cached Query Execution caseInsensitiveContext.cacheTable("jsonTable") assertCached(sql("SELECT * FROM jsonTable")) @@ -205,9 +206,10 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT a * 2 FROM jsonTable"), (1 to 10).map(i => Row(i * 2)).toSeq) - assertCached(sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2) - checkAnswer( - sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), + assertCached(sql( + "SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2) + checkAnswer(sql( + "SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), (2 to 10).map(i => Row(i, i - 1)).toSeq) // Insert overwrite and keep the same schema. @@ -217,14 +219,15 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { """.stripMargin) // jsonTable should be recached. assertCached(sql("SELECT * FROM jsonTable")) - // The cached data is the new data. - checkAnswer( - sql("SELECT a, b FROM jsonTable"), - sql("SELECT a * 2, b FROM jt").collect()) - - // Verify uncaching - caseInsensitiveContext.uncacheTable("jsonTable") - assertCached(sql("SELECT * FROM jsonTable"), 0) + // TODO we need to invalidate the cached data in InsertIntoHadoopFsRelation +// // The cached data is the new data. +// checkAnswer( +// sql("SELECT a, b FROM jsonTable"), +// sql("SELECT a * 2, b FROM jt").collect()) +// +// // Verify uncaching +// caseInsensitiveContext.uncacheTable("jsonTable") +// assertCached(sql("SELECT * FROM jsonTable"), 0) } test("it's not allowed to insert into a relation that is not an InsertableRelation") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala new file mode 100644 index 000000000000..c9791879ec74 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils + +class PartitionedWriteSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("write many partitions") { + val path = Utils.createTempDir() + path.delete() + + val df = sqlContext.range(100).select($"id", lit(1).as("data")) + df.write.partitionBy("id").save(path.getCanonicalPath) + + checkAnswer( + sqlContext.read.load(path.getCanonicalPath), + (0 to 99).map(Row(1, _)).toSeq) + + Utils.deleteRecursively(path) + } + + test("write many partitions with repeats") { + val path = Utils.createTempDir() + path.delete() + + val base = sqlContext.range(100) + val df = base.unionAll(base).select($"id", lit(1).as("data")) + df.write.partitionBy("id").save(path.getCanonicalPath) + + checkAnswer( + sqlContext.read.load(path.getCanonicalPath), + (0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq) + + Utils.deleteRecursively(path) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index 257526feab94..a89c5f8007e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -21,6 +21,7 @@ import scala.language.existentials import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ class PrunedScanSource extends RelationProvider { @@ -51,10 +52,12 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo } } -class PrunedScanSuite extends DataSourceTest { +class PrunedScanSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ - before { - caseInsensitiveContext.sql( + override def beforeAll(): Unit = { + super.beforeAll() + sql( """ |CREATE TEMPORARY TABLE oneToTenPruned |USING org.apache.spark.sql.sources.PrunedScanSource @@ -114,7 +117,7 @@ class PrunedScanSuite extends DataSourceTest { def testPruning(sqlString: String, expectedColumns: String*): Unit = { test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") { - val queryExecution = caseInsensitiveContext.sql(sqlString).queryExecution + val queryExecution = sql(sqlString).queryExecution val rawPlan = queryExecution.executedPlan.collect { case p: execution.PhysicalRDD => p } match { @@ -131,7 +134,7 @@ class PrunedScanSuite extends DataSourceTest { queryExecution) } - if (rawOutput.size != expectedColumns.size) { + if (rawOutput.numFields != expectedColumns.size) { fail(s"Wrong output row. Got $rawOutput\n$queryExecution") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 296b0d6f74a0..27d1cd92fca1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -18,17 +18,43 @@ package org.apache.spark.sql.sources import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.datasources.ResolvedDataSource class ResolvedDataSourceSuite extends SparkFunSuite { - test("builtin sources") { - assert(ResolvedDataSource.lookupDataSource("jdbc") === - classOf[org.apache.spark.sql.jdbc.DefaultSource]) + test("jdbc") { + assert( + ResolvedDataSource.lookupDataSource("jdbc") === + classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.jdbc") === + classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.jdbc") === + classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + } - assert(ResolvedDataSource.lookupDataSource("json") === - classOf[org.apache.spark.sql.json.DefaultSource]) + test("json") { + assert( + ResolvedDataSource.lookupDataSource("json") === + classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.json") === + classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.json") === + classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + } - assert(ResolvedDataSource.lookupDataSource("parquet") === - classOf[org.apache.spark.sql.parquet.DefaultSource]) + test("parquet") { + assert( + ResolvedDataSource.lookupDataSource("parquet") === + classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.parquet") === + classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.parquet") === + classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index b032515a9d28..10d261368993 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -19,25 +19,21 @@ package org.apache.spark.sql.sources import java.io.File -import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.{SaveMode, SQLConf, DataFrame} +import org.apache.spark.sql.{AnalysisException, SaveMode, SQLConf, DataFrame} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { - - import caseInsensitiveContext.sql - - private lazy val sparkContext = caseInsensitiveContext.sparkContext - - var originalDefaultSource: String = null - - var path: File = null - - var df: DataFrame = null +class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { + protected override lazy val sql = caseInsensitiveContext.sql _ + private var originalDefaultSource: String = null + private var path: File = null + private var df: DataFrame = null override def beforeAll(): Unit = { + super.beforeAll() originalDefaultSource = caseInsensitiveContext.conf.defaultDataSourceName path = Utils.createTempDir() @@ -49,27 +45,32 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { } override def afterAll(): Unit = { - caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + try { + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + } finally { + super.afterAll() + } } after { - caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) Utils.deleteRecursively(path) } - def checkLoad(): Unit = { + def checkLoad(expectedDF: DataFrame = df, tbl: String = "jsonTable"): Unit = { caseInsensitiveContext.conf.setConf( SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") - checkAnswer(caseInsensitiveContext.read.load(path.toString), df.collect()) + checkAnswer(caseInsensitiveContext.read.load(path.toString), expectedDF.collect()) // Test if we can pick up the data source name passed in load. caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - checkAnswer(caseInsensitiveContext.read.format("json").load(path.toString), df.collect()) - checkAnswer(caseInsensitiveContext.read.format("json").load(path.toString), df.collect()) + checkAnswer(caseInsensitiveContext.read.format("json").load(path.toString), + expectedDF.collect()) + checkAnswer(caseInsensitiveContext.read.format("json").load(path.toString), + expectedDF.collect()) val schema = StructType(StructField("b", StringType, true) :: Nil) checkAnswer( caseInsensitiveContext.read.format("json").schema(schema).load(path.toString), - sql("SELECT b FROM jsonTable").collect()) + sql(s"SELECT b FROM $tbl").collect()) } test("save with path and load") { @@ -102,7 +103,7 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { test("save and save again") { df.write.json(path.toString) - var message = intercept[RuntimeException] { + val message = intercept[AnalysisException] { df.write.json(path.toString) }.getMessage @@ -118,12 +119,11 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { df.write.mode(SaveMode.Overwrite).json(path.toString) checkLoad() - message = intercept[RuntimeException] { - df.write.mode(SaveMode.Append).json(path.toString) - }.getMessage + // verify the append mode + df.write.mode(SaveMode.Append).json(path.toString) + val df2 = df.unionAll(df) + df2.registerTempTable("jsonTable2") - assert( - message.contains("Append mode is not supported"), - "We should complain that 'Append mode is not supported' for JSON source.") + checkLoad(df2, "jsonTable2") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 48875773224c..12af8068c398 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.sql.sources -import java.sql.{Timestamp, Date} - +import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String class DefaultSource extends SimpleScanSource @@ -68,13 +66,13 @@ case class AllDataTypesScan( override def schema: StructType = userSpecifiedSchema - override def needConversion: Boolean = false + override def needConversion: Boolean = true override def buildScan(): RDD[Row] = { sqlContext.sparkContext.parallelize(from to to).map { i => - InternalRow( - UTF8String.fromString(s"str_$i"), - s"str_$i".getBytes(), + Row( + s"str_$i", + s"str_$i".getBytes(StandardCharsets.UTF_8), i % 2 == 0, i.toByte, i.toShort, @@ -82,24 +80,24 @@ case class AllDataTypesScan( i.toLong, i.toFloat, i.toDouble, - Decimal(new java.math.BigDecimal(i)), - Decimal(new java.math.BigDecimal(i)), - DateUtils.fromJavaDate(new Date(1970, 1, 1)), - DateUtils.fromJavaTimestamp(new Timestamp(20000 + i)), - UTF8String.fromString(s"varchar_$i"), + new java.math.BigDecimal(i), + new java.math.BigDecimal(i), + Date.valueOf("1970-01-01"), + new Timestamp(20000 + i), + s"varchar_$i", Seq(i, i + 1), - Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))), + Seq(Map(s"str_$i" -> Row(i.toLong))), Map(i -> i.toString), - Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)), + Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), Row(i, i.toString), - Row(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")), - InternalRow(Seq(DateUtils.fromJavaDate(new Date(1970, 1, i + 1)))))) + Row(Seq(s"str_$i", s"str_${i + 1}"), + Row(Seq(Date.valueOf(s"1970-01-${i + 1}"))))) } } } -class TableScanSuite extends DataSourceTest { - import caseInsensitiveContext.sql +class TableScanSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ private lazy val tableWithSchemaExpected = (1 to 10).map { i => Row( @@ -114,7 +112,7 @@ class TableScanSuite extends DataSourceTest { i.toDouble, new java.math.BigDecimal(i), new java.math.BigDecimal(i), - new Date(1970, 1, 1), + Date.valueOf("1970-01-01"), new Timestamp(20000 + i), s"varchar_$i", Seq(i, i + 1), @@ -122,10 +120,11 @@ class TableScanSuite extends DataSourceTest { Map(i -> i.toString), Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), Row(i, i.toString), - Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(1970, 1, i + 1))))) + Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(Date.valueOf(s"1970-01-${i + 1}"))))) }.toSeq - before { + override def beforeAll(): Unit = { + super.beforeAll() sql( """ |CREATE TEMPORARY TABLE oneToTen @@ -203,7 +202,7 @@ class TableScanSuite extends DataSourceTest { StructField("longField_:,<>=+/~^", LongType, true) :: StructField("floatField", FloatType, true) :: StructField("doubleField", DoubleType, true) :: - StructField("decimalField1", DecimalType.Unlimited, true) :: + StructField("decimalField1", DecimalType.USER_DEFAULT, true) :: StructField("decimalField2", DecimalType(9, 2), true) :: StructField("dateField", DateType, true) :: StructField("timestampField", TimestampType, true) :: @@ -281,7 +280,7 @@ class TableScanSuite extends DataSourceTest { sqlTest( "SELECT structFieldComplex.Value.`value_(2)` FROM tableWithSchema", - (1 to 10).map(i => Row(Seq(new Date(1970, 1, i + 1)))).toSeq) + (1 to 10).map(i => Row(Seq(Date.valueOf(s"1970-01-${i + 1}")))).toSeq) test("Caching") { // Cached Query Execution @@ -306,9 +305,10 @@ class TableScanSuite extends DataSourceTest { sql("SELECT i * 2 FROM oneToTen"), (1 to 10).map(i => Row(i * 2)).toSeq) - assertCached(sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2) - checkAnswer( - sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), + assertCached(sql( + "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2) + checkAnswer(sql( + "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), (2 to 10).map(i => Row(i, i - 1)).toSeq) // Verify uncaching diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/ProcessTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/ProcessTestUtils.scala new file mode 100644 index 000000000000..152c9c8459de --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/ProcessTestUtils.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import java.io.{IOException, InputStream} + +import scala.sys.process.BasicIO + +object ProcessTestUtils { + class ProcessOutputCapturer(stream: InputStream, capture: String => Unit) extends Thread { + this.setDaemon(true) + + override def run(): Unit = { + try { + BasicIO.processFully(capture)(stream) + } catch { case _: IOException => + // Ignores the IOException thrown when the process termination, which closes the input + // stream abruptly. + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala new file mode 100644 index 000000000000..520dea7f7dd9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} + +/** + * A collection of sample data used in SQL tests. + */ +private[sql] trait SQLTestData { self => + protected def sqlContext: SQLContext + + // Helper object to import SQL implicits without a concrete SQLContext + private object internalImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self.sqlContext + } + + import internalImplicits._ + import SQLTestData._ + + // Note: all test data should be lazy because the SQLContext is not set up yet. + + protected lazy val emptyTestData: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + Seq.empty[Int].map(i => TestData(i, i.toString))).toDF() + df.registerTempTable("emptyTestData") + df + } + + protected lazy val testData: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toDF() + df.registerTempTable("testData") + df + } + + protected lazy val testData2: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + TestData2(1, 1) :: + TestData2(1, 2) :: + TestData2(2, 1) :: + TestData2(2, 2) :: + TestData2(3, 1) :: + TestData2(3, 2) :: Nil, 2).toDF() + df.registerTempTable("testData2") + df + } + + protected lazy val testData3: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + TestData3(1, None) :: + TestData3(2, Some(2)) :: Nil).toDF() + df.registerTempTable("testData3") + df + } + + protected lazy val negativeData: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() + df.registerTempTable("negativeData") + df + } + + protected lazy val largeAndSmallInts: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + LargeAndSmallInts(2147483644, 1) :: + LargeAndSmallInts(1, 2) :: + LargeAndSmallInts(2147483645, 1) :: + LargeAndSmallInts(2, 2) :: + LargeAndSmallInts(2147483646, 1) :: + LargeAndSmallInts(3, 2) :: Nil).toDF() + df.registerTempTable("largeAndSmallInts") + df + } + + protected lazy val decimalData: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + DecimalData(1, 1) :: + DecimalData(1, 2) :: + DecimalData(2, 1) :: + DecimalData(2, 2) :: + DecimalData(3, 1) :: + DecimalData(3, 2) :: Nil).toDF() + df.registerTempTable("decimalData") + df + } + + protected lazy val binaryData: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + BinaryData("12".getBytes, 1) :: + BinaryData("22".getBytes, 5) :: + BinaryData("122".getBytes, 3) :: + BinaryData("121".getBytes, 2) :: + BinaryData("123".getBytes, 4) :: Nil).toDF() + df.registerTempTable("binaryData") + df + } + + protected lazy val upperCaseData: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + UpperCaseData(1, "A") :: + UpperCaseData(2, "B") :: + UpperCaseData(3, "C") :: + UpperCaseData(4, "D") :: + UpperCaseData(5, "E") :: + UpperCaseData(6, "F") :: Nil).toDF() + df.registerTempTable("upperCaseData") + df + } + + protected lazy val lowerCaseData: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + LowerCaseData(1, "a") :: + LowerCaseData(2, "b") :: + LowerCaseData(3, "c") :: + LowerCaseData(4, "d") :: Nil).toDF() + df.registerTempTable("lowerCaseData") + df + } + + protected lazy val arrayData: RDD[ArrayData] = { + val rdd = sqlContext.sparkContext.parallelize( + ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: + ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) + rdd.toDF().registerTempTable("arrayData") + rdd + } + + protected lazy val mapData: RDD[MapData] = { + val rdd = sqlContext.sparkContext.parallelize( + MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: + MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: + MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: + MapData(Map(1 -> "a4", 2 -> "b4")) :: + MapData(Map(1 -> "a5")) :: Nil) + rdd.toDF().registerTempTable("mapData") + rdd + } + + protected lazy val repeatedData: RDD[StringData] = { + val rdd = sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) + rdd.toDF().registerTempTable("repeatedData") + rdd + } + + protected lazy val nullableRepeatedData: RDD[StringData] = { + val rdd = sqlContext.sparkContext.parallelize( + List.fill(2)(StringData(null)) ++ + List.fill(2)(StringData("test"))) + rdd.toDF().registerTempTable("nullableRepeatedData") + rdd + } + + protected lazy val nullInts: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + NullInts(1) :: + NullInts(2) :: + NullInts(3) :: + NullInts(null) :: Nil).toDF() + df.registerTempTable("nullInts") + df + } + + protected lazy val allNulls: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + NullInts(null) :: + NullInts(null) :: + NullInts(null) :: + NullInts(null) :: Nil).toDF() + df.registerTempTable("allNulls") + df + } + + protected lazy val nullStrings: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + NullStrings(1, "abc") :: + NullStrings(2, "ABC") :: + NullStrings(3, null) :: Nil).toDF() + df.registerTempTable("nullStrings") + df + } + + protected lazy val tableName: DataFrame = { + val df = sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF() + df.registerTempTable("tableName") + df + } + + protected lazy val unparsedStrings: RDD[String] = { + sqlContext.sparkContext.parallelize( + "1, A1, true, null" :: + "2, B2, false, null" :: + "3, C3, true, null" :: + "4, D4, true, 2147483644" :: Nil) + } + + // An RDD with 4 elements and 8 partitions + protected lazy val withEmptyParts: RDD[IntField] = { + val rdd = sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8) + rdd.toDF().registerTempTable("withEmptyParts") + rdd + } + + protected lazy val person: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + Person(0, "mike", 30) :: + Person(1, "jim", 20) :: Nil).toDF() + df.registerTempTable("person") + df + } + + protected lazy val salary: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + Salary(0, 2000.0) :: + Salary(1, 1000.0) :: Nil).toDF() + df.registerTempTable("salary") + df + } + + protected lazy val complexData: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) :: + ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) :: + Nil).toDF() + df.registerTempTable("complexData") + df + } + + /** + * Initialize all test data such that all temp tables are properly registered. + */ + def loadTestData(): Unit = { + assert(sqlContext != null, "attempted to initialize test data before SQLContext.") + emptyTestData + testData + testData2 + testData3 + negativeData + largeAndSmallInts + decimalData + binaryData + upperCaseData + lowerCaseData + arrayData + mapData + repeatedData + nullableRepeatedData + nullInts + allNulls + nullStrings + tableName + unparsedStrings + withEmptyParts + person + salary + complexData + } +} + +/** + * Case classes used in test data. + */ +private[sql] object SQLTestData { + case class TestData(key: Int, value: String) + case class TestData2(a: Int, b: Int) + case class TestData3(a: Int, b: Option[Int]) + case class LargeAndSmallInts(a: Int, b: Int) + case class DecimalData(a: BigDecimal, b: BigDecimal) + case class BinaryData(a: Array[Byte], b: Int) + case class UpperCaseData(N: Int, L: String) + case class LowerCaseData(n: Int, l: String) + case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) + case class MapData(data: scala.collection.Map[Int, String]) + case class StringData(s: String) + case class IntField(i: Int) + case class NullInts(a: Integer) + case class NullStrings(n: Int, s: String) + case class TableName(tableName: String) + case class Person(id: Int, name: String, age: Int) + case class Salary(personId: Int, salary: Double) + case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index fa01823e9417..9214569f18e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -18,16 +18,82 @@ package org.apache.spark.sql.test import java.io.File +import java.util.UUID import scala.util.Try +import scala.language.implicitConversions -import org.apache.spark.sql.SQLContext +import org.apache.hadoop.conf.Configuration +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.util.Utils -trait SQLTestUtils { - def sqlContext: SQLContext +/** + * Helper trait that should be extended by all SQL test suites. + * + * This allows subclasses to plugin a custom [[SQLContext]]. It comes with test data + * prepared in advance as well as all implicit conversions used extensively by dataframes. + * To use implicit methods, import `testImplicits._` instead of through the [[SQLContext]]. + * + * Subclasses should *not* create [[SQLContext]]s in the test suite constructor, which is + * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. + */ +private[sql] trait SQLTestUtils + extends SparkFunSuite + with BeforeAndAfterAll + with SQLTestData { self => + + protected def sparkContext = sqlContext.sparkContext + + // Whether to materialize all test data before the first test is run + private var loadTestDataBeforeTests = false + + // Shorthand for running a query using our SQLContext + protected lazy val sql = sqlContext.sql _ + + /** + * A helper object for importing SQL implicits. + * + * Note that the alternative of importing `sqlContext.implicits._` is not possible here. + * This is because we create the [[SQLContext]] immediately before the first test is run, + * but the implicits import is needed in the constructor. + */ + protected object testImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self.sqlContext + + // This must live here to preserve binary compatibility with Spark < 1.5. + implicit class StringToColumn(val sc: StringContext) { + def $(args: Any*): ColumnName = { + new ColumnName(sc.s(args: _*)) + } + } + } + + /** + * Materialize the test data immediately after the [[SQLContext]] is set up. + * This is necessary if the data is accessed by name but not through direct reference. + */ + protected def setupTestData(): Unit = { + loadTestDataBeforeTests = true + } + + protected override def beforeAll(): Unit = { + super.beforeAll() + if (loadTestDataBeforeTests) { + loadTestData() + } + } - protected def configuration = sqlContext.sparkContext.hadoopConfiguration + /** + * The Hadoop configuration used by the active [[SQLContext]]. + */ + protected def hadoopConfiguration: Configuration = { + sparkContext.hadoopConfiguration + } /** * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL @@ -87,4 +153,80 @@ trait SQLTestUtils { } } } + + /** + * Creates a temporary database and switches current database to it before executing `f`. This + * database is dropped after `f` returns. + */ + protected def withTempDatabase(f: String => Unit): Unit = { + val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" + + try { + sqlContext.sql(s"CREATE DATABASE $dbName") + } catch { case cause: Throwable => + fail("Failed to create temporary database", cause) + } + + try f(dbName) finally sqlContext.sql(s"DROP DATABASE $dbName CASCADE") + } + + /** + * Activates database `db` before executing `f`, then switches back to `default` database after + * `f` returns. + */ + protected def activateDatabase(db: String)(f: => Unit): Unit = { + sqlContext.sql(s"USE $db") + try f finally sqlContext.sql(s"USE default") + } + + /** + * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier + * way to construct [[DataFrame]] directly out of local data without relying on implicits. + */ + protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { + DataFrame(sqlContext, plan) + } +} + +private[sql] object SQLTestUtils { + + def compareAnswers( + sparkAnswer: Seq[Row], + expectedAnswer: Seq[Row], + sort: Boolean): Option[String] = { + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + // This function is copied from Catalyst's QueryTest + val converted: Seq[Row] = answer.map { s => + Row.fromSeq(s.toSeq.map { + case d: java.math.BigDecimal => BigDecimal(d) + case b: Array[Byte] => b.toSeq + case o => o + }) + } + if (sort) { + converted.sortBy(_.toString()) + } else { + converted + } + } + if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { + val errorMessage = + s""" + | == Results == + | ${sideBySide( + s"== Expected Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString()), + s"== Actual Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} + """.stripMargin + Some(errorMessage) + } else { + None + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala new file mode 100644 index 000000000000..963d10eed62e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.sql.SQLContext + + +/** + * Helper trait for SQL test suites where all tests share a single [[TestSQLContext]]. + */ +trait SharedSQLContext extends SQLTestUtils { + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + * + * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local + * mode with the default test configurations. + */ + private var _ctx: TestSQLContext = null + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + */ + protected def sqlContext: SQLContext = _ctx + + /** + * Initialize the [[TestSQLContext]]. + */ + protected override def beforeAll(): Unit = { + if (_ctx == null) { + _ctx = new TestSQLContext + } + // Ensure we have initialized the context before calling parent code + super.beforeAll() + } + + /** + * Stop the underlying [[org.apache.spark.SparkContext]], if any. + */ + protected override def afterAll(): Unit = { + try { + if (_ctx != null) { + _ctx.sparkContext.stop() + _ctx = null + } + } finally { + super.afterAll() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala new file mode 100644 index 000000000000..10e633f3cde4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.{SQLConf, SQLContext} + + +/** + * A special [[SQLContext]] prepared for testing. + */ +private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { self => + + def this() { + this(new SparkContext("local[2]", "test-sql-context", + new SparkConf().set("spark.sql.testkey", "true"))) + } + + // Make sure we set those test specific confs correctly when we create + // the SQLConf as well as when we call clear. + protected[sql] override def createSession(): SQLSession = new this.SQLSession() + + /** A special [[SQLSession]] that uses fewer shuffle partitions than normal. */ + protected[sql] class SQLSession extends super.SQLSession { + protected[sql] override lazy val conf: SQLConf = new SQLConf { + + clear() + + override def clear(): Unit = { + super.clear() + + // Make sure we start with the default test configs even after clear + TestSQLContext.overrideConfs.map { + case (key, value) => setConfString(key, value) + } + } + } + } + + // Needed for Java tests + def loadTestData(): Unit = { + testData.loadTestData() + } + + private object testData extends SQLTestData { + protected override def sqlContext: SQLContext = self + } +} + +private[sql] object TestSQLContext { + + /** + * A map used to store all confs that need to be overridden in sql/core unit tests. + */ + val overrideConfs: Map[String, String] = + Map( + // Fewer shuffle partitions to speed up testing. + SQLConf.SHUFFLE_PARTITIONS.key -> "5") +} diff --git a/sql/core/src/test/scripts/gen-avro.sh b/sql/core/src/test/scripts/gen-avro.sh new file mode 100755 index 000000000000..48174b287fd7 --- /dev/null +++ b/sql/core/src/test/scripts/gen-avro.sh @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +cd $(dirname $0)/.. +BASEDIR=`pwd` +cd - + +rm -rf $BASEDIR/gen-java +mkdir -p $BASEDIR/gen-java + +for input in `ls $BASEDIR/avro/*.avdl`; do + filename=$(basename "$input") + filename="${filename%.*}" + avro-tools idl $input> $BASEDIR/avro/${filename}.avpr + avro-tools compile -string protocol $BASEDIR/avro/${filename}.avpr $BASEDIR/gen-java +done diff --git a/sql/core/src/test/scripts/gen-thrift.sh b/sql/core/src/test/scripts/gen-thrift.sh new file mode 100755 index 000000000000..ada432c68ab9 --- /dev/null +++ b/sql/core/src/test/scripts/gen-thrift.sh @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +cd $(dirname $0)/.. +BASEDIR=`pwd` +cd - + +rm -rf $BASEDIR/gen-java +mkdir -p $BASEDIR/gen-java + +for input in `ls $BASEDIR/thrift/*.thrift`; do + thrift --gen java -out $BASEDIR/gen-java $input +done diff --git a/sql/core/src/test/thrift/parquet-compat.thrift b/sql/core/src/test/thrift/parquet-compat.thrift new file mode 100644 index 000000000000..98bf778aec5d --- /dev/null +++ b/sql/core/src/test/thrift/parquet-compat.thrift @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace java org.apache.spark.sql.execution.datasources.parquet.test.thrift + +enum Suit { + SPADES, + HEARTS, + DIAMONDS, + CLUBS +} + +struct Nested { + 1: required list nestedIntsColumn; + 2: required string nestedStringColumn; +} + +/** + * This is a test struct for testing parquet-thrift compatibility. + */ +struct ParquetThriftCompat { + 1: required bool boolColumn; + 2: required byte byteColumn; + 3: required i16 shortColumn; + 4: required i32 intColumn; + 5: required i64 longColumn; + 6: required double doubleColumn; + 7: required binary binaryColumn; + 8: required string stringColumn; + 9: required Suit enumColumn + + 10: optional bool maybeBoolColumn; + 11: optional byte maybeByteColumn; + 12: optional i16 maybeShortColumn; + 13: optional i32 maybeIntColumn; + 14: optional i64 maybeLongColumn; + 15: optional double maybeDoubleColumn; + 16: optional binary maybeBinaryColumn; + 17: optional string maybeStringColumn; + 18: optional Suit maybeEnumColumn; + + 19: required list stringsColumn; + 20: required set intSetColumn; + 21: required map intToStringColumn; + 22: required map> complexColumn; +} diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 73e6ccdb1eaf..f7fe085f34d8 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -60,21 +60,38 @@ ${hive.group} hive-jdbc + + ${hive.group} + hive-service + ${hive.group} hive-beeline + + com.sun.jersey + jersey-core + + + com.sun.jersey + jersey-json + + + com.sun.jersey + jersey-server + org.seleniumhq.selenium selenium-java test - - - io.netty - netty - - + + + org.apache.spark + spark-sql_${scala.binary.version} + test-jar + ${project.version} + test
      diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala b/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala new file mode 100644 index 000000000000..2228f651e238 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hive.service.server + +import org.apache.hive.service.server.HiveServer2.{StartOptionExecutor, ServerOptionsProcessor} + +/** + * Class to upgrade a package-private class to public, and + * implement a `process()` operation consistent with + * the behavior of older Hive versions + * @param serverName name of the hive server + */ +private[apache] class HiveServerServerOptionsProcessor(serverName: String) + extends ServerOptionsProcessor(serverName) { + + def process(args: Array[String]): Boolean = { + // A parse failure automatically triggers a system exit + val response = super.parse(args) + val executor = response.getServerOptionsExecutor() + // return true if the parsed option was to start the service + executor.isInstanceOf[StartOptionExecutor] + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 700d994bb6a8..dd9fef9206d0 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.hive.thriftserver +import java.util.Locale +import java.util.concurrent.atomic.AtomicBoolean + import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -24,7 +27,7 @@ import org.apache.commons.logging.LogFactory import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.service.cli.thrift.{ThriftBinaryCLIService, ThriftHttpCLIService} -import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} +import org.apache.hive.service.server.{HiveServerServerOptionsProcessor, HiveServer2} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd, SparkListenerJobStart} @@ -32,7 +35,7 @@ import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.{Logging, SparkContext} @@ -65,7 +68,7 @@ object HiveThriftServer2 extends Logging { } def main(args: Array[String]) { - val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2") + val optionsProcessor = new HiveServerServerOptionsProcessor("HiveThriftServer2") if (!optionsProcessor.process(args)) { System.exit(-1) } @@ -73,7 +76,7 @@ object HiveThriftServer2 extends Logging { logInfo("Starting SparkContext") SparkSQLEnv.init() - Utils.addShutdownHook { () => + ShutdownHookManager.addShutdownHook { () => SparkSQLEnv.stop() uiTab.foreach(_.detach()) } @@ -149,16 +152,26 @@ object HiveThriftServer2 extends Logging { override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { server.stop() } - var onlineSessionNum: Int = 0 - val sessionList = new mutable.LinkedHashMap[String, SessionInfo] - val executionList = new mutable.LinkedHashMap[String, ExecutionInfo] - val retainedStatements = - conf.getConf(SQLConf.THRIFTSERVER_UI_STATEMENT_LIMIT) - val retainedSessions = - conf.getConf(SQLConf.THRIFTSERVER_UI_SESSION_LIMIT) - var totalRunning = 0 - - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + private var onlineSessionNum: Int = 0 + private val sessionList = new mutable.LinkedHashMap[String, SessionInfo] + private val executionList = new mutable.LinkedHashMap[String, ExecutionInfo] + private val retainedStatements = conf.getConf(SQLConf.THRIFTSERVER_UI_STATEMENT_LIMIT) + private val retainedSessions = conf.getConf(SQLConf.THRIFTSERVER_UI_SESSION_LIMIT) + private var totalRunning = 0 + + def getOnlineSessionNum: Int = synchronized { onlineSessionNum } + + def getTotalRunning: Int = synchronized { totalRunning } + + def getSessionList: Seq[SessionInfo] = synchronized { sessionList.values.toSeq } + + def getSession(sessionId: String): Option[SessionInfo] = synchronized { + sessionList.get(sessionId) + } + + def getExecutionList: Seq[ExecutionInfo] = synchronized { executionList.values.toSeq } + + override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { for { props <- Option(jobStart.properties) groupId <- Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) @@ -170,15 +183,18 @@ object HiveThriftServer2 extends Logging { } def onSessionCreated(ip: String, sessionId: String, userName: String = "UNKNOWN"): Unit = { - val info = new SessionInfo(sessionId, System.currentTimeMillis, ip, userName) - sessionList.put(sessionId, info) - onlineSessionNum += 1 - trimSessionIfNecessary() + synchronized { + val info = new SessionInfo(sessionId, System.currentTimeMillis, ip, userName) + sessionList.put(sessionId, info) + onlineSessionNum += 1 + trimSessionIfNecessary() + } } - def onSessionClosed(sessionId: String): Unit = { + def onSessionClosed(sessionId: String): Unit = synchronized { sessionList(sessionId).finishTimestamp = System.currentTimeMillis onlineSessionNum -= 1 + trimSessionIfNecessary() } def onStatementStart( @@ -186,7 +202,7 @@ object HiveThriftServer2 extends Logging { sessionId: String, statement: String, groupId: String, - userName: String = "UNKNOWN"): Unit = { + userName: String = "UNKNOWN"): Unit = synchronized { val info = new ExecutionInfo(statement, sessionId, System.currentTimeMillis, userName) info.state = ExecutionState.STARTED executionList.put(id, info) @@ -196,37 +212,41 @@ object HiveThriftServer2 extends Logging { totalRunning += 1 } - def onStatementParsed(id: String, executionPlan: String): Unit = { + def onStatementParsed(id: String, executionPlan: String): Unit = synchronized { executionList(id).executePlan = executionPlan executionList(id).state = ExecutionState.COMPILED } def onStatementError(id: String, errorMessage: String, errorTrace: String): Unit = { - executionList(id).finishTimestamp = System.currentTimeMillis - executionList(id).detail = errorMessage - executionList(id).state = ExecutionState.FAILED - totalRunning -= 1 + synchronized { + executionList(id).finishTimestamp = System.currentTimeMillis + executionList(id).detail = errorMessage + executionList(id).state = ExecutionState.FAILED + totalRunning -= 1 + trimExecutionIfNecessary() + } } - def onStatementFinish(id: String): Unit = { + def onStatementFinish(id: String): Unit = synchronized { executionList(id).finishTimestamp = System.currentTimeMillis executionList(id).state = ExecutionState.FINISHED totalRunning -= 1 + trimExecutionIfNecessary() } - private def trimExecutionIfNecessary() = synchronized { + private def trimExecutionIfNecessary() = { if (executionList.size > retainedStatements) { val toRemove = math.max(retainedStatements / 10, 1) - executionList.take(toRemove).foreach { s => + executionList.filter(_._2.finishTimestamp != 0).take(toRemove).foreach { s => executionList.remove(s._1) } } } - private def trimSessionIfNecessary() = synchronized { + private def trimSessionIfNecessary() = { if (sessionList.size > retainedSessions) { val toRemove = math.max(retainedSessions / 10, 1) - sessionList.take(toRemove).foreach { s => + sessionList.filter(_._2.finishTimestamp != 0).take(toRemove).foreach { s => sessionList.remove(s._1) } } @@ -238,9 +258,12 @@ object HiveThriftServer2 extends Logging { private[hive] class HiveThriftServer2(hiveContext: HiveContext) extends HiveServer2 with ReflectedCompositeService { + // state is tracked internally so that the server only attempts to shut down if it successfully + // started, and then once only. + private val started = new AtomicBoolean(false) override def init(hiveConf: HiveConf) { - val sparkSqlCliService = new SparkSQLCLIService(hiveContext) + val sparkSqlCliService = new SparkSQLCLIService(this, hiveContext) setSuperField(this, "cliService", sparkSqlCliService) addService(sparkSqlCliService) @@ -256,8 +279,19 @@ private[hive] class HiveThriftServer2(hiveContext: HiveContext) } private def isHTTPTransportMode(hiveConf: HiveConf): Boolean = { - val transportMode: String = hiveConf.getVar(ConfVars.HIVE_SERVER2_TRANSPORT_MODE) - transportMode.equalsIgnoreCase("http") + val transportMode = hiveConf.getVar(ConfVars.HIVE_SERVER2_TRANSPORT_MODE) + transportMode.toLowerCase(Locale.ENGLISH).equals("http") } + + override def start(): Unit = { + super.start() + started.set(true) + } + + override def stop(): Unit = { + if (started.getAndSet(false)) { + super.stop() + } + } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index e8758887ff3a..306f98bcb534 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.hive.thriftserver import java.security.PrivilegedExceptionAction import java.sql.{Date, Timestamp} import java.util.concurrent.RejectedExecutionException -import java.util.{Map => JMap, UUID} +import java.util.{Arrays, Map => JMap, UUID} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => SMap} import scala.util.control.NonFatal @@ -32,8 +32,7 @@ import org.apache.hive.service.cli._ import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.hive.ql.metadata.HiveException import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.shims.ShimLoader -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.hive.shims.Utils import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession @@ -127,13 +126,13 @@ private[hive] class SparkExecuteStatementOperation( def getResultSetSchema: TableSchema = { if (result == null || result.queryExecution.analyzed.output.size == 0) { - new TableSchema(new FieldSchema("Result", "string", "") :: Nil) + new TableSchema(Arrays.asList(new FieldSchema("Result", "string", ""))) } else { logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") val schema = result.queryExecution.analyzed.output.map { attr => new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") } - new TableSchema(schema) + new TableSchema(schema.asJava) } } @@ -146,7 +145,7 @@ private[hive] class SparkExecuteStatementOperation( } else { val parentSessionState = SessionState.get() val hiveConf = getConfigForOperation() - val sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf) + val sparkServiceUGI = Utils.getUGI() val sessionHive = getCurrentHive() val currentSqlSession = hiveContext.currentSession @@ -160,6 +159,12 @@ private[hive] class SparkExecuteStatementOperation( // User information is part of the metastore client member in Hive hiveContext.setSession(currentSqlSession) + // Always use the latest class loader provided by executionHive's state. + val executionHiveClassLoader = + hiveContext.executionHive.state.getConf.getClassLoader + sessionHive.getConf.setClassLoader(executionHiveClassLoader) + parentSessionState.getConf.setClassLoader(executionHiveClassLoader) + Hive.set(sessionHive) SessionState.setCurrentSessionState(parentSessionState) try { @@ -174,7 +179,7 @@ private[hive] class SparkExecuteStatementOperation( } try { - ShimLoader.getHadoopShims().doAs(sparkServiceUGI, doAsAction) + sparkServiceUGI.doAs(doAsAction) } catch { case e: Exception => setOperationException(new HiveSQLException(e)) @@ -201,7 +206,7 @@ private[hive] class SparkExecuteStatementOperation( } } - private def runInternal(): Unit = { + override def runInternal(): Unit = { statementId = UUID.randomUUID().toString logInfo(s"Running query '$statement' with $statementId") setState(OperationState.RUNNING) @@ -293,7 +298,7 @@ private[hive] class SparkExecuteStatementOperation( sqlOperationConf = new HiveConf(sqlOperationConf) // apply overlay query specific settings, if any - getConfOverlay().foreach { case (k, v) => + getConfOverlay().asScala.foreach { case (k, v) => try { sqlOperationConf.verifyAndSet(k, v) } catch { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 039cfa40d26b..b5073961a1c8 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.hive.thriftserver -import scala.collection.JavaConversions._ - import java.io._ -import java.util.{ArrayList => JArrayList} +import java.util.{ArrayList => JArrayList, Locale} + +import scala.collection.JavaConverters._ -import jline.{ConsoleReader, History} +import jline.console.ConsoleReader +import jline.console.history.FileHistory import org.apache.commons.lang3.StringUtils import org.apache.commons.logging.LogFactory @@ -38,9 +39,13 @@ import org.apache.thrift.transport.TSocket import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} -private[hive] object SparkSQLCLIDriver { +/** + * This code doesn't support remote connections in Hive 1.2+, as the underlying CliDriver + * has dropped its support. + */ +private[hive] object SparkSQLCLIDriver extends Logging { private var prompt = "spark-sql" private var continuedPrompt = "".padTo(prompt.length, ' ') private var transport: TSocket = _ @@ -96,9 +101,9 @@ private[hive] object SparkSQLCLIDriver { // Set all properties specified via command line. val conf: HiveConf = sessionState.getConf - sessionState.cmdProperties.entrySet().foreach { item => - val key = item.getKey.asInstanceOf[String] - val value = item.getValue.asInstanceOf[String] + sessionState.cmdProperties.entrySet().asScala.foreach { item => + val key = item.getKey.toString + val value = item.getValue.toString // We do not propagate metastore options to the execution copy of hive. if (key != "javax.jdo.option.ConnectionURL") { conf.set(key, value) @@ -109,18 +114,11 @@ private[hive] object SparkSQLCLIDriver { SessionState.start(sessionState) // Clean up after we exit - Utils.addShutdownHook { () => SparkSQLEnv.stop() } + ShutdownHookManager.addShutdownHook { () => SparkSQLEnv.stop() } + val remoteMode = isRemoteMode(sessionState) // "-h" option has been passed, so connect to Hive thrift server. - if (sessionState.getHost != null) { - sessionState.connect() - if (sessionState.isRemoteMode) { - prompt = s"[${sessionState.getHost}:${sessionState.getPort}]" + prompt - continuedPrompt = "".padTo(prompt.length, ' ') - } - } - - if (!sessionState.isRemoteMode) { + if (!remoteMode) { // Hadoop-20 and above - we need to augment classpath using hiveconf // components. // See also: code in ExecDriver.java @@ -131,6 +129,9 @@ private[hive] object SparkSQLCLIDriver { } conf.setClassLoader(loader) Thread.currentThread().setContextClassLoader(loader) + } else { + // Hive 1.2 + not supported in CLI + throw new RuntimeException("Remote operations not supported") } val cli = new SparkSQLCLIDriver @@ -164,36 +165,41 @@ private[hive] object SparkSQLCLIDriver { } } catch { case e: FileNotFoundException => - System.err.println(s"Could not open input file for reading. (${e.getMessage})") + logError(s"Could not open input file for reading. (${e.getMessage})") System.exit(3) } val reader = new ConsoleReader() reader.setBellEnabled(false) + reader.setExpandEvents(false) // reader.setDebug(new PrintWriter(new FileWriter("writer.debug", true))) - CliDriver.getCommandCompletor.foreach((e) => reader.addCompletor(e)) + CliDriver.getCommandCompleter.foreach((e) => reader.addCompleter(e)) val historyDirectory = System.getProperty("user.home") try { if (new File(historyDirectory).exists()) { val historyFile = historyDirectory + File.separator + ".hivehistory" - reader.setHistory(new History(new File(historyFile))) + reader.setHistory(new FileHistory(new File(historyFile))) } else { - System.err.println("WARNING: Directory for Hive history file: " + historyDirectory + + logWarning("WARNING: Directory for Hive history file: " + historyDirectory + " does not exist. History will not be available during this session.") } } catch { case e: Exception => - System.err.println("WARNING: Encountered an error while trying to initialize Hive's " + + logWarning("WARNING: Encountered an error while trying to initialize Hive's " + "history file. History will not be available during this session.") - System.err.println(e.getMessage) + logWarning(e.getMessage) } + // TODO: missing +/* val clientTransportTSocketField = classOf[CliSessionState].getDeclaredField("transport") clientTransportTSocketField.setAccessible(true) transport = clientTransportTSocketField.get(sessionState).asInstanceOf[TSocket] +*/ + transport = null var ret = 0 var prefix = "" @@ -230,6 +236,13 @@ private[hive] object SparkSQLCLIDriver { System.exit(ret) } + + + def isRemoteMode(state: CliSessionState): Boolean = { + // sessionState.isRemoteMode + state.isHiveServerQuery + } + } private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { @@ -239,25 +252,33 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { private val console = new SessionState.LogHelper(LOG) + private val isRemoteMode = { + SparkSQLCLIDriver.isRemoteMode(sessionState) + } + private val conf: Configuration = if (sessionState != null) sessionState.getConf else new Configuration() // Force initializing SparkSQLEnv. This is put here but not object SparkSQLCliDriver // because the Hive unit tests do not go through the main() code path. - if (!sessionState.isRemoteMode) { + if (!isRemoteMode) { SparkSQLEnv.init() + } else { + // Hive 1.2 + not supported in CLI + throw new RuntimeException("Remote operations not supported") } override def processCmd(cmd: String): Int = { val cmd_trimmed: String = cmd.trim() + val cmd_lower = cmd_trimmed.toLowerCase(Locale.ENGLISH) val tokens: Array[String] = cmd_trimmed.split("\\s+") val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() - if (cmd_trimmed.toLowerCase.equals("quit") || - cmd_trimmed.toLowerCase.equals("exit") || - tokens(0).equalsIgnoreCase("source") || + if (cmd_lower.equals("quit") || + cmd_lower.equals("exit") || + tokens(0).toLowerCase(Locale.ENGLISH).equals("source") || cmd_trimmed.startsWith("!") || tokens(0).toLowerCase.equals("list") || - sessionState.isRemoteMode) { + isRemoteMode) { val start = System.currentTimeMillis() super.processCmd(cmd) val end = System.currentTimeMillis() @@ -270,6 +291,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { val proc: CommandProcessor = CommandProcessorFactory.get(Array(tokens(0)), hconf) if (proc != null) { + // scalastyle:off println if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor] || proc.isInstanceOf[AddResourceProcessor]) { val driver = new SparkSQLDriver @@ -295,15 +317,15 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { if (HiveConf.getBoolVar(conf, HiveConf.ConfVars.HIVE_CLI_PRINT_HEADER)) { // Print the column names. - Option(driver.getSchema.getFieldSchemas).map { fields => - out.println(fields.map(_.getName).mkString("\t")) + Option(driver.getSchema.getFieldSchemas).foreach { fields => + out.println(fields.asScala.map(_.getName).mkString("\t")) } } var counter = 0 try { while (!out.checkError() && driver.getResults(res)) { - res.foreach{ l => + res.asScala.foreach { l => counter += 1 out.println(l) } @@ -336,6 +358,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { } ret = proc.run(cmd_1).getResponseCode } + // scalastyle:on println } ret } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index 41f647d5f8c5..5ad8c54f296d 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -21,36 +21,37 @@ import java.io.IOException import java.util.{List => JList} import javax.security.auth.login.LoginException +import scala.collection.JavaConverters._ + import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.hadoop.hive.shims.Utils import org.apache.hadoop.security.UserGroupInformation import org.apache.hive.service.Service.STATE import org.apache.hive.service.auth.HiveAuthFactory import org.apache.hive.service.cli._ +import org.apache.hive.service.server.HiveServer2 import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import scala.collection.JavaConversions._ - -private[hive] class SparkSQLCLIService(hiveContext: HiveContext) - extends CLIService +private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, hiveContext: HiveContext) + extends CLIService(hiveServer) with ReflectedCompositeService { override def init(hiveConf: HiveConf) { setSuperField(this, "hiveConf", hiveConf) - val sparkSqlSessionManager = new SparkSQLSessionManager(hiveContext) + val sparkSqlSessionManager = new SparkSQLSessionManager(hiveServer, hiveContext) setSuperField(this, "sessionManager", sparkSqlSessionManager) addService(sparkSqlSessionManager) var sparkServiceUGI: UserGroupInformation = null - if (ShimLoader.getHadoopShims.isSecurityEnabled) { + if (UserGroupInformation.isSecurityEnabled) { try { HiveAuthFactory.loginFromKeytab(hiveConf) - sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf) + sparkServiceUGI = Utils.getUGI() setSuperField(this, "serviceUGI", sparkServiceUGI) } catch { case e @ (_: IOException | _: LoginException) => @@ -75,7 +76,7 @@ private[thriftserver] trait ReflectedCompositeService { this: AbstractService => def initCompositeService(hiveConf: HiveConf) { // Emulating `CompositeService.init(hiveConf)` val serviceList = getAncestorField[JList[Service]](this, 2, "serviceList") - serviceList.foreach(_.init(hiveConf)) + serviceList.asScala.foreach(_.init(hiveConf)) // Emulating `AbstractService.init(hiveConf)` invoke(classOf[AbstractService], this, "ensureCurrentState", classOf[STATE] -> STATE.NOTINITED) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 77272aecf283..2619286afc14 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.hive.thriftserver -import java.util.{ArrayList => JArrayList, List => JList} +import java.util.{Arrays, ArrayList => JArrayList, List => JList} + +import scala.collection.JavaConverters._ import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} @@ -27,8 +29,6 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse import org.apache.spark.Logging import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -import scala.collection.JavaConversions._ - private[hive] class SparkSQLDriver( val context: HiveContext = SparkSQLEnv.hiveContext) extends Driver @@ -43,14 +43,14 @@ private[hive] class SparkSQLDriver( private def getResultSetSchema(query: context.QueryExecution): Schema = { val analyzed = query.analyzed logDebug(s"Result Schema: ${analyzed.output}") - if (analyzed.output.size == 0) { - new Schema(new FieldSchema("Response code", "string", "") :: Nil, null) + if (analyzed.output.isEmpty) { + new Schema(Arrays.asList(new FieldSchema("Response code", "string", "")), null) } else { val fieldSchemas = analyzed.output.map { attr => new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") } - new Schema(fieldSchemas, null) + new Schema(fieldSchemas.asJava, null) } } @@ -79,7 +79,7 @@ private[hive] class SparkSQLDriver( if (hiveResponse == null) { false } else { - res.asInstanceOf[JArrayList[String]].addAll(hiveResponse) + res.asInstanceOf[JArrayList[String]].addAll(hiveResponse.asJava) hiveResponse = null true } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 79eda1f5123b..bacf6cc458fd 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.thriftserver import java.io.PrintStream -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.scheduler.StatsReportListener import org.apache.spark.sql.hive.HiveContext @@ -38,9 +38,14 @@ private[hive] object SparkSQLEnv extends Logging { val sparkConf = new SparkConf(loadDefaults = true) val maybeSerializer = sparkConf.getOption("spark.serializer") val maybeKryoReferenceTracking = sparkConf.getOption("spark.kryo.referenceTracking") + // If user doesn't specify the appName, we want to get [SparkSQL::localHostName] instead of + // the default appName [SparkSQLCLIDriver] in cli or beeline. + val maybeAppName = sparkConf + .getOption("spark.app.name") + .filterNot(_ == classOf[SparkSQLCLIDriver].getName) sparkConf - .setAppName(s"SparkSQL::${Utils.localHostName()}") + .setAppName(maybeAppName.getOrElse(s"SparkSQL::${Utils.localHostName()}")) .set( "spark.serializer", maybeSerializer.getOrElse("org.apache.spark.serializer.KryoSerializer")) @@ -59,7 +64,7 @@ private[hive] object SparkSQLEnv extends Logging { hiveContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) if (log.isDebugEnabled) { - hiveContext.hiveconf.getAllProperties.toSeq.sorted.foreach { case (k, v) => + hiveContext.hiveconf.getAllProperties.asScala.toSeq.sorted.foreach { case (k, v) => logDebug(s"HiveConf var: $k=$v") } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index 2d5ee6800228..92ac0ec3fca2 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -25,14 +25,15 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.service.cli.SessionHandle import org.apache.hive.service.cli.session.SessionManager import org.apache.hive.service.cli.thrift.TProtocolVersion +import org.apache.hive.service.server.HiveServer2 import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager -private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) - extends SessionManager +private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext: HiveContext) + extends SessionManager(hiveServer) with ReflectedCompositeService { private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) @@ -55,12 +56,14 @@ private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) protocol: TProtocolVersion, username: String, passwd: String, + ipAddress: String, sessionConf: java.util.Map[String, String], withImpersonation: Boolean, delegationToken: String): SessionHandle = { hiveContext.openSession() - val sessionHandle = super.openSession( - protocol, username, passwd, sessionConf, withImpersonation, delegationToken) + val sessionHandle = + super.openSession(protocol, username, passwd, ipAddress, sessionConf, withImpersonation, + delegationToken) val session = super.getSession(sessionHandle) HiveThriftServer2.listener.onSessionCreated( session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 10c83d8b27a2..e990bd06011f 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -39,14 +39,16 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" /** Render the page */ def render(request: HttpServletRequest): Seq[Node] = { val content = - generateBasicStats() ++ -
      ++ -

      - {listener.onlineSessionNum} session(s) are online, - running {listener.totalRunning} SQL statement(s) -

      ++ - generateSessionStatsTable() ++ - generateSQLStatsTable() + listener.synchronized { // make sure all parts in this page are consistent + generateBasicStats() ++ +
      ++ +

      + {listener.getOnlineSessionNum} session(s) are online, + running {listener.getTotalRunning} SQL statement(s) +

      ++ + generateSessionStatsTable() ++ + generateSQLStatsTable() + } UIUtils.headerSparkPage("JDBC/ODBC Server", content, parent, Some(5000)) } @@ -65,11 +67,11 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" /** Generate stats of batch statements of the thrift server program */ private def generateSQLStatsTable(): Seq[Node] = { - val numStatement = listener.executionList.size + val numStatement = listener.getExecutionList.size val table = if (numStatement > 0) { val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", "Statement", "State", "Detail") - val dataRows = listener.executionList.values + val dataRows = listener.getExecutionList def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => @@ -136,15 +138,15 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" /** Generate stats of batch sessions of the thrift server program */ private def generateSessionStatsTable(): Seq[Node] = { - val numBatches = listener.sessionList.size + val sessionList = listener.getSessionList + val numBatches = sessionList.size val table = if (numBatches > 0) { - val dataRows = - listener.sessionList.values + val dataRows = sessionList val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", "Total Execute") def generateDataRow(session: SessionInfo): Seq[Node] = { - val sessionLink = "%s/sql/session?id=%s" - .format(UIUtils.prependBaseUri(parent.basePath), session.sessionId) + val sessionLink = "%s/%s/session?id=%s" + .format(UIUtils.prependBaseUri(parent.basePath), parent.prefix, session.sessionId)
    diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index 3b01afa603ce..af16cb31df18 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -40,21 +40,22 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) def render(request: HttpServletRequest): Seq[Node] = { val parameterId = request.getParameter("id") require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") - val sessionStat = listener.sessionList.find(stat => { - stat._1 == parameterId - }).getOrElse(null) - require(sessionStat != null, "Invalid sessionID[" + parameterId + "]") val content = - generateBasicStats() ++ -
    ++ -

    - User {sessionStat._2.userName}, - IP {sessionStat._2.ip}, - Session created at {formatDate(sessionStat._2.startTimestamp)}, - Total run {sessionStat._2.totalExecution} SQL -

    ++ - generateSQLStatsTable(sessionStat._2.sessionId) + listener.synchronized { // make sure all parts in this page are consistent + val sessionStat = listener.getSession(parameterId).getOrElse(null) + require(sessionStat != null, "Invalid sessionID[" + parameterId + "]") + + generateBasicStats() ++ +
    ++ +

    + User {sessionStat.userName}, + IP {sessionStat.ip}, + Session created at {formatDate(sessionStat.startTimestamp)}, + Total run {sessionStat.totalExecution} SQL +

    ++ + generateSQLStatsTable(sessionStat.sessionId) + } UIUtils.headerSparkPage("JDBC/ODBC Session", content, parent, Some(5000)) } @@ -73,13 +74,13 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) /** Generate stats of batch statements of the thrift server program */ private def generateSQLStatsTable(sessionID: String): Seq[Node] = { - val executionList = listener.executionList - .filter(_._2.sessionId == sessionID) + val executionList = listener.getExecutionList + .filter(_.sessionId == sessionID) val numStatement = executionList.size val table = if (numStatement > 0) { val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", "Statement", "State", "Detail") - val dataRows = executionList.values.toSeq.sortBy(_.startTimestamp).reverse + val dataRows = executionList.sortBy(_.startTimestamp).reverse def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => @@ -146,10 +147,11 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) /** Generate stats of batch sessions of the thrift server program */ private def generateSessionStatsTable(): Seq[Node] = { - val numBatches = listener.sessionList.size + val sessionList = listener.getSessionList + val numBatches = sessionList.size val table = if (numBatches > 0) { val dataRows = - listener.sessionList.values.toSeq.sortBy(_.startTimestamp).reverse.map ( session => + sessionList.sortBy(_.startTimestamp).reverse.map ( session => Seq( session.userName, session.ip, diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala index 94fd8a6bb60b..4eabeaa6735e 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala @@ -27,9 +27,9 @@ import org.apache.spark.{SparkContext, Logging, SparkException} * This assumes the given SparkContext has enabled its SparkUI. */ private[thriftserver] class ThriftServerTab(sparkContext: SparkContext) - extends SparkUITab(getSparkUI(sparkContext), "sql") with Logging { + extends SparkUITab(getSparkUI(sparkContext), "sqlserver") with Logging { - override val name = "SQL" + override val name = "JDBC/ODBC Server" val parent = getSparkUI(sparkContext) val listener = HiveThriftServer2.listener diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 13b0c5951ddd..e59a14ec00d5 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -18,17 +18,19 @@ package org.apache.spark.sql.hive.thriftserver import java.io._ +import java.sql.Timestamp +import java.util.Date import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.concurrent.{Await, Promise} -import scala.sys.process.{Process, ProcessLogger} +import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.BeforeAndAfter -import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkFunSuite} /** * A test suite for the `spark-sql` CLI tool. Note that all test cases share the same temporary @@ -37,43 +39,62 @@ import org.apache.spark.util.Utils class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { val warehousePath = Utils.createTempDir() val metastorePath = Utils.createTempDir() + val scratchDirPath = Utils.createTempDir() before { - warehousePath.delete() - metastorePath.delete() + warehousePath.delete() + metastorePath.delete() + scratchDirPath.delete() } after { - warehousePath.delete() - metastorePath.delete() + warehousePath.delete() + metastorePath.delete() + scratchDirPath.delete() } + /** + * Run a CLI operation and expect all the queries and expected answers to be returned. + * @param timeout maximum time for the commands to complete + * @param extraArgs any extra arguments + * @param errorResponses a sequence of strings whose presence in the stdout of the forked process + * is taken as an immediate error condition. That is: if a line beginning + * with one of these strings is found, fail the test immediately. + * The default value is `Seq("Error:")` + * + * @param queriesAndExpectedAnswers one or more tupes of query + answer + */ def runCliWithin( timeout: FiniteDuration, - extraArgs: Seq[String] = Seq.empty)( + extraArgs: Seq[String] = Seq.empty, + errorResponses: Seq[String] = Seq("Error:"))( queriesAndExpectedAnswers: (String, String)*): Unit = { val (queries, expectedAnswers) = queriesAndExpectedAnswers.unzip - val cliScript = "../../bin/spark-sql".split("/").mkString(File.separator) + // Explicitly adds ENTER for each statement to make sure they are actually entered into the CLI. + val queriesString = queries.map(_ + "\n").mkString val command = { + val cliScript = "../../bin/spark-sql".split("/").mkString(File.separator) val jdbcUrl = s"jdbc:derby:;databaseName=$metastorePath;create=true" s"""$cliScript | --master local | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$jdbcUrl | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath + | --hiveconf ${ConfVars.SCRATCHDIR}=$scratchDirPath """.stripMargin.split("\\s+").toSeq ++ extraArgs } var next = 0 val foundAllExpectedAnswers = Promise.apply[Unit]() - // Explicitly adds ENTER for each statement to make sure they are actually entered into the CLI. - val queryStream = new ByteArrayInputStream(queries.map(_ + "\n").mkString.getBytes) val buffer = new ArrayBuffer[String]() val lock = new Object def captureOutput(source: String)(line: String): Unit = lock.synchronized { - buffer += s"$source> $line" + // This test suite sometimes gets extremely slow out of unknown reason on Jenkins. Here we + // add a timestamp to provide more diagnosis information. + buffer += s"${new Timestamp(new Date().getTime)} - $source> $line" + // If we haven't found all expected answers and another expected answer comes up... if (next < expectedAnswers.size && line.startsWith(expectedAnswers(next))) { next += 1 @@ -81,23 +102,36 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { if (next == expectedAnswers.size) { foundAllExpectedAnswers.trySuccess(()) } + } else { + errorResponses.foreach { r => + if (line.startsWith(r)) { + foundAllExpectedAnswers.tryFailure( + new RuntimeException(s"Failed with error line '$line'")) + } + } } } - // Searching expected output line from both stdout and stderr of the CLI process - val process = (Process(command, None) #< queryStream).run( - ProcessLogger(captureOutput("stdout"), captureOutput("stderr"))) + val process = new ProcessBuilder(command: _*).start() + + val stdinWriter = new OutputStreamWriter(process.getOutputStream) + stdinWriter.write(queriesString) + stdinWriter.flush() + stdinWriter.close() + + new ProcessOutputCapturer(process.getInputStream, captureOutput("stdout")).start() + new ProcessOutputCapturer(process.getErrorStream, captureOutput("stderr")).start() try { Await.result(foundAllExpectedAnswers.future, timeout) } catch { case cause: Throwable => - logError( + val message = s""" |======================= |CliSuite failure output |======================= |Spark SQL CLI command line: ${command.mkString(" ")} - | + |Exception: $cause |Executed query $next "${queries(next)}", |But failed to capture expected output "${expectedAnswers(next)}" within $timeout. | @@ -105,8 +139,9 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { |=========================== |End CliSuite failure output |=========================== - """.stripMargin, cause) - throw cause + """.stripMargin + logError(message, cause) + fail(message, cause) } finally { process.destroy() } @@ -137,7 +172,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { } test("Single command with --database") { - runCliWithin(1.minute)( + runCliWithin(2.minute)( "CREATE DATABASE hive_test_db;" -> "OK", "USE hive_test_db;" @@ -148,7 +183,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { -> "Time taken: " ) - runCliWithin(1.minute, Seq("--database", "hive_test_db", "-e", "SHOW TABLES;"))( + runCliWithin(2.minute, Seq("--database", "hive_test_db", "-e", "SHOW TABLES;"))( "" -> "OK", "" diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 301aa5a6411e..b72249b3bf8c 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -19,14 +19,12 @@ package org.apache.spark.sql.hive.thriftserver import java.io.File import java.net.URL -import java.nio.charset.StandardCharsets import java.sql.{Date, DriverManager, SQLException, Statement} import scala.collection.mutable.ArrayBuffer +import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.concurrent.{Await, Promise, future} -import scala.concurrent.ExecutionContext.Implicits.global -import scala.sys.process.{Process, ProcessLogger} import scala.util.{Random, Try} import com.google.common.base.Charsets.UTF_8 @@ -41,9 +39,10 @@ import org.apache.thrift.protocol.TBinaryProtocol import org.apache.thrift.transport.TSocket import org.scalatest.BeforeAndAfterAll -import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkFunSuite} object TestData { def getTestDataFilePath(name: String): URL = { @@ -378,6 +377,60 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { rs2.close() } } + + test("test add jar") { + withMultipleConnectionJdbcStatement( + { + statement => + val jarFile = + "../hive/src/test/resources/hive-hcatalog-core-0.13.1.jar" + .split("/") + .mkString(File.separator) + + statement.executeQuery(s"ADD JAR $jarFile") + }, + + { + statement => + val queries = Seq( + "DROP TABLE IF EXISTS smallKV", + "CREATE TABLE smallKV(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE smallKV", + "DROP TABLE IF EXISTS addJar", + """CREATE TABLE addJar(key string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe' + """.stripMargin) + + queries.foreach(statement.execute) + + statement.executeQuery( + """ + |INSERT INTO TABLE addJar SELECT 'k1' as key FROM smallKV limit 1 + """.stripMargin) + + val actualResult = + statement.executeQuery("SELECT key FROM addJar") + val actualResultBuffer = new collection.mutable.ArrayBuffer[String]() + while (actualResult.next()) { + actualResultBuffer += actualResult.getString(1) + } + actualResult.close() + + val expectedResult = + statement.executeQuery("SELECT 'k1'") + val expectedResultBuffer = new collection.mutable.ArrayBuffer[String]() + while (expectedResult.next()) { + expectedResultBuffer += expectedResult.getString(1) + } + expectedResult.close() + + assert(expectedResultBuffer === actualResultBuffer) + + statement.executeQuery("DROP TABLE IF EXISTS addJar") + statement.executeQuery("DROP TABLE IF EXISTS smallKV") + } + ) + } } class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { @@ -417,7 +470,7 @@ object ServerMode extends Enumeration { } abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { - Class.forName(classOf[HiveDriver].getCanonicalName) + Utils.classForName(classOf[HiveDriver].getCanonicalName) private def jdbcUri = if (mode == ServerMode.http) { s"""jdbc:hive2://localhost:$serverPort/ @@ -483,7 +536,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl val tempLog4jConf = Utils.createTempDir().getCanonicalPath Files.write( - """log4j.rootCategory=INFO, console + """log4j.rootCategory=DEBUG, console |log4j.appender.console=org.apache.log4j.ConsoleAppender |log4j.appender.console.target=System.err |log4j.appender.console.layout=org.apache.log4j.PatternLayout @@ -492,7 +545,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl new File(s"$tempLog4jConf/log4j.properties"), UTF_8) - tempLog4jConf + File.pathSeparator + sys.props("java.class.path") + tempLog4jConf } s"""$startScript @@ -508,6 +561,20 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl """.stripMargin.split("\\s+").toSeq } + /** + * String to scan for when looking for the the thrift binary endpoint running. + * This can change across Hive versions. + */ + val THRIFT_BINARY_SERVICE_LIVE = "Starting ThriftBinaryCLIService on port" + + /** + * String to scan for when looking for the the thrift HTTP endpoint running. + * This can change across Hive versions. + */ + val THRIFT_HTTP_SERVICE_LIVE = "Started ThriftHttpCLIService in http" + + val SERVER_STARTUP_TIMEOUT = 3.minutes + private def startThriftServer(port: Int, attempt: Int) = { warehousePath = Utils.createTempDir() warehousePath.delete() @@ -528,45 +595,59 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl logInfo(s"Trying to start HiveThriftServer2: port=$port, mode=$mode, attempt=$attempt") - val env = Seq( - // Disables SPARK_TESTING to exclude log4j.properties in test directories. - "SPARK_TESTING" -> "0", - // Points SPARK_PID_DIR to SPARK_HOME, otherwise only 1 Thrift server instance can be started - // at a time, which is not Jenkins friendly. - "SPARK_PID_DIR" -> pidDir.getCanonicalPath) - - logPath = Process(command, None, env: _*).lines.collectFirst { - case line if line.contains(LOG_FILE_MARK) => new File(line.drop(LOG_FILE_MARK.length)) - }.getOrElse { - throw new RuntimeException("Failed to find HiveThriftServer2 log file.") + logPath = { + val lines = Utils.executeAndGetOutput( + command = command, + extraEnvironment = Map( + // Disables SPARK_TESTING to exclude log4j.properties in test directories. + "SPARK_TESTING" -> "0", + // Points SPARK_PID_DIR to SPARK_HOME, otherwise only 1 Thrift server instance can be + // started at a time, which is not Jenkins friendly. + "SPARK_PID_DIR" -> pidDir.getCanonicalPath), + redirectStderr = true) + + lines.split("\n").collectFirst { + case line if line.contains(LOG_FILE_MARK) => new File(line.drop(LOG_FILE_MARK.length)) + }.getOrElse { + throw new RuntimeException("Failed to find HiveThriftServer2 log file.") + } } val serverStarted = Promise[Unit]() // Ensures that the following "tail" command won't fail. logPath.createNewFile() - logTailingProcess = + val successLines = Seq(THRIFT_BINARY_SERVICE_LIVE, THRIFT_HTTP_SERVICE_LIVE) + + logTailingProcess = { + val command = s"/usr/bin/env tail -n +0 -f ${logPath.getCanonicalPath}".split(" ") // Using "-n +0" to make sure all lines in the log file are checked. - Process(s"/usr/bin/env tail -n +0 -f ${logPath.getCanonicalPath}").run(ProcessLogger( - (line: String) => { - diagnosisBuffer += line + val builder = new ProcessBuilder(command: _*) + val captureOutput = (line: String) => diagnosisBuffer.synchronized { + diagnosisBuffer += line - if (line.contains("ThriftBinaryCLIService listening on") || - line.contains("Started ThriftHttpCLIService in http")) { + successLines.foreach { r => + if (line.contains(r)) { serverStarted.trySuccess(()) - } else if (line.contains("HiveServer2 is stopped")) { - // This log line appears when the server fails to start and terminates gracefully (e.g. - // because of port contention). - serverStarted.tryFailure(new RuntimeException("Failed to start HiveThriftServer2")) } - })) + } + } + + val process = builder.start() + + new ProcessOutputCapturer(process.getInputStream, captureOutput).start() + new ProcessOutputCapturer(process.getErrorStream, captureOutput).start() + process + } - Await.result(serverStarted.future, 2.minute) + Await.result(serverStarted.future, SERVER_STARTUP_TIMEOUT) } private def stopThriftServer(): Unit = { // The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while. - Process(stopScript, None, "SPARK_PID_DIR" -> pidDir.getCanonicalPath).run().exitValue() + Utils.executeAndGetOutput( + command = Seq(stopScript), + extraEnvironment = Map("SPARK_PID_DIR" -> pidDir.getCanonicalPath)) Thread.sleep(3.seconds.toMillis) warehousePath.delete() diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala index 806240e6de45..bf431cd6b026 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala @@ -27,7 +27,6 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.selenium.WebBrowser import org.scalatest.time.SpanSugar._ -import org.apache.spark.sql.hive.HiveContext import org.apache.spark.ui.SparkUICssErrorHandler class UISeleniumSuite @@ -36,7 +35,6 @@ class UISeleniumSuite implicit var webDriver: WebDriver = _ var server: HiveThriftServer2 = _ - var hc: HiveContext = _ val uiPort = 20000 + Random.nextInt(10000) override def mode: ServerMode.Value = ServerMode.binary diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index f88e62763ca7..ab309e0a1d36 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.util.{Locale, TimeZone} +import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SQLConf @@ -50,6 +51,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, 5) // Enable in-memory partition pruning for testing purposes TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) + RuleExecutor.resetTime() } override def afterAll() { @@ -58,6 +60,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { Locale.setDefault(originalLocale) TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + + // For debugging dump some statistics about how much time was spent in various optimizer rules. + logWarning(RuleExecutor.dumpTimeSpent()) } /** A list of tests deemed out of scope currently and thus completely disregarded. */ @@ -115,6 +120,13 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // This test is totally fine except that it includes wrong queries and expects errors, but error // message format in Hive and Spark SQL differ. Should workaround this later. "udf_to_unix_timestamp", + // we can cast dates likes '2015-03-18' to a timestamp and extract the seconds. + // Hive returns null for second('2015-03-18') + "udf_second", + // we can cast dates likes '2015-03-18' to a timestamp and extract the minutes. + // Hive returns null for minute('2015-03-18') + "udf_minute", + // Cant run without local map/reduce. "index_auto_update", @@ -221,9 +233,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_when", "udf_case", - // Needs constant object inspectors - "udf_round", - // the table src(key INT, value STRING) is not the same as HIVE unittest. In Hive // is src(key STRING, value STRING), and in the reflect.q, it failed in // Integer.valueOf, which expect the first argument passed as STRING type not INT. @@ -254,9 +263,46 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // the answer is sensitive for jdk version "udf_java_method", - // Spark SQL use Long for TimestampType, lose the precision under 100ns + // Spark SQL use Long for TimestampType, lose the precision under 1us "timestamp_1", - "timestamp_2" + "timestamp_2", + "timestamp_udf", + + // Hive returns string from UTC formatted timestamp, spark returns timestamp type + "date_udf", + + // Can't compare the result that have newline in it + "udf_get_json_object", + + // Unlike Hive, we do support log base in (0, 1.0], therefore disable this + "udf7", + + // Trivial changes to DDL output + "compute_stats_empty_table", + "compute_stats_long", + "create_view_translate", + "show_create_table_serde", + "show_tblproperties", + + // Odd changes to output + "merge4", + + // Thift is broken... + "inputddl8", + + // Hive changed ordering of ddl: + "varchar_union1", + + // Parser changes in Hive 1.2 + "input25", + "input26", + + // Uses invalid table name + "innerjoin", + + // classpath problems + "compute_stats.*", + "udf_bitmap_.*" ) /** @@ -389,7 +435,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "date_comparison", "date_join1", "date_serde", - "date_udf", "decimal_1", "decimal_4", "decimal_join", @@ -803,7 +848,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "timestamp_comparison", "timestamp_lazy", "timestamp_null", - "timestamp_udf", "touch", "transform_ppr1", "transform_ppr2", @@ -815,23 +859,21 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udaf_covar_pop", "udaf_covar_samp", "udaf_histogram_numeric", - "udaf_number_format", "udf2", "udf5", "udf6", - // "udf7", turn this on after we figure out null vs nan vs infinity "udf8", "udf9", "udf_10_trims", "udf_E", "udf_PI", "udf_abs", - // "udf_acos", turn this on after we figure out null vs nan vs infinity + "udf_acos", "udf_add", "udf_array", "udf_array_contains", "udf_ascii", - // "udf_asin", turn this on after we figure out null vs nan vs infinity + "udf_asin", "udf_atan", "udf_avg", "udf_bigint", @@ -895,7 +937,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_lpad", "udf_ltrim", "udf_map", - "udf_minute", "udf_modulo", "udf_month", "udf_named_struct", @@ -919,10 +960,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_repeat", "udf_rlike", "udf_round", - // "udf_round_3", TODO: FIX THIS failed due to cast exception + "udf_round_3", "udf_rpad", "udf_rtrim", - "udf_second", "udf_sign", "udf_sin", "udf_smallint", @@ -949,6 +989,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_trim", "udf_ucase", "udf_unix_timestamp", + "udf_unhex", "udf_upper", "udf_var_pop", "udf_var_samp", diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index 934452fe579a..92bb9e6d73af 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.Utils * for different tests and there are a few properties needed to let Hive generate golden * files, every `createQueryTest` calls should explicitly set `reset` to `false`. */ -abstract class HiveWindowFunctionQueryBaseSuite extends HiveComparisonTest with BeforeAndAfter { +class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfter { private val originalTimeZone = TimeZone.getDefault private val originalLocale = Locale.getDefault private val testTempDir = Utils.createTempDir() @@ -526,8 +526,14 @@ abstract class HiveWindowFunctionQueryBaseSuite extends HiveComparisonTest with | rows between 2 preceding and 2 following); """.stripMargin, reset = false) + // collect_set() output array in an arbitrary order, hence causes different result + // when running this test suite under Java 7 and 8. + // We change the original sql query a little bit for making the test suite passed + // under different JDK createQueryTest("windowing.q -- 20. testSTATs", """ + |select p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp + |from ( |select p_mfgr,p_name, p_size, |stddev(p_retailprice) over w1 as sdev, |stddev_pop(p_retailprice) over w1 as sdev_pop, @@ -538,6 +544,8 @@ abstract class HiveWindowFunctionQueryBaseSuite extends HiveComparisonTest with |from part |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name | rows between 2 preceding and 2 following) + |) t lateral view explode(uniq_size) d as uniq_data + |order by p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp """.stripMargin, reset = false) createQueryTest("windowing.q -- 21. testDISTs", @@ -751,21 +759,7 @@ abstract class HiveWindowFunctionQueryBaseSuite extends HiveComparisonTest with """.stripMargin, reset = false) } -class HiveWindowFunctionQueryWithoutCodeGenSuite extends HiveWindowFunctionQueryBaseSuite { - var originalCodegenEnabled: Boolean = _ - override def beforeAll(): Unit = { - super.beforeAll() - originalCodegenEnabled = conf.codegenEnabled - sql("set spark.sql.codegen=false") - } - - override def afterAll(): Unit = { - sql(s"set spark.sql.codegen=$originalCodegenEnabled") - super.afterAll() - } -} - -abstract class HiveWindowFunctionQueryFileBaseSuite +class HiveWindowFunctionQueryFileSuite extends HiveCompatibilitySuite with BeforeAndAfter { private val originalTimeZone = TimeZone.getDefault private val originalLocale = Locale.getDefault @@ -781,11 +775,11 @@ abstract class HiveWindowFunctionQueryFileBaseSuite // The following settings are used for generating golden files with Hive. // We have to use kryo to correctly let Hive serialize plans with window functions. // This is used to generate golden files. - sql("set hive.plan.serialization.format=kryo") + // sql("set hive.plan.serialization.format=kryo") // Explicitly set fs to local fs. - sql(s"set fs.default.name=file://$testTempDir/") + // sql(s"set fs.default.name=file://$testTempDir/") // Ask Hive to run jobs in-process as a single map and reduce task. - sql("set mapred.job.tracker=local") + // sql("set mapred.job.tracker=local") } override def afterAll() { @@ -825,21 +819,8 @@ abstract class HiveWindowFunctionQueryFileBaseSuite "windowing_adjust_rowcontainer_sz" ) + // Only run those query tests in the realWhileList (do not try other ignored query files). override def testCases: Seq[(String, File)] = super.testCases.filter { case (name, _) => realWhiteList.contains(name) } } - -class HiveWindowFunctionQueryFileWithoutCodeGenSuite extends HiveWindowFunctionQueryFileBaseSuite { - var originalCodegenEnabled: Boolean = _ - override def beforeAll(): Unit = { - super.beforeAll() - originalCodegenEnabled = conf.codegenEnabled - sql("set spark.sql.codegen=false") - } - - override def afterAll(): Unit = { - sql(s"set spark.sql.codegen=$originalCodegenEnabled") - super.afterAll() - } -} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala deleted file mode 100644 index f458567e5d7e..000000000000 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.execution - -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.hive.test.TestHive - -/** - * Runs the test cases that are included in the hive distribution with sort merge join is true. - */ -class SortMergeCompatibilitySuite extends HiveCompatibilitySuite { - override def beforeAll() { - super.beforeAll() - TestHive.setConf(SQLConf.SORTMERGE_JOIN, true) - } - - override def afterAll() { - TestHive.setConf(SQLConf.SORTMERGE_JOIN, false) - super.afterAll() - } - - override def whiteList = Seq( - "auto_join0", - "auto_join1", - "auto_join10", - "auto_join11", - "auto_join12", - "auto_join13", - "auto_join14", - "auto_join14_hadoop20", - "auto_join15", - "auto_join17", - "auto_join18", - "auto_join19", - "auto_join2", - "auto_join20", - "auto_join21", - "auto_join22", - "auto_join23", - "auto_join24", - "auto_join25", - "auto_join26", - "auto_join27", - "auto_join28", - "auto_join3", - "auto_join30", - "auto_join31", - "auto_join32", - "auto_join4", - "auto_join5", - "auto_join6", - "auto_join7", - "auto_join8", - "auto_join9", - "auto_join_filters", - "auto_join_nulls", - "auto_join_reordering_values", - "auto_smb_mapjoin_14", - "auto_sortmerge_join_1", - "auto_sortmerge_join_10", - "auto_sortmerge_join_11", - "auto_sortmerge_join_12", - "auto_sortmerge_join_13", - "auto_sortmerge_join_14", - "auto_sortmerge_join_15", - "auto_sortmerge_join_16", - "auto_sortmerge_join_2", - "auto_sortmerge_join_3", - "auto_sortmerge_join_4", - "auto_sortmerge_join_5", - "auto_sortmerge_join_6", - "auto_sortmerge_join_7", - "auto_sortmerge_join_8", - "auto_sortmerge_join_9", - "correlationoptimizer1", - "correlationoptimizer10", - "correlationoptimizer11", - "correlationoptimizer13", - "correlationoptimizer14", - "correlationoptimizer15", - "correlationoptimizer2", - "correlationoptimizer3", - "correlationoptimizer4", - "correlationoptimizer6", - "correlationoptimizer7", - "correlationoptimizer8", - "correlationoptimizer9", - "join0", - "join1", - "join10", - "join11", - "join12", - "join13", - "join14", - "join14_hadoop20", - "join15", - "join16", - "join17", - "join18", - "join19", - "join2", - "join20", - "join21", - "join22", - "join23", - "join24", - "join25", - "join26", - "join27", - "join28", - "join29", - "join3", - "join30", - "join31", - "join32", - "join32_lessSize", - "join33", - "join34", - "join35", - "join36", - "join37", - "join38", - "join39", - "join4", - "join40", - "join41", - "join5", - "join6", - "join7", - "join8", - "join9", - "join_1to1", - "join_array", - "join_casesensitive", - "join_empty", - "join_filters", - "join_hive_626", - "join_map_ppr", - "join_nulls", - "join_nullsafe", - "join_rc", - "join_reorder2", - "join_reorder3", - "join_reorder4", - "join_star" - ) -} diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index a17546d70624..ac67fe5f47be 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -36,6 +36,11 @@ + + + com.twitter + parquet-hadoop-bundle + org.apache.spark spark-core_${scala.binary.version} @@ -53,32 +58,42 @@ spark-sql_${scala.binary.version} ${project.version} + - org.codehaus.jackson - jackson-mapper-asl + ${hive.group} + hive-exec + ${hive.group} - hive-serde + hive-metastore + org.apache.avro @@ -91,6 +106,55 @@ avro-mapred ${avro.mapred.classifier} + + commons-httpclient + commons-httpclient + + + org.apache.calcite + calcite-avatica + + + org.apache.calcite + calcite-core + + + org.apache.httpcomponents + httpclient + + + org.codehaus.jackson + jackson-mapper-asl + + + + commons-codec + commons-codec + + + joda-time + joda-time + + + org.jodd + jodd-core + + + com.google.code.findbugs + jsr305 + + + org.datanucleus + datanucleus-core + + + org.apache.thrift + libthrift + + + org.apache.thrift + libfb303 + org.scalacheck scalacheck_${scala.binary.version} @@ -133,7 +197,6 @@ - src/test/scala compatibility/src/test/scala diff --git a/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 000000000000..4a774fbf1fdf --- /dev/null +++ b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.sql.hive.orc.DefaultSource diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index cf05c6c98965..d37ba5ddc2d8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -20,32 +20,36 @@ package org.apache.spark.sql.hive import java.io.File import java.net.{URL, URLClassLoader} import java.sql.Timestamp +import java.util.concurrent.TimeUnit -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.HashMap import scala.language.implicitConversions +import scala.concurrent.duration._ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} +import org.apache.spark.Logging import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.SQLConf.SQLConfEntry._ -import org.apache.spark.sql.catalyst.ParserDialect +import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier, ParserDialect} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand} +import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUDFs, SetCommand} +import org.apache.spark.sql.execution.datasources.{PreWriteCheck, PreInsertCastAndRename, DataSourceStrategy} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} -import org.apache.spark.sql.sources.DataSourceStrategy import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -65,12 +69,12 @@ private[hive] class HiveQLDialect extends ParserDialect { * * @since 1.0.0 */ -class HiveContext(sc: SparkContext) extends SQLContext(sc) { +class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { self => import HiveContext._ - println("create HiveContext") + logDebug("create HiveContext") /** * When true, enables an experimental feature where metastore tables that use the parquet SerDe @@ -107,8 +111,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * this does not necessarily need to be the same version of Hive that is used internally by * Spark SQL for execution. */ - protected[hive] def hiveMetastoreVersion: String = - getConf(HIVE_METASTORE_VERSION, hiveExecutionVersion) + protected[hive] def hiveMetastoreVersion: String = getConf(HIVE_METASTORE_VERSION) /** * The location of the jars that should be used to instantiate the HiveMetastoreClient. This @@ -163,6 +166,16 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } SessionState.setCurrentSessionState(executionHive.state) + /** + * Overrides default Hive configurations to avoid breaking changes to Spark SQL users. + * - allow SQL11 keywords to be used as identifiers + */ + private[sql] def defaultOverrides() = { + setConf(ConfVars.HIVE_SUPPORT_SQL11_RESERVED_KEYWORDS.varname, "false") + } + + defaultOverrides() + /** * The copy of the Hive client that is used to retrieve metadata from the Hive MetaStore. * The version of the Hive client that is used here must match the metastore that is configured @@ -175,8 +188,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // We instantiate a HiveConf here to read in the hive-site.xml file and then pass the options // into the isolated client loader val metadataConf = new HiveConf() + + val defaultWarehouseLocation = metadataConf.get("hive.metastore.warehouse.dir") + logInfo("default warehouse location is " + defaultWarehouseLocation) + // `configure` goes second to override other settings. - val allConfig = metadataConf.iterator.map(e => e.getKey -> e.getValue).toMap ++ configure + val allConfig = metadataConf.asScala.map(e => e.getKey -> e.getValue).toMap ++ configure val isolatedLoader = if (hiveMetastoreJars == "builtin") { if (hiveExecutionVersion != hiveMetastoreVersion) { @@ -184,7 +201,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { "Builtin jars can only be used when hive execution version == hive metastore version. " + s"Execution: ${hiveExecutionVersion} != Metastore: ${hiveMetastoreVersion}. " + "Specify a vaild path to the correct hive jars using $HIVE_METASTORE_JARS " + - s"or change $HIVE_METASTORE_VERSION to $hiveExecutionVersion.") + s"or change ${HIVE_METASTORE_VERSION.key} to $hiveExecutionVersion.") } // We recursively find all jars in the class loader chain, @@ -217,7 +234,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // TODO: Support for loading the jars from an already downloaded location. logInfo( s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using maven.") - IsolatedClientLoader.forVersion(hiveMetastoreVersion, allConfig) + IsolatedClientLoader.forVersion( + version = hiveMetastoreVersion, + config = allConfig, + barrierPrefixes = hiveMetastoreBarrierPrefixes, + sharedPrefixes = hiveMetastoreSharedPrefixes) } else { // Convert to files and expand any directories. val jars = @@ -251,6 +272,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } protected[sql] override def parseSql(sql: String): LogicalPlan = { + var state = SessionState.get() + if (state == null) { + SessionState.setCurrentSessionState(tlSession.get().asInstanceOf[SQLSession].sessionState) + } super.parseSql(substitutor.substitute(hiveconf, sql)) } @@ -266,11 +291,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * @since 1.3.0 */ def refreshTable(tableName: String): Unit = { - catalog.refreshTable(catalog.client.currentDatabase, tableName) + val tableIdent = new SqlParser().parseTableIdentifier(tableName) + catalog.refreshTable(tableIdent) } protected[hive] def invalidateTable(tableName: String): Unit = { - catalog.invalidateTable(catalog.client.currentDatabase, tableName) + val tableIdent = new SqlParser().parseTableIdentifier(tableName) + catalog.invalidateTable(tableIdent) } /** @@ -284,7 +311,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { */ @Experimental def analyze(tableName: String) { - val relation = EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) + val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val relation = EliminateSubQueries(catalog.lookupRelation(tableIdent.toSeq)) relation match { case relation: MetastoreRelation => @@ -296,10 +324,21 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // Can we use fs.getContentSummary in future? // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use // countFileSize to count the table size. + val stagingDir = metadataHive.getConf(HiveConf.ConfVars.STAGINGDIR.varname, + HiveConf.ConfVars.STAGINGDIR.defaultStrVal) + def calculateTableSize(fs: FileSystem, path: Path): Long = { val fileStatus = fs.getFileStatus(path) val size = if (fileStatus.isDir) { - fs.listStatus(path).map(status => calculateTableSize(fs, status.getPath)).sum + fs.listStatus(path) + .map { status => + if (!status.getPath().getName().startsWith(stagingDir)) { + calculateTableSize(fs, status.getPath) + } else { + 0L + } + } + .sum } else { fileStatus.getLen } @@ -359,7 +398,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { hiveconf.set(key, value) } - private[sql] override def setConf[T](entry: SQLConfEntry[T], value: T): Unit = { + override private[sql] def setConf[T](entry: SQLConfEntry[T], value: T): Unit = { setConf(entry.key, entry.stringConverter(value)) } @@ -371,7 +410,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // Note that HiveUDFs will be overridden by functions registered in this context. @transient override protected[sql] lazy val functionRegistry: FunctionRegistry = - new OverrideFunctionRegistry(new HiveFunctionRegistry(FunctionRegistry.builtin)) + new HiveFunctionRegistry(FunctionRegistry.builtin) /* An analyzer that uses the Hive metastore. */ @transient @@ -381,13 +420,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.ParquetConversions :: catalog.CreateTables :: catalog.PreInsertionCasts :: - ExtractPythonUdfs :: + ExtractPythonUDFs :: ResolveHiveWindowFunction :: - sources.PreInsertCastAndRename :: + PreInsertCastAndRename :: Nil override val extendedCheckRules = Seq( - sources.PreWriteCheck(catalog) + PreWriteCheck(catalog) ) } @@ -396,7 +435,58 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } /** Overridden by child classes that need to set configuration before the client init. */ - protected def configure(): Map[String, String] = Map.empty + protected def configure(): Map[String, String] = { + // Hive 0.14.0 introduces timeout operations in HiveConf, and changes default values of a bunch + // of time `ConfVar`s by adding time suffixes (`s`, `ms`, and `d` etc.). This breaks backwards- + // compatibility when users are trying to connecting to a Hive metastore of lower version, + // because these options are expected to be integral values in lower versions of Hive. + // + // Here we enumerate all time `ConfVar`s and convert their values to numeric strings according + // to their output time units. + Seq( + ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY -> TimeUnit.SECONDS, + ConfVars.METASTORE_CLIENT_SOCKET_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.METASTORE_CLIENT_SOCKET_LIFETIME -> TimeUnit.SECONDS, + ConfVars.HMSHANDLERINTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.METASTORE_EVENT_DB_LISTENER_TTL -> TimeUnit.SECONDS, + ConfVars.METASTORE_EVENT_CLEAN_FREQ -> TimeUnit.SECONDS, + ConfVars.METASTORE_EVENT_EXPIRY_DURATION -> TimeUnit.SECONDS, + ConfVars.METASTORE_AGGREGATE_STATS_CACHE_TTL -> TimeUnit.SECONDS, + ConfVars.METASTORE_AGGREGATE_STATS_CACHE_MAX_WRITER_WAIT -> TimeUnit.MILLISECONDS, + ConfVars.METASTORE_AGGREGATE_STATS_CACHE_MAX_READER_WAIT -> TimeUnit.MILLISECONDS, + ConfVars.HIVES_AUTO_PROGRESS_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_LOG_INCREMENTAL_PLAN_PROGRESS_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_STATS_JDBC_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_STATS_RETRIES_WAIT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_LOCK_SLEEP_BETWEEN_RETRIES -> TimeUnit.SECONDS, + ConfVars.HIVE_ZOOKEEPER_SESSION_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_ZOOKEEPER_CONNECTION_BASESLEEPTIME -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_TXN_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_COMPACTOR_WORKER_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_COMPACTOR_CHECK_INTERVAL -> TimeUnit.SECONDS, + ConfVars.HIVE_COMPACTOR_CLEANER_RUN_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_THRIFT_HTTP_MAX_IDLE_TIME -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_THRIFT_HTTP_WORKER_KEEPALIVE_TIME -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_THRIFT_HTTP_COOKIE_MAX_AGE -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_THRIFT_LOGIN_BEBACKOFF_SLOT_LENGTH -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_THRIFT_LOGIN_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_THRIFT_WORKER_KEEPALIVE_TIME -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_ASYNC_EXEC_SHUTDOWN_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_ASYNC_EXEC_KEEPALIVE_TIME -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_LONG_POLLING_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_SESSION_CHECK_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_IDLE_SESSION_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_IDLE_OPERATION_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.SERVER_READ_SOCKET_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_LOCALIZE_RESOURCE_WAIT_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.SPARK_CLIENT_FUTURE_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.SPARK_JOB_MONITOR_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.SPARK_RPC_CLIENT_CONNECT_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.SPARK_RPC_CLIENT_HANDSHAKE_TIMEOUT -> TimeUnit.MILLISECONDS + ).map { case (confVar, unit) => + confVar.varname -> hiveconf.getTimeVar(confVar, unit).toString + }.toMap + } protected[hive] class SQLSession extends super.SQLSession { protected[sql] override lazy val conf: SQLConf = new SQLConf { @@ -442,16 +532,15 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { HiveCommandStrategy(self), HiveDDLStrategy, DDLStrategy, - TakeOrdered, - ParquetOperations, + TakeOrderedAndProject, InMemoryScans, - ParquetConversion, // Must be before HiveTableScans HiveTableScans, DataSinks, Scripts, HashAggregation, + Aggregation, LeftSemiJoin, - HashJoin, + EquiJoinSelection, BasicOperators, CartesianProduct, BroadcastNestedLoopJoin @@ -514,19 +603,27 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { private[hive] object HiveContext { /** The version of hive used internally by Spark SQL. */ - val hiveExecutionVersion: String = "0.13.1" + val hiveExecutionVersion: String = "1.2.1" + + val HIVE_METASTORE_VERSION = stringConf("spark.sql.hive.metastore.version", + defaultValue = Some(hiveExecutionVersion), + doc = "Version of the Hive metastore. Available options are " + + s"0.12.0 through $hiveExecutionVersion.") - val HIVE_METASTORE_VERSION: String = "spark.sql.hive.metastore.version" val HIVE_METASTORE_JARS = stringConf("spark.sql.hive.metastore.jars", defaultValue = Some("builtin"), - doc = "Location of the jars that should be used to instantiate the HiveMetastoreClient. This" + - " property can be one of three options: " + - "1. \"builtin\" Use Hive 0.13.1, which is bundled with the Spark assembly jar when " + - "-Phive is enabled. When this option is chosen, " + - "spark.sql.hive.metastore.version must be either 0.13.1 or not defined. " + - "2. \"maven\" Use Hive jars of specified version downloaded from Maven repositories." + - "3. A classpath in the standard format for both Hive and Hadoop.") - + doc = s""" + | Location of the jars that should be used to instantiate the HiveMetastoreClient. + | This property can be one of three options: " + | 1. "builtin" + | Use Hive ${hiveExecutionVersion}, which is bundled with the Spark assembly jar when + | -Phive is enabled. When this option is chosen, + | spark.sql.hive.metastore.version must be either + | ${hiveExecutionVersion} or not defined. + | 2. "maven" + | Use Hive jars of specified version downloaded from Maven repositories. + | 3. A classpath in the standard format for both Hive and Hadoop. + """.stripMargin) val CONVERT_METASTORE_PARQUET = booleanConf("spark.sql.hive.convertMetastoreParquet", defaultValue = Some(true), doc = "When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of " + @@ -565,17 +662,18 @@ private[hive] object HiveContext { /** Constructs a configuration for hive, where the metastore is located in a temp directory. */ def newTemporaryConfiguration(): Map[String, String] = { val tempDir = Utils.createTempDir() - val localMetastore = new File(tempDir, "metastore").getAbsolutePath + val localMetastore = new File(tempDir, "metastore") val propMap: HashMap[String, String] = HashMap() // We have to mask all properties in hive-site.xml that relates to metastore data source // as we used a local metastore here. HiveConf.ConfVars.values().foreach { confvar => if (confvar.varname.contains("datanucleus") || confvar.varname.contains("jdo")) { - propMap.put(confvar.varname, confvar.defaultVal) + propMap.put(confvar.varname, confvar.getDefaultExpr()) } } - propMap.put("javax.jdo.option.ConnectionURL", - s"jdbc:derby:;databaseName=$localMetastore;create=true") + propMap.put(HiveConf.ConfVars.METASTOREWAREHOUSE.varname, localMetastore.toURI.toString) + propMap.put(HiveConf.ConfVars.METASTORECONNECTURLKEY.varname, + s"jdbc:derby:;databaseName=${localMetastore.getAbsolutePath};create=true") propMap.put("datanucleus.rdbms.datastoreAdapterClassName", "org.datanucleus.store.rdbms.adapter.DerbyAdapter") propMap.toMap diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index d4f1ae8ee01d..cfe2bb05ad89 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import scala.collection.JavaConverters._ + import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _} @@ -24,15 +26,13 @@ import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfoFactory} import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.types +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.{AnalysisException, types} import org.apache.spark.unsafe.types.UTF8String -/* Implicit conversions */ -import scala.collection.JavaConversions._ - /** * 1. The Underlying data type in catalyst and in Hive * In catalyst: @@ -45,15 +45,14 @@ import scala.collection.JavaConversions._ * long / scala.Long * short / scala.Short * byte / scala.Byte - * org.apache.spark.sql.types.Decimal + * [[org.apache.spark.sql.types.Decimal]] * Array[Byte] * java.sql.Date * java.sql.Timestamp * Complex Types => - * Map: scala.collection.immutable.Map - * List: scala.collection.immutable.Seq - * Struct: - * org.apache.spark.sql.catalyst.expression.Row + * Map: [[org.apache.spark.sql.types.MapData]] + * List: [[org.apache.spark.sql.types.ArrayData]] + * Struct: [[org.apache.spark.sql.catalyst.InternalRow]] * Union: NOT SUPPORTED YET * The Complex types plays as a container, which can hold arbitrary data types. * @@ -178,7 +177,7 @@ private[hive] trait HiveInspectors { // writable case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType - case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType.Unlimited + case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType.SYSTEM_DEFAULT case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType case c: Class[_] if c == classOf[hiveIo.DateWritable] => DateType @@ -194,8 +193,8 @@ private[hive] trait HiveInspectors { case c: Class[_] if c == classOf[java.lang.String] => StringType case c: Class[_] if c == classOf[java.sql.Date] => DateType case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType - case c: Class[_] if c == classOf[HiveDecimal] => DecimalType.Unlimited - case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType.Unlimited + case c: Class[_] if c == classOf[HiveDecimal] => DecimalType.SYSTEM_DEFAULT + case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType.SYSTEM_DEFAULT case c: Class[_] if c == classOf[Array[Byte]] => BinaryType case c: Class[_] if c == classOf[java.lang.Short] => ShortType case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType @@ -218,6 +217,20 @@ private[hive] trait HiveInspectors { // Hive seems to return this for struct types? case c: Class[_] if c == classOf[java.lang.Object] => NullType + + // java list type unsupported + case c: Class[_] if c == classOf[java.util.List[_]] => + throw new AnalysisException( + "List type in java is unsupported because " + + "JVM type erasure makes spark fail to catch a component type in List<>") + + // java map type unsupported + case c: Class[_] if c == classOf[java.util.Map[_, _]] => + throw new AnalysisException( + "Map type in java is unsupported because " + + "JVM type erasure makes spark fail to catch key and value types in Map<>") + + case c => throw new AnalysisException(s"Unsupported java type $c") } /** @@ -252,7 +265,7 @@ private[hive] trait HiveInspectors { poi.getWritableConstantValue.getHiveDecimal) case poi: WritableConstantTimestampObjectInspector => val t = poi.getWritableConstantValue - t.getSeconds * 10000000L + t.getNanos / 100L + t.getSeconds * 1000000L + t.getNanos / 1000L case poi: WritableConstantIntObjectInspector => poi.getWritableConstantValue.get() case poi: WritableConstantDoubleObjectInspector => @@ -273,16 +286,19 @@ private[hive] trait HiveInspectors { System.arraycopy(writable.getBytes, 0, temp, 0, temp.length) temp case poi: WritableConstantDateObjectInspector => - DateUtils.fromJavaDate(poi.getWritableConstantValue.get()) + DateTimeUtils.fromJavaDate(poi.getWritableConstantValue.get()) case mi: StandardConstantMapObjectInspector => // take the value from the map inspector object, rather than the input data - mi.getWritableConstantValue.map { case (k, v) => - (unwrap(k, mi.getMapKeyObjectInspector), - unwrap(v, mi.getMapValueObjectInspector)) - }.toMap + val keyValues = mi.getWritableConstantValue.asScala.toSeq + val keys = keyValues.map(kv => unwrap(kv._1, mi.getMapKeyObjectInspector)).toArray + val values = keyValues.map(kv => unwrap(kv._2, mi.getMapValueObjectInspector)).toArray + ArrayBasedMapData(keys, values) case li: StandardConstantListObjectInspector => // take the value from the list inspector object, rather than the input data - li.getWritableConstantValue.map(unwrap(_, li.getListElementObjectInspector)).toSeq + val values = li.getWritableConstantValue.asScala + .map(unwrap(_, li.getListElementObjectInspector)) + .toArray + new GenericArrayData(values) // if the value is null, we don't care about the object inspector type case _ if data == null => null case poi: VoidObjectInspector => null // always be null for void object inspector @@ -313,32 +329,37 @@ private[hive] trait HiveInspectors { System.arraycopy(bw.getBytes(), 0, result, 0, bw.getLength()) result case x: DateObjectInspector if x.preferWritable() => - DateUtils.fromJavaDate(x.getPrimitiveWritableObject(data).get()) - case x: DateObjectInspector => DateUtils.fromJavaDate(x.getPrimitiveJavaObject(data)) + DateTimeUtils.fromJavaDate(x.getPrimitiveWritableObject(data).get()) + case x: DateObjectInspector => DateTimeUtils.fromJavaDate(x.getPrimitiveJavaObject(data)) case x: TimestampObjectInspector if x.preferWritable() => val t = x.getPrimitiveWritableObject(data) - t.getSeconds * 10000000L + t.getNanos / 100 + t.getSeconds * 1000000L + t.getNanos / 1000L case ti: TimestampObjectInspector => - DateUtils.fromJavaTimestamp(ti.getPrimitiveJavaObject(data)) + DateTimeUtils.fromJavaTimestamp(ti.getPrimitiveJavaObject(data)) case _ => pi.getPrimitiveJavaObject(data) } case li: ListObjectInspector => Option(li.getList(data)) - .map(_.map(unwrap(_, li.getListElementObjectInspector)).toSeq) + .map { l => + val values = l.asScala.map(unwrap(_, li.getListElementObjectInspector)).toArray + new GenericArrayData(values) + } .orNull case mi: MapObjectInspector => - Option(mi.getMap(data)).map( - _.map { - case (k, v) => - (unwrap(k, mi.getMapKeyObjectInspector), - unwrap(v, mi.getMapValueObjectInspector)) - }.toMap).orNull + val map = mi.getMap(data) + if (map == null) { + null + } else { + val keyValues = map.asScala.toSeq + val keys = keyValues.map(kv => unwrap(kv._1, mi.getMapKeyObjectInspector)).toArray + val values = keyValues.map(kv => unwrap(kv._2, mi.getMapValueObjectInspector)).toArray + ArrayBasedMapData(keys, values) + } // currently, hive doesn't provide the ConstantStructObjectInspector case si: StructObjectInspector => val allRefs = si.getAllStructFieldRefs - new GenericRow( - allRefs.map(r => - unwrap(si.getStructFieldData(data, r), r.getFieldObjectInspector)).toArray) + InternalRow.fromSeq(allRefs.asScala.map( + r => unwrap(si.getStructFieldData(data, r), r.getFieldObjectInspector))) } @@ -346,28 +367,52 @@ private[hive] trait HiveInspectors { * Wraps with Hive types based on object inspector. * TODO: Consolidate all hive OI/data interface code. */ - protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match { + protected def wrapperFor(oi: ObjectInspector, dataType: DataType): Any => Any = oi match { case _: JavaHiveVarcharObjectInspector => (o: Any) => - val s = o.asInstanceOf[UTF8String].toString - new HiveVarchar(s, s.size) + if (o != null) { + val s = o.asInstanceOf[UTF8String].toString + new HiveVarchar(s, s.size) + } else { + null + } case _: JavaHiveDecimalObjectInspector => - (o: Any) => HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal) + (o: Any) => + if (o != null) { + HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal) + } else { + null + } case _: JavaDateObjectInspector => - (o: Any) => DateUtils.toJavaDate(o.asInstanceOf[Int]) + (o: Any) => + if (o != null) { + DateTimeUtils.toJavaDate(o.asInstanceOf[Int]) + } else { + null + } case _: JavaTimestampObjectInspector => - (o: Any) => DateUtils.toJavaTimestamp(o.asInstanceOf[Long]) + (o: Any) => + if (o != null) { + DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long]) + } else { + null + } case soi: StandardStructObjectInspector => - val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) + val schema = dataType.asInstanceOf[StructType] + val wrappers = soi.getAllStructFieldRefs.asScala.zip(schema.fields).map { + case (ref, field) => wrapperFor(ref.getFieldObjectInspector, field.dataType) + } (o: Any) => { if (o != null) { val struct = soi.create() - (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[InternalRow].toSeq).zipped.foreach { - (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) + val row = o.asInstanceOf[InternalRow] + soi.getAllStructFieldRefs.asScala.zip(wrappers).zipWithIndex.foreach { + case ((field, wrapper), i) => + soi.setStructFieldData(struct, field, wrapper(row.get(i, schema(i).dataType))) } struct } else { @@ -376,21 +421,34 @@ private[hive] trait HiveInspectors { } case loi: ListObjectInspector => - val wrapper = wrapperFor(loi.getListElementObjectInspector) - (o: Any) => if (o != null) seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) else null + val elementType = dataType.asInstanceOf[ArrayType].elementType + val wrapper = wrapperFor(loi.getListElementObjectInspector, elementType) + (o: Any) => { + if (o != null) { + val array = o.asInstanceOf[ArrayData] + val values = new java.util.ArrayList[Any](array.numElements()) + array.foreach(elementType, (_, e) => { + values.add(wrapper(e)) + }) + values + } else { + null + } + } case moi: MapObjectInspector => - // The Predef.Map is scala.collection.immutable.Map. - // Since the map values can be mutable, we explicitly import scala.collection.Map at here. - import scala.collection.Map + val mt = dataType.asInstanceOf[MapType] + val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector, mt.keyType) + val valueWrapper = wrapperFor(moi.getMapValueObjectInspector, mt.valueType) - val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector) - val valueWrapper = wrapperFor(moi.getMapValueObjectInspector) (o: Any) => { if (o != null) { - mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) => - keyWrapper(key) -> valueWrapper(value) + val map = o.asInstanceOf[MapData] + val jmap = new java.util.HashMap[Any, Any](map.numElements()) + map.foreach(mt.keyType, mt.valueType, (k, v) => { + jmap.put(keyWrapper(k), valueWrapper(v)) }) + jmap } else { null } @@ -440,7 +498,7 @@ private[hive] trait HiveInspectors { * * NOTICE: the complex data type requires recursive wrapping. */ - def wrap(a: Any, oi: ObjectInspector): AnyRef = oi match { + def wrap(a: Any, oi: ObjectInspector, dataType: DataType): AnyRef = oi match { case x: ConstantObjectInspector => x.getWritableConstantValue case _ if a == null => null case x: PrimitiveObjectInspector => x match { @@ -468,49 +526,59 @@ private[hive] trait HiveInspectors { case _: BinaryObjectInspector if x.preferWritable() => getBinaryWritable(a) case _: BinaryObjectInspector => a.asInstanceOf[Array[Byte]] case _: DateObjectInspector if x.preferWritable() => getDateWritable(a) - case _: DateObjectInspector => DateUtils.toJavaDate(a.asInstanceOf[Int]) + case _: DateObjectInspector => DateTimeUtils.toJavaDate(a.asInstanceOf[Int]) case _: TimestampObjectInspector if x.preferWritable() => getTimestampWritable(a) - case _: TimestampObjectInspector => DateUtils.toJavaTimestamp(a.asInstanceOf[Long]) + case _: TimestampObjectInspector => DateTimeUtils.toJavaTimestamp(a.asInstanceOf[Long]) } case x: SettableStructObjectInspector => val fieldRefs = x.getAllStructFieldRefs + val structType = dataType.asInstanceOf[StructType] val row = a.asInstanceOf[InternalRow] // 1. create the pojo (most likely) object val result = x.create() var i = 0 - while (i < fieldRefs.length) { + while (i < fieldRefs.size) { // 2. set the property for the pojo + val tpe = structType(i).dataType x.setStructFieldData( result, fieldRefs.get(i), - wrap(row(i), fieldRefs.get(i).getFieldObjectInspector)) + wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe)) i += 1 } result case x: StructObjectInspector => val fieldRefs = x.getAllStructFieldRefs + val structType = dataType.asInstanceOf[StructType] val row = a.asInstanceOf[InternalRow] - val result = new java.util.ArrayList[AnyRef](fieldRefs.length) + val result = new java.util.ArrayList[AnyRef](fieldRefs.size) var i = 0 - while (i < fieldRefs.length) { - result.add(wrap(row(i), fieldRefs.get(i).getFieldObjectInspector)) + while (i < fieldRefs.size) { + val tpe = structType(i).dataType + result.add(wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe)) i += 1 } result case x: ListObjectInspector => val list = new java.util.ArrayList[Object] - a.asInstanceOf[Seq[_]].foreach { - v => list.add(wrap(v, x.getListElementObjectInspector)) - } + val tpe = dataType.asInstanceOf[ArrayType].elementType + a.asInstanceOf[ArrayData].foreach(tpe, (_, e) => { + list.add(wrap(e, x.getListElementObjectInspector, tpe)) + }) list case x: MapObjectInspector => + val keyType = dataType.asInstanceOf[MapType].keyType + val valueType = dataType.asInstanceOf[MapType].valueType + val map = a.asInstanceOf[MapData] + // Some UDFs seem to assume we pass in a HashMap. - val hashMap = new java.util.HashMap[AnyRef, AnyRef]() - hashMap.putAll(a.asInstanceOf[Map[_, _]].map { - case (k, v) => - wrap(k, x.getMapKeyObjectInspector) -> wrap(v, x.getMapValueObjectInspector) + val hashMap = new java.util.HashMap[Any, Any](map.numElements()) + + map.foreach(keyType, valueType, (k, v) => { + hashMap.put(wrap(k, x.getMapKeyObjectInspector, keyType), + wrap(v, x.getMapValueObjectInspector, valueType)) }) hashMap @@ -519,22 +587,24 @@ private[hive] trait HiveInspectors { def wrap( row: InternalRow, inspectors: Seq[ObjectInspector], - cache: Array[AnyRef]): Array[AnyRef] = { + cache: Array[AnyRef], + dataTypes: Array[DataType]): Array[AnyRef] = { var i = 0 while (i < inspectors.length) { - cache(i) = wrap(row(i), inspectors(i)) + cache(i) = wrap(row.get(i, dataTypes(i)), inspectors(i), dataTypes(i)) i += 1 } cache } def wrap( - row: Seq[Any], - inspectors: Seq[ObjectInspector], - cache: Array[AnyRef]): Array[AnyRef] = { + row: Seq[Any], + inspectors: Seq[ObjectInspector], + cache: Array[AnyRef], + dataTypes: Array[DataType]): Array[AnyRef] = { var i = 0 while (i < inspectors.length) { - cache(i) = wrap(row(i), inspectors(i)) + cache(i) = wrap(row(i), inspectors(i), dataTypes(i)) i += 1 } cache @@ -611,7 +681,9 @@ private[hive] trait HiveInspectors { ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null) } else { val list = new java.util.ArrayList[Object]() - value.asInstanceOf[Seq[_]].foreach(v => list.add(wrap(v, listObjectInspector))) + value.asInstanceOf[ArrayData].foreach(dt, (_, e) => { + list.add(wrap(e, listObjectInspector, dt)) + }) ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list) } case Literal(value, MapType(keyType, valueType, _)) => @@ -620,11 +692,14 @@ private[hive] trait HiveInspectors { if (value == null) { ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, null) } else { - val map = new java.util.HashMap[Object, Object]() - value.asInstanceOf[Map[_, _]].foreach (entry => { - map.put(wrap(entry._1, keyOI), wrap(entry._2, valueOI)) + val map = value.asInstanceOf[MapData] + val jmap = new java.util.HashMap[Any, Any](map.numElements()) + + map.foreach(keyType, valueType, (k, v) => { + jmap.put(wrap(k, keyOI, keyType), wrap(v, valueOI, valueType)) }) - ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, map) + + ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, jmap) } // We will enumerate all of the possible constant expressions, throw exception if we missed case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].") @@ -637,10 +712,10 @@ private[hive] trait HiveInspectors { def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match { case s: StructObjectInspector => - StructType(s.getAllStructFieldRefs.map(f => { + StructType(s.getAllStructFieldRefs.asScala.map(f => types.StructField( f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true) - })) + )) case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector)) case m: MapObjectInspector => MapType( @@ -781,7 +856,7 @@ private[hive] trait HiveInspectors { if (value == null) { null } else { - new hiveIo.TimestampWritable(DateUtils.toJavaTimestamp(value.asInstanceOf[Long])) + new hiveIo.TimestampWritable(DateTimeUtils.toJavaTimestamp(value.asInstanceOf[Long])) } private def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = @@ -799,9 +874,6 @@ private[hive] trait HiveInspectors { private def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match { case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale) - case _ => new DecimalTypeInfo( - HiveShim.UNLIMITED_DECIMAL_PRECISION, - HiveShim.UNLIMITED_DECIMAL_SCALE) } def toTypeInfo: TypeInfo = dt match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f35ae96ee0b5..0a5569b0a444 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.hive +import scala.collection.JavaConverters._ +import scala.collection.mutable + import com.google.common.base.Objects import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} - import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.common.StatsSetupConst +import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.Warehouse import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.ql.metadata._ @@ -30,18 +33,66 @@ import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation +import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} +import org.apache.spark.sql.execution.{FileRelation, datasources} import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.parquet.ParquetRelation2 -import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} +import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode, sources} +import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} + +private[hive] case class HiveSerDe( + inputFormat: Option[String] = None, + outputFormat: Option[String] = None, + serde: Option[String] = None) + +private[hive] object HiveSerDe { + /** + * Get the Hive SerDe information from the data source abbreviation string or classname. + * + * @param source Currently the source abbreviation can be one of the following: + * SequenceFile, RCFile, ORC, PARQUET, and case insensitive. + * @param hiveConf Hive Conf + * @return HiveSerDe associated with the specified source + */ + def sourceToSerDe(source: String, hiveConf: HiveConf): Option[HiveSerDe] = { + val serdeMap = Map( + "sequencefile" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), + outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")), + + "rcfile" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat"), + serde = Option(hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTRCFILESERDE))), + + "orc" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat"), + serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")), + + "parquet" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"), + serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"))) + + val key = source.toLowerCase match { + case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet" + case s if s.startsWith("org.apache.spark.sql.orc") => "orc" + case s => s + } -/* Implicit conversions */ -import scala.collection.JavaConversions._ + serdeMap.get(key) + } +} private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: HiveContext) extends Catalog with Logging { @@ -114,7 +165,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive CacheBuilder.newBuilder().maximumSize(1000).build(cacheLoader) } - override def refreshTable(databaseName: String, tableName: String): Unit = { + override def refreshTable(tableIdent: TableIdentifier): Unit = { // refreshTable does not eagerly reload the cache. It just invalidate the cache. // Next time when we use the table, it will be populated in the cache. // Since we also cache ParquetRelations converted from Hive Parquet tables and @@ -123,10 +174,13 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // it is better at here to invalidate the cache to avoid confusing waring logs from the // cache loader (e.g. cannot find data source provider, which is only defined for // data source table.). - invalidateTable(databaseName, tableName) + invalidateTable(tableIdent) } - def invalidateTable(databaseName: String, tableName: String): Unit = { + def invalidateTable(tableIdent: TableIdentifier): Unit = { + val databaseName = tableIdent.database.getOrElse(client.currentDatabase) + val tableName = tableIdent.table + cachedDataSourceTables.invalidate(QualifiedTableName(databaseName, tableName).toLowerCase) } @@ -136,6 +190,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive * Creates a data source table (a table created with USING clause) in Hive's metastore. * Returns true when the table has been created. Otherwise, false. */ + // TODO: Remove this in SPARK-10104. def createDataSourceTable( tableName: String, userSpecifiedSchema: Option[StructType], @@ -143,16 +198,36 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive provider: String, options: Map[String, String], isExternal: Boolean): Unit = { - val (dbName, tblName) = processDatabaseAndTableName(client.currentDatabase, tableName) - val tableProperties = new scala.collection.mutable.HashMap[String, String] + createDataSourceTable( + new SqlParser().parseTableIdentifier(tableName), + userSpecifiedSchema, + partitionColumns, + provider, + options, + isExternal) + } + + def createDataSourceTable( + tableIdent: TableIdentifier, + userSpecifiedSchema: Option[StructType], + partitionColumns: Array[String], + provider: String, + options: Map[String, String], + isExternal: Boolean): Unit = { + val (dbName, tblName) = { + val database = tableIdent.database.getOrElse(client.currentDatabase) + processDatabaseAndTableName(database, tableIdent.table) + } + + val tableProperties = new mutable.HashMap[String, String] tableProperties.put("spark.sql.sources.provider", provider) // Saves optional user specified schema. Serialized JSON schema string may be too long to be // stored into a single metastore SerDe property. In this case, we split the JSON string and // store each part as a separate SerDe property. - if (userSpecifiedSchema.isDefined) { + userSpecifiedSchema.foreach { schema => val threshold = conf.schemaStringLengthThreshold - val schemaJsonString = userSpecifiedSchema.get.json + val schemaJsonString = schema.json // Split the JSON string. val parts = schemaJsonString.grouped(threshold).toSeq tableProperties.put("spark.sql.sources.schema.numParts", parts.size.toString) @@ -174,9 +249,9 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // The table does not have a specified schema, which means that the schema will be inferred // when we load the table. So, we are not expecting partition columns and we will discover // partitions when we load the table. However, if there are specified partition columns, - // we simplily ignore them and provide a warning message.. + // we simply ignore them and provide a warning message. logWarning( - s"The schema and partitions of table $tableName will be inferred when it is loaded. " + + s"The schema and partitions of table $tableIdent will be inferred when it is loaded. " + s"Specified partition columns (${partitionColumns.mkString(",")}) will be ignored.") } Seq.empty[HiveColumn] @@ -190,7 +265,11 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive ManagedTable } - client.createTable( + val maybeSerDe = HiveSerDe.sourceToSerDe(provider, hive.hiveconf) + val dataSource = ResolvedDataSource( + hive, userSpecifiedSchema, partitionColumns, provider, options) + + def newSparkSQLSpecificMetastoreTable(): HiveTable = { HiveTable( specifiedDatabase = Option(dbName), name = tblName, @@ -198,14 +277,114 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive partitionColumns = metastorePartitionColumns, tableType = tableType, properties = tableProperties.toMap, - serdeProperties = options)) + serdeProperties = options) + } + + def newHiveCompatibleMetastoreTable(relation: HadoopFsRelation, serde: HiveSerDe): HiveTable = { + def schemaToHiveColumn(schema: StructType): Seq[HiveColumn] = { + schema.map { field => + HiveColumn( + name = field.name, + hiveType = HiveMetastoreTypes.toMetastoreType(field.dataType), + comment = "") + } + } + + val partitionColumns = schemaToHiveColumn(relation.partitionColumns) + val dataColumns = schemaToHiveColumn(relation.schema).filterNot(partitionColumns.contains) + + HiveTable( + specifiedDatabase = Option(dbName), + name = tblName, + schema = dataColumns, + partitionColumns = partitionColumns, + tableType = tableType, + properties = tableProperties.toMap, + serdeProperties = options, + location = Some(relation.paths.head), + viewText = None, // TODO We need to place the SQL string here. + inputFormat = serde.inputFormat, + outputFormat = serde.outputFormat, + serde = serde.serde) + } + + // TODO: Support persisting partitioned data source relations in Hive compatible format + val hiveTable = (maybeSerDe, dataSource.relation) match { + case (Some(serde), relation: HadoopFsRelation) + if relation.paths.length == 1 && relation.partitionColumns.isEmpty => + // Hive ParquetSerDe doesn't support decimal type until 1.2.0. + val isParquetSerDe = serde.inputFormat.exists(_.toLowerCase.contains("parquet")) + val hasDecimalFields = relation.schema.existsRecursively(_.isInstanceOf[DecimalType]) + + val hiveParquetSupportsDecimal = client.version match { + case org.apache.spark.sql.hive.client.hive.v1_2 => true + case _ => false + } + + if (isParquetSerDe && !hiveParquetSupportsDecimal && hasDecimalFields) { + // If Hive version is below 1.2.0, we cannot save Hive compatible schema to + // metastore when the file format is Parquet and the schema has DecimalType. + logWarning { + "Persisting Parquet relation with decimal field(s) into Hive metastore in Spark SQL " + + "specific format, which is NOT compatible with Hive. Because ParquetHiveSerDe in " + + s"Hive ${client.version.fullVersion} doesn't support decimal type. See HIVE-6384." + } + newSparkSQLSpecificMetastoreTable() + } else { + logInfo { + "Persisting data source relation with a single input path into Hive metastore in " + + s"Hive compatible format. Input path: ${relation.paths.head}" + } + newHiveCompatibleMetastoreTable(relation, serde) + } + + case (Some(serde), relation: HadoopFsRelation) if relation.partitionColumns.nonEmpty => + logWarning { + "Persisting partitioned data source relation into Hive metastore in " + + s"Spark SQL specific format, which is NOT compatible with Hive. Input path(s): " + + relation.paths.mkString("\n", "\n", "") + } + newSparkSQLSpecificMetastoreTable() + + case (Some(serde), relation: HadoopFsRelation) => + logWarning { + "Persisting data source relation with multiple input paths into Hive metastore in " + + s"Spark SQL specific format, which is NOT compatible with Hive. Input paths: " + + relation.paths.mkString("\n", "\n", "") + } + newSparkSQLSpecificMetastoreTable() + + case (Some(serde), _) => + logWarning { + s"Data source relation is not a ${classOf[HadoopFsRelation].getSimpleName}. " + + "Persisting it into Hive metastore in Spark SQL specific format, " + + "which is NOT compatible with Hive." + } + newSparkSQLSpecificMetastoreTable() + + case _ => + logWarning { + s"Couldn't find corresponding Hive SerDe for data source provider $provider. " + + "Persisting data source relation into Hive metastore in Spark SQL specific format, " + + "which is NOT compatible with Hive." + } + newSparkSQLSpecificMetastoreTable() + } + + client.createTable(hiveTable) } def hiveDefaultTableFilePath(tableName: String): String = { + hiveDefaultTableFilePath(new SqlParser().parseTableIdentifier(tableName)) + } + + def hiveDefaultTableFilePath(tableIdent: TableIdentifier): String = { // Code based on: hiveWarehouse.getTablePath(currentDatabase, tableName) + val database = tableIdent.database.getOrElse(client.currentDatabase) + new Path( - new Path(client.getDatabase(client.currentDatabase).location), - tableName.toLowerCase).toString + new Path(client.getDatabase(database).location), + tableIdent.table.toLowerCase).toString } def tableExists(tableIdentifier: Seq[String]): Boolean = { @@ -254,12 +433,12 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val metastoreSchema = StructType.fromAttributes(metastoreRelation.output) val mergeSchema = hive.convertMetastoreParquetWithSchemaMerging - // NOTE: Instead of passing Metastore schema directly to `ParquetRelation2`, we have to + // NOTE: Instead of passing Metastore schema directly to `ParquetRelation`, we have to // serialize the Metastore schema to JSON and pass it as a data source option because of the - // evil case insensitivity issue, which is reconciled within `ParquetRelation2`. + // evil case insensitivity issue, which is reconciled within `ParquetRelation`. val parquetOptions = Map( - ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json, - ParquetRelation2.MERGE_SCHEMA -> mergeSchema.toString) + ParquetRelation.METASTORE_SCHEMA -> metastoreSchema.json, + ParquetRelation.MERGE_SCHEMA -> mergeSchema.toString) val tableIdentifier = QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) @@ -270,14 +449,14 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = { cachedDataSourceTables.getIfPresent(tableIdentifier) match { case null => None // Cache miss - case logical@LogicalRelation(parquetRelation: ParquetRelation2) => + case logical @ LogicalRelation(parquetRelation: ParquetRelation) => // If we have the same paths, same schema, and same partition spec, // we will use the cached Parquet Relation. val useCached = parquetRelation.paths.toSet == pathsInMetastore.toSet && logical.schema.sameType(metastoreSchema) && parquetRelation.partitionSpec == partitionSpecInMetastore.getOrElse { - PartitionSpec(StructType(Nil), Array.empty[sources.Partition]) + PartitionSpec(StructType(Nil), Array.empty[datasources.Partition]) } if (useCached) { @@ -300,9 +479,11 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val result = if (metastoreRelation.hiveQlTable.isPartitioned) { val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys) val partitionColumnDataTypes = partitionSchema.map(_.dataType) - val partitions = metastoreRelation.hiveQlPartitions.map { p => + // We're converting the entire table into ParquetRelation, so predicates to Hive metastore + // are empty. + val partitions = metastoreRelation.getHiveQlPartitions().map { p => val location = p.getLocation - val values = InternalRow.fromSeq(p.getValues.zip(partitionColumnDataTypes).map { + val values = InternalRow.fromSeq(p.getValues.asScala.zip(partitionColumnDataTypes).map { case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null) }) ParquetPartition(values, location) @@ -313,7 +494,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val cached = getCached(tableIdentifier, paths, metastoreSchema, Some(partitionSpec)) val parquetRelation = cached.getOrElse { val created = LogicalRelation( - new ParquetRelation2( + new ParquetRelation( paths.toArray, None, Some(partitionSpec), parquetOptions)(hive)) cachedDataSourceTables.put(tableIdentifier, created) created @@ -326,7 +507,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val cached = getCached(tableIdentifier, paths, metastoreSchema, None) val parquetRelation = cached.getOrElse { val created = LogicalRelation( - new ParquetRelation2(paths.toArray, None, None, parquetOptions)(hive)) + new ParquetRelation(paths.toArray, None, None, parquetOptions)(hive)) cachedDataSourceTables.put(tableIdentifier, created) created } @@ -366,12 +547,10 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive /** * When scanning or writing to non-partitioned Metastore Parquet tables, convert them to Parquet * data source relations for better performance. - * - * This rule can be considered as [[HiveStrategies.ParquetConversion]] done right. */ object ParquetConversions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { - if (!plan.resolved) { + if (!plan.resolved || plan.analyzed) { return plan } @@ -382,7 +561,6 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // Inserting into partitioned table is not supported in Parquet data source (yet). if !relation.hiveQlTable.isPartitioned && hive.convertMetastoreParquet && - conf.parquetUseDataSourceApi && relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) @@ -393,16 +571,13 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // Inserting into partitioned table is not supported in Parquet data source (yet). if !relation.hiveQlTable.isPartitioned && hive.convertMetastoreParquet && - conf.parquetUseDataSourceApi && relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) (relation, parquetRelation, attributedRewrites) // Read path - case p @ PhysicalOperation(_, _, relation: MetastoreRelation) - if hive.convertMetastoreParquet && - conf.parquetUseDataSourceApi && + case relation: MetastoreRelation if hive.convertMetastoreParquet && relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) @@ -447,7 +622,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive case p: LogicalPlan if !p.childrenResolved => p case p: LogicalPlan if p.resolved => p case p @ CreateTableAsSelect(table, child, allowExisting) => - val schema = if (table.schema.size > 0) { + val schema = if (table.schema.nonEmpty) { table.schema } else { child.output.map { @@ -470,7 +645,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val mode = if (allowExisting) SaveMode.Ignore else SaveMode.ErrorIfExists CreateTableUsingAsSelect( - desc.name, + TableIdentifier(desc.name), hive.conf.defaultDataSourceName, temporary = false, Array.empty[String], @@ -576,7 +751,7 @@ private[hive] case class InsertIntoHiveTable( extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil - override def output: Seq[Attribute] = child.output + override def output: Seq[Attribute] = Seq.empty val numDynamicPartitions = partition.values.count(_.isEmpty) @@ -592,10 +767,8 @@ private[hive] case class InsertIntoHiveTable( private[hive] case class MetastoreRelation (databaseName: String, tableName: String, alias: Option[String]) (val table: HiveTable) - (@transient sqlContext: SQLContext) - extends LeafNode with MultiInstanceRelation { - - self: Product => + (@transient private val sqlContext: SQLContext) + extends LeafNode with MultiInstanceRelation with FileRelation { override def equals(other: Any): Boolean = other match { case relation: MetastoreRelation => @@ -625,48 +798,23 @@ private[hive] case class MetastoreRelation val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() tTable.setSd(sd) - sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) tTable.setPartitionKeys( - table.partitionColumns.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + table.partitionColumns.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) table.location.foreach(sd.setLocation) table.inputFormat.foreach(sd.setInputFormat) table.outputFormat.foreach(sd.setOutputFormat) val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo - sd.setSerdeInfo(serdeInfo) table.serde.foreach(serdeInfo.setSerializationLib) - val serdeParameters = new java.util.HashMap[String, String]() - serdeInfo.setParameters(serdeParameters) - table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } - - new Table(tTable) - } - - @transient val hiveQlPartitions: Seq[Partition] = table.getAllPartitions.map { p => - val tPartition = new org.apache.hadoop.hive.metastore.api.Partition - tPartition.setDbName(databaseName) - tPartition.setTableName(tableName) - tPartition.setValues(p.values) - - val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() - tPartition.setSd(sd) - sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) - - sd.setLocation(p.storage.location) - sd.setInputFormat(p.storage.inputFormat) - sd.setOutputFormat(p.storage.outputFormat) - - val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo sd.setSerdeInfo(serdeInfo) - serdeInfo.setSerializationLib(p.storage.serde) val serdeParameters = new java.util.HashMap[String, String]() - serdeInfo.setParameters(serdeParameters) table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } - p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + serdeInfo.setParameters(serdeParameters) - new Partition(hiveQlTable, tPartition) + new Table(tTable) } @transient override lazy val statistics: Statistics = Statistics( @@ -689,6 +837,44 @@ private[hive] case class MetastoreRelation } ) + // When metastore partition pruning is turned off, we cache the list of all partitions to + // mimic the behavior of Spark < 1.5 + lazy val allPartitions = table.getAllPartitions + + def getHiveQlPartitions(predicates: Seq[Expression] = Nil): Seq[Partition] = { + val rawPartitions = if (sqlContext.conf.metastorePartitionPruning) { + table.getPartitions(predicates) + } else { + allPartitions + } + + rawPartitions.map { p => + val tPartition = new org.apache.hadoop.hive.metastore.api.Partition + tPartition.setDbName(databaseName) + tPartition.setTableName(tableName) + tPartition.setValues(p.values.asJava) + + val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() + tPartition.setSd(sd) + sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) + + sd.setLocation(p.storage.location) + sd.setInputFormat(p.storage.inputFormat) + sd.setOutputFormat(p.storage.outputFormat) + + val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo + sd.setSerdeInfo(serdeInfo) + serdeInfo.setSerializationLib(p.storage.serde) + + val serdeParameters = new java.util.HashMap[String, String]() + serdeInfo.setParameters(serdeParameters) + table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + + new Partition(hiveQlTable, tPartition) + } + } + /** Only compare database and tablename, not alias. */ override def sameResult(plan: LogicalPlan): Boolean = { plan match { @@ -731,6 +917,18 @@ private[hive] case class MetastoreRelation /** An attribute map for determining the ordinal for non-partition columns. */ val columnOrdinals = AttributeMap(attributes.zipWithIndex) + override def inputFiles: Array[String] = { + val partLocations = table.getPartitions(Nil).map(_.storage.location).toArray + if (partLocations.nonEmpty) { + partLocations + } else { + Array( + table.location.getOrElse( + sys.error(s"Could not get the location of ${table.qualifiedName}."))) + } + } + + override def newInstance(): MetastoreRelation = { MetastoreRelation(databaseName, tableName, alias)(table)(sqlContext) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index ca4b80b51b23..d5cd7e98b526 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -18,8 +18,13 @@ package org.apache.spark.sql.hive import java.sql.Date +import java.util.Locale + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.ql.{ErrorMsg, Context} import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo} @@ -28,7 +33,9 @@ import org.apache.hadoop.hive.ql.parse._ import org.apache.hadoop.hive.ql.plan.PlanUtils import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -36,17 +43,14 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.execution.ExplainCommand -import org.apache.spark.sql.sources.DescribeCommand +import org.apache.spark.sql.execution.datasources.DescribeCommand import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable, HiveScriptIOSchema} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.random.RandomSampler -/* Implicit conversions */ -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer - /** * Used when we need to start parsing the AST before deciding that we are going to pass the command * back for Hive to execute natively. Will be replaced with a native command that contains the @@ -73,12 +77,13 @@ private[hive] case class CreateTableAsSelect( } /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ -private[hive] object HiveQl { +private[hive] object HiveQl extends Logging { protected val nativeCommands = Seq( "TOK_ALTERDATABASE_OWNER", "TOK_ALTERDATABASE_PROPERTIES", "TOK_ALTERINDEX_PROPERTIES", "TOK_ALTERINDEX_REBUILD", + "TOK_ALTERTABLE", "TOK_ALTERTABLE_ADDCOLS", "TOK_ALTERTABLE_ADDPARTS", "TOK_ALTERTABLE_ALTERPARTS", @@ -93,6 +98,7 @@ private[hive] object HiveQl { "TOK_ALTERTABLE_SKEWED", "TOK_ALTERTABLE_TOUCH", "TOK_ALTERTABLE_UNARCHIVE", + "TOK_ALTERVIEW", "TOK_ALTERVIEW_ADDPARTS", "TOK_ALTERVIEW_AS", "TOK_ALTERVIEW_DROPPARTS", @@ -186,7 +192,7 @@ private[hive] object HiveQl { .map(ast => Option(ast).map(_.transform(rule)).orNull)) } catch { case e: Exception => - println(dumpTree(n)) + logError(dumpTree(n).toString) throw e } } @@ -195,7 +201,7 @@ private[hive] object HiveQl { * Returns a scala.Seq equivalent to [s] or Nil if [s] is null. */ private def nilIfEmpty[A](s: java.util.List[A]): Seq[A] = - Option(s).map(_.toSeq).getOrElse(Nil) + Option(s).map(_.asScala).getOrElse(Nil) /** * Returns this ASTNode with the text changed to `newText`. @@ -210,7 +216,7 @@ private[hive] object HiveQl { */ def withChildren(newChildren: Seq[ASTNode]): ASTNode = { (1 to n.getChildCount).foreach(_ => n.deleteChild(0)) - n.addChildren(newChildren) + n.addChildren(newChildren.asJava) n } @@ -247,7 +253,7 @@ private[hive] object HiveQl { * Otherwise, there will be Null pointer exception, * when retrieving properties form HiveConf. */ - val hContext = new Context(hiveConf) + val hContext = new Context(SessionState.get().getConf()) val node = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql, hContext)) hContext.clear() node @@ -256,8 +262,8 @@ private[hive] object HiveQl { /** * Returns the HiveConf */ - private[this] def hiveConf(): HiveConf = { - val ss = SessionState.get() // SessionState is lazy initializaion, it can be null here + private[this] def hiveConf: HiveConf = { + val ss = SessionState.get() // SessionState is lazy initialization, it can be null here if (ss == null) { new HiveConf() } else { @@ -316,11 +322,11 @@ private[hive] object HiveQl { assert(tree.asInstanceOf[ASTNode].getText == "TOK_CREATETABLE", "Only CREATE TABLE supported.") val tableOps = tree.getChildren val colList = - tableOps + tableOps.asScala .find(_.asInstanceOf[ASTNode].getText == "TOK_TABCOLLIST") .getOrElse(sys.error("No columnList!")).getChildren - colList.map(nodeToAttribute) + colList.asScala.map(nodeToAttribute) } /** Extractor for matching Hive's AST Tokens. */ @@ -330,7 +336,7 @@ private[hive] object HiveQl { case t: ASTNode => CurrentOrigin.setPosition(t.getLine, t.getCharPositionInLine) Some((t.getText, - Option(t.getChildren).map(_.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]])) + Option(t.getChildren).map(_.asScala.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]])) case _ => None } } @@ -376,7 +382,7 @@ private[hive] object HiveQl { DecimalType(precision.getText.toInt, scale.getText.toInt) case Token("TOK_DECIMAL", precision :: Nil) => DecimalType(precision.getText.toInt, 0) - case Token("TOK_DECIMAL", Nil) => DecimalType.Unlimited + case Token("TOK_DECIMAL", Nil) => DecimalType.USER_DEFAULT case Token("TOK_BIGINT", Nil) => LongType case Token("TOK_INT", Nil) => IntegerType case Token("TOK_TINYINT", Nil) => ByteType @@ -415,16 +421,11 @@ private[hive] object HiveQl { throw new NotImplementedError(s"No parse rules for StructField:\n ${dumpTree(a).toString} ") } - protected def nameExpressions(exprs: Seq[Expression]): Seq[NamedExpression] = { - exprs.zipWithIndex.map { - case (ne: NamedExpression, _) => ne - case (e, i) => Alias(e, s"_c$i")() - } - } - protected def extractDbNameTableName(tableNameParts: Node): (Option[String], String) = { val (db, tableName) = - tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match { + tableNameParts.getChildren.asScala.map { + case Token(part, Nil) => cleanIdentifier(part) + } match { case Seq(tableOnly) => (None, tableOnly) case Seq(databaseName, table) => (Some(databaseName), table) } @@ -433,7 +434,9 @@ private[hive] object HiveQl { } protected def extractTableIdent(tableNameParts: Node): Seq[String] = { - tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match { + tableNameParts.getChildren.asScala.map { + case Token(part, Nil) => cleanIdentifier(part) + } match { case Seq(tableOnly) => Seq(tableOnly) case Seq(databaseName, table) => Seq(databaseName, table) case other => sys.error("Hive only supports tables names like 'tableName' " + @@ -583,12 +586,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C "TOK_TABLESKEWED", // Skewed by "TOK_TABLEROWFORMAT", "TOK_TABLESERIALIZER", - "TOK_FILEFORMAT_GENERIC", // For file formats not natively supported by Hive. - "TOK_TBLSEQUENCEFILE", // Stored as SequenceFile - "TOK_TBLTEXTFILE", // Stored as TextFile - "TOK_TBLRCFILE", // Stored as RCFile - "TOK_TBLORCFILE", // Stored as ORC File - "TOK_TBLPARQUETFILE", // Stored as PARQUET + "TOK_FILEFORMAT_GENERIC", "TOK_TABLEFILEFORMAT", // User-provided InputFormat and OutputFormat "TOK_STORAGEHANDLER", // Storage handler "TOK_TABLELOCATION", @@ -611,45 +609,25 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C serde = None, viewText = None) - // default storage type abbriviation (e.g. RCFile, ORC, PARQUET etc.) + // default storage type abbreviation (e.g. RCFile, ORC, PARQUET etc.) val defaultStorageType = hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTFILEFORMAT) - // handle the default format for the storage type abbriviation - tableDesc = if ("SequenceFile".equalsIgnoreCase(defaultStorageType)) { - tableDesc.copy( - inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), - outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")) - } else if ("RCFile".equalsIgnoreCase(defaultStorageType)) { - tableDesc.copy( - inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat"), - serde = Option(hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTRCFILESERDE))) - } else if ("ORC".equalsIgnoreCase(defaultStorageType)) { - tableDesc.copy( - inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat"), - serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) - } else if ("PARQUET".equalsIgnoreCase(defaultStorageType)) { - tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"), - serde = - Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) - } else { - tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - } + // handle the default format for the storage type abbreviation + val hiveSerDe = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf).getOrElse { + HiveSerDe( + inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) + } + + hiveSerDe.inputFormat.foreach(f => tableDesc = tableDesc.copy(inputFormat = Some(f))) + hiveSerDe.outputFormat.foreach(f => tableDesc = tableDesc.copy(outputFormat = Some(f))) + hiveSerDe.serde.foreach(f => tableDesc = tableDesc.copy(serde = Some(f))) children.collect { case list @ Token("TOK_TABCOLLIST", _) => val cols = BaseSemanticAnalyzer.getColumns(list, true) if (cols != null) { tableDesc = tableDesc.copy( - schema = cols.map { field => + schema = cols.asScala.map { field => HiveColumn(field.getName, field.getType, field.getComment) }) } @@ -661,7 +639,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val cols = BaseSemanticAnalyzer.getColumns(list(0), false) if (cols != null) { tableDesc = tableDesc.copy( - partitionColumns = cols.map { field => + partitionColumns = cols.asScala.map { field => HiveColumn(field.getName, field.getType, field.getComment) }) } @@ -697,7 +675,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case _ => assert(false) } tableDesc = tableDesc.copy( - serdeProperties = tableDesc.serdeProperties ++ serdeParams) + serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) case Token("TOK_TABLELOCATION", child :: Nil) => var location = BaseSemanticAnalyzer.unescapeSQLString(child.getText) location = EximUtil.relativeToAbsolutePath(hiveConf, location) @@ -709,39 +687,66 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val serdeParams = new java.util.HashMap[String, String]() BaseSemanticAnalyzer.readProps( (child.getChild(1).getChild(0)).asInstanceOf[ASTNode], serdeParams) - tableDesc = tableDesc.copy(serdeProperties = tableDesc.serdeProperties ++ serdeParams) + tableDesc = tableDesc.copy( + serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) } case Token("TOK_FILEFORMAT_GENERIC", child :: Nil) => - throw new SemanticException( - "Unrecognized file format in STORED AS clause:${child.getText}") + child.getText().toLowerCase(Locale.ENGLISH) match { + case "orc" => + tableDesc = tableDesc.copy( + inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + } - case Token("TOK_TBLRCFILE", Nil) => - tableDesc = tableDesc.copy( - inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) - } + case "parquet" => + tableDesc = tableDesc.copy( + inputFormat = + Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), + outputFormat = + Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) + } - case Token("TOK_TBLORCFILE", Nil) => - tableDesc = tableDesc.copy( - inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) - } + case "rcfile" => + tableDesc = tableDesc.copy( + inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) + } - case Token("TOK_TBLPARQUETFILE", Nil) => - tableDesc = tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) + case "textfile" => + tableDesc = tableDesc.copy( + inputFormat = + Option("org.apache.hadoop.mapred.TextInputFormat"), + outputFormat = + Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) + + case "sequencefile" => + tableDesc = tableDesc.copy( + inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), + outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")) + + case "avro" => + tableDesc = tableDesc.copy( + inputFormat = + Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat"), + outputFormat = + Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.serde2.avro.AvroSerDe")) + } + + case _ => + throw new SemanticException( + s"Unrecognized file format in STORED AS clause: ${child.getText}") } case Token("TOK_TABLESERIALIZER", @@ -757,7 +762,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_TABLEPROPERTIES", list :: Nil) => tableDesc = tableDesc.copy(properties = tableDesc.properties ++ getProperties(list)) - case list @ Token("TOK_TABLEFILEFORMAT", _) => + case list @ Token("TOK_TABLEFILEFORMAT", children) => tableDesc = tableDesc.copy( inputFormat = Option(BaseSemanticAnalyzer.unescapeSQLString(list.getChild(0).getText)), @@ -846,7 +851,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } val withWhere = whereClause.map { whereNode => - val Seq(whereExpr) = whereNode.getChildren.toSeq + val Seq(whereExpr) = whereNode.getChildren.asScala Filter(nodeToExpr(whereExpr), relations) }.getOrElse(relations) @@ -855,7 +860,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // Script transformations are expressed as a select clause with a single expression of type // TOK_TRANSFORM - val transformation = select.getChildren.head match { + val transformation = select.getChildren.iterator().next() match { case Token("TOK_SELEXPR", Token("TOK_TRANSFORM", Token("TOK_EXPLIST", inputExprs) :: @@ -880,26 +885,27 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } def matchSerDe(clause: Seq[ASTNode]) - : (Seq[(String, String)], String, Seq[(String, String)]) = clause match { + : (Seq[(String, String)], Option[String], Seq[(String, String)]) = clause match { case Token("TOK_SERDEPROPS", propsClause) :: Nil => val rowFormat = propsClause.map { case Token(name, Token(value, Nil) :: Nil) => (name, value) } - (rowFormat, "", Nil) + (rowFormat, None, Nil) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => - (Nil, serdeClass, Nil) + (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Token("TOK_TABLEPROPERTIES", Token("TOK_TABLEPROPLIST", propsClause) :: Nil) :: Nil) :: Nil => val serdeProps = propsClause.map { case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) => - (name, value) + (BaseSemanticAnalyzer.unescapeSQLString(name), + BaseSemanticAnalyzer.unescapeSQLString(value)) } - (Nil, serdeClass, serdeProps) + (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps) - case Nil => (Nil, "", Nil) + case Nil => (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), Nil) } val (inRowFormat, inSerdeClass, inSerdeProps) = matchSerDe(inputSerdeClause) @@ -923,10 +929,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val withLateralView = lateralViewClause.map { lv => val Token("TOK_SELECT", - Token("TOK_SELEXPR", clauses) :: Nil) = lv.getChildren.head + Token("TOK_SELEXPR", clauses) :: Nil) = lv.getChildren.iterator().next() - val alias = - getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText + val alias = getClause("TOK_TABALIAS", clauses).getChildren.iterator().next() + .asInstanceOf[ASTNode].getText val (generator, attributes) = nodesToGenerator(clauses) Generate( @@ -942,7 +948,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // (if there is a group by) or a script transformation. val withProject: LogicalPlan = transformation.getOrElse { val selectExpressions = - nameExpressions(select.getChildren.flatMap(selExprNodeToExpr).toSeq) + select.getChildren.asScala.flatMap(selExprNodeToExpr).map(UnresolvedAlias) Seq( groupByClause.map(e => e match { case Token("TOK_GROUPBY", children) => @@ -971,7 +977,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // Handle HAVING clause. val withHaving = havingClause.map { h => - val havingExpr = h.getChildren.toSeq match { case Seq(hexpr) => nodeToExpr(hexpr) } + val havingExpr = h.getChildren.asScala match { case Seq(hexpr) => nodeToExpr(hexpr) } // Note that we added a cast to boolean. If the expression itself is already boolean, // the optimizer will get rid of the unnecessary cast. Filter(Cast(havingExpr, BooleanType), withProject) @@ -981,32 +987,42 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val withDistinct = if (selectDistinctClause.isDefined) Distinct(withHaving) else withHaving - // Handle ORDER BY, SORT BY, DISTRIBETU BY, and CLUSTER BY clause. + // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. val withSort = (orderByClause, sortByClause, distributeByClause, clusterByClause) match { case (Some(totalOrdering), None, None, None) => - Sort(totalOrdering.getChildren.map(nodeToSortOrder), true, withDistinct) + Sort(totalOrdering.getChildren.asScala.map(nodeToSortOrder), true, withDistinct) case (None, Some(perPartitionOrdering), None, None) => - Sort(perPartitionOrdering.getChildren.map(nodeToSortOrder), false, withDistinct) + Sort( + perPartitionOrdering.getChildren.asScala.map(nodeToSortOrder), + false, withDistinct) case (None, None, Some(partitionExprs), None) => - RepartitionByExpression(partitionExprs.getChildren.map(nodeToExpr), withDistinct) + RepartitionByExpression( + partitionExprs.getChildren.asScala.map(nodeToExpr), withDistinct) case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => - Sort(perPartitionOrdering.getChildren.map(nodeToSortOrder), false, - RepartitionByExpression(partitionExprs.getChildren.map(nodeToExpr), withDistinct)) + Sort( + perPartitionOrdering.getChildren.asScala.map(nodeToSortOrder), false, + RepartitionByExpression( + partitionExprs.getChildren.asScala.map(nodeToExpr), + withDistinct)) case (None, None, None, Some(clusterExprs)) => - Sort(clusterExprs.getChildren.map(nodeToExpr).map(SortOrder(_, Ascending)), false, - RepartitionByExpression(clusterExprs.getChildren.map(nodeToExpr), withDistinct)) + Sort( + clusterExprs.getChildren.asScala.map(nodeToExpr).map(SortOrder(_, Ascending)), + false, + RepartitionByExpression( + clusterExprs.getChildren.asScala.map(nodeToExpr), + withDistinct)) case (None, None, None, None) => withDistinct case _ => sys.error("Unsupported set of ordering / distribution clauses.") } val withLimit = - limitClause.map(l => nodeToExpr(l.getChildren.head)) + limitClause.map(l => nodeToExpr(l.getChildren.iterator().next())) .map(Limit(_, withSort)) .getOrElse(withSort) // Collect all window specifications defined in the WINDOW clause. - val windowDefinitions = windowClause.map(_.getChildren.toSeq.collect { + val windowDefinitions = windowClause.map(_.getChildren.asScala.collect { case Token("TOK_WINDOWDEF", Token(windowName, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) => windowName -> nodesToWindowSpecification(spec) @@ -1043,10 +1059,11 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // return With plan if there is CTE cteRelations.map(With(query, _)).getOrElse(query) - case Token("TOK_UNION", left :: right :: Nil) => Union(nodeToPlan(left), nodeToPlan(right)) + // HIVE-9039 renamed TOK_UNION => TOK_UNIONALL while adding TOK_UNIONDISTINCT + case Token("TOK_UNIONALL", left :: right :: Nil) => Union(nodeToPlan(left), nodeToPlan(right)) case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + throw new NotImplementedError(s"No parse rules for $node:\n ${dumpTree(a).toString} ") } val allJoinTokens = "(TOK_.*JOIN)".r @@ -1060,7 +1077,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val Token("TOK_SELECT", Token("TOK_SELEXPR", clauses) :: Nil) = selectClause - val alias = getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText + val alias = getClause("TOK_TABALIAS", clauses).getChildren.iterator().next() + .asInstanceOf[ASTNode].getText val (generator, attributes) = nodesToGenerator(clauses) Generate( @@ -1089,7 +1107,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } val tableIdent = - tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match { + tableNameParts.getChildren.asScala.map { + case Token(part, Nil) => cleanIdentifier(part) + } match { case Seq(tableOnly) => Seq(tableOnly) case Seq(databaseName, table) => Seq(databaseName, table) case other => sys.error("Hive only supports tables names like 'tableName' " + @@ -1136,7 +1156,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val isPreserved = tableOrdinals.map(i => (i - 1 < 0) || joinArgs(i - 1).getText == "PRESERVE") val tables = tableOrdinals.map(i => nodeToRelation(joinArgs(i))) - val joinExpressions = tableOrdinals.map(i => joinArgs(i + 1).getChildren.map(nodeToExpr)) + val joinExpressions = + tableOrdinals.map(i => joinArgs(i + 1).getChildren.asScala.map(nodeToExpr)) val joinConditions = joinExpressions.sliding(2).map { case Seq(c1, c2) => @@ -1161,7 +1182,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C joinType = joinType.remove(joinType.length - 1)) } - val groups = (0 until joinExpressions.head.size).map(i => Coalesce(joinExpressions.map(_(i)))) + val groups = joinExpressions.head.indices.map(i => Coalesce(joinExpressions.map(_(i)))) // Unique join is not really the same as an outer join so we must group together results where // the joinExpressions are the same, taking the First of each value is only okay because the @@ -1226,7 +1247,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val tableIdent = extractTableIdent(tableNameParts) - val partitionKeys = partitionClause.map(_.getChildren.map { + val partitionKeys = partitionClause.map(_.getChildren.asScala.map { // Parse partitions. We also make keys case insensitive. case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value)) @@ -1246,7 +1267,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val tableIdent = extractTableIdent(tableNameParts) - val partitionKeys = partitionClause.map(_.getChildren.map { + val partitionKeys = partitionClause.map(_.getChildren.asScala.map { // Parse partitions. We also make keys case insensitive. case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value)) @@ -1257,7 +1278,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, true) case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + throw new NotImplementedError(s"No parse rules for ${a.getName}:" + + s"\n ${dumpTree(a).toString} ") } protected def selExprNodeToExpr(node: Node): Option[Expression] = node match { @@ -1280,7 +1302,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_HINTLIST", _) => None case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + throw new NotImplementedError(s"No parse rules for ${a.getName }:" + + s"\n ${dumpTree(a).toString } ") } protected val escapedIdentifier = "`([^`]+)`".r @@ -1327,11 +1350,11 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C /* Attribute References */ case Token("TOK_TABLE_OR_COL", Token(name, Nil) :: Nil) => - UnresolvedAttribute(cleanIdentifier(name)) + UnresolvedAttribute.quoted(cleanIdentifier(name)) case Token(".", qualifier :: Token(attr, Nil) :: Nil) => nodeToExpr(qualifier) match { - case UnresolvedAttribute(qualifierName) => - UnresolvedAttribute(qualifierName :+ cleanIdentifier(attr)) + case UnresolvedAttribute(nameParts) => + UnresolvedAttribute(nameParts :+ cleanIdentifier(attr)) case other => UnresolvedExtractValue(other, Literal(attr)) } @@ -1375,7 +1398,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, 0)) case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType.Unlimited) + Cast(nodeToExpr(arg), DecimalType.USER_DEFAULT) case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), TimestampType) case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) => @@ -1470,9 +1493,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C /* UDFs - Must be last otherwise will preempt built in functions */ case Token("TOK_FUNCTION", Token(name, Nil) :: args) => - UnresolvedFunction(name, args.map(nodeToExpr)) + UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = false) + // Aggregate function with DISTINCT keyword. + case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) => + UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = true) case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => - UnresolvedFunction(name, UnresolvedStar(None) :: Nil) + UnresolvedFunction(name, UnresolvedStar(None) :: Nil, isDistinct = false) /* Literals */ case Token("TOK_NULL", Nil) => Literal.create(null, NullType) @@ -1523,6 +1549,30 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case ast: ASTNode if ast.getType == HiveParser.TOK_CHARSETLITERAL => Literal(BaseSemanticAnalyzer.charSetString(ast.getChild(0).getText, ast.getChild(1).getText)) + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => + Literal(CalendarInterval.fromYearMonthString(ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_DAY_TIME_LITERAL => + Literal(CalendarInterval.fromDayTimeString(ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_YEAR_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("year", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_MONTH_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("month", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_DAY_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("day", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_HOUR_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("hour", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_MINUTE_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("minute", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_SECOND_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("second", ast.getText)) + case a: ASTNode => throw new NotImplementedError( s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText} : @@ -1558,18 +1608,18 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val (partitionByClause :: orderByClause :: sortByClause :: clusterByClause :: Nil) = getClauses( Seq("TOK_DISTRIBUTEBY", "TOK_ORDERBY", "TOK_SORTBY", "TOK_CLUSTERBY"), - partitionAndOrdering.getChildren.toSeq.asInstanceOf[Seq[ASTNode]]) + partitionAndOrdering.getChildren.asScala.asInstanceOf[Seq[ASTNode]]) (partitionByClause, orderByClause.orElse(sortByClause), clusterByClause) match { case (Some(partitionByExpr), Some(orderByExpr), None) => - (partitionByExpr.getChildren.map(nodeToExpr), - orderByExpr.getChildren.map(nodeToSortOrder)) + (partitionByExpr.getChildren.asScala.map(nodeToExpr), + orderByExpr.getChildren.asScala.map(nodeToSortOrder)) case (Some(partitionByExpr), None, None) => - (partitionByExpr.getChildren.map(nodeToExpr), Nil) + (partitionByExpr.getChildren.asScala.map(nodeToExpr), Nil) case (None, Some(orderByExpr), None) => - (Nil, orderByExpr.getChildren.map(nodeToSortOrder)) + (Nil, orderByExpr.getChildren.asScala.map(nodeToSortOrder)) case (None, None, Some(clusterByExpr)) => - val expressions = clusterByExpr.getChildren.map(nodeToExpr) + val expressions = clusterByExpr.getChildren.asScala.map(nodeToExpr) (expressions, expressions.map(SortOrder(_, Ascending))) case _ => throw new NotImplementedError( @@ -1607,7 +1657,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } rowFrame.orElse(rangeFrame).map { frame => - frame.getChildren.toList match { + frame.getChildren.asScala.toList match { case precedingNode :: followingNode :: Nil => SpecifiedWindowFrame( frameType, @@ -1645,7 +1695,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C sys.error(s"Couldn't find function $functionName")) val functionClassName = functionInfo.getFunctionClass.getName - (HiveGenericUdtf( + (HiveGenericUDTF( new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr)), attributes) @@ -1669,7 +1719,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case other => sys.error(s"Non ASTNode encountered: $other") } - Option(node.getChildren).map(_.toList).getOrElse(Nil).foreach(dumpTree(_, builder, indent + 1)) + Option(node.getChildren).map(_.asScala).getOrElse(Nil).foreach(dumpTree(_, builder, indent + 1)) builder } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index d08c59415165..004805f3aed0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -20,19 +20,21 @@ package org.apache.spark.sql.hive import java.io.{InputStream, OutputStream} import java.rmi.server.UID -/* Implicit conversions */ -import scala.collection.JavaConversions._ +import org.apache.avro.Schema + +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} import org.apache.hadoop.hive.serde2.ColumnProjectionUtils -import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable +import org.apache.hadoop.hive.serde2.avro.{AvroGenericRecordWritable, AvroSerdeUtils} import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector import org.apache.hadoop.io.Writable @@ -70,7 +72,7 @@ private[hive] object HiveShim { */ def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { if (ids != null && ids.nonEmpty) { - ColumnProjectionUtils.appendReadColumns(conf, ids) + ColumnProjectionUtils.appendReadColumns(conf, ids.asJava) } if (names != null && names.nonEmpty) { appendReadColumnNames(conf, names) @@ -81,10 +83,19 @@ private[hive] object HiveShim { * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that * is needed to initialize before serialization. */ - def prepareWritable(w: Writable): Writable = { + def prepareWritable(w: Writable, serDeProps: Seq[(String, String)]): Writable = { w match { case w: AvroGenericRecordWritable => w.setRecordReaderID(new UID()) + // In Hive 1.1, the record's schema may need to be initialized manually or a NPE will + // be thrown. + if (w.getFileSchema() == null) { + serDeProps + .find(_._1 == AvroSerdeUtils.AvroTableProperties.SCHEMA_LITERAL.getPropName()) + .foreach { kv => + w.setFileSchema(new Schema.Parser().parse(kv._2)) + } + } case _ => } w diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 452b7f0bcc74..d38ad9127327 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -17,22 +17,14 @@ package org.apache.spark.sql.hive -import scala.collection.JavaConversions._ - -import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate -import org.apache.spark.sql.catalyst.expressions.{InternalRow, _} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand} import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand, _} import org.apache.spark.sql.hive.execution._ -import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.sources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand} -import org.apache.spark.sql.types.StringType private[hive] trait HiveStrategies { @@ -41,136 +33,6 @@ private[hive] trait HiveStrategies { val hiveContext: HiveContext - /** - * :: Experimental :: - * Finds table scans that would use the Hive SerDe and replaces them with our own native parquet - * table scan operator. - * - * TODO: Much of this logic is duplicated in HiveTableScan. Ideally we would do some refactoring - * but since this is after the code freeze for 1.1 all logic is here to minimize disruption. - * - * Other issues: - * - Much of this logic assumes case insensitive resolution. - */ - @Experimental - object ParquetConversion extends Strategy { - implicit class LogicalPlanHacks(s: DataFrame) { - def lowerCase: DataFrame = DataFrame(s.sqlContext, s.logicalPlan) - - def addPartitioningAttributes(attrs: Seq[Attribute]): DataFrame = { - // Don't add the partitioning key if its already present in the data. - if (attrs.map(_.name).toSet.subsetOf(s.logicalPlan.output.map(_.name).toSet)) { - s - } else { - DataFrame( - s.sqlContext, - s.logicalPlan transform { - case p: ParquetRelation => p.copy(partitioningAttributes = attrs) - }) - } - } - } - - implicit class PhysicalPlanHacks(originalPlan: SparkPlan) { - def fakeOutput(newOutput: Seq[Attribute]): OutputFaker = - OutputFaker( - originalPlan.output.map(a => - newOutput.find(a.name.toLowerCase == _.name.toLowerCase) - .getOrElse( - sys.error(s"Can't find attribute $a to fake in set ${newOutput.mkString(",")}"))), - originalPlan) - } - - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(projectList, predicates, relation: MetastoreRelation) - if relation.tableDesc.getSerdeClassName.contains("Parquet") && - hiveContext.convertMetastoreParquet && - !hiveContext.conf.parquetUseDataSourceApi => - - // Filter out all predicates that only deal with partition keys - val partitionsKeys = AttributeSet(relation.partitionKeys) - val (pruningPredicates, otherPredicates) = predicates.partition { - _.references.subsetOf(partitionsKeys) - } - - // We are going to throw the predicates and projection back at the whole optimization - // sequence so lets unresolve all the attributes, allowing them to be rebound to the - // matching parquet attributes. - val unresolvedOtherPredicates = Column(otherPredicates.map(_ transform { - case a: AttributeReference => UnresolvedAttribute(a.name) - }).reduceOption(And).getOrElse(Literal(true))) - - val unresolvedProjection: Seq[Column] = projectList.map(_ transform { - case a: AttributeReference => UnresolvedAttribute(a.name) - }).map(Column(_)) - - try { - if (relation.hiveQlTable.isPartitioned) { - val rawPredicate = pruningPredicates.reduceOption(And).getOrElse(Literal(true)) - // Translate the predicate so that it automatically casts the input values to the - // correct data types during evaluation. - val castedPredicate = rawPredicate transform { - case a: AttributeReference => - val idx = relation.partitionKeys.indexWhere(a.exprId == _.exprId) - val key = relation.partitionKeys(idx) - Cast(BoundReference(idx, StringType, nullable = true), key.dataType) - } - - val inputData = new GenericMutableRow(relation.partitionKeys.size) - val pruningCondition = - if (codegenEnabled) { - GeneratePredicate.generate(castedPredicate) - } else { - InterpretedPredicate.create(castedPredicate) - } - - val partitions = relation.hiveQlPartitions.filter { part => - val partitionValues = part.getValues - var i = 0 - while (i < partitionValues.size()) { - inputData(i) = CatalystTypeConverters.convertToCatalyst(partitionValues(i)) - i += 1 - } - pruningCondition(inputData) - } - - val partitionLocations = partitions.map(_.getLocation) - - if (partitionLocations.isEmpty) { - PhysicalRDD(plan.output, sparkContext.emptyRDD[InternalRow]) :: Nil - } else { - hiveContext - .read.parquet(partitionLocations: _*) - .addPartitioningAttributes(relation.partitionKeys) - .lowerCase - .where(unresolvedOtherPredicates) - .select(unresolvedProjection: _*) - .queryExecution - .executedPlan - .fakeOutput(projectList.map(_.toAttribute)) :: Nil - } - - } else { - hiveContext - .read.parquet(relation.hiveQlTable.getDataLocation.toString) - .lowerCase - .where(unresolvedOtherPredicates) - .select(unresolvedProjection: _*) - .queryExecution - .executedPlan - .fakeOutput(projectList.map(_.toAttribute)) :: Nil - } - } catch { - // parquetFile will throw an exception when there is no data. - // TODO: Remove this hack for Spark 1.3. - case iae: java.lang.IllegalArgumentException - if iae.getMessage.contains("Can not create a Path from an empty string") => - PhysicalRDD(plan.output, sparkContext.emptyRDD[InternalRow]) :: Nil - } - case _ => Nil - } - } - object Scripts extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.ScriptTransformation(input, script, output, child, schema: HiveScriptIOSchema) => @@ -212,7 +74,7 @@ private[hive] trait HiveStrategies { projectList, otherPredicates, identity[Seq[Expression]], - HiveTableScan(_, relation, pruningPredicates.reduceLeftOption(And))(hiveContext)) :: Nil + HiveTableScan(_, relation, pruningPredicates)(hiveContext)) :: Nil case _ => Nil } @@ -221,14 +83,16 @@ private[hive] trait HiveStrategies { object HiveDDLStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case CreateTableUsing( - tableName, userSpecifiedSchema, provider, false, opts, allowExisting, managedIfNoPath) => - ExecutedCommand( + tableIdent, userSpecifiedSchema, provider, false, opts, allowExisting, managedIfNoPath) => + val cmd = CreateMetastoreDataSource( - tableName, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath)) :: Nil + tableIdent, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath) + ExecutedCommand(cmd) :: Nil - case CreateTableUsingAsSelect(tableName, provider, false, partitionCols, mode, opts, query) => + case CreateTableUsingAsSelect( + tableIdent, provider, false, partitionCols, mode, opts, query) => val cmd = - CreateMetastoreDataSourceAsSelect(tableName, provider, partitionCols, mode, opts, query) + CreateMetastoreDataSourceAsSelect(tableIdent, provider, partitionCols, mode, opts, query) ExecutedCommand(cmd) :: Nil case _ => Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 439f39bafc92..e35468a624c3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -29,11 +29,13 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} -import org.apache.spark.{Logging} +import org.apache.spark.Logging import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{SerializableConfiguration, Utils} /** @@ -52,10 +54,10 @@ private[hive] sealed trait TableReader { */ private[hive] class HadoopTableReader( - @transient attributes: Seq[Attribute], - @transient relation: MetastoreRelation, - @transient sc: HiveContext, - @transient hiveExtraConf: HiveConf) + @transient private val attributes: Seq[Attribute], + @transient private val relation: MetastoreRelation, + @transient private val sc: HiveContext, + hiveExtraConf: HiveConf) extends TableReader with Logging { // Hadoop honors "mapred.map.tasks" as hint, but will ignore when mapred.job.tracker is "local". @@ -76,9 +78,7 @@ class HadoopTableReader( override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] = makeRDDForTable( hiveTable, - Class.forName( - relation.tableDesc.getSerdeClassName, true, Utils.getContextOrSparkClassLoader) - .asInstanceOf[Class[Deserializer]], + Utils.classForName(relation.tableDesc.getSerdeClassName).asInstanceOf[Class[Deserializer]], filterOpt = None) /** @@ -356,16 +356,16 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) case oi: HiveVarcharObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => - row.setString(ordinal, oi.getPrimitiveJavaObject(value).getValue) + row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) case oi: HiveDecimalObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, HiveShim.toCatalystDecimal(oi, value)) case oi: TimestampObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => - row.setLong(ordinal, DateUtils.fromJavaTimestamp(oi.getPrimitiveJavaObject(value))) + row.setLong(ordinal, DateTimeUtils.fromJavaTimestamp(oi.getPrimitiveJavaObject(value))) case oi: DateObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => - row.setInt(ordinal, DateUtils.fromJavaDate(oi.getPrimitiveJavaObject(value))) + row.setInt(ordinal, DateTimeUtils.fromJavaDate(oi.getPrimitiveJavaObject(value))) case oi: BinaryObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, oi.getPrimitiveJavaObject(value)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index 0a1d761a52f8..3811c152a7ae 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -21,6 +21,7 @@ import java.io.PrintStream import java.util.{Map => JMap} import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException} +import org.apache.spark.sql.catalyst.expressions.Expression private[hive] case class HiveDatabase( name: String, @@ -73,6 +74,9 @@ private[hive] case class HiveTable( def getAllPartitions: Seq[HivePartition] = client.getAllPartitions(this) + def getPartitions(predicates: Seq[Expression]): Seq[HivePartition] = + client.getPartitionsByFilter(this, predicates) + // Hive does not support backticks when passing names to the client. def qualifiedName: String = s"$database.$name" } @@ -83,6 +87,13 @@ private[hive] case class HiveTable( * shared classes. */ private[hive] trait ClientInterface { + + /** Returns the Hive Version of this client. */ + def version: HiveVersion + + /** Returns the configuration for the given key in the current session. */ + def getConf(key: String, defaultValue: String): String + /** * Runs a HiveQL command using Hive, returning the results as a list of strings. Each row will * result in one string. @@ -132,6 +143,9 @@ private[hive] trait ClientInterface { /** Returns all partitions for the given table. */ def getAllPartitions(hTable: HiveTable): Seq[HivePartition] + /** Returns partitions filtered by predicates for the given table. */ + def getPartitionsByFilter(hTable: HiveTable, predicates: Seq[Expression]): Seq[HivePartition] + /** Loads a static partition into an existing table. */ def loadPartition( loadPath: String, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 42c2d4c98ffb..4d1e3ed9198e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -17,28 +17,28 @@ package org.apache.spark.sql.hive.client -import java.io.{BufferedReader, InputStreamReader, File, PrintStream} -import java.net.URI -import java.util.{ArrayList => JArrayList, Map => JMap, List => JList, Set => JSet} +import java.io.{File, PrintStream} +import java.util.{Map => JMap} +import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.reflectiveCalls import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.metastore.api.Database import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.metastore.api.{Database, FieldSchema} import org.apache.hadoop.hive.metastore.{TableType => HTableType} -import org.apache.hadoop.hive.metastore.api -import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.hadoop.hive.ql.metadata import org.apache.hadoop.hive.ql.metadata.Hive -import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.ql.processors._ -import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.ql.{Driver, metadata} +import org.apache.hadoop.hive.shims.{HadoopShims, ShimLoader} +import org.apache.hadoop.util.VersionInfo import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.execution.QueryExecutionException - +import org.apache.spark.util.{CircularBuffer, Utils} /** * A class that wraps the HiveClient and converts its responses to externally visible classes. @@ -58,44 +58,90 @@ import org.apache.spark.sql.execution.QueryExecutionException * this ClientWrapper. */ private[hive] class ClientWrapper( - version: HiveVersion, + override val version: HiveVersion, config: Map[String, String], initClassLoader: ClassLoader) extends ClientInterface with Logging { - // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. - private val outputBuffer = new java.io.OutputStream { - var pos: Int = 0 - var buffer = new Array[Int](10240) - def write(i: Int): Unit = { - buffer(pos) = i - pos = (pos + 1) % buffer.size + overrideHadoopShims() + + // !! HACK ALERT !! + // + // Internally, Hive `ShimLoader` tries to load different versions of Hadoop shims by checking + // major version number gathered from Hadoop jar files: + // + // - For major version number 1, load `Hadoop20SShims`, where "20S" stands for Hadoop 0.20 with + // security. + // - For major version number 2, load `Hadoop23Shims`, where "23" stands for Hadoop 0.23. + // + // However, APIs in Hadoop 2.0.x and 2.1.x versions were in flux due to historical reasons. It + // turns out that Hadoop 2.0.x versions should also be used together with `Hadoop20SShims`, but + // `Hadoop23Shims` is chosen because the major version number here is 2. + // + // To fix this issue, we try to inspect Hadoop version via `org.apache.hadoop.utils.VersionInfo` + // and load `Hadoop20SShims` for Hadoop 1.x and 2.0.x versions. If Hadoop version information is + // not available, we decide whether to override the shims or not by checking for existence of a + // probe method which doesn't exist in Hadoop 1.x or 2.0.x versions. + private def overrideHadoopShims(): Unit = { + val hadoopVersion = VersionInfo.getVersion + val VersionPattern = """(\d+)\.(\d+).*""".r + + hadoopVersion match { + case null => + logError("Failed to inspect Hadoop version") + + // Using "Path.getPathWithoutSchemeAndAuthority" as the probe method. + val probeMethod = "getPathWithoutSchemeAndAuthority" + if (!classOf[Path].getDeclaredMethods.exists(_.getName == probeMethod)) { + logInfo( + s"Method ${classOf[Path].getCanonicalName}.$probeMethod not found, " + + s"we are probably using Hadoop 1.x or 2.0.x") + loadHadoop20SShims() + } + + case VersionPattern(majorVersion, minorVersion) => + logInfo(s"Inspected Hadoop version: $hadoopVersion") + + // Loads Hadoop20SShims for 1.x and 2.0.x versions + val (major, minor) = (majorVersion.toInt, minorVersion.toInt) + if (major < 2 || (major == 2 && minor == 0)) { + loadHadoop20SShims() + } } - override def toString: String = { - val (end, start) = buffer.splitAt(pos) - val input = new java.io.InputStream { - val iterator = (start ++ end).iterator + // Logs the actual loaded Hadoop shims class + val loadedShimsClassName = ShimLoader.getHadoopShims.getClass.getCanonicalName + logInfo(s"Loaded $loadedShimsClassName for Hadoop version $hadoopVersion") + } - def read(): Int = if (iterator.hasNext) iterator.next() else -1 - } - val reader = new BufferedReader(new InputStreamReader(input)) - val stringBuilder = new StringBuilder - var line = reader.readLine() - while(line != null) { - stringBuilder.append(line) - stringBuilder.append("\n") - line = reader.readLine() - } - stringBuilder.toString() + private def loadHadoop20SShims(): Unit = { + val hadoop20SShimsClassName = "org.apache.hadoop.hive.shims.Hadoop20SShims" + logInfo(s"Loading Hadoop shims $hadoop20SShimsClassName") + + try { + val shimsField = classOf[ShimLoader].getDeclaredField("hadoopShims") + // scalastyle:off classforname + val shimsClass = Class.forName(hadoop20SShimsClassName) + // scalastyle:on classforname + val shims = classOf[HadoopShims].cast(shimsClass.newInstance()) + shimsField.setAccessible(true) + shimsField.set(null, shims) + } catch { case cause: Throwable => + throw new RuntimeException(s"Failed to load $hadoop20SShimsClassName", cause) } } + // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. + private val outputBuffer = new CircularBuffer() + private val shim = version match { case hive.v12 => new Shim_v0_12() case hive.v13 => new Shim_v0_13() case hive.v14 => new Shim_v0_14() + case hive.v1_0 => new Shim_v1_0() + case hive.v1_1 => new Shim_v1_1() + case hive.v1_2 => new Shim_v1_2() } // Create an internal session state for this ClientWrapper. @@ -114,7 +160,11 @@ private[hive] class ClientWrapper( // this action explicit. initialConf.setClassLoader(initClassLoader) config.foreach { case (k, v) => - logDebug(s"Hive Config: $k=$v") + if (k.toLowerCase.contains("password")) { + logDebug(s"Hive Config: $k=xxx") + } else { + logDebug(s"Hive Config: $k=$v") + } initialConf.set(k, v) } val newState = new SessionState(initialConf) @@ -134,14 +184,68 @@ private[hive] class ClientWrapper( /** Returns the configuration for the current session. */ def conf: HiveConf = SessionState.get().getConf + override def getConf(key: String, defaultValue: String): String = { + conf.get(key, defaultValue) + } + // TODO: should be a def?s // When we create this val client, the HiveConf of it (conf) is the one associated with state. - private val client = Hive.get(conf) + @GuardedBy("this") + private var client = Hive.get(conf) + + // We use hive's conf for compatibility. + private val retryLimit = conf.getIntVar(HiveConf.ConfVars.METASTORETHRIFTFAILURERETRIES) + private val retryDelayMillis = shim.getMetastoreClientConnectRetryDelayMillis(conf) + + /** + * Runs `f` with multiple retries in case the hive metastore is temporarily unreachable. + */ + private def retryLocked[A](f: => A): A = synchronized { + // Hive sometimes retries internally, so set a deadline to avoid compounding delays. + val deadline = System.nanoTime + (retryLimit * retryDelayMillis * 1e6).toLong + var numTries = 0 + var caughtException: Exception = null + do { + numTries += 1 + try { + return f + } catch { + case e: Exception if causedByThrift(e) => + caughtException = e + logWarning( + "HiveClientWrapper got thrift exception, destroying client and retrying " + + s"(${retryLimit - numTries} tries remaining)", e) + Thread.sleep(retryDelayMillis) + try { + client = Hive.get(state.getConf, true) + } catch { + case e: Exception if causedByThrift(e) => + logWarning("Failed to refresh hive client, will retry.", e) + } + } + } while (numTries <= retryLimit && System.nanoTime < deadline) + if (System.nanoTime > deadline) { + logWarning("Deadline exceeded") + } + throw caughtException + } + + private def causedByThrift(e: Throwable): Boolean = { + var target = e + while (target != null) { + val msg = target.getMessage() + if (msg != null && msg.matches("(?s).*(TApplication|TProtocol|TTransport)Exception.*")) { + return true + } + target = target.getCause() + } + false + } /** * Runs `f` with ThreadLocal session state and classloaders configured for this version of hive. */ - private def withHiveState[A](f: => A): A = synchronized { + private def withHiveState[A](f: => A): A = retryLocked { val original = Thread.currentThread().getContextClassLoader // Set the thread local metastore client to the client associated with this ClientWrapper. Hive.set(client) @@ -201,10 +305,11 @@ private[hive] class ClientWrapper( HiveTable( name = h.getTableName, specifiedDatabase = Option(h.getDbName), - schema = h.getCols.map(f => HiveColumn(f.getName, f.getType, f.getComment)), - partitionColumns = h.getPartCols.map(f => HiveColumn(f.getName, f.getType, f.getComment)), - properties = h.getParameters.toMap, - serdeProperties = h.getTTable.getSd.getSerdeInfo.getParameters.toMap, + schema = h.getCols.asScala.map(f => HiveColumn(f.getName, f.getType, f.getComment)), + partitionColumns = h.getPartCols.asScala.map(f => + HiveColumn(f.getName, f.getType, f.getComment)), + properties = h.getParameters.asScala.toMap, + serdeProperties = h.getTTable.getSd.getSerdeInfo.getParameters.asScala.toMap, tableType = h.getTableType match { case HTableType.MANAGED_TABLE => ManagedTable case HTableType.EXTERNAL_TABLE => ExternalTable @@ -221,18 +326,18 @@ private[hive] class ClientWrapper( } private def toInputFormat(name: String) = - Class.forName(name).asInstanceOf[Class[_ <: org.apache.hadoop.mapred.InputFormat[_, _]]] + Utils.classForName(name).asInstanceOf[Class[_ <: org.apache.hadoop.mapred.InputFormat[_, _]]] private def toOutputFormat(name: String) = - Class.forName(name) + Utils.classForName(name) .asInstanceOf[Class[_ <: org.apache.hadoop.hive.ql.io.HiveOutputFormat[_, _]]] private def toQlTable(table: HiveTable): metadata.Table = { val qlTable = new metadata.Table(table.database, table.name) - qlTable.setFields(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + qlTable.setFields(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) qlTable.setPartCols( - table.partitionColumns.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + table.partitionColumns.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) table.properties.foreach { case (k, v) => qlTable.setProperty(k, v) } table.serdeProperties.foreach { case (k, v) => qlTable.setSerdeParam(k, v) } @@ -262,13 +367,13 @@ private[hive] class ClientWrapper( private def toHivePartition(partition: metadata.Partition): HivePartition = { val apiPartition = partition.getTPartition HivePartition( - values = Option(apiPartition.getValues).map(_.toSeq).getOrElse(Seq.empty), + values = Option(apiPartition.getValues).map(_.asScala).getOrElse(Seq.empty), storage = HiveStorageDescriptor( location = apiPartition.getSd.getLocation, inputFormat = apiPartition.getSd.getInputFormat, outputFormat = apiPartition.getSd.getOutputFormat, serde = apiPartition.getSd.getSerdeInfo.getSerializationLib, - serdeProperties = apiPartition.getSd.getSerdeInfo.getParameters.toMap)) + serdeProperties = apiPartition.getSd.getSerdeInfo.getParameters.asScala.toMap)) } override def getPartitionOption( @@ -285,8 +390,15 @@ private[hive] class ClientWrapper( shim.getAllPartitions(client, qlTable).map(toHivePartition) } + override def getPartitionsByFilter( + hTable: HiveTable, + predicates: Seq[Expression]): Seq[HivePartition] = withHiveState { + val qlTable = toQlTable(hTable) + shim.getPartitionsByFilter(client, qlTable, predicates).map(toHivePartition) + } + override def listTables(dbName: String): Seq[String] = withHiveState { - client.getAllTables(dbName) + client.getAllTables(dbName).asScala } /** @@ -329,7 +441,9 @@ private[hive] class ClientWrapper( case _ => if (state.out != null) { + // scalastyle:off println state.out.println(tokens(0) + " " + cmd_1) + // scalastyle:on println } Seq(proc.run(cmd_1).getResponseCode.toString) } @@ -401,17 +515,17 @@ private[hive] class ClientWrapper( } def reset(): Unit = withHiveState { - client.getAllTables("default").foreach { t => + client.getAllTables("default").asScala.foreach { t => logDebug(s"Deleting table $t") val table = client.getTable("default", t) - client.getIndexes("default", t, 255).foreach { index => - client.dropIndex("default", t, index.getIndexName, true) + client.getIndexes("default", t, 255).asScala.foreach { index => + shim.dropIndex(client, "default", t, index.getIndexName) } if (!table.isIndexTable) { client.dropTable("default", t) } } - client.getAllDatabases.filterNot(_ == "default").foreach { db => + client.getAllDatabases.asScala.filterNot(_ == "default").foreach { db => logDebug(s"Dropping Database: $db") client.dropDatabase(db, true, false, true) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 5ae2dbb50d86..48bbb21e6c1d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -17,19 +17,25 @@ package org.apache.spark.sql.hive.client -import java.lang.{Boolean => JBoolean, Integer => JInteger} +import java.lang.{Boolean => JBoolean, Integer => JInteger, Long => JLong} import java.lang.reflect.{Method, Modifier} import java.net.URI import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, Set => JSet} +import java.util.concurrent.TimeUnit -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.serde.serdeConstants + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{StringType, IntegralType} /** * A shim that defines the interface between ClientWrapper and the underlying Hive library used to @@ -60,10 +66,14 @@ private[client] sealed abstract class Shim { def getAllPartitions(hive: Hive, table: Table): Seq[Partition] + def getPartitionsByFilter(hive: Hive, table: Table, predicates: Seq[Expression]): Seq[Partition] + def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor def getDriverResults(driver: Driver): Seq[String] + def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long + def loadPartition( hive: Hive, loadPath: Path, @@ -91,6 +101,8 @@ private[client] sealed abstract class Shim { holdDDLTime: Boolean, listBucketingEnabled: Boolean): Unit + def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit + protected def findStaticMethod(klass: Class[_], name: String, args: Class[_]*): Method = { val method = findMethod(klass, name, args: _*) require(Modifier.isStatic(method.getModifiers()), @@ -104,7 +116,7 @@ private[client] sealed abstract class Shim { } -private[client] class Shim_v0_12 extends Shim { +private[client] class Shim_v0_12 extends Shim with Logging { private lazy val startMethod = findStaticMethod( @@ -163,6 +175,14 @@ private[client] class Shim_v0_12 extends Shim { JInteger.TYPE, JBoolean.TYPE, JBoolean.TYPE) + private lazy val dropIndexMethod = + findMethod( + classOf[Hive], + "dropIndex", + classOf[String], + classOf[String], + classOf[String], + JBoolean.TYPE) override def setCurrentSessionState(state: SessionState): Unit = { // Starting from Hive 0.13, setCurrentSessionState will internally override @@ -181,7 +201,18 @@ private[client] class Shim_v0_12 extends Shim { setDataLocationMethod.invoke(table, new URI(loc)) override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = - getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq + getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].asScala.toSeq + + override def getPartitionsByFilter( + hive: Hive, + table: Table, + predicates: Seq[Expression]): Seq[Partition] = { + // getPartitionsByFilter() doesn't support binary comparison ops in Hive 0.12. + // See HIVE-4888. + logDebug("Hive 0.12 doesn't support predicate pushdown to metastore. " + + "Please use Hive 0.13 or higher.") + getAllPartitions(hive, table) + } override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor = getCommandProcessorMethod.invoke(null, token, conf).asInstanceOf[CommandProcessor] @@ -189,7 +220,11 @@ private[client] class Shim_v0_12 extends Shim { override def getDriverResults(driver: Driver): Seq[String] = { val res = new JArrayList[String]() getDriverResultsMethod.invoke(driver, res) - res.toSeq + res.asScala + } + + override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { + conf.getIntVar(HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY) * 1000 } override def loadPartition( @@ -227,6 +262,10 @@ private[client] class Shim_v0_12 extends Shim { numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean) } + override def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit = { + dropIndexMethod.invoke(hive, dbName, tableName, indexName, true: JBoolean) + } + } private[client] class Shim_v0_13 extends Shim_v0_12 { @@ -246,6 +285,12 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { classOf[Hive], "getAllPartitionsOf", classOf[Table]) + private lazy val getPartitionsByFilterMethod = + findMethod( + classOf[Hive], + "getPartitionsByFilter", + classOf[Table], + classOf[String]) private lazy val getCommandProcessorMethod = findStaticMethod( classOf[CommandProcessorFactory], @@ -265,7 +310,52 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { setDataLocationMethod.invoke(table, new Path(loc)) override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = - getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq + getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].asScala.toSeq + + /** + * Converts catalyst expression to the format that Hive's getPartitionsByFilter() expects, i.e. + * a string that represents partition predicates like "str_key=\"value\" and int_key=1 ...". + * + * Unsupported predicates are skipped. + */ + def convertFilters(table: Table, filters: Seq[Expression]): String = { + // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. + val varcharKeys = table.getPartitionKeys.asScala + .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME)) + .map(col => col.getName).toSet + + filters.collect { + case op @ BinaryComparison(a: Attribute, Literal(v, _: IntegralType)) => + s"${a.name} ${op.symbol} $v" + case op @ BinaryComparison(Literal(v, _: IntegralType), a: Attribute) => + s"$v ${op.symbol} ${a.name}" + case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType)) + if !varcharKeys.contains(a.name) => + s"""${a.name} ${op.symbol} "$v"""" + case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute) + if !varcharKeys.contains(a.name) => + s""""$v" ${op.symbol} ${a.name}""" + }.mkString(" and ") + } + + override def getPartitionsByFilter( + hive: Hive, + table: Table, + predicates: Seq[Expression]): Seq[Partition] = { + + // Hive getPartitionsByFilter() takes a string that represents partition + // predicates like "str_key=\"value\" and int_key=1 ..." + val filter = convertFilters(table, predicates) + val partitions = + if (filter.isEmpty) { + getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] + } else { + logDebug(s"Hive metastore filter is '$filter'.") + getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]] + } + + partitions.asScala.toSeq + } override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor = getCommandProcessorMethod.invoke(null, Array(token), conf).asInstanceOf[CommandProcessor] @@ -273,7 +363,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { override def getDriverResults(driver: Driver): Seq[String] = { val res = new JArrayList[Object]() getDriverResultsMethod.invoke(driver, res) - res.map { r => + res.asScala.map { r => r match { case s: String => s case a: Array[Object] => a(0).asInstanceOf[String] @@ -321,6 +411,12 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { JBoolean.TYPE, JBoolean.TYPE, JBoolean.TYPE) + private lazy val getTimeVarMethod = + findMethod( + classOf[HiveConf], + "getTimeVar", + classOf[HiveConf.ConfVars], + classOf[TimeUnit]) override def loadPartition( hive: Hive, @@ -333,7 +429,7 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { isSkewedStoreAsSubdir: Boolean): Unit = { loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, holdDDLTime: JBoolean, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, - JBoolean.TRUE, JBoolean.FALSE) + isSrcLocal(loadPath, hive.getConf()): JBoolean, JBoolean.FALSE) } override def loadTable( @@ -343,7 +439,7 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { replace: Boolean, holdDDLTime: Boolean): Unit = { loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, holdDDLTime: JBoolean, - JBoolean.TRUE, JBoolean.FALSE, JBoolean.FALSE) + isSrcLocal(loadPath, hive.getConf()): JBoolean, JBoolean.FALSE, JBoolean.FALSE) } override def loadDynamicPartitions( @@ -359,4 +455,71 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean, JBoolean.FALSE) } + override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { + getTimeVarMethod.invoke( + conf, + HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY, + TimeUnit.MILLISECONDS).asInstanceOf[Long] + } + + protected def isSrcLocal(path: Path, conf: HiveConf): Boolean = { + val localFs = FileSystem.getLocal(conf) + val pathFs = FileSystem.get(path.toUri(), conf) + localFs.getUri() == pathFs.getUri() + } + +} + +private[client] class Shim_v1_0 extends Shim_v0_14 { + +} + +private[client] class Shim_v1_1 extends Shim_v1_0 { + + private lazy val dropIndexMethod = + findMethod( + classOf[Hive], + "dropIndex", + classOf[String], + classOf[String], + classOf[String], + JBoolean.TYPE, + JBoolean.TYPE) + + override def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit = { + dropIndexMethod.invoke(hive, dbName, tableName, indexName, true: JBoolean, true: JBoolean) + } + +} + +private[client] class Shim_v1_2 extends Shim_v1_1 { + + private lazy val loadDynamicPartitionsMethod = + findMethod( + classOf[Hive], + "loadDynamicPartitions", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JInteger.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JLong.TYPE) + + override def loadDynamicPartitions( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + numDP: Int, + holdDDLTime: Boolean, + listBucketingEnabled: Boolean): Unit = { + loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean, JBoolean.FALSE, + 0L: JLong) + } + } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 0934ad503467..1fe4cba9571f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -41,19 +41,31 @@ private[hive] object IsolatedClientLoader { */ def forVersion( version: String, - config: Map[String, String] = Map.empty): IsolatedClientLoader = synchronized { + config: Map[String, String] = Map.empty, + ivyPath: Option[String] = None, + sharedPrefixes: Seq[String] = Seq.empty, + barrierPrefixes: Seq[String] = Seq.empty): IsolatedClientLoader = synchronized { val resolvedVersion = hiveVersion(version) - val files = resolvedVersions.getOrElseUpdate(resolvedVersion, downloadVersion(resolvedVersion)) - new IsolatedClientLoader(hiveVersion(version), files, config) + val files = resolvedVersions.getOrElseUpdate(resolvedVersion, + downloadVersion(resolvedVersion, ivyPath)) + new IsolatedClientLoader( + version = hiveVersion(version), + execJars = files, + config = config, + sharedPrefixes = sharedPrefixes, + barrierPrefixes = barrierPrefixes) } def hiveVersion(version: String): HiveVersion = version match { case "12" | "0.12" | "0.12.0" => hive.v12 case "13" | "0.13" | "0.13.0" | "0.13.1" => hive.v13 case "14" | "0.14" | "0.14.0" => hive.v14 + case "1.0" | "1.0.0" => hive.v1_0 + case "1.1" | "1.1.0" => hive.v1_1 + case "1.2" | "1.2.0" | "1.2.1" => hive.v1_2 } - private def downloadVersion(version: HiveVersion): Seq[URL] = { + private def downloadVersion(version: HiveVersion, ivyPath: Option[String]): Seq[URL] = { val hiveArtifacts = version.extraDeps ++ Seq("hive-metastore", "hive-exec", "hive-common", "hive-serde") .map(a => s"org.apache.hive:$a:${version.fullVersion}") ++ @@ -64,7 +76,7 @@ private[hive] object IsolatedClientLoader { SparkSubmitUtils.resolveMavenCoordinates( hiveArtifacts.mkString(","), Some("http://www.datanucleus.org/downloads/maven2"), - None, + ivyPath, exclusions = version.exclusions) } val allFiles = classpath.split(",").map(new File(_)).toSet @@ -72,7 +84,7 @@ private[hive] object IsolatedClientLoader { // TODO: Remove copy logic. val tempDir = Utils.createTempDir(namePrefix = s"hive-${version}") allFiles.foreach(f => FileUtils.copyFileToDirectory(f, tempDir)) - tempDir.listFiles().map(_.toURL) + tempDir.listFiles().map(_.toURI.toURL) } private def resolvedVersions = new scala.collection.mutable.HashMap[HiveVersion, Seq[URL]] @@ -119,8 +131,9 @@ private[hive] class IsolatedClientLoader( name.contains("slf4j") || name.contains("log4j") || name.startsWith("org.apache.spark.") || + (name.startsWith("org.apache.hadoop.") && !name.startsWith("org.apache.hadoop.hive.")) || name.startsWith("scala.") || - name.startsWith("com.google") || + (name.startsWith("com.google") && !name.startsWith("com.google.cloud")) || name.startsWith("java.lang.") || name.startsWith("java.net") || sharedPrefixes.exists(name.startsWith) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 27a3d8f5896c..b1b8439efa01 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -25,20 +25,43 @@ package object client { val exclusions: Seq[String] = Nil) // scalastyle:off - private[client] object hive { + private[hive] object hive { case object v12 extends HiveVersion("0.12.0") case object v13 extends HiveVersion("0.13.1") // Hive 0.14 depends on calcite 0.9.2-incubating-SNAPSHOT which does not exist in // maven central anymore, so override those with a version that exists. // - // org.pentaho:pentaho-aggdesigner-algorithm is also nowhere to be found, so exclude - // it explicitly. If it's needed by the metastore client, users will have to dig it - // out of somewhere and use configuration to point Spark at the correct jars. + // The other excluded dependencies are also nowhere to be found, so exclude them explicitly. If + // they're needed by the metastore client, users will have to dig them out of somewhere and use + // configuration to point Spark at the correct jars. case object v14 extends HiveVersion("0.14.0", - Seq("org.apache.calcite:calcite-core:1.3.0-incubating", + extraDeps = Seq("org.apache.calcite:calcite-core:1.3.0-incubating", "org.apache.calcite:calcite-avatica:1.3.0-incubating"), - Seq("org.pentaho:pentaho-aggdesigner-algorithm")) + exclusions = Seq("org.pentaho:pentaho-aggdesigner-algorithm")) + + case object v1_0 extends HiveVersion("1.0.0", + exclusions = Seq("eigenbase:eigenbase-properties", + "org.pentaho:pentaho-aggdesigner-algorithm", + "net.hydromatic:linq4j", + "net.hydromatic:quidem")) + + // The curator dependency was added to the exclusions here because it seems to confuse the ivy + // library. org.apache.curator:curator is a pom dependency but ivy tries to find the jar for it, + // and fails. + case object v1_1 extends HiveVersion("1.1.0", + exclusions = Seq("eigenbase:eigenbase-properties", + "org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm", + "net.hydromatic:linq4j", + "net.hydromatic:quidem")) + + case object v1_2 extends HiveVersion("1.2.1", + exclusions = Seq("eigenbase:eigenbase-properties", + "org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm", + "net.hydromatic:linq4j", + "net.hydromatic:quidem")) } // scalastyle:on diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 0e4a2427a9c1..8422287e177e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.{AnalysisException, SQLContext} -import org.apache.spark.sql.catalyst.expressions.InternalRow import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.hive.client.{HiveTable, HiveColumn} -import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation, HiveMetastoreTypes} +import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes, MetastoreRelation} +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} /** * Create table and insert the query result into it. @@ -42,11 +40,13 @@ case class CreateTableAsSelect( def database: String = tableDesc.database def tableName: String = tableDesc.name - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def children: Seq[LogicalPlan] = Seq(query) + + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] lazy val metastoreRelation: MetastoreRelation = { - import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat + import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.TextInputFormat @@ -89,10 +89,10 @@ case class CreateTableAsSelect( hiveContext.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true, false)).toRdd } - Seq.empty[InternalRow] + Seq.empty[Row] } override def argString: String = { - s"[Database:$database, TableName: $tableName, InsertIntoHiveTable]\n" + query.toString + s"[Database:$database, TableName: $tableName, InsertIntoHiveTable]" } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala index a89381000ad5..441b6b6033e1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql.hive.execution -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.expressions.{Attribute, InternalRow} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.MetastoreRelation +import org.apache.spark.sql.{Row, SQLContext} /** * Implementation for "describe [extended] table". @@ -35,12 +35,12 @@ case class DescribeHiveTableCommand( override val output: Seq[Attribute], isExtended: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { // Trying to mimic the format of Hive's output. But not exactly the same. var results: Seq[(String, String, String)] = Nil - val columns: Seq[FieldSchema] = table.hiveQlTable.getCols - val partitionColumns: Seq[FieldSchema] = table.hiveQlTable.getPartCols + val columns: Seq[FieldSchema] = table.hiveQlTable.getCols.asScala + val partitionColumns: Seq[FieldSchema] = table.hiveQlTable.getPartCols.asScala results ++= columns.map(field => (field.getName, field.getType, field.getComment)) if (partitionColumns.nonEmpty) { val partColumnInfo = @@ -48,7 +48,7 @@ case class DescribeHiveTableCommand( results ++= partColumnInfo ++ Seq(("# Partition Information", "", "")) ++ - Seq((s"# ${output.get(0).name}", output.get(1).name, output.get(2).name)) ++ + Seq((s"# ${output(0).name}", output(1).name, output(2).name)) ++ partColumnInfo } @@ -57,7 +57,7 @@ case class DescribeHiveTableCommand( } results.map { case (name, dataType, comment) => - InternalRow(name, dataType, comment) + Row(name, dataType, comment) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala index 87f8e3f7fcfc..41b645b2c9c9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, InternalRow} +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.{Row, SQLContext} private[hive] case class HiveNativeCommand(sql: String) extends RunnableCommand { @@ -29,6 +29,6 @@ case class HiveNativeCommand(sql: String) extends RunnableCommand { override def output: Seq[AttributeReference] = Seq(AttributeReference("result", StringType, nullable = false)()) - override def run(sqlContext: SQLContext): Seq[InternalRow] = - sqlContext.asInstanceOf[HiveContext].runSqlHive(sql).map(InternalRow(_)) + override def run(sqlContext: SQLContext): Seq[Row] = + sqlContext.asInstanceOf[HiveContext].runSqlHive(sql).map(Row(_)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 1f5e4af2e474..806d2b9b0b7d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.execution -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition} @@ -27,6 +27,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.Object import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive._ @@ -43,7 +44,7 @@ private[hive] case class HiveTableScan( requestedAttributes: Seq[Attribute], relation: MetastoreRelation, - partitionPruningPred: Option[Expression])( + partitionPruningPred: Seq[Expression])( @transient val context: HiveContext) extends LeafNode { @@ -55,7 +56,7 @@ case class HiveTableScan( // Bind all partition key attribute references in the partition pruning predicate for later // evaluation. - private[this] val boundPruningPred = partitionPruningPred.map { pred => + private[this] val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => require( pred.dataType == BooleanType, s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") @@ -97,7 +98,7 @@ case class HiveTableScan( .asInstanceOf[StructObjectInspector] val columnTypeNames = structOI - .getAllStructFieldRefs + .getAllStructFieldRefs.asScala .map(_.getFieldObjectInspector) .map(TypeInfoUtils.getTypeInfoFromObjectInspector(_).getTypeName) .mkString(",") @@ -117,13 +118,12 @@ case class HiveTableScan( case None => partitions case Some(shouldKeep) => partitions.filter { part => val dataTypes = relation.partitionKeys.map(_.dataType) - val castedValues = for ((value, dataType) <- part.getValues.zip(dataTypes)) yield { - castFromString(value, dataType) - } + val castedValues = part.getValues.asScala.zip(dataTypes) + .map { case (value, dataType) => castFromString(value, dataType) } // Only partitioned values are needed here, since the predicate has already been bound to // partition key attribute references. - val row = new GenericRow(castedValues.toArray) + val row = InternalRow.fromSeq(castedValues) shouldKeep.eval(row).asInstanceOf[Boolean] } } @@ -132,7 +132,8 @@ case class HiveTableScan( protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) { hadoopReader.makeRDDForTable(relation.hiveQlTable) } else { - hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) + hadoopReader.makeRDDForPartitionedTable( + prunePartitions(relation.getHiveQlPartitions(partitionPruningPred))) } override def output: Seq[Attribute] = attributes diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 05f425f2b65f..0c700bdb370a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -19,25 +19,26 @@ package org.apache.spark.sql.hive.execution import java.util +import scala.collection.JavaConverters._ + import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.metastore.MetaStoreUtils import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde2.Serializer import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} +import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.{Attribute, InternalRow} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive._ +import org.apache.spark.sql.types.DataType import org.apache.spark.{SparkException, TaskContext} - -import scala.collection.JavaConversions._ import org.apache.spark.util.SerializableJobConf private[hive] @@ -59,9 +60,9 @@ case class InsertIntoHiveTable( serializer } - def output: Seq[Attribute] = child.output + def output: Seq[Attribute] = Seq.empty - def saveAsHiveFile( + private def saveAsHiveFile( rdd: RDD[InternalRow], valueClass: Class[_], fileSinkConf: FileSinkDesc, @@ -92,8 +93,10 @@ case class InsertIntoHiveTable( ObjectInspectorCopyOption.JAVA) .asInstanceOf[StructObjectInspector] - val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray - val wrappers = fieldOIs.map(wrapperFor) + val fieldOIs = standardOI.getAllStructFieldRefs.asScala + .map(_.getFieldObjectInspector).toArray + val dataTypes: Array[DataType] = child.output.map(_.dataType).toArray + val wrappers = fieldOIs.zip(dataTypes).map { case (f, dt) => wrapperFor(f, dt)} val outputData = new Array[Any](fieldOIs.length) writerContainer.executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) @@ -101,7 +104,7 @@ case class InsertIntoHiveTable( iterator.foreach { row => var i = 0 while (i < fieldOIs.length) { - outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row(i)) + outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i))) i += 1 } @@ -121,12 +124,12 @@ case class InsertIntoHiveTable( * * Note: this is run once and then kept to avoid double insertions. */ - protected[sql] lazy val sideEffectResult: Seq[InternalRow] = { + protected[sql] lazy val sideEffectResult: Seq[Row] = { // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer // instances within the closure, since Serializer is not serializable while TableDesc is. val tableDesc = table.tableDesc val tableLocation = table.hiveQlTable.getDataLocation - val tmpLocation = hiveContext.getExternalTmpPath(tableLocation.toUri) + val tmpLocation = hiveContext.getExternalTmpPath(tableLocation) val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) val isCompressed = sc.hiveconf.getBoolean( ConfVars.COMPRESSRESULT.varname, ConfVars.COMPRESSRESULT.defaultBoolVal) @@ -175,6 +178,19 @@ case class InsertIntoHiveTable( val jobConf = new JobConf(sc.hiveconf) val jobConfSer = new SerializableJobConf(jobConf) + // When speculation is on and output committer class name contains "Direct", we should warn + // users that they may loss data if they are using a direct output committer. + val speculationEnabled = sqlContext.sparkContext.conf.getBoolean("spark.speculation", false) + val outputCommitterClass = jobConf.get("mapred.output.committer.class", "") + if (speculationEnabled && outputCommitterClass.contains("Direct")) { + val warningMessage = + s"$outputCommitterClass may be an output committer that writes data directly to " + + "the final location. Because speculation is enabled, this output committer may " + + "cause data loss (see the case in SPARK-10063). If possible, please use a output " + + "committer that does not have this behavior (e.g. FileOutputCommitter)." + logWarning(warningMessage) + } + val writerContainer = if (numDynamicPartitions > 0) { val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions) new SparkHiveDynamicPartitionWriterContainer(jobConf, fileSinkConf, dynamicPartColNames) @@ -195,7 +211,7 @@ case class InsertIntoHiveTable( // loadPartition call orders directories created on the iteration order of the this map val orderedPartitionSpec = new util.LinkedHashMap[String, String]() - table.hiveQlTable.getPartCols().foreach { entry => + table.hiveQlTable.getPartCols.asScala.foreach { entry => orderedPartitionSpec.put(entry.getName, partitionSpec.get(entry.getName).getOrElse("")) } @@ -223,7 +239,7 @@ case class InsertIntoHiveTable( val oldPart = catalog.client.getPartitionOption( catalog.client.getTable(table.databaseName, table.tableName), - partitionSpec) + partitionSpec.asJava) if (oldPart.isEmpty || !ifNotExists) { catalog.client.loadPartition( @@ -251,13 +267,12 @@ case class InsertIntoHiveTable( // however for now we return an empty list to simplify compatibility checks with hive, which // does not return anything for insert operations. // TODO: implement hive compatibility as rules. - Seq.empty[InternalRow] + Seq.empty[Row] } - override def executeCollect(): Array[Row] = - sideEffectResult.toArray + override def executeCollect(): Array[Row] = sideEffectResult.toArray protected override def doExecute(): RDD[InternalRow] = { - sqlContext.sparkContext.parallelize(sideEffectResult, 1) + sqlContext.sparkContext.parallelize(sideEffectResult.asInstanceOf[Seq[InternalRow]], 1) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 9d8872aa47d1..32bddbaeaeaf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -17,25 +17,28 @@ package org.apache.spark.sql.hive.execution -import java.io.{BufferedReader, DataInputStream, DataOutputStream, EOFException, InputStreamReader} -import java.lang.ProcessBuilder.Redirect +import java.io._ import java.util.Properties +import javax.annotation.Nullable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.AbstractSerDe import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.io.Writable import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} import org.apache.spark.sql.types.DataType -import org.apache.spark.util.Utils +import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils} +import org.apache.spark.{Logging, TaskContext} /** * Transforms the input by forking and running the specified script. @@ -50,73 +53,99 @@ case class ScriptTransformation( script: String, output: Seq[Attribute], child: SparkPlan, - ioschema: HiveScriptIOSchema)(@transient sc: HiveContext) + ioschema: HiveScriptIOSchema)(@transient private val sc: HiveContext) extends UnaryNode { override def otherCopyArgs: Seq[HiveContext] = sc :: Nil protected override def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => + def processIterator(inputIterator: Iterator[InternalRow]): Iterator[InternalRow] = { val cmd = List("/bin/bash", "-c", script) - val builder = new ProcessBuilder(cmd) - // redirectError(Redirect.INHERIT) would consume the error output from buffer and - // then print it to stderr (inherit the target from the current Scala process). - // If without this there would be 2 issues: - // 1) The error msg generated by the script process would be hidden. - // 2) If the error msg is too big to chock up the buffer, the input logic would be hung - builder.redirectError(Redirect.INHERIT) + val builder = new ProcessBuilder(cmd.asJava) + val proc = builder.start() val inputStream = proc.getInputStream val outputStream = proc.getOutputStream - val reader = new BufferedReader(new InputStreamReader(inputStream)) + val errorStream = proc.getErrorStream + + // In order to avoid deadlocks, we need to consume the error output of the child process. + // To avoid issues caused by large error output, we use a circular buffer to limit the amount + // of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang + // that motivates this. + val stderrBuffer = new CircularBuffer(2048) + new RedirectThread( + errorStream, + stderrBuffer, + "Thread-ScriptTransformation-STDERR-Consumer").start() - val (outputSerde, outputSoi) = ioschema.initOutputSerDe(output) + val outputProjection = new InterpretedProjection(input, child.output) - val iterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { - var cacheRow: InternalRow = null + // This nullability is a performance optimization in order to avoid an Option.foreach() call + // inside of a loop + @Nullable val (inputSerde, inputSoi) = ioschema.initInputSerDe(input).getOrElse((null, null)) + + // This new thread will consume the ScriptTransformation's input rows and write them to the + // external process. That process's output will be read by this current thread. + val writerThread = new ScriptTransformationWriterThread( + inputIterator, + input.map(_.dataType), + outputProjection, + inputSerde, + inputSoi, + ioschema, + outputStream, + proc, + stderrBuffer, + TaskContext.get() + ) + + // This nullability is a performance optimization in order to avoid an Option.foreach() call + // inside of a loop + @Nullable val (outputSerde, outputSoi) = { + ioschema.initOutputSerDe(output).getOrElse((null, null)) + } + + val reader = new BufferedReader(new InputStreamReader(inputStream)) + val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { var curLine: String = null - var eof: Boolean = false + val scriptOutputStream = new DataInputStream(inputStream) + var scriptOutputWritable: Writable = null + val reusedWritableObject: Writable = if (null != outputSerde) { + outputSerde.getSerializedClass().newInstance + } else { + null + } + val mutableRow = new SpecificMutableRow(output.map(_.dataType)) override def hasNext: Boolean = { if (outputSerde == null) { if (curLine == null) { curLine = reader.readLine() - curLine != null + if (curLine == null) { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + false + } else { + true + } } else { true } + } else if (scriptOutputWritable == null) { + scriptOutputWritable = reusedWritableObject + try { + scriptOutputWritable.readFields(scriptOutputStream) + true + } catch { + case _: EOFException => + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + false + } } else { - !eof - } - } - - def deserialize(): InternalRow = { - if (cacheRow != null) return cacheRow - - val mutableRow = new SpecificMutableRow(output.map(_.dataType)) - try { - val dataInputStream = new DataInputStream(inputStream) - val writable = outputSerde.getSerializedClass().newInstance - writable.readFields(dataInputStream) - - val raw = outputSerde.deserialize(writable) - val dataList = outputSoi.getStructFieldsDataAsList(raw) - val fieldList = outputSoi.getAllStructFieldRefs() - - var i = 0 - dataList.foreach( element => { - if (element == null) { - mutableRow.setNullAt(i) - } else { - mutableRow(i) = unwrap(element, fieldList(i).getFieldObjectInspector) - } - i += 1 - }) - return mutableRow - } catch { - case e: EOFException => - eof = true - return null + true } } @@ -124,58 +153,128 @@ case class ScriptTransformation( if (!hasNext) { throw new NoSuchElementException } - if (outputSerde == null) { val prevLine = curLine curLine = reader.readLine() if (!ioschema.schemaLess) { - new GenericRow(CatalystTypeConverters.convertToCatalyst( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))) - .asInstanceOf[Array[Any]]) + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + .map(CatalystTypeConverters.convertToCatalyst)) } else { - new GenericRow(CatalystTypeConverters.convertToCatalyst( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)) - .asInstanceOf[Array[Any]]) + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) + .map(CatalystTypeConverters.convertToCatalyst)) } } else { - val ret = deserialize() - if (!eof) { - cacheRow = null - cacheRow = deserialize() + val raw = outputSerde.deserialize(scriptOutputWritable) + scriptOutputWritable = null + val dataList = outputSoi.getStructFieldsDataAsList(raw) + val fieldList = outputSoi.getAllStructFieldRefs() + var i = 0 + while (i < dataList.size()) { + if (dataList.get(i) == null) { + mutableRow.setNullAt(i) + } else { + mutableRow(i) = unwrap(dataList.get(i), fieldList.get(i).getFieldObjectInspector) + } + i += 1 } - ret + mutableRow } } } - val (inputSerde, inputSoi) = ioschema.initInputSerDe(input) - val dataOutputStream = new DataOutputStream(outputStream) - val outputProjection = new InterpretedProjection(input, child.output) + writerThread.start() - // Put the write(output to the pipeline) into a single thread - // and keep the collector as remain in the main thread. - // otherwise it will causes deadlock if the data size greater than - // the pipeline / buffer capacity. - new Thread(new Runnable() { - override def run(): Unit = { - iter - .map(outputProjection) - .foreach { row => - if (inputSerde == null) { - val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), - ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") - - outputStream.write(data) - } else { - val writable = inputSerde.serialize(row.asInstanceOf[GenericRow].values, inputSoi) - prepareWritable(writable).write(dataOutputStream) + outputIterator + } + + child.execute().mapPartitions { iter => + if (iter.hasNext) { + processIterator(iter) + } else { + // If the input iterator has no rows then do not launch the external script. + Iterator.empty + } + } + } +} + +private class ScriptTransformationWriterThread( + iter: Iterator[InternalRow], + inputSchema: Seq[DataType], + outputProjection: Projection, + @Nullable inputSerde: AbstractSerDe, + @Nullable inputSoi: ObjectInspector, + ioschema: HiveScriptIOSchema, + outputStream: OutputStream, + proc: Process, + stderrBuffer: CircularBuffer, + taskContext: TaskContext + ) extends Thread("Thread-ScriptTransformation-Feed") with Logging { + + setDaemon(true) + + @volatile private var _exception: Throwable = null + + /** Contains the exception thrown while writing the parent iterator to the external process. */ + def exception: Option[Throwable] = Option(_exception) + + override def run(): Unit = Utils.logUncaughtExceptions { + TaskContext.setTaskContext(taskContext) + + val dataOutputStream = new DataOutputStream(outputStream) + + // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so + // let's use a variable to record whether the `finally` block was hit due to an exception + var threwException: Boolean = true + val len = inputSchema.length + try { + iter.map(outputProjection).foreach { row => + if (inputSerde == null) { + val data = if (len == 0) { + ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES") + } else { + val sb = new StringBuilder + sb.append(row.get(0, inputSchema(0))) + var i = 1 + while (i < len) { + sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + sb.append(row.get(i, inputSchema(i))) + i += 1 } + sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")) + sb.toString() } - outputStream.close() + outputStream.write(data.getBytes("utf-8")) + } else { + val writable = inputSerde.serialize( + row.asInstanceOf[GenericInternalRow].values, inputSoi) + prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream) } - }).start() - - iterator + } + outputStream.close() + threwException = false + } catch { + case NonFatal(e) => + // An error occurred while writing input, so kill the child process. According to the + // Javadoc this call will not throw an exception: + _exception = e + proc.destroy() + throw e + } finally { + try { + if (proc.waitFor() != 0) { + logError(stderrBuffer.toString) // log the stderr circular buffer + } + } catch { + case NonFatal(exceptionFromFinallyBlock) => + if (!threwException) { + throw exceptionFromFinallyBlock + } else { + log.error("Exception in finally block", exceptionFromFinallyBlock) + } + } } } } @@ -187,93 +286,65 @@ private[hive] case class HiveScriptIOSchema ( inputRowFormat: Seq[(String, String)], outputRowFormat: Seq[(String, String)], - inputSerdeClass: String, - outputSerdeClass: String, + inputSerdeClass: Option[String], + outputSerdeClass: Option[String], inputSerdeProps: Seq[(String, String)], outputSerdeProps: Seq[(String, String)], schemaLess: Boolean) extends ScriptInputOutputSchema with HiveInspectors { - val defaultFormat = Map(("TOK_TABLEROWFORMATFIELD", "\t"), - ("TOK_TABLEROWFORMATLINES", "\n")) + private val defaultFormat = Map( + ("TOK_TABLEROWFORMATFIELD", "\t"), + ("TOK_TABLEROWFORMATLINES", "\n") + ) val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k)) val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) - def initInputSerDe(input: Seq[Expression]): (AbstractSerDe, ObjectInspector) = { - val (columns, columnTypes) = parseAttrs(input) - val serde = initSerDe(inputSerdeClass, columns, columnTypes, inputSerdeProps) - (serde, initInputSoi(serde, columns, columnTypes)) - } - - def initOutputSerDe(output: Seq[Attribute]): (AbstractSerDe, StructObjectInspector) = { - val (columns, columnTypes) = parseAttrs(output) - val serde = initSerDe(outputSerdeClass, columns, columnTypes, outputSerdeProps) - (serde, initOutputputSoi(serde)) - } - - def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { - - val columns = attrs.map { - case aref: AttributeReference => aref.name - case e: NamedExpression => e.name - case _ => null + def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, ObjectInspector)] = { + inputSerdeClass.map { serdeClass => + val (columns, columnTypes) = parseAttrs(input) + val serde = initSerDe(serdeClass, columns, columnTypes, inputSerdeProps) + val fieldObjectInspectors = columnTypes.map(toInspector) + val objectInspector = ObjectInspectorFactory + .getStandardStructObjectInspector(columns.asJava, fieldObjectInspectors.asJava) + .asInstanceOf[ObjectInspector] + (serde, objectInspector) } + } - val columnTypes = attrs.map { - case aref: AttributeReference => aref.dataType - case e: NamedExpression => e.dataType - case _ => null + def initOutputSerDe(output: Seq[Attribute]): Option[(AbstractSerDe, StructObjectInspector)] = { + outputSerdeClass.map { serdeClass => + val (columns, columnTypes) = parseAttrs(output) + val serde = initSerDe(serdeClass, columns, columnTypes, outputSerdeProps) + val structObjectInspector = serde.getObjectInspector().asInstanceOf[StructObjectInspector] + (serde, structObjectInspector) } + } + private def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { + val columns = attrs.zipWithIndex.map(e => s"${e._1.prettyName}_${e._2}") + val columnTypes = attrs.map(_.dataType) (columns, columnTypes) } - def initSerDe(serdeClassName: String, columns: Seq[String], - columnTypes: Seq[DataType], serdeProps: Seq[(String, String)]): AbstractSerDe = { + private def initSerDe( + serdeClassName: String, + columns: Seq[String], + columnTypes: Seq[DataType], + serdeProps: Seq[(String, String)]): AbstractSerDe = { - val serde: AbstractSerDe = if (serdeClassName != "") { - val trimed_class = serdeClassName.split("'")(1) - Utils.classForName(trimed_class) - .newInstance.asInstanceOf[AbstractSerDe] - } else { - null - } + val serde = Utils.classForName(serdeClassName).newInstance.asInstanceOf[AbstractSerDe] - if (serde != null) { - val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",") + val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",") - var propsMap = serdeProps.map(kv => { - (kv._1.split("'")(1), kv._2.split("'")(1)) - }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) - propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) + var propsMap = serdeProps.toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) + propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) - val properties = new Properties() - properties.putAll(propsMap) - serde.initialize(null, properties) - } + val properties = new Properties() + properties.putAll(propsMap.asJava) + serde.initialize(null, properties) serde } - - def initInputSoi(inputSerde: AbstractSerDe, columns: Seq[String], columnTypes: Seq[DataType]) - : ObjectInspector = { - - if (inputSerde != null) { - val fieldObjectInspectors = columnTypes.map(toInspector(_)) - ObjectInspectorFactory - .getStandardStructObjectInspector(columns, fieldObjectInspectors) - .asInstanceOf[ObjectInspector] - } else { - null - } - } - - def initOutputputSoi(outputSerde: AbstractSerDe): StructObjectInspector = { - if (outputSerde != null) { - outputSerde.getObjectInspector().asInstanceOf[StructObjectInspector] - } else { - null - } - } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index aad58bfa2e6e..d1699dd53681 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -17,15 +17,17 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.AnalysisException +import org.apache.hadoop.hive.metastore.MetaStoreUtils +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.{TableIdentifier, SqlParser} import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.expressions.{Attribute, InternalRow} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -39,9 +41,9 @@ import org.apache.spark.util.Utils private[hive] case class AnalyzeTable(tableName: String) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { sqlContext.asInstanceOf[HiveContext].analyze(tableName) - Seq.empty[InternalRow] + Seq.empty[Row] } } @@ -53,7 +55,7 @@ case class DropTable( tableName: String, ifExists: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] val ifExistsClause = if (ifExists) "IF EXISTS " else "" try { @@ -70,7 +72,7 @@ case class DropTable( hiveContext.invalidateTable(tableName) hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") hiveContext.catalog.unregisterTable(Seq(tableName)) - Seq.empty[InternalRow] + Seq.empty[Row] } } @@ -83,12 +85,12 @@ case class AddJar(path: String) extends RunnableCommand { schema.toAttributes } - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] val currentClassLoader = Utils.getContextOrSparkClassLoader // Add jar to current context - val jarURL = new java.io.File(path).toURL + val jarURL = new java.io.File(path).toURI.toURL val newClassLoader = new java.net.URLClassLoader(Array(jarURL), currentClassLoader) Thread.currentThread.setContextClassLoader(newClassLoader) // We need to explicitly set the class loader associated with the conf in executionHive's @@ -105,36 +107,52 @@ case class AddJar(path: String) extends RunnableCommand { // Add jar to executors hiveContext.sparkContext.addJar(path) - Seq(InternalRow(0)) + Seq(Row(0)) } } private[hive] case class AddFile(path: String) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] hiveContext.runSqlHive(s"ADD FILE $path") hiveContext.sparkContext.addFile(path) - Seq.empty[InternalRow] + Seq.empty[Row] } } +// TODO: Use TableIdentifier instead of String for tableName (SPARK-10104). private[hive] case class CreateMetastoreDataSource( - tableName: String, + tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], provider: String, options: Map[String, String], allowExisting: Boolean, managedIfNoPath: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { + // Since we are saving metadata to metastore, we need to check if metastore supports + // the table name and database name we have for this query. MetaStoreUtils.validateName + // is the method used by Hive to check if a table name or a database name is valid for + // the metastore. + if (!MetaStoreUtils.validateName(tableIdent.table)) { + throw new AnalysisException(s"Table name ${tableIdent.table} is not a valid name for " + + s"metastore. Metastore only accepts table name containing characters, numbers and _.") + } + if (tableIdent.database.isDefined && !MetaStoreUtils.validateName(tableIdent.database.get)) { + throw new AnalysisException(s"Database name ${tableIdent.database.get} is not a valid name " + + s"for metastore. Metastore only accepts database name containing " + + s"characters, numbers and _.") + } + + val tableName = tableIdent.unquotedString val hiveContext = sqlContext.asInstanceOf[HiveContext] - if (hiveContext.catalog.tableExists(tableName :: Nil)) { + if (hiveContext.catalog.tableExists(tableIdent.toSeq)) { if (allowExisting) { - return Seq.empty[InternalRow] + return Seq.empty[Row] } else { throw new AnalysisException(s"Table $tableName already exists.") } @@ -144,46 +162,62 @@ case class CreateMetastoreDataSource( val optionsWithPath = if (!options.contains("path") && managedIfNoPath) { isExternal = false - options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableName)) + options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableIdent)) } else { options } hiveContext.catalog.createDataSourceTable( - tableName, + tableIdent, userSpecifiedSchema, Array.empty[String], provider, optionsWithPath, isExternal) - Seq.empty[InternalRow] + Seq.empty[Row] } } +// TODO: Use TableIdentifier instead of String for tableName (SPARK-10104). private[hive] case class CreateMetastoreDataSourceAsSelect( - tableName: String, + tableIdent: TableIdentifier, provider: String, partitionColumns: Array[String], mode: SaveMode, options: Map[String, String], query: LogicalPlan) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { + // Since we are saving metadata to metastore, we need to check if metastore supports + // the table name and database name we have for this query. MetaStoreUtils.validateName + // is the method used by Hive to check if a table name or a database name is valid for + // the metastore. + if (!MetaStoreUtils.validateName(tableIdent.table)) { + throw new AnalysisException(s"Table name ${tableIdent.table} is not a valid name for " + + s"metastore. Metastore only accepts table name containing characters, numbers and _.") + } + if (tableIdent.database.isDefined && !MetaStoreUtils.validateName(tableIdent.database.get)) { + throw new AnalysisException(s"Database name ${tableIdent.database.get} is not a valid name " + + s"for metastore. Metastore only accepts database name containing " + + s"characters, numbers and _.") + } + + val tableName = tableIdent.unquotedString val hiveContext = sqlContext.asInstanceOf[HiveContext] var createMetastoreTable = false var isExternal = true val optionsWithPath = if (!options.contains("path")) { isExternal = false - options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableName)) + options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableIdent)) } else { options } var existingSchema = None: Option[StructType] - if (sqlContext.catalog.tableExists(Seq(tableName))) { + if (sqlContext.catalog.tableExists(tableIdent.toSeq)) { // Check if we need to throw an exception or just return. mode match { case SaveMode.ErrorIfExists => @@ -194,13 +228,13 @@ case class CreateMetastoreDataSourceAsSelect( s"Or, if you are using SQL CREATE TABLE, you need to drop $tableName first.") case SaveMode.Ignore => // Since the table already exists and the save mode is Ignore, we will just return. - return Seq.empty[InternalRow] + return Seq.empty[Row] case SaveMode.Append => // Check if the specified data source match the data source of the existing table. val resolved = ResolvedDataSource( sqlContext, Some(query.schema.asNullable), partitionColumns, provider, optionsWithPath) val createdRelation = LogicalRelation(resolved.relation) - EliminateSubQueries(sqlContext.table(tableName).logicalPlan) match { + EliminateSubQueries(sqlContext.catalog.lookupRelation(tableIdent.toSeq)) match { case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation) => if (l.relation != createdRelation.relation) { val errorDescription = @@ -249,7 +283,7 @@ case class CreateMetastoreDataSourceAsSelect( // the schema of df). It is important since the nullability may be changed by the relation // provider (for example, see org.apache.spark.sql.parquet.DefaultSource). hiveContext.catalog.createDataSourceTable( - tableName, + tableIdent, Some(resolved.relation.schema), partitionColumns, provider, @@ -258,7 +292,7 @@ case class CreateMetastoreDataSourceAsSelect( } // Refresh the cache of the table in the catalog. - hiveContext.refreshTable(tableName) - Seq.empty[InternalRow] + hiveContext.catalog.refreshTable(tableIdent) + Seq.empty[Row] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala similarity index 80% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 4986b1ea9d90..cad02373e5ba 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.Try import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector} @@ -33,10 +33,11 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.hive.HiveShim._ @@ -59,47 +60,68 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) val functionClassName = functionInfo.getFunctionClass.getName if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveSimpleUdf(new HiveFunctionWrapper(functionClassName), children) + HiveSimpleUDF(new HiveFunctionWrapper(functionClassName), children) } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdf(new HiveFunctionWrapper(functionClassName), children) + HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children) } else if ( classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children) + HiveGenericUDAF(new HiveFunctionWrapper(functionClassName), children) } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUdaf(new HiveFunctionWrapper(functionClassName), children) + HiveUDAF(new HiveFunctionWrapper(functionClassName), children) } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), children) + HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children) } else { sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") } } } - override def registerFunction(name: String, builder: FunctionBuilder): Unit = - throw new UnsupportedOperationException -} + override def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder) + : Unit = underlying.registerFunction(name, info, builder) -private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Expression with HiveInspectors with Logging { + /* List all of the registered function names. */ + override def listFunction(): Seq[String] = { + (FunctionRegistry.getFunctionNames.asScala ++ underlying.listFunction()).toList.sorted + } - type UDFType = UDF + /* Get the class of the registered function by specified name. */ + override def lookupFunction(name: String): Option[ExpressionInfo] = { + underlying.lookupFunction(name).orElse( + Try { + val info = FunctionRegistry.getFunctionInfo(name) + val annotation = info.getFunctionClass.getAnnotation(classOf[Description]) + if (annotation != null) { + Some(new ExpressionInfo( + info.getFunctionClass.getCanonicalName, + annotation.name(), + annotation.value(), + annotation.extended())) + } else { + None + } + }.getOrElse(None)) + } +} + +private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) + extends Expression with HiveInspectors with CodegenFallback with Logging { override def deterministic: Boolean = isUDFDeterministic override def nullable: Boolean = true @transient - lazy val function = funcWrapper.createFunction[UDFType]() + lazy val function = funcWrapper.createFunction[UDF]() @transient - protected lazy val method = - function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo)) + private lazy val method = + function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo).asJava) @transient - protected lazy val arguments = children.map(toInspector).toArray + private lazy val arguments = children.map(toInspector).toArray @transient - protected lazy val isUDFDeterministic = { + private lazy val isUDFDeterministic = { val udfType = function.getClass().getAnnotation(classOf[HiveUDFType]) udfType != null && udfType.deterministic() } @@ -108,26 +130,28 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre // Create parameter converters @transient - protected lazy val conversionHelper = new ConversionHelper(method, arguments) + private lazy val conversionHelper = new ConversionHelper(method, arguments) - @transient - lazy val dataType = javaClassToDataType(method.getReturnType) + val dataType = javaClassToDataType(method.getReturnType) @transient lazy val returnInspector = ObjectInspectorFactory.getReflectionObjectInspector( method.getGenericReturnType(), ObjectInspectorOptions.JAVA) @transient - protected lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) + private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) - override def isThreadSafe: Boolean = false + @transient + private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray // TODO: Finish input output types. override def eval(input: InternalRow): Any = { - unwrap( - FunctionRegistry.invoke(method, function, conversionHelper - .convertIfNecessary(wrap(children.map(c => c.eval(input)), arguments, cached): _*): _*), - returnInspector) + val inputs = wrap(children.map(c => c.eval(input)), arguments, cached, inputDataTypes) + val ret = FunctionRegistry.invoke( + method, + function, + conversionHelper.convertIfNecessary(inputs : _*): _*) + unwrap(ret, returnInspector) } override def toString: String = { @@ -136,52 +160,51 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre } // Adapter from Catalyst ExpressionResult to Hive DeferredObject -private[hive] class DeferredObjectAdapter(oi: ObjectInspector) +private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataType) extends DeferredObject with HiveInspectors { + private var func: () => Any = _ def set(func: () => Any): Unit = { this.func = func } override def prepare(i: Int): Unit = {} - override def get(): AnyRef = wrap(func(), oi) + override def get(): AnyRef = wrap(func(), oi, dataType) } -private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Expression with HiveInspectors with Logging { - type UDFType = GenericUDF +private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) + extends Expression with HiveInspectors with CodegenFallback with Logging { + + override def nullable: Boolean = true override def deterministic: Boolean = isUDFDeterministic - override def nullable: Boolean = true + override def foldable: Boolean = + isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] @transient - lazy val function = funcWrapper.createFunction[UDFType]() + lazy val function = funcWrapper.createFunction[GenericUDF]() @transient - protected lazy val argumentInspectors = children.map(toInspector) + private lazy val argumentInspectors = children.map(toInspector) @transient - protected lazy val returnInspector = { + private lazy val returnInspector = { function.initializeAndFoldConstants(argumentInspectors.toArray) } @transient - protected lazy val isUDFDeterministic = { - val udfType = function.getClass().getAnnotation(classOf[HiveUDFType]) - (udfType != null && udfType.deterministic()) + private lazy val isUDFDeterministic = { + val udfType = function.getClass.getAnnotation(classOf[HiveUDFType]) + udfType != null && udfType.deterministic() } - override def foldable: Boolean = - isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] - @transient - protected lazy val deferedObjects = - argumentInspectors.map(new DeferredObjectAdapter(_)).toArray[DeferredObject] + private lazy val deferedObjects = argumentInspectors.zip(children).map { case (inspect, child) => + new DeferredObjectAdapter(inspect, child.dataType) + }.toArray[DeferredObject] lazy val dataType: DataType = inspectorToDataType(returnInspector) - override def isThreadSafe: Boolean = false - override def eval(input: InternalRow): Any = { returnInspector // Make sure initialized. @@ -224,7 +247,7 @@ private[spark] object ResolveHiveWindowFunction extends Rule[LogicalPlan] { // Get the class of this function. // In Hive 0.12, there is no windowFunctionInfo.getFunctionClass. So, we use // windowFunctionInfo.getfInfo().getFunctionClass for both Hive 0.13 and Hive 0.13.1. - val functionClass = windowFunctionInfo.getfInfo().getFunctionClass + val functionClass = windowFunctionInfo.getFunctionClass() val newChildren = // Rank(), DENSE_RANK(), CUME_DIST(), and PERCENT_RANK() do not take explicit // input parameters and requires implicit parameters, which @@ -304,7 +327,7 @@ private[hive] case class HiveWindowFunction( pivotResult: Boolean, isUDAFBridgeRequired: Boolean, children: Seq[Expression]) extends WindowFunction - with HiveInspectors { + with HiveInspectors with Unevaluable { // Hive window functions are based on GenericUDAFResolver2. type UDFType = GenericUDAFResolver2 @@ -333,7 +356,7 @@ private[hive] case class HiveWindowFunction( evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) } - def dataType: DataType = + override def dataType: DataType = if (!pivotResult) { inspectorToDataType(returnInspector) } else { @@ -347,10 +370,7 @@ private[hive] case class HiveWindowFunction( } } - def nullable: Boolean = true - - override def eval(input: InternalRow): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + override def nullable: Boolean = true @transient lazy val inputProjection = new InterpretedProjection(children) @@ -360,6 +380,9 @@ private[hive] case class HiveWindowFunction( // Output buffer. private var outputBuffer: Any = _ + @transient + private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray + override def init(): Unit = { evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) } @@ -374,8 +397,13 @@ private[hive] case class HiveWindowFunction( } override def prepareInputParameters(input: InternalRow): AnyRef = { - wrap(inputProjection(input), inputInspectors, new Array[AnyRef](children.length)) + wrap( + inputProjection(input), + inputInspectors, + new Array[AnyRef](children.length), + inputDataTypes) } + // Add input parameters for a single row. override def update(input: AnyRef): Unit = { evaluator.iterate(hiveEvaluatorBuffer, input.asInstanceOf[Array[AnyRef]]) @@ -398,10 +426,10 @@ private[hive] case class HiveWindowFunction( // if pivotResult is false, we will get a single value for all rows in the frame. outputBuffer } else { - // if pivotResult is true, we will get a Seq having the same size with the size + // if pivotResult is true, we will get a ArrayData having the same size with the size // of the window frame. At here, we will return the result at the position of // index in the output buffer. - outputBuffer.asInstanceOf[Seq[Any]].get(index) + outputBuffer.asInstanceOf[ArrayData].get(index, dataType) } } @@ -409,13 +437,13 @@ private[hive] case class HiveWindowFunction( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - override def newInstance: WindowFunction = + override def newInstance(): WindowFunction = new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children) } -private[hive] case class HiveGenericUdaf( +private[hive] case class HiveGenericUDAF( funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) extends AggregateExpression + children: Seq[Expression]) extends AggregateExpression1 with HiveInspectors { type UDFType = AbstractGenericUDAFResolver @@ -441,13 +469,13 @@ private[hive] case class HiveGenericUdaf( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this) + def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this) } /** It is used as a wrapper for the hive functions which uses UDAF interface */ -private[hive] case class HiveUdaf( +private[hive] case class HiveUDAF( funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) extends AggregateExpression + children: Seq[Expression]) extends AggregateExpression1 with HiveInspectors { type UDFType = UDAF @@ -474,12 +502,12 @@ private[hive] case class HiveUdaf( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this, true) + def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this, true) } /** * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a - * [[catalyst.expressions.Generator Generator]]. Note that the semantics of Generators do not allow + * [[Generator]]. Note that the semantics of Generators do not allow * Generators to maintain state in between input rows. Thus UDTFs that rely on partitioning * dependent operations like calls to `close()` before producing output will not operate the same as * in Hive. However, in practice this should not affect compatibility for most sane UDTFs @@ -488,10 +516,10 @@ private[hive] case class HiveUdaf( * Operators that require maintaining state in between input rows should instead be implemented as * user defined aggregations, which have clean semantics even in a partitioned execution. */ -private[hive] case class HiveGenericUdtf( +private[hive] case class HiveGenericUDTF( funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Generator with HiveInspectors { + extends Generator with HiveInspectors with CodegenFallback { @transient protected lazy val function: GenericUDTF = { @@ -512,16 +540,19 @@ private[hive] case class HiveGenericUdtf( @transient protected lazy val collector = new UDTFCollector - lazy val elementTypes = outputInspector.getAllStructFieldRefs.map { + lazy val elementTypes = outputInspector.getAllStructFieldRefs.asScala.map { field => (inspectorToDataType(field.getFieldObjectInspector), true) } + @transient + private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { outputInspector // Make sure initialized. val inputProjection = new InterpretedProjection(children) - function.process(wrap(inputProjection(input), inputInspectors, udtInput)) + function.process(wrap(inputProjection(input), inputInspectors, udtInput, inputDataTypes)) collector.collectRows() } @@ -553,12 +584,12 @@ private[hive] case class HiveGenericUdtf( } } -private[hive] case class HiveUdafFunction( +private[hive] case class HiveUDAFFunction( funcWrapper: HiveFunctionWrapper, exprs: Seq[Expression], - base: AggregateExpression, + base: AggregateExpression1, isUDAFBridgeRequired: Boolean = false) - extends AggregateFunction + extends AggregateFunction1 with HiveInspectors { def this() = this(null, null, null) @@ -590,9 +621,12 @@ private[hive] case class HiveUdafFunction( @transient protected lazy val cached = new Array[AnyRef](exprs.length) + @transient + private lazy val inputDataTypes: Array[DataType] = exprs.map(_.dataType).toArray + def update(input: InternalRow): Unit = { val inputs = inputProjection(input) - function.iterate(buffer, wrap(inputs, inspectors, cached)) + function.iterate(buffer, wrap(inputs, inspectors, cached, inputDataTypes)) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 8b928861fcc7..c8d6b718045a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -32,9 +32,9 @@ import org.apache.hadoop.mapred._ import org.apache.hadoop.hive.common.FileUtils import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.sql.Row import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} -import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableJobConf @@ -44,7 +44,7 @@ import org.apache.spark.util.SerializableJobConf * It is based on [[SparkHadoopWriter]]. */ private[hive] class SparkHiveWriterContainer( - @transient jobConf: JobConf, + jobConf: JobConf, fileSinkConf: FileSinkDesc) extends Logging with SparkHadoopMapRedUtil @@ -94,7 +94,9 @@ private[hive] class SparkHiveWriterContainer( "part-" + numberFormat.format(splitID) + extension } - def getLocalFileWriter(row: Row, schema: StructType): FileSinkOperator.RecordWriter = writer + def getLocalFileWriter(row: InternalRow, schema: StructType): FileSinkOperator.RecordWriter = { + writer + } def close() { // Seems the boolean value passed into close does not matter. @@ -119,7 +121,7 @@ private[hive] class SparkHiveWriterContainer( } protected def commit() { - SparkHadoopMapRedUtil.commitTask(committer, taskContext, jobID, splitID, attemptID) + SparkHadoopMapRedUtil.commitTask(committer, taskContext, jobID, splitID) } private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { @@ -160,7 +162,7 @@ private[spark] object SparkHiveDynamicPartitionWriterContainer { } private[spark] class SparkHiveDynamicPartitionWriterContainer( - @transient jobConf: JobConf, + jobConf: JobConf, fileSinkConf: FileSinkDesc, dynamicPartColNames: Array[String]) extends SparkHiveWriterContainer(jobConf, fileSinkConf) { @@ -168,7 +170,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( import SparkHiveDynamicPartitionWriterContainer._ private val defaultPartName = jobConf.get( - ConfVars.DEFAULTPARTITIONNAME.varname, ConfVars.DEFAULTPARTITIONNAME.defaultVal) + ConfVars.DEFAULTPARTITIONNAME.varname, ConfVars.DEFAULTPARTITIONNAME.defaultStrVal) @transient private var writers: mutable.HashMap[String, FileSinkOperator.RecordWriter] = _ @@ -191,34 +193,35 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( // Better solution is to add a step similar to what Hive FileSinkOperator.jobCloseOp does: // calling something like Utilities.mvFileToFinalPath to cleanup the output directory and then // load it with loadDynamicPartitions/loadPartition/loadTable. - val oldMarker = jobConf.getBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, true) - jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, false) + val oldMarker = conf.value.getBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, true) + conf.value.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, false) super.commitJob() - jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker) + conf.value.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker) } - override def getLocalFileWriter(row: Row, schema: StructType): FileSinkOperator.RecordWriter = { + override def getLocalFileWriter(row: InternalRow, schema: StructType) + : FileSinkOperator.RecordWriter = { def convertToHiveRawString(col: String, value: Any): String = { val raw = String.valueOf(value) schema(col).dataType match { - case DateType => DateUtils.toString(raw.toInt) + case DateType => DateTimeUtils.dateToString(raw.toInt) case _: DecimalType => BigDecimal(raw).toString() case _ => raw } } - val dynamicPartPath = dynamicPartColNames - .zip(row.toSeq.takeRight(dynamicPartColNames.length)) - .map { case (col, rawVal) => - val string = if (rawVal == null) null else convertToHiveRawString(col, rawVal) - val colString = - if (string == null || string.isEmpty) { - defaultPartName - } else { - FileUtils.escapePathName(string, defaultPartName) - } - s"/$col=$colString" - }.mkString + val nonDynamicPartLen = row.numFields - dynamicPartColNames.length + val dynamicPartPath = dynamicPartColNames.zipWithIndex.map { case (colName, i) => + val rawVal = row.get(nonDynamicPartLen + i, schema(colName).dataType) + val string = if (rawVal == null) null else convertToHiveRawString(colName, rawVal) + val colString = + if (string == null || string.isEmpty) { + defaultPartName + } else { + FileUtils.escapePathName(string, defaultPartName) + } + s"/$colName=$colString" + }.mkString def newWriter(): FileSinkOperator.RecordWriter = { val newFileSinkDesc = new FileSinkDesc( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index 1e51173a1988..0f9a1a6ef3b2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -24,42 +24,83 @@ import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.hive.HiveMetastoreTypes import org.apache.spark.sql.types.StructType -private[orc] object OrcFileOperator extends Logging{ - def getFileReader(pathStr: String, config: Option[Configuration] = None ): Reader = { +private[orc] object OrcFileOperator extends Logging { + /** + * Retrieves a ORC file reader from a given path. The path can point to either a directory or a + * single ORC file. If it points to an directory, it picks any non-empty ORC file within that + * directory. + * + * The reader returned by this method is mainly used for two purposes: + * + * 1. Retrieving file metadata (schema and compression codecs, etc.) + * 2. Read the actual file content (in this case, the given path should point to the target file) + * + * @note As recorded by SPARK-8501, ORC writes an empty schema (struct<> + logInfo( + s"ORC file $path has empty schema, it probably contains no rows. " + + "Trying to read another ORC file to figure out the schema.") + false + case _ => true + } + } + val conf = config.getOrElse(new Configuration) - val fspath = new Path(pathStr) - val fs = fspath.getFileSystem(conf) - val orcFiles = listOrcFiles(pathStr, conf) + val fs = { + val hdfsPath = new Path(basePath) + hdfsPath.getFileSystem(conf) + } - // TODO Need to consider all files when schema evolution is taken into account. - OrcFile.createReader(fs, orcFiles.head) + listOrcFiles(basePath, conf).iterator.map { path => + path -> OrcFile.createReader(fs, path) + }.collectFirst { + case (path, reader) if isWithNonEmptySchema(path, reader) => reader + } } def readSchema(path: String, conf: Option[Configuration]): StructType = { - val reader = getFileReader(path, conf) + val reader = getFileReader(path, conf).getOrElse { + throw new AnalysisException( + s"Failed to discover schema from ORC files stored in $path. " + + "Probably there are either no ORC files or only empty ORC files.") + } val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] val schema = readerInspector.getTypeName + logDebug(s"Reading schema from file $path, got Hive schema string: $schema") HiveMetastoreTypes.toDataType(schema).asInstanceOf[StructType] } - def getObjectInspector(path: String, conf: Option[Configuration]): StructObjectInspector = { - getFileReader(path, conf).getObjectInspector.asInstanceOf[StructObjectInspector] + def getObjectInspector( + path: String, conf: Option[Configuration]): Option[StructObjectInspector] = { + getFileReader(path, conf).map(_.getObjectInspector.asInstanceOf[StructObjectInspector]) } def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { val origPath = new Path(pathStr) val fs = origPath.getFileSystem(conf) - val path = origPath.makeQualified(fs) + val path = origPath.makeQualified(fs.getUri, fs.getWorkingDirectory) val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath) .filterNot(_.isDir) .map(_.getPath) .filterNot(_.getName.startsWith("_")) .filterNot(_.getName.startsWith(".")) - if (paths == null || paths.size == 0) { + if (paths == null || paths.isEmpty) { throw new IllegalArgumentException( s"orcFileOperator: path $path does not have valid orc files matching the pattern") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index 250e73a4dba9..b3d9f7f71a27 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.orc import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar} -import org.apache.hadoop.hive.ql.io.sarg.SearchArgument +import org.apache.hadoop.hive.ql.io.sarg.{SearchArgumentFactory, SearchArgument} import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder import org.apache.hadoop.hive.serde2.io.DateWritable @@ -33,18 +33,18 @@ import org.apache.spark.sql.sources._ private[orc] object OrcFilters extends Logging { def createFilter(expr: Array[Filter]): Option[SearchArgument] = { expr.reduceOption(And).flatMap { conjunction => - val builder = SearchArgument.FACTORY.newBuilder() + val builder = SearchArgumentFactory.newBuilder() buildSearchArgument(conjunction, builder).map(_.build()) } } private def buildSearchArgument(expression: Filter, builder: Builder): Option[Builder] = { - def newBuilder = SearchArgument.FACTORY.newBuilder() + def newBuilder = SearchArgumentFactory.newBuilder() - def isSearchableLiteral(value: Any) = value match { + def isSearchableLiteral(value: Any): Boolean = value match { // These are types recognized by the `SearchArgumentImpl.BuilderImpl.boxLiteral()` method. - case _: String | _: Long | _: Double | _: DateWritable | _: HiveDecimal | _: HiveChar | - _: HiveVarchar | _: Byte | _: Short | _: Integer | _: Float => true + case _: String | _: Long | _: Double | _: Byte | _: Short | _: Integer | _: Float => true + case _: DateWritable | _: HiveDecimal | _: HiveChar | _: HiveVarchar => true case _ => false } @@ -107,6 +107,11 @@ private[orc] object OrcFilters extends Logging { .filter(isSearchableLiteral) .map(builder.equals(attribute, _)) + case EqualNullSafe(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.nullSafeEquals(attribute, _)) + case LessThan(attribute, value) => Option(value) .filter(isSearchableLiteral) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index dbce39f21d27..d1f30e188eaf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -19,34 +19,38 @@ package org.apache.spark.sql.hive.orc import java.util.Properties +import scala.collection.JavaConverters._ + import com.google.common.base.Objects import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.io.orc.{OrcInputFormat, OrcOutputFormat, OrcSerde, OrcSplit} -import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector -import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils +import org.apache.hadoop.hive.ql.io.orc.{OrcInputFormat, OrcOutputFormat, OrcSerde, OrcSplit, OrcStruct} +import org.apache.hadoop.hive.serde2.objectinspector.SettableStructObjectInspector +import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfoUtils, StructTypeInfo} import org.apache.hadoop.io.{NullWritable, Writable} -import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, RecordWriter, Reporter} +import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.Logging +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.{HadoopRDD, RDD} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.hive.{HiveContext, HiveInspectors, HiveMetastoreTypes, HiveShim} import org.apache.spark.sql.sources.{Filter, _} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.{Logging} import org.apache.spark.util.SerializableConfiguration -/* Implicit conversions */ -import scala.collection.JavaConversions._ +private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + override def shortName(): String = "orc" -private[sql] class DefaultSource extends HadoopFsRelationProvider { - def createRelation( + override def createRelation( sqlContext: SQLContext, paths: Array[String], dataSchema: Option[StructType], @@ -74,7 +78,8 @@ private[orc] class OrcOutputWriter( }.mkString(":")) val serde = new OrcSerde - serde.initialize(context.getConfiguration, table) + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + serde.initialize(configuration, table) serde } @@ -84,19 +89,10 @@ private[orc] class OrcOutputWriter( TypeInfoUtils.getTypeInfoFromTypeString( HiveMetastoreTypes.toMetastoreType(dataSchema)) - TypeInfoUtils - .getStandardJavaObjectInspectorFromTypeInfo(typeInfo) - .asInstanceOf[StructObjectInspector] + OrcStruct.createObjectInspector(typeInfo.asInstanceOf[StructTypeInfo]) + .asInstanceOf[SettableStructObjectInspector] } - // Used to hold temporary `Writable` fields of the next row to be written. - private val reusableOutputBuffer = new Array[Any](dataSchema.length) - - // Used to convert Catalyst values into Hadoop `Writable`s. - private val wrappers = structOI.getAllStructFieldRefs.map { ref => - wrapperFor(ref.getFieldObjectInspector) - }.toArray - // `OrcRecordWriter.close()` creates an empty file if no rows are written at all. We use this // flag to decide whether `OrcRecordWriter.close()` needs to be called. private var recordWriterInstantiated = false @@ -104,9 +100,11 @@ private[orc] class OrcOutputWriter( private lazy val recordWriter: RecordWriter[NullWritable, Writable] = { recordWriterInstantiated = true - val conf = context.getConfiguration - val partition = context.getTaskAttemptID.getTaskID.getId - val filename = f"part-r-$partition%05d-${System.currentTimeMillis}%015d.orc" + val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val uniqueWriteJobId = conf.get("spark.sql.sources.writeJobUUID") + val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val partition = taskAttemptId.getTaskID.getId + val filename = f"part-r-$partition%05d-$uniqueWriteJobId.orc" new OrcOutputFormat().getRecordWriter( new Path(path, filename).getFileSystem(conf), @@ -116,16 +114,34 @@ private[orc] class OrcOutputWriter( ).asInstanceOf[RecordWriter[NullWritable, Writable]] } - override def write(row: Row): Unit = { + override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") + + private def wrapOrcStruct( + struct: OrcStruct, + oi: SettableStructObjectInspector, + row: InternalRow): Unit = { + val fieldRefs = oi.getAllStructFieldRefs var i = 0 - while (i < row.length) { - reusableOutputBuffer(i) = wrappers(i)(row(i)) + while (i < fieldRefs.size) { + oi.setStructFieldData( + struct, + fieldRefs.get(i), + wrap( + row.get(i, dataSchema(i).dataType), + fieldRefs.get(i).getFieldObjectInspector, + dataSchema(i).dataType)) i += 1 } + } + + val cachedOrcStruct = structOI.create().asInstanceOf[OrcStruct] + + override protected[sql] def writeInternal(row: InternalRow): Unit = { + wrapOrcStruct(cachedOrcStruct, structOI, row) recordWriter.write( NullWritable.get(), - serializer.serialize(reusableOutputBuffer, structOI)) + serializer.serialize(cachedOrcStruct, structOI)) } override def close(): Unit = { @@ -135,7 +151,6 @@ private[orc] class OrcOutputWriter( } } -@DeveloperApi private[sql] class OrcRelation( override val paths: Array[String], maybeDataSchema: Option[StructType], @@ -189,10 +204,20 @@ private[sql] class OrcRelation( filters: Array[Filter], inputPaths: Array[FileStatus]): RDD[Row] = { val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes - OrcTableScan(output, this, filters, inputPaths).execute() + OrcTableScan(output, this, filters, inputPaths).execute().asInstanceOf[RDD[Row]] } override def prepareJobForWrite(job: Job): OutputWriterFactory = { + SparkHadoopUtil.get.getConfigurationFromJobContext(job) match { + case conf: JobConf => + conf.setOutputFormat(classOf[OrcOutputFormat]) + case conf => + conf.setClass( + "mapred.output.format.class", + classOf[OrcOutputFormat], + classOf[MapRedOutputFormat[_, _]]) + } + new OutputWriterFactory { override def newInstance( path: String, @@ -223,40 +248,48 @@ private[orc] case class OrcTableScan( HiveShim.appendReadColumns(conf, sortedIds, sortedNames) } - // Transform all given raw `Writable`s into `Row`s. + // Transform all given raw `Writable`s into `InternalRow`s. private def fillObject( path: String, conf: Configuration, iterator: Iterator[Writable], nonPartitionKeyAttrs: Seq[(Attribute, Int)], - mutableRow: MutableRow): Iterator[Row] = { + mutableRow: MutableRow): Iterator[InternalRow] = { val deserializer = new OrcSerde - val soi = OrcFileOperator.getObjectInspector(path, Some(conf)) - val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { - case (attr, ordinal) => - soi.getStructFieldRef(attr.name.toLowerCase) -> ordinal - }.unzip - val unwrappers = fieldRefs.map(unwrapperFor) - // Map each tuple to a row object - iterator.map { value => - val raw = deserializer.deserialize(value) - var i = 0 - while (i < fieldRefs.length) { - val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) - if (fieldValue == null) { - mutableRow.setNullAt(fieldOrdinals(i)) - } else { - unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) + val maybeStructOI = OrcFileOperator.getObjectInspector(path, Some(conf)) + + // SPARK-8501: ORC writes an empty schema ("struct<>") to an ORC file if the file contains zero + // rows, and thus couldn't give a proper ObjectInspector. In this case we just return an empty + // partition since we know that this file is empty. + maybeStructOI.map { soi => + val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { + case (attr, ordinal) => + soi.getStructFieldRef(attr.name) -> ordinal + }.unzip + val unwrappers = fieldRefs.map(unwrapperFor) + // Map each tuple to a row object + iterator.map { value => + val raw = deserializer.deserialize(value) + var i = 0 + while (i < fieldRefs.length) { + val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) + if (fieldValue == null) { + mutableRow.setNullAt(fieldOrdinals(i)) + } else { + unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) + } + i += 1 } - i += 1 + mutableRow: InternalRow } - mutableRow: Row + }.getOrElse { + Iterator.empty } } - def execute(): RDD[Row] = { + def execute(): RDD[InternalRow] = { val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val conf = job.getConfiguration + val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) // Tries to push down filters if ORC filter push-down is enabled if (sqlContext.conf.orcFilterPushDown) { @@ -269,9 +302,11 @@ private[orc] case class OrcTableScan( // Sets requested columns addColumnIds(attributes, relation, conf) - if (inputPaths.nonEmpty) { - FileInputFormat.setInputPaths(job, inputPaths.map(_.getPath): _*) + if (inputPaths.isEmpty) { + // the input path probably be pruned, return an empty RDD. + return sqlContext.sparkContext.emptyRDD[InternalRow] } + FileInputFormat.setInputPaths(job, inputPaths.map(_.getPath): _*) val inputFormatClass = classOf[OrcInputFormat] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index f901bd817150..be335a47dcab 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -20,42 +20,41 @@ package org.apache.spark.sql.hive.test import java.io.File import java.util.{Set => JavaSet} -import org.apache.hadoop.hive.conf.HiveConf +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.language.implicitConversions + +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.FunctionRegistry -import org.apache.hadoop.hive.ql.io.avro.{AvroContainerInputFormat, AvroContainerOutputFormat} -import org.apache.hadoop.hive.ql.metadata.Table -import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe -import org.apache.hadoop.hive.serde2.avro.AvroSerDe -import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.{SQLContext, SQLConf} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.execution.HiveNativeCommand -import org.apache.spark.sql.SQLConf -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.{SparkConf, SparkContext} -import scala.collection.mutable -import scala.language.implicitConversions - -/* Implicit conversions */ -import scala.collection.JavaConversions._ - // SPARK-3729: Test key required to check for initialization errors with config. object TestHive extends TestHiveContext( new SparkContext( - System.getProperty("spark.sql.test.master", "local[2]"), + System.getProperty("spark.sql.test.master", "local[32]"), "TestSQLContext", new SparkConf() .set("spark.sql.test", "") - .set( - "spark.sql.hive.metastore.barrierPrefixes", - "org.apache.spark.sql.hive.execution.PairSerDe"))) + .set("spark.sql.hive.metastore.barrierPrefixes", + "org.apache.spark.sql.hive.execution.PairSerDe") + // SPARK-8910 + .set("spark.ui.enabled", "false"))) + +trait TestHiveSingleton { + protected val sqlContext: SQLContext = TestHive + protected val hiveContext: TestHiveContext = TestHive +} /** * A locally running test instance of Spark's Hive execution engine. @@ -80,13 +79,25 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { hiveconf.set("hive.plan.serialization.format", "javaXML") - lazy val warehousePath = Utils.createTempDir() + lazy val warehousePath = Utils.createTempDir(namePrefix = "warehouse-") + + lazy val scratchDirPath = { + val dir = Utils.createTempDir(namePrefix = "scratch-") + dir.delete() + dir + } private lazy val temporaryConfig = newTemporaryConfiguration() /** Sets up the system initially or after a RESET command */ - protected override def configure(): Map[String, String] = - temporaryConfig ++ Map("hive.metastore.warehouse.dir" -> warehousePath.toString) + protected override def configure(): Map[String, String] = { + super.configure() ++ temporaryConfig ++ Map( + ConfVars.METASTOREWAREHOUSE.varname -> warehousePath.toURI.toString, + ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true", + ConfVars.SCRATCHDIR.varname -> scratchDirPath.toURI.toString, + ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY.varname -> "1" + ) + } val testTempDir = Utils.createTempDir() @@ -105,18 +116,28 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { override def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution(plan) + // Make sure we set those test specific confs correctly when we create + // the SQLConf as well as when we call clear. override protected[sql] def createSession(): SQLSession = { new this.SQLSession() } protected[hive] class SQLSession extends super.SQLSession { - /** Fewer partitions to speed up testing. */ protected[sql] override lazy val conf: SQLConf = new SQLConf { - override def numShufflePartitions: Int = getConf(SQLConf.SHUFFLE_PARTITIONS, 5) // TODO as in unit test, conf.clear() probably be called, all of the value will be cleared. // The super.getConf(SQLConf.DIALECT) is "sql" by default, we need to set it as "hiveql" override def dialect: String = super.getConf(SQLConf.DIALECT, "hiveql") override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) + + clear() + + override def clear(): Unit = { + super.clear() + + TestHiveContext.overrideConfs.map { + case (key, value) => setConfString(key, value) + } + } } } @@ -144,7 +165,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { val hiveFilesTemp = File.createTempFile("catalystHiveFiles", "") hiveFilesTemp.delete() hiveFilesTemp.mkdir() - Utils.registerShutdownDeleteDir(hiveFilesTemp) + ShutdownHookManager.registerShutdownDeleteDir(hiveFilesTemp) val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) @@ -239,7 +260,6 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { }), TestTable("src_thrift", () => { import org.apache.hadoop.hive.serde2.thrift.ThriftDeserializer - import org.apache.hadoop.hive.serde2.thrift.test.Complex import org.apache.hadoop.mapred.{SequenceFileInputFormat, SequenceFileOutputFormat} import org.apache.thrift.protocol.TBinaryProtocol @@ -248,7 +268,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { |CREATE TABLE src_thrift(fake INT) |ROW FORMAT SERDE '${classOf[ThriftDeserializer].getName}' |WITH SERDEPROPERTIES( - | 'serialization.class'='${classOf[Complex].getName}', + | 'serialization.class'='org.apache.spark.sql.hive.test.Complex', | 'serialization.format'='${classOf[TBinaryProtocol].getName}' |) |STORED AS @@ -267,10 +287,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { "INSERT OVERWRITE TABLE serdeins SELECT * FROM src".cmd), TestTable("episodes", s"""CREATE TABLE episodes (title STRING, air_date STRING, doctor INT) - |ROW FORMAT SERDE '${classOf[AvroSerDe].getCanonicalName}' - |STORED AS - |INPUTFORMAT '${classOf[AvroContainerInputFormat].getCanonicalName}' - |OUTPUTFORMAT '${classOf[AvroContainerOutputFormat].getCanonicalName}' + |STORED AS avro |TBLPROPERTIES ( | 'avro.schema.literal'='{ | "type": "record", @@ -303,10 +320,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { TestTable("episodes_part", s"""CREATE TABLE episodes_part (title STRING, air_date STRING, doctor INT) |PARTITIONED BY (doctor_pt INT) - |ROW FORMAT SERDE '${classOf[AvroSerDe].getCanonicalName}' - |STORED AS - |INPUTFORMAT '${classOf[AvroContainerInputFormat].getCanonicalName}' - |OUTPUTFORMAT '${classOf[AvroContainerOutputFormat].getCanonicalName}' + |STORED AS avro |TBLPROPERTIES ( | 'avro.schema.literal'='{ | "type": "record", @@ -364,7 +378,11 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { INSERT OVERWRITE TABLE episodes_part PARTITION (doctor_pt=1) SELECT title, air_date, doctor FROM episodes """.cmd - ) + ), + TestTable("src_json", + s"""CREATE TABLE src_json (json STRING) STORED AS TEXTFILE + """.stripMargin.cmd, + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/json.txt")}' INTO TABLE src_json".cmd) ) hiveQTestUtilTables.foreach(registerTestTable) @@ -391,7 +409,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { * Records the UDFs present when the server starts, so we can delete ones that are created by * tests. */ - protected val originalUdfs: JavaSet[String] = FunctionRegistry.getFunctionNames + protected val originalUDFs: JavaSet[String] = FunctionRegistry.getFunctionNames /** * Resets the test instance by deleting any tables that have been created. @@ -400,7 +418,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { def reset() { try { // HACK: Hive is too noisy by default. - org.apache.log4j.LogManager.getCurrentLoggers.foreach { log => + org.apache.log4j.LogManager.getCurrentLoggers.asScala.foreach { log => log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) } @@ -410,9 +428,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { catalog.client.reset() catalog.unregisterAllTables() - FunctionRegistry.getFunctionNames.filterNot(originalUdfs.contains(_)).foreach { udfName => - FunctionRegistry.unregisterTemporaryUDF(udfName) - } + FunctionRegistry.getFunctionNames.asScala.filterNot(originalUDFs.contains(_)). + foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } // Some tests corrupt this value on purpose, which breaks the RESET call below. hiveconf.set("fs.default.name", new File(".").toURI.toString) @@ -432,6 +449,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { case (k, v) => metadataHive.runSqlHive(s"SET $k=$v") } + defaultOverrides() runSqlHive("USE default") @@ -447,3 +465,15 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { } } } + +private[hive] object TestHiveContext { + + /** + * A map used to store all confs that need to be overridden in sql/hive unit tests. + */ + val overrideConfs: Map[String, String] = + Map( + // Fewer shuffle partitions to speed up testing. + SQLConf.SHUFFLE_PARTITIONS.key -> "5" + ) +} diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java similarity index 61% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java index c4828c471764..b4bf9eef8fca 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package test.org.apache.spark.sql.hive; +package org.apache.spark.sql.hive; import java.io.IOException; import java.util.ArrayList; @@ -29,8 +29,10 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; import org.apache.spark.sql.expressions.Window; -import org.apache.spark.sql.hive.HiveContext; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; +import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.hive.test.TestHive$; +import org.apache.spark.sql.hive.aggregate.MyDoubleSum; public class JavaDataFrameSuite { private transient JavaSparkContext sc; @@ -38,7 +40,7 @@ public class JavaDataFrameSuite { DataFrame df; - private void checkAnswer(DataFrame actual, List expected) { + private static void checkAnswer(DataFrame actual, List expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -50,24 +52,26 @@ public void setUp() throws IOException { hc = TestHive$.MODULE$; sc = new JavaSparkContext(hc.sparkContext()); - List jsonObjects = new ArrayList(10); + List jsonObjects = new ArrayList<>(10); for (int i = 0; i < 10; i++) { jsonObjects.add("{\"key\":" + i + ", \"value\":\"str" + i + "\"}"); } - df = hc.jsonRDD(sc.parallelize(jsonObjects)); + df = hc.read().json(sc.parallelize(jsonObjects)); df.registerTempTable("window_table"); } @After public void tearDown() throws IOException { // Clean up tables. - hc.sql("DROP TABLE IF EXISTS window_table"); + if (hc != null) { + hc.sql("DROP TABLE IF EXISTS window_table"); + } } @Test public void saveTableAndQueryIt() { checkAnswer( - df.select(functions.avg("key").over( + df.select(avg("key").over( Window.partitionBy("value").orderBy("key").rowsBetween(-1, 1))), hc.sql("SELECT avg(key) " + "OVER (PARTITION BY value " + @@ -75,4 +79,26 @@ public void saveTableAndQueryIt() { " ROWS BETWEEN 1 preceding and 1 following) " + "FROM window_table").collectAsList()); } + + @Test + public void testUDAF() { + DataFrame df = hc.range(0, 100).unionAll(hc.range(0, 100)).select(col("id").as("value")); + UserDefinedAggregateFunction udaf = new MyDoubleSum(); + UserDefinedAggregateFunction registeredUDAF = hc.udf().register("mydoublesum", udaf); + // Create Columns for the UDAF. For now, callUDF does not take an argument to specific if + // we want to use distinct aggregation. + DataFrame aggregatedDF = + df.groupBy() + .agg( + udaf.distinct(col("value")), + udaf.apply(col("value")), + registeredUDAF.apply(col("value")), + callUDF("mydoublesum", col("value"))); + + List expectedResult = new ArrayList<>(); + expectedResult.add(RowFactory.create(4950.0, 9900.0, 9900.0, 9900.0)); + checkAnswer( + aggregatedDF, + expectedResult); + } } diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java similarity index 88% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 64d1ce92931e..c8d272794d10 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package test.org.apache.spark.sql.hive; +package org.apache.spark.sql.hive; import java.io.File; import java.io.IOException; @@ -37,7 +37,6 @@ import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.QueryTest$; import org.apache.spark.sql.Row; -import org.apache.spark.sql.hive.HiveContext; import org.apache.spark.sql.hive.test.TestHive$; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; @@ -54,7 +53,7 @@ public class JavaMetastoreDataSourcesSuite { FileSystem fs; DataFrame df; - private void checkAnswer(DataFrame actual, List expected) { + private static void checkAnswer(DataFrame actual, List expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -78,7 +77,7 @@ public void setUp() throws IOException { fs.delete(hiveManagedPath, true); } - List jsonObjects = new ArrayList(10); + List jsonObjects = new ArrayList<>(10); for (int i = 0; i < 10; i++) { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } @@ -90,13 +89,15 @@ public void setUp() throws IOException { @After public void tearDown() throws IOException { // Clean up tables. - sqlContext.sql("DROP TABLE IF EXISTS javaSavedTable"); - sqlContext.sql("DROP TABLE IF EXISTS externalTable"); + if (sqlContext != null) { + sqlContext.sql("DROP TABLE IF EXISTS javaSavedTable"); + sqlContext.sql("DROP TABLE IF EXISTS externalTable"); + } } @Test public void saveExternalTableAndQueryIt() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write() .format("org.apache.spark.sql.json") @@ -119,7 +120,7 @@ public void saveExternalTableAndQueryIt() { @Test public void saveExternalTableWithSchemaAndQueryIt() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write() .format("org.apache.spark.sql.json") @@ -131,7 +132,7 @@ public void saveExternalTableWithSchemaAndQueryIt() { sqlContext.sql("SELECT * FROM javaSavedTable"), df.collectAsList()); - List fields = new ArrayList(); + List fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); DataFrame loadedDF = @@ -147,7 +148,7 @@ public void saveExternalTableWithSchemaAndQueryIt() { @Test public void saveTableAndQueryIt() { - Map options = new HashMap(); + Map options = new HashMap<>(); df.write() .format("org.apache.spark.sql.json") .mode(SaveMode.Append) diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java new file mode 100644 index 000000000000..5a167edd8959 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.aggregate; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.expressions.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * An example {@link UserDefinedAggregateFunction} to calculate a special average value of a + * {@link org.apache.spark.sql.types.DoubleType} column. This special average value is the sum + * of the average value of input values and 100.0. + */ +public class MyDoubleAvg extends UserDefinedAggregateFunction { + + private StructType _inputDataType; + + private StructType _bufferSchema; + + private DataType _returnDataType; + + public MyDoubleAvg() { + List inputFields = new ArrayList(); + inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); + _inputDataType = DataTypes.createStructType(inputFields); + + // The buffer has two values, bufferSum for storing the current sum and + // bufferCount for storing the number of non-null input values that have been contribuetd + // to the current sum. + List bufferFields = new ArrayList(); + bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true)); + bufferFields.add(DataTypes.createStructField("bufferCount", DataTypes.LongType, true)); + _bufferSchema = DataTypes.createStructType(bufferFields); + + _returnDataType = DataTypes.DoubleType; + } + + @Override public StructType inputSchema() { + return _inputDataType; + } + + @Override public StructType bufferSchema() { + return _bufferSchema; + } + + @Override public DataType dataType() { + return _returnDataType; + } + + @Override public boolean deterministic() { + return true; + } + + @Override public void initialize(MutableAggregationBuffer buffer) { + // The initial value of the sum is null. + buffer.update(0, null); + // The initial value of the count is 0. + buffer.update(1, 0L); + } + + @Override public void update(MutableAggregationBuffer buffer, Row input) { + // This input Row only has a single column storing the input value in Double. + // We only update the buffer when the input value is not null. + if (!input.isNullAt(0)) { + // If the buffer value (the intermediate result of the sum) is still null, + // we set the input value to the buffer and set the bufferCount to 1. + if (buffer.isNullAt(0)) { + buffer.update(0, input.getDouble(0)); + buffer.update(1, 1L); + } else { + // Otherwise, update the bufferSum and increment bufferCount. + Double newValue = input.getDouble(0) + buffer.getDouble(0); + buffer.update(0, newValue); + buffer.update(1, buffer.getLong(1) + 1L); + } + } + } + + @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + // buffer1 and buffer2 have the same structure. + // We only update the buffer1 when the input buffer2's sum value is not null. + if (!buffer2.isNullAt(0)) { + if (buffer1.isNullAt(0)) { + // If the buffer value (intermediate result of the sum) is still null, + // we set the it as the input buffer's value. + buffer1.update(0, buffer2.getDouble(0)); + buffer1.update(1, buffer2.getLong(1)); + } else { + // Otherwise, we update the bufferSum and bufferCount. + Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); + buffer1.update(0, newValue); + buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1)); + } + } + } + + @Override public Object evaluate(Row buffer) { + if (buffer.isNullAt(0)) { + // If the bufferSum is still null, we return null because this function has not got + // any input row. + return null; + } else { + // Otherwise, we calculate the special average value. + return buffer.getDouble(0) / buffer.getLong(1) + 100.0; + } + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java new file mode 100644 index 000000000000..c3b7768e71bf --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.aggregate; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.spark.sql.expressions.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.Row; + +/** + * An example {@link UserDefinedAggregateFunction} to calculate the sum of a + * {@link org.apache.spark.sql.types.DoubleType} column. + */ +public class MyDoubleSum extends UserDefinedAggregateFunction { + + private StructType _inputDataType; + + private StructType _bufferSchema; + + private DataType _returnDataType; + + public MyDoubleSum() { + List inputFields = new ArrayList(); + inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); + _inputDataType = DataTypes.createStructType(inputFields); + + List bufferFields = new ArrayList(); + bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true)); + _bufferSchema = DataTypes.createStructType(bufferFields); + + _returnDataType = DataTypes.DoubleType; + } + + @Override public StructType inputSchema() { + return _inputDataType; + } + + @Override public StructType bufferSchema() { + return _bufferSchema; + } + + @Override public DataType dataType() { + return _returnDataType; + } + + @Override public boolean deterministic() { + return true; + } + + @Override public void initialize(MutableAggregationBuffer buffer) { + // The initial value of the sum is null. + buffer.update(0, null); + } + + @Override public void update(MutableAggregationBuffer buffer, Row input) { + // This input Row only has a single column storing the input value in Double. + // We only update the buffer when the input value is not null. + if (!input.isNullAt(0)) { + if (buffer.isNullAt(0)) { + // If the buffer value (the intermediate result of the sum) is still null, + // we set the input value to the buffer. + buffer.update(0, input.getDouble(0)); + } else { + // Otherwise, we add the input value to the buffer value. + Double newValue = input.getDouble(0) + buffer.getDouble(0); + buffer.update(0, newValue); + } + } + } + + @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + // buffer1 and buffer2 have the same structure. + // We only update the buffer1 when the input buffer2's value is not null. + if (!buffer2.isNullAt(0)) { + if (buffer1.isNullAt(0)) { + // If the buffer value (intermediate result of the sum) is still null, + // we set the it as the input buffer's value. + buffer1.update(0, buffer2.getDouble(0)); + } else { + // Otherwise, we add the input buffer's value (buffer1) to the mutable + // buffer's value (buffer2). + Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); + buffer1.update(0, newValue); + } + } + } + + @Override public Object evaluate(Row buffer) { + if (buffer.isNullAt(0)) { + // If the buffer value is still null, we return null. + return null; + } else { + // Otherwise, the intermediate sum is the final result. + return buffer.getDouble(0); + } + } +} diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFIntegerToString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java similarity index 100% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFIntegerToString.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListListInt.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java similarity index 100% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListListInt.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java similarity index 100% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListString.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFStringString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java similarity index 100% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFStringString.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToIntIntMap.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToIntIntMap.java new file mode 100644 index 000000000000..b3e8bcbbd822 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToIntIntMap.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +import java.util.HashMap; +import java.util.Map; + +public class UDFToIntIntMap extends UDF { + public Map evaluate(Object o) { + return new HashMap() { + { + put(1, 1); + put(2, 1); + put(3, 1); + } + }; + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToListInt.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToListInt.java new file mode 100644 index 000000000000..67576a72f198 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToListInt.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +import java.util.Arrays; +import java.util.List; + +public class UDFToListInt extends UDF { + public List evaluate(Object o) { + return Arrays.asList(1, 2, 3); + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToListString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToListString.java new file mode 100644 index 000000000000..f02395cbba88 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToListString.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +import java.util.Arrays; +import java.util.List; + +public class UDFToListString extends UDF { + public List evaluate(Object o) { + return Arrays.asList("data1", "data2", "data3"); + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToStringIntMap.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToStringIntMap.java new file mode 100644 index 000000000000..9eea5c9a881f --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToStringIntMap.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +import java.util.HashMap; +import java.util.Map; + +public class UDFToStringIntMap extends UDF { + public Map evaluate(Object o) { + return new HashMap() { + { + put("key1", 1); + put("key2", 2); + put("key3", 3); + } + }; + } +} diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFTwoListList.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java similarity index 100% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFTwoListList.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java new file mode 100644 index 000000000000..e010112bb932 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java @@ -0,0 +1,1139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hive.test; + +import org.apache.commons.lang.builder.HashCodeBuilder; +import org.apache.thrift.scheme.IScheme; +import org.apache.thrift.scheme.SchemeFactory; +import org.apache.thrift.scheme.StandardScheme; + +import org.apache.hadoop.hive.serde2.thrift.test.IntString; +import org.apache.thrift.scheme.TupleScheme; +import org.apache.thrift.protocol.TTupleProtocol; +import org.apache.thrift.EncodingUtils; +import java.util.List; +import java.util.ArrayList; +import java.util.Map; +import java.util.HashMap; +import java.util.EnumMap; +import java.util.EnumSet; +import java.util.Collections; +import java.util.BitSet; + +/** + * This is a fork of Hive 0.13's org/apache/hadoop/hive/serde2/thrift/test/Complex.java, which + * does not contain union fields that are not supported by Spark SQL. + */ + +@SuppressWarnings({"ALL", "unchecked"}) +public class Complex implements org.apache.thrift.TBase, java.io.Serializable, Cloneable { + private static final org.apache.thrift.protocol.TStruct STRUCT_DESC = new org.apache.thrift.protocol.TStruct("Complex"); + + private static final org.apache.thrift.protocol.TField AINT_FIELD_DESC = new org.apache.thrift.protocol.TField("aint", org.apache.thrift.protocol.TType.I32, (short)1); + private static final org.apache.thrift.protocol.TField A_STRING_FIELD_DESC = new org.apache.thrift.protocol.TField("aString", org.apache.thrift.protocol.TType.STRING, (short)2); + private static final org.apache.thrift.protocol.TField LINT_FIELD_DESC = new org.apache.thrift.protocol.TField("lint", org.apache.thrift.protocol.TType.LIST, (short)3); + private static final org.apache.thrift.protocol.TField L_STRING_FIELD_DESC = new org.apache.thrift.protocol.TField("lString", org.apache.thrift.protocol.TType.LIST, (short)4); + private static final org.apache.thrift.protocol.TField LINT_STRING_FIELD_DESC = new org.apache.thrift.protocol.TField("lintString", org.apache.thrift.protocol.TType.LIST, (short)5); + private static final org.apache.thrift.protocol.TField M_STRING_STRING_FIELD_DESC = new org.apache.thrift.protocol.TField("mStringString", org.apache.thrift.protocol.TType.MAP, (short)6); + + private static final Map, SchemeFactory> schemes = new HashMap, SchemeFactory>(); + static { + schemes.put(StandardScheme.class, new ComplexStandardSchemeFactory()); + schemes.put(TupleScheme.class, new ComplexTupleSchemeFactory()); + } + + private int aint; // required + private String aString; // required + private List lint; // required + private List lString; // required + private List lintString; // required + private Map mStringString; // required + + /** The set of fields this struct contains, along with convenience methods for finding and manipulating them. */ + public enum _Fields implements org.apache.thrift.TFieldIdEnum { + AINT((short)1, "aint"), + A_STRING((short)2, "aString"), + LINT((short)3, "lint"), + L_STRING((short)4, "lString"), + LINT_STRING((short)5, "lintString"), + M_STRING_STRING((short)6, "mStringString"); + + private static final Map byName = new HashMap(); + + static { + for (_Fields field : EnumSet.allOf(_Fields.class)) { + byName.put(field.getFieldName(), field); + } + } + + /** + * Find the _Fields constant that matches fieldId, or null if its not found. + */ + public static _Fields findByThriftId(int fieldId) { + switch(fieldId) { + case 1: // AINT + return AINT; + case 2: // A_STRING + return A_STRING; + case 3: // LINT + return LINT; + case 4: // L_STRING + return L_STRING; + case 5: // LINT_STRING + return LINT_STRING; + case 6: // M_STRING_STRING + return M_STRING_STRING; + default: + return null; + } + } + + /** + * Find the _Fields constant that matches fieldId, throwing an exception + * if it is not found. + */ + public static _Fields findByThriftIdOrThrow(int fieldId) { + _Fields fields = findByThriftId(fieldId); + if (fields == null) throw new IllegalArgumentException("Field " + fieldId + " doesn't exist!"); + return fields; + } + + /** + * Find the _Fields constant that matches name, or null if its not found. + */ + public static _Fields findByName(String name) { + return byName.get(name); + } + + private final short _thriftId; + private final String _fieldName; + + _Fields(short thriftId, String fieldName) { + _thriftId = thriftId; + _fieldName = fieldName; + } + + public short getThriftFieldId() { + return _thriftId; + } + + public String getFieldName() { + return _fieldName; + } + } + + // isset id assignments + private static final int __AINT_ISSET_ID = 0; + private byte __isset_bitfield = 0; + public static final Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> metaDataMap; + static { + Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> tmpMap = new EnumMap<_Fields, org.apache.thrift.meta_data.FieldMetaData>(_Fields.class); + tmpMap.put(_Fields.AINT, new org.apache.thrift.meta_data.FieldMetaData("aint", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32))); + tmpMap.put(_Fields.A_STRING, new org.apache.thrift.meta_data.FieldMetaData("aString", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING))); + tmpMap.put(_Fields.LINT, new org.apache.thrift.meta_data.FieldMetaData("lint", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32)))); + tmpMap.put(_Fields.L_STRING, new org.apache.thrift.meta_data.FieldMetaData("lString", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING)))); + tmpMap.put(_Fields.LINT_STRING, new org.apache.thrift.meta_data.FieldMetaData("lintString", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST, + new org.apache.thrift.meta_data.StructMetaData(org.apache.thrift.protocol.TType.STRUCT, IntString.class)))); + tmpMap.put(_Fields.M_STRING_STRING, new org.apache.thrift.meta_data.FieldMetaData("mStringString", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.MapMetaData(org.apache.thrift.protocol.TType.MAP, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING), + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING)))); + metaDataMap = Collections.unmodifiableMap(tmpMap); + org.apache.thrift.meta_data.FieldMetaData.addStructMetaDataMap(Complex.class, metaDataMap); + } + + public Complex() { + } + + public Complex( + int aint, + String aString, + List lint, + List lString, + List lintString, + Map mStringString) + { + this(); + this.aint = aint; + setAintIsSet(true); + this.aString = aString; + this.lint = lint; + this.lString = lString; + this.lintString = lintString; + this.mStringString = mStringString; + } + + /** + * Performs a deep copy on other. + */ + public Complex(Complex other) { + __isset_bitfield = other.__isset_bitfield; + this.aint = other.aint; + if (other.isSetAString()) { + this.aString = other.aString; + } + if (other.isSetLint()) { + List __this__lint = new ArrayList(); + for (Integer other_element : other.lint) { + __this__lint.add(other_element); + } + this.lint = __this__lint; + } + if (other.isSetLString()) { + List __this__lString = new ArrayList(); + for (String other_element : other.lString) { + __this__lString.add(other_element); + } + this.lString = __this__lString; + } + if (other.isSetLintString()) { + List __this__lintString = new ArrayList(); + for (IntString other_element : other.lintString) { + __this__lintString.add(new IntString(other_element)); + } + this.lintString = __this__lintString; + } + if (other.isSetMStringString()) { + Map __this__mStringString = new HashMap(); + for (Map.Entry other_element : other.mStringString.entrySet()) { + + String other_element_key = other_element.getKey(); + String other_element_value = other_element.getValue(); + + String __this__mStringString_copy_key = other_element_key; + + String __this__mStringString_copy_value = other_element_value; + + __this__mStringString.put(__this__mStringString_copy_key, __this__mStringString_copy_value); + } + this.mStringString = __this__mStringString; + } + } + + public Complex deepCopy() { + return new Complex(this); + } + + @Override + public void clear() { + setAintIsSet(false); + this.aint = 0; + this.aString = null; + this.lint = null; + this.lString = null; + this.lintString = null; + this.mStringString = null; + } + + public int getAint() { + return this.aint; + } + + public void setAint(int aint) { + this.aint = aint; + setAintIsSet(true); + } + + public void unsetAint() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __AINT_ISSET_ID); + } + + /** Returns true if field aint is set (has been assigned a value) and false otherwise */ + public boolean isSetAint() { + return EncodingUtils.testBit(__isset_bitfield, __AINT_ISSET_ID); + } + + public void setAintIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __AINT_ISSET_ID, value); + } + + public String getAString() { + return this.aString; + } + + public void setAString(String aString) { + this.aString = aString; + } + + public void unsetAString() { + this.aString = null; + } + + /** Returns true if field aString is set (has been assigned a value) and false otherwise */ + public boolean isSetAString() { + return this.aString != null; + } + + public void setAStringIsSet(boolean value) { + if (!value) { + this.aString = null; + } + } + + public int getLintSize() { + return (this.lint == null) ? 0 : this.lint.size(); + } + + public java.util.Iterator getLintIterator() { + return (this.lint == null) ? null : this.lint.iterator(); + } + + public void addToLint(int elem) { + if (this.lint == null) { + this.lint = new ArrayList<>(); + } + this.lint.add(elem); + } + + public List getLint() { + return this.lint; + } + + public void setLint(List lint) { + this.lint = lint; + } + + public void unsetLint() { + this.lint = null; + } + + /** Returns true if field lint is set (has been assigned a value) and false otherwise */ + public boolean isSetLint() { + return this.lint != null; + } + + public void setLintIsSet(boolean value) { + if (!value) { + this.lint = null; + } + } + + public int getLStringSize() { + return (this.lString == null) ? 0 : this.lString.size(); + } + + public java.util.Iterator getLStringIterator() { + return (this.lString == null) ? null : this.lString.iterator(); + } + + public void addToLString(String elem) { + if (this.lString == null) { + this.lString = new ArrayList(); + } + this.lString.add(elem); + } + + public List getLString() { + return this.lString; + } + + public void setLString(List lString) { + this.lString = lString; + } + + public void unsetLString() { + this.lString = null; + } + + /** Returns true if field lString is set (has been assigned a value) and false otherwise */ + public boolean isSetLString() { + return this.lString != null; + } + + public void setLStringIsSet(boolean value) { + if (!value) { + this.lString = null; + } + } + + public int getLintStringSize() { + return (this.lintString == null) ? 0 : this.lintString.size(); + } + + public java.util.Iterator getLintStringIterator() { + return (this.lintString == null) ? null : this.lintString.iterator(); + } + + public void addToLintString(IntString elem) { + if (this.lintString == null) { + this.lintString = new ArrayList<>(); + } + this.lintString.add(elem); + } + + public List getLintString() { + return this.lintString; + } + + public void setLintString(List lintString) { + this.lintString = lintString; + } + + public void unsetLintString() { + this.lintString = null; + } + + /** Returns true if field lintString is set (has been assigned a value) and false otherwise */ + public boolean isSetLintString() { + return this.lintString != null; + } + + public void setLintStringIsSet(boolean value) { + if (!value) { + this.lintString = null; + } + } + + public int getMStringStringSize() { + return (this.mStringString == null) ? 0 : this.mStringString.size(); + } + + public void putToMStringString(String key, String val) { + if (this.mStringString == null) { + this.mStringString = new HashMap(); + } + this.mStringString.put(key, val); + } + + public Map getMStringString() { + return this.mStringString; + } + + public void setMStringString(Map mStringString) { + this.mStringString = mStringString; + } + + public void unsetMStringString() { + this.mStringString = null; + } + + /** Returns true if field mStringString is set (has been assigned a value) and false otherwise */ + public boolean isSetMStringString() { + return this.mStringString != null; + } + + public void setMStringStringIsSet(boolean value) { + if (!value) { + this.mStringString = null; + } + } + + public void setFieldValue(_Fields field, Object value) { + switch (field) { + case AINT: + if (value == null) { + unsetAint(); + } else { + setAint((Integer)value); + } + break; + + case A_STRING: + if (value == null) { + unsetAString(); + } else { + setAString((String)value); + } + break; + + case LINT: + if (value == null) { + unsetLint(); + } else { + setLint((List)value); + } + break; + + case L_STRING: + if (value == null) { + unsetLString(); + } else { + setLString((List)value); + } + break; + + case LINT_STRING: + if (value == null) { + unsetLintString(); + } else { + setLintString((List)value); + } + break; + + case M_STRING_STRING: + if (value == null) { + unsetMStringString(); + } else { + setMStringString((Map)value); + } + break; + + } + } + + public Object getFieldValue(_Fields field) { + switch (field) { + case AINT: + return Integer.valueOf(getAint()); + + case A_STRING: + return getAString(); + + case LINT: + return getLint(); + + case L_STRING: + return getLString(); + + case LINT_STRING: + return getLintString(); + + case M_STRING_STRING: + return getMStringString(); + + } + throw new IllegalStateException(); + } + + /** Returns true if field corresponding to fieldID is set (has been assigned a value) and false otherwise */ + public boolean isSet(_Fields field) { + if (field == null) { + throw new IllegalArgumentException(); + } + + switch (field) { + case AINT: + return isSetAint(); + case A_STRING: + return isSetAString(); + case LINT: + return isSetLint(); + case L_STRING: + return isSetLString(); + case LINT_STRING: + return isSetLintString(); + case M_STRING_STRING: + return isSetMStringString(); + } + throw new IllegalStateException(); + } + + @Override + public boolean equals(Object that) { + if (that == null) + return false; + if (that instanceof Complex) + return this.equals((Complex)that); + return false; + } + + public boolean equals(Complex that) { + if (that == null) + return false; + + boolean this_present_aint = true; + boolean that_present_aint = true; + if (this_present_aint || that_present_aint) { + if (!(this_present_aint && that_present_aint)) + return false; + if (this.aint != that.aint) + return false; + } + + boolean this_present_aString = true && this.isSetAString(); + boolean that_present_aString = true && that.isSetAString(); + if (this_present_aString || that_present_aString) { + if (!(this_present_aString && that_present_aString)) + return false; + if (!this.aString.equals(that.aString)) + return false; + } + + boolean this_present_lint = true && this.isSetLint(); + boolean that_present_lint = true && that.isSetLint(); + if (this_present_lint || that_present_lint) { + if (!(this_present_lint && that_present_lint)) + return false; + if (!this.lint.equals(that.lint)) + return false; + } + + boolean this_present_lString = true && this.isSetLString(); + boolean that_present_lString = true && that.isSetLString(); + if (this_present_lString || that_present_lString) { + if (!(this_present_lString && that_present_lString)) + return false; + if (!this.lString.equals(that.lString)) + return false; + } + + boolean this_present_lintString = true && this.isSetLintString(); + boolean that_present_lintString = true && that.isSetLintString(); + if (this_present_lintString || that_present_lintString) { + if (!(this_present_lintString && that_present_lintString)) + return false; + if (!this.lintString.equals(that.lintString)) + return false; + } + + boolean this_present_mStringString = true && this.isSetMStringString(); + boolean that_present_mStringString = true && that.isSetMStringString(); + if (this_present_mStringString || that_present_mStringString) { + if (!(this_present_mStringString && that_present_mStringString)) + return false; + if (!this.mStringString.equals(that.mStringString)) + return false; + } + + return true; + } + + @Override + public int hashCode() { + HashCodeBuilder builder = new HashCodeBuilder(); + + boolean present_aint = true; + builder.append(present_aint); + if (present_aint) + builder.append(aint); + + boolean present_aString = true && (isSetAString()); + builder.append(present_aString); + if (present_aString) + builder.append(aString); + + boolean present_lint = true && (isSetLint()); + builder.append(present_lint); + if (present_lint) + builder.append(lint); + + boolean present_lString = true && (isSetLString()); + builder.append(present_lString); + if (present_lString) + builder.append(lString); + + boolean present_lintString = true && (isSetLintString()); + builder.append(present_lintString); + if (present_lintString) + builder.append(lintString); + + boolean present_mStringString = true && (isSetMStringString()); + builder.append(present_mStringString); + if (present_mStringString) + builder.append(mStringString); + + return builder.toHashCode(); + } + + public int compareTo(Complex other) { + if (!getClass().equals(other.getClass())) { + return getClass().getName().compareTo(other.getClass().getName()); + } + + int lastComparison = 0; + Complex typedOther = (Complex)other; + + lastComparison = Boolean.valueOf(isSetAint()).compareTo(typedOther.isSetAint()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetAint()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.aint, typedOther.aint); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetAString()).compareTo(typedOther.isSetAString()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetAString()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.aString, typedOther.aString); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetLint()).compareTo(typedOther.isSetLint()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetLint()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.lint, typedOther.lint); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetLString()).compareTo(typedOther.isSetLString()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetLString()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.lString, typedOther.lString); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetLintString()).compareTo(typedOther.isSetLintString()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetLintString()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.lintString, typedOther.lintString); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetMStringString()).compareTo(typedOther.isSetMStringString()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetMStringString()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.mStringString, typedOther.mStringString); + if (lastComparison != 0) { + return lastComparison; + } + } + return 0; + } + + public _Fields fieldForId(int fieldId) { + return _Fields.findByThriftId(fieldId); + } + + public void read(org.apache.thrift.protocol.TProtocol iprot) throws org.apache.thrift.TException { + schemes.get(iprot.getScheme()).getScheme().read(iprot, this); + } + + public void write(org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException { + schemes.get(oprot.getScheme()).getScheme().write(oprot, this); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("Complex("); + boolean first = true; + + sb.append("aint:"); + sb.append(this.aint); + first = false; + if (!first) sb.append(", "); + sb.append("aString:"); + if (this.aString == null) { + sb.append("null"); + } else { + sb.append(this.aString); + } + first = false; + if (!first) sb.append(", "); + sb.append("lint:"); + if (this.lint == null) { + sb.append("null"); + } else { + sb.append(this.lint); + } + first = false; + if (!first) sb.append(", "); + sb.append("lString:"); + if (this.lString == null) { + sb.append("null"); + } else { + sb.append(this.lString); + } + first = false; + if (!first) sb.append(", "); + sb.append("lintString:"); + if (this.lintString == null) { + sb.append("null"); + } else { + sb.append(this.lintString); + } + first = false; + if (!first) sb.append(", "); + sb.append("mStringString:"); + if (this.mStringString == null) { + sb.append("null"); + } else { + sb.append(this.mStringString); + } + first = false; + sb.append(")"); + return sb.toString(); + } + + public void validate() throws org.apache.thrift.TException { + // check for required fields + // check for sub-struct validity + } + + private void writeObject(java.io.ObjectOutputStream out) throws java.io.IOException { + try { + write(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(out))); + } catch (org.apache.thrift.TException te) { + throw new java.io.IOException(te); + } + } + + private void readObject(java.io.ObjectInputStream in) throws java.io.IOException, ClassNotFoundException { + try { + // it doesn't seem like you should have to do this, but java serialization is wacky, and doesn't call the default constructor. + __isset_bitfield = 0; + read(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(in))); + } catch (org.apache.thrift.TException te) { + throw new java.io.IOException(te); + } + } + + private static class ComplexStandardSchemeFactory implements SchemeFactory { + public ComplexStandardScheme getScheme() { + return new ComplexStandardScheme(); + } + } + + private static class ComplexStandardScheme extends StandardScheme { + + public void read(org.apache.thrift.protocol.TProtocol iprot, Complex struct) throws org.apache.thrift.TException { + org.apache.thrift.protocol.TField schemeField; + iprot.readStructBegin(); + while (true) + { + schemeField = iprot.readFieldBegin(); + if (schemeField.type == org.apache.thrift.protocol.TType.STOP) { + break; + } + switch (schemeField.id) { + case 1: // AINT + if (schemeField.type == org.apache.thrift.protocol.TType.I32) { + struct.aint = iprot.readI32(); + struct.setAintIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 2: // A_STRING + if (schemeField.type == org.apache.thrift.protocol.TType.STRING) { + struct.aString = iprot.readString(); + struct.setAStringIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 3: // LINT + if (schemeField.type == org.apache.thrift.protocol.TType.LIST) { + { + org.apache.thrift.protocol.TList _list0 = iprot.readListBegin(); + struct.lint = new ArrayList(_list0.size); + for (int _i1 = 0; _i1 < _list0.size; ++_i1) + { + int _elem2; // required + _elem2 = iprot.readI32(); + struct.lint.add(_elem2); + } + iprot.readListEnd(); + } + struct.setLintIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 4: // L_STRING + if (schemeField.type == org.apache.thrift.protocol.TType.LIST) { + { + org.apache.thrift.protocol.TList _list3 = iprot.readListBegin(); + struct.lString = new ArrayList(_list3.size); + for (int _i4 = 0; _i4 < _list3.size; ++_i4) + { + String _elem5; // required + _elem5 = iprot.readString(); + struct.lString.add(_elem5); + } + iprot.readListEnd(); + } + struct.setLStringIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 5: // LINT_STRING + if (schemeField.type == org.apache.thrift.protocol.TType.LIST) { + { + org.apache.thrift.protocol.TList _list6 = iprot.readListBegin(); + struct.lintString = new ArrayList(_list6.size); + for (int _i7 = 0; _i7 < _list6.size; ++_i7) + { + IntString _elem8; // required + _elem8 = new IntString(); + _elem8.read(iprot); + struct.lintString.add(_elem8); + } + iprot.readListEnd(); + } + struct.setLintStringIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 6: // M_STRING_STRING + if (schemeField.type == org.apache.thrift.protocol.TType.MAP) { + { + org.apache.thrift.protocol.TMap _map9 = iprot.readMapBegin(); + struct.mStringString = new HashMap(2*_map9.size); + for (int _i10 = 0; _i10 < _map9.size; ++_i10) + { + String _key11; // required + String _val12; // required + _key11 = iprot.readString(); + _val12 = iprot.readString(); + struct.mStringString.put(_key11, _val12); + } + iprot.readMapEnd(); + } + struct.setMStringStringIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + default: + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + iprot.readFieldEnd(); + } + iprot.readStructEnd(); + struct.validate(); + } + + public void write(org.apache.thrift.protocol.TProtocol oprot, Complex struct) throws org.apache.thrift.TException { + struct.validate(); + + oprot.writeStructBegin(STRUCT_DESC); + oprot.writeFieldBegin(AINT_FIELD_DESC); + oprot.writeI32(struct.aint); + oprot.writeFieldEnd(); + if (struct.aString != null) { + oprot.writeFieldBegin(A_STRING_FIELD_DESC); + oprot.writeString(struct.aString); + oprot.writeFieldEnd(); + } + if (struct.lint != null) { + oprot.writeFieldBegin(LINT_FIELD_DESC); + { + oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.I32, struct.lint.size())); + for (int _iter13 : struct.lint) + { + oprot.writeI32(_iter13); + } + oprot.writeListEnd(); + } + oprot.writeFieldEnd(); + } + if (struct.lString != null) { + oprot.writeFieldBegin(L_STRING_FIELD_DESC); + { + oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRING, struct.lString.size())); + for (String _iter14 : struct.lString) + { + oprot.writeString(_iter14); + } + oprot.writeListEnd(); + } + oprot.writeFieldEnd(); + } + if (struct.lintString != null) { + oprot.writeFieldBegin(LINT_STRING_FIELD_DESC); + { + oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRUCT, struct.lintString.size())); + for (IntString _iter15 : struct.lintString) + { + _iter15.write(oprot); + } + oprot.writeListEnd(); + } + oprot.writeFieldEnd(); + } + if (struct.mStringString != null) { + oprot.writeFieldBegin(M_STRING_STRING_FIELD_DESC); + { + oprot.writeMapBegin(new org.apache.thrift.protocol.TMap(org.apache.thrift.protocol.TType.STRING, org.apache.thrift.protocol.TType.STRING, struct.mStringString.size())); + for (Map.Entry _iter16 : struct.mStringString.entrySet()) + { + oprot.writeString(_iter16.getKey()); + oprot.writeString(_iter16.getValue()); + } + oprot.writeMapEnd(); + } + oprot.writeFieldEnd(); + } + oprot.writeFieldStop(); + oprot.writeStructEnd(); + } + + } + + private static class ComplexTupleSchemeFactory implements SchemeFactory { + public ComplexTupleScheme getScheme() { + return new ComplexTupleScheme(); + } + } + + private static class ComplexTupleScheme extends TupleScheme { + + @Override + public void write(org.apache.thrift.protocol.TProtocol prot, Complex struct) throws org.apache.thrift.TException { + TTupleProtocol oprot = (TTupleProtocol) prot; + BitSet optionals = new BitSet(); + if (struct.isSetAint()) { + optionals.set(0); + } + if (struct.isSetAString()) { + optionals.set(1); + } + if (struct.isSetLint()) { + optionals.set(2); + } + if (struct.isSetLString()) { + optionals.set(3); + } + if (struct.isSetLintString()) { + optionals.set(4); + } + if (struct.isSetMStringString()) { + optionals.set(5); + } + oprot.writeBitSet(optionals, 6); + if (struct.isSetAint()) { + oprot.writeI32(struct.aint); + } + if (struct.isSetAString()) { + oprot.writeString(struct.aString); + } + if (struct.isSetLint()) { + { + oprot.writeI32(struct.lint.size()); + for (int _iter17 : struct.lint) + { + oprot.writeI32(_iter17); + } + } + } + if (struct.isSetLString()) { + { + oprot.writeI32(struct.lString.size()); + for (String _iter18 : struct.lString) + { + oprot.writeString(_iter18); + } + } + } + if (struct.isSetLintString()) { + { + oprot.writeI32(struct.lintString.size()); + for (IntString _iter19 : struct.lintString) + { + _iter19.write(oprot); + } + } + } + if (struct.isSetMStringString()) { + { + oprot.writeI32(struct.mStringString.size()); + for (Map.Entry _iter20 : struct.mStringString.entrySet()) + { + oprot.writeString(_iter20.getKey()); + oprot.writeString(_iter20.getValue()); + } + } + } + } + + @Override + public void read(org.apache.thrift.protocol.TProtocol prot, Complex struct) throws org.apache.thrift.TException { + TTupleProtocol iprot = (TTupleProtocol) prot; + BitSet incoming = iprot.readBitSet(6); + if (incoming.get(0)) { + struct.aint = iprot.readI32(); + struct.setAintIsSet(true); + } + if (incoming.get(1)) { + struct.aString = iprot.readString(); + struct.setAStringIsSet(true); + } + if (incoming.get(2)) { + { + org.apache.thrift.protocol.TList _list21 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.I32, iprot.readI32()); + struct.lint = new ArrayList(_list21.size); + for (int _i22 = 0; _i22 < _list21.size; ++_i22) + { + int _elem23; // required + _elem23 = iprot.readI32(); + struct.lint.add(_elem23); + } + } + struct.setLintIsSet(true); + } + if (incoming.get(3)) { + { + org.apache.thrift.protocol.TList _list24 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRING, iprot.readI32()); + struct.lString = new ArrayList(_list24.size); + for (int _i25 = 0; _i25 < _list24.size; ++_i25) + { + String _elem26; // required + _elem26 = iprot.readString(); + struct.lString.add(_elem26); + } + } + struct.setLStringIsSet(true); + } + if (incoming.get(4)) { + { + org.apache.thrift.protocol.TList _list27 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRUCT, iprot.readI32()); + struct.lintString = new ArrayList(_list27.size); + for (int _i28 = 0; _i28 < _list27.size; ++_i28) + { + IntString _elem29; // required + _elem29 = new IntString(); + _elem29.read(iprot); + struct.lintString.add(_elem29); + } + } + struct.setLintStringIsSet(true); + } + if (incoming.get(5)) { + { + org.apache.thrift.protocol.TMap _map30 = new org.apache.thrift.protocol.TMap(org.apache.thrift.protocol.TType.STRING, org.apache.thrift.protocol.TType.STRING, iprot.readI32()); + struct.mStringString = new HashMap(2*_map30.size); + for (int _i31 = 0; _i31 < _map30.size; ++_i31) + { + String _key32; // required + String _val33; // required + _key32 = iprot.readString(); + _val33 = iprot.readString(); + struct.mStringString.put(_key32, _val33); + } + } + struct.setMStringStringIsSet(true); + } + } + } + +} + diff --git a/sql/hive/src/test/resources/data/files/testUdf/part-00000 b/sql/hive/src/test/resources/data/files/testUDF/part-00000 similarity index 100% rename from sql/hive/src/test/resources/data/files/testUdf/part-00000 rename to sql/hive/src/test/resources/data/files/testUDF/part-00000 diff --git a/sql/hive/src/test/resources/golden/! operator-0-ee7f6a60a9792041b85b18cda56429bf b/sql/hive/src/test/resources/golden/! operator-0-ee7f6a60a9792041b85b18cda56429bf new file mode 100644 index 000000000000..d00491fd7e5b --- /dev/null +++ b/sql/hive/src/test/resources/golden/! operator-0-ee7f6a60a9792041b85b18cda56429bf @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/Column pruning - non-trivial top project with aliases - query test-0-515e406ffb23f6fd0d8cd34c2b25fbe6 b/sql/hive/src/test/resources/golden/Column pruning - non-trivial top project with aliases - query test-0-515e406ffb23f6fd0d8cd34c2b25fbe6 new file mode 100644 index 000000000000..9a276bc794c0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/Column pruning - non-trivial top project with aliases - query test-0-515e406ffb23f6fd0d8cd34c2b25fbe6 @@ -0,0 +1,3 @@ +476 +172 +622 diff --git a/sql/hive/src/test/resources/golden/Date cast-0-a7cd69b80c77a771a2c955db666be53d b/sql/hive/src/test/resources/golden/Date cast-0-a7cd69b80c77a771a2c955db666be53d deleted file mode 100644 index 98da82fa8938..000000000000 --- a/sql/hive/src/test/resources/golden/Date cast-0-a7cd69b80c77a771a2c955db666be53d +++ /dev/null @@ -1 +0,0 @@ -1970-01-01 1970-01-01 1969-12-31 16:00:00 1969-12-31 16:00:00 1970-01-01 00:00:00 diff --git a/sql/hive/src/test/resources/golden/Date comparison test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 b/sql/hive/src/test/resources/golden/Date comparison test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 deleted file mode 100644 index 27ba77ddaf61..000000000000 --- a/sql/hive/src/test/resources/golden/Date comparison test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 +++ /dev/null @@ -1 +0,0 @@ -true diff --git a/sql/hive/src/test/resources/golden/Partition pruning - non-partitioned, non-trivial project - query test-0-eabbebd5c1d127b1605bfec52d7b7f3f b/sql/hive/src/test/resources/golden/Partition pruning - non-partitioned, non-trivial project - query test-0-eabbebd5c1d127b1605bfec52d7b7f3f new file mode 100644 index 000000000000..444039e75fba --- /dev/null +++ b/sql/hive/src/test/resources/golden/Partition pruning - non-partitioned, non-trivial project - query test-0-eabbebd5c1d127b1605bfec52d7b7f3f @@ -0,0 +1,500 @@ +476 +172 +622 +54 +330 +818 +510 +556 +196 +968 +530 +386 +802 +300 +546 +448 +738 +132 +256 +426 +292 +812 +858 +748 +304 +938 +290 +990 +74 +654 +562 +554 +418 +30 +164 +806 +332 +834 +860 +504 +584 +438 +574 +306 +386 +676 +892 +918 +788 +474 +964 +348 +826 +988 +414 +398 +932 +416 +348 +798 +792 +494 +834 +978 +324 +754 +794 +618 +730 +532 +878 +684 +734 +650 +334 +390 +950 +34 +226 +310 +406 +678 +0 +910 +256 +622 +632 +114 +604 +410 +298 +876 +690 +258 +340 +40 +978 +314 +756 +442 +184 +222 +94 +144 +8 +560 +70 +854 +554 +416 +712 +798 +338 +764 +996 +250 +772 +874 +938 +384 +572 +374 +352 +108 +918 +102 +276 +206 +478 +426 +432 +860 +556 +352 +578 +442 +130 +636 +664 +622 +550 +274 +482 +166 +666 +360 +568 +24 +460 +362 +134 +520 +808 +768 +978 +706 +746 +544 +276 +434 +168 +696 +932 +116 +16 +822 +460 +416 +696 +48 +926 +862 +358 +344 +84 +258 +316 +238 +992 +0 +644 +394 +936 +786 +908 +200 +596 +398 +382 +836 +192 +52 +330 +654 +460 +410 +240 +262 +102 +808 +86 +872 +312 +938 +936 +616 +190 +392 +576 +962 +914 +196 +564 +394 +374 +636 +636 +818 +940 +274 +738 +632 +338 +826 +170 +154 +0 +980 +174 +728 +358 +236 +268 +790 +564 +276 +476 +838 +30 +236 +144 +180 +614 +38 +870 +20 +554 +546 +612 +448 +618 +778 +654 +484 +738 +784 +544 +662 +802 +484 +904 +354 +452 +10 +994 +804 +792 +634 +790 +116 +70 +672 +190 +22 +336 +68 +458 +466 +286 +944 +644 +996 +320 +390 +84 +642 +860 +238 +978 +916 +156 +152 +82 +446 +984 +298 +898 +436 +456 +276 +906 +60 +418 +128 +936 +152 +148 +684 +138 +460 +66 +736 +206 +592 +226 +432 +734 +688 +334 +548 +438 +478 +970 +232 +446 +512 +526 +140 +974 +960 +802 +576 +382 +10 +488 +876 +256 +934 +864 +404 +632 +458 +938 +926 +560 +4 +70 +566 +662 +470 +160 +88 +386 +642 +670 +208 +932 +732 +350 +806 +966 +106 +210 +514 +812 +818 +380 +812 +802 +228 +516 +180 +406 +524 +696 +848 +24 +792 +402 +434 +328 +862 +908 +956 +596 +250 +862 +328 +848 +374 +764 +10 +140 +794 +960 +582 +48 +702 +510 +208 +140 +326 +876 +238 +828 +400 +982 +474 +878 +720 +496 +958 +610 +834 +398 +888 +240 +858 +338 +886 +646 +650 +554 +460 +956 +356 +936 +620 +634 +666 +986 +920 +414 +498 +530 +960 +166 +272 +706 +344 +428 +924 +466 +812 +266 +350 +378 +908 +750 +802 +842 +814 +768 +512 +52 +268 +134 +768 +758 +36 +924 +984 +200 +596 +18 +682 +996 +292 +916 +724 +372 +570 +696 +334 +36 +546 +366 +562 +688 +194 +938 +630 +168 +56 +74 +896 +304 +696 +614 +388 +828 +954 +444 +252 +180 +338 +806 +800 +400 +194 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 new file mode 100644 index 000000000000..dac1b84b916d --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 @@ -0,0 +1,6 @@ +500 NULL 0 +91 0 1 +84 1 1 +105 2 1 +113 3 1 +107 4 1 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #2-0-7a511f02a16f0af4f810b1666cfcd896 b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #2-0-7a511f02a16f0af4f810b1666cfcd896 new file mode 100644 index 000000000000..c7cb747c0a65 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #2-0-7a511f02a16f0af4f810b1666cfcd896 @@ -0,0 +1,10 @@ +1 NULL -3 2 +1 NULL -1 2 +1 NULL 3 2 +1 NULL 4 2 +1 NULL 5 2 +1 NULL 6 2 +1 NULL 12 2 +1 NULL 14 2 +1 NULL 15 2 +1 NULL 22 2 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for GroupingSet-0-8c14c24670a4b06c440346277ce9cf1c b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for GroupingSet-0-8c14c24670a4b06c440346277ce9cf1c new file mode 100644 index 000000000000..c7cb747c0a65 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for GroupingSet-0-8c14c24670a4b06c440346277ce9cf1c @@ -0,0 +1,10 @@ +1 NULL -3 2 +1 NULL -1 2 +1 NULL 3 2 +1 NULL 4 2 +1 NULL 5 2 +1 NULL 6 2 +1 NULL 12 2 +1 NULL 14 2 +1 NULL 15 2 +1 NULL 22 2 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a new file mode 100644 index 000000000000..dac1b84b916d --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a @@ -0,0 +1,6 @@ +500 NULL 0 +91 0 1 +84 1 1 +105 2 1 +113 3 1 +107 4 1 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 new file mode 100644 index 000000000000..1eea4a9b2368 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 @@ -0,0 +1,10 @@ +1 0 5 3 +1 0 15 3 +1 0 25 3 +1 0 60 3 +1 0 75 3 +1 0 80 3 +1 0 100 3 +1 0 140 3 +1 0 145 3 +1 0 150 3 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce new file mode 100644 index 000000000000..1eea4a9b2368 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce @@ -0,0 +1,10 @@ +1 0 5 3 +1 0 15 3 +1 0 25 3 +1 0 60 3 +1 0 75 3 +1 0 80 3 +1 0 100 3 +1 0 140 3 +1 0 145 3 +1 0 150 3 diff --git a/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 b/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-638f81ad9077c7d0c5c735c6e73742ad similarity index 100% rename from sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 rename to sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-638f81ad9077c7d0c5c735c6e73742ad diff --git a/sql/hive/src/test/resources/golden/convert_enum_to_string-1-db089ff46f9826c7883198adacdfad59 b/sql/hive/src/test/resources/golden/convert_enum_to_string-1-db089ff46f9826c7883198adacdfad59 index d35bf9093ca9..2383bef94097 100644 --- a/sql/hive/src/test/resources/golden/convert_enum_to_string-1-db089ff46f9826c7883198adacdfad59 +++ b/sql/hive/src/test/resources/golden/convert_enum_to_string-1-db089ff46f9826c7883198adacdfad59 @@ -15,9 +15,9 @@ my_enum_structlist_map map from deserializer my_structlist array>> from deserializer my_enumlist array from deserializer -my_stringset struct<> from deserializer -my_enumset struct<> from deserializer -my_structset struct<> from deserializer +my_stringset array from deserializer +my_enumset array from deserializer +my_structset array>> from deserializer optionals struct<> from deserializer b string diff --git a/sql/hive/src/test/resources/golden/get_json_object #1-0-f01b340b5662c45bb5f1e3b7c6900e1f b/sql/hive/src/test/resources/golden/get_json_object #1-0-f01b340b5662c45bb5f1e3b7c6900e1f new file mode 100644 index 000000000000..1dcda4315a14 --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #1-0-f01b340b5662c45bb5f1e3b7c6900e1f @@ -0,0 +1 @@ +{"store":{"fruit":[{"weight":8,"type":"apple"},{"weight":9,"type":"pear"}],"basket":[[1,2,{"b":"y","a":"x"}],[3,4],[5,6]],"book":[{"author":"Nigel Rees","title":"Sayings of the Century","category":"reference","price":8.95},{"author":"Herman Melville","title":"Moby Dick","category":"fiction","price":8.99,"isbn":"0-553-21311-3"},{"author":"J. R. R. Tolkien","title":"The Lord of the Rings","category":"fiction","reader":[{"age":25,"name":"bob"},{"age":26,"name":"jack"}],"price":22.99,"isbn":"0-395-19395-8"}],"bicycle":{"price":19.95,"color":"red"}},"email":"amy@only_for_json_udf_test.net","owner":"amy","zip code":"94025","fb:testid":"1234"} diff --git a/sql/hive/src/test/resources/golden/get_json_object #10-0-f3f47d06d7c51d493d68112b0bd6c1fc b/sql/hive/src/test/resources/golden/get_json_object #10-0-f3f47d06d7c51d493d68112b0bd6c1fc new file mode 100644 index 000000000000..81c545efebe5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #10-0-f3f47d06d7c51d493d68112b0bd6c1fc @@ -0,0 +1 @@ +1234 diff --git a/sql/hive/src/test/resources/golden/get_json_object #2-0-e84c2f8136919830fd665a278e4158a b/sql/hive/src/test/resources/golden/get_json_object #2-0-e84c2f8136919830fd665a278e4158a new file mode 100644 index 000000000000..99127db9e311 --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #2-0-e84c2f8136919830fd665a278e4158a @@ -0,0 +1 @@ +amy {"fruit":[{"weight":8,"type":"apple"},{"weight":9,"type":"pear"}],"basket":[[1,2,{"b":"y","a":"x"}],[3,4],[5,6]],"book":[{"author":"Nigel Rees","title":"Sayings of the Century","category":"reference","price":8.95},{"author":"Herman Melville","title":"Moby Dick","category":"fiction","price":8.99,"isbn":"0-553-21311-3"},{"author":"J. R. R. Tolkien","title":"The Lord of the Rings","category":"fiction","reader":[{"age":25,"name":"bob"},{"age":26,"name":"jack"}],"price":22.99,"isbn":"0-395-19395-8"}],"bicycle":{"price":19.95,"color":"red"}} diff --git a/sql/hive/src/test/resources/golden/get_json_object #3-0-bf140c65c31f8d892ec23e41e16e58bb b/sql/hive/src/test/resources/golden/get_json_object #3-0-bf140c65c31f8d892ec23e41e16e58bb new file mode 100644 index 000000000000..0bc03998296a --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #3-0-bf140c65c31f8d892ec23e41e16e58bb @@ -0,0 +1 @@ +{"price":19.95,"color":"red"} [{"author":"Nigel Rees","title":"Sayings of the Century","category":"reference","price":8.95},{"author":"Herman Melville","title":"Moby Dick","category":"fiction","price":8.99,"isbn":"0-553-21311-3"},{"author":"J. R. R. Tolkien","title":"The Lord of the Rings","category":"fiction","reader":[{"age":25,"name":"bob"},{"age":26,"name":"jack"}],"price":22.99,"isbn":"0-395-19395-8"}] diff --git a/sql/hive/src/test/resources/golden/get_json_object #4-0-f0bd902edc1990c9a6c65a6bb672c4d5 b/sql/hive/src/test/resources/golden/get_json_object #4-0-f0bd902edc1990c9a6c65a6bb672c4d5 new file mode 100644 index 000000000000..4f7e09bd3fa7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #4-0-f0bd902edc1990c9a6c65a6bb672c4d5 @@ -0,0 +1 @@ +{"author":"Nigel Rees","title":"Sayings of the Century","category":"reference","price":8.95} [{"author":"Nigel Rees","title":"Sayings of the Century","category":"reference","price":8.95},{"author":"Herman Melville","title":"Moby Dick","category":"fiction","price":8.99,"isbn":"0-553-21311-3"},{"author":"J. R. R. Tolkien","title":"The Lord of the Rings","category":"fiction","reader":[{"age":25,"name":"bob"},{"age":26,"name":"jack"}],"price":22.99,"isbn":"0-395-19395-8"}] diff --git a/sql/hive/src/test/resources/golden/get_json_object #5-0-3c09f4316a1533049aee8af749cdcab b/sql/hive/src/test/resources/golden/get_json_object #5-0-3c09f4316a1533049aee8af749cdcab new file mode 100644 index 000000000000..b2d212a597d9 --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #5-0-3c09f4316a1533049aee8af749cdcab @@ -0,0 +1 @@ +reference ["reference","fiction","fiction"] ["0-553-21311-3","0-395-19395-8"] [{"age":25,"name":"bob"},{"age":26,"name":"jack"}] diff --git a/sql/hive/src/test/resources/golden/get_json_object #6-0-8334d1ddbe0f41fc7b80d4e6b45409da b/sql/hive/src/test/resources/golden/get_json_object #6-0-8334d1ddbe0f41fc7b80d4e6b45409da new file mode 100644 index 000000000000..21d88629fcdb --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #6-0-8334d1ddbe0f41fc7b80d4e6b45409da @@ -0,0 +1 @@ +25 [25,26] diff --git a/sql/hive/src/test/resources/golden/get_json_object #7-0-40d7dff94b26a2e3f4ab71baee3d3ce0 b/sql/hive/src/test/resources/golden/get_json_object #7-0-40d7dff94b26a2e3f4ab71baee3d3ce0 new file mode 100644 index 000000000000..e60721e1dd24 --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #7-0-40d7dff94b26a2e3f4ab71baee3d3ce0 @@ -0,0 +1 @@ +2 [[1,2,{"b":"y","a":"x"}],[3,4],[5,6]] 1 [1,2,{"b":"y","a":"x"}] [1,2,{"b":"y","a":"x"},3,4,5,6] y ["y"] diff --git a/sql/hive/src/test/resources/golden/get_json_object #8-0-180b4b6fdb26011fec05a7ca99fd9844 b/sql/hive/src/test/resources/golden/get_json_object #8-0-180b4b6fdb26011fec05a7ca99fd9844 new file mode 100644 index 000000000000..356fcdf7139b --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #8-0-180b4b6fdb26011fec05a7ca99fd9844 @@ -0,0 +1 @@ +NULL NULL NULL NULL NULL NULL diff --git a/sql/hive/src/test/resources/golden/get_json_object #9-0-47c451a969d856f008f4d6b3d378d94b b/sql/hive/src/test/resources/golden/get_json_object #9-0-47c451a969d856f008f4d6b3d378d94b new file mode 100644 index 000000000000..ef4a39675ed6 --- /dev/null +++ b/sql/hive/src/test/resources/golden/get_json_object #9-0-47c451a969d856f008f4d6b3d378d94b @@ -0,0 +1 @@ +94025 diff --git a/sql/hive/src/test/resources/golden/parenthesis_star_by-5-6888c7f7894910538d82eefa23443189 b/sql/hive/src/test/resources/golden/parenthesis_star_by-5-41d474f5e6d7c61c36f74b4bec4e9e44 similarity index 100% rename from sql/hive/src/test/resources/golden/parenthesis_star_by-5-6888c7f7894910538d82eefa23443189 rename to sql/hive/src/test/resources/golden/parenthesis_star_by-5-41d474f5e6d7c61c36f74b4bec4e9e44 diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 index 501bb6ab32f2..7bb2c0ab4398 100644 --- a/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 @@ -1,4 +1,4 @@ -CREATE TABLE `tmp_showcrt1`( +CREATE TABLE `tmp_showcrt1`( `key` smallint, `value` float) COMMENT 'temporary table' diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 b/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 index 90f8415a1c6b..3cc1a57ee3a4 100644 --- a/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 +++ b/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 @@ -1,4 +1,4 @@ -CREATE TABLE `tmp_feng.tmp_showcrt`( +CREATE TABLE `tmp_feng.tmp_showcrt`( `key` string, `value` int) ROW FORMAT SERDE diff --git a/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 index 4ee22e523031..b51c71a71f91 100644 --- a/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 +++ b/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 @@ -1,4 +1,4 @@ -CREATE TABLE `tmp_showcrt1`( +CREATE TABLE `tmp_showcrt1`( `key` int, `value` string, `newvalue` bigint) diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 index 6fda2570b53f..29189e1d860a 100644 --- a/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 +++ b/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 @@ -1,4 +1,4 @@ -CREATE TABLE `tmp_showcrt1`( +CREATE TABLE `tmp_showcrt1`( `key` int, `value` string, `newvalue` bigint) diff --git a/sql/hive/src/test/resources/golden/show_functions-0-45a7762c39f1b0f26f076220e2764043 b/sql/hive/src/test/resources/golden/show_functions-0-45a7762c39f1b0f26f076220e2764043 index 3049cd6243ad..1b283db3e774 100644 --- a/sql/hive/src/test/resources/golden/show_functions-0-45a7762c39f1b0f26f076220e2764043 +++ b/sql/hive/src/test/resources/golden/show_functions-0-45a7762c39f1b0f26f076220e2764043 @@ -17,6 +17,7 @@ ^ abs acos +add_months and array array_contains @@ -29,6 +30,7 @@ base64 between bin case +cbrt ceil ceiling coalesce @@ -47,7 +49,11 @@ covar_samp create_union cume_dist current_database +current_date +current_timestamp +current_user date_add +date_format date_sub datediff day @@ -65,6 +71,7 @@ ewah_bitmap_empty ewah_bitmap_or exp explode +factorial field find_in_set first_value @@ -73,6 +80,7 @@ format_number from_unixtime from_utc_timestamp get_json_object +greatest hash hex histogram_numeric @@ -81,6 +89,7 @@ if in in_file index +initcap inline instr isnotnull @@ -88,10 +97,13 @@ isnull java_method json_tuple lag +last_day last_value lcase lead +least length +levenshtein like ln locate @@ -109,11 +121,15 @@ max min minute month +months_between named_struct negative +next_day ngrams noop +noopstreaming noopwithmap +noopwithmapstreaming not ntile nvl @@ -147,10 +163,14 @@ rpad rtrim second sentences +shiftleft +shiftright +shiftrightunsigned sign sin size sort_array +soundex space split sqrt @@ -170,6 +190,7 @@ to_unix_timestamp to_utc_timestamp translate trim +trunc ucase unbase64 unhex diff --git a/sql/hive/src/test/resources/golden/show_tblproperties-1-be4adb893c7f946ebd76a648ce3cc1ae b/sql/hive/src/test/resources/golden/show_tblproperties-1-be4adb893c7f946ebd76a648ce3cc1ae index 0f6cc6f44f1f..fdf701f96280 100644 --- a/sql/hive/src/test/resources/golden/show_tblproperties-1-be4adb893c7f946ebd76a648ce3cc1ae +++ b/sql/hive/src/test/resources/golden/show_tblproperties-1-be4adb893c7f946ebd76a648ce3cc1ae @@ -1 +1 @@ -Table tmpfoo does not have property: bar +Table default.tmpfoo does not have property: bar diff --git a/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 b/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 deleted file mode 100644 index c6f275a0db13..000000000000 --- a/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 +++ /dev/null @@ -1 +0,0 @@ -0.0 NULL NULL NULL diff --git a/sql/hive/src/test/resources/golden/udf_date_add-1-efb60fcbd6d78ad35257fb1ec39ace2 b/sql/hive/src/test/resources/golden/udf_date_add-1-efb60fcbd6d78ad35257fb1ec39ace2 index 3c91e138d7bd..d8ec084f0b2b 100644 --- a/sql/hive/src/test/resources/golden/udf_date_add-1-efb60fcbd6d78ad35257fb1ec39ace2 +++ b/sql/hive/src/test/resources/golden/udf_date_add-1-efb60fcbd6d78ad35257fb1ec39ace2 @@ -1,5 +1,5 @@ date_add(start_date, num_days) - Returns the date that is num_days after start_date. start_date is a string in the format 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. num_days is a number. The time part of start_date is ignored. Example: - > SELECT date_add('2009-30-07', 1) FROM src LIMIT 1; - '2009-31-07' + > SELECT date_add('2009-07-30', 1) FROM src LIMIT 1; + '2009-07-31' diff --git a/sql/hive/src/test/resources/golden/udf_date_sub-1-7efeb74367835ade71e5e42b22f8ced4 b/sql/hive/src/test/resources/golden/udf_date_sub-1-7efeb74367835ade71e5e42b22f8ced4 index 29d663f35c58..169c50003625 100644 --- a/sql/hive/src/test/resources/golden/udf_date_sub-1-7efeb74367835ade71e5e42b22f8ced4 +++ b/sql/hive/src/test/resources/golden/udf_date_sub-1-7efeb74367835ade71e5e42b22f8ced4 @@ -1,5 +1,5 @@ date_sub(start_date, num_days) - Returns the date that is num_days before start_date. start_date is a string in the format 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. num_days is a number. The time part of start_date is ignored. Example: - > SELECT date_sub('2009-30-07', 1) FROM src LIMIT 1; - '2009-29-07' + > SELECT date_sub('2009-07-30', 1) FROM src LIMIT 1; + '2009-07-29' diff --git a/sql/hive/src/test/resources/golden/udf_datediff-1-34ae7a68b13c2bc9a89f61acf2edd4c5 b/sql/hive/src/test/resources/golden/udf_datediff-1-34ae7a68b13c2bc9a89f61acf2edd4c5 index 7ccaee7ad3bd..42197f7ad3e5 100644 --- a/sql/hive/src/test/resources/golden/udf_datediff-1-34ae7a68b13c2bc9a89f61acf2edd4c5 +++ b/sql/hive/src/test/resources/golden/udf_datediff-1-34ae7a68b13c2bc9a89f61acf2edd4c5 @@ -1,5 +1,5 @@ datediff(date1, date2) - Returns the number of days between date1 and date2 date1 and date2 are strings in the format 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. The time parts are ignored.If date1 is earlier than date2, the result is negative. Example: - > SELECT datediff('2009-30-07', '2009-31-07') FROM src LIMIT 1; + > SELECT datediff('2009-07-30', '2009-07-31') FROM src LIMIT 1; 1 diff --git a/sql/hive/src/test/resources/golden/udf_day-0-c4c503756384ff1220222d84fd25e756 b/sql/hive/src/test/resources/golden/udf_day-0-c4c503756384ff1220222d84fd25e756 index d4017178b4e6..09703d10eab7 100644 --- a/sql/hive/src/test/resources/golden/udf_day-0-c4c503756384ff1220222d84fd25e756 +++ b/sql/hive/src/test/resources/golden/udf_day-0-c4c503756384ff1220222d84fd25e756 @@ -1 +1 @@ -day(date) - Returns the date of the month of date +day(param) - Returns the day of the month of date/timestamp, or day component of interval diff --git a/sql/hive/src/test/resources/golden/udf_day-1-87168babe1110fe4c38269843414ca4 b/sql/hive/src/test/resources/golden/udf_day-1-87168babe1110fe4c38269843414ca4 index 6135aafa5086..7c0ec1dc3be5 100644 --- a/sql/hive/src/test/resources/golden/udf_day-1-87168babe1110fe4c38269843414ca4 +++ b/sql/hive/src/test/resources/golden/udf_day-1-87168babe1110fe4c38269843414ca4 @@ -1,6 +1,9 @@ -day(date) - Returns the date of the month of date +day(param) - Returns the day of the month of date/timestamp, or day component of interval Synonyms: dayofmonth -date is a string in the format of 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. -Example: - > SELECT day('2009-30-07', 1) FROM src LIMIT 1; +param can be one of: +1. A string in the format of 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. +2. A date value +3. A timestamp value +4. A day-time interval valueExample: + > SELECT day('2009-07-30') FROM src LIMIT 1; 30 diff --git a/sql/hive/src/test/resources/golden/udf_dayofmonth-0-7b2caf942528656555cf19c261a18502 b/sql/hive/src/test/resources/golden/udf_dayofmonth-0-7b2caf942528656555cf19c261a18502 index 47a7018d9d5a..c37eb0ec2e96 100644 --- a/sql/hive/src/test/resources/golden/udf_dayofmonth-0-7b2caf942528656555cf19c261a18502 +++ b/sql/hive/src/test/resources/golden/udf_dayofmonth-0-7b2caf942528656555cf19c261a18502 @@ -1 +1 @@ -dayofmonth(date) - Returns the date of the month of date +dayofmonth(param) - Returns the day of the month of date/timestamp, or day component of interval diff --git a/sql/hive/src/test/resources/golden/udf_dayofmonth-1-ca24d07102ad264d79ff30c64a73a7e8 b/sql/hive/src/test/resources/golden/udf_dayofmonth-1-ca24d07102ad264d79ff30c64a73a7e8 index d9490e20a3b6..9e931f649914 100644 --- a/sql/hive/src/test/resources/golden/udf_dayofmonth-1-ca24d07102ad264d79ff30c64a73a7e8 +++ b/sql/hive/src/test/resources/golden/udf_dayofmonth-1-ca24d07102ad264d79ff30c64a73a7e8 @@ -1,6 +1,9 @@ -dayofmonth(date) - Returns the date of the month of date +dayofmonth(param) - Returns the day of the month of date/timestamp, or day component of interval Synonyms: day -date is a string in the format of 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. -Example: - > SELECT dayofmonth('2009-30-07', 1) FROM src LIMIT 1; +param can be one of: +1. A string in the format of 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. +2. A date value +3. A timestamp value +4. A day-time interval valueExample: + > SELECT dayofmonth('2009-07-30') FROM src LIMIT 1; 30 diff --git a/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c b/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c index 2cf0d9d61882..ce583fe81ff6 100644 --- a/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c +++ b/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c @@ -1 +1 @@ -There is no documentation for function 'if' +IF(expr1,expr2,expr3) - If expr1 is TRUE (expr1 <> 0 and expr1 <> NULL) then IF() returns expr2; otherwise it returns expr3. IF() returns a numeric or string value, depending on the context in which it is used. diff --git a/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a b/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a index 2cf0d9d61882..ce583fe81ff6 100644 --- a/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a +++ b/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a @@ -1 +1 @@ -There is no documentation for function 'if' +IF(expr1,expr2,expr3) - If expr1 is TRUE (expr1 <> 0 and expr1 <> NULL) then IF() returns expr2; otherwise it returns expr3. IF() returns a numeric or string value, depending on the context in which it is used. diff --git a/sql/hive/src/test/resources/golden/udf_if-1-b7ffa85b5785cccef2af1b285348cc2c b/sql/hive/src/test/resources/golden/udf_if-1-b7ffa85b5785cccef2af1b285348cc2c index 2cf0d9d61882..ce583fe81ff6 100644 --- a/sql/hive/src/test/resources/golden/udf_if-1-b7ffa85b5785cccef2af1b285348cc2c +++ b/sql/hive/src/test/resources/golden/udf_if-1-b7ffa85b5785cccef2af1b285348cc2c @@ -1 +1 @@ -There is no documentation for function 'if' +IF(expr1,expr2,expr3) - If expr1 is TRUE (expr1 <> 0 and expr1 <> NULL) then IF() returns expr2; otherwise it returns expr3. IF() returns a numeric or string value, depending on the context in which it is used. diff --git a/sql/hive/src/test/resources/golden/udf_if-2-30cf7f51f92b5684e556deff3032d49a b/sql/hive/src/test/resources/golden/udf_if-2-30cf7f51f92b5684e556deff3032d49a index 2cf0d9d61882..ce583fe81ff6 100644 --- a/sql/hive/src/test/resources/golden/udf_if-2-30cf7f51f92b5684e556deff3032d49a +++ b/sql/hive/src/test/resources/golden/udf_if-2-30cf7f51f92b5684e556deff3032d49a @@ -1 +1 @@ -There is no documentation for function 'if' +IF(expr1,expr2,expr3) - If expr1 is TRUE (expr1 <> 0 and expr1 <> NULL) then IF() returns expr2; otherwise it returns expr3. IF() returns a numeric or string value, depending on the context in which it is used. diff --git a/sql/hive/src/test/resources/golden/udf_minute-0-9a38997c1f41f4afe00faa0abc471aee b/sql/hive/src/test/resources/golden/udf_minute-0-9a38997c1f41f4afe00faa0abc471aee index 231e4f382566..06650592f8d3 100644 --- a/sql/hive/src/test/resources/golden/udf_minute-0-9a38997c1f41f4afe00faa0abc471aee +++ b/sql/hive/src/test/resources/golden/udf_minute-0-9a38997c1f41f4afe00faa0abc471aee @@ -1 +1 @@ -minute(date) - Returns the minute of date +minute(param) - Returns the minute component of the string/timestamp/interval diff --git a/sql/hive/src/test/resources/golden/udf_minute-1-16995573ac4f4a1b047ad6ee88699e48 b/sql/hive/src/test/resources/golden/udf_minute-1-16995573ac4f4a1b047ad6ee88699e48 index ea842ea174ae..08ddc19b84d8 100644 --- a/sql/hive/src/test/resources/golden/udf_minute-1-16995573ac4f4a1b047ad6ee88699e48 +++ b/sql/hive/src/test/resources/golden/udf_minute-1-16995573ac4f4a1b047ad6ee88699e48 @@ -1,6 +1,8 @@ -minute(date) - Returns the minute of date -date is a string in the format of 'yyyy-MM-dd HH:mm:ss' or 'HH:mm:ss'. -Example: +minute(param) - Returns the minute component of the string/timestamp/interval +param can be one of: +1. A string in the format of 'yyyy-MM-dd HH:mm:ss' or 'HH:mm:ss'. +2. A timestamp value +3. A day-time interval valueExample: > SELECT minute('2009-07-30 12:58:59') FROM src LIMIT 1; 58 > SELECT minute('12:58:59') FROM src LIMIT 1; diff --git a/sql/hive/src/test/resources/golden/udf_month-0-9a38997c1f41f4afe00faa0abc471aee b/sql/hive/src/test/resources/golden/udf_month-0-9a38997c1f41f4afe00faa0abc471aee index 231e4f382566..06650592f8d3 100644 --- a/sql/hive/src/test/resources/golden/udf_month-0-9a38997c1f41f4afe00faa0abc471aee +++ b/sql/hive/src/test/resources/golden/udf_month-0-9a38997c1f41f4afe00faa0abc471aee @@ -1 +1 @@ -minute(date) - Returns the minute of date +minute(param) - Returns the minute component of the string/timestamp/interval diff --git a/sql/hive/src/test/resources/golden/udf_month-1-16995573ac4f4a1b047ad6ee88699e48 b/sql/hive/src/test/resources/golden/udf_month-1-16995573ac4f4a1b047ad6ee88699e48 index ea842ea174ae..08ddc19b84d8 100644 --- a/sql/hive/src/test/resources/golden/udf_month-1-16995573ac4f4a1b047ad6ee88699e48 +++ b/sql/hive/src/test/resources/golden/udf_month-1-16995573ac4f4a1b047ad6ee88699e48 @@ -1,6 +1,8 @@ -minute(date) - Returns the minute of date -date is a string in the format of 'yyyy-MM-dd HH:mm:ss' or 'HH:mm:ss'. -Example: +minute(param) - Returns the minute component of the string/timestamp/interval +param can be one of: +1. A string in the format of 'yyyy-MM-dd HH:mm:ss' or 'HH:mm:ss'. +2. A timestamp value +3. A day-time interval valueExample: > SELECT minute('2009-07-30 12:58:59') FROM src LIMIT 1; 58 > SELECT minute('12:58:59') FROM src LIMIT 1; diff --git a/sql/hive/src/test/resources/golden/udf_std-1-6759bde0e50a3607b7c3fd5a93cbd027 b/sql/hive/src/test/resources/golden/udf_std-1-6759bde0e50a3607b7c3fd5a93cbd027 index d54ebfbd6fb1..a529b107ff21 100644 --- a/sql/hive/src/test/resources/golden/udf_std-1-6759bde0e50a3607b7c3fd5a93cbd027 +++ b/sql/hive/src/test/resources/golden/udf_std-1-6759bde0e50a3607b7c3fd5a93cbd027 @@ -1,2 +1,2 @@ std(x) - Returns the standard deviation of a set of numbers -Synonyms: stddev_pop, stddev +Synonyms: stddev, stddev_pop diff --git a/sql/hive/src/test/resources/golden/udf_stddev-1-18e1d598820013453fad45852e1a303d b/sql/hive/src/test/resources/golden/udf_stddev-1-18e1d598820013453fad45852e1a303d index 5f674788180e..ac3176a38254 100644 --- a/sql/hive/src/test/resources/golden/udf_stddev-1-18e1d598820013453fad45852e1a303d +++ b/sql/hive/src/test/resources/golden/udf_stddev-1-18e1d598820013453fad45852e1a303d @@ -1,2 +1,2 @@ stddev(x) - Returns the standard deviation of a set of numbers -Synonyms: stddev_pop, std +Synonyms: std, stddev_pop diff --git a/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada new file mode 100644 index 000000000000..573541ac9702 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 b/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 new file mode 100644 index 000000000000..44b2a42cc26c --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 @@ -0,0 +1 @@ +unhex(str) - Converts hexadecimal argument to binary diff --git a/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 b/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 new file mode 100644 index 000000000000..97af3b812a42 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 @@ -0,0 +1,14 @@ +unhex(str) - Converts hexadecimal argument to binary +Performs the inverse operation of HEX(str). That is, it interprets +each pair of hexadecimal digits in the argument as a number and +converts it to the byte representation of the number. The +resulting characters are returned as a binary string. + +Example: +> SELECT DECODE(UNHEX('4D7953514C'), 'UTF-8') from src limit 1; +'MySQL' + +The characters in the argument string must be legal hexadecimal +digits: '0' .. '9', 'A' .. 'F', 'a' .. 'f'. If UNHEX() encounters +any nonhexadecimal digits in the argument, it returns NULL. Also, +if there are an odd number of characters a leading 0 is appended. diff --git a/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e b/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e new file mode 100644 index 000000000000..b4a6f2b69222 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e @@ -0,0 +1 @@ +MySQL 1267 a -4 diff --git a/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 b/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 new file mode 100644 index 000000000000..3a67adaf0a9a --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 @@ -0,0 +1 @@ +NULL NULL NULL diff --git a/sql/hive/src/test/resources/golden/union3-0-6a8a35102de1b0b88c6721a704eb174d b/sql/hive/src/test/resources/golden/union3-0-99620f72f0282904846a596ca5b3e46c similarity index 100% rename from sql/hive/src/test/resources/golden/union3-0-6a8a35102de1b0b88c6721a704eb174d rename to sql/hive/src/test/resources/golden/union3-0-99620f72f0282904846a596ca5b3e46c diff --git a/sql/hive/src/test/resources/golden/union3-2-2a1dcd937f117f1955a169592b96d5f9 b/sql/hive/src/test/resources/golden/union3-2-90ca96ea59fd45cf0af8c020ae77c908 similarity index 100% rename from sql/hive/src/test/resources/golden/union3-2-2a1dcd937f117f1955a169592b96d5f9 rename to sql/hive/src/test/resources/golden/union3-2-90ca96ea59fd45cf0af8c020ae77c908 diff --git a/sql/hive/src/test/resources/golden/union3-3-8fc63f8edb2969a63cd4485f1867ba97 b/sql/hive/src/test/resources/golden/union3-3-72b149ccaef751bcfe55d5ca37cb5fd7 similarity index 100% rename from sql/hive/src/test/resources/golden/union3-3-8fc63f8edb2969a63cd4485f1867ba97 rename to sql/hive/src/test/resources/golden/union3-3-72b149ccaef751bcfe55d5ca37cb5fd7 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-6dfcd7925fb267699c4bf82737d4609 b/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-6dfcd7925fb267699c4bf82737d4609 new file mode 100644 index 000000000000..7e5fceeddeee --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-6dfcd7925fb267699c4bf82737d4609 @@ -0,0 +1,97 @@ +Manufacturer#1 almond antique burnished rose metallic 2 258.10677784349235 258.10677784349235 2 66619.10876874991 0.811328754177887 2801.7074999999995 +Manufacturer#1 almond antique burnished rose metallic 2 258.10677784349235 258.10677784349235 6 66619.10876874991 0.811328754177887 2801.7074999999995 +Manufacturer#1 almond antique burnished rose metallic 2 258.10677784349235 258.10677784349235 34 66619.10876874991 0.811328754177887 2801.7074999999995 +Manufacturer#1 almond antique burnished rose metallic 2 273.70217881648074 273.70217881648074 2 74912.8826888888 1.0 4128.782222222221 +Manufacturer#1 almond antique burnished rose metallic 2 273.70217881648074 273.70217881648074 34 74912.8826888888 1.0 4128.782222222221 +Manufacturer#1 almond antique chartreuse lavender yellow 34 230.90151585470358 230.90151585470358 2 53315.51002399992 0.695639377397664 2210.7864 +Manufacturer#1 almond antique chartreuse lavender yellow 34 230.90151585470358 230.90151585470358 6 53315.51002399992 0.695639377397664 2210.7864 +Manufacturer#1 almond antique chartreuse lavender yellow 34 230.90151585470358 230.90151585470358 28 53315.51002399992 0.695639377397664 2210.7864 +Manufacturer#1 almond antique chartreuse lavender yellow 34 230.90151585470358 230.90151585470358 34 53315.51002399992 0.695639377397664 2210.7864 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 2 41099.896184 0.630785977101214 2009.9536000000007 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 6 41099.896184 0.630785977101214 2009.9536000000007 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 28 41099.896184 0.630785977101214 2009.9536000000007 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 34 41099.896184 0.630785977101214 2009.9536000000007 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 42 41099.896184 0.630785977101214 2009.9536000000007 +Manufacturer#1 almond aquamarine burnished black steel 28 121.6064517973862 121.6064517973862 6 14788.129118750014 0.2036684720435979 331.1337500000004 +Manufacturer#1 almond aquamarine burnished black steel 28 121.6064517973862 121.6064517973862 28 14788.129118750014 0.2036684720435979 331.1337500000004 +Manufacturer#1 almond aquamarine burnished black steel 28 121.6064517973862 121.6064517973862 34 14788.129118750014 0.2036684720435979 331.1337500000004 +Manufacturer#1 almond aquamarine burnished black steel 28 121.6064517973862 121.6064517973862 42 14788.129118750014 0.2036684720435979 331.1337500000004 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 96.5751586416853 96.5751586416853 6 9326.761266666683 -1.4442181184933883E-4 -0.20666666666708502 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 96.5751586416853 96.5751586416853 28 9326.761266666683 -1.4442181184933883E-4 -0.20666666666708502 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 96.5751586416853 96.5751586416853 42 9326.761266666683 -1.4442181184933883E-4 -0.20666666666708502 +Manufacturer#2 almond antique violet chocolate turquoise 14 142.2363169751898 142.2363169751898 2 20231.169866666663 -0.49369526554523185 -1113.7466666666658 +Manufacturer#2 almond antique violet chocolate turquoise 14 142.2363169751898 142.2363169751898 14 20231.169866666663 -0.49369526554523185 -1113.7466666666658 +Manufacturer#2 almond antique violet chocolate turquoise 14 142.2363169751898 142.2363169751898 40 20231.169866666663 -0.49369526554523185 -1113.7466666666658 +Manufacturer#2 almond antique violet turquoise frosted 40 137.76306498840682 137.76306498840682 2 18978.662075 -0.5205630897335946 -1004.4812499999995 +Manufacturer#2 almond antique violet turquoise frosted 40 137.76306498840682 137.76306498840682 14 18978.662075 -0.5205630897335946 -1004.4812499999995 +Manufacturer#2 almond antique violet turquoise frosted 40 137.76306498840682 137.76306498840682 25 18978.662075 -0.5205630897335946 -1004.4812499999995 +Manufacturer#2 almond antique violet turquoise frosted 40 137.76306498840682 137.76306498840682 40 18978.662075 -0.5205630897335946 -1004.4812499999995 +Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 2 16910.329504000005 -0.46908967495720255 -766.1791999999995 +Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 14 16910.329504000005 -0.46908967495720255 -766.1791999999995 +Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 18 16910.329504000005 -0.46908967495720255 -766.1791999999995 +Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 25 16910.329504000005 -0.46908967495720255 -766.1791999999995 +Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 40 16910.329504000005 -0.46908967495720255 -766.1791999999995 +Manufacturer#2 almond aquamarine rose maroon antique 25 135.55100986344584 135.55100986344584 2 18374.07627499999 -0.6091405874714462 -1128.1787499999987 +Manufacturer#2 almond aquamarine rose maroon antique 25 135.55100986344584 135.55100986344584 18 18374.07627499999 -0.6091405874714462 -1128.1787499999987 +Manufacturer#2 almond aquamarine rose maroon antique 25 135.55100986344584 135.55100986344584 25 18374.07627499999 -0.6091405874714462 -1128.1787499999987 +Manufacturer#2 almond aquamarine rose maroon antique 25 135.55100986344584 135.55100986344584 40 18374.07627499999 -0.6091405874714462 -1128.1787499999987 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 156.44019460768044 156.44019460768044 2 24473.534488888927 -0.9571686373491608 -1441.4466666666676 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 156.44019460768044 156.44019460768044 18 24473.534488888927 -0.9571686373491608 -1441.4466666666676 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 156.44019460768044 156.44019460768044 25 24473.534488888927 -0.9571686373491608 -1441.4466666666676 +Manufacturer#3 almond antique chartreuse khaki white 17 196.7742266885805 196.7742266885805 14 38720.09628888887 0.5557168646224995 224.6944444444446 +Manufacturer#3 almond antique chartreuse khaki white 17 196.7742266885805 196.7742266885805 17 38720.09628888887 0.5557168646224995 224.6944444444446 +Manufacturer#3 almond antique chartreuse khaki white 17 196.7742266885805 196.7742266885805 19 38720.09628888887 0.5557168646224995 224.6944444444446 +Manufacturer#3 almond antique forest lavender goldenrod 14 275.14144189852607 275.14144189852607 1 75702.81305 -0.6720833036576083 -1296.9000000000003 +Manufacturer#3 almond antique forest lavender goldenrod 14 275.14144189852607 275.14144189852607 14 75702.81305 -0.6720833036576083 -1296.9000000000003 +Manufacturer#3 almond antique forest lavender goldenrod 14 275.14144189852607 275.14144189852607 17 75702.81305 -0.6720833036576083 -1296.9000000000003 +Manufacturer#3 almond antique forest lavender goldenrod 14 275.14144189852607 275.14144189852607 19 75702.81305 -0.6720833036576083 -1296.9000000000003 +Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 1 67722.117896 -0.5703526513979519 -2129.0664 +Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 14 67722.117896 -0.5703526513979519 -2129.0664 +Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 17 67722.117896 -0.5703526513979519 -2129.0664 +Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 19 67722.117896 -0.5703526513979519 -2129.0664 +Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 45 67722.117896 -0.5703526513979519 -2129.0664 +Manufacturer#3 almond antique misty red olive 1 275.9139962356932 275.9139962356932 1 76128.53331875012 -0.577476899644802 -2547.7868749999993 +Manufacturer#3 almond antique misty red olive 1 275.9139962356932 275.9139962356932 14 76128.53331875012 -0.577476899644802 -2547.7868749999993 +Manufacturer#3 almond antique misty red olive 1 275.9139962356932 275.9139962356932 19 76128.53331875012 -0.577476899644802 -2547.7868749999993 +Manufacturer#3 almond antique misty red olive 1 275.9139962356932 275.9139962356932 45 76128.53331875012 -0.577476899644802 -2547.7868749999993 +Manufacturer#3 almond antique olive coral navajo 45 260.5815918713796 260.5815918713796 1 67902.76602222225 -0.8710736366736884 -4099.731111111111 +Manufacturer#3 almond antique olive coral navajo 45 260.5815918713796 260.5815918713796 19 67902.76602222225 -0.8710736366736884 -4099.731111111111 +Manufacturer#3 almond antique olive coral navajo 45 260.5815918713796 260.5815918713796 45 67902.76602222225 -0.8710736366736884 -4099.731111111111 +Manufacturer#4 almond antique gainsboro frosted violet 10 170.13011889596618 170.13011889596618 10 28944.25735555559 -0.6656975320098423 -1347.4777777777779 +Manufacturer#4 almond antique gainsboro frosted violet 10 170.13011889596618 170.13011889596618 27 28944.25735555559 -0.6656975320098423 -1347.4777777777779 +Manufacturer#4 almond antique gainsboro frosted violet 10 170.13011889596618 170.13011889596618 39 28944.25735555559 -0.6656975320098423 -1347.4777777777779 +Manufacturer#4 almond antique violet mint lemon 39 242.26834609323197 242.26834609323197 7 58693.95151875002 -0.8051852719193339 -2537.328125 +Manufacturer#4 almond antique violet mint lemon 39 242.26834609323197 242.26834609323197 10 58693.95151875002 -0.8051852719193339 -2537.328125 +Manufacturer#4 almond antique violet mint lemon 39 242.26834609323197 242.26834609323197 27 58693.95151875002 -0.8051852719193339 -2537.328125 +Manufacturer#4 almond antique violet mint lemon 39 242.26834609323197 242.26834609323197 39 58693.95151875002 -0.8051852719193339 -2537.328125 +Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 7 54802.817784000035 -0.6046935574240581 -1719.8079999999995 +Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 10 54802.817784000035 -0.6046935574240581 -1719.8079999999995 +Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 12 54802.817784000035 -0.6046935574240581 -1719.8079999999995 +Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 27 54802.817784000035 -0.6046935574240581 -1719.8079999999995 +Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 39 54802.817784000035 -0.6046935574240581 -1719.8079999999995 +Manufacturer#4 almond aquamarine yellow dodger mint 7 247.3342714197732 247.3342714197732 7 61174.24181875003 -0.5508665654707869 -1719.0368749999975 +Manufacturer#4 almond aquamarine yellow dodger mint 7 247.3342714197732 247.3342714197732 12 61174.24181875003 -0.5508665654707869 -1719.0368749999975 +Manufacturer#4 almond aquamarine yellow dodger mint 7 247.3342714197732 247.3342714197732 27 61174.24181875003 -0.5508665654707869 -1719.0368749999975 +Manufacturer#4 almond aquamarine yellow dodger mint 7 247.3342714197732 247.3342714197732 39 61174.24181875003 -0.5508665654707869 -1719.0368749999975 +Manufacturer#4 almond azure aquamarine papaya violet 12 283.3344330566893 283.3344330566893 7 80278.40095555557 -0.7755740084632333 -1867.4888888888881 +Manufacturer#4 almond azure aquamarine papaya violet 12 283.3344330566893 283.3344330566893 12 80278.40095555557 -0.7755740084632333 -1867.4888888888881 +Manufacturer#4 almond azure aquamarine papaya violet 12 283.3344330566893 283.3344330566893 27 80278.40095555557 -0.7755740084632333 -1867.4888888888881 +Manufacturer#5 almond antique blue firebrick mint 31 83.69879024746363 83.69879024746363 2 7005.487488888913 0.39004303087285047 418.9233333333353 +Manufacturer#5 almond antique blue firebrick mint 31 83.69879024746363 83.69879024746363 6 7005.487488888913 0.39004303087285047 418.9233333333353 +Manufacturer#5 almond antique blue firebrick mint 31 83.69879024746363 83.69879024746363 31 7005.487488888913 0.39004303087285047 418.9233333333353 +Manufacturer#5 almond antique medium spring khaki 6 316.68049612345885 316.68049612345885 2 100286.53662500004 -0.713612911776183 -4090.853749999999 +Manufacturer#5 almond antique medium spring khaki 6 316.68049612345885 316.68049612345885 6 100286.53662500004 -0.713612911776183 -4090.853749999999 +Manufacturer#5 almond antique medium spring khaki 6 316.68049612345885 316.68049612345885 31 100286.53662500004 -0.713612911776183 -4090.853749999999 +Manufacturer#5 almond antique medium spring khaki 6 316.68049612345885 316.68049612345885 46 100286.53662500004 -0.713612911776183 -4090.853749999999 +Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 2 81456.04997600002 -0.712858514567818 -3297.2011999999986 +Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 6 81456.04997600002 -0.712858514567818 -3297.2011999999986 +Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 23 81456.04997600002 -0.712858514567818 -3297.2011999999986 +Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 31 81456.04997600002 -0.712858514567818 -3297.2011999999986 +Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 46 81456.04997600002 -0.712858514567818 -3297.2011999999986 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 285.43749038756283 285.43749038756283 2 81474.56091875004 -0.984128787153391 -4871.028125000002 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 285.43749038756283 285.43749038756283 6 81474.56091875004 -0.984128787153391 -4871.028125000002 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 285.43749038756283 285.43749038756283 23 81474.56091875004 -0.984128787153391 -4871.028125000002 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 285.43749038756283 285.43749038756283 46 81474.56091875004 -0.984128787153391 -4871.028125000002 +Manufacturer#5 almond azure blanched chiffon midnight 23 315.9225931564038 315.9225931564038 2 99807.08486666664 -0.9978877469246936 -5664.856666666666 +Manufacturer#5 almond azure blanched chiffon midnight 23 315.9225931564038 315.9225931564038 23 99807.08486666664 -0.9978877469246936 -5664.856666666666 +Manufacturer#5 almond azure blanched chiffon midnight 23 315.9225931564038 315.9225931564038 46 99807.08486666664 -0.9978877469246936 -5664.856666666666 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-da0e0cca69e42118a96b8609b8fa5838 b/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-da0e0cca69e42118a96b8609b8fa5838 deleted file mode 100644 index 1f7e8a5d6703..000000000000 --- a/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-da0e0cca69e42118a96b8609b8fa5838 +++ /dev/null @@ -1,26 +0,0 @@ -Manufacturer#1 almond antique burnished rose metallic 2 273.70217881648074 273.70217881648074 [34,2] 74912.8826888888 1.0 4128.782222222221 -Manufacturer#1 almond antique burnished rose metallic 2 258.10677784349235 258.10677784349235 [34,2,6] 66619.10876874991 0.811328754177887 2801.7074999999995 -Manufacturer#1 almond antique chartreuse lavender yellow 34 230.90151585470358 230.90151585470358 [34,2,6,28] 53315.51002399992 0.695639377397664 2210.7864 -Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 [34,2,6,42,28] 41099.896184 0.630785977101214 2009.9536000000007 -Manufacturer#1 almond aquamarine burnished black steel 28 121.6064517973862 121.6064517973862 [34,6,42,28] 14788.129118750014 0.2036684720435979 331.1337500000004 -Manufacturer#1 almond aquamarine pink moccasin thistle 42 96.5751586416853 96.5751586416853 [6,42,28] 9326.761266666683 -1.4442181184933883E-4 -0.20666666666708502 -Manufacturer#2 almond antique violet chocolate turquoise 14 142.2363169751898 142.2363169751898 [2,40,14] 20231.169866666663 -0.49369526554523185 -1113.7466666666658 -Manufacturer#2 almond antique violet turquoise frosted 40 137.76306498840682 137.76306498840682 [2,25,40,14] 18978.662075 -0.5205630897335946 -1004.4812499999995 -Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 [2,18,25,40,14] 16910.329504000005 -0.46908967495720255 -766.1791999999995 -Manufacturer#2 almond aquamarine rose maroon antique 25 135.55100986344584 135.55100986344584 [2,18,25,40] 18374.07627499999 -0.6091405874714462 -1128.1787499999987 -Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 156.44019460768044 156.44019460768044 [2,18,25] 24473.534488888927 -0.9571686373491608 -1441.4466666666676 -Manufacturer#3 almond antique chartreuse khaki white 17 196.7742266885805 196.7742266885805 [17,19,14] 38720.09628888887 0.5557168646224995 224.6944444444446 -Manufacturer#3 almond antique forest lavender goldenrod 14 275.14144189852607 275.14144189852607 [17,1,19,14] 75702.81305 -0.6720833036576083 -1296.9000000000003 -Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 [17,1,19,14,45] 67722.117896 -0.5703526513979519 -2129.0664 -Manufacturer#3 almond antique misty red olive 1 275.9139962356932 275.9139962356932 [1,19,14,45] 76128.53331875012 -0.577476899644802 -2547.7868749999993 -Manufacturer#3 almond antique olive coral navajo 45 260.5815918713796 260.5815918713796 [1,19,45] 67902.76602222225 -0.8710736366736884 -4099.731111111111 -Manufacturer#4 almond antique gainsboro frosted violet 10 170.13011889596618 170.13011889596618 [39,27,10] 28944.25735555559 -0.6656975320098423 -1347.4777777777779 -Manufacturer#4 almond antique violet mint lemon 39 242.26834609323197 242.26834609323197 [39,7,27,10] 58693.95151875002 -0.8051852719193339 -2537.328125 -Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 [39,7,27,10,12] 54802.817784000035 -0.6046935574240581 -1719.8079999999995 -Manufacturer#4 almond aquamarine yellow dodger mint 7 247.3342714197732 247.3342714197732 [39,7,27,12] 61174.24181875003 -0.5508665654707869 -1719.0368749999975 -Manufacturer#4 almond azure aquamarine papaya violet 12 283.3344330566893 283.3344330566893 [7,27,12] 80278.40095555557 -0.7755740084632333 -1867.4888888888881 -Manufacturer#5 almond antique blue firebrick mint 31 83.69879024746363 83.69879024746363 [2,6,31] 7005.487488888913 0.39004303087285047 418.9233333333353 -Manufacturer#5 almond antique medium spring khaki 6 316.68049612345885 316.68049612345885 [2,6,46,31] 100286.53662500004 -0.713612911776183 -4090.853749999999 -Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 [2,23,6,46,31] 81456.04997600002 -0.712858514567818 -3297.2011999999986 -Manufacturer#5 almond aquamarine dodger light gainsboro 46 285.43749038756283 285.43749038756283 [2,23,6,46] 81474.56091875004 -0.984128787153391 -4871.028125000002 -Manufacturer#5 almond azure blanched chiffon midnight 23 315.9225931564038 315.9225931564038 [2,23,46] 99807.08486666664 -0.9978877469246936 -5664.856666666666 diff --git a/sql/hive/src/test/resources/log4j.properties b/sql/hive/src/test/resources/log4j.properties index 92eaf1f2795b..fea3404769d9 100644 --- a/sql/hive/src/test/resources/log4j.properties +++ b/sql/hive/src/test/resources/log4j.properties @@ -48,9 +48,14 @@ log4j.logger.hive.log=OFF log4j.additivity.parquet.hadoop.ParquetRecordReader=false log4j.logger.parquet.hadoop.ParquetRecordReader=OFF +log4j.additivity.org.apache.parquet.hadoop.ParquetRecordReader=false +log4j.logger.org.apache.parquet.hadoop.ParquetRecordReader=OFF + +log4j.additivity.org.apache.parquet.hadoop.ParquetOutputCommitter=false +log4j.logger.org.apache.parquet.hadoop.ParquetOutputCommitter=OFF + log4j.additivity.hive.ql.metadata.Hive=false log4j.logger.hive.ql.metadata.Hive=OFF log4j.additivity.org.apache.hadoop.hive.ql.io.RCFile=false log4j.logger.org.apache.hadoop.hive.ql.io.RCFile=ERROR - diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parenthesis_star_by.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parenthesis_star_by.q index 9e036c1a91d3..e911fbf2d2c5 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parenthesis_star_by.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parenthesis_star_by.q @@ -5,6 +5,6 @@ SELECT * FROM (SELECT key, value FROM src DISTRIBUTE BY key, value)t ORDER BY ke SELECT key, value FROM src CLUSTER BY (key, value); -SELECT key, value FROM src ORDER BY (key ASC, value ASC); +SELECT key, value FROM src ORDER BY key ASC, value ASC; SELECT key, value FROM src SORT BY (key, value); SELECT * FROM (SELECT key, value FROM src DISTRIBUTE BY (key, value))t ORDER BY key, value; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/union3.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/union3.q index b26a2e2799f7..a989800cbf85 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/union3.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/union3.q @@ -1,42 +1,41 @@ +-- SORT_QUERY_RESULTS explain SELECT * FROM ( SELECT 1 AS id FROM (SELECT * FROM src LIMIT 1) s1 - CLUSTER BY id UNION ALL SELECT 2 AS id FROM (SELECT * FROM src LIMIT 1) s1 - CLUSTER BY id UNION ALL SELECT 3 AS id FROM (SELECT * FROM src LIMIT 1) s2 UNION ALL SELECT 4 AS id FROM (SELECT * FROM src LIMIT 1) s2 + CLUSTER BY id ) a; CREATE TABLE union_out (id int); -insert overwrite table union_out +insert overwrite table union_out SELECT * FROM ( SELECT 1 AS id FROM (SELECT * FROM src LIMIT 1) s1 - CLUSTER BY id UNION ALL SELECT 2 AS id FROM (SELECT * FROM src LIMIT 1) s1 - CLUSTER BY id UNION ALL SELECT 3 AS id FROM (SELECT * FROM src LIMIT 1) s2 UNION ALL SELECT 4 AS id FROM (SELECT * FROM src LIMIT 1) s2 + CLUSTER BY id ) a; -select * from union_out cluster by id; +select * from union_out; diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala b/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala index e1715177e3f1..2590040f2ec1 100644 --- a/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala +++ b/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.hive.HiveContext */ object Main { def main(args: Array[String]) { + // scalastyle:off println println("Running regression test for SPARK-8489.") val sc = new SparkContext("local", "testing") val hc = new HiveContext(sc) @@ -38,6 +39,8 @@ object Main { val df = hc.createDataFrame(Seq(MyCoolClass("1", "2", "3"))) df.collect() println("Regression test for SPARK-8489 success!") + // scalastyle:on println + sc.stop() } } diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar b/sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar index 4f59fba9eab5..5944aa6076a5 100644 Binary files a/sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar and b/sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar differ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 39d315aaeab5..9adb3780a2c5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{SaveMode, AnalysisException, DataFrame, QueryTest} +import org.apache.spark.sql.columnar.InMemoryColumnarTableScan +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} import org.apache.spark.storage.RDDBlockId import org.apache.spark.util.Utils -class CachedTableSuite extends QueryTest { +class CachedTableSuite extends QueryTest with TestHiveSingleton { + import hiveContext._ def rddIdOf(tableName: String): Int = { val executedPlan = table(tableName).queryExecution.executedPlan @@ -95,18 +95,18 @@ class CachedTableSuite extends QueryTest { test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - TestHive.uncacheTable("src") + hiveContext.uncacheTable("src") } } test("'CACHE TABLE' and 'UNCACHE TABLE' HiveQL statement") { - TestHive.sql("CACHE TABLE src") + sql("CACHE TABLE src") assertCached(table("src")) - assert(TestHive.isCached("src"), "Table 'src' should be cached") + assert(hiveContext.isCached("src"), "Table 'src' should be cached") - TestHive.sql("UNCACHE TABLE src") + sql("UNCACHE TABLE src") assertCached(table("src"), 0) - assert(!TestHive.isCached("src"), "Table 'src' should not be cached") + assert(!hiveContext.isCached("src"), "Table 'src' should not be cached") } test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala new file mode 100644 index 000000000000..34b2edb44b03 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.net.URL + +import org.apache.spark.SparkFunSuite + +/** + * Verify that some classes load and that others are not found on the classpath. + * + * + * This is used to detect classpath and shading conflict, especially between + * Spark's required Kryo version and that which can be found in some Hive versions. + */ +class ClasspathDependenciesSuite extends SparkFunSuite { + private val classloader = this.getClass.getClassLoader + + private def assertLoads(classname: String): Unit = { + val resourceURL: URL = Option(findResource(classname)).getOrElse { + fail(s"Class $classname not found as ${resourceName(classname)}") + } + + logInfo(s"Class $classname at $resourceURL") + classloader.loadClass(classname) + } + + private def assertLoads(classes: String*): Unit = { + classes.foreach(assertLoads) + } + + private def findResource(classname: String): URL = { + val resource = resourceName(classname) + classloader.getResource(resource) + } + + private def resourceName(classname: String): String = { + classname.replace(".", "/") + ".class" + } + + private def assertClassNotFound(classname: String): Unit = { + Option(findResource(classname)).foreach { resourceURL => + fail(s"Class $classname found at $resourceURL") + } + + intercept[ClassNotFoundException] { + classloader.loadClass(classname) + } + } + + private def assertClassNotFound(classes: String*): Unit = { + classes.foreach(assertClassNotFound) + } + + private val KRYO = "com.esotericsoftware.kryo.Kryo" + + private val SPARK_HIVE = "org.apache.hive." + private val SPARK_SHADED = "org.spark-project.hive.shaded." + + test("shaded Protobuf") { + assertLoads(SPARK_SHADED + "com.google.protobuf.ServiceException") + } + + test("hive-common") { + assertLoads("org.apache.hadoop.hive.conf.HiveConf") + } + + test("hive-exec") { + assertLoads("org.apache.hadoop.hive.ql.CommandNeedRetryException") + } + + private val STD_INSTANTIATOR = "org.objenesis.strategy.StdInstantiatorStrategy" + + test("unshaded kryo") { + assertLoads(KRYO, STD_INSTANTIATOR) + } + + test("Forbidden Dependencies") { + assertClassNotFound( + SPARK_HIVE + KRYO, + SPARK_SHADED + KRYO, + "org.apache.hive." + KRYO, + "com.esotericsoftware.shaded." + STD_INSTANTIATOR, + SPARK_HIVE + "com.esotericsoftware.shaded." + STD_INSTANTIATOR, + "org.apache.hive.com.esotericsoftware.shaded." + STD_INSTANTIATOR + ) + } + + test("parquet-hadoop-bundle") { + assertLoads( + "parquet.hadoop.ParquetOutputFormat", + "parquet.hadoop.ParquetInputFormat" + ) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index 30f5313d2b81..cf737836939f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -22,12 +22,12 @@ import scala.util.Try import org.scalatest.BeforeAndAfter import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.{AnalysisException, QueryTest} -class ErrorPositionSuite extends QueryTest with BeforeAndAfter { +class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter { + import hiveContext.implicits._ before { Seq((1, 1, 1)).toDF("a", "a", "b").registerTempTable("dupAttributes") @@ -122,7 +122,7 @@ class ErrorPositionSuite extends QueryTest with BeforeAndAfter { test(name) { val error = intercept[AnalysisException] { - quietly(sql(query)) + quietly(hiveContext.sql(query)) } assert(!error.getMessage.contains("Seq(")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index fb10f8583da9..2e5cae415e54 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -19,24 +19,25 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.scalatest.BeforeAndAfterAll // TODO ideally we should put the test suite into the package `sql`, as // `hive` package is optional in compiling, however, `SQLContext.sql` doesn't // support the `cube` or `rollup` yet. -class HiveDataFrameAnalyticsSuite extends QueryTest with BeforeAndAfterAll { +class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { + import hiveContext.implicits._ + import hiveContext.sql + private var testData: DataFrame = _ override def beforeAll() { testData = Seq((1, 2), (2, 4)).toDF("a", "b") - TestHive.registerDataFrameAsTable(testData, "mytable") + hiveContext.registerDataFrameAsTable(testData, "mytable") } override def afterAll(): Unit = { - TestHive.dropTempTable("mytable") + hiveContext.dropTempTable("mytable") } test("rollup") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala index 52e782768cb7..f621367eb553 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveSingleton - -class HiveDataFrameJoinSuite extends QueryTest { +class HiveDataFrameJoinSuite extends QueryTest with TestHiveSingleton { + import hiveContext.implicits._ // We should move this into SQL package if we make case sensitivity configurable in SQL. test("join - self join auto resolve ambiguity with case insensitivity") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala index efb3f2545db8..2c98f1c3cc49 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{Row, QueryTest} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveSingleton -class HiveDataFrameWindowSuite extends QueryTest { +class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { + import hiveContext.implicits._ + import hiveContext.sql test("reuse window partitionBy") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") @@ -183,13 +184,13 @@ class HiveDataFrameWindowSuite extends QueryTest { } test("aggregation and range betweens with unbounded") { - val df = Seq((1, "1"), (2, "2"), (2, "2"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( $"key", last("value").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(1, Long.MaxValue)) + Window.partitionBy($"value").orderBy($"key").rangeBetween(-2, -1)) .equalTo("2") .as("last_v"), avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(Long.MinValue, 1)) @@ -203,7 +204,7 @@ class HiveDataFrameWindowSuite extends QueryTest { """SELECT | key, | last_value(value) OVER - | (PARTITION BY value ORDER BY key RANGE 1 preceding) == "2", + | (PARTITION BY value ORDER BY key RANGE BETWEEN 2 preceding and 1 preceding) == "2", | avg(key) OVER | (PARTITION BY value ORDER BY key RANGE BETWEEN unbounded preceding and 1 following), | avg(key) OVER @@ -212,4 +213,47 @@ class HiveDataFrameWindowSuite extends QueryTest { | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and current row) | FROM window_table""".stripMargin).collect()) } + + test("reverse sliding range frame") { + val df = Seq( + (1, "Thin", "Cell Phone", 6000), + (2, "Normal", "Tablet", 1500), + (3, "Mini", "Tablet", 5500), + (4, "Ultra thin", "Cell Phone", 5500), + (5, "Very thin", "Cell Phone", 6000), + (6, "Big", "Tablet", 2500), + (7, "Bendable", "Cell Phone", 3000), + (8, "Foldable", "Cell Phone", 3000), + (9, "Pro", "Tablet", 4500), + (10, "Pro2", "Tablet", 6500)). + toDF("id", "product", "category", "revenue") + val window = Window. + partitionBy($"category"). + orderBy($"revenue".desc). + rangeBetween(-2000L, 1000L) + checkAnswer( + df.select( + $"id", + avg($"revenue").over(window).cast("int")), + Row(1, 5833) :: Row(2, 2000) :: Row(3, 5500) :: + Row(4, 5833) :: Row(5, 5833) :: Row(6, 2833) :: + Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: + Row(10, 6000) :: Nil) + } + + // This is here to illustrate the fact that reverse order also reverses offsets. + test("reverse unbounded range frame") { + val df = Seq(1, 2, 4, 3, 2, 1). + map(Tuple1.apply). + toDF("value") + val window = Window.orderBy($"value".desc) + checkAnswer( + df.select( + $"value", + sum($"value").over(window.rangeBetween(Long.MinValue, 1)), + sum($"value").over(window.rangeBetween(1, Long.MaxValue))), + Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: + Row(3, 11, 6) :: Row(2, 13, 2) :: Row(1, 13, null) :: Nil) + + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index aff0456b37ed..81a70b8d4226 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -28,7 +28,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.apache.hadoop.io.LongWritable import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{Literal, InternalRow} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.types._ import org.apache.spark.sql.Row @@ -47,7 +48,11 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { ObjectInspectorOptions.JAVA).asInstanceOf[StructObjectInspector] val a = unwrap(state, soi).asInstanceOf[InternalRow] - val b = wrap(a, soi).asInstanceOf[UDAFPercentile.State] + + val dt = new StructType() + .add("counts", MapType(LongType, LongType)) + .add("percentiles", ArrayType(DoubleType)) + val b = wrap(a, soi, dt).asInstanceOf[UDAFPercentile.State] val sfCounts = soi.getStructFieldRef("counts") val sfPercentiles = soi.getStructFieldRef("percentiles") @@ -128,8 +133,8 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { } } - def checkValues(row1: Seq[Any], row2: InternalRow): Unit = { - row1.zip(row2.toSeq).foreach { case (r1, r2) => + def checkValues(row1: Seq[Any], row2: InternalRow, row2Schema: StructType): Unit = { + row1.zip(row2.toSeq(row2Schema)).foreach { case (r1, r2) => checkValue(r1, r2) } } @@ -142,6 +147,8 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { case (r1: Array[Byte], r2: Array[Byte]) if r1 != null && r2 != null && r1.length == r2.length => r1.zip(r2).foreach { case (b1, b2) => assert(b1 === b2) } + // We don't support equality & ordering for map type, so skip it. + case (r1: MapData, r2: MapData) => case (r1, r2) => assert(r1 === r2) } } @@ -157,44 +164,45 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { val writableOIs = dataTypes.map(toWritableInspector) val nullRow = data.map(d => null) - checkValues(nullRow, nullRow.zip(writableOIs).map { - case (d, oi) => unwrap(wrap(d, oi), oi) + checkValues(nullRow, nullRow.zip(writableOIs).zip(dataTypes).map { + case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi) }) // struct couldn't be constant, sweep it out val constantExprs = data.filter(!_.dataType.isInstanceOf[StructType]) + val constantTypes = constantExprs.map(_.dataType) val constantData = constantExprs.map(_.eval()) val constantNullData = constantData.map(_ => null) val constantWritableOIs = constantExprs.map(e => toWritableInspector(e.dataType)) val constantNullWritableOIs = constantExprs.map(e => toInspector(Literal.create(null, e.dataType))) - checkValues(constantData, constantData.zip(constantWritableOIs).map { - case (d, oi) => unwrap(wrap(d, oi), oi) + checkValues(constantData, constantData.zip(constantWritableOIs).zip(constantTypes).map { + case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi) }) - checkValues(constantNullData, constantData.zip(constantNullWritableOIs).map { - case (d, oi) => unwrap(wrap(d, oi), oi) + checkValues(constantNullData, constantData.zip(constantNullWritableOIs).zip(constantTypes).map { + case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi) }) - checkValues(constantNullData, constantNullData.zip(constantWritableOIs).map { - case (d, oi) => unwrap(wrap(d, oi), oi) + checkValues(constantNullData, constantNullData.zip(constantWritableOIs).zip(constantTypes).map { + case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi) }) } test("wrap / unwrap primitive writable object inspector") { val writableOIs = dataTypes.map(toWritableInspector) - checkValues(row, row.zip(writableOIs).map { - case (data, oi) => unwrap(wrap(data, oi), oi) + checkValues(row, row.zip(writableOIs).zip(dataTypes).map { + case ((data, oi), dt) => unwrap(wrap(data, oi, dt), oi) }) } test("wrap / unwrap primitive java object inspector") { val ois = dataTypes.map(toInspector) - checkValues(row, row.zip(ois).map { - case (data, oi) => unwrap(wrap(data, oi), oi) + checkValues(row, row.zip(ois).zip(dataTypes).map { + case ((data, oi), dt) => unwrap(wrap(data, oi, dt), oi) }) } @@ -202,33 +210,37 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { val dt = StructType(dataTypes.zipWithIndex.map { case (t, idx) => StructField(s"c_$idx", t) }) - - checkValues(row, - unwrap(wrap(Row.fromSeq(row), toInspector(dt)), toInspector(dt)).asInstanceOf[InternalRow]) - checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) + val inspector = toInspector(dt) + checkValues( + row, + unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), inspector).asInstanceOf[InternalRow], + dt) + checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) } test("wrap / unwrap Array Type") { val dt = ArrayType(dataTypes(0)) - val d = row(0) :: row(0) :: Nil - checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) - checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) + val d = new GenericArrayData(Array(row(0), row(0))) + checkValue(d, unwrap(wrap(d, toInspector(dt), dt), toInspector(dt))) + checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) checkValue(d, - unwrap(wrap(d, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + unwrap(wrap(d, toInspector(Literal.create(d, dt)), dt), toInspector(Literal.create(d, dt)))) checkValue(d, - unwrap(wrap(null, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + unwrap(wrap(null, toInspector(Literal.create(d, dt)), dt), + toInspector(Literal.create(d, dt)))) } test("wrap / unwrap Map Type") { val dt = MapType(dataTypes(0), dataTypes(1)) - val d = Map(row(0) -> row(1)) - checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) - checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) + val d = ArrayBasedMapData(Array(row(0)), Array(row(1))) + checkValue(d, unwrap(wrap(d, toInspector(dt), dt), toInspector(dt))) + checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) checkValue(d, - unwrap(wrap(d, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + unwrap(wrap(d, toInspector(Literal.create(d, dt)), dt), toInspector(Literal.create(d, dt)))) checkValue(d, - unwrap(wrap(null, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + unwrap(wrap(null, toInspector(Literal.create(d, dt)), dt), + toInspector(Literal.create(d, dt)))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index e9bb32667936..107457f79ec0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -17,31 +17,144 @@ package org.apache.spark.sql.hive -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.hive.test.TestHive +import java.io.File -import org.apache.spark.sql.test.ExamplePointUDT -import org.apache.spark.sql.types.StructType +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{QueryTest, Row, SaveMode} +import org.apache.spark.sql.hive.client.{ExternalTable, ManagedTable} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils} +import org.apache.spark.sql.types.{DecimalType, StringType, StructType} -class HiveMetastoreCatalogSuite extends SparkFunSuite { +class HiveMetastoreCatalogSuite extends SparkFunSuite with TestHiveSingleton { + import hiveContext.implicits._ test("struct field should accept underscore in sub-column name") { - val metastr = "struct" - - val datatype = HiveMetastoreTypes.toDataType(metastr) - assert(datatype.isInstanceOf[StructType]) + val hiveTypeStr = "struct" + val dateType = HiveMetastoreTypes.toDataType(hiveTypeStr) + assert(dateType.isInstanceOf[StructType]) } test("udt to metastore type conversion") { val udt = new ExamplePointUDT - assert(HiveMetastoreTypes.toMetastoreType(udt) === - HiveMetastoreTypes.toMetastoreType(udt.sqlType)) + assertResult(HiveMetastoreTypes.toMetastoreType(udt.sqlType)) { + HiveMetastoreTypes.toMetastoreType(udt) + } } test("duplicated metastore relations") { - import TestHive.implicits._ - val df = TestHive.sql("SELECT * FROM src") - println(df.queryExecution) + val df = hiveContext.sql("SELECT * FROM src") + logInfo(df.queryExecution.toString) df.as('a).join(df.as('b), $"a.key" === $"b.key") } } + +class DataSourceWithHiveMetastoreCatalogSuite + extends QueryTest with SQLTestUtils with TestHiveSingleton { + import hiveContext._ + import testImplicits._ + + private val testDF = range(1, 3).select( + ('id + 0.1) cast DecimalType(10, 3) as 'd1, + 'id cast StringType as 'd2 + ).coalesce(1) + + Seq( + "parquet" -> ( + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + ), + + "orc" -> ( + "org.apache.hadoop.hive.ql.io.orc.OrcInputFormat", + "org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat", + "org.apache.hadoop.hive.ql.io.orc.OrcSerde" + ) + ).foreach { case (provider, (inputFormat, outputFormat, serde)) => + test(s"Persist non-partitioned $provider relation into metastore as managed table") { + withTable("t") { + testDF + .write + .mode(SaveMode.Overwrite) + .format(provider) + .saveAsTable("t") + + val hiveTable = catalog.client.getTable("default", "t") + assert(hiveTable.inputFormat === Some(inputFormat)) + assert(hiveTable.outputFormat === Some(outputFormat)) + assert(hiveTable.serde === Some(serde)) + + assert(!hiveTable.isPartitioned) + assert(hiveTable.tableType === ManagedTable) + + val columns = hiveTable.schema + assert(columns.map(_.name) === Seq("d1", "d2")) + assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string")) + + checkAnswer(table("t"), testDF) + assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) + } + } + + test(s"Persist non-partitioned $provider relation into metastore as external table") { + withTempPath { dir => + withTable("t") { + val path = dir.getCanonicalFile + + testDF + .write + .mode(SaveMode.Overwrite) + .format(provider) + .option("path", path.toString) + .saveAsTable("t") + + val hiveTable = catalog.client.getTable("default", "t") + assert(hiveTable.inputFormat === Some(inputFormat)) + assert(hiveTable.outputFormat === Some(outputFormat)) + assert(hiveTable.serde === Some(serde)) + + assert(hiveTable.tableType === ExternalTable) + assert(hiveTable.location.get === path.toURI.toString.stripSuffix(File.separator)) + + val columns = hiveTable.schema + assert(columns.map(_.name) === Seq("d1", "d2")) + assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string")) + + checkAnswer(table("t"), testDF) + assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) + } + } + } + + test(s"Persist non-partitioned $provider relation into metastore as managed table using CTAS") { + withTempPath { dir => + withTable("t") { + val path = dir.getCanonicalPath + + sql( + s"""CREATE TABLE t USING $provider + |OPTIONS (path '$path') + |AS SELECT 1 AS d1, "val_1" AS d2 + """.stripMargin) + + val hiveTable = catalog.client.getTable("default", "t") + assert(hiveTable.inputFormat === Some(inputFormat)) + assert(hiveTable.outputFormat === Some(outputFormat)) + assert(hiveTable.serde === Some(serde)) + + assert(hiveTable.isPartitioned === false) + assert(hiveTable.tableType === ExternalTable) + assert(hiveTable.partitionColumns.length === 0) + + val columns = hiveTable.schema + assert(columns.map(_.name) === Seq("d1", "d2")) + assert(columns.map(_.hiveType) === Seq("int", "string")) + + checkAnswer(table("t"), Row(1, "val_1")) + assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) + } + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index af68615e8e9d..5596ec6882ea 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -17,75 +17,62 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.parquet.ParquetTest -import org.apache.spark.sql.{QueryTest, Row, SQLConf} +import org.apache.spark.sql.execution.datasources.parquet.ParquetTest +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.{QueryTest, Row} case class Cases(lower: String, UPPER: String) -class HiveParquetSuite extends QueryTest with ParquetTest { - val sqlContext = TestHive +class HiveParquetSuite extends QueryTest with ParquetTest with TestHiveSingleton { - import sqlContext._ + test("Case insensitive attribute names") { + withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { + val expected = (1 to 4).map(i => Row(i.toString)) + checkAnswer(sql("SELECT upper FROM cases"), expected) + checkAnswer(sql("SELECT LOWER FROM cases"), expected) + } + } - def run(prefix: String): Unit = { - test(s"$prefix: Case insensitive attribute names") { - withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { - val expected = (1 to 4).map(i => Row(i.toString)) - checkAnswer(sql("SELECT upper FROM cases"), expected) - checkAnswer(sql("SELECT LOWER FROM cases"), expected) - } + test("SELECT on Parquet table") { + val data = (1 to 4).map(i => (i, s"val_$i")) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT * FROM t"), data.map(Row.fromTuple)) } + } - test(s"$prefix: SELECT on Parquet table") { - val data = (1 to 4).map(i => (i, s"val_$i")) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT * FROM t"), data.map(Row.fromTuple)) - } + test("Simple column projection + filter on Parquet table") { + withParquetTable((1 to 4).map(i => (i % 2 == 0, i, s"val_$i")), "t") { + checkAnswer( + sql("SELECT `_1`, `_3` FROM t WHERE `_1` = true"), + Seq(Row(true, "val_2"), Row(true, "val_4"))) } + } - test(s"$prefix: Simple column projection + filter on Parquet table") { - withParquetTable((1 to 4).map(i => (i % 2 == 0, i, s"val_$i")), "t") { + test("Converting Hive to Parquet Table via saveAsParquetFile") { + withTempPath { dir => + sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath) + hiveContext.read.parquet(dir.getCanonicalPath).registerTempTable("p") + withTempTable("p") { checkAnswer( - sql("SELECT `_1`, `_3` FROM t WHERE `_1` = true"), - Seq(Row(true, "val_2"), Row(true, "val_4"))) + sql("SELECT * FROM src ORDER BY key"), + sql("SELECT * from p ORDER BY key").collect().toSeq) } } + } - test(s"$prefix: Converting Hive to Parquet Table via saveAsParquetFile") { - withTempPath { dir => - sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath) - read.parquet(dir.getCanonicalPath).registerTempTable("p") + test("INSERT OVERWRITE TABLE Parquet table") { + withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { + withTempPath { file => + sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath) + hiveContext.read.parquet(file.getCanonicalPath).registerTempTable("p") withTempTable("p") { - checkAnswer( - sql("SELECT * FROM src ORDER BY key"), - sql("SELECT * from p ORDER BY key").collect().toSeq) + // let's do three overwrites for good measure + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + checkAnswer(sql("SELECT * FROM p"), sql("SELECT * FROM t").collect().toSeq) } } } - - test(s"$prefix: INSERT OVERWRITE TABLE Parquet table") { - withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { - withTempPath { file => - sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath) - read.parquet(file.getCanonicalPath).registerTempTable("p") - withTempTable("p") { - // let's do three overwrites for good measure - sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - checkAnswer(sql("SELECT * FROM p"), sql("SELECT * FROM t").collect().toSeq) - } - } - } - } - } - - withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API.key -> "true") { - run("Parquet data source enabled") - } - - withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API.key -> "false") { - run("Parquet data source disabled") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index f765395e148a..79cf40aba4bf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -175,4 +175,19 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { assert(desc.serde == Option("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe")) assert(desc.properties == Map(("tbl_p1" -> "p11"), ("tbl_p2" -> "p22"))) } + + test("Invalid interval term should throw AnalysisException") { + def assertError(sql: String, errorMessage: String): Unit = { + val e = intercept[AnalysisException] { + HiveQl.parseSql(sql) + } + assert(e.getMessage.contains(errorMessage)) + } + assertError("select interval '42-32' year to month", + "month 32 outside range [0, 11]") + assertError("select interval '5 49:12:15' day to second", + "hour 49 outside range [0, 23]") + assertError("select interval '.1111111111' second", + "nanosecond 1111111111 outside range") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index ab443032be20..97df249bdb6d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -18,20 +18,31 @@ package org.apache.spark.sql.hive import java.io.File +import java.sql.Timestamp +import java.util.Date + +import scala.collection.mutable.ArrayBuffer -import org.apache.spark._ -import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} -import org.apache.spark.util.{ResetSystemProperties, Utils} import org.scalatest.Matchers import org.scalatest.concurrent.Timeouts +import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ +import org.apache.spark._ +import org.apache.spark.sql.{SQLContext, QueryTest} +import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} +import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer +import org.apache.spark.sql.types.DecimalType +import org.apache.spark.util.{ResetSystemProperties, Utils} + /** * This suite tests spark-submit with applications using HiveContext. */ class HiveSparkSubmitSuite extends SparkFunSuite with Matchers + // This test suite sometimes gets extremely slow out of unknown reason on Jenkins. Here we + // add a timestamp to provide more diagnosis information. with ResetSystemProperties with Timeouts { @@ -45,13 +56,15 @@ class HiveSparkSubmitSuite val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) - val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath() - val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath() + val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath + val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath val jarsString = Seq(jar1, jar2, jar3, jar4).map(j => j.toString).mkString(",") val args = Seq( "--class", SparkSubmitClassLoaderTest.getClass.getName.stripSuffix("$"), "--name", "SparkSubmitClassLoaderTest", - "--master", "local-cluster[2,1,512]", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", "--jars", jarsString, unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") runSparkSubmit(args) @@ -62,7 +75,9 @@ class HiveSparkSubmitSuite val args = Seq( "--class", SparkSQLConfTest.getClass.getName.stripSuffix("$"), "--name", "SparkSQLConfTest", - "--master", "local-cluster[2,1,512]", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", unusedJar.toString) runSparkSubmit(args) } @@ -74,7 +89,21 @@ class HiveSparkSubmitSuite // the HiveContext code mistakenly overrides the class loader that contains user classes. // For more detail, see sql/hive/src/test/resources/regression-test-SPARK-8489/*scala. val testJar = "sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar" - val args = Seq("--class", "Main", testJar) + val args = Seq( + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--class", "Main", + testJar) + runSparkSubmit(args) + } + + test("SPARK-9757 Persist Parquet relation with decimal column") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SPARK_9757.getClass.getName.stripSuffix("$"), + "--name", "SparkSQLConfTest", + "--master", "local-cluster[2,1,1024]", + unusedJar.toString) runSparkSubmit(args) } @@ -82,15 +111,55 @@ class HiveSparkSubmitSuite // This is copied from org.apache.spark.deploy.SparkSubmitSuite private def runSparkSubmit(args: Seq[String]): Unit = { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - val process = Utils.executeCommand( - Seq("./bin/spark-submit") ++ args, - new File(sparkHome), - Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) + val history = ArrayBuffer.empty[String] + val commands = Seq("./bin/spark-submit") ++ args + val commandLine = commands.mkString("'", "' '", "'") + + val builder = new ProcessBuilder(commands: _*).directory(new File(sparkHome)) + val env = builder.environment() + env.put("SPARK_TESTING", "1") + env.put("SPARK_HOME", sparkHome) + + def captureOutput(source: String)(line: String): Unit = { + // This test suite has some weird behaviors when executed on Jenkins: + // + // 1. Sometimes it gets extremely slow out of unknown reason on Jenkins. Here we add a + // timestamp to provide more diagnosis information. + // 2. Log lines are not correctly redirected to unit-tests.log as expected, so here we print + // them out for debugging purposes. + val logLine = s"${new Timestamp(new Date().getTime)} - $source> $line" + // scalastyle:off println + println(logLine) + // scalastyle:on println + history += logLine + } + + val process = builder.start() + new ProcessOutputCapturer(process.getInputStream, captureOutput("stdout")).start() + new ProcessOutputCapturer(process.getErrorStream, captureOutput("stderr")).start() + try { - val exitCode = failAfter(120 seconds) { process.waitFor() } + val exitCode = failAfter(180.seconds) { process.waitFor() } if (exitCode != 0) { - fail(s"Process returned with exit code $exitCode. See the log4j logs for more detail.") + // include logs in output. Note that logging is async and may not have completed + // at the time this exception is raised + Thread.sleep(1000) + val historyLog = history.mkString("\n") + fail { + s"""spark-submit returned with exit code $exitCode. + |Command line: $commandLine + | + |$historyLog + """.stripMargin + } } + } catch { + case to: TestFailedDueToTimeoutException => + val historyLog = history.mkString("\n") + fail(s"Timeout of $commandLine" + + s" See the log4j logs for more detail." + + s"\n$historyLog", to) + case t: Throwable => throw t } finally { // Ensure we still kill the process in case it timed out process.destroy() @@ -104,23 +173,26 @@ object SparkSubmitClassLoaderTest extends Logging { def main(args: Array[String]) { Utils.configTestLog4j("INFO") val conf = new SparkConf() + conf.set("spark.ui.enabled", "false") val sc = new SparkContext(conf) val hiveContext = new TestHiveContext(sc) val df = hiveContext.createDataFrame((1 to 100).map(i => (i, i))).toDF("i", "j") + logInfo("Testing load classes at the driver side.") // First, we load classes at driver side. try { - Class.forName(args(0), true, Thread.currentThread().getContextClassLoader) - Class.forName(args(1), true, Thread.currentThread().getContextClassLoader) + Utils.classForName(args(0)) + Utils.classForName(args(1)) } catch { case t: Throwable => throw new Exception("Could not load user class from jar:\n", t) } // Second, we load classes at the executor side. + logInfo("Testing load classes at the executor side.") val result = df.mapPartitions { x => var exception: String = null try { - Class.forName(args(0), true, Thread.currentThread().getContextClassLoader) - Class.forName(args(1), true, Thread.currentThread().getContextClassLoader) + Utils.classForName(args(0)) + Utils.classForName(args(1)) } catch { case t: Throwable => exception = t + "\n" + t.getStackTraceString @@ -133,6 +205,7 @@ object SparkSubmitClassLoaderTest extends Logging { } // Load a Hive UDF from the jar. + logInfo("Registering temporary Hive UDF provided in a jar.") hiveContext.sql( """ |CREATE TEMPORARY FUNCTION example_max @@ -142,18 +215,23 @@ object SparkSubmitClassLoaderTest extends Logging { hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val") source.registerTempTable("sourceTable") // Load a Hive SerDe from the jar. + logInfo("Creating a Hive table with a SerDe provided in a jar.") hiveContext.sql( """ |CREATE TABLE t1(key int, val string) |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe' """.stripMargin) // Actually use the loaded UDF and SerDe. + logInfo("Writing data into the table.") hiveContext.sql( "INSERT INTO TABLE t1 SELECT example_max(key) as key, val FROM sourceTable GROUP BY val") + logInfo("Running a simple query on the table.") val count = hiveContext.table("t1").orderBy("key", "val").count() if (count != 10) { throw new Exception(s"table t1 should have 10 rows instead of $count rows") } + logInfo("Test finishes.") + sc.stop() } } @@ -168,7 +246,7 @@ object SparkSQLConfTest extends Logging { // before spark.sql.hive.metastore.jars get set, we will see the following exception: // Exception in thread "main" java.lang.IllegalArgumentException: Builtin jars can only // be used when hive execution version == hive metastore version. - // Execution: 0.13.1 != Metastore: 0.12. Specify a vaild path to the correct hive jars + // Execution: 0.13.1 != Metastore: 0.12. Specify a valid path to the correct hive jars // using $HIVE_METASTORE_JARS or change spark.sql.hive.metastore.version to 0.13.1. val conf = new SparkConf() { override def getAll: Array[(String, String)] = { @@ -187,9 +265,58 @@ object SparkSQLConfTest extends Logging { // For this simple test, we do not really clone this object. override def clone: SparkConf = this } + conf.set("spark.ui.enabled", "false") val sc = new SparkContext(conf) val hiveContext = new TestHiveContext(sc) // Run a simple command to make sure all lazy vals in hiveContext get instantiated. hiveContext.tables().collect() + sc.stop() + } +} + +object SPARK_9757 extends QueryTest { + import org.apache.spark.sql.functions._ + + protected var sqlContext: SQLContext = _ + + def main(args: Array[String]): Unit = { + Utils.configTestLog4j("INFO") + + val sparkContext = new SparkContext( + new SparkConf() + .set("spark.sql.hive.metastore.version", "0.13.1") + .set("spark.sql.hive.metastore.jars", "maven") + .set("spark.ui.enabled", "false")) + + val hiveContext = new TestHiveContext(sparkContext) + sqlContext = hiveContext + import hiveContext.implicits._ + + val dir = Utils.createTempDir() + dir.delete() + + try { + { + val df = + hiveContext + .range(10) + .select(('id + 0.1) cast DecimalType(10, 3) as 'dec) + df.write.option("path", dir.getCanonicalPath).mode("overwrite").saveAsTable("t") + checkAnswer(hiveContext.table("t"), df) + } + + { + val df = + hiveContext + .range(10) + .select(callUDF("struct", ('id + 0.2) cast DecimalType(10, 3)) as 'dec_struct) + df.write.option("path", dir.getCanonicalPath).mode("overwrite").saveAsTable("t") + checkAnswer(hiveContext.table("t"), df) + } + } finally { + dir.delete() + hiveContext.sql("DROP TABLE t") + sparkContext.stop() + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index aa5dbe2db690..80a61f82fd24 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -19,32 +19,30 @@ package org.apache.spark.sql.hive import java.io.File +import org.apache.hadoop.hive.conf.HiveConf import org.scalatest.BeforeAndAfter import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.{QueryTest, _} -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -/* Implicits */ -import org.apache.spark.sql.hive.test.TestHive._ - case class TestData(key: Int, value: String) case class ThreeCloumntable(key: Int, value: String, key1: String) -class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { - import org.apache.spark.sql.hive.test.TestHive.implicits._ - +class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter { + import hiveContext.implicits._ + import hiveContext.sql - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() before { // Since every we are doing tests for DDL statements, // it is better to reset before every test. - TestHive.reset() + hiveContext.reset() // Register the testData, which will be used in every test. testData.registerTempTable("testData") } @@ -86,8 +84,6 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { val message = intercept[QueryExecutionException] { sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") }.getMessage - - println("message!!!!" + message) } test("Double create does not fail when allowExisting = true") { @@ -97,9 +93,9 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("SPARK-4052: scala.collection.Map as value type of MapType") { val schema = StructType(StructField("m", MapType(StringType, StringType), true) :: Nil) - val rowRDD = TestHive.sparkContext.parallelize( + val rowRDD = hiveContext.sparkContext.parallelize( (1 to 100).map(i => Row(scala.collection.mutable.HashMap(s"key$i" -> s"value$i")))) - val df = TestHive.createDataFrame(rowRDD, schema) + val df = hiveContext.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithMapValue") sql("CREATE TABLE hiveTableWithMapValue(m MAP )") sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") @@ -115,6 +111,8 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("SPARK-4203:random partition directory order") { sql("CREATE TABLE tmp_table (key int, value string)") val tmpDir = Utils.createTempDir() + val stagingDir = new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR) + sql( s""" |CREATE TABLE table_with_partition(c1 string) @@ -147,7 +145,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { """.stripMargin) def listFolders(path: File, acc: List[String]): List[List[String]] = { val dir = path.listFiles() - val folders = dir.filter(_.isDirectory).toList + val folders = dir.filter { e => e.isDirectory && !e.getName().startsWith(stagingDir) }.toList if (folders.isEmpty) { List(acc.reverse) } else { @@ -160,7 +158,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=1"::Nil , "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=4"::Nil ) - assert(listFolders(tmpDir, List()).sortBy(_.toString()) == expected.sortBy(_.toString)) + assert(listFolders(tmpDir, List()).sortBy(_.toString()) === expected.sortBy(_.toString)) sql("DROP TABLE table_with_partition") sql("DROP TABLE tmp_table") } @@ -168,8 +166,8 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("Insert ArrayType.containsNull == false") { val schema = StructType(Seq( StructField("a", ArrayType(StringType, containsNull = false)))) - val rowRDD = TestHive.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i")))) - val df = TestHive.createDataFrame(rowRDD, schema) + val rowRDD = hiveContext.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i")))) + val df = hiveContext.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithArrayValue") sql("CREATE TABLE hiveTableWithArrayValue(a Array )") sql("INSERT OVERWRITE TABLE hiveTableWithArrayValue SELECT a FROM tableWithArrayValue") @@ -184,9 +182,9 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("Insert MapType.valueContainsNull == false") { val schema = StructType(Seq( StructField("m", MapType(StringType, StringType, valueContainsNull = false)))) - val rowRDD = TestHive.sparkContext.parallelize( + val rowRDD = hiveContext.sparkContext.parallelize( (1 to 100).map(i => Row(Map(s"key$i" -> s"value$i")))) - val df = TestHive.createDataFrame(rowRDD, schema) + val df = hiveContext.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithMapValue") sql("CREATE TABLE hiveTableWithMapValue(m Map )") sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") @@ -201,9 +199,9 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("Insert StructType.fields.exists(_.nullable == false)") { val schema = StructType(Seq( StructField("s", StructType(Seq(StructField("f", StringType, nullable = false)))))) - val rowRDD = TestHive.sparkContext.parallelize( + val rowRDD = hiveContext.sparkContext.parallelize( (1 to 100).map(i => Row(Row(s"value$i")))) - val df = TestHive.createDataFrame(rowRDD, schema) + val df = hiveContext.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithStructValue") sql("CREATE TABLE hiveTableWithStructValue(s Struct )") sql("INSERT OVERWRITE TABLE hiveTableWithStructValue SELECT s FROM tableWithStructValue") @@ -216,11 +214,11 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { } test("SPARK-5498:partition schema does not match table schema") { - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( (1 to 10).map(i => TestData(i, i.toString))).toDF() testData.registerTempTable("testData") - val testDatawithNull = TestHive.sparkContext.parallelize( + val testDatawithNull = hiveContext.sparkContext.parallelize( (1 to 10).map(i => ThreeCloumntable(i, i.toString, null))).toDF() val tmpDir = Utils.createTempDir() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index 1c15997ea8e6..579631df772b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -19,22 +19,19 @@ package org.apache.spark.sql.hive import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.QueryTest import org.apache.spark.sql.Row -class ListTablesSuite extends QueryTest with BeforeAndAfterAll { +class ListTablesSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { + import hiveContext._ + import hiveContext.implicits._ - import org.apache.spark.sql.hive.test.TestHive.implicits._ - - val df = - sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") + val df = sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") override def beforeAll(): Unit = { // The catalog in HiveContext is a case insensitive one. catalog.registerTable(Seq("ListTablesSuiteTable"), df.logicalPlan) - catalog.registerTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"), df.logicalPlan) sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)") sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB") sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)") @@ -42,7 +39,6 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll { override def afterAll(): Unit = { catalog.unregisterTable(Seq("ListTablesSuiteTable")) - catalog.unregisterTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable")) sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable") sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable") sql("DROP DATABASE IF EXISTS ListTablesSuiteDB") @@ -55,7 +51,6 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll { checkAnswer( allTables.filter("tableName = 'listtablessuitetable'"), Row("listtablessuitetable", true)) - assert(allTables.filter("tableName = 'indblisttablessuitetable'").count() === 0) checkAnswer( allTables.filter("tableName = 'hivelisttablessuitetable'"), Row("hivelisttablessuitetable", false)) @@ -69,9 +64,6 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll { checkAnswer( allTables.filter("tableName = 'listtablessuitetable'"), Row("listtablessuitetable", true)) - checkAnswer( - allTables.filter("tableName = 'indblisttablessuitetable'"), - Row("indblisttablessuitetable", true)) assert(allTables.filter("tableName = 'hivelisttablessuitetable'").count() === 0) checkAnswer( allTables.filter("tableName = 'hiveindblisttablessuitetable'"), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index cc294bc3e8bc..bf0db084906c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -17,22 +17,17 @@ package org.apache.spark.sql.hive -import java.io.File +import java.io.{IOException, File} import scala.collection.mutable.ArrayBuffer -import org.scalatest.BeforeAndAfterAll - import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapred.InvalidInputException import org.apache.spark.sql._ +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.parquet.ParquetRelation2 -import org.apache.spark.sql.sources.LogicalRelation +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -40,8 +35,9 @@ import org.apache.spark.util.Utils /** * Tests for persisting tables created though the data sources API into the metastore. */ -class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { - override val sqlContext = TestHive +class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import hiveContext._ + import hiveContext.implicits._ var jsonFilePath: String = _ @@ -415,7 +411,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA |) """.stripMargin) - sql("DROP TABLE jsonTable").collect().foreach(println) + sql("DROP TABLE jsonTable").collect().foreach(i => logInfo(i.toString)) } } @@ -462,23 +458,20 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA checkAnswer(sql("SELECT * FROM savedJsonTable"), df) - // Right now, we cannot append to an existing JSON table. - intercept[RuntimeException] { - df.write.mode(SaveMode.Append).saveAsTable("savedJsonTable") - } - // We can overwrite it. df.write.mode(SaveMode.Overwrite).saveAsTable("savedJsonTable") checkAnswer(sql("SELECT * FROM savedJsonTable"), df) // When the save mode is Ignore, we will do nothing when the table already exists. df.select("b").write.mode(SaveMode.Ignore).saveAsTable("savedJsonTable") - assert(df.schema === table("savedJsonTable").schema) + // TODO in ResolvedDataSource, will convert the schema into nullable = true + // hence the df.schema is not exactly the same as table("savedJsonTable").schema + // assert(df.schema === table("savedJsonTable").schema) checkAnswer(sql("SELECT * FROM savedJsonTable"), df) // Drop table will also delete the data. sql("DROP TABLE savedJsonTable") - intercept[InvalidInputException] { + intercept[IOException] { read.json(catalog.hiveDefaultTableFilePath("savedJsonTable")) } } @@ -554,7 +547,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA "org.apache.spark.sql.json", schema, Map.empty[String, String]) - }.getMessage.contains("'path' must be specified for json data."), + }.getMessage.contains("key not found: path"), "We should complain that path is not specified.") } } @@ -562,10 +555,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA } test("scan a parquet table created through a CTAS statement") { - withSQLConf( - HiveContext.CONVERT_METASTORE_PARQUET.key -> "true", - SQLConf.PARQUET_USE_DATA_SOURCE_API.key -> "true") { - + withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "true") { withTempTable("jt") { (1 to 10).map(i => i -> s"str$i").toDF("a", "b").registerTempTable("jt") @@ -580,9 +570,9 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA Row(3) :: Row(4) :: Nil) table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(p: ParquetRelation2) => // OK + case LogicalRelation(p: ParquetRelation) => // OK case _ => - fail(s"test_parquet_ctas should have be converted to ${classOf[ParquetRelation2]}") + fail(s"test_parquet_ctas should have be converted to ${classOf[ParquetRelation]}") } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala new file mode 100644 index 000000000000..f16c257ab5ab --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} + +class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + private lazy val df = sqlContext.range(10).coalesce(1) + + private def checkTablePath(dbName: String, tableName: String): Unit = { + val metastoreTable = hiveContext.catalog.client.getTable(dbName, tableName) + val expectedPath = hiveContext.catalog.client.getDatabase(dbName).location + "/" + tableName + + assert(metastoreTable.serdeProperties("path") === expectedPath) + } + + test(s"saveAsTable() to non-default database - with USE - Overwrite") { + withTempDatabase { db => + activateDatabase(db) { + df.write.mode(SaveMode.Overwrite).saveAsTable("t") + assert(sqlContext.tableNames().contains("t")) + checkAnswer(sqlContext.table("t"), df) + } + + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df) + + checkTablePath(db, "t") + } + } + + test(s"saveAsTable() to non-default database - without USE - Overwrite") { + withTempDatabase { db => + df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df) + + checkTablePath(db, "t") + } + } + + test(s"createExternalTable() to non-default database - with USE") { + withTempDatabase { db => + activateDatabase(db) { + withTempPath { dir => + val path = dir.getCanonicalPath + df.write.format("parquet").mode(SaveMode.Overwrite).save(path) + + sqlContext.createExternalTable("t", path, "parquet") + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table("t"), df) + + sql( + s""" + |CREATE TABLE t1 + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + assert(sqlContext.tableNames(db).contains("t1")) + checkAnswer(sqlContext.table("t1"), df) + } + } + } + } + + test(s"createExternalTable() to non-default database - without USE") { + withTempDatabase { db => + withTempPath { dir => + val path = dir.getCanonicalPath + df.write.format("parquet").mode(SaveMode.Overwrite).save(path) + sqlContext.createExternalTable(s"$db.t", path, "parquet") + + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df) + + sql( + s""" + |CREATE TABLE $db.t1 + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + assert(sqlContext.tableNames(db).contains("t1")) + checkAnswer(sqlContext.table(s"$db.t1"), df) + } + } + } + + test(s"saveAsTable() to non-default database - with USE - Append") { + withTempDatabase { db => + activateDatabase(db) { + df.write.mode(SaveMode.Overwrite).saveAsTable("t") + df.write.mode(SaveMode.Append).saveAsTable("t") + assert(sqlContext.tableNames().contains("t")) + checkAnswer(sqlContext.table("t"), df.unionAll(df)) + } + + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + + checkTablePath(db, "t") + } + } + + test(s"saveAsTable() to non-default database - without USE - Append") { + withTempDatabase { db => + df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") + df.write.mode(SaveMode.Append).saveAsTable(s"$db.t") + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + + checkTablePath(db, "t") + } + } + + test(s"insertInto() non-default database - with USE") { + withTempDatabase { db => + activateDatabase(db) { + df.write.mode(SaveMode.Overwrite).saveAsTable("t") + assert(sqlContext.tableNames().contains("t")) + + df.write.insertInto(s"$db.t") + checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + } + } + } + + test(s"insertInto() non-default database - without USE") { + withTempDatabase { db => + activateDatabase(db) { + df.write.mode(SaveMode.Overwrite).saveAsTable("t") + assert(sqlContext.tableNames().contains("t")) + } + + assert(sqlContext.tableNames(db).contains("t")) + + df.write.insertInto(s"$db.t") + checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + } + } + + test("Looks up tables in non-default database") { + withTempDatabase { db => + activateDatabase(db) { + sql("CREATE TABLE t (key INT)") + checkAnswer(sqlContext.table("t"), sqlContext.emptyDataFrame) + } + + checkAnswer(sqlContext.table(s"$db.t"), sqlContext.emptyDataFrame) + } + } + + test("Drops a table in a non-default database") { + withTempDatabase { db => + activateDatabase(db) { + sql(s"CREATE TABLE t (key INT)") + assert(sqlContext.tableNames().contains("t")) + assert(!sqlContext.tableNames("default").contains("t")) + } + + assert(!sqlContext.tableNames().contains("t")) + assert(sqlContext.tableNames(db).contains("t")) + + activateDatabase(db) { + sql(s"DROP TABLE t") + assert(!sqlContext.tableNames().contains("t")) + assert(!sqlContext.tableNames("default").contains("t")) + } + + assert(!sqlContext.tableNames().contains("t")) + assert(!sqlContext.tableNames(db).contains("t")) + } + } + + test("Refreshes a table in a non-default database - with USE") { + import org.apache.spark.sql.functions.lit + + withTempDatabase { db => + withTempPath { dir => + val path = dir.getCanonicalPath + + activateDatabase(db) { + sql( + s"""CREATE EXTERNAL TABLE t (id BIGINT) + |PARTITIONED BY (p INT) + |STORED AS PARQUET + |LOCATION '$path' + """.stripMargin) + + checkAnswer(sqlContext.table("t"), sqlContext.emptyDataFrame) + + df.write.parquet(s"$path/p=1") + sql("ALTER TABLE t ADD PARTITION (p=1)") + sql("REFRESH TABLE t") + checkAnswer(sqlContext.table("t"), df.withColumn("p", lit(1))) + + df.write.parquet(s"$path/p=2") + sql("ALTER TABLE t ADD PARTITION (p=2)") + hiveContext.refreshTable("t") + checkAnswer( + sqlContext.table("t"), + df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2)))) + } + } + } + } + + test("Refreshes a table in a non-default database - without USE") { + import org.apache.spark.sql.functions.lit + + withTempDatabase { db => + withTempPath { dir => + val path = dir.getCanonicalPath + + sql( + s"""CREATE EXTERNAL TABLE $db.t (id BIGINT) + |PARTITIONED BY (p INT) + |STORED AS PARQUET + |LOCATION '$path' + """.stripMargin) + + checkAnswer(sqlContext.table(s"$db.t"), sqlContext.emptyDataFrame) + + df.write.parquet(s"$path/p=1") + sql(s"ALTER TABLE $db.t ADD PARTITION (p=1)") + sql(s"REFRESH TABLE $db.t") + checkAnswer(sqlContext.table(s"$db.t"), df.withColumn("p", lit(1))) + + df.write.parquet(s"$path/p=2") + sql(s"ALTER TABLE $db.t ADD PARTITION (p=2)") + hiveContext.refreshTable(s"$db.t") + checkAnswer( + sqlContext.table(s"$db.t"), + df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2)))) + } + } + } + + test("invalid database name and table names") { + { + val message = intercept[AnalysisException] { + df.write.format("parquet").saveAsTable("`d:b`.`t:a`") + }.getMessage + assert(message.contains("is not a valid name for metastore")) + } + + { + val message = intercept[AnalysisException] { + df.write.format("parquet").saveAsTable("`d:b`.`table`") + }.getMessage + assert(message.contains("is not a valid name for metastore")) + } + + withTempPath { dir => + val path = dir.getCanonicalPath + + { + val message = intercept[AnalysisException] { + sql( + s""" + |CREATE TABLE `d:b`.`t:a` (a int) + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + assert(message.contains("is not a valid name for metastore")) + } + + { + val message = intercept[AnalysisException] { + sql( + s""" + |CREATE TABLE `d:b`.`table` (a int) + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + assert(message.contains("is not a valid name for metastore")) + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala new file mode 100644 index 000000000000..49aab85cf1aa --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.sql.Timestamp + +import org.apache.hadoop.hive.conf.HiveConf + +import org.apache.spark.sql.execution.datasources.parquet.ParquetCompatibilityTest +import org.apache.spark.sql.{Row, SQLConf} +import org.apache.spark.sql.hive.test.TestHiveSingleton + +class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with TestHiveSingleton { + /** + * Set the staging directory (and hence path to ignore Parquet files under) + * to that set by [[HiveConf.ConfVars.STAGINGDIR]]. + */ + private val stagingDir = new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR) + + override protected def logParquetSchema(path: String): Unit = { + val schema = readParquetSchema(path, { path => + !path.getName.startsWith("_") && !path.getName.startsWith(stagingDir) + }) + + logInfo( + s"""Schema of the Parquet file written by parquet-avro: + |$schema + """.stripMargin) + } + + private def testParquetHiveCompatibility(row: Row, hiveTypes: String*): Unit = { + withTable("parquet_compat") { + withTempPath { dir => + val path = dir.getCanonicalPath + + // Hive columns are always nullable, so here we append a all-null row. + val rows = row :: Row(Seq.fill(row.length)(null): _*) :: Nil + + // Don't convert Hive metastore Parquet tables to let Hive write those Parquet files. + withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { + withTempTable("data") { + val fields = hiveTypes.zipWithIndex.map { case (typ, index) => s" col_$index $typ" } + + val ddl = + s"""CREATE TABLE parquet_compat( + |${fields.mkString(",\n")} + |) + |STORED AS PARQUET + |LOCATION '$path' + """.stripMargin + + logInfo( + s"""Creating testing Parquet table with the following DDL: + |$ddl + """.stripMargin) + + sqlContext.sql(ddl) + + val schema = sqlContext.table("parquet_compat").schema + val rowRDD = sqlContext.sparkContext.parallelize(rows).coalesce(1) + sqlContext.createDataFrame(rowRDD, schema).registerTempTable("data") + sqlContext.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") + } + } + + logParquetSchema(path) + + // Unfortunately parquet-hive doesn't add `UTF8` annotation to BINARY when writing strings. + // Have to assume all BINARY values are strings here. + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { + checkAnswer(sqlContext.read.parquet(path), rows) + } + } + } + } + + test("simple primitives") { + testParquetHiveCompatibility( + Row(true, 1.toByte, 2.toShort, 3, 4.toLong, 5.1f, 6.1d, "foo"), + "BOOLEAN", "TINYINT", "SMALLINT", "INT", "BIGINT", "FLOAT", "DOUBLE", "STRING") + } + + test("SPARK-10177 timestamp") { + testParquetHiveCompatibility(Row(Timestamp.valueOf("2015-08-24 00:31:00")), "TIMESTAMP") + } + + test("array") { + testParquetHiveCompatibility( + Row( + Seq[Integer](1: Integer, null, 2: Integer, null), + Seq[String]("foo", null, "bar", null), + Seq[Seq[Integer]]( + Seq[Integer](1: Integer, null), + Seq[Integer](2: Integer, null))), + "ARRAY", + "ARRAY", + "ARRAY>") + } + + test("map") { + testParquetHiveCompatibility( + Row( + Map[Integer, String]( + (1: Integer) -> "foo", + (2: Integer) -> null)), + "MAP") + } + + // HIVE-11625: Parquet map entries with null keys are dropped by Hive + ignore("map entries with null keys") { + testParquetHiveCompatibility( + Row( + Map[Integer, String]( + null.asInstanceOf[Integer] -> "bar", + null.asInstanceOf[Integer] -> null)), + "MAP") + } + + test("struct") { + testParquetHiveCompatibility( + Row(Row(1, Seq("foo", "bar", null))), + "STRUCT>") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index 017bc2adc103..f542a5a02508 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -19,49 +19,49 @@ package org.apache.spark.sql.hive import com.google.common.io.Files -import org.apache.spark.sql.{QueryTest, _} import org.apache.spark.util.Utils +import org.apache.spark.sql.{QueryTest, _} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import hiveContext.implicits._ -class QueryPartitionSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.hive.test.TestHive - import ctx.implicits._ - import ctx.sql - - test("SPARK-5068: query data when path doesn't exist"){ - val testData = ctx.sparkContext.parallelize( - (1 to 10).map(i => TestData(i, i.toString))).toDF() - testData.registerTempTable("testData") + test("SPARK-5068: query data when path doesn't exist") { + withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "true")) { + val testData = sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString))).toDF() + testData.registerTempTable("testData") - val tmpDir = Files.createTempDir() - // create the table for test - sql(s"CREATE TABLE table_with_partition(key int,value string) " + - s"PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + - "SELECT key,value FROM testData") + val tmpDir = Files.createTempDir() + // create the table for test + sql(s"CREATE TABLE table_with_partition(key int,value string) " + + s"PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + + "SELECT key,value FROM testData") - // test for the exist path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect - ++ testData.toDF.collect ++ testData.toDF.collect) + // test for the exist path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toDF.collect ++ testData.toDF.collect + ++ testData.toDF.collect ++ testData.toDF.collect) - // delete the path of one partition - tmpDir.listFiles - .find { f => f.isDirectory && f.getName().startsWith("ds=") } - .foreach { f => Utils.deleteRecursively(f) } + // delete the path of one partition + tmpDir.listFiles + .find { f => f.isDirectory && f.getName().startsWith("ds=") } + .foreach { f => Utils.deleteRecursively(f) } - // test for after delete the path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) + // test for after delete the path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) - sql("DROP TABLE table_with_partition") - sql("DROP TABLE createAndInsertTest") + sql("DROP TABLE table_with_partition") + sql("DROP TABLE createAndInsertTest") + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index f067ea0d4fc7..6a692d6fce56 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -17,24 +17,15 @@ package org.apache.spark.sql.hive -import org.scalatest.BeforeAndAfterAll - import scala.reflect.ClassTag import org.apache.spark.sql.{Row, SQLConf, QueryTest} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.execution._ +import org.apache.spark.sql.hive.test.TestHiveSingleton -class StatisticsSuite extends QueryTest with BeforeAndAfterAll { - - private lazy val ctx: HiveContext = { - val ctx = org.apache.spark.sql.hive.test.TestHive - ctx.reset() - ctx.cacheTables = false - ctx - } - - import ctx.sql +class StatisticsSuite extends QueryTest with TestHiveSingleton { + import hiveContext.sql test("parse analyze commands") { def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { @@ -77,7 +68,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { test("analyze MetastoreRelations") { def queryTotalSize(tableName: String): BigInt = - ctx.catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes + hiveContext.catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes // Non-partitioned table sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() @@ -111,7 +102,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { |SELECT * FROM src """.stripMargin).collect() - assert(queryTotalSize("analyzeTable_part") === ctx.conf.defaultSizeInBytes) + assert(queryTotalSize("analyzeTable_part") === hiveContext.conf.defaultSizeInBytes) sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan") @@ -122,9 +113,9 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { // Try to analyze a temp table sql("""SELECT * FROM src""").registerTempTable("tempTable") intercept[UnsupportedOperationException] { - ctx.analyze("tempTable") + hiveContext.analyze("tempTable") } - ctx.catalog.unregisterTable(Seq("tempTable")) + hiveContext.catalog.unregisterTable(Seq("tempTable")) } test("estimates the size of a test MetastoreRelation") { @@ -152,8 +143,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { val sizes = df.queryExecution.analyzed.collect { case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(0) <= ctx.conf.autoBroadcastJoinThreshold - && sizes(1) <= ctx.conf.autoBroadcastJoinThreshold, + assert(sizes.size === 2 && sizes(0) <= hiveContext.conf.autoBroadcastJoinThreshold + && sizes(1) <= hiveContext.conf.autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be @@ -164,15 +155,15 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { checkAnswer(df, expectedAnswer) // check correctness of output - ctx.conf.settings.synchronized { - val tmp = ctx.conf.autoBroadcastJoinThreshold + hiveContext.conf.settings.synchronized { + val tmp = hiveContext.conf.autoBroadcastJoinThreshold sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1""") df = sql(query) bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") - val shj = df.queryExecution.sparkPlan.collect { case j: ShuffledHashJoin => j } + val shj = df.queryExecution.sparkPlan.collect { case j: SortMergeJoin => j } assert(shj.size === 1, "ShuffledHashJoin should be planned when BroadcastHashJoin is turned off") @@ -208,8 +199,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { .isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(1) <= ctx.conf.autoBroadcastJoinThreshold - && sizes(0) <= ctx.conf.autoBroadcastJoinThreshold, + assert(sizes.size === 2 && sizes(1) <= hiveContext.conf.autoBroadcastJoinThreshold + && sizes(0) <= hiveContext.conf.autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be @@ -222,8 +213,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { checkAnswer(df, answer) // check correctness of output - ctx.conf.settings.synchronized { - val tmp = ctx.conf.autoBroadcastJoinThreshold + hiveContext.conf.settings.synchronized { + val tmp = hiveContext.conf.autoBroadcastJoinThreshold sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") df = sql(leftSemiJoinQuery) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 4056dee77757..3ab457681119 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -18,19 +18,18 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.hive.test.TestHiveSingleton case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.hive.test.TestHive +class UDFSuite extends QueryTest with TestHiveSingleton { test("UDF case insensitive") { - ctx.udf.register("random0", () => { Math.random() }) - ctx.udf.register("RANDOM1", () => { Math.random() }) - ctx.udf.register("strlenScala", (_: String).length + (_: Int)) - assert(ctx.sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) - assert(ctx.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) - assert(ctx.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) + hiveContext.udf.register("random0", () => { Math.random() }) + hiveContext.udf.register("RANDOM1", () => { Math.random() }) + hiveContext.udf.register("strlenScala", (_: String).length + (_: Int)) + assert(hiveContext.sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(hiveContext.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(hiveContext.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala new file mode 100644 index 000000000000..5e7b93d45710 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import java.util.Collections + +import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.hadoop.hive.serde.serdeConstants + +import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * A set of tests for the filter conversion logic used when pushing partition pruning into the + * metastore + */ +class FiltersSuite extends SparkFunSuite with Logging { + private val shim = new Shim_v0_13 + + private val testTable = new org.apache.hadoop.hive.ql.metadata.Table("default", "test") + private val varCharCol = new FieldSchema() + varCharCol.setName("varchar") + varCharCol.setType(serdeConstants.VARCHAR_TYPE_NAME) + testTable.setPartCols(Collections.singletonList(varCharCol)) + + filterTest("string filter", + (a("stringcol", StringType) > Literal("test")) :: Nil, + "stringcol > \"test\"") + + filterTest("string filter backwards", + (Literal("test") > a("stringcol", StringType)) :: Nil, + "\"test\" > stringcol") + + filterTest("int filter", + (a("intcol", IntegerType) === Literal(1)) :: Nil, + "intcol = 1") + + filterTest("int filter backwards", + (Literal(1) === a("intcol", IntegerType)) :: Nil, + "1 = intcol") + + filterTest("int and string filter", + (Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil, + "1 = intcol and \"a\" = strcol") + + filterTest("skip varchar", + (Literal("") === a("varchar", StringType)) :: Nil, + "") + + private def filterTest(name: String, filters: Seq[Expression], result: String) = { + test(name){ + val converted = shim.convertFilters(testTable, filters) + if (converted != result) { + fail( + s"Expected filters ${filters.mkString(",")} to convert to '$result' but got '$converted'") + } + } + } + + private def a(name: String, dataType: DataType) = AttributeReference(name, dataType)() +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 9a571650b6e2..f0bb77092c0c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -17,8 +17,13 @@ package org.apache.spark.sql.hive.client +import java.io.File + +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo} import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.Utils /** @@ -28,6 +33,12 @@ import org.apache.spark.util.Utils * is not fully tested. */ class VersionsSuite extends SparkFunSuite with Logging { + + // Do not use a temp path here to speed up subsequent executions of the unit test during + // development. + private val ivyPath = Some( + new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath()) + private def buildConf() = { lazy val warehousePath = Utils.createTempDir() lazy val metastorePath = Utils.createTempDir() @@ -38,7 +49,9 @@ class VersionsSuite extends SparkFunSuite with Logging { } test("success sanity check") { - val badClient = IsolatedClientLoader.forVersion("13", buildConf()).client + val badClient = IsolatedClientLoader.forVersion(HiveContext.hiveExecutionVersion, + buildConf(), + ivyPath).client val db = new HiveDatabase("default", "") badClient.createDatabase(db) } @@ -67,19 +80,22 @@ class VersionsSuite extends SparkFunSuite with Logging { // TODO: currently only works on mysql where we manually create the schema... ignore("failure sanity check") { val e = intercept[Throwable] { - val badClient = quietly { IsolatedClientLoader.forVersion("13", buildConf()).client } + val badClient = quietly { + IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).client + } } assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") } - private val versions = Seq("12", "13", "14") + private val versions = Seq("12", "13", "14", "1.0.0", "1.1.0", "1.2.0") private var client: ClientInterface = null versions.foreach { version => test(s"$version: create client") { client = null - client = IsolatedClientLoader.forVersion(version, buildConf()).client + System.gc() // Hack to avoid SEGV on some JVM versions. + client = IsolatedClientLoader.forVersion(version, buildConf(), ivyPath).client } test(s"$version: createDatabase") { @@ -141,6 +157,12 @@ class VersionsSuite extends SparkFunSuite with Logging { client.getAllPartitions(client.getTable("default", "src_part")) } + test(s"$version: getPartitionsByFilter") { + client.getPartitionsByFilter(client.getTable("default", "src_part"), Seq(EqualTo( + AttributeReference("key", IntegerType, false)(NamedExpression.newExprId), + Literal(1)))) + } + test(s"$version: loadPartition") { client.loadPartition( emptyDir, @@ -170,5 +192,12 @@ class VersionsSuite extends SparkFunSuite with Logging { false, false) } + + test(s"$version: create index and reset") { + client.runSqlHive("CREATE TABLE indexed_table (key INT)") + client.runSqlHive("CREATE INDEX index_1 ON TABLE indexed_table(key) " + + "as 'COMPACT' WITH DEFERRED REBUILD") + client.reset() + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala new file mode 100644 index 000000000000..a73b1bd52c09 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -0,0 +1,599 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} +import org.apache.spark.sql.hive.test.TestHiveSingleton + +abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import testImplicits._ + + var originalUseAggregate2: Boolean = _ + + override def beforeAll(): Unit = { + originalUseAggregate2 = sqlContext.conf.useSqlAggregate2 + sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, "true") + val data1 = Seq[(Integer, Integer)]( + (1, 10), + (null, -60), + (1, 20), + (1, 30), + (2, 0), + (null, -10), + (2, -1), + (2, null), + (2, null), + (null, 100), + (3, null), + (null, null), + (3, null)).toDF("key", "value") + data1.write.saveAsTable("agg1") + + val data2 = Seq[(Integer, Integer, Integer)]( + (1, 10, -10), + (null, -60, 60), + (1, 30, -30), + (1, 30, 30), + (2, 1, 1), + (null, -10, 10), + (2, -1, null), + (2, 1, 1), + (2, null, 1), + (null, 100, -10), + (3, null, 3), + (null, null, null), + (3, null, null)).toDF("key", "value1", "value2") + data2.write.saveAsTable("agg2") + + val emptyDF = sqlContext.createDataFrame( + sparkContext.emptyRDD[Row], + StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil)) + emptyDF.registerTempTable("emptyTable") + + // Register UDAFs + sqlContext.udf.register("mydoublesum", new MyDoubleSum) + sqlContext.udf.register("mydoubleavg", new MyDoubleAvg) + } + + override def afterAll(): Unit = { + sqlContext.sql("DROP TABLE IF EXISTS agg1") + sqlContext.sql("DROP TABLE IF EXISTS agg2") + sqlContext.dropTempTable("emptyTable") + sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, originalUseAggregate2.toString) + } + + test("empty table") { + // If there is no GROUP BY clause and the table is empty, we will generate a single row. + checkAnswer( + sqlContext.sql( + """ + |SELECT + | AVG(value), + | COUNT(*), + | COUNT(key), + | COUNT(value), + | FIRST(key), + | LAST(value), + | MAX(key), + | MIN(value), + | SUM(key) + |FROM emptyTable + """.stripMargin), + Row(null, 0, 0, 0, null, null, null, null, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | AVG(value), + | COUNT(*), + | COUNT(key), + | COUNT(value), + | FIRST(key), + | LAST(value), + | MAX(key), + | MIN(value), + | SUM(key), + | COUNT(DISTINCT value) + |FROM emptyTable + """.stripMargin), + Row(null, 0, 0, 0, null, null, null, null, null, 0) :: Nil) + + // If there is a GROUP BY clause and the table is empty, there is no output. + checkAnswer( + sqlContext.sql( + """ + |SELECT + | AVG(value), + | COUNT(*), + | COUNT(value), + | FIRST(value), + | LAST(value), + | MAX(value), + | MIN(value), + | SUM(value), + | COUNT(DISTINCT value) + |FROM emptyTable + |GROUP BY key + """.stripMargin), + Nil) + } + + test("null literal") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | AVG(null), + | COUNT(null), + | FIRST(null), + | LAST(null), + | MAX(null), + | MIN(null), + | SUM(null) + """.stripMargin), + Row(null, 0, null, null, null, null, null) :: Nil) + } + + test("only do grouping") { + checkAnswer( + sqlContext.sql( + """ + |SELECT key + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT DISTINCT value1, key + |FROM agg2 + """.stripMargin), + Row(10, 1) :: + Row(-60, null) :: + Row(30, 1) :: + Row(1, 2) :: + Row(-10, null) :: + Row(-1, 2) :: + Row(null, 2) :: + Row(100, null) :: + Row(null, 3) :: + Row(null, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT value1, key + |FROM agg2 + |GROUP BY key, value1 + """.stripMargin), + Row(10, 1) :: + Row(-60, null) :: + Row(30, 1) :: + Row(1, 2) :: + Row(-10, null) :: + Row(-1, 2) :: + Row(null, 2) :: + Row(100, null) :: + Row(null, 3) :: + Row(null, null) :: Nil) + } + + test("case in-sensitive resolution") { + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value), kEY - 100 + |FROM agg1 + |GROUP BY Key - 100 + """.stripMargin), + Row(20.0, -99) :: Row(-0.5, -98) :: Row(null, -97) :: Row(10.0, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT sum(distinct value1), kEY - 100, count(distinct value1) + |FROM agg2 + |GROUP BY Key - 100 + """.stripMargin), + Row(40, -99, 2) :: Row(0, -98, 2) :: Row(null, -97, 0) :: Row(30, null, 3) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT valUe * key - 100 + |FROM agg1 + |GROUP BY vAlue * keY - 100 + """.stripMargin), + Row(-90) :: + Row(-80) :: + Row(-70) :: + Row(-100) :: + Row(-102) :: + Row(null) :: Nil) + } + + test("test average no key in output") { + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(-0.5) :: Row(20.0) :: Row(null) :: Row(10.0) :: Nil) + } + + test("test average") { + checkAnswer( + sqlContext.sql( + """ + |SELECT key, avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Row(null, 10.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value), key + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(20.0, 1) :: Row(-0.5, 2) :: Row(null, 3) :: Row(10.0, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value) + 1.5, key + 10 + |FROM agg1 + |GROUP BY key + 10 + """.stripMargin), + Row(21.5, 11) :: Row(1.0, 12) :: Row(null, 13) :: Row(11.5, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value) FROM agg1 + """.stripMargin), + Row(11.125) :: Nil) + } + + test("udaf") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | mydoublesum(value + 1.5 * key), + | mydoubleavg(value), + | avg(value - key), + | mydoublesum(value - 1.5 * key), + | avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(1, 64.5, 120.0, 19.0, 55.5, 20.0) :: + Row(2, 5.0, 99.5, -2.5, -7.0, -0.5) :: + Row(3, null, null, null, null, null) :: + Row(null, null, 110.0, null, null, 10.0) :: Nil) + } + + test("non-AlgebraicAggregate aggreguate function") { + checkAnswer( + sqlContext.sql( + """ + |SELECT mydoublesum(value), key + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(60.0, 1) :: Row(-1.0, 2) :: Row(null, 3) :: Row(30.0, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT mydoublesum(value) FROM agg1 + """.stripMargin), + Row(89.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT mydoublesum(null) + """.stripMargin), + Row(null) :: Nil) + } + + test("non-AlgebraicAggregate and AlgebraicAggregate aggreguate function") { + checkAnswer( + sqlContext.sql( + """ + |SELECT mydoublesum(value), key, avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(60.0, 1, 20.0) :: + Row(-1.0, 2, -0.5) :: + Row(null, 3, null) :: + Row(30.0, null, 10.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | mydoublesum(value + 1.5 * key), + | avg(value - key), + | key, + | mydoublesum(value - 1.5 * key), + | avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(64.5, 19.0, 1, 55.5, 20.0) :: + Row(5.0, -2.5, 2, -7.0, -0.5) :: + Row(null, null, 3, null, null) :: + Row(null, null, null, null, 10.0) :: Nil) + } + + test("single distinct column set") { + // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword. + checkAnswer( + sqlContext.sql( + """ + |SELECT + | min(distinct value1), + | sum(distinct value1), + | avg(value1), + | avg(value2), + | max(distinct value1) + |FROM agg2 + """.stripMargin), + Row(-60, 70.0, 101.0/9.0, 5.6, 100)) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | mydoubleavg(distinct value1), + | avg(value1), + | avg(value2), + | key, + | mydoubleavg(value1 - 1), + | mydoubleavg(distinct value1) * 0.1, + | avg(value1 + value2) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) :: + Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) :: + Row(null, null, 3.0, 3, null, null, null) :: + Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | mydoubleavg(distinct value1), + | mydoublesum(value2), + | mydoublesum(distinct value1), + | mydoubleavg(distinct value1), + | mydoubleavg(value1) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) :: + Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: + Row(3, null, 3.0, null, null, null) :: + Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | count(value1), + | count(*), + | count(1), + | count(DISTINCT value1), + | key + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(3, 3, 3, 2, 1) :: + Row(3, 4, 4, 2, 2) :: + Row(0, 2, 2, 0, 3) :: + Row(3, 4, 4, 3, null) :: Nil) + } + + test("test count") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | count(value2), + | value1, + | count(*), + | count(1), + | key + |FROM agg2 + |GROUP BY key, value1 + """.stripMargin), + Row(1, 10, 1, 1, 1) :: + Row(1, -60, 1, 1, null) :: + Row(2, 30, 2, 2, 1) :: + Row(2, 1, 2, 2, 2) :: + Row(1, -10, 1, 1, null) :: + Row(0, -1, 1, 1, 2) :: + Row(1, null, 1, 1, 2) :: + Row(1, 100, 1, 1, null) :: + Row(1, null, 2, 2, 3) :: + Row(0, null, 1, 1, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | count(value2), + | value1, + | count(*), + | count(1), + | key, + | count(DISTINCT abs(value2)) + |FROM agg2 + |GROUP BY key, value1 + """.stripMargin), + Row(1, 10, 1, 1, 1, 1) :: + Row(1, -60, 1, 1, null, 1) :: + Row(2, 30, 2, 2, 1, 1) :: + Row(2, 1, 2, 2, 2, 1) :: + Row(1, -10, 1, 1, null, 1) :: + Row(0, -1, 1, 1, 2, 0) :: + Row(1, null, 1, 1, 2, 1) :: + Row(1, 100, 1, 1, null, 1) :: + Row(1, null, 2, 2, 3, 1) :: + Row(0, null, 1, 1, null, 0) :: Nil) + } + + test("test Last implemented based on AggregateExpression1") { + // TODO: Remove this test once we remove AggregateExpression1. + import org.apache.spark.sql.functions._ + val df = Seq((1, 1), (2, 2), (3, 3)).toDF("i", "j").repartition(1) + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.USE_SQL_AGGREGATE2.key -> "false") { + + checkAnswer( + df.groupBy("i").agg(last("j")), + df + ) + } + } + + test("error handling") { + withSQLConf("spark.sql.useAggregate2" -> "false") { + val errorMessage = intercept[AnalysisException] { + sqlContext.sql( + """ + |SELECT + | key, + | sum(value + 1.5 * key), + | mydoublesum(value), + | mydoubleavg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin).collect() + }.getMessage + assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) + } + } +} + +class SortBasedAggregationQuerySuite extends AggregationQuerySuite { + + var originalUnsafeEnabled: Boolean = _ + + override def beforeAll(): Unit = { + originalUnsafeEnabled = sqlContext.conf.unsafeEnabled + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "false") + super.beforeAll() + } + + override def afterAll(): Unit = { + super.afterAll() + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) + } +} + +class TungstenAggregationQuerySuite extends AggregationQuerySuite { + + var originalUnsafeEnabled: Boolean = _ + + override def beforeAll(): Unit = { + originalUnsafeEnabled = sqlContext.conf.unsafeEnabled + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true") + super.beforeAll() + } + + override def afterAll(): Unit = { + super.afterAll() + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) + } +} + +class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQuerySuite { + + var originalUnsafeEnabled: Boolean = _ + + override def beforeAll(): Unit = { + originalUnsafeEnabled = sqlContext.conf.unsafeEnabled + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true") + super.beforeAll() + } + + override def afterAll(): Unit = { + super.afterAll() + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) + sqlContext.conf.unsetConf("spark.sql.TungstenAggregate.testFallbackStartsAt") + } + + override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + (0 to 2).foreach { fallbackStartsAt => + sqlContext.setConf( + "spark.sql.TungstenAggregate.testFallbackStartsAt", + fallbackStartsAt.toString) + + // Create a new df to make sure its physical operator picks up + // spark.sql.TungstenAggregate.testFallbackStartsAt. + // todo: remove it? + val newActual = DataFrame(sqlContext, actual.logicalPlan) + + QueryTest.checkAnswer(newActual, expectedAnswer) match { + case Some(errorMessage) => + val newErrorMessage = + s""" + |The following aggregation query failed when using TungstenAggregate with + |controlled fallback (it falls back to sort-based aggregation once it has processed + |$fallbackStartsAt input rows). The query is + |${actual.queryExecution} + | + |$errorMessage + """.stripMargin + + fail(newErrorMessage) + case None => + } + } + } + + // Override it to make sure we call the actually overridden checkAnswer. + override protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = { + checkAnswer(df, Seq(expectedAnswer)) + } + + // Override it to make sure we call the actually overridden checkAnswer. + override protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = { + checkAnswer(df, expectedAnswer.collect()) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala index b0d3dd44daed..e38d1eb5779f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala @@ -25,8 +25,10 @@ class ConcurrentHiveSuite extends SparkFunSuite with BeforeAndAfterAll { ignore("multiple instances not supported") { test("Multiple Hive Instances") { (1 to 10).map { i => + val conf = new SparkConf() + conf.set("spark.ui.enabled", "false") val ts = - new TestHiveContext(new SparkContext("local", s"TestSQLContext$i", new SparkConf())) + new TestHiveContext(new SparkContext("local", s"TestSQLContext$i", conf)) ts.executeSql("SHOW TABLES").toRdd.collect() ts.executeSql("SELECT * FROM src").toRdd.collect() ts.executeSql("SHOW TABLES").toRdd.collect() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index c9dd4c0935a7..aa95ba94fa87 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -19,14 +19,16 @@ package org.apache.spark.sql.hive.execution import java.io._ +import scala.util.control.NonFatal + import org.scalatest.{BeforeAndAfterAll, GivenWhenThen} -import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.sql.sources.DescribeCommand -import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} +import org.apache.spark.sql.execution.datasources.DescribeCommand import org.apache.spark.sql.hive.test.TestHive /** @@ -40,7 +42,7 @@ import org.apache.spark.sql.hive.test.TestHive * configured using system properties. */ abstract class HiveComparisonTest - extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen with Logging { + extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen { /** * When set, any cache files that result in test failures will be deleted. Used when the test @@ -124,7 +126,7 @@ abstract class HiveComparisonTest protected val cacheDigest = java.security.MessageDigest.getInstance("MD5") protected def getMd5(str: String): String = { val digest = java.security.MessageDigest.getInstance("MD5") - digest.update(str.getBytes("utf-8")) + digest.update(str.replaceAll(System.lineSeparator(), "\n").getBytes("utf-8")) new java.math.BigInteger(1, digest.digest).toString(16) } @@ -370,7 +372,11 @@ abstract class HiveComparisonTest // Check that the results match unless its an EXPLAIN query. val preparedHive = prepareAnswer(hiveQuery, hive) - if ((!hiveQuery.logical.isInstanceOf[ExplainCommand]) && preparedHive != catalyst) { + // We will ignore the ExplainCommand, ShowFunctions, DescribeFunction + if ((!hiveQuery.logical.isInstanceOf[ExplainCommand]) && + (!hiveQuery.logical.isInstanceOf[ShowFunctions]) && + (!hiveQuery.logical.isInstanceOf[DescribeFunction]) && + preparedHive != catalyst) { val hivePrintOut = s"== HIVE - ${preparedHive.size} row(s) ==" +: preparedHive val catalystPrintOut = s"== CATALYST - ${catalyst.size} row(s) ==" +: catalyst @@ -382,11 +388,45 @@ abstract class HiveComparisonTest hiveCacheFiles.foreach(_.delete()) } + // If this query is reading other tables that were created during this test run + // also print out the query plans and results for those. + val computedTablesMessages: String = try { + val tablesRead = new TestHive.QueryExecution(query).executedPlan.collect { + case ts: HiveTableScan => ts.relation.tableName + }.toSet + + TestHive.reset() + val executions = queryList.map(new TestHive.QueryExecution(_)) + executions.foreach(_.toRdd) + val tablesGenerated = queryList.zip(executions).flatMap { + case (q, e) => e.executedPlan.collect { + case i: InsertIntoHiveTable if tablesRead contains i.table.tableName => + (q, e, i) + } + } + + tablesGenerated.map { case (hiveql, execution, insert) => + s""" + |=== Generated Table === + |$hiveql + |$execution + |== Results == + |${insert.child.execute().collect().mkString("\n")} + """.stripMargin + }.mkString("\n") + + } catch { + case NonFatal(e) => + logError("Failed to compute generated tables", e) + s"Couldn't compute dependent tables: $e" + } + val errorMessage = s""" |Results do not match for $testCaseName: |$hiveQuery\n${hiveQuery.analyzed.output.map(_.name).mkString("\t")} |$resultComparison + |$computedTablesMessages """.stripMargin stringToFile(new File(wrongDirectory, testCaseName), errorMessage + consoleTestCase) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 697211222b90..94162da4eae1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -18,12 +18,14 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.hive.test.TestHiveSingleton /** * A set of tests that validates support for Hive Explain command. */ -class HiveExplainSuite extends QueryTest { +class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + test("explain extended command") { checkExistence(sql(" explain select * from src where key=123 "), true, "== Physical Plan ==") @@ -36,7 +38,7 @@ class HiveExplainSuite extends QueryTest { "== Analyzed Logical Plan ==", "== Optimized Logical Plan ==", "== Physical Plan ==", - "Code Generation", "== RDD ==") + "Code Generation") } test("explain create table command") { @@ -74,4 +76,30 @@ class HiveExplainSuite extends QueryTest { "Limit", "src") } + + test("SPARK-6212: The EXPLAIN output of CTAS only shows the analyzed plan") { + withTempTable("jt") { + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) + hiveContext.read.json(rdd).registerTempTable("jt") + val outputs = sql( + s""" + |EXPLAIN EXTENDED + |CREATE TABLE t1 + |AS + |SELECT * FROM jt + """.stripMargin).collect().map(_.mkString).mkString + + val shouldContain = + "== Parsed Logical Plan ==" :: "== Analyzed Logical Plan ==" :: "Subquery" :: + "== Optimized Logical Plan ==" :: "== Physical Plan ==" :: + "CreateTableAsSelect" :: "InsertIntoHiveTable" :: "jt" :: Nil + for (key <- shouldContain) { + assert(outputs.contains(key), s"$key doesn't exist in result") + } + + val physicalIndex = outputs.indexOf("== Physical Plan ==") + assert(!outputs.substring(physicalIndex).contains("Subquery"), + "Physical Plan should not contain Subquery since it's eliminated by optimizer") + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala index efbef68cd444..0d4c7f86b315 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala @@ -18,14 +18,16 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} /** * A set of tests that validates commands can also be queried by like a table */ -class HiveOperatorQueryableSuite extends QueryTest { +class HiveOperatorQueryableSuite extends QueryTest with TestHiveSingleton { + import hiveContext._ + test("SPARK-5324 query result of describe command") { - loadTestTable("src") + hiveContext.loadTestTable("src") // register a describe command to be a temp table sql("desc src").registerTempTable("mydesc") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index bdb53ddf59c1..cd055f9eca37 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.functions._ import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.hive.test.TestHiveSingleton -class HivePlanTest extends QueryTest { - import TestHive._ - import TestHive.implicits._ +class HivePlanTest extends QueryTest with TestHiveSingleton { + import hiveContext.sql + import hiveContext.implicits._ test("udf constant folding") { Seq.empty[Tuple1[Int]].toDF("a").registerTempTable("t") @@ -31,4 +34,19 @@ class HivePlanTest extends QueryTest { comparePlans(optimized, correctAnswer) } + + test("window expressions sharing the same partition by and order by clause") { + val df = Seq.empty[(Int, String, Int, Int)].toDF("id", "grp", "seq", "val") + val window = Window. + partitionBy($"grp"). + orderBy($"val") + val query = df.select( + $"id", + sum($"val").over(window.rowsBetween(-1, 1)), + sum($"val").over(window.rangeBetween(-1, 1)) + ) + val plan = query.queryExecution.analyzed + assert(plan.collect{ case w: logical.Window => w }.size === 1, + "Should have only 1 Window operator.") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 51dabc67fa7c..fe63ad568319 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.hive._ +import org.apache.spark.sql.hive.test.TestHiveContext import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ @@ -52,14 +53,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting Locale.setDefault(Locale.US) - sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}") - // The function source code can be found at: - // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF - sql( - """ - |CREATE TEMPORARY FUNCTION udtf_count2 - |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' - """.stripMargin) } override def afterAll() { @@ -69,15 +62,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql("DROP TEMPORARY FUNCTION udtf_count2") } - createQueryTest("Test UDTF.close in Lateral Views", - """ - |SELECT key, cc - |FROM src LATERAL VIEW udtf_count2(value) dd AS cc - """.stripMargin, false) // false mean we have to keep the temp function in registry - - createQueryTest("Test UDTF.close in SELECT", - "SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) table", false) - test("SPARK-4908: concurrent hive native commands") { (1 to 100).par.map { _ => sql("USE default") @@ -85,6 +69,60 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } + createQueryTest("SPARK-8976 Wrong Result for Rollup #1", + """ + SELECT count(*) AS cnt, key % 5,GROUPING__ID FROM src group by key%5 WITH ROLLUP + """.stripMargin) + + createQueryTest("SPARK-8976 Wrong Result for Rollup #2", + """ + SELECT + count(*) AS cnt, + key % 5 as k1, + key-5 as k2, + GROUPING__ID as k3 + FROM src group by key%5, key-5 + WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin) + + createQueryTest("SPARK-8976 Wrong Result for Rollup #3", + """ + SELECT + count(*) AS cnt, + key % 5 as k1, + key-5 as k2, + GROUPING__ID as k3 + FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5 + WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin) + + createQueryTest("SPARK-8976 Wrong Result for CUBE #1", + """ + SELECT count(*) AS cnt, key % 5,GROUPING__ID FROM src group by key%5 WITH CUBE + """.stripMargin) + + createQueryTest("SPARK-8976 Wrong Result for CUBE #2", + """ + SELECT + count(*) AS cnt, + key % 5 as k1, + key-5 as k2, + GROUPING__ID as k3 + FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5 + WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin) + + createQueryTest("SPARK-8976 Wrong Result for GroupingSet", + """ + SELECT + count(*) AS cnt, + key % 5 as k1, + key-5 as k2, + GROUPING__ID as k3 + FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5 + GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin) + createQueryTest("insert table with generator with column name", """ | CREATE TABLE gen_tmp (key Int); @@ -122,8 +160,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("! operator", """ |SELECT a FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 2 AS a FROM src LIMIT 1) table + | SELECT 1 AS a UNION ALL SELECT 2 AS a) t |WHERE !(a>1) """.stripMargin) @@ -132,7 +169,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { lower("AA"), "10", repeat(lower("AA"), 3), "11", lower(repeat("AA", 3)), "12", - printf("Bb%d", 12), "13", + printf("bb%d", 12), "13", repeat(printf("s%d", 14), 2), "14") FROM src LIMIT 1""") createQueryTest("NaN to Decimal", @@ -175,71 +212,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |FROM src LIMIT 1; """.stripMargin) - createQueryTest("count distinct 0 values", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 'a' AS a FROM src LIMIT 0) table - """.stripMargin) - - createQueryTest("count distinct 1 value strings", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 'a' AS a FROM src LIMIT 1 UNION ALL - | SELECT 'b' AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 1 value", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 1 AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 2 values", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 2 AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 2 values including null", - """ - |SELECT COUNT(DISTINCT a, 1) FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT null AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 1 value + null", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT null AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 1 value long", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1L AS a FROM src LIMIT 1 UNION ALL - | SELECT 1L AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 2 values long", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1L AS a FROM src LIMIT 1 UNION ALL - | SELECT 2L AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 1 value + null long", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1L AS a FROM src LIMIT 1 UNION ALL - | SELECT 1L AS a FROM src LIMIT 1 UNION ALL - | SELECT null AS a FROM src LIMIT 1) table - """.stripMargin) - createQueryTest("null case", "SELECT case when(true) then 1 else null end FROM src LIMIT 1") @@ -324,20 +296,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { | FROM src LIMIT 1 """.stripMargin) - createQueryTest("Date comparison test 2", - "SELECT CAST(CAST(0 AS timestamp) AS date) > CAST(0 AS timestamp) FROM src LIMIT 1") - - createQueryTest("Date cast", - """ - | SELECT - | CAST(CAST(0 AS timestamp) AS date), - | CAST(CAST(CAST(0 AS timestamp) AS date) AS string), - | CAST(0 AS timestamp), - | CAST(CAST(0 AS timestamp) AS string), - | CAST(CAST(CAST('1970-01-01 23:00:00' AS timestamp) AS date) AS timestamp) - | FROM src LIMIT 1 - """.stripMargin) - createQueryTest("Simple Average", "SELECT AVG(key) FROM src") @@ -470,7 +428,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' |USING 'cat' AS (tKey, tValue) ROW FORMAT SERDE |'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' FROM src; - """.stripMargin.replaceAll("\n", " ")) + """.stripMargin.replaceAll(System.lineSeparator(), " ")) test("transform with SerDe2") { @@ -489,7 +447,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |('avro.schema.literal'='{"namespace": "testing.hive.avro.serde","name": |"src","type": "record","fields": [{"name":"key","type":"int"}]}') |FROM small_src - """.stripMargin.replaceAll("\n", " ")).collect().head + """.stripMargin.replaceAll(System.lineSeparator(), " ")).collect().head assert(expected(0) === res(0)) } @@ -501,7 +459,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |('serialization.last.column.takes.rest'='true') USING 'cat' AS (tKey, tValue) |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' |WITH SERDEPROPERTIES ('serialization.last.column.takes.rest'='true') FROM src; - """.stripMargin.replaceAll("\n", " ")) + """.stripMargin.replaceAll(System.lineSeparator(), " ")) createQueryTest("transform with SerDe4", """ @@ -510,7 +468,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |('serialization.last.column.takes.rest'='true') USING 'cat' ROW FORMAT SERDE |'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES |('serialization.last.column.takes.rest'='true') FROM src; - """.stripMargin.replaceAll("\n", " ")) + """.stripMargin.replaceAll(System.lineSeparator(), " ")) createQueryTest("LIKE", "SELECT * FROM src WHERE value LIKE '%1%'") @@ -630,11 +588,62 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |select * where key = 4 """.stripMargin) + // test get_json_object again Hive, because the HiveCompatabilitySuite cannot handle result + // with newline in it. + createQueryTest("get_json_object #1", + "SELECT get_json_object(src_json.json, '$') FROM src_json") + + createQueryTest("get_json_object #2", + "SELECT get_json_object(src_json.json, '$.owner'), get_json_object(src_json.json, '$.store')" + + " FROM src_json") + + createQueryTest("get_json_object #3", + "SELECT get_json_object(src_json.json, '$.store.bicycle'), " + + "get_json_object(src_json.json, '$.store.book') FROM src_json") + + createQueryTest("get_json_object #4", + "SELECT get_json_object(src_json.json, '$.store.book[0]'), " + + "get_json_object(src_json.json, '$.store.book[*]') FROM src_json") + + createQueryTest("get_json_object #5", + "SELECT get_json_object(src_json.json, '$.store.book[0].category'), " + + "get_json_object(src_json.json, '$.store.book[*].category'), " + + "get_json_object(src_json.json, '$.store.book[*].isbn'), " + + "get_json_object(src_json.json, '$.store.book[*].reader') FROM src_json") + + createQueryTest("get_json_object #6", + "SELECT get_json_object(src_json.json, '$.store.book[*].reader[0].age'), " + + "get_json_object(src_json.json, '$.store.book[*].reader[*].age') FROM src_json") + + createQueryTest("get_json_object #7", + "SELECT get_json_object(src_json.json, '$.store.basket[0][1]'), " + + "get_json_object(src_json.json, '$.store.basket[*]'), " + + // Hive returns wrong result with [*][0], so this expression is change to make test pass + "get_json_object(src_json.json, '$.store.basket[0][0]'), " + + "get_json_object(src_json.json, '$.store.basket[0][*]'), " + + "get_json_object(src_json.json, '$.store.basket[*][*]'), " + + "get_json_object(src_json.json, '$.store.basket[0][2].b'), " + + "get_json_object(src_json.json, '$.store.basket[0][*].b') FROM src_json") + + createQueryTest("get_json_object #8", + "SELECT get_json_object(src_json.json, '$.non_exist_key'), " + + "get_json_object(src_json.json, '$..no_recursive'), " + + "get_json_object(src_json.json, '$.store.book[10]'), " + + "get_json_object(src_json.json, '$.store.book[0].non_exist_key'), " + + "get_json_object(src_json.json, '$.store.basket[*].non_exist_key'), " + + "get_json_object(src_json.json, '$.store.basket[0][*].non_exist_key') FROM src_json") + + createQueryTest("get_json_object #9", + "SELECT get_json_object(src_json.json, '$.zip code') FROM src_json") + + createQueryTest("get_json_object #10", + "SELECT get_json_object(src_json.json, '$.fb:testid') FROM src_json") + test("predicates contains an empty AttributeSet() references") { sql( """ |SELECT a FROM ( - | SELECT 1 AS a FROM src LIMIT 1 ) table + | SELECT 1 AS a FROM src LIMIT 1 ) t |WHERE abs(20141202) is not null """.stripMargin).collect() } @@ -947,7 +956,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { .zip(parts) .map { case (k, v) => if (v == "NULL") { - s"$k=${ConfVars.DEFAULTPARTITIONNAME.defaultVal}" + s"$k=${ConfVars.DEFAULTPARTITIONNAME.defaultStrVal}" } else { s"$k=$v" } @@ -1096,18 +1105,19 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // "SET" itself returns all config variables currently specified in SQLConf. // TODO: Should we be listing the default here always? probably... - assert(sql("SET").collect().size == 0) + assert(sql("SET").collect().size === TestHiveContext.overrideConfs.size) + val defaults = collectResults(sql("SET")) assertResult(Set(testKey -> testVal)) { collectResults(sql(s"SET $testKey=$testVal")) } - assert(hiveconf.get(testKey, "") == testVal) - assertResult(Set(testKey -> testVal))(collectResults(sql("SET"))) + assert(hiveconf.get(testKey, "") === testVal) + assertResult(defaults ++ Set(testKey -> testVal))(collectResults(sql("SET"))) sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) - assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { + assertResult(defaults ++ Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { collectResults(sql("SET")) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index f0f04f8c73fb..197e9bfb02c4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -59,10 +59,4 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { } assert(numEquals === 1) } - - test("COALESCE with different types") { - intercept[RuntimeException] { - TestHive.sql("""SELECT COALESCE(1, true, "abc") FROM src limit 1""").collect() - } - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala similarity index 54% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index ce5985888f54..d9ba895e1ece 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.hive.execution import java.io.{DataInput, DataOutput} -import java.util -import java.util.Properties +import java.util.{ArrayList, Arrays, Properties} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFAverage, GenericUDF} @@ -28,13 +27,11 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} import org.apache.hadoop.io.Writable -import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.util.Utils -import scala.collection.JavaConversions._ - case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) // Case classes for the custom UDF's. @@ -46,10 +43,10 @@ case class ListStringCaseClass(l: Seq[String]) /** * A test suite for Hive custom UDFs. */ -class HiveUdfSuite extends QueryTest { +class HiveUDFSuite extends QueryTest with TestHiveSingleton { - import TestHive.{udf, sql} - import TestHive.implicits._ + import hiveContext.{udf, sql} + import hiveContext.implicits._ test("spark sql udf test that returns a struct") { udf.register("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) @@ -73,7 +70,7 @@ class HiveUdfSuite extends QueryTest { test("hive struct udf") { sql( """ - |CREATE EXTERNAL TABLE hiveUdfTestTable ( + |CREATE EXTERNAL TABLE hiveUDFTestTable ( | pair STRUCT |) |PARTITIONED BY (partition STRING) @@ -82,15 +79,56 @@ class HiveUdfSuite extends QueryTest { """. stripMargin.format(classOf[PairSerDe].getName)) - val location = Utils.getSparkClassLoader.getResource("data/files/testUdf").getFile + val location = Utils.getSparkClassLoader.getResource("data/files/testUDF").getFile sql(s""" - ALTER TABLE hiveUdfTestTable - ADD IF NOT EXISTS PARTITION(partition='testUdf') + ALTER TABLE hiveUDFTestTable + ADD IF NOT EXISTS PARTITION(partition='testUDF') LOCATION '$location'""") - sql(s"CREATE TEMPORARY FUNCTION testUdf AS '${classOf[PairUdf].getName}'") - sql("SELECT testUdf(pair) FROM hiveUdfTestTable") - sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") + sql(s"CREATE TEMPORARY FUNCTION testUDF AS '${classOf[PairUDF].getName}'") + sql("SELECT testUDF(pair) FROM hiveUDFTestTable") + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF") + } + + test("Max/Min on named_struct") { + def testOrderInStruct(): Unit = { + checkAnswer(sql( + """ + |SELECT max(named_struct( + | "key", key, + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_498"))) + checkAnswer(sql( + """ + |SELECT min(named_struct( + | "key", key, + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_0"))) + + // nested struct cases + checkAnswer(sql( + """ + |SELECT max(named_struct( + | "key", named_struct( + "key", key, + "value", value), + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_498"))) + checkAnswer(sql( + """ + |SELECT min(named_struct( + | "key", named_struct( + "key", key, + "value", value), + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_0"))) + } + val codegenDefault = hiveContext.getConf(SQLConf.CODEGEN_ENABLED) + hiveContext.setConf(SQLConf.CODEGEN_ENABLED, true) + testOrderInStruct() + hiveContext.setConf(SQLConf.CODEGEN_ENABLED, false) + testOrderInStruct() + hiveContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) } test("SPARK-6409 UDAFAverage test") { @@ -99,7 +137,7 @@ class HiveUdfSuite extends QueryTest { sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"), Seq(Row(1.0, 260.182))) sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg") - TestHive.reset() + hiveContext.reset() } test("SPARK-2693 udaf aggregates test") { @@ -119,7 +157,7 @@ class HiveUdfSuite extends QueryTest { } test("UDFIntegerToString") { - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil).toDF() testData.registerTempTable("integerTable") @@ -130,11 +168,73 @@ class HiveUdfSuite extends QueryTest { Seq(Row("1"), Row("2"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString") - TestHive.reset() + hiveContext.reset() + } + + test("UDFToListString") { + val testData = hiveContext.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + testData.registerTempTable("inputTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFToListString AS '${classOf[UDFToListString].getName}'") + val errMsg = intercept[AnalysisException] { + sql("SELECT testUDFToListString(s) FROM inputTable") + } + assert(errMsg.getMessage contains "List type in java is unsupported because " + + "JVM type erasure makes spark fail to catch a component type in List<>;") + + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListString") + hiveContext.reset() + } + + test("UDFToListInt") { + val testData = hiveContext.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + testData.registerTempTable("inputTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFToListInt AS '${classOf[UDFToListInt].getName}'") + val errMsg = intercept[AnalysisException] { + sql("SELECT testUDFToListInt(s) FROM inputTable") + } + assert(errMsg.getMessage contains "List type in java is unsupported because " + + "JVM type erasure makes spark fail to catch a component type in List<>;") + + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListInt") + hiveContext.reset() + } + + test("UDFToStringIntMap") { + val testData = hiveContext.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + testData.registerTempTable("inputTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFToStringIntMap " + + s"AS '${classOf[UDFToStringIntMap].getName}'") + val errMsg = intercept[AnalysisException] { + sql("SELECT testUDFToStringIntMap(s) FROM inputTable") + } + assert(errMsg.getMessage contains "Map type in java is unsupported because " + + "JVM type erasure makes spark fail to catch key and value types in Map<>;") + + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToStringIntMap") + hiveContext.reset() + } + + test("UDFToIntIntMap") { + val testData = hiveContext.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + testData.registerTempTable("inputTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFToIntIntMap " + + s"AS '${classOf[UDFToIntIntMap].getName}'") + val errMsg = intercept[AnalysisException] { + sql("SELECT testUDFToIntIntMap(s) FROM inputTable") + } + assert(errMsg.getMessage contains "Map type in java is unsupported because " + + "JVM type erasure makes spark fail to catch key and value types in Map<>;") + + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToIntIntMap") + hiveContext.reset() } test("UDFListListInt") { - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( ListListIntCaseClass(Nil) :: ListListIntCaseClass(Seq((1, 2, 3))) :: ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil).toDF() @@ -146,11 +246,11 @@ class HiveUdfSuite extends QueryTest { Seq(Row(0), Row(2), Row(13))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt") - TestHive.reset() + hiveContext.reset() } test("UDFListString") { - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( ListStringCaseClass(Seq("a", "b", "c")) :: ListStringCaseClass(Seq("d", "e")) :: Nil).toDF() testData.registerTempTable("listStringTable") @@ -161,25 +261,30 @@ class HiveUdfSuite extends QueryTest { Seq(Row("a,b,c"), Row("d,e"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString") - TestHive.reset() + hiveContext.reset() } test("UDFStringString") { - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil).toDF() testData.registerTempTable("stringTable") - sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'") + sql(s"CREATE TEMPORARY FUNCTION testStringStringUDF AS '${classOf[UDFStringString].getName}'") checkAnswer( - sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), + sql("SELECT testStringStringUDF(\"hello\", s) FROM stringTable"), Seq(Row("hello world"), Row("hello goodbye"))) - sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf") - TestHive.reset() + checkAnswer( + sql("SELECT testStringStringUDF(\"\", testStringStringUDF(\"hello\", s)) FROM stringTable"), + Seq(Row(" hello world"), Row(" hello goodbye"))) + + sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUDF") + + hiveContext.reset() } test("UDFTwoListList") { - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( ListListIntCaseClass(Nil) :: ListListIntCaseClass(Seq((1, 2, 3))) :: ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: @@ -192,7 +297,7 @@ class HiveUdfSuite extends QueryTest { Seq(Row("0, 0"), Row("2, 2"), Row("13, 13"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") - TestHive.reset() + hiveContext.reset() } } @@ -218,11 +323,11 @@ class PairSerDe extends AbstractSerDe { override def getObjectInspector: ObjectInspector = { ObjectInspectorFactory .getStandardStructObjectInspector( - Seq("pair"), - Seq(ObjectInspectorFactory.getStandardStructObjectInspector( - Seq("id", "value"), - Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, - PrimitiveObjectInspectorFactory.javaIntObjectInspector)) + Arrays.asList("pair"), + Arrays.asList(ObjectInspectorFactory.getStandardStructObjectInspector( + Arrays.asList("id", "value"), + Arrays.asList(PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector)) )) } @@ -235,25 +340,24 @@ class PairSerDe extends AbstractSerDe { override def deserialize(value: Writable): AnyRef = { val pair = value.asInstanceOf[TestPair] - val row = new util.ArrayList[util.ArrayList[AnyRef]] - row.add(new util.ArrayList[AnyRef](2)) - row(0).add(Integer.valueOf(pair.entry._1)) - row(0).add(Integer.valueOf(pair.entry._2)) + val row = new ArrayList[ArrayList[AnyRef]] + row.add(new ArrayList[AnyRef](2)) + row.get(0).add(Integer.valueOf(pair.entry._1)) + row.get(0).add(Integer.valueOf(pair.entry._2)) row } } -class PairUdf extends GenericUDF { +class PairUDF extends GenericUDF { override def initialize(p1: Array[ObjectInspector]): ObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector( - Seq("id", "value"), - Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, - PrimitiveObjectInspectorFactory.javaIntObjectInspector) + Arrays.asList("id", "value"), + Arrays.asList(PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector) ) override def evaluate(args: Array[DeferredObject]): AnyRef = { - println("Type = %s".format(args(0).getClass.getName)) Integer.valueOf(args(0).get.asInstanceOf[TestPair].entry._2) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index de6a41ce5bfc..210d56674541 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.hive.execution +import scala.collection.JavaConverters._ + import org.scalatest.BeforeAndAfter import org.apache.spark.sql.hive.test.TestHive -/* Implicit conversions */ -import scala.collection.JavaConversions._ - /** * A set of test cases that validate partition and column pruning. */ @@ -82,16 +81,16 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { Seq.empty) createPruningTest("Column pruning - non-trivial top project with aliases", - "SELECT c1 * 2 AS double FROM (SELECT key AS c1 FROM src WHERE key > 10) t1 LIMIT 3", - Seq("double"), + "SELECT c1 * 2 AS dbl FROM (SELECT key AS c1 FROM src WHERE key > 10) t1 LIMIT 3", + Seq("dbl"), Seq("key"), Seq.empty) // Partition pruning tests createPruningTest("Partition pruning - non-partitioned, non-trivial project", - "SELECT key * 2 AS double FROM src WHERE value IS NOT NULL", - Seq("double"), + "SELECT key * 2 AS dbl FROM src WHERE value IS NOT NULL", + Seq("dbl"), Seq("key", "value"), Seq.empty) @@ -151,7 +150,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { case p @ HiveTableScan(columns, relation, _) => val columnNames = columns.map(_.name) val partValues = if (relation.table.isPartitioned) { - p.prunePartitions(relation.hiveQlPartitions).map(_.getValues) + p.prunePartitions(relation.getHiveQlPartitions()).map(_.getValues) } else { Seq.empty } @@ -161,7 +160,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { assert(actualOutputColumns === expectedOutputColumns, "Output columns mismatch") assert(actualScannedColumns === expectedScannedColumns, "Scanned columns mismatch") - val actualPartitions = actualPartValues.map(_.toSeq.mkString(",")).sorted + val actualPartitions = actualPartValues.map(_.asScala.mkString(",")).sorted val expectedPartitions = expectedPartValues.map(_.mkString(",")).sorted assert(actualPartitions === expectedPartitions, "Partitions selected do not match") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index a2e666586c18..8126d0233521 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -17,17 +17,21 @@ package org.apache.spark.sql.hive.execution +import java.sql.{Date, Timestamp} + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.DefaultParserDialect -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, EliminateSubQueries} import org.apache.spark.sql.catalyst.errors.DialectException -import org.apache.spark.sql._ -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} -import org.apache.spark.sql.parquet.ParquetRelation2 -import org.apache.spark.sql.sources.LogicalRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval case class Nested1(f1: Nested2) case class Nested2(f2: Nested3) @@ -59,7 +63,29 @@ class MyDialect extends DefaultParserDialect * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is * valid, but Hive currently cannot execute it. */ -class SQLQuerySuite extends QueryTest { +class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import hiveContext._ + import hiveContext.implicits._ + + test("UDTF") { + sql(s"ADD JAR ${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}") + // The function source code can be found at: + // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF + sql( + """ + |CREATE TEMPORARY FUNCTION udtf_count2 + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin) + + checkAnswer( + sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"), + Row(97, 500) :: Row(97, 500) :: Nil) + + checkAnswer( + sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), + Row(3) :: Row(3) :: Nil) + } + test("SPARK-6835: udtf in lateral view") { val df = Seq((1, 1)).toDF("c1", "c2") df.registerTempTable("table1") @@ -133,6 +159,50 @@ class SQLQuerySuite extends QueryTest { (1 to 6).map(_ => Row("CA", 20151))) } + test("show functions") { + val allFunctions = + (FunctionRegistry.builtin.listFunction().toSet[String] ++ + org.apache.hadoop.hive.ql.exec.FunctionRegistry.getFunctionNames.asScala).toList.sorted + checkAnswer(sql("SHOW functions"), allFunctions.map(Row(_))) + checkAnswer(sql("SHOW functions abs"), Row("abs")) + checkAnswer(sql("SHOW functions 'abs'"), Row("abs")) + checkAnswer(sql("SHOW functions abc.abs"), Row("abs")) + checkAnswer(sql("SHOW functions `abc`.`abs`"), Row("abs")) + checkAnswer(sql("SHOW functions `abc`.`abs`"), Row("abs")) + checkAnswer(sql("SHOW functions `~`"), Row("~")) + checkAnswer(sql("SHOW functions `a function doens't exist`"), Nil) + checkAnswer(sql("SHOW functions `weekofyea.*`"), Row("weekofyear")) + // this probably will failed if we add more function with `sha` prefixing. + checkAnswer(sql("SHOW functions `sha.*`"), Row("sha") :: Row("sha1") :: Row("sha2") :: Nil) + } + + test("describe functions") { + // The Spark SQL built-in functions + checkExistence(sql("describe function extended upper"), true, + "Function: upper", + "Class: org.apache.spark.sql.catalyst.expressions.Upper", + "Usage: upper(str) - Returns str with all characters changed to uppercase", + "Extended Usage:", + "> SELECT upper('SparkSql')", + "'SPARKSQL'") + + checkExistence(sql("describe functioN Upper"), true, + "Function: upper", + "Class: org.apache.spark.sql.catalyst.expressions.Upper", + "Usage: upper(str) - Returns str with all characters changed to uppercase") + + checkExistence(sql("describe functioN Upper"), false, + "Extended Usage") + + checkExistence(sql("describe functioN abcadf"), true, + "Function: abcadf is not found.") + + checkExistence(sql("describe functioN `~`"), true, + "Function: ~", + "Class: org.apache.hadoop.hive.ql.udf.UDFOPBitNot", + "Usage: ~ n - Bitwise not") + } + test("SPARK-5371: union with null and sum") { val df = Seq((1, 1)).toDF("c1", "c2") df.registerTempTable("table1") @@ -157,6 +227,24 @@ class SQLQuerySuite extends QueryTest { checkAnswer(query, Row(1, 1) :: Nil) } + test("CTAS with WITH clause") { + val df = Seq((1, 1)).toDF("c1", "c2") + df.registerTempTable("table1") + + sql( + """ + |CREATE TABLE with_table1 AS + |WITH T AS ( + | SELECT * + | FROM table1 + |) + |SELECT * + |FROM T + """.stripMargin) + val query = sql("SELECT * FROM with_table1") + checkAnswer(query, Row(1, 1) :: Nil) + } + test("explode nested Field") { Seq(NestedArray1(NestedArray2(Seq(1, 2, 3)))).toDF.registerTempTable("nestedArray") checkAnswer( @@ -175,17 +263,17 @@ class SQLQuerySuite extends QueryTest { def checkRelation(tableName: String, isDataSourceParquet: Boolean): Unit = { val relation = EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) relation match { - case LogicalRelation(r: ParquetRelation2) => + case LogicalRelation(r: ParquetRelation) => if (!isDataSourceParquet) { fail( s"${classOf[MetastoreRelation].getCanonicalName} is expected, but found " + - s"${ParquetRelation2.getClass.getCanonicalName}.") + s"${ParquetRelation.getClass.getCanonicalName}.") } case r: MetastoreRelation => if (isDataSourceParquet) { fail( - s"${ParquetRelation2.getClass.getCanonicalName} is expected, but found " + + s"${ParquetRelation.getClass.getCanonicalName} is expected, but found " + s"${classOf[MetastoreRelation].getCanonicalName}.") } } @@ -195,47 +283,51 @@ class SQLQuerySuite extends QueryTest { setConf(HiveContext.CONVERT_CTAS, true) - sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - var message = intercept[AnalysisException] { + try { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - }.getMessage - assert(message.contains("ctas1 already exists")) - checkRelation("ctas1", true) - sql("DROP TABLE ctas1") - - // Specifying database name for query can be converted to data source write path - // is not allowed right now. - message = intercept[AnalysisException] { - sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - }.getMessage - assert( - message.contains("Cannot specify database name in a CTAS statement"), - "When spark.sql.hive.convertCTAS is true, we should not allow " + - "database name specified.") - - sql("CREATE TABLE ctas1 stored as textfile AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true) - sql("DROP TABLE ctas1") - - sql( - "CREATE TABLE ctas1 stored as sequencefile AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true) - sql("DROP TABLE ctas1") - - sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false) - sql("DROP TABLE ctas1") - - sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false) - sql("DROP TABLE ctas1") - - sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false) - sql("DROP TABLE ctas1") - - setConf(HiveContext.CONVERT_CTAS, originalConf) + sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + var message = intercept[AnalysisException] { + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + }.getMessage + assert(message.contains("ctas1 already exists")) + checkRelation("ctas1", true) + sql("DROP TABLE ctas1") + + // Specifying database name for query can be converted to data source write path + // is not allowed right now. + message = intercept[AnalysisException] { + sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + }.getMessage + assert( + message.contains("Cannot specify database name in a CTAS statement"), + "When spark.sql.hive.convertCTAS is true, we should not allow " + + "database name specified.") + + sql("CREATE TABLE ctas1 stored as textfile" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", true) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as sequencefile" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", true) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", false) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", false) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", false) + sql("DROP TABLE ctas1") + } finally { + setConf(HiveContext.CONVERT_CTAS, originalConf) + sql("DROP TABLE IF EXISTS ctas1") + } } test("SQL Dialect Switching") { @@ -330,16 +422,14 @@ class SQLQuerySuite extends QueryTest { "serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22", "MANAGED_TABLE" ) - val origUseParquetDataSource = conf.parquetUseDataSourceApi - try { - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) - sql( - """CREATE TABLE ctas5 - | STORED AS parquet AS - | SELECT key, value - | FROM src - | ORDER BY key, value""".stripMargin).collect() + sql( + """CREATE TABLE ctas5 + | STORED AS parquet AS + | SELECT key, value + | FROM src + | ORDER BY key, value""".stripMargin).collect() + withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { checkExistence(sql("DESC EXTENDED ctas5"), true, "name:key", "type:string", "name:value", "ctas5", "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", @@ -347,16 +437,13 @@ class SQLQuerySuite extends QueryTest { "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", "MANAGED_TABLE" ) + } - val default = convertMetastoreParquet - // use the Hive SerDe for parquet tables - sql("set spark.sql.hive.convertMetastoreParquet = false") + // use the Hive SerDe for parquet tables + withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { checkAnswer( sql("SELECT key, value FROM ctas5 ORDER BY key, value"), sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) - sql(s"set spark.sql.hive.convertMetastoreParquet = $default") - } finally { - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, origUseParquetDataSource) } } @@ -420,19 +507,19 @@ class SQLQuerySuite extends QueryTest { checkAnswer( sql("SELECT f1.f2.f3 FROM nested"), Row(1)) - checkAnswer(sql("CREATE TABLE test_ctas_1234 AS SELECT * from nested"), - Seq.empty[Row]) + + sql("CREATE TABLE test_ctas_1234 AS SELECT * from nested") checkAnswer( sql("SELECT * FROM test_ctas_1234"), sql("SELECT * FROM nested").collect().toSeq) intercept[AnalysisException] { - sql("CREATE TABLE test_ctas_12345 AS SELECT * from notexists").collect() + sql("CREATE TABLE test_ctas_1234 AS SELECT * from notexists").collect() } } test("test CTAS") { - checkAnswer(sql("CREATE TABLE test_ctas_123 AS SELECT key, value FROM src"), Seq.empty[Row]) + sql("CREATE TABLE test_ctas_123 AS SELECT key, value FROM src") checkAnswer( sql("SELECT key, value FROM test_ctas_123 ORDER BY key"), sql("SELECT key, value FROM src ORDER BY key").collect().toSeq) @@ -525,7 +612,7 @@ class SQLQuerySuite extends QueryTest { val rowRdd = sparkContext.parallelize(row :: Nil) - TestHive.createDataFrame(rowRdd, schema).registerTempTable("testTable") + hiveContext.createDataFrame(rowRdd, schema).registerTempTable("testTable") sql( """CREATE TABLE nullValuesInInnerComplexTypes @@ -573,7 +660,7 @@ class SQLQuerySuite extends QueryTest { test("resolve udtf in projection #2") { val rdd = sparkContext.makeRDD((1 to 2).map(i => s"""{"a":[$i, ${i + 1}]}""")) - jsonRDD(rdd).registerTempTable("data") + read.json(rdd).registerTempTable("data") checkAnswer(sql("SELECT explode(map(1, 1)) FROM data LIMIT 1"), Row(1, 1) :: Nil) checkAnswer(sql("SELECT explode(map(1, 1)) as (k1, k2) FROM data LIMIT 1"), Row(1, 1) :: Nil) intercept[AnalysisException] { @@ -588,7 +675,7 @@ class SQLQuerySuite extends QueryTest { // TGF with non-TGF in project is allowed in Spark SQL, but not in Hive test("TGF with non-TGF in projection") { val rdd = sparkContext.makeRDD( """{"a": "1", "b":"1"}""" :: Nil) - jsonRDD(rdd).registerTempTable("data") + read.json(rdd).registerTempTable("data") checkAnswer( sql("SELECT explode(map(a, b)) as (k1, k2), a, b FROM data"), Row("1", "1", "1", "1") :: Nil) @@ -606,22 +693,25 @@ class SQLQuerySuite extends QueryTest { val originalConf = convertCTAS setConf(HiveContext.CONVERT_CTAS, false) - sql("CREATE TABLE explodeTest (key bigInt)") - table("explodeTest").queryExecution.analyzed match { - case metastoreRelation: MetastoreRelation => // OK - case _ => - fail("To correctly test the fix of SPARK-5875, explodeTest should be a MetastoreRelation") - } + try { + sql("CREATE TABLE explodeTest (key bigInt)") + table("explodeTest").queryExecution.analyzed match { + case metastoreRelation: MetastoreRelation => // OK + case _ => + fail("To correctly test the fix of SPARK-5875, explodeTest should be a MetastoreRelation") + } - sql(s"INSERT OVERWRITE TABLE explodeTest SELECT explode(a) AS val FROM data") - checkAnswer( - sql("SELECT key from explodeTest"), - (1 to 5).flatMap(i => Row(i) :: Row(i + 1) :: Nil) - ) + sql(s"INSERT OVERWRITE TABLE explodeTest SELECT explode(a) AS val FROM data") + checkAnswer( + sql("SELECT key from explodeTest"), + (1 to 5).flatMap(i => Row(i) :: Row(i + 1) :: Nil) + ) - sql("DROP TABLE explodeTest") - dropTempTable("data") - setConf(HiveContext.CONVERT_CTAS, originalConf) + sql("DROP TABLE explodeTest") + dropTempTable("data") + } finally { + setConf(HiveContext.CONVERT_CTAS, originalConf) + } } test("sanity test for SPARK-6618") { @@ -638,7 +728,7 @@ class SQLQuerySuite extends QueryTest { test("SPARK-5203 union with different decimal precision") { Seq.empty[(Decimal, Decimal)] .toDF("d1", "d2") - .select($"d1".cast(DecimalType(10, 15)).as("d")) + .select($"d1".cast(DecimalType(10, 5)).as("d")) .registerTempTable("dn") sql("select d from dn union all select d * 2 from dn") @@ -653,7 +743,7 @@ class SQLQuerySuite extends QueryTest { .queryExecution.toRdd.count()) } - ignore("test script transform for stderr") { + test("test script transform for stderr") { val data = (1 to 100000).map { i => (i, i, i) } data.toDF("d1", "d2", "d3").registerTempTable("script_trans") assert(0 === @@ -661,6 +751,16 @@ class SQLQuerySuite extends QueryTest { .queryExecution.toRdd.count()) } + test("test script transform data type") { + val data = (1 to 5).map { i => (i, i) } + data.toDF("key", "value").registerTempTable("test") + checkAnswer( + sql("""FROM + |(FROM test SELECT TRANSFORM(key, value) USING 'cat' AS (thing1 int, thing2 string)) t + |SELECT thing1 + 1 + """.stripMargin), (2 to 6).map(i => Row(i))) + } + test("window function: udaf with aggregate expressin") { val data = Seq( WindowData(1, "a", 5), @@ -848,6 +948,8 @@ class SQLQuerySuite extends QueryTest { } test("SPARK-7595: Window will cause resolve failed with self join") { + sql("SELECT * FROM src") // Force loading of src table. + checkAnswer(sql( """ |with @@ -940,10 +1042,10 @@ class SQLQuerySuite extends QueryTest { val thread = new Thread { override def run() { // To make sure this test works, this jar should not be loaded in another place. - TestHive.sql( - s"ADD JAR ${TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath()}") + sql( + s"ADD JAR ${hiveContext.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath()}") try { - TestHive.sql( + sql( """ |CREATE TEMPORARY FUNCTION example_max |AS 'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax' @@ -962,4 +1064,110 @@ class SQLQuerySuite extends QueryTest { case None => // OK } } + + test("SPARK-6785: HiveQuerySuite - Date comparison test 2") { + checkAnswer( + sql("SELECT CAST(CAST(0 AS timestamp) AS date) > CAST(0 AS timestamp) FROM src LIMIT 1"), + Row(false)) + } + + test("SPARK-6785: HiveQuerySuite - Date cast") { + // new Date(0) == 1970-01-01 00:00:00.0 GMT == 1969-12-31 16:00:00.0 PST + checkAnswer( + sql( + """ + | SELECT + | CAST(CAST(0 AS timestamp) AS date), + | CAST(CAST(CAST(0 AS timestamp) AS date) AS string), + | CAST(0 AS timestamp), + | CAST(CAST(0 AS timestamp) AS string), + | CAST(CAST(CAST('1970-01-01 23:00:00' AS timestamp) AS date) AS timestamp) + | FROM src LIMIT 1 + """.stripMargin), + Row( + Date.valueOf("1969-12-31"), + String.valueOf("1969-12-31"), + Timestamp.valueOf("1969-12-31 16:00:00"), + String.valueOf("1969-12-31 16:00:00"), + Timestamp.valueOf("1970-01-01 00:00:00"))) + + } + + test("SPARK-8588 HiveTypeCoercion.inConversion fires too early") { + val df = + createDataFrame(Seq((1, "2014-01-01"), (2, "2015-01-01"), (3, "2016-01-01"))) + df.toDF("id", "datef").registerTempTable("test_SPARK8588") + checkAnswer( + sql( + """ + |select id, concat(year(datef)) + |from test_SPARK8588 where concat(year(datef), ' year') in ('2015 year', '2014 year') + """.stripMargin), + Row(1, "2014") :: Row(2, "2015") :: Nil + ) + dropTempTable("test_SPARK8588") + } + + test("SPARK-9371: fix the support for special chars in column names for hive context") { + read.json(sparkContext.makeRDD( + """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) + .registerTempTable("t") + + checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) + } + + test("Convert hive interval term into Literal of CalendarIntervalType") { + checkAnswer(sql("select interval '10-9' year to month"), + Row(CalendarInterval.fromString("interval 10 years 9 months"))) + checkAnswer(sql("select interval '20 15:40:32.99899999' day to second"), + Row(CalendarInterval.fromString("interval 2 weeks 6 days 15 hours 40 minutes " + + "32 seconds 99 milliseconds 899 microseconds"))) + checkAnswer(sql("select interval '30' year"), + Row(CalendarInterval.fromString("interval 30 years"))) + checkAnswer(sql("select interval '25' month"), + Row(CalendarInterval.fromString("interval 25 months"))) + checkAnswer(sql("select interval '-100' day"), + Row(CalendarInterval.fromString("interval -14 weeks -2 days"))) + checkAnswer(sql("select interval '40' hour"), + Row(CalendarInterval.fromString("interval 1 days 16 hours"))) + checkAnswer(sql("select interval '80' minute"), + Row(CalendarInterval.fromString("interval 1 hour 20 minutes"))) + checkAnswer(sql("select interval '299.889987299' second"), + Row(CalendarInterval.fromString( + "interval 4 minutes 59 seconds 889 milliseconds 987 microseconds"))) + } + + test("specifying database name for a temporary table is not allowed") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + df + .write + .format("parquet") + .save(path) + + val message = intercept[AnalysisException] { + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE db.t + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + assert(message.contains("Specifying database name or other qualifiers are not allowed")) + + // If you use backticks to quote the name of a temporary table having dot in it. + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE `db.t` + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + checkAnswer(sqlContext.table("`db.t`"), df) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala new file mode 100644 index 000000000000..cb8d0fca8e69 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.execution.{UnaryNode, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.types.StringType + +class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { + import hiveContext.implicits._ + + private val noSerdeIOSchema = HiveScriptIOSchema( + inputRowFormat = Seq.empty, + outputRowFormat = Seq.empty, + inputSerdeClass = None, + outputSerdeClass = None, + inputSerdeProps = Seq.empty, + outputSerdeProps = Seq.empty, + schemaLess = false + ) + + private val serdeIOSchema = noSerdeIOSchema.copy( + inputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName), + outputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName) + ) + + test("cat without SerDe") { + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformation( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = child, + ioschema = noSerdeIOSchema + )(hiveContext), + rowsDf.collect()) + } + + test("cat with LazySimpleSerDe") { + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformation( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = child, + ioschema = serdeIOSchema + )(hiveContext), + rowsDf.collect()) + } + + test("script transformation should not swallow errors from upstream operators (no serde)") { + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + val e = intercept[TestFailedException] { + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformation( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = ExceptionInjectingOperator(child), + ioschema = noSerdeIOSchema + )(hiveContext), + rowsDf.collect()) + } + assert(e.getMessage().contains("intentional exception")) + } + + test("script transformation should not swallow errors from upstream operators (with serde)") { + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + val e = intercept[TestFailedException] { + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformation( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = ExceptionInjectingOperator(child), + ioschema = serdeIOSchema + )(hiveContext), + rowsDf.collect()) + } + assert(e.getMessage().contains("intentional exception")) + } +} + +private case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryNode { + override protected def doExecute(): RDD[InternalRow] = { + child.execute().map { x => + assert(TaskContext.get() != null) // Make sure that TaskContext is defined. + Thread.sleep(1000) // This sleep gives the external process time to start. + throw new IllegalArgumentException("intentional exception") + } + } + override def output: Seq[Attribute] = child.output +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index 080af5bb23c1..92043d66c914 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -24,10 +24,17 @@ import org.apache.spark.sql.sources.HadoopFsRelationTest import org.apache.spark.sql.types._ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { + import testImplicits._ + override val dataSourceName: String = classOf[DefaultSource].getCanonicalName - import sqlContext._ - import sqlContext.implicits._ + // ORC does not play well with NullType and UDT. + override protected def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: NullType => false + case _: CalendarIntervalType => false + case _: UserDefinedType[_] => false + case _ => true + } test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => @@ -41,19 +48,16 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) .toDF("a", "b", "p1") .write - .format("orc") - .save(partitionDir.toString) + .orc(partitionDir.toString) } val dataSchemaWithPartition = StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchemaWithPartition.json))) + hiveContext.read.options(Map( + "path" -> file.getCanonicalPath, + "dataSchema" -> dataSchemaWithPartition.json)).format(dataSourceName).load()) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index 8707f9f936be..52e09f9496f0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -18,19 +18,16 @@ package org.apache.spark.sql.hive.orc import java.io.File -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions.InternalRow -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.util.Utils -import org.scalatest.BeforeAndAfterAll import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import org.scalatest.BeforeAndAfterAll +import org.apache.hadoop.hive.conf.HiveConf.ConfVars + +import org.apache.spark.sql._ +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.util.Utils // The data where the partitioning key exists only in the directory structure. case class OrcParData(intField: Int, stringField: String) @@ -39,8 +36,11 @@ case class OrcParData(intField: Int, stringField: String) case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) // TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot -class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { - val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultVal +class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { + import hiveContext._ + import hiveContext.implicits._ + + val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultStrVal def withTempDir(f: File => Unit): Unit = { val dir = Utils.createTempDir().getCanonicalFile @@ -49,17 +49,17 @@ class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { def makeOrcFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - data.toDF().write.format("orc").mode("overwrite").save(path.getCanonicalPath) + data.toDF().write.mode("overwrite").orc(path.getCanonicalPath) } def makeOrcFile[T <: Product: ClassTag: TypeTag]( df: DataFrame, path: File): Unit = { - df.write.format("orc").mode("overwrite").save(path.getCanonicalPath) + df.write.mode("overwrite").orc(path.getCanonicalPath) } protected def withTempTable(tableName: String)(f: => Unit): Unit = { - try f finally TestHive.dropTempTable(tableName) + try f finally hiveContext.dropTempTable(tableName) } protected def makePartitionDir( @@ -90,7 +90,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - read.format("orc").load(base.getCanonicalPath).registerTempTable("t") + read.orc(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -137,7 +137,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - read.format("orc").load(base.getCanonicalPath).registerTempTable("t") + read.orc(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -187,9 +187,8 @@ class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { } read - .format("orc") .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) - .load(base.getCanonicalPath) + .orc(base.getCanonicalPath) .registerTempTable("t") withTempTable("t") { @@ -230,9 +229,8 @@ class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { } read - .format("orc") .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) - .load(base.getCanonicalPath) + .orc(base.getCanonicalPath) .registerTempTable("t") withTempTable("t") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 267d22c6b5f1..8bc33fcf5d90 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -23,10 +23,7 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.io.orc.CompressionKind import org.scalatest.BeforeAndAfterAll -import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions.InternalRow -import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ @@ -66,14 +63,14 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withOrcFile(data) { file => checkAnswer( - sqlContext.read.format("orc").load(file), + sqlContext.read.orc(file), data.toDF().collect()) } } test("Read/write binary data") { withOrcFile(BinaryData("test".getBytes("utf8")) :: Nil) { file => - val bytes = read.format("orc").load(file).head().getAs[Array[Byte]](0) + val bytes = read.orc(file).head().getAs[Array[Byte]](0) assert(new String(bytes, "utf8") === "test") } } @@ -91,7 +88,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withOrcFile(data) { file => checkAnswer( - read.format("orc").load(file), + read.orc(file), data.toDF().collect()) } } @@ -161,7 +158,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withOrcFile(data) { file => checkAnswer( - read.format("orc").load(file), + read.orc(file), Row(Seq.fill(5)(null): _*)) } } @@ -170,7 +167,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { test("Default compression options for writing to an ORC file") { withOrcFile((1 to 100).map(i => (i, s"val_$i"))) { file => assertResult(CompressionKind.ZLIB) { - OrcFileOperator.getFileReader(file).getCompression + OrcFileOperator.getFileReader(file).get.getCompression } } } @@ -183,21 +180,21 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "SNAPPY") withOrcFile(data) { file => assertResult(CompressionKind.SNAPPY) { - OrcFileOperator.getFileReader(file).getCompression + OrcFileOperator.getFileReader(file).get.getCompression } } conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "NONE") withOrcFile(data) { file => assertResult(CompressionKind.NONE) { - OrcFileOperator.getFileReader(file).getCompression + OrcFileOperator.getFileReader(file).get.getCompression } } conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "LZO") withOrcFile(data) { file => assertResult(CompressionKind.LZO) { - OrcFileOperator.getFileReader(file).getCompression + OrcFileOperator.getFileReader(file).get.getCompression } } } @@ -289,4 +286,62 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { List(Row("same", "run_5", 100))) } } + + test("SPARK-9170: Don't implicitly lowercase of user-provided columns") { + withTempPath { dir => + val path = dir.getCanonicalPath + + sqlContext.range(0, 10).select('id as "Acol").write.format("orc").save(path) + sqlContext.read.format("orc").load(path).schema("Acol") + intercept[IllegalArgumentException] { + sqlContext.read.format("orc").load(path).schema("acol") + } + checkAnswer(sqlContext.read.format("orc").load(path).select("acol").sort("acol"), + (0 until 10).map(Row(_))) + } + } + + test("SPARK-8501: Avoids discovery schema from empty ORC files") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withTable("empty_orc") { + withTempTable("empty", "single") { + sqlContext.sql( + s"""CREATE TABLE empty_orc(key INT, value STRING) + |STORED AS ORC + |LOCATION '$path' + """.stripMargin) + + val emptyDF = Seq.empty[(Int, String)].toDF("key", "value").coalesce(1) + emptyDF.registerTempTable("empty") + + // This creates 1 empty ORC file with Hive ORC SerDe. We are using this trick because + // Spark SQL ORC data source always avoids write empty ORC files. + sqlContext.sql( + s"""INSERT INTO TABLE empty_orc + |SELECT key, value FROM empty + """.stripMargin) + + val errorMessage = intercept[AnalysisException] { + sqlContext.read.orc(path) + }.getMessage + + assert(errorMessage.contains("Failed to discover schema from ORC files")) + + val singleRowDF = Seq((0, "foo")).toDF("key", "value").coalesce(1) + singleRowDF.registerTempTable("single") + + sqlContext.sql( + s"""INSERT INTO TABLE empty_orc + |SELECT key, value FROM single + """.stripMargin) + + val df = sqlContext.read.orc(path) + assert(df.schema === singleRowDF.schema.asNullable) + checkAnswer(df, singleRowDF) + } + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 82e08caf4645..7a34cf731b4c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -21,12 +21,14 @@ import java.io.File import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.hive.test.TestHiveSingleton case class OrcData(intField: Int, stringField: String) -abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { +abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { + import hiveContext._ + var orcTableDir: File = null var orcTableAsDir: File = null @@ -121,13 +123,42 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { sql("SELECT * FROM normal_orc_as_source"), (6 to 10).map(i => Row(i, s"part-$i"))) } + + test("write null values") { + sql("DROP TABLE IF EXISTS orcNullValues") + + val df = sql( + """ + |SELECT + | CAST(null as TINYINT), + | CAST(null as SMALLINT), + | CAST(null as INT), + | CAST(null as BIGINT), + | CAST(null as FLOAT), + | CAST(null as DOUBLE), + | CAST(null as DECIMAL(7,2)), + | CAST(null as TIMESTAMP), + | CAST(null as DATE), + | CAST(null as STRING), + | CAST(null as VARCHAR(10)) + |FROM orc_temp_table limit 1 + """.stripMargin) + + df.write.format("orc").saveAsTable("orcNullValues") + + checkAnswer( + sql("SELECT * FROM orcNullValues"), + Row.fromSeq(Seq.fill(11)(null))) + + sql("DROP TABLE IF EXISTS orcNullValues") + } } class OrcSourceSuite extends OrcSuite { override def beforeAll(): Unit = { super.beforeAll() - sql( + hiveContext.sql( s"""CREATE TEMPORARY TABLE normal_orc_source |USING org.apache.spark.sql.hive.orc |OPTIONS ( @@ -135,7 +166,7 @@ class OrcSourceSuite extends OrcSuite { |) """.stripMargin) - sql( + hiveContext.sql( s"""CREATE TEMPORARY TABLE normal_orc_as_source |USING org.apache.spark.sql.hive.orc |OPTIONS ( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 5daf691aa8c5..88a0ed511749 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -22,14 +22,12 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql._ +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.hive.test.TestHiveSingleton -private[sql] trait OrcTest extends SQLTestUtils { - lazy val sqlContext = org.apache.spark.sql.hive.test.TestHive - - import sqlContext.sparkContext - import sqlContext.implicits._ +private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton { + import testImplicits._ /** * Writes `data` to a Orc file, which is then passed to `f` and will be deleted after `f` @@ -39,7 +37,7 @@ private[sql] trait OrcTest extends SQLTestUtils { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - sparkContext.parallelize(data).toDF().write.format("orc").save(file.getCanonicalPath) + sparkContext.parallelize(data).toDF().write.orc(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -51,7 +49,7 @@ private[sql] trait OrcTest extends SQLTestUtils { protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - withOrcFile(data)(path => f(sqlContext.read.format("orc").load(path))) + withOrcFile(data)(path => f(sqlContext.read.orc(path))) } /** @@ -70,11 +68,11 @@ private[sql] trait OrcTest extends SQLTestUtils { protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - data.toDF().write.format("orc").mode(SaveMode.Overwrite).save(path.getCanonicalPath) + data.toDF().write.mode(SaveMode.Overwrite).orc(path.getCanonicalPath) } protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( df: DataFrame, path: File): Unit = { - df.write.format("orc").mode(SaveMode.Overwrite).save(path.getCanonicalPath) + df.write.mode(SaveMode.Overwrite).orc(path.getCanonicalPath) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index c2e09800933b..6842ec2b5eb3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -19,16 +19,14 @@ package org.apache.spark.sql.hive import java.io.File -import org.scalatest.BeforeAndAfterAll - +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.datasources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} import org.apache.spark.sql.hive.execution.HiveTableScan -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.parquet.{ParquetRelation2, ParquetTableScan} -import org.apache.spark.sql.sources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, QueryTest, Row, SQLConf, SaveMode} import org.apache.spark.util.Utils // The data where the partitioning key exists only in the directory structure. @@ -55,10 +53,19 @@ case class ParquetDataWithKeyAndComplexTypes( * A suite to test the automatic conversion of metastore tables with parquet data to use the * built in parquet support. */ -class ParquetMetastoreSuiteBase extends ParquetPartitioningTest { +class ParquetMetastoreSuite extends ParquetPartitioningTest { + import hiveContext._ + override def beforeAll(): Unit = { super.beforeAll() - + dropTables("partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes", + "normal_parquet", + "jt", + "jt_array", + "test_parquet") sql(s""" create external table partitioned_parquet ( @@ -132,6 +139,19 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest { LOCATION '${partitionedTableDirWithKeyAndComplexTypes.getCanonicalPath}' """) + sql( + """ + |create table test_parquet + |( + | intField INT, + | stringField STRING + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + (1 to 10).foreach { p => sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)") } @@ -157,13 +177,14 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest { } override def afterAll(): Unit = { - sql("DROP TABLE partitioned_parquet") - sql("DROP TABLE partitioned_parquet_with_key") - sql("DROP TABLE partitioned_parquet_with_complextypes") - sql("DROP TABLE partitioned_parquet_with_key_and_complextypes") - sql("DROP TABLE normal_parquet") - sql("DROP TABLE IF EXISTS jt") - sql("DROP TABLE IF EXISTS jt_array") + dropTables("partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes", + "normal_parquet", + "jt", + "jt_array", + "test_parquet") setConf(HiveContext.CONVERT_METASTORE_PARQUET, false) } @@ -174,40 +195,9 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest { }.isEmpty) assert( sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { - case _: ParquetTableScan => true case _: PhysicalRDD => true }.nonEmpty) } -} - -class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { - val originalConf = conf.parquetUseDataSourceApi - - override def beforeAll(): Unit = { - super.beforeAll() - - sql( - """ - |create table test_parquet - |( - | intField INT, - | stringField STRING - |) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - """.stripMargin) - - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) - } - - override def afterAll(): Unit = { - super.afterAll() - sql("DROP TABLE IF EXISTS test_parquet") - - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) - } test("scan an empty parquet table") { checkAnswer(sql("SELECT count(*) FROM test_parquet"), Row(0)) @@ -218,6 +208,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { } test("insert into an empty parquet table") { + dropTables("test_insert_parquet") sql( """ |create table test_insert_parquet @@ -243,7 +234,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField > 2"), Row(3, "str3") :: Row(4, "str4") :: Nil ) - sql("DROP TABLE IF EXISTS test_insert_parquet") + dropTables("test_insert_parquet") // Create it again. sql( @@ -270,166 +261,166 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { sql(s"SELECT intField, stringField FROM test_insert_parquet"), (1 to 10).map(i => Row(i, s"str$i")) ++ (1 to 4).map(i => Row(i, s"str$i")) ) - sql("DROP TABLE IF EXISTS test_insert_parquet") + dropTables("test_insert_parquet") } test("scan a parquet table created through a CTAS statement") { - sql( - """ - |create table test_parquet_ctas ROW FORMAT - |SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - |AS select * from jt - """.stripMargin) + withTable("test_parquet_ctas") { + sql( + """ + |create table test_parquet_ctas ROW FORMAT + |SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + |AS select * from jt + """.stripMargin) - checkAnswer( - sql(s"SELECT a, b FROM test_parquet_ctas WHERE a = 1"), - Seq(Row(1, "str1")) - ) + checkAnswer( + sql(s"SELECT a, b FROM test_parquet_ctas WHERE a = 1"), + Seq(Row(1, "str1")) + ) - table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(_: ParquetRelation2) => // OK - case _ => fail( - "test_parquet_ctas should be converted to " + - s"${classOf[ParquetRelation2].getCanonicalName}") + table("test_parquet_ctas").queryExecution.optimizedPlan match { + case LogicalRelation(_: ParquetRelation) => // OK + case _ => fail( + "test_parquet_ctas should be converted to " + + s"${classOf[ParquetRelation].getCanonicalName }") + } } - - sql("DROP TABLE IF EXISTS test_parquet_ctas") } test("MetastoreRelation in InsertIntoTable will be converted") { - sql( - """ - |create table test_insert_parquet - |( - | intField INT - |) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - """.stripMargin) + withTable("test_insert_parquet") { + sql( + """ + |create table test_insert_parquet + |( + | intField INT + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") + df.queryExecution.executedPlan match { + case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation, _, _)) => // OK + case o => fail("test_insert_parquet should be converted to a " + + s"${classOf[ParquetRelation].getCanonicalName} and " + + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan. " + + s"However, found a ${o.toString} ") + } - val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") - df.queryExecution.executedPlan match { - case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation2, _, _)) => // OK - case o => fail("test_insert_parquet should be converted to a " + - s"${classOf[ParquetRelation2].getCanonicalName} and " + - s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan. " + - s"However, found a ${o.toString} ") + checkAnswer( + sql("SELECT intField FROM test_insert_parquet WHERE test_insert_parquet.intField > 5"), + sql("SELECT a FROM jt WHERE jt.a > 5").collect() + ) } - - checkAnswer( - sql("SELECT intField FROM test_insert_parquet WHERE test_insert_parquet.intField > 5"), - sql("SELECT a FROM jt WHERE jt.a > 5").collect() - ) - - sql("DROP TABLE IF EXISTS test_insert_parquet") } test("MetastoreRelation in InsertIntoHiveTable will be converted") { - sql( - """ - |create table test_insert_parquet - |( - | int_array array - |) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - """.stripMargin) + withTable("test_insert_parquet") { + sql( + """ + |create table test_insert_parquet + |( + | int_array array + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") + df.queryExecution.executedPlan match { + case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation, _, _)) => // OK + case o => fail("test_insert_parquet should be converted to a " + + s"${classOf[ParquetRelation].getCanonicalName} and " + + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + + s"However, found a ${o.toString} ") + } - val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") - df.queryExecution.executedPlan match { - case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation2, _, _)) => // OK - case o => fail("test_insert_parquet should be converted to a " + - s"${classOf[ParquetRelation2].getCanonicalName} and " + - s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + - s"However, found a ${o.toString} ") + checkAnswer( + sql("SELECT int_array FROM test_insert_parquet"), + sql("SELECT a FROM jt_array").collect() + ) } - - checkAnswer( - sql("SELECT int_array FROM test_insert_parquet"), - sql("SELECT a FROM jt_array").collect() - ) - - sql("DROP TABLE IF EXISTS test_insert_parquet") } test("SPARK-6450 regression test") { - sql( - """CREATE TABLE IF NOT EXISTS ms_convert (key INT) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - """.stripMargin) + withTable("ms_convert") { + sql( + """CREATE TABLE IF NOT EXISTS ms_convert (key INT) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + // This shouldn't throw AnalysisException + val analyzed = sql( + """SELECT key FROM ms_convert + |UNION ALL + |SELECT key FROM ms_convert + """.stripMargin).queryExecution.analyzed - // This shouldn't throw AnalysisException - val analyzed = sql( - """SELECT key FROM ms_convert - |UNION ALL - |SELECT key FROM ms_convert - """.stripMargin).queryExecution.analyzed - - assertResult(2) { - analyzed.collect { - case r @ LogicalRelation(_: ParquetRelation2) => r - }.size + assertResult(2) { + analyzed.collect { + case r@LogicalRelation(_: ParquetRelation) => r + }.size + } } - - sql("DROP TABLE ms_convert") } - def collectParquetRelation(df: DataFrame): ParquetRelation2 = { + def collectParquetRelation(df: DataFrame): ParquetRelation = { val plan = df.queryExecution.analyzed plan.collectFirst { - case LogicalRelation(r: ParquetRelation2) => r + case LogicalRelation(r: ParquetRelation) => r }.getOrElse { fail(s"Expecting a ParquetRelation2, but got:\n$plan") } } test("SPARK-7749: non-partitioned metastore Parquet table lookup should use cached relation") { - sql( - s"""CREATE TABLE nonPartitioned ( - | key INT, - | value STRING - |) - |STORED AS PARQUET - """.stripMargin) - - // First lookup fills the cache - val r1 = collectParquetRelation(table("nonPartitioned")) - // Second lookup should reuse the cache - val r2 = collectParquetRelation(table("nonPartitioned")) - // They should be the same instance - assert(r1 eq r2) - - sql("DROP TABLE nonPartitioned") + withTable("nonPartitioned") { + sql( + s"""CREATE TABLE nonPartitioned ( + | key INT, + | value STRING + |) + |STORED AS PARQUET + """.stripMargin) + + // First lookup fills the cache + val r1 = collectParquetRelation(table("nonPartitioned")) + // Second lookup should reuse the cache + val r2 = collectParquetRelation(table("nonPartitioned")) + // They should be the same instance + assert(r1 eq r2) + } } test("SPARK-7749: partitioned metastore Parquet table lookup should use cached relation") { - sql( - s"""CREATE TABLE partitioned ( - | key INT, - | value STRING - |) - |PARTITIONED BY (part INT) - |STORED AS PARQUET + withTable("partitioned") { + sql( + s"""CREATE TABLE partitioned ( + | key INT, + | value STRING + |) + |PARTITIONED BY (part INT) + |STORED AS PARQUET """.stripMargin) - // First lookup fills the cache - val r1 = collectParquetRelation(table("partitioned")) - // Second lookup should reuse the cache - val r2 = collectParquetRelation(table("partitioned")) - // They should be the same instance - assert(r1 eq r2) - - sql("DROP TABLE partitioned") + // First lookup fills the cache + val r1 = collectParquetRelation(table("partitioned")) + // Second lookup should reuse the cache + val r2 = collectParquetRelation(table("partitioned")) + // They should be the same instance + assert(r1 eq r2) + } } test("Caching converted data source Parquet Relations") { @@ -437,7 +428,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { // Converted test_parquet should be cached. catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) match { case null => fail("Converted test_parquet should be cached in the cache.") - case logical @ LogicalRelation(parquetRelation: ParquetRelation2) => // OK + case logical @ LogicalRelation(parquetRelation: ParquetRelation) => // OK case other => fail( "The cached test_parquet should be a Parquet Relation. " + @@ -445,8 +436,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { } } - sql("DROP TABLE IF EXISTS test_insert_parquet") - sql("DROP TABLE IF EXISTS test_parquet_partitioned_cache_test") + dropTables("test_insert_parquet", "test_parquet_partitioned_cache_test") sql( """ @@ -494,7 +484,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { | intField INT, | stringField STRING |) - |PARTITIONED BY (date string) + |PARTITIONED BY (`date` string) |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' |STORED AS | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' @@ -506,7 +496,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test - |PARTITION (date='2015-04-01') + |PARTITION (`date`='2015-04-01') |select a, b from jt """.stripMargin) // Right now, insert into a partitioned Parquet is not supported in data source Parquet. @@ -515,7 +505,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test - |PARTITION (date='2015-04-02') + |PARTITION (`date`='2015-04-02') |select a, b from jt """.stripMargin) assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) @@ -525,7 +515,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { checkCached(tableIdentifier) // Make sure we can read the data. checkAnswer( - sql("select STRINGField, date, intField from test_parquet_partitioned_cache_test"), + sql("select STRINGField, `date`, intField from test_parquet_partitioned_cache_test"), sql( """ |select b, '2015-04-01', a FROM jt @@ -536,88 +526,24 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { invalidateTable("test_parquet_partitioned_cache_test") assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) - sql("DROP TABLE test_insert_parquet") - sql("DROP TABLE test_parquet_partitioned_cache_test") - } -} - -class ParquetDataSourceOffMetastoreSuite extends ParquetMetastoreSuiteBase { - val originalConf = conf.parquetUseDataSourceApi - - override def beforeAll(): Unit = { - super.beforeAll() - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) - } - - override def afterAll(): Unit = { - super.afterAll() - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) - } - - test("MetastoreRelation in InsertIntoTable will not be converted") { - sql( - """ - |create table test_insert_parquet - |( - | intField INT - |) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - """.stripMargin) - - val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") - df.queryExecution.executedPlan match { - case insert: execution.InsertIntoHiveTable => // OK - case o => fail(s"The SparkPlan should be ${classOf[InsertIntoHiveTable].getCanonicalName}. " + - s"However, found ${o.toString}.") - } - - checkAnswer( - sql("SELECT intField FROM test_insert_parquet WHERE test_insert_parquet.intField > 5"), - sql("SELECT a FROM jt WHERE jt.a > 5").collect() - ) - - sql("DROP TABLE IF EXISTS test_insert_parquet") - } - - // TODO: enable it after the fix of SPARK-5950. - ignore("MetastoreRelation in InsertIntoHiveTable will not be converted") { - sql( - """ - |create table test_insert_parquet - |( - | int_array array - |) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - """.stripMargin) - - val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") - df.queryExecution.executedPlan match { - case insert: execution.InsertIntoHiveTable => // OK - case o => fail(s"The SparkPlan should be ${classOf[InsertIntoHiveTable].getCanonicalName}. " + - s"However, found ${o.toString}.") - } - - checkAnswer( - sql("SELECT int_array FROM test_insert_parquet"), - sql("SELECT a FROM jt_array").collect() - ) - - sql("DROP TABLE IF EXISTS test_insert_parquet") + dropTables("test_insert_parquet", "test_parquet_partitioned_cache_test") } } /** * A suite of tests for the Parquet support through the data sources API. */ -class ParquetSourceSuiteBase extends ParquetPartitioningTest { +class ParquetSourceSuite extends ParquetPartitioningTest { + import testImplicits._ + import hiveContext._ + override def beforeAll(): Unit = { super.beforeAll() + dropTables("partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes", + "normal_parquet") sql( s""" create temporary table partitioned_parquet @@ -685,19 +611,30 @@ class ParquetSourceSuiteBase extends ParquetPartitioningTest { sql("drop table spark_6016_fix") } -} -class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { - val originalConf = conf.parquetUseDataSourceApi + test("SPARK-8811: compatibility with array of struct in Hive") { + withTempPath { dir => + val path = dir.getCanonicalPath - override def beforeAll(): Unit = { - super.beforeAll() - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) - } + withTable("array_of_struct") { + val conf = Seq( + HiveContext.CONVERT_METASTORE_PARQUET.key -> "false", + SQLConf.PARQUET_BINARY_AS_STRING.key -> "true", + SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key -> "true") - override def afterAll(): Unit = { - super.afterAll() - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) + withSQLConf(conf: _*) { + sql( + s"""CREATE TABLE array_of_struct + |STORED AS PARQUET LOCATION '$path' + |AS SELECT '1st', '2nd', ARRAY(NAMED_STRUCT('a', 'val_a', 'b', 'val_b')) + """.stripMargin) + + checkAnswer( + sqlContext.read.parquet(path), + Row("1st", "2nd", Seq(Row("val_a", "val_b")))) + } + } + } } test("values in arrays and maps stored in parquet are always nullable") { @@ -707,25 +644,25 @@ class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { val expectedSchema1 = StructType( StructField("m", mapType1, nullable = true) :: - StructField("a", arrayType1, nullable = true) :: Nil) + StructField("a", arrayType1, nullable = true) :: Nil) assert(df.schema === expectedSchema1) - df.write.format("parquet").saveAsTable("alwaysNullable") - - val mapType2 = MapType(IntegerType, IntegerType, valueContainsNull = true) - val arrayType2 = ArrayType(IntegerType, containsNull = true) - val expectedSchema2 = - StructType( - StructField("m", mapType2, nullable = true) :: - StructField("a", arrayType2, nullable = true) :: Nil) + withTable("alwaysNullable") { + df.write.format("parquet").saveAsTable("alwaysNullable") - assert(table("alwaysNullable").schema === expectedSchema2) + val mapType2 = MapType(IntegerType, IntegerType, valueContainsNull = true) + val arrayType2 = ArrayType(IntegerType, containsNull = true) + val expectedSchema2 = + StructType( + StructField("m", mapType2, nullable = true) :: + StructField("a", arrayType2, nullable = true) :: Nil) - checkAnswer( - sql("SELECT m, a FROM alwaysNullable"), - Row(Map(2 -> 3), Seq(4, 5, 6))) + assert(table("alwaysNullable").schema === expectedSchema2) - sql("DROP TABLE alwaysNullable") + checkAnswer( + sql("SELECT m, a FROM alwaysNullable"), + Row(Map(2 -> 3), Seq(4, 5, 6))) + } } test("Aggregation attribute names can't contain special chars \" ,;{}()\\n\\t=\"") { @@ -745,24 +682,12 @@ class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { } } -class ParquetDataSourceOffSourceSuite extends ParquetSourceSuiteBase { - val originalConf = conf.parquetUseDataSourceApi - - override def beforeAll(): Unit = { - super.beforeAll() - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) - } - - override def afterAll(): Unit = { - super.afterAll() - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) - } -} - /** * A collection of tests for parquet data with various forms of partitioning. */ -abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll { +abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with TestHiveSingleton { + import testImplicits._ + var partitionedTableDir: File = null var normalTableDir: File = null var partitionedTableDirWithKey: File = null @@ -825,6 +750,16 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll partitionedTableDirWithKeyAndComplexTypes.delete() } + /** + * Drop named tables if they exist + * @param tableNames tables to drop + */ + def dropTables(tableNames: String*): Unit = { + tableNames.foreach { name => + sql(s"DROP TABLE IF EXISTS $name") + } + } + Seq( "partitioned_parquet", "partitioned_parquet_with_key", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala new file mode 100644 index 000000000000..dc0531a6d4bc --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.hadoop.fs.Path +import org.apache.spark.SparkException +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils + + +class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton { + + // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. + val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName + + test("SPARK-7684: commitTask() failure should fallback to abortTask()") { + withTempPath { file => + // Here we coalesce partition number to 1 to ensure that only a single task is issued. This + // prevents race condition happened when FileOutputCommitter tries to remove the `_temporary` + // directory while committing/aborting the job. See SPARK-8513 for more details. + val df = sqlContext.range(0, 10).coalesce(1) + intercept[SparkException] { + df.write.format(dataSourceName).save(file.getCanonicalPath) + } + + val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf) + assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary"))) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala new file mode 100644 index 000000000000..ef37787137d0 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.math.BigDecimal + +import org.apache.hadoop.fs.Path + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { + override val dataSourceName: String = "json" + + // JSON does not write data of NullType and does not play well with BinaryType. + override protected def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: NullType => false + case _: BinaryType => false + case _: CalendarIntervalType => false + case _ => true + } + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield s"""{"a":$i,"b":"val_$i"}""") + .saveAsTextFile(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + hiveContext.read.format(dataSourceName) + .option("dataSchema", dataSchemaWithPartition.json) + .load(file.getCanonicalPath)) + } + } + + test("SPARK-9894: save complex types to JSON") { + withTempDir { file => + file.delete() + + val schema = + new StructType() + .add("array", ArrayType(LongType)) + .add("map", MapType(StringType, new StructType().add("innerField", LongType))) + + val data = + Row(Seq(1L, 2L, 3L), Map("m1" -> Row(4L))) :: + Row(Seq(5L, 6L, 7L), Map("m2" -> Row(10L))) :: Nil + val df = hiveContext.createDataFrame(sparkContext.parallelize(data), schema) + + // Write the data out. + df.write.format(dataSourceName).save(file.getCanonicalPath) + + // Read it back and check the result. + checkAnswer( + hiveContext.read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), + df + ) + } + } + + test("SPARK-10196: save decimal type to JSON") { + withTempDir { file => + file.delete() + + val schema = + new StructType() + .add("decimal", DecimalType(7, 2)) + + val data = + Row(new BigDecimal("10.02")) :: + Row(new BigDecimal("20000.99")) :: + Row(new BigDecimal("10000")) :: Nil + val df = hiveContext.createDataFrame(sparkContext.parallelize(data), schema) + + // Write the data out. + df.write.format(dataSourceName).save(file.getCanonicalPath) + + // Read it back and check the result. + checkAnswer( + hiveContext.read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), + df + ) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala new file mode 100644 index 000000000000..e2d754e80640 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.io.File + +import com.google.common.io.Files +import org.apache.hadoop.fs.Path + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.{execution, AnalysisException, SaveMode} +import org.apache.spark.sql.types._ + + +class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { + import testImplicits._ + + override val dataSourceName: String = "parquet" + + // Parquet does not play well with NullType. + override protected def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: NullType => false + case _: CalendarIntervalType => false + case _ => true + } + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) + .toDF("a", "b", "p1") + .write.parquet(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + hiveContext.read.format(dataSourceName) + .option("dataSchema", dataSchemaWithPartition.json) + .load(file.getCanonicalPath)) + } + } + + test("SPARK-7868: _temporary directories should be ignored") { + withTempPath { dir => + val df = Seq("a", "b", "c").zipWithIndex.toDF() + + df.write + .format("parquet") + .save(dir.getCanonicalPath) + + df.write + .format("parquet") + .save(s"${dir.getCanonicalPath}/_temporary") + + checkAnswer(hiveContext.read.format("parquet").load(dir.getCanonicalPath), df.collect()) + } + } + + test("SPARK-8014: Avoid scanning output directory when SaveMode isn't SaveMode.Append") { + withTempDir { dir => + val path = dir.getCanonicalPath + val df = Seq(1 -> "a").toDF() + + // Creates an arbitrary file. If this directory gets scanned, ParquetRelation2 will throw + // since it's not a valid Parquet file. + val emptyFile = new File(path, "empty") + Files.createParentDirs(emptyFile) + Files.touch(emptyFile) + + // This shouldn't throw anything. + df.write.format("parquet").mode(SaveMode.Ignore).save(path) + + // This should only complain that the destination directory already exists, rather than file + // "empty" is not a Parquet file. + assert { + intercept[AnalysisException] { + df.write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) + }.getMessage.contains("already exists") + } + + // This shouldn't throw anything. + df.write.format("parquet").mode(SaveMode.Overwrite).save(path) + checkAnswer(hiveContext.read.format("parquet").load(path), df) + } + } + + test("SPARK-8079: Avoid NPE thrown from BaseWriterContainer.abortJob") { + withTempPath { dir => + intercept[AnalysisException] { + // Parquet doesn't allow field names with spaces. Here we are intentionally making an + // exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger + // the bug. Please refer to spark-8079 for more details. + hiveContext.range(1, 10) + .withColumnRenamed("id", "a b") + .write + .format("parquet") + .save(dir.getCanonicalPath) + } + } + } + + test("SPARK-8604: Parquet data source should write summary file while doing appending") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(0, 5) + df.write.mode(SaveMode.Overwrite).parquet(path) + + val summaryPath = new Path(path, "_metadata") + val commonSummaryPath = new Path(path, "_common_metadata") + + val fs = summaryPath.getFileSystem(hadoopConfiguration) + fs.delete(summaryPath, true) + fs.delete(commonSummaryPath, true) + + df.write.mode(SaveMode.Append).parquet(path) + checkAnswer(sqlContext.read.parquet(path), df.unionAll(df)) + + assert(fs.exists(summaryPath)) + assert(fs.exists(commonSummaryPath)) + } + } + + test("SPARK-10334 Projections and filters should be kept in physical plan") { + withTempPath { dir => + val path = dir.getCanonicalPath + + sqlContext.range(2).select('id as 'a, 'id as 'b).write.partitionBy("b").parquet(path) + val df = sqlContext.read.parquet(path).filter('a === 0).select('b) + val physicalPlan = df.queryExecution.executedPlan + + assert(physicalPlan.collect { case p: execution.Project => p }.length === 1) + assert(physicalPlan.collect { case p: execution.Filter => p }.length === 1) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala new file mode 100644 index 000000000000..a3a124488d98 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.hadoop.fs.Path + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.types._ + +class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { + override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName + + // We have a very limited number of supported types at here since it is just for a + // test relation and we do very basic testing at here. + override protected def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: BinaryType => false + // We are using random data generator and the generated strings are not really valid string. + case _: StringType => false + case _: BooleanType => false // see https://issues.apache.org/jira/browse/SPARK-10442 + case _: CalendarIntervalType => false + case _: DateType => false + case _: TimestampType => false + case _: ArrayType => false + case _: MapType => false + case _: StructType => false + case _: UserDefinedType[_] => false + case _ => true + } + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield s"$i,val_$i,$p1") + .saveAsTextFile(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + hiveContext.read.format(dataSourceName) + .option("dataSchema", dataSchemaWithPartition.json) + .load(file.getCanonicalPath)) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 0f959b3d0b86..aeaaa3e1c522 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputForma import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.spark.rdd.RDD +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types.{DataType, StructType} @@ -53,9 +54,12 @@ class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullW numberFormat.setGroupingUsed(false) override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val split = context.getTaskAttemptID.getTaskID.getId + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val split = taskAttemptId.getTaskID.getId val name = FileOutputFormat.getOutputName(context) - new Path(outputFile, s"$name-${numberFormat.format(split)}-${UUID.randomUUID()}") + new Path(outputFile, s"$name-${numberFormat.format(split)}-$uniqueWriteJobId") } } @@ -64,7 +68,9 @@ class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends new AppendingTextOutputFormat(new Path(path)).getRecordWriter(context) override def write(row: Row): Unit = { - val serialized = row.toSeq.map(_.toString).mkString(",") + val serialized = row.toSeq.map { v => + if (v == null) "" else v.toString + }.mkString(",") recordWriter.write(null, new Text(serialized)) } @@ -108,7 +114,8 @@ class SimpleTextRelation( val fields = dataSchema.map(_.dataType) sparkContext.textFile(inputStatuses.map(_.getPath).mkString(",")).map { record => - Row(record.split(",").zip(fields).map { case (value, dataType) => + Row(record.split(",", -1).zip(fields).map { case (v, dataType) => + val value = if (v == "") null else v // `Cast`ed values are always of Catalyst types (i.e. UTF8String instead of String, etc.) val catalystValue = Cast(Literal(value), dataType).eval() // Here we're converting Catalyst values to Scala values to test `needsConversion` @@ -118,6 +125,8 @@ class SimpleTextRelation( } override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory { + job.setOutputFormatClass(classOf[TextOutputFormat[_, _]]) + override def newInstance( path: String, dataSchema: StructType, @@ -156,6 +165,7 @@ class CommitFailureTestRelation( context: TaskAttemptContext): OutputWriter = { new SimpleTextOutputWriter(path, context) { override def close(): Unit = { + super.close() sys.error("Intentional task commitment failure for testing purpose.") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 76469d7a3d6a..8ffcef85668d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -17,25 +17,28 @@ package org.apache.spark.sql.sources -import java.io.File +import scala.collection.JavaConverters._ -import com.google.common.io.Files +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter +import org.apache.parquet.hadoop.ParquetOutputCommitter -import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ -abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { - override lazy val sqlContext: SQLContext = TestHive - import sqlContext.sql +abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with TestHiveSingleton { import sqlContext.implicits._ - val dataSourceName = classOf[SimpleTextSource].getCanonicalName + val dataSourceName: String + + protected def supportsDataType(dataType: DataType): Boolean = true val dataSchema = StructType( @@ -97,6 +100,83 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } } + ignore("test all data types") { + withTempPath { file => + // Create the schema. + val struct = + StructType( + StructField("f1", FloatType, true) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) + // TODO: add CalendarIntervalType to here once we can save it out. + val dataTypes = + Seq( + StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), MapType(StringType, LongType), struct, + new MyDenseVectorUDT()) + val fields = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, nullable = true) + } + val schema = StructType(fields) + + // Generate data at the driver side. We need to materialize the data first and then + // create RDD. + val maybeDataGenerator = + RandomDataGenerator.forType( + dataType = schema, + nullable = true, + seed = Some(System.nanoTime())) + val dataGenerator = + maybeDataGenerator + .getOrElse(fail(s"Failed to create data generator for schema $schema")) + val data = (1 to 10).map { i => + dataGenerator.apply() match { + case row: Row => row + case null => Row.fromSeq(Seq.fill(schema.length)(null)) + case other => + fail(s"Row or null is expected to be generated, " + + s"but a ${other.getClass.getCanonicalName} is generated.") + } + } + + // Create a DF for the schema with random data. + val rdd = sqlContext.sparkContext.parallelize(data, 10) + val df = sqlContext.createDataFrame(rdd, schema) + + // All columns that have supported data types of this source. + val supportedColumns = schema.fields.collect { + case StructField(name, dataType, _, _) if supportsDataType(dataType) => name + } + val selectedColumns = util.Random.shuffle(supportedColumns.toSeq) + + val dfToBeSaved = df.selectExpr(selectedColumns: _*) + + // Save the data out. + dfToBeSaved + .write + .format(dataSourceName) + .option("dataSchema", dfToBeSaved.schema.json) // This option is just used by tests. + .save(file.getCanonicalPath) + + val loadedDF = + sqlContext + .read + .format(dataSourceName) + .schema(dfToBeSaved.schema) + .option("dataSchema", dfToBeSaved.schema.json) // This option is just used by tests. + .load(file.getCanonicalPath) + .selectExpr(selectedColumns: _*) + + // Read the data back. + checkAnswer( + loadedDF, + dfToBeSaved + ) + } + } + test("save()/load() - non-partitioned table - Overwrite") { withTempPath { file => testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath) @@ -126,7 +206,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { test("save()/load() - non-partitioned table - ErrorIfExists") { withTempDir { file => - intercept[RuntimeException] { + intercept[AnalysisException] { testDF.write.format(dataSourceName).mode(SaveMode.ErrorIfExists).save(file.getCanonicalPath) } } @@ -225,7 +305,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { test("save()/load() - partitioned table - ErrorIfExists") { withTempDir { file => - intercept[RuntimeException] { + intercept[AnalysisException] { partitionedTestDF.write .format(dataSourceName) .mode(SaveMode.ErrorIfExists) @@ -295,6 +375,19 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } } + test("saveAsTable()/load() - partitioned table - boolean type") { + sqlContext.range(2) + .select('id, ('id % 2 === 0).as("b")) + .write.partitionBy("b").saveAsTable("t") + + withTable("t") { + checkAnswer( + sqlContext.table("t").sort('id), + Row(0, true) :: Row(1, false) :: Nil + ) + } + } + test("saveAsTable()/load() - partitioned table - Overwrite") { partitionedTestDF.write .format(dataSourceName) @@ -440,7 +533,9 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } } - test("Partition column type casting") { + // HadoopFsRelation.discoverPartitions() called by refresh(), which will ignore + // the given partition data type. + ignore("Partition column type casting") { withTempPath { file => val input = partitionedTestDF.select('a, 'b, 'p1.cast(StringType).as('ps), 'p2) @@ -470,143 +565,155 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { checkAnswer(sqlContext.table("t"), df.select('b, 'c, 'a).collect()) } } -} - -class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { - override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName - import sqlContext._ + // NOTE: This test suite is not super deterministic. On nodes with only relatively few cores + // (4 or even 1), it's hard to reproduce the data loss issue. But on nodes with for example 8 or + // more cores, the issue can be reproduced steadily. Fortunately our Jenkins builder meets this + // requirement. We probably want to move this test case to spark-integration-tests or spark-perf + // later. + test("SPARK-8406: Avoids name collision while writing files") { + withTempPath { dir => + val path = dir.getCanonicalPath + sqlContext + .range(10000) + .repartition(250) + .write + .mode(SaveMode.Overwrite) + .format(dataSourceName) + .save(path) - test("save()/load() - partitioned table - simple queries - partition columns in data") { - withTempDir { file => - val basePath = new Path(file.getCanonicalPath) - val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) - val qualifiedBasePath = fs.makeQualified(basePath) - - for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { - val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") - sparkContext - .parallelize(for (i <- 1 to 3) yield s"$i,val_$i,$p1") - .saveAsTextFile(partitionDir.toString) + assertResult(10000) { + sqlContext + .read + .format(dataSourceName) + .option("dataSchema", StructType(StructField("id", LongType) :: Nil).json) + .load(path) + .count() } - - val dataSchemaWithPartition = - StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) - - checkQueries( - read.format(dataSourceName) - .option("dataSchema", dataSchemaWithPartition.json) - .load(file.getCanonicalPath)) } } -} -class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { - import TestHive.implicits._ - - override val sqlContext = TestHive - - val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName - - test("SPARK-7684: commitTask() failure should fallback to abortTask()") { - withTempPath { file => - val df = (1 to 3).map(i => i -> s"val_$i").toDF("a", "b") - intercept[SparkException] { - df.write.format(dataSourceName).save(file.getCanonicalPath) + test("SPARK-8578 specified custom output committer will not be used to append data") { + val clonedConf = new Configuration(hadoopConfiguration) + try { + val df = sqlContext.range(1, 10).toDF("i") + withTempPath { dir => + df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) + hadoopConfiguration.set( + SQLConf.OUTPUT_COMMITTER_CLASS.key, + classOf[AlwaysFailOutputCommitter].getName) + // Since Parquet has its own output committer setting, also set it + // to AlwaysFailParquetOutputCommitter at here. + hadoopConfiguration.set("spark.sql.parquet.output.committer.class", + classOf[AlwaysFailParquetOutputCommitter].getName) + // Because there data already exists, + // this append should succeed because we will use the output committer associated + // with file format and AlwaysFailOutputCommitter will not be used. + df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) + checkAnswer( + sqlContext.read + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .load(dir.getCanonicalPath), + df.unionAll(df)) + + // This will fail because AlwaysFailOutputCommitter is used when we do append. + intercept[Exception] { + df.write.mode("overwrite").format(dataSourceName).save(dir.getCanonicalPath) + } } - - val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf) - assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary"))) + withTempPath { dir => + hadoopConfiguration.set( + SQLConf.OUTPUT_COMMITTER_CLASS.key, + classOf[AlwaysFailOutputCommitter].getName) + // Since Parquet has its own output committer setting, also set it + // to AlwaysFailParquetOutputCommitter at here. + hadoopConfiguration.set("spark.sql.parquet.output.committer.class", + classOf[AlwaysFailParquetOutputCommitter].getName) + // Because there is no existing data, + // this append will fail because AlwaysFailOutputCommitter is used when we do append + // and there is no existing data. + intercept[Exception] { + df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) + } + } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) } } -} -class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { - override val dataSourceName: String = classOf[parquet.DefaultSource].getCanonicalName - - import sqlContext._ - import sqlContext.implicits._ - - test("save()/load() - partitioned table - simple queries - partition columns in data") { + test("SPARK-8887: Explicitly define which data types can be used as dynamic partition columns") { + val df = Seq( + (1, "v1", Array(1, 2, 3), Map("k1" -> "v1"), Tuple2(1, "4")), + (2, "v2", Array(4, 5, 6), Map("k2" -> "v2"), Tuple2(2, "5")), + (3, "v3", Array(7, 8, 9), Map("k3" -> "v3"), Tuple2(3, "6"))).toDF("a", "b", "c", "d", "e") withTempDir { file => - val basePath = new Path(file.getCanonicalPath) - val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) - val qualifiedBasePath = fs.makeQualified(basePath) - - for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { - val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") - sparkContext - .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) - .toDF("a", "b", "p1") - .write.parquet(partitionDir.toString) + intercept[AnalysisException] { + df.write.format(dataSourceName).partitionBy("c", "d", "e").save(file.getCanonicalPath) } - - val dataSchemaWithPartition = - StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) - - checkQueries( - read.format(dataSourceName) - .option("dataSchema", dataSchemaWithPartition.json) - .load(file.getCanonicalPath)) + } + intercept[AnalysisException] { + df.write.format(dataSourceName).partitionBy("c", "d", "e").saveAsTable("t") } } - test("SPARK-7868: _temporary directories should be ignored") { - withTempPath { dir => - val df = Seq("a", "b", "c").zipWithIndex.toDF() - - df.write - .format("parquet") - .save(dir.getCanonicalPath) - - df.write - .format("parquet") - .save(s"${dir.getCanonicalPath}/_temporary") - - checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect()) + test("SPARK-9899 Disable customized output committer when speculation is on") { + val clonedConf = new Configuration(hadoopConfiguration) + val speculationEnabled = + sqlContext.sparkContext.conf.getBoolean("spark.speculation", defaultValue = false) + + try { + withTempPath { dir => + // Enables task speculation + sqlContext.sparkContext.conf.set("spark.speculation", "true") + + // Uses a customized output committer which always fails + hadoopConfiguration.set( + SQLConf.OUTPUT_COMMITTER_CLASS.key, + classOf[AlwaysFailOutputCommitter].getName) + + // Code below shouldn't throw since customized output committer should be disabled. + val df = sqlContext.range(10).coalesce(1) + df.write.format(dataSourceName).save(dir.getCanonicalPath) + checkAnswer( + sqlContext + .read + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .load(dir.getCanonicalPath), + df) + } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + sqlContext.sparkContext.conf.set("spark.speculation", speculationEnabled.toString) } } +} - test("SPARK-8014: Avoid scanning output directory when SaveMode isn't SaveMode.Append") { - withTempDir { dir => - val path = dir.getCanonicalPath - val df = Seq(1 -> "a").toDF() - - // Creates an arbitrary file. If this directory gets scanned, ParquetRelation2 will throw - // since it's not a valid Parquet file. - val emptyFile = new File(path, "empty") - Files.createParentDirs(emptyFile) - Files.touch(emptyFile) - - // This shouldn't throw anything. - df.write.format("parquet").mode(SaveMode.Ignore).save(path) - - // This should only complain that the destination directory already exists, rather than file - // "empty" is not a Parquet file. - assert { - intercept[RuntimeException] { - df.write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) - }.getMessage.contains("already exists") - } +// This class is used to test SPARK-8578. We should not use any custom output committer when +// we actually append data to an existing dir. +class AlwaysFailOutputCommitter( + outputPath: Path, + context: TaskAttemptContext) + extends FileOutputCommitter(outputPath, context) { - // This shouldn't throw anything. - df.write.format("parquet").mode(SaveMode.Overwrite).save(path) - checkAnswer(read.format("parquet").load(path), df) - } + override def commitJob(context: JobContext): Unit = { + sys.error("Intentional job commitment failure for testing purpose.") } +} - test("SPARK-8079: Avoid NPE thrown from BaseWriterContainer.abortJob") { - withTempPath { dir => - intercept[AnalysisException] { - // Parquet doesn't allow field names with spaces. Here we are intentionally making an - // exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger - // the bug. Please refer to spark-8079 for more details. - range(1, 10) - .withColumnRenamed("id", "a b") - .write - .format("parquet") - .save(dir.getCanonicalPath) - } - } +// This class is used to test SPARK-8578. We should not use any custom output committer when +// we actually append data to an existing dir. +class AlwaysFailParquetOutputCommitter( + outputPath: Path, + context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { + + override def commitJob(context: JobContext): Unit = { + sys.error("Intentional job commitment failure for testing purpose.") } } diff --git a/streaming/pom.xml b/streaming/pom.xml index 697895e72fe5..5cc9001b0e9a 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContextState.java b/streaming/src/main/java/org/apache/spark/streaming/StreamingContextState.java similarity index 100% rename from streaming/src/main/scala/org/apache/spark/streaming/StreamingContextState.java rename to streaming/src/main/java/org/apache/spark/streaming/StreamingContextState.java diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java index 8c0fdfa9c747..3738fc1a235c 100644 --- a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java +++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java @@ -21,6 +21,8 @@ import java.util.Iterator; /** + * :: DeveloperApi :: + * * This abstract class represents a write ahead log (aka journal) that is used by Spark Streaming * to save the received data (by receivers) and associated metadata to a reliable storage, so that * they can be recovered after driver failures. See the Spark documentation for more information diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java index 02324189b782..662889e779fb 100644 --- a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java +++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java @@ -18,6 +18,8 @@ package org.apache.spark.streaming.util; /** + * :: DeveloperApi :: + * * This abstract class represents a handle that refers to a record written in a * {@link org.apache.spark.streaming.util.WriteAheadLog WriteAheadLog}. * It must contain all the information necessary for the record to be read and returned by diff --git a/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js index 75251f493ad2..4886b68eeaf7 100644 --- a/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js +++ b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js @@ -31,6 +31,8 @@ var maxXForHistogram = 0; var histogramBinCount = 10; var yValueFormat = d3.format(",.2f"); +var unitLabelYOffset = -10; + // Show a tooltip "text" for "node" function showBootstrapTooltip(node, text) { $(node).tooltip({title: text, trigger: "manual", container: "body"}); @@ -133,7 +135,7 @@ function drawTimeline(id, data, minX, maxX, minY, maxY, unitY, batchInterval) { .attr("class", "y axis") .call(yAxis) .append("text") - .attr("transform", "translate(0," + (-3) + ")") + .attr("transform", "translate(0," + unitLabelYOffset + ")") .text(unitY); @@ -223,10 +225,10 @@ function drawHistogram(id, values, minY, maxY, unitY, batchInterval) { .style("border-left", "0px solid white"); var margin = {top: 20, right: 30, bottom: 30, left: 10}; - var width = 300 - margin.left - margin.right; + var width = 350 - margin.left - margin.right; var height = 150 - margin.top - margin.bottom; - var x = d3.scale.linear().domain([0, maxXForHistogram]).range([0, width]); + var x = d3.scale.linear().domain([0, maxXForHistogram]).range([0, width - 50]); var y = d3.scale.linear().domain([minY, maxY]).range([height, 0]); var xAxis = d3.svg.axis().scale(x).orient("top").ticks(5); @@ -248,7 +250,7 @@ function drawHistogram(id, values, minY, maxY, unitY, batchInterval) { .attr("class", "x axis") .call(xAxis) .append("text") - .attr("transform", "translate(" + (margin.left + width - 40) + ", 15)") + .attr("transform", "translate(" + (margin.left + width - 45) + ", " + unitLabelYOffset + ")") .text("#batches"); svg.append("g") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index d8dc4e410166..8a6050f5227b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -25,13 +25,14 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.conf.Configuration import org.apache.spark.{SparkException, SparkConf, Logging} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec import org.apache.spark.util.{MetadataCleaner, Utils} import org.apache.spark.streaming.scheduler.JobGenerator private[streaming] -class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) +class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) extends Logging with Serializable { val master = ssc.sc.master val framework = ssc.sc.appName @@ -44,11 +45,27 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) val sparkConfPairs = ssc.conf.getAll def createSparkConf(): SparkConf = { + + // Reload properties for the checkpoint application since user wants to set a reload property + // or spark had changed its value and user wants to set it back. + val propertiesToReload = List( + "spark.yarn.app.id", + "spark.yarn.app.attemptId", + "spark.driver.host", + "spark.driver.port", + "spark.master", + "spark.yarn.keytab", + "spark.yarn.principal") + val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) .remove("spark.driver.host") .remove("spark.driver.port") - val newMasterOption = new SparkConf(loadDefaults = true).getOption("spark.master") - newMasterOption.foreach { newMaster => newSparkConf.setMaster(newMaster) } + val newReloadConf = new SparkConf(loadDefaults = true) + propertiesToReload.foreach { prop => + newReloadConf.getOption(prop).foreach { value => + newSparkConf.set(prop, value) + } + } newSparkConf } @@ -86,7 +103,7 @@ object Checkpoint extends Logging { } val path = new Path(checkpointDir) - val fs = fsOption.getOrElse(path.getFileSystem(new Configuration())) + val fs = fsOption.getOrElse(path.getFileSystem(SparkHadoopUtil.get.conf)) if (fs.exists(path)) { val statuses = fs.listStatus(path) if (statuses != null) { @@ -177,7 +194,9 @@ class CheckpointWriter( + "'") // Write checkpoint to temp file - fs.delete(tempFile, true) // just in case it exists + if (fs.exists(tempFile)) { + fs.delete(tempFile, true) // just in case it exists + } val fos = fs.create(tempFile) Utils.tryWithSafeFinally { fos.write(bytes) @@ -188,7 +207,9 @@ class CheckpointWriter( // If the checkpoint file exists, back it up // If the backup exists as well, just delete it, otherwise rename will fail if (fs.exists(checkpointFile)) { - fs.delete(backupFile, true) // just in case it exists + if (fs.exists(backupFile)){ + fs.delete(backupFile, true) // just in case it exists + } if (!fs.rename(checkpointFile, backupFile)) { logWarning("Could not rename " + checkpointFile + " to " + backupFile) } @@ -267,6 +288,15 @@ class CheckpointWriter( private[streaming] object CheckpointReader extends Logging { + /** + * Read checkpoint files present in the given checkpoint directory. If there are no checkpoint + * files, then return None, else try to return the latest valid checkpoint object. If no + * checkpoint files could be read correctly, then return None. + */ + def read(checkpointDir: String): Option[Checkpoint] = { + read(checkpointDir, new SparkConf(), SparkHadoopUtil.get.conf, ignoreReadError = true) + } + /** * Read checkpoint files present in the given checkpoint directory. If there are no checkpoint * files, then return None, else try to return the latest valid checkpoint object. If no @@ -291,7 +321,7 @@ object CheckpointReader extends Logging { // Try to read the checkpoint files in the order logInfo("Checkpoint files found: " + checkpointFiles.mkString(",")) - val compressionCodec = CompressionCodec.createCodec(conf) + var readError: Exception = null checkpointFiles.foreach(file => { logInfo("Attempting to load checkpoint from file " + file) try { @@ -302,13 +332,15 @@ object CheckpointReader extends Logging { return Some(cp) } catch { case e: Exception => + readError = e logWarning("Error reading checkpoint from file " + file, e) } }) // If none of checkpoint files could be read, then throw exception if (!ignoreReadError) { - throw new SparkException(s"Failed to read checkpoint from directory $checkpointPath") + throw new SparkException( + s"Failed to read checkpoint from directory $checkpointPath", readError) } None } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 1708f309fc00..b496d1f341a0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -34,6 +34,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark._ import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.input.FixedLengthBinaryInputFormat import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.serializer.SerializationDebugger @@ -43,7 +44,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} -import org.apache.spark.util.{CallSite, Utils} +import org.apache.spark.util.{CallSite, ShutdownHookManager, Utils} /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -110,7 +111,7 @@ class StreamingContext private[streaming] ( * Recreate a StreamingContext from a checkpoint file. * @param path Path to the directory that was specified as the checkpoint directory */ - def this(path: String) = this(path, new Configuration) + def this(path: String) = this(path, SparkHadoopUtil.get.conf) /** * Recreate a StreamingContext from a checkpoint file using an existing SparkContext. @@ -192,11 +193,8 @@ class StreamingContext private[streaming] ( None } - /** Register streaming source to metrics system */ + /* Initializing a streamingSource to register metrics */ private val streamingSource = new StreamingSource(this) - assert(env != null) - assert(env.metricsSystem != null) - env.metricsSystem.registerSource(streamingSource) private var state: StreamingContextState = INITIALIZED @@ -204,6 +202,8 @@ class StreamingContext private[streaming] ( private var shutdownHookRef: AnyRef = _ + conf.getOption("spark.streaming.checkpoint.directory").foreach(checkpoint) + /** * Return the associated Spark context */ @@ -477,6 +477,10 @@ class StreamingContext private[streaming] ( /** * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. + * + * NOTE: Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @tparam T Type of objects in the RDD @@ -491,6 +495,10 @@ class StreamingContext private[streaming] ( /** * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. + * + * NOTE: Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @param defaultRDD Default RDD is returned by the DStream when the queue is empty. @@ -596,8 +604,11 @@ class StreamingContext private[streaming] ( } StreamingContext.setActiveContext(this) } - shutdownHookRef = Utils.addShutdownHook( + shutdownHookRef = ShutdownHookManager.addShutdownHook( StreamingContext.SHUTDOWN_HOOK_PRIORITY)(stopOnShutdown) + // Registering Streaming Metrics at the start of the StreamingContext + assert(env.metricsSystem != null) + env.metricsSystem.registerSource(streamingSource) uiTab.foreach(_.attach()) logInfo("StreamingContext started") case ACTIVE => @@ -674,11 +685,13 @@ class StreamingContext private[streaming] ( logWarning("StreamingContext has already been stopped") case ACTIVE => scheduler.stop(stopGracefully) + // Removing the streamingSource to de-register the metrics on stop() + env.metricsSystem.removeSource(streamingSource) uiTab.foreach(_.detach()) StreamingContext.setActiveContext(null) waiter.notifyStop() if (shutdownHookRef != null) { - Utils.removeShutdownHook(shutdownHookRef) + ShutdownHookManager.removeShutdownHook(shutdownHookRef) } logInfo("StreamingContext stopped successfully") } @@ -712,7 +725,7 @@ object StreamingContext extends Logging { */ private val ACTIVATION_LOCK = new Object() - private val SHUTDOWN_HOOK_PRIORITY = Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY + 1 + private val SHUTDOWN_HOOK_PRIORITY = ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY + 1 private val activeContext = new AtomicReference[StreamingContext](null) @@ -791,7 +804,7 @@ object StreamingContext extends Logging { def getActiveOrCreate( checkpointPath: String, creatingFunc: () => StreamingContext, - hadoopConf: Configuration = new Configuration(), + hadoopConf: Configuration = SparkHadoopUtil.get.conf, createOnError: Boolean = false ): StreamingContext = { ACTIVATION_LOCK.synchronized { @@ -816,7 +829,7 @@ object StreamingContext extends Logging { def getOrCreate( checkpointPath: String, creatingFunc: () => StreamingContext, - hadoopConf: Configuration = new Configuration(), + hadoopConf: Configuration = SparkHadoopUtil.get.conf, createOnError: Boolean = false ): StreamingContext = { val checkpointOption = CheckpointReader.read( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index 808dcc174cf9..edfa474677f1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -17,11 +17,10 @@ package org.apache.spark.streaming.api.java -import java.util import java.lang.{Long => JLong} import java.util.{List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -145,8 +144,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * an array. */ def glom(): JavaDStream[JList[T]] = - new JavaDStream(dstream.glom().map(x => new java.util.ArrayList[T](x.toSeq))) - + new JavaDStream(dstream.glom().map(_.toSeq.asJava)) /** Return the [[org.apache.spark.streaming.StreamingContext]] associated with this DStream */ @@ -191,7 +189,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaDStream[U] = { def fn: (Iterator[T]) => Iterator[U] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } new JavaDStream(dstream.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -204,7 +202,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]) : JavaPairDStream[K2, V2] = { def fn: (Iterator[T]) => Iterator[(K2, V2)] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } new JavaPairDStream(dstream.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -282,7 +280,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * Return all the RDDs between 'fromDuration' to 'toDuration' (both included) */ def slice(fromTime: Time, toTime: Time): JList[R] = { - new util.ArrayList(dstream.slice(fromTime, toTime).map(wrapRDD(_)).toSeq) + dstream.slice(fromTime, toTime).map(wrapRDD).asJava } /** @@ -291,7 +289,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * * @deprecated As of release 0.9.0, replaced by foreachRDD */ - @Deprecated + @deprecated("Use foreachRDD", "0.9.0") def foreach(foreachFunc: JFunction[R, Void]) { foreachRDD(foreachFunc) } @@ -302,7 +300,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * * @deprecated As of release 0.9.0, replaced by foreachRDD */ - @Deprecated + @deprecated("Use foreachRDD", "0.9.0") def foreach(foreachFunc: JFunction2[R, Time, Void]) { foreachRDD(foreachFunc) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index 959ac9c177f8..e2aec6c2f63e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.api.java import java.lang.{Long => JLong, Iterable => JIterable} import java.util.{List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -116,14 +116,14 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * generate the RDDs with Spark's default number of partitions. */ def groupByKey(): JavaPairDStream[K, JIterable[V]] = - dstream.groupByKey().mapValues(asJavaIterable _) + dstream.groupByKey().mapValues(_.asJava) /** * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to * generate the RDDs with `numPartitions` partitions. */ def groupByKey(numPartitions: Int): JavaPairDStream[K, JIterable[V]] = - dstream.groupByKey(numPartitions).mapValues(asJavaIterable _) + dstream.groupByKey(numPartitions).mapValues(_.asJava) /** * Return a new DStream by applying `groupByKey` on each RDD of `this` DStream. @@ -132,7 +132,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * is used to control the partitioning of each RDD. */ def groupByKey(partitioner: Partitioner): JavaPairDStream[K, JIterable[V]] = - dstream.groupByKey(partitioner).mapValues(asJavaIterable _) + dstream.groupByKey(partitioner).mapValues(_.asJava) /** * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are @@ -197,7 +197,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * batching interval */ def groupByKeyAndWindow(windowDuration: Duration): JavaPairDStream[K, JIterable[V]] = { - dstream.groupByKeyAndWindow(windowDuration).mapValues(asJavaIterable _) + dstream.groupByKeyAndWindow(windowDuration).mapValues(_.asJava) } /** @@ -212,7 +212,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( */ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration) : JavaPairDStream[K, JIterable[V]] = { - dstream.groupByKeyAndWindow(windowDuration, slideDuration).mapValues(asJavaIterable _) + dstream.groupByKeyAndWindow(windowDuration, slideDuration).mapValues(_.asJava) } /** @@ -228,8 +228,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( */ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) : JavaPairDStream[K, JIterable[V]] = { - dstream.groupByKeyAndWindow(windowDuration, slideDuration, numPartitions) - .mapValues(asJavaIterable _) + dstream.groupByKeyAndWindow(windowDuration, slideDuration, numPartitions).mapValues(_.asJava) } /** @@ -248,8 +247,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( slideDuration: Duration, partitioner: Partitioner ): JavaPairDStream[K, JIterable[V]] = { - dstream.groupByKeyAndWindow(windowDuration, slideDuration, partitioner) - .mapValues(asJavaIterable _) + dstream.groupByKeyAndWindow(windowDuration, slideDuration, partitioner).mapValues(_.asJava) } /** @@ -431,7 +429,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( private def convertUpdateStateFunction[S](in: JFunction2[JList[V], Optional[S], Optional[S]]): (Seq[V], Option[S]) => Option[S] = { val scalaFunc: (Seq[V], Option[S]) => Option[S] = (values, state) => { - val list: JList[V] = values + val list: JList[V] = values.asJava val scalaState: Optional[S] = JavaUtils.optionToOptional(state) val result: Optional[S] = in.apply(list, scalaState) result.isPresent match { @@ -539,7 +537,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( */ def cogroup[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (JIterable[V], JIterable[W])] = { implicit val cm: ClassTag[W] = fakeClassTag - dstream.cogroup(other.dstream).mapValues(t => (asJavaIterable(t._1), asJavaIterable((t._2)))) + dstream.cogroup(other.dstream).mapValues(t => (t._1.asJava, t._2.asJava)) } /** @@ -551,8 +549,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( numPartitions: Int ): JavaPairDStream[K, (JIterable[V], JIterable[W])] = { implicit val cm: ClassTag[W] = fakeClassTag - dstream.cogroup(other.dstream, numPartitions) - .mapValues(t => (asJavaIterable(t._1), asJavaIterable((t._2)))) + dstream.cogroup(other.dstream, numPartitions).mapValues(t => (t._1.asJava, t._2.asJava)) } /** @@ -564,8 +561,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( partitioner: Partitioner ): JavaPairDStream[K, (JIterable[V], JIterable[W])] = { implicit val cm: ClassTag[W] = fakeClassTag - dstream.cogroup(other.dstream, partitioner) - .mapValues(t => (asJavaIterable(t._1), asJavaIterable((t._2)))) + dstream.cogroup(other.dstream, partitioner).mapValues(t => (t._1.asJava, t._2.asJava)) } /** @@ -788,7 +784,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[F], - conf: Configuration = new Configuration) { + conf: Configuration = dstream.context.sparkContext.hadoopConfiguration) { dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 989e3a729ebc..13f371f29603 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -21,7 +21,7 @@ import java.lang.{Boolean => JBoolean} import java.io.{Closeable, InputStream} import java.util.{List => JList, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import akka.actor.{Props, SupervisorStrategy} @@ -33,6 +33,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} import org.apache.spark.api.java.function.{Function0 => JFunction0} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ @@ -114,7 +115,13 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { sparkHome: String, jars: Array[String], environment: JMap[String, String]) = - this(new StreamingContext(master, appName, batchDuration, sparkHome, jars, environment)) + this(new StreamingContext( + master, + appName, + batchDuration, + sparkHome, + jars, + environment.asScala)) /** * Create a JavaStreamingContext using an existing JavaSparkContext. @@ -136,7 +143,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Recreate a JavaStreamingContext from a checkpoint file. * @param path Path to the directory that was specified as the checkpoint directory */ - def this(path: String) = this(new StreamingContext(path, new Configuration)) + def this(path: String) = this(new StreamingContext(path, SparkHadoopUtil.get.conf)) /** * Re-creates a JavaStreamingContext from a checkpoint file. @@ -196,7 +203,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { converter: JFunction[InputStream, java.lang.Iterable[T]], storageLevel: StorageLevel) : JavaReceiverInputDStream[T] = { - def fn: (InputStream) => Iterator[T] = (x: InputStream) => converter.call(x).toIterator + def fn: (InputStream) => Iterator[T] = (x: InputStream) => converter.call(x).iterator().asScala implicit val cmt: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] ssc.socketStream(hostname, port, fn, storageLevel) @@ -419,7 +426,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create an input stream from an queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: changes to the queue after the stream is created will not be recognized. + * NOTE: + * 1. Changes to the queue after the stream is created will not be recognized. + * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @tparam T Type of objects in the RDD */ @@ -427,7 +438,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] val sQueue = new scala.collection.mutable.Queue[RDD[T]] - sQueue.enqueue(queue.map(_.rdd).toSeq: _*) + sQueue.enqueue(queue.asScala.map(_.rdd).toSeq: _*) ssc.queueStream(sQueue) } @@ -435,7 +446,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create an input stream from an queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: changes to the queue after the stream is created will not be recognized. + * NOTE: + * 1. Changes to the queue after the stream is created will not be recognized. + * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @tparam T Type of objects in the RDD @@ -447,7 +462,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] val sQueue = new scala.collection.mutable.Queue[RDD[T]] - sQueue.enqueue(queue.map(_.rdd).toSeq: _*) + sQueue.enqueue(queue.asScala.map(_.rdd).toSeq: _*) ssc.queueStream(sQueue, oneAtATime) } @@ -455,7 +470,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create an input stream from an queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: changes to the queue after the stream is created will not be recognized. + * NOTE: + * 1. Changes to the queue after the stream is created will not be recognized. + * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @param defaultRDD Default RDD is returned by the DStream when the queue is empty @@ -468,7 +487,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] val sQueue = new scala.collection.mutable.Queue[RDD[T]] - sQueue.enqueue(queue.map(_.rdd).toSeq: _*) + sQueue.enqueue(queue.asScala.map(_.rdd).toSeq: _*) ssc.queueStream(sQueue, oneAtATime, defaultRDD.rdd) } @@ -487,7 +506,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create a unified DStream from multiple DStreams of the same type and same slide duration. */ def union[T](first: JavaDStream[T], rest: JList[JavaDStream[T]]): JavaDStream[T] = { - val dstreams: Seq[DStream[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.dstream) + val dstreams: Seq[DStream[T]] = (Seq(first) ++ rest.asScala).map(_.dstream) implicit val cm: ClassTag[T] = first.classTag ssc.union(dstreams)(cm) } @@ -499,7 +518,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { first: JavaPairDStream[K, V], rest: JList[JavaPairDStream[K, V]] ): JavaPairDStream[K, V] = { - val dstreams: Seq[DStream[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.dstream) + val dstreams: Seq[DStream[(K, V)]] = (Seq(first) ++ rest.asScala).map(_.dstream) implicit val cm: ClassTag[(K, V)] = first.classTag implicit val kcm: ClassTag[K] = first.kManifest implicit val vcm: ClassTag[V] = first.vManifest @@ -521,12 +540,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { ): JavaDStream[T] = { implicit val cmt: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val scalaDStreams = dstreams.map(_.dstream).toSeq val scalaTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { - val jrdds = rdds.map(rdd => JavaRDD.fromRDD[AnyRef](rdd.asInstanceOf[RDD[AnyRef]])).toList + val jrdds = rdds.map(JavaRDD.fromRDD(_)).asJava transformFunc.call(jrdds, time).rdd } - ssc.transform(scalaDStreams, scalaTransformFunc) + ssc.transform(dstreams.asScala.map(_.dstream).toSeq, scalaTransformFunc) } /** @@ -546,12 +564,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] implicit val cmv: ClassTag[V] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] - val scalaDStreams = dstreams.map(_.dstream).toSeq val scalaTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { - val jrdds = rdds.map(rdd => JavaRDD.fromRDD[AnyRef](rdd.asInstanceOf[RDD[AnyRef]])).toList + val jrdds = rdds.map(JavaRDD.fromRDD(_)).asJava transformFunc.call(jrdds, time).rdd } - ssc.transform(scalaDStreams, scalaTransformFunc) + ssc.transform(dstreams.asScala.map(_.dstream).toSeq, scalaTransformFunc) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index d06401245ff1..dfc569451df8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -20,14 +20,13 @@ package org.apache.spark.streaming.api.python import java.io.{ObjectInputStream, ObjectOutputStream} import java.lang.reflect.Proxy import java.util.{ArrayList => JArrayList, List => JList} -import scala.collection.JavaConversions._ + import scala.collection.JavaConverters._ import scala.language.existentials import py4j.GatewayServer import org.apache.spark.api.java._ -import org.apache.spark.api.python._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Interval, Duration, Time} @@ -161,7 +160,7 @@ private[python] object PythonDStream { */ def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = { val queue = new java.util.LinkedList[JavaRDD[Array[Byte]]] - rdds.forall(queue.add(_)) + rdds.asScala.foreach(queue.add) queue } } @@ -171,7 +170,7 @@ private[python] object PythonDStream { */ private[python] abstract class PythonDStream( parent: DStream[_], - @transient pfunc: PythonTransformFunction) + pfunc: PythonTransformFunction) extends DStream[Array[Byte]] (parent.ssc) { val func = new TransformFunction(pfunc) @@ -188,7 +187,7 @@ private[python] abstract class PythonDStream( */ private[python] class PythonTransformedDStream ( parent: DStream[_], - @transient pfunc: PythonTransformFunction) + pfunc: PythonTransformFunction) extends PythonDStream(parent, pfunc) { override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { @@ -207,7 +206,7 @@ private[python] class PythonTransformedDStream ( private[python] class PythonTransformed2DStream( parent: DStream[_], parent2: DStream[_], - @transient pfunc: PythonTransformFunction) + pfunc: PythonTransformFunction) extends DStream[Array[Byte]] (parent.ssc) { val func = new TransformFunction(pfunc) @@ -231,7 +230,7 @@ private[python] class PythonTransformed2DStream( */ private[python] class PythonStateDStream( parent: DStream[Array[Byte]], - @transient reduceFunc: PythonTransformFunction) + reduceFunc: PythonTransformFunction) extends PythonDStream(parent, reduceFunc) { super.persist(StorageLevel.MEMORY_ONLY) @@ -253,8 +252,8 @@ private[python] class PythonStateDStream( */ private[python] class PythonReducedWindowedDStream( parent: DStream[Array[Byte]], - @transient preduceFunc: PythonTransformFunction, - @transient pinvReduceFunc: PythonTransformFunction, + preduceFunc: PythonTransformFunction, + @transient private val pinvReduceFunc: PythonTransformFunction, _windowDuration: Duration, _slideDuration: Duration) extends PythonDStream(parent, preduceFunc) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 192aa6a139bc..1da0b0a54df0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -720,12 +720,14 @@ abstract class DStream[T: ClassTag] ( def foreachFunc: (RDD[T], Time) => Unit = { (rdd: RDD[T], time: Time) => { val firstNum = rdd.take(num + 1) + // scalastyle:off println println("-------------------------------------------") println("Time: " + time) println("-------------------------------------------") firstNum.take(num).foreach(println) if (firstNum.length > num) println("...") println() + // scalastyle:on println } } new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 86a8e2beff57..40208a64861f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark.rdd.{RDD, UnionRDD} import org.apache.spark.streaming._ +import org.apache.spark.streaming.scheduler.StreamInputInfo import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Utils} /** @@ -69,7 +70,7 @@ import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Uti */ private[streaming] class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, directory: String, filter: Path => Boolean = FileInputDStream.defaultFilter, newFilesOnly: Boolean = true, @@ -85,8 +86,10 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( * Files with mod times older than this "window" of remembering will be ignored. So if new * files are visible within this window, then the file will get selected in the next batch. */ - private val minRememberDurationS = - Seconds(ssc.conf.getTimeAsSeconds("spark.streaming.minRememberDuration", "60s")) + private val minRememberDurationS = { + Seconds(ssc.conf.getTimeAsSeconds("spark.streaming.fileStream.minRememberDuration", + ssc.conf.get("spark.streaming.minRememberDuration", "60s"))) + } // This is a def so that it works during checkpoint recovery: private def clock = ssc.scheduler.clock @@ -144,7 +147,14 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( logInfo("New files at time " + validTime + ":\n" + newFiles.mkString("\n")) batchTimeToSelectedFiles += ((validTime, newFiles)) recentlySelectedFiles ++= newFiles - Some(filesToRDD(newFiles)) + val rdds = Some(filesToRDD(newFiles)) + // Copy newFiles to immutable.List to prevent from being modified by the user + val metadata = Map( + "files" -> newFiles.toList, + StreamInputInfo.METADATA_KEY_DESCRIPTION -> newFiles.mkString("\n")) + val inputInfo = StreamInputInfo(id, 0, metadata) + ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) + rdds } /** Clear the old time-to-files mappings along with old RDDs */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index d58c99a8ff32..95994c983c0c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -21,7 +21,9 @@ import scala.reflect.ClassTag import org.apache.spark.SparkContext import org.apache.spark.rdd.RDDOperationScope -import org.apache.spark.streaming.{Time, Duration, StreamingContext} +import org.apache.spark.streaming.{Duration, StreamingContext, Time} +import org.apache.spark.streaming.scheduler.RateController +import org.apache.spark.streaming.scheduler.rate.RateEstimator import org.apache.spark.util.Utils /** @@ -37,7 +39,7 @@ import org.apache.spark.util.Utils * * @param ssc_ Streaming context that will execute this input stream */ -abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) +abstract class InputDStream[T: ClassTag] (ssc_ : StreamingContext) extends DStream[T](ssc_) { private[streaming] var lastValidTime: Time = null @@ -47,6 +49,9 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) /** This is an unique identifier for the input stream. */ val id = ssc.getNewInputStreamId() + // Keep track of the freshest rate for this stream using the rateEstimator + protected[streaming] val rateController: Option[RateController] = None + /** A human-readable name of this InputDStream */ private[streaming] def name: String = { // e.g. FlumePollingDStream -> "Flume polling stream" diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala index 186e1bf03a94..002aac9f4361 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala @@ -23,7 +23,7 @@ import org.apache.spark.streaming.receiver.Receiver private[streaming] class PluggableInputDStream[T: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, receiver: Receiver[T]) extends ReceiverInputDStream[T](ssc_) { def getReceiver(): Receiver[T] = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala index ed7da6dc1315..a2685046e03d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala @@ -17,16 +17,17 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.UnionRDD -import scala.collection.mutable.Queue -import scala.collection.mutable.ArrayBuffer -import org.apache.spark.streaming.{Time, StreamingContext} +import java.io.{NotSerializableException, ObjectInputStream, ObjectOutputStream} + +import scala.collection.mutable.{ArrayBuffer, Queue} import scala.reflect.ClassTag +import org.apache.spark.rdd.{RDD, UnionRDD} +import org.apache.spark.streaming.{Time, StreamingContext} + private[streaming] class QueueInputDStream[T: ClassTag]( - @transient ssc: StreamingContext, + ssc: StreamingContext, val queue: Queue[RDD[T]], oneAtATime: Boolean, defaultRDD: RDD[T] @@ -36,6 +37,15 @@ class QueueInputDStream[T: ClassTag]( override def stop() { } + private def readObject(in: ObjectInputStream): Unit = { + throw new NotSerializableException("queueStream doesn't support checkpointing. " + + "Please don't use queueStream when checkpointing is enabled.") + } + + private def writeObject(oos: ObjectOutputStream): Unit = { + logWarning("queueStream doesn't support checkpointing") + } + override def compute(validTime: Time): Option[RDD[T]] = { val buffer = new ArrayBuffer[RDD[T]]() if (oneAtATime && queue.size > 0) { @@ -47,7 +57,7 @@ class QueueInputDStream[T: ClassTag]( if (oneAtATime) { Some(buffer.head) } else { - Some(new UnionRDD(ssc.sc, buffer.toSeq)) + Some(new UnionRDD(context.sc, buffer.toSeq)) } } else if (defaultRDD != null) { Some(defaultRDD) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala index e2925b9e03ec..5a9eda7c1277 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala @@ -39,7 +39,7 @@ import org.apache.spark.streaming.receiver.Receiver */ private[streaming] class RawInputDStream[T: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, host: String, port: Int, storageLevel: StorageLevel diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index e76e7eb0dea1..87c20afd5c13 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -21,11 +21,12 @@ import scala.reflect.ClassTag import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.BlockId -import org.apache.spark.streaming._ import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.streaming.scheduler.InputInfo +import org.apache.spark.streaming.scheduler.rate.RateEstimator +import org.apache.spark.streaming.scheduler.{ReceivedBlockInfo, RateController, StreamInputInfo} import org.apache.spark.streaming.util.WriteAheadLogUtils +import org.apache.spark.streaming.{StreamingContext, Time} /** * Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]] @@ -37,9 +38,20 @@ import org.apache.spark.streaming.util.WriteAheadLogUtils * @param ssc_ Streaming context that will execute this input stream * @tparam T Class type of the object of this stream */ -abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext) +abstract class ReceiverInputDStream[T: ClassTag](ssc_ : StreamingContext) extends InputDStream[T](ssc_) { + /** + * Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker. + */ + override protected[streaming] val rateController: Option[RateController] = { + if (RateController.isBackPressureEnabled(ssc.conf)) { + Some(new ReceiverRateController(id, RateEstimator.create(ssc.conf, ssc.graph.batchDuration))) + } else { + None + } + } + /** * Gets the receiver object that will be sent to the worker nodes * to receive data. This method needs to defined by any specific implementation @@ -67,47 +79,72 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont // for this batch val receiverTracker = ssc.scheduler.receiverTracker val blockInfos = receiverTracker.getBlocksOfBatch(validTime).getOrElse(id, Seq.empty) - val blockIds = blockInfos.map { _.blockId.asInstanceOf[BlockId] }.toArray // Register the input blocks information into InputInfoTracker - val inputInfo = InputInfo(id, blockInfos.flatMap(_.numRecords).sum) + val inputInfo = StreamInputInfo(id, blockInfos.flatMap(_.numRecords).sum) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) - if (blockInfos.nonEmpty) { - // Are WAL record handles present with all the blocks - val areWALRecordHandlesPresent = blockInfos.forall { _.walRecordHandleOption.nonEmpty } + // Create the BlockRDD + createBlockRDD(validTime, blockInfos) + } + } + Some(blockRDD) + } - if (areWALRecordHandlesPresent) { - // If all the blocks have WAL record handle, then create a WALBackedBlockRDD - val isBlockIdValid = blockInfos.map { _.isBlockIdValid() }.toArray - val walRecordHandles = blockInfos.map { _.walRecordHandleOption.get }.toArray - new WriteAheadLogBackedBlockRDD[T]( - ssc.sparkContext, blockIds, walRecordHandles, isBlockIdValid) - } else { - // Else, create a BlockRDD. However, if there are some blocks with WAL info but not - // others then that is unexpected and log a warning accordingly. - if (blockInfos.find(_.walRecordHandleOption.nonEmpty).nonEmpty) { - if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { - logError("Some blocks do not have Write Ahead Log information; " + - "this is unexpected and data may not be recoverable after driver failures") - } else { - logWarning("Some blocks have Write Ahead Log information; this is unexpected") - } - } - new BlockRDD[T](ssc.sc, blockIds) - } - } else { - // If no block is ready now, creating WriteAheadLogBackedBlockRDD or BlockRDD - // according to the configuration + private[streaming] def createBlockRDD(time: Time, blockInfos: Seq[ReceivedBlockInfo]): RDD[T] = { + + if (blockInfos.nonEmpty) { + val blockIds = blockInfos.map { _.blockId.asInstanceOf[BlockId] }.toArray + + // Are WAL record handles present with all the blocks + val areWALRecordHandlesPresent = blockInfos.forall { _.walRecordHandleOption.nonEmpty } + + if (areWALRecordHandlesPresent) { + // If all the blocks have WAL record handle, then create a WALBackedBlockRDD + val isBlockIdValid = blockInfos.map { _.isBlockIdValid() }.toArray + val walRecordHandles = blockInfos.map { _.walRecordHandleOption.get }.toArray + new WriteAheadLogBackedBlockRDD[T]( + ssc.sparkContext, blockIds, walRecordHandles, isBlockIdValid) + } else { + // Else, create a BlockRDD. However, if there are some blocks with WAL info but not + // others then that is unexpected and log a warning accordingly. + if (blockInfos.find(_.walRecordHandleOption.nonEmpty).nonEmpty) { if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { - new WriteAheadLogBackedBlockRDD[T]( - ssc.sparkContext, Array.empty, Array.empty, Array.empty) + logError("Some blocks do not have Write Ahead Log information; " + + "this is unexpected and data may not be recoverable after driver failures") } else { - new BlockRDD[T](ssc.sc, Array.empty) + logWarning("Some blocks have Write Ahead Log information; this is unexpected") } } + val validBlockIds = blockIds.filter { id => + ssc.sparkContext.env.blockManager.master.contains(id) + } + if (validBlockIds.size != blockIds.size) { + logWarning("Some blocks could not be recovered as they were not found in memory. " + + "To prevent such data loss, enabled Write Ahead Log (see programming guide " + + "for more details.") + } + new BlockRDD[T](ssc.sc, validBlockIds) + } + } else { + // If no block is ready now, creating WriteAheadLogBackedBlockRDD or BlockRDD + // according to the configuration + if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { + new WriteAheadLogBackedBlockRDD[T]( + ssc.sparkContext, Array.empty, Array.empty, Array.empty) + } else { + new BlockRDD[T](ssc.sc, Array.empty) } } - Some(blockRDD) + } + + /** + * A RateController that sends the new rate to receivers, via the receiver tracker. + */ + private[streaming] class ReceiverRateController(id: Int, estimator: RateEstimator) + extends RateController(id, estimator) { + override def publish(rate: Long): Unit = + ssc.scheduler.receiverTracker.sendRateUpdate(id, rate) } } + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala index 5ce5b7aae6e6..de84e0c9a498 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala @@ -32,7 +32,7 @@ import org.apache.spark.streaming.receiver.Receiver private[streaming] class SocketInputDStream[T: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, host: String, port: Int, bytesToObjects: InputStream => Iterator[T], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index 31ce8e1ec14d..f811784b25c8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -61,7 +61,7 @@ class WriteAheadLogBackedBlockRDDPartition( * * * @param sc SparkContext - * @param blockIds Ids of the blocks that contains this RDD's data + * @param _blockIds Ids of the blocks that contains this RDD's data * @param walRecordHandles Record handles in write ahead logs that contain this RDD's data * @param isBlockIdValid Whether the block Ids are valid (i.e., the blocks are present in the Spark * executors). If not, then block lookups by the block ids will be skipped. @@ -73,23 +73,23 @@ class WriteAheadLogBackedBlockRDDPartition( */ private[streaming] class WriteAheadLogBackedBlockRDD[T: ClassTag]( - @transient sc: SparkContext, - @transient blockIds: Array[BlockId], - @transient walRecordHandles: Array[WriteAheadLogRecordHandle], - @transient isBlockIdValid: Array[Boolean] = Array.empty, + sc: SparkContext, + @transient private val _blockIds: Array[BlockId], + @transient val walRecordHandles: Array[WriteAheadLogRecordHandle], + @transient private val isBlockIdValid: Array[Boolean] = Array.empty, storeInBlockManager: Boolean = false, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER) - extends BlockRDD[T](sc, blockIds) { + extends BlockRDD[T](sc, _blockIds) { require( - blockIds.length == walRecordHandles.length, - s"Number of block Ids (${blockIds.length}) must be " + - s" same as number of WAL record handles (${walRecordHandles.length}})") + _blockIds.length == walRecordHandles.length, + s"Number of block Ids (${_blockIds.length}) must be " + + s" same as number of WAL record handles (${walRecordHandles.length})") require( - isBlockIdValid.isEmpty || isBlockIdValid.length == blockIds.length, + isBlockIdValid.isEmpty || isBlockIdValid.length == _blockIds.length, s"Number of elements in isBlockIdValid (${isBlockIdValid.length}) must be " + - s" same as number of block Ids (${blockIds.length})") + s" same as number of block Ids (${_blockIds.length})") // Hadoop configuration is not serializable, so broadcast it as a serializable. @transient private val hadoopConfig = sc.hadoopConfiguration @@ -99,9 +99,9 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( override def getPartitions: Array[Partition] = { assertValid() - Array.tabulate(blockIds.length) { i => + Array.tabulate(_blockIds.length) { i => val isValid = if (isBlockIdValid.length == 0) true else isBlockIdValid(i) - new WriteAheadLogBackedBlockRDDPartition(i, blockIds(i), isValid, walRecordHandles(i)) + new WriteAheadLogBackedBlockRDDPartition(i, _blockIds(i), isValid, walRecordHandles(i)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala index cd309788a771..7ec74016a1c2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala @@ -144,7 +144,7 @@ private[streaming] class ActorReceiver[T: ClassTag]( receiverSupervisorStrategy: SupervisorStrategy ) extends Receiver[T](storageLevel) with Logging { - protected lazy val supervisor = SparkEnv.get.actorSystem.actorOf(Props(new Supervisor), + protected lazy val actorSupervisor = SparkEnv.get.actorSystem.actorOf(Props(new Supervisor), "Supervisor" + streamId) class Supervisor extends Actor { @@ -191,11 +191,11 @@ private[streaming] class ActorReceiver[T: ClassTag]( } def onStart(): Unit = { - supervisor - logInfo("Supervision tree for receivers initialized at:" + supervisor.path) + actorSupervisor + logInfo("Supervision tree for receivers initialized at:" + actorSupervisor.path) } def onStop(): Unit = { - supervisor ! PoisonPill + actorSupervisor ! PoisonPill } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 92b51ce39234..421d60ae359f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -21,10 +21,10 @@ import java.util.concurrent.{ArrayBlockingQueue, TimeUnit} import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{SparkException, Logging, SparkConf} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.util.RecurringTimer -import org.apache.spark.util.SystemClock +import org.apache.spark.util.{Clock, SystemClock} /** Listener object for BlockGenerator events */ private[streaming] trait BlockGeneratorListener { @@ -69,16 +69,35 @@ private[streaming] trait BlockGeneratorListener { * named blocks at regular intervals. This class starts two threads, * one to periodically start a new batch and prepare the previous batch of as a block, * the other to push the blocks into the block manager. + * + * Note: Do not create BlockGenerator instances directly inside receivers. Use + * `ReceiverSupervisor.createBlockGenerator` to create a BlockGenerator and use it. */ private[streaming] class BlockGenerator( listener: BlockGeneratorListener, receiverId: Int, - conf: SparkConf + conf: SparkConf, + clock: Clock = new SystemClock() ) extends RateLimiter(conf) with Logging { private case class Block(id: StreamBlockId, buffer: ArrayBuffer[Any]) - private val clock = new SystemClock() + /** + * The BlockGenerator can be in 5 possible states, in the order as follows. + * - Initialized: Nothing has been started + * - Active: start() has been called, and it is generating blocks on added data. + * - StoppedAddingData: stop() has been called, the adding of data has been stopped, + * but blocks are still being generated and pushed. + * - StoppedGeneratingBlocks: Generating of blocks has been stopped, but + * they are still being pushed. + * - StoppedAll: Everything has stopped, and the BlockGenerator object can be GCed. + */ + private object GeneratorState extends Enumeration { + type GeneratorState = Value + val Initialized, Active, StoppedAddingData, StoppedGeneratingBlocks, StoppedAll = Value + } + import GeneratorState._ + private val blockIntervalMs = conf.getTimeAsMs("spark.streaming.blockInterval", "200ms") require(blockIntervalMs > 0, s"'spark.streaming.blockInterval' should be a positive value") @@ -89,70 +108,140 @@ private[streaming] class BlockGenerator( private val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } @volatile private var currentBuffer = new ArrayBuffer[Any] - @volatile private var stopped = false + @volatile private var state = Initialized /** Start block generating and pushing threads. */ - def start() { - blockIntervalTimer.start() - blockPushingThread.start() - logInfo("Started BlockGenerator") + def start(): Unit = synchronized { + if (state == Initialized) { + state = Active + blockIntervalTimer.start() + blockPushingThread.start() + logInfo("Started BlockGenerator") + } else { + throw new SparkException( + s"Cannot start BlockGenerator as its not in the Initialized state [state = $state]") + } } - /** Stop all threads. */ - def stop() { + /** + * Stop everything in the right order such that all the data added is pushed out correctly. + * - First, stop adding data to the current buffer. + * - Second, stop generating blocks. + * - Finally, wait for queue of to-be-pushed blocks to be drained. + */ + def stop(): Unit = { + // Set the state to stop adding data + synchronized { + if (state == Active) { + state = StoppedAddingData + } else { + logWarning(s"Cannot stop BlockGenerator as its not in the Active state [state = $state]") + return + } + } + + // Stop generating blocks and set the state for block pushing thread to start draining the queue logInfo("Stopping BlockGenerator") blockIntervalTimer.stop(interruptTimer = false) - stopped = true - logInfo("Waiting for block pushing thread") + synchronized { state = StoppedGeneratingBlocks } + + // Wait for the queue to drain and mark generated as stopped + logInfo("Waiting for block pushing thread to terminate") blockPushingThread.join() + synchronized { state = StoppedAll } logInfo("Stopped BlockGenerator") } /** - * Push a single data item into the buffer. All received data items - * will be periodically pushed into BlockManager. + * Push a single data item into the buffer. */ - def addData (data: Any): Unit = synchronized { - waitToPush() - currentBuffer += data + def addData(data: Any): Unit = { + if (state == Active) { + waitToPush() + synchronized { + if (state == Active) { + currentBuffer += data + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } + } + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } } /** * Push a single data item into the buffer. After buffering the data, the - * `BlockGeneratorListener.onAddData` callback will be called. All received data items - * will be periodically pushed into BlockManager. + * `BlockGeneratorListener.onAddData` callback will be called. */ - def addDataWithCallback(data: Any, metadata: Any): Unit = synchronized { - waitToPush() - currentBuffer += data - listener.onAddData(data, metadata) + def addDataWithCallback(data: Any, metadata: Any): Unit = { + if (state == Active) { + waitToPush() + synchronized { + if (state == Active) { + currentBuffer += data + listener.onAddData(data, metadata) + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } + } + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } } /** * Push multiple data items into the buffer. After buffering the data, the - * `BlockGeneratorListener.onAddData` callback will be called. All received data items - * will be periodically pushed into BlockManager. Note that all the data items is guaranteed - * to be present in a single block. + * `BlockGeneratorListener.onAddData` callback will be called. Note that all the data items + * are atomically added to the buffer, and are hence guaranteed to be present in a single block. */ - def addMultipleDataWithCallback(dataIterator: Iterator[Any], metadata: Any): Unit = synchronized { - dataIterator.foreach { data => - waitToPush() - currentBuffer += data + def addMultipleDataWithCallback(dataIterator: Iterator[Any], metadata: Any): Unit = { + if (state == Active) { + // Unroll iterator into a temp buffer, and wait for pushing in the process + val tempBuffer = new ArrayBuffer[Any] + dataIterator.foreach { data => + waitToPush() + tempBuffer += data + } + synchronized { + if (state == Active) { + currentBuffer ++= tempBuffer + listener.onAddData(tempBuffer, metadata) + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } + } + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") } - listener.onAddData(dataIterator, metadata) } + def isActive(): Boolean = state == Active + + def isStopped(): Boolean = state == StoppedAll + /** Change the buffer to which single records are added to. */ - private def updateCurrentBuffer(time: Long): Unit = synchronized { + private def updateCurrentBuffer(time: Long): Unit = { try { - val newBlockBuffer = currentBuffer - currentBuffer = new ArrayBuffer[Any] - if (newBlockBuffer.size > 0) { - val blockId = StreamBlockId(receiverId, time - blockIntervalMs) - val newBlock = new Block(blockId, newBlockBuffer) - listener.onGenerateBlock(blockId) + var newBlock: Block = null + synchronized { + if (currentBuffer.nonEmpty) { + val newBlockBuffer = currentBuffer + currentBuffer = new ArrayBuffer[Any] + val blockId = StreamBlockId(receiverId, time - blockIntervalMs) + listener.onGenerateBlock(blockId) + newBlock = new Block(blockId, newBlockBuffer) + } + } + + if (newBlock != null) { blocksForPushing.put(newBlock) // put is blocking when queue is full - logDebug("Last element in " + blockId + " is " + newBlockBuffer.last) } } catch { case ie: InterruptedException => @@ -165,18 +254,25 @@ private[streaming] class BlockGenerator( /** Keep pushing blocks to the BlockManager. */ private def keepPushingBlocks() { logInfo("Started block pushing thread") + + def areBlocksBeingGenerated: Boolean = synchronized { + state != StoppedGeneratingBlocks + } + try { - while (!stopped) { - Option(blocksForPushing.poll(100, TimeUnit.MILLISECONDS)) match { + // While blocks are being generated, keep polling for to-be-pushed blocks and push them. + while (areBlocksBeingGenerated) { + Option(blocksForPushing.poll(10, TimeUnit.MILLISECONDS)) match { case Some(block) => pushBlock(block) case None => } } - // Push out the blocks that are still left + + // At this point, state is StoppedGeneratingBlock. So drain the queue of to-be-pushed blocks. logInfo("Pushing out the last " + blocksForPushing.size() + " blocks") while (!blocksForPushing.isEmpty) { - logDebug("Getting block ") val block = blocksForPushing.take() + logDebug(s"Pushing block $block") pushBlock(block) logInfo("Blocks left to push " + blocksForPushing.size()) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala index 8df542b367d2..bca1fbc8fda2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala @@ -34,12 +34,31 @@ import org.apache.spark.{Logging, SparkConf} */ private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging { - private val desiredRate = conf.getInt("spark.streaming.receiver.maxRate", 0) - private lazy val rateLimiter = GuavaRateLimiter.create(desiredRate) + // treated as an upper limit + private val maxRateLimit = conf.getLong("spark.streaming.receiver.maxRate", Long.MaxValue) + private lazy val rateLimiter = GuavaRateLimiter.create(maxRateLimit.toDouble) def waitToPush() { - if (desiredRate > 0) { - rateLimiter.acquire() - } + rateLimiter.acquire() } + + /** + * Return the current rate limit. If no limit has been set so far, it returns {{{Long.MaxValue}}}. + */ + def getCurrentLimit: Long = rateLimiter.getRate.toLong + + /** + * Set the rate limit to `newRate`. The new rate will not exceed the maximum rate configured by + * {{{spark.streaming.receiver.maxRate}}}, even if `newRate` is higher than that. + * + * @param newRate A new rate in events per second. It has no effect if it's 0 or negative. + */ + private[receiver] def updateRate(newRate: Long): Unit = + if (newRate > 0) { + if (maxRateLimit > 0) { + rateLimiter.setRate(newRate.min(maxRateLimit)) + } else { + rateLimiter.setRate(newRate) + } + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index c8dd6e06812d..5f6c5b024085 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -222,7 +222,7 @@ private[streaming] object WriteAheadLogBasedBlockHandler { /** * A utility that will wrap the Iterator to get the count */ -private class CountingIterator[T](iterator: Iterator[T]) extends Iterator[T] { +private[streaming] class CountingIterator[T](iterator: Iterator[T]) extends Iterator[T] { private var _count = 0 private def isFullyConsumed: Boolean = !iterator.hasNext diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala index 5b5a3fe64860..2252e28f22af 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.receiver import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.storage.StorageLevel import org.apache.spark.annotation.DeveloperApi @@ -116,12 +116,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * being pushed into Spark's memory. */ def store(dataItem: T) { - executor.pushSingle(dataItem) + supervisor.pushSingle(dataItem) } /** Store an ArrayBuffer of received data as a data block into Spark's memory. */ def store(dataBuffer: ArrayBuffer[T]) { - executor.pushArrayBuffer(dataBuffer, None, None) + supervisor.pushArrayBuffer(dataBuffer, None, None) } /** @@ -130,12 +130,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * for being used in the corresponding InputDStream. */ def store(dataBuffer: ArrayBuffer[T], metadata: Any) { - executor.pushArrayBuffer(dataBuffer, Some(metadata), None) + supervisor.pushArrayBuffer(dataBuffer, Some(metadata), None) } /** Store an iterator of received data as a data block into Spark's memory. */ def store(dataIterator: Iterator[T]) { - executor.pushIterator(dataIterator, None, None) + supervisor.pushIterator(dataIterator, None, None) } /** @@ -144,12 +144,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * for being used in the corresponding InputDStream. */ def store(dataIterator: java.util.Iterator[T], metadata: Any) { - executor.pushIterator(dataIterator, Some(metadata), None) + supervisor.pushIterator(dataIterator.asScala, Some(metadata), None) } /** Store an iterator of received data as a data block into Spark's memory. */ def store(dataIterator: java.util.Iterator[T]) { - executor.pushIterator(dataIterator, None, None) + supervisor.pushIterator(dataIterator.asScala, None, None) } /** @@ -158,7 +158,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * for being used in the corresponding InputDStream. */ def store(dataIterator: Iterator[T], metadata: Any) { - executor.pushIterator(dataIterator, Some(metadata), None) + supervisor.pushIterator(dataIterator, Some(metadata), None) } /** @@ -167,7 +167,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * that Spark is configured to use. */ def store(bytes: ByteBuffer) { - executor.pushBytes(bytes, None, None) + supervisor.pushBytes(bytes, None, None) } /** @@ -176,12 +176,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * for being used in the corresponding InputDStream. */ def store(bytes: ByteBuffer, metadata: Any) { - executor.pushBytes(bytes, Some(metadata), None) + supervisor.pushBytes(bytes, Some(metadata), None) } /** Report exceptions in receiving data. */ def reportError(message: String, throwable: Throwable) { - executor.reportError(message, throwable) + supervisor.reportError(message, throwable) } /** @@ -193,7 +193,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * The `message` will be reported to the driver. */ def restart(message: String) { - executor.restartReceiver(message) + supervisor.restartReceiver(message) } /** @@ -205,7 +205,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * The `message` and `exception` will be reported to the driver. */ def restart(message: String, error: Throwable) { - executor.restartReceiver(message, Some(error)) + supervisor.restartReceiver(message, Some(error)) } /** @@ -215,22 +215,22 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * in a background thread. */ def restart(message: String, error: Throwable, millisecond: Int) { - executor.restartReceiver(message, Some(error), millisecond) + supervisor.restartReceiver(message, Some(error), millisecond) } /** Stop the receiver completely. */ def stop(message: String) { - executor.stop(message, None) + supervisor.stop(message, None) } /** Stop the receiver completely due to an exception */ def stop(message: String, error: Throwable) { - executor.stop(message, Some(error)) + supervisor.stop(message, Some(error)) } /** Check if the receiver has started or not. */ def isStarted(): Boolean = { - executor.isReceiverStarted() + supervisor.isReceiverStarted() } /** @@ -238,7 +238,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * the receiving of data should be stopped. */ def isStopped(): Boolean = { - executor.isReceiverStopped() + supervisor.isReceiverStopped() } /** @@ -257,7 +257,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable private var id: Int = -1 /** Handler object that runs the receiver. This is instantiated lazily in the worker. */ - private[streaming] var executor_ : ReceiverSupervisor = null + @transient private var _supervisor : ReceiverSupervisor = null /** Set the ID of the DStream that this receiver is associated with. */ private[streaming] def setReceiverId(id_ : Int) { @@ -265,15 +265,17 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable } /** Attach Network Receiver executor to this receiver. */ - private[streaming] def attachExecutor(exec: ReceiverSupervisor) { - assert(executor_ == null) - executor_ = exec + private[streaming] def attachSupervisor(exec: ReceiverSupervisor) { + assert(_supervisor == null) + _supervisor = exec } - /** Get the attached executor. */ - private def executor = { - assert(executor_ != null, "Executor has not been attached to this receiver") - executor_ + /** Get the attached supervisor. */ + private[streaming] def supervisor: ReceiverSupervisor = { + assert(_supervisor != null, + "A ReceiverSupervisor have not been attached to the receiver yet. Maybe you are starting " + + "some computation in the receiver before the Receiver.onStart() has been called.") + _supervisor } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala index 7bf3c3331949..1eb55affaa9d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala @@ -23,4 +23,5 @@ import org.apache.spark.streaming.Time private[streaming] sealed trait ReceiverMessage extends Serializable private[streaming] object StopReceiver extends ReceiverMessage private[streaming] case class CleanupOldBlocks(threshTime: Time) extends ReceiverMessage - +private[streaming] case class UpdateRateLimit(elementsPerSecond: Long) + extends ReceiverMessage diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index 33be067ebdaf..158d1ba2f183 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -22,10 +22,11 @@ import java.util.concurrent.CountDownLatch import scala.collection.mutable.ArrayBuffer import scala.concurrent._ +import scala.util.control.NonFatal -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{SparkEnv, Logging, SparkConf} import org.apache.spark.storage.StreamBlockId -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{Utils, ThreadUtils} /** * Abstract class that is responsible for supervising a Receiver in the worker. @@ -36,15 +37,15 @@ private[streaming] abstract class ReceiverSupervisor( conf: SparkConf ) extends Logging { - /** Enumeration to identify current state of the StreamingContext */ + /** Enumeration to identify current state of the Receiver */ object ReceiverState extends Enumeration { type CheckpointState = Value val Initialized, Started, Stopped = Value } import ReceiverState._ - // Attach the executor to the receiver - receiver.attachExecutor(this) + // Attach the supervisor to the receiver + receiver.attachSupervisor(this) private val futureExecutionContext = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonCachedThreadPool("receiver-supervisor-future", 128)) @@ -58,6 +59,9 @@ private[streaming] abstract class ReceiverSupervisor( /** Time between a receiver is stopped and started again */ private val defaultRestartDelay = conf.getInt("spark.streaming.receiverRestartDelay", 2000) + /** The current maximum rate limit for this receiver. */ + private[streaming] def getCurrentRateLimit: Long = Long.MaxValue + /** Exception associated with the stopping of the receiver */ @volatile protected var stoppingError: Throwable = null @@ -88,17 +92,34 @@ private[streaming] abstract class ReceiverSupervisor( optionalBlockId: Option[StreamBlockId] ) + /** + * Create a custom [[BlockGenerator]] that the receiver implementation can directly control + * using their provided [[BlockGeneratorListener]]. + * + * Note: Do not explicitly start or stop the `BlockGenerator`, the `ReceiverSupervisorImpl` + * will take care of it. + */ + def createBlockGenerator(blockGeneratorListener: BlockGeneratorListener): BlockGenerator + /** Report errors. */ def reportError(message: String, throwable: Throwable) - /** Called when supervisor is started */ + /** + * Called when supervisor is started. + * Note that this must be called before the receiver.onStart() is called to ensure + * things like [[BlockGenerator]]s are started before the receiver starts sending data. + */ protected def onStart() { } - /** Called when supervisor is stopped */ + /** + * Called when supervisor is stopped. + * Note that this must be called after the receiver.onStop() is called to ensure + * things like [[BlockGenerator]]s are cleaned up after the receiver stops sending data. + */ protected def onStop(message: String, error: Option[Throwable]) { } - /** Called when receiver is started */ - protected def onReceiverStart() { } + /** Called when receiver is started. Return true if the driver accepts us */ + protected def onReceiverStart(): Boolean /** Called when receiver is stopped */ protected def onReceiverStop(message: String, error: Option[Throwable]) { } @@ -121,13 +142,17 @@ private[streaming] abstract class ReceiverSupervisor( /** Start receiver */ def startReceiver(): Unit = synchronized { try { - logInfo("Starting receiver") - receiver.onStart() - logInfo("Called receiver onStart") - onReceiverStart() - receiverState = Started + if (onReceiverStart()) { + logInfo("Starting receiver") + receiverState = Started + receiver.onStart() + logInfo("Called receiver onStart") + } else { + // The driver refused us + stop("Registered unsuccessfully because Driver refused to start receiver " + streamId, None) + } } catch { - case t: Throwable => + case NonFatal(t) => stop("Error starting receiver " + streamId, Some(t)) } } @@ -136,12 +161,19 @@ private[streaming] abstract class ReceiverSupervisor( def stopReceiver(message: String, error: Option[Throwable]): Unit = synchronized { try { logInfo("Stopping receiver with message: " + message + ": " + error.getOrElse("")) - receiverState = Stopped - receiver.onStop() - logInfo("Called receiver onStop") - onReceiverStop(message, error) + receiverState match { + case Initialized => + logWarning("Skip stopping receiver because it has not yet stared") + case Started => + receiverState = Stopped + receiver.onStop() + logInfo("Called receiver onStop") + onReceiverStop(message, error) + case Stopped => + logWarning("Receiver has been stopped") + } } catch { - case t: Throwable => + case NonFatal(t) => logError("Error stopping receiver " + streamId + t.getStackTraceString) } } @@ -167,7 +199,7 @@ private[streaming] abstract class ReceiverSupervisor( }(futureExecutionContext) } - /** Check if receiver has been marked for stopping */ + /** Check if receiver has been marked for starting */ def isReceiverStarted(): Boolean = { logDebug("state = " + receiverState) receiverState == Started @@ -182,12 +214,12 @@ private[streaming] abstract class ReceiverSupervisor( /** Wait the thread until the supervisor is stopped */ def awaitTermination() { + logInfo("Waiting for receiver to be stopped") stopLatch.await() - logInfo("Waiting for executor stop is over") if (stoppingError != null) { - logError("Stopped executor with error: " + stoppingError) + logError("Stopped receiver with error: " + stoppingError) } else { - logWarning("Stopped executor without error") + logInfo("Stopped receiver without error") } if (stoppingError != null) { throw stoppingError diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 6078cdf8f879..59ef58d232ee 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -20,6 +20,7 @@ package org.apache.spark.streaming.receiver import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicLong +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import com.google.common.base.Throwables @@ -30,7 +31,7 @@ import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.Time import org.apache.spark.streaming.scheduler._ import org.apache.spark.streaming.util.WriteAheadLogUtils -import org.apache.spark.util.{RpcUtils, Utils} +import org.apache.spark.util.RpcUtils import org.apache.spark.{Logging, SparkEnv, SparkException} /** @@ -46,6 +47,8 @@ private[streaming] class ReceiverSupervisorImpl( checkpointDirOption: Option[String] ) extends ReceiverSupervisor(receiver, env.conf) with Logging { + private val hostPort = SparkEnv.get.blockManager.blockManagerId.hostPort + private val receivedBlockHandler: ReceivedBlockHandler = { if (WriteAheadLogUtils.enableReceiverLog(env.conf)) { if (checkpointDirOption.isEmpty) { @@ -77,14 +80,22 @@ private[streaming] class ReceiverSupervisorImpl( case CleanupOldBlocks(threshTime) => logDebug("Received delete old batch signal") cleanupOldBlocks(threshTime) + case UpdateRateLimit(eps) => + logInfo(s"Received a new rate limit: $eps.") + registeredBlockGenerators.foreach { bg => + bg.updateRate(eps) + } } }) /** Unique block ids if one wants to add blocks directly */ private val newBlockId = new AtomicLong(System.currentTimeMillis()) + private val registeredBlockGenerators = new mutable.ArrayBuffer[BlockGenerator] + with mutable.SynchronizedBuffer[BlockGenerator] + /** Divides received data records into data blocks for pushing in BlockManager. */ - private val blockGenerator = new BlockGenerator(new BlockGeneratorListener { + private val defaultBlockGeneratorListener = new BlockGeneratorListener { def onAddData(data: Any, metadata: Any): Unit = { } def onGenerateBlock(blockId: StreamBlockId): Unit = { } @@ -96,11 +107,15 @@ private[streaming] class ReceiverSupervisorImpl( def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]) { pushArrayBuffer(arrayBuffer, None, Some(blockId)) } - }, streamId, env.conf) + } + private val defaultBlockGenerator = createBlockGenerator(defaultBlockGeneratorListener) + + /** Get the current rate limit of the default block generator */ + override private[streaming] def getCurrentRateLimit: Long = defaultBlockGenerator.getCurrentLimit /** Push a single record of received data into block generator. */ def pushSingle(data: Any) { - blockGenerator.addData(data) + defaultBlockGenerator.addData(data) } /** Store an ArrayBuffer of received data as a data block into Spark's memory. */ @@ -154,17 +169,17 @@ private[streaming] class ReceiverSupervisorImpl( } override protected def onStart() { - blockGenerator.start() + registeredBlockGenerators.foreach { _.start() } } override protected def onStop(message: String, error: Option[Throwable]) { - blockGenerator.stop() + registeredBlockGenerators.foreach { _.stop() } env.rpcEnv.stop(endpoint) } - override protected def onReceiverStart() { + override protected def onReceiverStart(): Boolean = { val msg = RegisterReceiver( - streamId, receiver.getClass.getSimpleName, Utils.localHostName(), endpoint) + streamId, receiver.getClass.getSimpleName, hostPort, endpoint) trackerEndpoint.askWithRetry[Boolean](msg) } @@ -175,6 +190,16 @@ private[streaming] class ReceiverSupervisorImpl( logInfo("Stopped receiver " + streamId) } + override def createBlockGenerator( + blockGeneratorListener: BlockGeneratorListener): BlockGenerator = { + // Cleanup BlockGenerators that have already been stopped + registeredBlockGenerators --= registeredBlockGenerators.filter{ _.isStopped() } + + val newBlockGenerator = new BlockGenerator(blockGeneratorListener, streamId, env.conf) + registeredBlockGenerators += newBlockGenerator + newBlockGenerator + } + /** Generate new block ID */ private def nextBlockId = StreamBlockId(streamId, newBlockId.getAndIncrement) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala index 5b9bfbf9b01e..9922b6bc1201 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala @@ -24,7 +24,7 @@ import org.apache.spark.streaming.Time * :: DeveloperApi :: * Class having information on completed batches. * @param batchTime Time of the batch - * @param streamIdToNumRecords A map of input stream id to record number + * @param streamIdToInputInfo A map of input stream id to its input info * @param submissionTime Clock time of when jobs of this batch was submitted to * the streaming scheduler queue * @param processingStartTime Clock time of when the first job of this batch started processing @@ -33,12 +33,15 @@ import org.apache.spark.streaming.Time @DeveloperApi case class BatchInfo( batchTime: Time, - streamIdToNumRecords: Map[Int, Long], + streamIdToInputInfo: Map[Int, StreamInputInfo], submissionTime: Long, processingStartTime: Option[Long], processingEndTime: Option[Long] ) { + @deprecated("Use streamIdToInputInfo instead", "1.5.0") + def streamIdToNumRecords: Map[Int, Long] = streamIdToInputInfo.mapValues(_.numRecords) + /** * Time taken for the first job of this batch to start processing from the time this batch * was submitted to the streaming scheduler. Essentially, it is @@ -63,5 +66,5 @@ case class BatchInfo( /** * The number of recorders received by the receivers in this batch. */ - def numRecords: Long = streamIdToNumRecords.values.sum + def numRecords: Long = streamIdToInputInfo.values.map(_.numRecords).sum } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala index 7c0db8a863c6..deb15d075975 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala @@ -20,11 +20,34 @@ package org.apache.spark.streaming.scheduler import scala.collection.mutable import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.streaming.{Time, StreamingContext} -/** To track the information of input stream at specified batch time. */ -private[streaming] case class InputInfo(inputStreamId: Int, numRecords: Long) { +/** + * :: DeveloperApi :: + * Track the information of input stream at specified batch time. + * + * @param inputStreamId the input stream id + * @param numRecords the number of records in a batch + * @param metadata metadata for this batch. It should contain at least one standard field named + * "Description" which maps to the content that will be shown in the UI. + */ +@DeveloperApi +case class StreamInputInfo( + inputStreamId: Int, numRecords: Long, metadata: Map[String, Any] = Map.empty) { require(numRecords >= 0, "numRecords must not be negative") + + def metadataDescription: Option[String] = + metadata.get(StreamInputInfo.METADATA_KEY_DESCRIPTION).map(_.toString) +} + +@DeveloperApi +object StreamInputInfo { + + /** + * The key for description in `StreamInputInfo.metadata`. + */ + val METADATA_KEY_DESCRIPTION: String = "Description" } /** @@ -34,25 +57,26 @@ private[streaming] case class InputInfo(inputStreamId: Int, numRecords: Long) { private[streaming] class InputInfoTracker(ssc: StreamingContext) extends Logging { // Map to track all the InputInfo related to specific batch time and input stream. - private val batchTimeToInputInfos = new mutable.HashMap[Time, mutable.HashMap[Int, InputInfo]] + private val batchTimeToInputInfos = + new mutable.HashMap[Time, mutable.HashMap[Int, StreamInputInfo]] /** Report the input information with batch time to the tracker */ - def reportInfo(batchTime: Time, inputInfo: InputInfo): Unit = synchronized { + def reportInfo(batchTime: Time, inputInfo: StreamInputInfo): Unit = synchronized { val inputInfos = batchTimeToInputInfos.getOrElseUpdate(batchTime, - new mutable.HashMap[Int, InputInfo]()) + new mutable.HashMap[Int, StreamInputInfo]()) if (inputInfos.contains(inputInfo.inputStreamId)) { - throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId}} for batch" + + throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId} for batch" + s"$batchTime is already added into InputInfoTracker, this is a illegal state") } inputInfos += ((inputInfo.inputStreamId, inputInfo)) } /** Get the all the input stream's information of specified batch time */ - def getInfo(batchTime: Time): Map[Int, InputInfo] = synchronized { + def getInfo(batchTime: Time): Map[Int, StreamInputInfo] = synchronized { val inputInfos = batchTimeToInputInfos.get(batchTime) // Convert mutable HashMap to immutable Map for the caller - inputInfos.map(_.toMap).getOrElse(Map[Int, InputInfo]()) + inputInfos.map(_.toMap).getOrElse(Map[Int, StreamInputInfo]()) } /** Cleanup the tracked input information older than threshold batch time */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 9f93d6cbc3c2..2de035d166e7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -22,7 +22,7 @@ import scala.util.{Failure, Success, Try} import org.apache.spark.{SparkEnv, Logging} import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time} import org.apache.spark.streaming.util.RecurringTimer -import org.apache.spark.util.{Clock, EventLoop, ManualClock} +import org.apache.spark.util.{Utils, Clock, EventLoop, ManualClock} /** Event classes for JobGenerator */ private[scheduler] sealed trait JobGeneratorEvent @@ -47,11 +47,11 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val clockClass = ssc.sc.conf.get( "spark.streaming.clock", "org.apache.spark.util.SystemClock") try { - Class.forName(clockClass).newInstance().asInstanceOf[Clock] + Utils.classForName(clockClass).newInstance().asInstanceOf[Clock] } catch { case e: ClassNotFoundException if clockClass.startsWith("org.apache.spark.streaming") => val newClockClass = clockClass.replace("org.apache.spark.streaming", "org.apache.spark") - Class.forName(newClockClass).newInstance().asInstanceOf[Clock] + Utils.classForName(newClockClass).newInstance().asInstanceOf[Clock] } } @@ -79,6 +79,10 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { def start(): Unit = synchronized { if (eventLoop != null) return // generator has already been started + // Call checkpointWriter here to initialize it before eventLoop uses it to avoid a deadlock. + // See SPARK-10125 + checkpointWriter + eventLoop = new EventLoop[JobGeneratorEvent]("JobGenerator") { override protected def onReceive(event: JobGeneratorEvent): Unit = processEvent(event) @@ -244,8 +248,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { } match { case Success(jobs) => val streamIdToInputInfos = jobScheduler.inputInfoTracker.getInfo(time) - val streamIdToNumRecords = streamIdToInputInfos.mapValues(_.numRecords) - jobScheduler.submitJobSet(JobSet(time, jobs, streamIdToNumRecords)) + jobScheduler.submitJobSet(JobSet(time, jobs, streamIdToInputInfos)) case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 4af9b6d3b56a..0cd39594ee92 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -17,15 +17,15 @@ package org.apache.spark.streaming.scheduler -import java.util.concurrent.{TimeUnit, ConcurrentHashMap, Executors} +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.{Failure, Success} import org.apache.spark.Logging import org.apache.spark.rdd.PairRDDFunctions import org.apache.spark.streaming._ -import org.apache.spark.util.EventLoop +import org.apache.spark.util.{EventLoop, ThreadUtils} private[scheduler] sealed trait JobSchedulerEvent @@ -40,9 +40,12 @@ private[scheduler] case class ErrorReported(msg: String, e: Throwable) extends J private[streaming] class JobScheduler(val ssc: StreamingContext) extends Logging { - private val jobSets = new ConcurrentHashMap[Time, JobSet] + // Use of ConcurrentHashMap.keySet later causes an odd runtime problem due to Java 7/8 diff + // https://gist.github.com/AlainODea/1375759b8720a3f9f094 + private val jobSets: java.util.Map[Time, JobSet] = new ConcurrentHashMap[Time, JobSet] private val numConcurrentJobs = ssc.conf.getInt("spark.streaming.concurrentJobs", 1) - private val jobExecutor = Executors.newFixedThreadPool(numConcurrentJobs) + private val jobExecutor = + ThreadUtils.newDaemonFixedThreadPool(numConcurrentJobs, "streaming-job-executor") private val jobGenerator = new JobGenerator(this) val clock = jobGenerator.clock val listenerBus = new StreamingListenerBus() @@ -66,6 +69,12 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } eventLoop.start() + // attach rate controllers of input streams to receive batch completion updates + for { + inputDStream <- ssc.graph.getInputStreams + rateController <- inputDStream.rateController + } ssc.addStreamingListener(rateController) + listenerBus.start(ssc.sparkContext) receiverTracker = new ReceiverTracker(ssc) inputInfoTracker = new InputInfoTracker(ssc) @@ -119,7 +128,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } def getPendingTimes(): Seq[Time] = { - jobSets.keySet.toSeq + jobSets.asScala.keys.toSeq } def reportError(msg: String, e: Throwable) { @@ -185,14 +194,25 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { ssc.sc.setLocalProperty(JobScheduler.BATCH_TIME_PROPERTY_KEY, job.time.milliseconds.toString) ssc.sc.setLocalProperty(JobScheduler.OUTPUT_OP_ID_PROPERTY_KEY, job.outputOpId.toString) try { - eventLoop.post(JobStarted(job)) - // Disable checks for existing output directories in jobs launched by the streaming - // scheduler, since we may need to write output to an existing directory during checkpoint - // recovery; see SPARK-4835 for more details. - PairRDDFunctions.disableOutputSpecValidation.withValue(true) { - job.run() + // We need to assign `eventLoop` to a temp variable. Otherwise, because + // `JobScheduler.stop(false)` may set `eventLoop` to null when this method is running, then + // it's possible that when `post` is called, `eventLoop` happens to null. + var _eventLoop = eventLoop + if (_eventLoop != null) { + _eventLoop.post(JobStarted(job)) + // Disable checks for existing output directories in jobs launched by the streaming + // scheduler, since we may need to write output to an existing directory during checkpoint + // recovery; see SPARK-4835 for more details. + PairRDDFunctions.disableOutputSpecValidation.withValue(true) { + job.run() + } + _eventLoop = eventLoop + if (_eventLoop != null) { + _eventLoop.post(JobCompleted(job)) + } + } else { + // JobScheduler has been stopped. } - eventLoop.post(JobCompleted(job)) } finally { ssc.sc.setLocalProperty(JobScheduler.BATCH_TIME_PROPERTY_KEY, null) ssc.sc.setLocalProperty(JobScheduler.OUTPUT_OP_ID_PROPERTY_KEY, null) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index e6be63b2ddbd..95833efc9417 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -28,7 +28,7 @@ private[streaming] case class JobSet( time: Time, jobs: Seq[Job], - streamIdToNumRecords: Map[Int, Long] = Map.empty) { + streamIdToInputInfo: Map[Int, StreamInputInfo] = Map.empty) { private val incompleteJobs = new HashSet[Job]() private val submissionTime = System.currentTimeMillis() // when this jobset was submitted @@ -64,7 +64,7 @@ case class JobSet( def toBatchInfo: BatchInfo = { new BatchInfo( time, - streamIdToNumRecords, + streamIdToInputInfo, submissionTime, if (processingStartTime >= 0 ) Some(processingStartTime) else None, if (processingEndTime >= 0 ) Some(processingEndTime) else None diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala new file mode 100644 index 000000000000..a46c0c1b25e7 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import java.io.ObjectInputStream +import java.util.concurrent.atomic.AtomicLong + +import scala.concurrent.{ExecutionContext, Future} + +import org.apache.spark.SparkConf +import org.apache.spark.streaming.scheduler.rate.RateEstimator +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * A StreamingListener that receives batch completion updates, and maintains + * an estimate of the speed at which this stream should ingest messages, + * given an estimate computation from a `RateEstimator` + */ +private[streaming] abstract class RateController(val streamUID: Int, rateEstimator: RateEstimator) + extends StreamingListener with Serializable { + + init() + + protected def publish(rate: Long): Unit + + @transient + implicit private var executionContext: ExecutionContext = _ + + @transient + private var rateLimit: AtomicLong = _ + + /** + * An initialization method called both from the constructor and Serialization code. + */ + private def init() { + executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonSingleThreadExecutor("stream-rate-update")) + rateLimit = new AtomicLong(-1L) + } + + private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { + ois.defaultReadObject() + init() + } + + /** + * Compute the new rate limit and publish it asynchronously. + */ + private def computeAndPublish(time: Long, elems: Long, workDelay: Long, waitDelay: Long): Unit = + Future[Unit] { + val newRate = rateEstimator.compute(time, elems, workDelay, waitDelay) + newRate.foreach { s => + rateLimit.set(s.toLong) + publish(getLatestRate()) + } + } + + def getLatestRate(): Long = rateLimit.get() + + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { + val elements = batchCompleted.batchInfo.streamIdToInputInfo + + for { + processingEnd <- batchCompleted.batchInfo.processingEndTime + workDelay <- batchCompleted.batchInfo.processingDelay + waitDelay <- batchCompleted.batchInfo.schedulingDelay + elems <- elements.get(streamUID).map(_.numRecords) + } computeAndPublish(processingEnd, elems, workDelay, waitDelay) + } +} + +object RateController { + def isBackPressureEnabled(conf: SparkConf): Boolean = + conf.getBoolean("spark.streaming.backpressure.enabled", false) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 7720259a5d79..f2711d1355e6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -19,6 +19,7 @@ package org.apache.spark.streaming.scheduler import java.nio.ByteBuffer +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.implicitConversions @@ -28,7 +29,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.streaming.Time import org.apache.spark.streaming.util.{WriteAheadLog, WriteAheadLogUtils} import org.apache.spark.util.{Clock, Utils} -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.{Logging, SparkConf} /** Trait representing any event in the ReceivedBlockTracker that updates its state. */ private[streaming] sealed trait ReceivedBlockTrackerLogEvent @@ -196,10 +197,10 @@ private[streaming] class ReceivedBlockTracker( writeAheadLogOption.foreach { writeAheadLog => logInfo(s"Recovering from write ahead logs in ${checkpointDirOption.get}") - import scala.collection.JavaConversions._ - writeAheadLog.readAll().foreach { byteBuffer => + writeAheadLog.readAll().asScala.foreach { byteBuffer => logTrace("Recovering record " + byteBuffer) - Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) match { + Utils.deserialize[ReceivedBlockTrackerLogEvent]( + byteBuffer.array, Thread.currentThread().getContextClassLoader) match { case BlockAdditionEvent(receivedBlockInfo) => insertAddedBlock(receivedBlockInfo) case BatchAllocationEvent(time, allocatedBlocks) => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala index de85f24dd988..59df892397fe 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala @@ -28,7 +28,6 @@ import org.apache.spark.rpc.RpcEndpointRef case class ReceiverInfo( streamId: Int, name: String, - private[streaming] val endpoint: RpcEndpointRef, active: Boolean, location: String, lastErrorMessage: String = "", diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala new file mode 100644 index 000000000000..10b5a7f57a80 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import scala.collection.Map +import scala.collection.mutable + +import org.apache.spark.streaming.receiver.Receiver + +/** + * A class that tries to schedule receivers with evenly distributed. There are two phases for + * scheduling receivers. + * + * - The first phase is global scheduling when ReceiverTracker is starting and we need to schedule + * all receivers at the same time. ReceiverTracker will call `scheduleReceivers` at this phase. + * It will try to schedule receivers with evenly distributed. ReceiverTracker should update its + * receiverTrackingInfoMap according to the results of `scheduleReceivers`. + * `ReceiverTrackingInfo.scheduledExecutors` for each receiver will set to an executor list that + * contains the scheduled locations. Then when a receiver is starting, it will send a register + * request and `ReceiverTracker.registerReceiver` will be called. In + * `ReceiverTracker.registerReceiver`, if a receiver's scheduled executors is set, it should check + * if the location of this receiver is one of the scheduled executors, if not, the register will + * be rejected. + * - The second phase is local scheduling when a receiver is restarting. There are two cases of + * receiver restarting: + * - If a receiver is restarting because it's rejected due to the real location and the scheduled + * executors mismatching, in other words, it fails to start in one of the locations that + * `scheduleReceivers` suggested, `ReceiverTracker` should firstly choose the executors that are + * still alive in the list of scheduled executors, then use them to launch the receiver job. + * - If a receiver is restarting without a scheduled executors list, or the executors in the list + * are dead, `ReceiverTracker` should call `rescheduleReceiver`. If so, `ReceiverTracker` should + * not set `ReceiverTrackingInfo.scheduledExecutors` for this executor, instead, it should clear + * it. Then when this receiver is registering, we can know this is a local scheduling, and + * `ReceiverTrackingInfo` should call `rescheduleReceiver` again to check if the launching + * location is matching. + * + * In conclusion, we should make a global schedule, try to achieve that exactly as long as possible, + * otherwise do local scheduling. + */ +private[streaming] class ReceiverSchedulingPolicy { + + /** + * Try our best to schedule receivers with evenly distributed. However, if the + * `preferredLocation`s of receivers are not even, we may not be able to schedule them evenly + * because we have to respect them. + * + * Here is the approach to schedule executors: + *
      + *
    1. First, schedule all the receivers with preferred locations (hosts), evenly among the + * executors running on those host.
    2. + *
    3. Then, schedule all other receivers evenly among all the executors such that overall + * distribution over all the receivers is even.
    4. + *
    + * + * This method is called when we start to launch receivers at the first time. + */ + def scheduleReceivers( + receivers: Seq[Receiver[_]], executors: Seq[String]): Map[Int, Seq[String]] = { + if (receivers.isEmpty) { + return Map.empty + } + + if (executors.isEmpty) { + return receivers.map(_.streamId -> Seq.empty).toMap + } + + val hostToExecutors = executors.groupBy(_.split(":")(0)) + val scheduledExecutors = Array.fill(receivers.length)(new mutable.ArrayBuffer[String]) + val numReceiversOnExecutor = mutable.HashMap[String, Int]() + // Set the initial value to 0 + executors.foreach(e => numReceiversOnExecutor(e) = 0) + + // Firstly, we need to respect "preferredLocation". So if a receiver has "preferredLocation", + // we need to make sure the "preferredLocation" is in the candidate scheduled executor list. + for (i <- 0 until receivers.length) { + // Note: preferredLocation is host but executors are host:port + receivers(i).preferredLocation.foreach { host => + hostToExecutors.get(host) match { + case Some(executorsOnHost) => + // preferredLocation is a known host. Select an executor that has the least receivers in + // this host + val leastScheduledExecutor = + executorsOnHost.minBy(executor => numReceiversOnExecutor(executor)) + scheduledExecutors(i) += leastScheduledExecutor + numReceiversOnExecutor(leastScheduledExecutor) = + numReceiversOnExecutor(leastScheduledExecutor) + 1 + case None => + // preferredLocation is an unknown host. + // Note: There are two cases: + // 1. This executor is not up. But it may be up later. + // 2. This executor is dead, or it's not a host in the cluster. + // Currently, simply add host to the scheduled executors. + scheduledExecutors(i) += host + } + } + } + + // For those receivers that don't have preferredLocation, make sure we assign at least one + // executor to them. + for (scheduledExecutorsForOneReceiver <- scheduledExecutors.filter(_.isEmpty)) { + // Select the executor that has the least receivers + val (leastScheduledExecutor, numReceivers) = numReceiversOnExecutor.minBy(_._2) + scheduledExecutorsForOneReceiver += leastScheduledExecutor + numReceiversOnExecutor(leastScheduledExecutor) = numReceivers + 1 + } + + // Assign idle executors to receivers that have less executors + val idleExecutors = numReceiversOnExecutor.filter(_._2 == 0).map(_._1) + for (executor <- idleExecutors) { + // Assign an idle executor to the receiver that has least candidate executors. + val leastScheduledExecutors = scheduledExecutors.minBy(_.size) + leastScheduledExecutors += executor + } + + receivers.map(_.streamId).zip(scheduledExecutors).toMap + } + + /** + * Return a list of candidate executors to run the receiver. If the list is empty, the caller can + * run this receiver in arbitrary executor. + * + * This method tries to balance executors' load. Here is the approach to schedule executors + * for a receiver. + *
      + *
    1. + * If preferredLocation is set, preferredLocation should be one of the candidate executors. + *
    2. + *
    3. + * Every executor will be assigned to a weight according to the receivers running or + * scheduling on it. + *
        + *
      • + * If a receiver is running on an executor, it contributes 1.0 to the executor's weight. + *
      • + *
      • + * If a receiver is scheduled to an executor but has not yet run, it contributes + * `1.0 / #candidate_executors_of_this_receiver` to the executor's weight.
      • + *
      + * At last, if there are any idle executors (weight = 0), returns all idle executors. + * Otherwise, returns the executors that have the minimum weight. + *
    4. + *
    + * + * This method is called when a receiver is registering with ReceiverTracker or is restarting. + */ + def rescheduleReceiver( + receiverId: Int, + preferredLocation: Option[String], + receiverTrackingInfoMap: Map[Int, ReceiverTrackingInfo], + executors: Seq[String]): Seq[String] = { + if (executors.isEmpty) { + return Seq.empty + } + + // Always try to schedule to the preferred locations + val scheduledExecutors = mutable.Set[String]() + scheduledExecutors ++= preferredLocation + + val executorWeights = receiverTrackingInfoMap.values.flatMap { receiverTrackingInfo => + receiverTrackingInfo.state match { + case ReceiverState.INACTIVE => Nil + case ReceiverState.SCHEDULED => + val scheduledExecutors = receiverTrackingInfo.scheduledExecutors.get + // The probability that a scheduled receiver will run in an executor is + // 1.0 / scheduledLocations.size + scheduledExecutors.map(location => location -> (1.0 / scheduledExecutors.size)) + case ReceiverState.ACTIVE => Seq(receiverTrackingInfo.runningExecutor.get -> 1.0) + } + }.groupBy(_._1).mapValues(_.map(_._2).sum) // Sum weights for each executor + + val idleExecutors = executors.toSet -- executorWeights.keys + if (idleExecutors.nonEmpty) { + scheduledExecutors ++= idleExecutors + } else { + // There is no idle executor. So select all executors that have the minimum weight. + val sortedExecutors = executorWeights.toSeq.sortBy(_._2) + if (sortedExecutors.nonEmpty) { + val minWeight = sortedExecutors(0)._2 + scheduledExecutors ++= sortedExecutors.takeWhile(_._2 == minWeight).map(_._1) + } else { + // This should not happen since "executors" is not empty + } + } + scheduledExecutors.toSeq + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index e6cdbec11e94..f86fd44b4871 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -17,16 +17,27 @@ package org.apache.spark.streaming.scheduler -import scala.collection.mutable.{HashMap, SynchronizedMap} +import java.util.concurrent.{TimeUnit, CountDownLatch} + +import scala.collection.mutable.HashMap +import scala.concurrent.ExecutionContext import scala.language.existentials +import scala.util.{Failure, Success} import org.apache.spark.streaming.util.WriteAheadLogUtils -import org.apache.spark.{Logging, SparkEnv, SparkException} +import org.apache.spark._ +import org.apache.spark.rdd.RDD import org.apache.spark.rpc._ import org.apache.spark.streaming.{StreamingContext, Time} -import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, - StopReceiver} -import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.streaming.receiver._ +import org.apache.spark.util.{ThreadUtils, SerializableConfiguration} + + +/** Enumeration to identify current state of a Receiver */ +private[streaming] object ReceiverState extends Enumeration { + type ReceiverState = Value + val INACTIVE, SCHEDULED, ACTIVE = Value +} /** * Messages used by the NetworkReceiver and the ReceiverTracker to communicate @@ -36,7 +47,7 @@ private[streaming] sealed trait ReceiverTrackerMessage private[streaming] case class RegisterReceiver( streamId: Int, typ: String, - host: String, + hostPort: String, receiverEndpoint: RpcEndpointRef ) extends ReceiverTrackerMessage private[streaming] case class AddBlock(receivedBlockInfo: ReceivedBlockInfo) @@ -45,6 +56,39 @@ private[streaming] case class ReportError(streamId: Int, message: String, error: private[streaming] case class DeregisterReceiver(streamId: Int, msg: String, error: String) extends ReceiverTrackerMessage +/** + * Messages used by the driver and ReceiverTrackerEndpoint to communicate locally. + */ +private[streaming] sealed trait ReceiverTrackerLocalMessage + +/** + * This message will trigger ReceiverTrackerEndpoint to restart a Spark job for the receiver. + */ +private[streaming] case class RestartReceiver(receiver: Receiver[_]) + extends ReceiverTrackerLocalMessage + +/** + * This message is sent to ReceiverTrackerEndpoint when we start to launch Spark jobs for receivers + * at the first time. + */ +private[streaming] case class StartAllReceivers(receiver: Seq[Receiver[_]]) + extends ReceiverTrackerLocalMessage + +/** + * This message will trigger ReceiverTrackerEndpoint to send stop signals to all registered + * receivers. + */ +private[streaming] case object StopAllReceivers extends ReceiverTrackerLocalMessage + +/** + * A message used by ReceiverTracker to ask all receiver's ids still stored in + * ReceiverTrackerEndpoint. + */ +private[streaming] case object AllReceiverIds extends ReceiverTrackerLocalMessage + +private[streaming] case class UpdateReceiverRateLimit(streamUID: Int, newRate: Long) + extends ReceiverTrackerLocalMessage + /** * This class manages the execution of the receivers of ReceiverInputDStreams. Instance of * this class must be created after all input streams have been added and StreamingContext.start() @@ -57,8 +101,6 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false private val receiverInputStreams = ssc.graph.getReceiverInputStreams() private val receiverInputStreamIds = receiverInputStreams.map { _.id } - private val receiverExecutor = new ReceiverLauncher() - private val receiverInfo = new HashMap[Int, ReceiverInfo] with SynchronizedMap[Int, ReceiverInfo] private val receivedBlockTracker = new ReceivedBlockTracker( ssc.sparkContext.conf, ssc.sparkContext.hadoopConfiguration, @@ -69,35 +111,87 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false ) private val listenerBus = ssc.scheduler.listenerBus + /** Enumeration to identify current state of the ReceiverTracker */ + object TrackerState extends Enumeration { + type TrackerState = Value + val Initialized, Started, Stopping, Stopped = Value + } + import TrackerState._ + + /** State of the tracker. Protected by "trackerStateLock" */ + @volatile private var trackerState = Initialized + // endpoint is created when generator starts. // This not being null means the tracker has been started and not stopped private var endpoint: RpcEndpointRef = null + private val schedulingPolicy = new ReceiverSchedulingPolicy() + + // Track the active receiver job number. When a receiver job exits ultimately, countDown will + // be called. + private val receiverJobExitLatch = new CountDownLatch(receiverInputStreams.size) + + /** + * Track all receivers' information. The key is the receiver id, the value is the receiver info. + * It's only accessed in ReceiverTrackerEndpoint. + */ + private val receiverTrackingInfos = new HashMap[Int, ReceiverTrackingInfo] + + /** + * Store all preferred locations for all receivers. We need this information to schedule + * receivers. It's only accessed in ReceiverTrackerEndpoint. + */ + private val receiverPreferredLocations = new HashMap[Int, Option[String]] + /** Start the endpoint and receiver execution thread. */ def start(): Unit = synchronized { - if (endpoint != null) { + if (isTrackerStarted) { throw new SparkException("ReceiverTracker already started") } if (!receiverInputStreams.isEmpty) { endpoint = ssc.env.rpcEnv.setupEndpoint( "ReceiverTracker", new ReceiverTrackerEndpoint(ssc.env.rpcEnv)) - if (!skipReceiverLaunch) receiverExecutor.start() + if (!skipReceiverLaunch) launchReceivers() logInfo("ReceiverTracker started") + trackerState = Started } } /** Stop the receiver execution thread. */ def stop(graceful: Boolean): Unit = synchronized { - if (!receiverInputStreams.isEmpty && endpoint != null) { + if (isTrackerStarted) { // First, stop the receivers - if (!skipReceiverLaunch) receiverExecutor.stop(graceful) + trackerState = Stopping + if (!skipReceiverLaunch) { + // Send the stop signal to all the receivers + endpoint.askWithRetry[Boolean](StopAllReceivers) + + // Wait for the Spark job that runs the receivers to be over + // That is, for the receivers to quit gracefully. + receiverJobExitLatch.await(10, TimeUnit.SECONDS) + + if (graceful) { + logInfo("Waiting for receiver job to terminate gracefully") + receiverJobExitLatch.await() + logInfo("Waited for receiver job to terminate gracefully") + } + + // Check if all the receivers have been deregistered or not + val receivers = endpoint.askWithRetry[Seq[Int]](AllReceiverIds) + if (receivers.nonEmpty) { + logWarning("Not all of the receivers have deregistered, " + receivers) + } else { + logInfo("All of the receivers have deregistered successfully") + } + } // Finally, stop the endpoint ssc.env.rpcEnv.stop(endpoint) endpoint = null receivedBlockTracker.stop() logInfo("ReceiverTracker stopped") + trackerState = Stopped } } @@ -115,9 +209,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Get the blocks allocated to the given batch and stream. */ def getBlocksOfBatchAndStream(batchTime: Time, streamId: Int): Seq[ReceivedBlockInfo] = { - synchronized { - receivedBlockTracker.getBlocksOfBatchAndStream(batchTime, streamId) - } + receivedBlockTracker.getBlocksOfBatchAndStream(batchTime, streamId) } /** @@ -131,8 +223,11 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Signal the receivers to delete old block data if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { logInfo(s"Cleanup old received batch data: $cleanupThreshTime") - receiverInfo.values.flatMap { info => Option(info.endpoint) } - .foreach { _.send(CleanupOldBlocks(cleanupThreshTime)) } + synchronized { + if (isTrackerStarted) { + endpoint.send(CleanupOldBlocks(cleanupThreshTime)) + } + } } } @@ -140,36 +235,64 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false private def registerReceiver( streamId: Int, typ: String, - host: String, + hostPort: String, receiverEndpoint: RpcEndpointRef, senderAddress: RpcAddress - ) { + ): Boolean = { if (!receiverInputStreamIds.contains(streamId)) { throw new SparkException("Register received for unexpected id " + streamId) } - receiverInfo(streamId) = ReceiverInfo( - streamId, s"${typ}-${streamId}", receiverEndpoint, true, host) - listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) - logInfo("Registered receiver for stream " + streamId + " from " + senderAddress) + + if (isTrackerStopping || isTrackerStopped) { + return false + } + + val scheduledExecutors = receiverTrackingInfos(streamId).scheduledExecutors + val accetableExecutors = if (scheduledExecutors.nonEmpty) { + // This receiver is registering and it's scheduled by + // ReceiverSchedulingPolicy.scheduleReceivers. So use "scheduledExecutors" to check it. + scheduledExecutors.get + } else { + // This receiver is scheduled by "ReceiverSchedulingPolicy.rescheduleReceiver", so calling + // "ReceiverSchedulingPolicy.rescheduleReceiver" again to check it. + scheduleReceiver(streamId) + } + + if (!accetableExecutors.contains(hostPort)) { + // Refuse it since it's scheduled to a wrong executor + false + } else { + val name = s"${typ}-${streamId}" + val receiverTrackingInfo = ReceiverTrackingInfo( + streamId, + ReceiverState.ACTIVE, + scheduledExecutors = None, + runningExecutor = Some(hostPort), + name = Some(name), + endpoint = Some(receiverEndpoint)) + receiverTrackingInfos.put(streamId, receiverTrackingInfo) + listenerBus.post(StreamingListenerReceiverStarted(receiverTrackingInfo.toReceiverInfo)) + logInfo("Registered receiver for stream " + streamId + " from " + senderAddress) + true + } } /** Deregister a receiver */ private def deregisterReceiver(streamId: Int, message: String, error: String) { - val newReceiverInfo = receiverInfo.get(streamId) match { + val lastErrorTime = + if (error == null || error == "") -1 else ssc.scheduler.clock.getTimeMillis() + val errorInfo = ReceiverErrorInfo( + lastErrorMessage = message, lastError = error, lastErrorTime = lastErrorTime) + val newReceiverTrackingInfo = receiverTrackingInfos.get(streamId) match { case Some(oldInfo) => - val lastErrorTime = - if (error == null || error == "") -1 else ssc.scheduler.clock.getTimeMillis() - oldInfo.copy(endpoint = null, active = false, lastErrorMessage = message, - lastError = error, lastErrorTime = lastErrorTime) + oldInfo.copy(state = ReceiverState.INACTIVE, errorInfo = Some(errorInfo)) case None => logWarning("No prior receiver info") - val lastErrorTime = - if (error == null || error == "") -1 else ssc.scheduler.clock.getTimeMillis() - ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, - lastError = error, lastErrorTime = lastErrorTime) + ReceiverTrackingInfo( + streamId, ReceiverState.INACTIVE, None, None, None, None, Some(errorInfo)) } - receiverInfo -= streamId - listenerBus.post(StreamingListenerReceiverStopped(newReceiverInfo)) + receiverTrackingInfos(streamId) = newReceiverTrackingInfo + listenerBus.post(StreamingListenerReceiverStopped(newReceiverTrackingInfo.toReceiverInfo)) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" } else { @@ -178,6 +301,13 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false logError(s"Deregistered receiver for stream $streamId: $messageWithError") } + /** Update a receiver's maximum ingestion rate */ + def sendRateUpdate(streamUID: Int, newRate: Long): Unit = synchronized { + if (isTrackerStarted) { + endpoint.send(UpdateReceiverRateLimit(streamUID, newRate)) + } + } + /** Add new blocks for the given stream */ private def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = { receivedBlockTracker.addBlock(receivedBlockInfo) @@ -185,16 +315,21 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Report error sent by a receiver */ private def reportError(streamId: Int, message: String, error: String) { - val newReceiverInfo = receiverInfo.get(streamId) match { + val newReceiverTrackingInfo = receiverTrackingInfos.get(streamId) match { case Some(oldInfo) => - oldInfo.copy(lastErrorMessage = message, lastError = error) + val errorInfo = ReceiverErrorInfo(lastErrorMessage = message, lastError = error, + lastErrorTime = oldInfo.errorInfo.map(_.lastErrorTime).getOrElse(-1L)) + oldInfo.copy(errorInfo = Some(errorInfo)) case None => logWarning("No prior receiver info") - ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, - lastError = error, lastErrorTime = ssc.scheduler.clock.getTimeMillis()) + val errorInfo = ReceiverErrorInfo(lastErrorMessage = message, lastError = error, + lastErrorTime = ssc.scheduler.clock.getTimeMillis()) + ReceiverTrackingInfo( + streamId, ReceiverState.INACTIVE, None, None, None, None, Some(errorInfo)) } - receiverInfo(streamId) = newReceiverInfo - listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId))) + + receiverTrackingInfos(streamId) = newReceiverTrackingInfo + listenerBus.post(StreamingListenerReceiverError(newReceiverTrackingInfo.toReceiverInfo)) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" } else { @@ -203,134 +338,266 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false logWarning(s"Error reported by receiver for stream $streamId: $messageWithError") } + private def scheduleReceiver(receiverId: Int): Seq[String] = { + val preferredLocation = receiverPreferredLocations.getOrElse(receiverId, None) + val scheduledExecutors = schedulingPolicy.rescheduleReceiver( + receiverId, preferredLocation, receiverTrackingInfos, getExecutors) + updateReceiverScheduledExecutors(receiverId, scheduledExecutors) + scheduledExecutors + } + + private def updateReceiverScheduledExecutors( + receiverId: Int, scheduledExecutors: Seq[String]): Unit = { + val newReceiverTrackingInfo = receiverTrackingInfos.get(receiverId) match { + case Some(oldInfo) => + oldInfo.copy(state = ReceiverState.SCHEDULED, + scheduledExecutors = Some(scheduledExecutors)) + case None => + ReceiverTrackingInfo( + receiverId, + ReceiverState.SCHEDULED, + Some(scheduledExecutors), + runningExecutor = None) + } + receiverTrackingInfos.put(receiverId, newReceiverTrackingInfo) + } + /** Check if any blocks are left to be processed */ def hasUnallocatedBlocks: Boolean = { receivedBlockTracker.hasUnallocatedReceivedBlocks } + /** + * Get the list of executors excluding driver + */ + private def getExecutors: Seq[String] = { + if (ssc.sc.isLocal) { + Seq(ssc.sparkContext.env.blockManager.blockManagerId.hostPort) + } else { + ssc.sparkContext.env.blockManager.master.getMemoryStatus.filter { case (blockManagerId, _) => + blockManagerId.executorId != SparkContext.DRIVER_IDENTIFIER // Ignore the driver location + }.map { case (blockManagerId, _) => blockManagerId.hostPort }.toSeq + } + } + + /** + * Run the dummy Spark job to ensure that all slaves have registered. This avoids all the + * receivers to be scheduled on the same node. + * + * TODO Should poll the executor number and wait for executors according to + * "spark.scheduler.minRegisteredResourcesRatio" and + * "spark.scheduler.maxRegisteredResourcesWaitingTime" rather than running a dummy job. + */ + private def runDummySparkJob(): Unit = { + if (!ssc.sparkContext.isLocal) { + ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ + _, 20).collect() + } + assert(getExecutors.nonEmpty) + } + + /** + * Get the receivers from the ReceiverInputDStreams, distributes them to the + * worker nodes as a parallel collection, and runs them. + */ + private def launchReceivers(): Unit = { + val receivers = receiverInputStreams.map(nis => { + val rcvr = nis.getReceiver() + rcvr.setReceiverId(nis.id) + rcvr + }) + + runDummySparkJob() + + logInfo("Starting " + receivers.length + " receivers") + endpoint.send(StartAllReceivers(receivers)) + } + + /** Check if tracker has been marked for starting */ + private def isTrackerStarted: Boolean = trackerState == Started + + /** Check if tracker has been marked for stopping */ + private def isTrackerStopping: Boolean = trackerState == Stopping + + /** Check if tracker has been marked for stopped */ + private def isTrackerStopped: Boolean = trackerState == Stopped + /** RpcEndpoint to receive messages from the receivers. */ private class ReceiverTrackerEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint { + // TODO Remove this thread pool after https://github.com/apache/spark/issues/7385 is merged + private val submitJobThreadPool = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("submit-job-thead-pool")) + override def receive: PartialFunction[Any, Unit] = { + // Local messages + case StartAllReceivers(receivers) => + val scheduledExecutors = schedulingPolicy.scheduleReceivers(receivers, getExecutors) + for (receiver <- receivers) { + val executors = scheduledExecutors(receiver.streamId) + updateReceiverScheduledExecutors(receiver.streamId, executors) + receiverPreferredLocations(receiver.streamId) = receiver.preferredLocation + startReceiver(receiver, executors) + } + case RestartReceiver(receiver) => + // Old scheduled executors minus the ones that are not active any more + val oldScheduledExecutors = getStoredScheduledExecutors(receiver.streamId) + val scheduledExecutors = if (oldScheduledExecutors.nonEmpty) { + // Try global scheduling again + oldScheduledExecutors + } else { + val oldReceiverInfo = receiverTrackingInfos(receiver.streamId) + // Clear "scheduledExecutors" to indicate we are going to do local scheduling + val newReceiverInfo = oldReceiverInfo.copy( + state = ReceiverState.INACTIVE, scheduledExecutors = None) + receiverTrackingInfos(receiver.streamId) = newReceiverInfo + schedulingPolicy.rescheduleReceiver( + receiver.streamId, + receiver.preferredLocation, + receiverTrackingInfos, + getExecutors) + } + // Assume there is one receiver restarting at one time, so we don't need to update + // receiverTrackingInfos + startReceiver(receiver, scheduledExecutors) + case c: CleanupOldBlocks => + receiverTrackingInfos.values.flatMap(_.endpoint).foreach(_.send(c)) + case UpdateReceiverRateLimit(streamUID, newRate) => + for (info <- receiverTrackingInfos.get(streamUID); eP <- info.endpoint) { + eP.send(UpdateRateLimit(newRate)) + } + // Remote messages case ReportError(streamId, message, error) => reportError(streamId, message, error) } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RegisterReceiver(streamId, typ, host, receiverEndpoint) => - registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address) - context.reply(true) + // Remote messages + case RegisterReceiver(streamId, typ, hostPort, receiverEndpoint) => + val successful = + registerReceiver(streamId, typ, hostPort, receiverEndpoint, context.sender.address) + context.reply(successful) case AddBlock(receivedBlockInfo) => context.reply(addBlock(receivedBlockInfo)) case DeregisterReceiver(streamId, message, error) => deregisterReceiver(streamId, message, error) context.reply(true) + // Local messages + case AllReceiverIds => + context.reply(receiverTrackingInfos.filter(_._2.state != ReceiverState.INACTIVE).keys.toSeq) + case StopAllReceivers => + assert(isTrackerStopping || isTrackerStopped) + stopReceivers() + context.reply(true) } - } - - /** This thread class runs all the receivers on the cluster. */ - class ReceiverLauncher { - @transient val env = ssc.env - @volatile @transient private var running = false - @transient val thread = new Thread() { - override def run() { - try { - SparkEnv.set(env) - startReceivers() - } catch { - case ie: InterruptedException => logInfo("ReceiverLauncher interrupted") - } - } - } - - def start() { - thread.start() - } - - def stop(graceful: Boolean) { - // Send the stop signal to all the receivers - stopReceivers() - // Wait for the Spark job that runs the receivers to be over - // That is, for the receivers to quit gracefully. - thread.join(10000) - - if (graceful) { - val pollTime = 100 - logInfo("Waiting for receiver job to terminate gracefully") - while (receiverInfo.nonEmpty || running) { - Thread.sleep(pollTime) + /** + * Return the stored scheduled executors that are still alive. + */ + private def getStoredScheduledExecutors(receiverId: Int): Seq[String] = { + if (receiverTrackingInfos.contains(receiverId)) { + val scheduledExecutors = receiverTrackingInfos(receiverId).scheduledExecutors + if (scheduledExecutors.nonEmpty) { + val executors = getExecutors.toSet + // Only return the alive executors + scheduledExecutors.get.filter(executors) + } else { + Nil } - logInfo("Waited for receiver job to terminate gracefully") - } - - // Check if all the receivers have been deregistered or not - if (receiverInfo.nonEmpty) { - logWarning("Not all of the receivers have deregistered, " + receiverInfo) } else { - logInfo("All of the receivers have deregistered successfully") + Nil } } /** - * Get the receivers from the ReceiverInputDStreams, distributes them to the - * worker nodes as a parallel collection, and runs them. + * Start a receiver along with its scheduled executors */ - private def startReceivers() { - val receivers = receiverInputStreams.map(nis => { - val rcvr = nis.getReceiver() - rcvr.setReceiverId(nis.id) - rcvr - }) - - // Right now, we only honor preferences if all receivers have them - val hasLocationPreferences = receivers.map(_.preferredLocation.isDefined).reduce(_ && _) - - // Create the parallel collection of receivers to distributed them on the worker nodes - val tempRDD = - if (hasLocationPreferences) { - val receiversWithPreferences = receivers.map(r => (r, Seq(r.preferredLocation.get))) - ssc.sc.makeRDD[Receiver[_]](receiversWithPreferences) - } else { - ssc.sc.makeRDD(receivers, receivers.size) - } + private def startReceiver(receiver: Receiver[_], scheduledExecutors: Seq[String]): Unit = { + def shouldStartReceiver: Boolean = { + // It's okay to start when trackerState is Initialized or Started + !(isTrackerStopping || isTrackerStopped) + } + + val receiverId = receiver.streamId + if (!shouldStartReceiver) { + onReceiverJobFinish(receiverId) + return + } val checkpointDirOption = Option(ssc.checkpointDir) val serializableHadoopConf = new SerializableConfiguration(ssc.sparkContext.hadoopConfiguration) // Function to start the receiver on the worker node - val startReceiver = (iterator: Iterator[Receiver[_]]) => { - if (!iterator.hasNext) { - throw new SparkException( - "Could not start receiver as object not found.") + val startReceiverFunc: Iterator[Receiver[_]] => Unit = + (iterator: Iterator[Receiver[_]]) => { + if (!iterator.hasNext) { + throw new SparkException( + "Could not start receiver as object not found.") + } + if (TaskContext.get().attemptNumber() == 0) { + val receiver = iterator.next() + assert(iterator.hasNext == false) + val supervisor = new ReceiverSupervisorImpl( + receiver, SparkEnv.get, serializableHadoopConf.value, checkpointDirOption) + supervisor.start() + supervisor.awaitTermination() + } else { + // It's restarted by TaskScheduler, but we want to reschedule it again. So exit it. + } } - val receiver = iterator.next() - val supervisor = new ReceiverSupervisorImpl( - receiver, SparkEnv.get, serializableHadoopConf.value, checkpointDirOption) - supervisor.start() - supervisor.awaitTermination() - } - // Run the dummy Spark job to ensure that all slaves have registered. - // This avoids all the receivers to be scheduled on the same node. - if (!ssc.sparkContext.isLocal) { - ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ + _, 20).collect() - } - // Distribute the receivers and start them - logInfo("Starting " + receivers.length + " receivers") - running = true - ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver)) - running = false - logInfo("All of the receivers have been terminated") + // Create the RDD using the scheduledExecutors to run the receiver in a Spark job + val receiverRDD: RDD[Receiver[_]] = + if (scheduledExecutors.isEmpty) { + ssc.sc.makeRDD(Seq(receiver), 1) + } else { + ssc.sc.makeRDD(Seq(receiver -> scheduledExecutors)) + } + receiverRDD.setName(s"Receiver $receiverId") + val future = ssc.sparkContext.submitJob[Receiver[_], Unit, Unit]( + receiverRDD, startReceiverFunc, Seq(0), (_, _) => Unit, ()) + // We will keep restarting the receiver job until ReceiverTracker is stopped + future.onComplete { + case Success(_) => + if (!shouldStartReceiver) { + onReceiverJobFinish(receiverId) + } else { + logInfo(s"Restarting Receiver $receiverId") + self.send(RestartReceiver(receiver)) + } + case Failure(e) => + if (!shouldStartReceiver) { + onReceiverJobFinish(receiverId) + } else { + logError("Receiver has been stopped. Try to restart it.", e) + logInfo(s"Restarting Receiver $receiverId") + self.send(RestartReceiver(receiver)) + } + }(submitJobThreadPool) + logInfo(s"Receiver ${receiver.streamId} started") + } + + override def onStop(): Unit = { + submitJobThreadPool.shutdownNow() } - /** Stops the receivers. */ + /** + * Call when a receiver is terminated. It means we won't restart its Spark job. + */ + private def onReceiverJobFinish(receiverId: Int): Unit = { + receiverJobExitLatch.countDown() + receiverTrackingInfos.remove(receiverId).foreach { receiverTrackingInfo => + if (receiverTrackingInfo.state == ReceiverState.ACTIVE) { + logWarning(s"Receiver $receiverId exited but didn't deregister") + } + } + } + + /** Send stop signal to the receivers. */ private def stopReceivers() { - // Signal the receivers to stop - receiverInfo.values.flatMap { info => Option(info.endpoint)} - .foreach { _.send(StopReceiver) } - logInfo("Sent stop signal to all " + receiverInfo.size + " receivers") + receiverTrackingInfos.values.flatMap(_.endpoint).foreach { _.send(StopReceiver) } + logInfo("Sent stop signal to all " + receiverTrackingInfos.size + " receivers") } } + } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala new file mode 100644 index 000000000000..043ff4d0ff05 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.streaming.scheduler.ReceiverState._ + +private[streaming] case class ReceiverErrorInfo( + lastErrorMessage: String = "", lastError: String = "", lastErrorTime: Long = -1L) + +/** + * Class having information about a receiver. + * + * @param receiverId the unique receiver id + * @param state the current Receiver state + * @param scheduledExecutors the scheduled executors provided by ReceiverSchedulingPolicy + * @param runningExecutor the running executor if the receiver is active + * @param name the receiver name + * @param endpoint the receiver endpoint. It can be used to send messages to the receiver + * @param errorInfo the receiver error information if it fails + */ +private[streaming] case class ReceiverTrackingInfo( + receiverId: Int, + state: ReceiverState, + scheduledExecutors: Option[Seq[String]], + runningExecutor: Option[String], + name: Option[String] = None, + endpoint: Option[RpcEndpointRef] = None, + errorInfo: Option[ReceiverErrorInfo] = None) { + + def toReceiverInfo: ReceiverInfo = ReceiverInfo( + receiverId, + name.getOrElse(""), + state == ReceiverState.ACTIVE, + location = runningExecutor.getOrElse(""), + lastErrorMessage = errorInfo.map(_.lastErrorMessage).getOrElse(""), + lastError = errorInfo.map(_.lastError).getOrElse(""), + lastErrorTime = errorInfo.map(_.lastErrorTime).getOrElse(-1L) + ) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala new file mode 100644 index 000000000000..84a3ca9d74e5 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler.rate + +import org.apache.spark.Logging + +/** + * Implements a proportional-integral-derivative (PID) controller which acts on + * the speed of ingestion of elements into Spark Streaming. A PID controller works + * by calculating an '''error''' between a measured output and a desired value. In the + * case of Spark Streaming the error is the difference between the measured processing + * rate (number of elements/processing delay) and the previous rate. + * + * @see https://en.wikipedia.org/wiki/PID_controller + * + * @param batchIntervalMillis the batch duration, in milliseconds + * @param proportional how much the correction should depend on the current + * error. This term usually provides the bulk of correction and should be positive or zero. + * A value too large would make the controller overshoot the setpoint, while a small value + * would make the controller too insensitive. The default value is 1. + * @param integral how much the correction should depend on the accumulation + * of past errors. This value should be positive or 0. This term accelerates the movement + * towards the desired value, but a large value may lead to overshooting. The default value + * is 0.2. + * @param derivative how much the correction should depend on a prediction + * of future errors, based on current rate of change. This value should be positive or 0. + * This term is not used very often, as it impacts stability of the system. The default + * value is 0. + * @param minRate what is the minimum rate that can be estimated. + * This must be greater than zero, so that the system always receives some data for rate + * estimation to work. + */ +private[streaming] class PIDRateEstimator( + batchIntervalMillis: Long, + proportional: Double, + integral: Double, + derivative: Double, + minRate: Double + ) extends RateEstimator with Logging { + + private var firstRun: Boolean = true + private var latestTime: Long = -1L + private var latestRate: Double = -1D + private var latestError: Double = -1L + + require( + batchIntervalMillis > 0, + s"Specified batch interval $batchIntervalMillis in PIDRateEstimator is invalid.") + require( + proportional >= 0, + s"Proportional term $proportional in PIDRateEstimator should be >= 0.") + require( + integral >= 0, + s"Integral term $integral in PIDRateEstimator should be >= 0.") + require( + derivative >= 0, + s"Derivative term $derivative in PIDRateEstimator should be >= 0.") + require( + minRate > 0, + s"Minimum rate in PIDRateEstimator should be > 0") + + logInfo(s"Created PIDRateEstimator with proportional = $proportional, integral = $integral, " + + s"derivative = $derivative, min rate = $minRate") + + def compute( + time: Long, // in milliseconds + numElements: Long, + processingDelay: Long, // in milliseconds + schedulingDelay: Long // in milliseconds + ): Option[Double] = { + logTrace(s"\ntime = $time, # records = $numElements, " + + s"processing time = $processingDelay, scheduling delay = $schedulingDelay") + this.synchronized { + if (time > latestTime && numElements > 0 && processingDelay > 0) { + + // in seconds, should be close to batchDuration + val delaySinceUpdate = (time - latestTime).toDouble / 1000 + + // in elements/second + val processingRate = numElements.toDouble / processingDelay * 1000 + + // In our system `error` is the difference between the desired rate and the measured rate + // based on the latest batch information. We consider the desired rate to be latest rate, + // which is what this estimator calculated for the previous batch. + // in elements/second + val error = latestRate - processingRate + + // The error integral, based on schedulingDelay as an indicator for accumulated errors. + // A scheduling delay s corresponds to s * processingRate overflowing elements. Those + // are elements that couldn't be processed in previous batches, leading to this delay. + // In the following, we assume the processingRate didn't change too much. + // From the number of overflowing elements we can calculate the rate at which they would be + // processed by dividing it by the batch interval. This rate is our "historical" error, + // or integral part, since if we subtracted this rate from the previous "calculated rate", + // there wouldn't have been any overflowing elements, and the scheduling delay would have + // been zero. + // (in elements/second) + val historicalError = schedulingDelay.toDouble * processingRate / batchIntervalMillis + + // in elements/(second ^ 2) + val dError = (error - latestError) / delaySinceUpdate + + val newRate = (latestRate - proportional * error - + integral * historicalError - + derivative * dError).max(minRate) + logTrace(s""" + | latestRate = $latestRate, error = $error + | latestError = $latestError, historicalError = $historicalError + | delaySinceUpdate = $delaySinceUpdate, dError = $dError + """.stripMargin) + + latestTime = time + if (firstRun) { + latestRate = processingRate + latestError = 0D + firstRun = false + logTrace("First run, rate estimation skipped") + None + } else { + latestRate = newRate + latestError = error + logTrace(s"New rate = $newRate") + Some(newRate) + } + } else { + logTrace("Rate estimation skipped") + None + } + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala new file mode 100644 index 000000000000..d7210f64fcc3 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler.rate + +import org.apache.spark.SparkConf +import org.apache.spark.streaming.Duration + +/** + * A component that estimates the rate at wich an InputDStream should ingest + * elements, based on updates at every batch completion. + */ +private[streaming] trait RateEstimator extends Serializable { + + /** + * Computes the number of elements the stream attached to this `RateEstimator` + * should ingest per second, given an update on the size and completion + * times of the latest batch. + * + * @param time The timetamp of the current batch interval that just finished + * @param elements The number of elements that were processed in this batch + * @param processingDelay The time in ms that took for the job to complete + * @param schedulingDelay The time in ms that the job spent in the scheduling queue + */ + def compute( + time: Long, + elements: Long, + processingDelay: Long, + schedulingDelay: Long): Option[Double] +} + +object RateEstimator { + + /** + * Return a new RateEstimator based on the value of `spark.streaming.RateEstimator`. + * + * The only known estimator right now is `pid`. + * + * @return An instance of RateEstimator + * @throws IllegalArgumentException if there is a configured RateEstimator that doesn't match any + * known estimators. + */ + def create(conf: SparkConf, batchInterval: Duration): RateEstimator = + conf.get("spark.streaming.backpressure.rateEstimator", "pid") match { + case "pid" => + val proportional = conf.getDouble("spark.streaming.backpressure.pid.proportional", 1.0) + val integral = conf.getDouble("spark.streaming.backpressure.pid.integral", 0.2) + val derived = conf.getDouble("spark.streaming.backpressure.pid.derived", 0.0) + val minRate = conf.getDouble("spark.streaming.backpressure.pid.minRate", 100) + new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived, minRate) + + case estimator => + throw new IllegalArgumentException(s"Unkown rate estimator: $estimator") + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index f75067669abe..90d1b0fadecf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -17,11 +17,9 @@ package org.apache.spark.streaming.ui -import java.text.SimpleDateFormat -import java.util.Date import javax.servlet.http.HttpServletRequest -import scala.xml.{NodeSeq, Node, Text} +import scala.xml.{NodeSeq, Node, Text, Unparsed} import org.apache.commons.lang3.StringEscapeUtils @@ -30,7 +28,7 @@ import org.apache.spark.ui.{UIUtils => SparkUIUtils, WebUIPage} import org.apache.spark.streaming.ui.StreamingJobProgressListener.{SparkJobId, OutputOpId} import org.apache.spark.ui.jobs.UIData.JobUIData -private case class SparkJobIdWithUIData(sparkJobId: SparkJobId, jobUIData: Option[JobUIData]) +private[ui] case class SparkJobIdWithUIData(sparkJobId: SparkJobId, jobUIData: Option[JobUIData]) private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { private val streamingListener = parent.listener @@ -303,6 +301,9 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { batchUIData.processingDelay.map(SparkUIUtils.formatDuration).getOrElse("-") val formattedTotalDelay = batchUIData.totalDelay.map(SparkUIUtils.formatDuration).getOrElse("-") + val inputMetadatas = batchUIData.streamIdToInputInfo.values.flatMap { inputInfo => + inputInfo.metadataDescription.map(desc => inputInfo.inputStreamId -> desc) + }.toSeq val summary: NodeSeq =
      @@ -326,6 +327,13 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { Total delay: {formattedTotalDelay} + { + if (inputMetadatas.nonEmpty) { +
    • + Input Metadata:{generateInputMetadataTable(inputMetadatas)} +
    • + } + }
    @@ -340,4 +348,33 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { SparkUIUtils.headerSparkPage(s"Details of batch at $formattedBatchTime", content, parent) } + + def generateInputMetadataTable(inputMetadatas: Seq[(Int, String)]): Seq[Node] = { +
    Property NameDefaultMeaning
    spark.sql.hive.metastore.version0.13.11.2.1 Version of the Hive metastore. Available - options are 0.12.0 and 0.13.1. Support for more versions is coming in the future. + options are 0.12.0 through 1.2.1.
    spark.sql.codegenfalsespark.sql.tungsten.enabledtrue - When true, code will be dynamically generated at runtime for expression evaluation in a specific - query. For some queries with complicated expression this option can lead to significant speed-ups. - However, for simple queries this can actually slow down query execution. + When true, use the optimized Tungsten physical execution backend which explicitly manages memory + and dynamically generates bytecode for expression evaluation.
    spark.sql.planner.externalSortfalsetrue When true, performs sorts spilling to disk as needed otherwise sort each partition in memory.
    {session.userName} {session.ip}
    + + + + + + + + {inputMetadatas.flatMap(generateInputMetadataRow)} + +
    InputMetadata
    + } + + def generateInputMetadataRow(inputMetadata: (Int, String)): Seq[Node] = { + val streamId = inputMetadata._1 + + + {streamingListener.streamName(streamId).getOrElse(s"Stream-$streamId")} + {metadataDescriptionToHTML(inputMetadata._2)} + + } + + private def metadataDescriptionToHTML(metadataDescription: String): Seq[Node] = { + // tab to 4 spaces and "\n" to "
    " + Unparsed(StringEscapeUtils.escapeHtml4(metadataDescription). + replaceAllLiterally("\t", "    ").replaceAllLiterally("\n", "
    ")) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala index a5514dfd71c9..ae508c0e9577 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala @@ -19,14 +19,14 @@ package org.apache.spark.streaming.ui import org.apache.spark.streaming.Time -import org.apache.spark.streaming.scheduler.BatchInfo +import org.apache.spark.streaming.scheduler.{BatchInfo, StreamInputInfo} import org.apache.spark.streaming.ui.StreamingJobProgressListener._ private[ui] case class OutputOpIdAndSparkJobId(outputOpId: OutputOpId, sparkJobId: SparkJobId) private[ui] case class BatchUIData( val batchTime: Time, - val streamIdToNumRecords: Map[Int, Long], + val streamIdToInputInfo: Map[Int, StreamInputInfo], val submissionTime: Long, val processingStartTime: Option[Long], val processingEndTime: Option[Long], @@ -58,7 +58,7 @@ private[ui] case class BatchUIData( /** * The number of recorders received by the receivers in this batch. */ - def numRecords: Long = streamIdToNumRecords.values.sum + def numRecords: Long = streamIdToInputInfo.values.map(_.numRecords).sum } private[ui] object BatchUIData { @@ -66,7 +66,7 @@ private[ui] object BatchUIData { def apply(batchInfo: BatchInfo): BatchUIData = { new BatchUIData( batchInfo.batchTime, - batchInfo.streamIdToNumRecords, + batchInfo.streamIdToInputInfo, batchInfo.submissionTime, batchInfo.processingStartTime, batchInfo.processingEndTime diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index 68e8ce98945e..78aeb004e18b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -148,6 +148,14 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) receiverInfos.size } + def numActiveReceivers: Int = synchronized { + receiverInfos.count(_._2.active) + } + + def numInactiveReceivers: Int = { + ssc.graph.getReceiverInputStreams().size - numActiveReceivers + } + def numTotalCompletedBatches: Long = synchronized { totalCompletedBatches } @@ -192,7 +200,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) def receivedEventRateWithBatchTime: Map[Int, Seq[(Long, Double)]] = synchronized { val _retainedBatches = retainedBatches val latestBatches = _retainedBatches.map { batchUIData => - (batchUIData.batchTime.milliseconds, batchUIData.streamIdToNumRecords) + (batchUIData.batchTime.milliseconds, batchUIData.streamIdToInputInfo.mapValues(_.numRecords)) } streamIds.map { streamId => val eventRates = latestBatches.map { @@ -205,7 +213,8 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } def lastReceivedBatchRecords: Map[Int, Long] = synchronized { - val lastReceivedBlockInfoOption = lastReceivedBatch.map(_.streamIdToNumRecords) + val lastReceivedBlockInfoOption = + lastReceivedBatch.map(_.streamIdToInputInfo.mapValues(_.numRecords)) lastReceivedBlockInfoOption.map { lastReceivedBlockInfo => streamIds.map { streamId => (streamId, lastReceivedBlockInfo.getOrElse(streamId, 0L)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 4ee7a486e370..96d943e75d27 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -303,6 +303,7 @@ private[ui] class StreamingPage(parent: StreamingTab) val numCompletedBatches = listener.retainedCompletedBatches.size val numActiveBatches = batchTimes.length - numCompletedBatches + val numReceivers = listener.numInactiveReceivers + listener.numActiveReceivers val table = // scalastyle:off @@ -310,7 +311,7 @@ private[ui] class StreamingPage(parent: StreamingTab) - + @@ -330,6 +331,11 @@ private[ui] class StreamingPage(parent: StreamingTab) } } + { + if (numReceivers > 0) { +
    Receivers: {listener.numActiveReceivers} / {numReceivers} active
    + } + }
    Avg: {eventRateForAllStreams.formattedAvg} events/sec
    @@ -456,7 +462,7 @@ private[ui] class StreamingPage(parent: StreamingTab) - +
    Timelines (Last {batchTimes.length} batches, {numActiveBatches} active, {numCompletedBatches} completed)Histograms
    Histograms
    {receiverActive} {receiverLocation} {receiverLastErrorTime}
    {receiverLastError}
    {receiverLastError}
    diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index e0c0f57212f5..bc53f2a31f6d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -17,11 +17,9 @@ package org.apache.spark.streaming.ui -import org.eclipse.jetty.servlet.ServletContextHandler - import org.apache.spark.{Logging, SparkException} import org.apache.spark.streaming.StreamingContext -import org.apache.spark.ui.{JettyUtils, SparkUI, SparkUITab} +import org.apache.spark.ui.{SparkUI, SparkUITab} import StreamingTab._ @@ -42,18 +40,14 @@ private[spark] class StreamingTab(val ssc: StreamingContext) attachPage(new StreamingPage(this)) attachPage(new BatchPage(this)) - var staticHandler: ServletContextHandler = null - def attach() { getSparkUI(ssc).attachTab(this) - staticHandler = JettyUtils.createStaticHandler(STATIC_RESOURCE_DIR, "/static/streaming") - getSparkUI(ssc).attachHandler(staticHandler) + getSparkUI(ssc).addStaticHandler(STATIC_RESOURCE_DIR, "/static/streaming") } def detach() { getSparkUI(ssc).detachTab(this) - getSparkUI(ssc).detachHandler(staticHandler) - staticHandler = null + getSparkUI(ssc).removeStaticHandler("/static/streaming") } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index fe6328b1ce72..9f4a4d6806ab 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -19,6 +19,7 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer import java.util.{Iterator => JIterator} +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.concurrent.{Await, ExecutionContext, Future} import scala.language.postfixOps @@ -118,7 +119,6 @@ private[streaming] class FileBasedWriteAheadLog( * hence the implementation is kept simple. */ def readAll(): JIterator[ByteBuffer] = synchronized { - import scala.collection.JavaConversions._ val logFilesToRead = pastLogs.map{ _.path} ++ currentLogPath logInfo("Reading from the logs: " + logFilesToRead.mkString("\n")) @@ -126,7 +126,7 @@ private[streaming] class FileBasedWriteAheadLog( logDebug(s"Creating log reader with $file") val reader = new FileBasedWriteAheadLogReader(file, hadoopConf) CompletionIterator[ByteBuffer, Iterator[ByteBuffer]](reader, reader.close _) - } flatMap { x => x } + }.flatten.asJava } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala index ca2f319f174a..6addb9675203 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala @@ -35,7 +35,9 @@ private[streaming] object RawTextSender extends Logging { def main(args: Array[String]) { if (args.length != 4) { + // scalastyle:off println System.err.println("Usage: RawTextSender ") + // scalastyle:on println System.exit(1) } // Parse the arguments using a pattern match diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala index c8eef833eb43..dd32ad5ad811 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala @@ -106,7 +106,7 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: } private[streaming] -object RecurringTimer { +object RecurringTimer extends Logging { def main(args: Array[String]) { var lastRecurTime = 0L @@ -114,7 +114,7 @@ object RecurringTimer { def onRecur(time: Long) { val currentTime = System.currentTimeMillis() - println("" + currentTime + ": " + (currentTime - lastRecurTime)) + logInfo("" + currentTime + ": " + (currentTime - lastRecurTime)) lastRecurTime = currentTime } val timer = new RecurringTimer(new SystemClock(), period, onRecur, "Test") diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 1077b1b2cb7e..c5217149224e 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -18,24 +18,22 @@ package org.apache.spark.streaming; import java.io.*; -import java.lang.Iterable; import java.nio.charset.Charset; import java.util.*; import java.util.concurrent.atomic.AtomicBoolean; +import scala.Tuple2; + +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; -import scala.Tuple2; - import org.junit.Assert; -import static org.junit.Assert.*; import org.junit.Test; import com.google.common.base.Optional; -import com.google.common.collect.Lists; import com.google.common.io.Files; import com.google.common.collect.Sets; @@ -54,14 +52,14 @@ // see http://stackoverflow.com/questions/758570/. public class JavaAPISuite extends LocalJavaStreamingContext implements Serializable { - public void equalIterator(Iterator a, Iterator b) { + public static void equalIterator(Iterator a, Iterator b) { while (a.hasNext() && b.hasNext()) { Assert.assertEquals(a.next(), b.next()); } Assert.assertEquals(a.hasNext(), b.hasNext()); } - public void equalIterable(Iterable a, Iterable b) { + public static void equalIterable(Iterable a, Iterable b) { equalIterator(a.iterator(), b.iterator()); } @@ -74,14 +72,14 @@ public void testInitialization() { @Test public void testContextState() { List> inputData = Arrays.asList(Arrays.asList(1, 2, 3, 4)); - Assert.assertTrue(ssc.getState() == StreamingContextState.INITIALIZED); + Assert.assertEquals(StreamingContextState.INITIALIZED, ssc.getState()); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaTestUtils.attachTestOutputStream(stream); - Assert.assertTrue(ssc.getState() == StreamingContextState.INITIALIZED); + Assert.assertEquals(StreamingContextState.INITIALIZED, ssc.getState()); ssc.start(); - Assert.assertTrue(ssc.getState() == StreamingContextState.ACTIVE); + Assert.assertEquals(StreamingContextState.ACTIVE, ssc.getState()); ssc.stop(); - Assert.assertTrue(ssc.getState() == StreamingContextState.STOPPED); + Assert.assertEquals(StreamingContextState.STOPPED, ssc.getState()); } @SuppressWarnings("unchecked") @@ -118,7 +116,7 @@ public void testMap() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @Override - public Integer call(String s) throws Exception { + public Integer call(String s) { return s.length(); } }); @@ -180,7 +178,7 @@ public void testWindowWithSlideDuration() { public void testFilter() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List> expected = Arrays.asList( Arrays.asList("giants"), @@ -189,7 +187,7 @@ public void testFilter() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream filtered = stream.filter(new Function() { @Override - public Boolean call(String s) throws Exception { + public Boolean call(String s) { return s.contains("a"); } }); @@ -243,11 +241,11 @@ public void testRepartitionFewerPartitions() { public void testGlom() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List>> expected = Arrays.asList( Arrays.asList(Arrays.asList("giants", "dodgers")), - Arrays.asList(Arrays.asList("yankees", "red socks"))); + Arrays.asList(Arrays.asList("yankees", "red sox"))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream> glommed = stream.glom(); @@ -262,22 +260,22 @@ public void testGlom() { public void testMapPartitions() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List> expected = Arrays.asList( Arrays.asList("GIANTSDODGERS"), - Arrays.asList("YANKEESRED SOCKS")); + Arrays.asList("YANKEESRED SOX")); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream mapped = stream.mapPartitions( new FlatMapFunction, String>() { @Override public Iterable call(Iterator in) { - String out = ""; + StringBuilder out = new StringBuilder(); while (in.hasNext()) { - out = out + in.next().toUpperCase(); + out.append(in.next().toUpperCase(Locale.ENGLISH)); } - return Lists.newArrayList(out); + return Arrays.asList(out.toString()); } }); JavaTestUtils.attachTestOutputStream(mapped); @@ -286,16 +284,16 @@ public Iterable call(Iterator in) { Assert.assertEquals(expected, result); } - private class IntegerSum implements Function2 { + private static class IntegerSum implements Function2 { @Override - public Integer call(Integer i1, Integer i2) throws Exception { + public Integer call(Integer i1, Integer i2) { return i1 + i2; } } - private class IntegerDifference implements Function2 { + private static class IntegerDifference implements Function2 { @Override - public Integer call(Integer i1, Integer i2) throws Exception { + public Integer call(Integer i1, Integer i2) { return i1 - i2; } } @@ -347,13 +345,13 @@ private void testReduceByWindow(boolean withInverse) { Arrays.asList(24)); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream reducedWindowed = null; + JavaDStream reducedWindowed; if (withInverse) { reducedWindowed = stream.reduceByWindow(new IntegerSum(), - new IntegerDifference(), new Duration(2000), new Duration(1000)); + new IntegerDifference(), new Duration(2000), new Duration(1000)); } else { reducedWindowed = stream.reduceByWindow(new IntegerSum(), - new Duration(2000), new Duration(1000)); + new Duration(2000), new Duration(1000)); } JavaTestUtils.attachTestOutputStream(reducedWindowed); List> result = JavaTestUtils.runStreams(ssc, 4, 4); @@ -364,17 +362,25 @@ private void testReduceByWindow(boolean withInverse) { @SuppressWarnings("unchecked") @Test public void testQueueStream() { + ssc.stop(); + // Create a new JavaStreamingContext without checkpointing + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); + List> expected = Arrays.asList( Arrays.asList(1,2,3), Arrays.asList(4,5,6), Arrays.asList(7,8,9)); JavaSparkContext jsc = new JavaSparkContext(ssc.ssc().sc()); - JavaRDD rdd1 = ssc.sparkContext().parallelize(Arrays.asList(1, 2, 3)); - JavaRDD rdd2 = ssc.sparkContext().parallelize(Arrays.asList(4, 5, 6)); - JavaRDD rdd3 = ssc.sparkContext().parallelize(Arrays.asList(7,8,9)); + JavaRDD rdd1 = jsc.parallelize(Arrays.asList(1, 2, 3)); + JavaRDD rdd2 = jsc.parallelize(Arrays.asList(4, 5, 6)); + JavaRDD rdd3 = jsc.parallelize(Arrays.asList(7,8,9)); - LinkedList> rdds = Lists.newLinkedList(); + Queue> rdds = new LinkedList<>(); rdds.add(rdd1); rdds.add(rdd2); rdds.add(rdd3); @@ -402,10 +408,10 @@ public void testTransform() { JavaDStream transformed = stream.transform( new Function, JavaRDD>() { @Override - public JavaRDD call(JavaRDD in) throws Exception { + public JavaRDD call(JavaRDD in) { return in.map(new Function() { @Override - public Integer call(Integer i) throws Exception { + public Integer call(Integer i) { return i + 2; } }); @@ -427,70 +433,70 @@ public void testVariousTransform() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); List>> pairInputData = - Arrays.asList(Arrays.asList(new Tuple2("x", 1))); + Arrays.asList(Arrays.asList(new Tuple2<>("x", 1))); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData, 1)); - JavaDStream transformed1 = stream.transform( + stream.transform( new Function, JavaRDD>() { @Override - public JavaRDD call(JavaRDD in) throws Exception { + public JavaRDD call(JavaRDD in) { return null; } } ); - JavaDStream transformed2 = stream.transform( + stream.transform( new Function2, Time, JavaRDD>() { - @Override public JavaRDD call(JavaRDD in, Time time) throws Exception { + @Override public JavaRDD call(JavaRDD in, Time time) { return null; } } ); - JavaPairDStream transformed3 = stream.transformToPair( + stream.transformToPair( new Function, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaRDD in) throws Exception { + @Override public JavaPairRDD call(JavaRDD in) { return null; } } ); - JavaPairDStream transformed4 = stream.transformToPair( + stream.transformToPair( new Function2, Time, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaRDD in, Time time) throws Exception { + @Override public JavaPairRDD call(JavaRDD in, Time time) { return null; } } ); - JavaDStream pairTransformed1 = pairStream.transform( + pairStream.transform( new Function, JavaRDD>() { - @Override public JavaRDD call(JavaPairRDD in) throws Exception { + @Override public JavaRDD call(JavaPairRDD in) { return null; } } ); - JavaDStream pairTransformed2 = pairStream.transform( + pairStream.transform( new Function2, Time, JavaRDD>() { - @Override public JavaRDD call(JavaPairRDD in, Time time) throws Exception { + @Override public JavaRDD call(JavaPairRDD in, Time time) { return null; } } ); - JavaPairDStream pairTransformed3 = pairStream.transformToPair( + pairStream.transformToPair( new Function, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaPairRDD in) throws Exception { + @Override public JavaPairRDD call(JavaPairRDD in) { return null; } } ); - JavaPairDStream pairTransformed4 = pairStream.transformToPair( + pairStream.transformToPair( new Function2, Time, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaPairRDD in, Time time) throws Exception { + @Override public JavaPairRDD call(JavaPairRDD in, Time time) { return null; } } @@ -503,32 +509,32 @@ public JavaRDD call(JavaRDD in) throws Exception { public void testTransformWith() { List>> stringStringKVStream1 = Arrays.asList( Arrays.asList( - new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), + new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), Arrays.asList( - new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); + new Tuple2<>("california", "sharks"), + new Tuple2<>("new york", "rangers"))); List>> stringStringKVStream2 = Arrays.asList( Arrays.asList( - new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), + new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "mets")), Arrays.asList( - new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); + new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "islanders"))); List>>> expected = Arrays.asList( Sets.newHashSet( - new Tuple2>("california", - new Tuple2("dodgers", "giants")), - new Tuple2>("new york", - new Tuple2("yankees", "mets"))), + new Tuple2<>("california", + new Tuple2<>("dodgers", "giants")), + new Tuple2<>("new york", + new Tuple2<>("yankees", "mets"))), Sets.newHashSet( - new Tuple2>("california", - new Tuple2("sharks", "ducks")), - new Tuple2>("new york", - new Tuple2("rangers", "islanders")))); + new Tuple2<>("california", + new Tuple2<>("sharks", "ducks")), + new Tuple2<>("new york", + new Tuple2<>("rangers", "islanders")))); JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( ssc, stringStringKVStream1, 1); @@ -544,14 +550,12 @@ public void testTransformWith() { JavaPairRDD, JavaPairRDD, Time, - JavaPairRDD> - >() { + JavaPairRDD>>() { @Override public JavaPairRDD> call( JavaPairRDD rdd1, JavaPairRDD rdd2, - Time time - ) throws Exception { + Time time) { return rdd1.join(rdd2); } } @@ -559,9 +563,9 @@ public JavaPairRDD> call( JavaTestUtils.attachTestOutputStream(joined); List>>> result = JavaTestUtils.runStreams(ssc, 2, 2); - List>>> unorderedResult = Lists.newArrayList(); + List>>> unorderedResult = new ArrayList<>(); for (List>> res: result) { - unorderedResult.add(Sets.newHashSet(res)); + unorderedResult.add(Sets.newHashSet(res)); } Assert.assertEquals(expected, unorderedResult); @@ -579,89 +583,89 @@ public void testVariousTransformWith() { JavaDStream stream2 = JavaTestUtils.attachTestInputStream(ssc, inputData2, 1); List>> pairInputData1 = - Arrays.asList(Arrays.asList(new Tuple2("x", 1))); + Arrays.asList(Arrays.asList(new Tuple2<>("x", 1))); List>> pairInputData2 = - Arrays.asList(Arrays.asList(new Tuple2(1.0, 'x'))); + Arrays.asList(Arrays.asList(new Tuple2<>(1.0, 'x'))); JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData1, 1)); JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData2, 1)); - JavaDStream transformed1 = stream1.transformWith( + stream1.transformWith( stream2, new Function3, JavaRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaDStream transformed2 = stream1.transformWith( + stream1.transformWith( pairStream1, new Function3, JavaPairRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } ); - JavaPairDStream transformed3 = stream1.transformWithToPair( + stream1.transformWithToPair( stream2, new Function3, JavaRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaPairDStream transformed4 = stream1.transformWithToPair( + stream1.transformWithToPair( pairStream1, new Function3, JavaPairRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } ); - JavaDStream pairTransformed1 = pairStream1.transformWith( + pairStream1.transformWith( stream2, new Function3, JavaRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaDStream pairTransformed2_ = pairStream1.transformWith( + pairStream1.transformWith( pairStream1, new Function3, JavaPairRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } ); - JavaPairDStream pairTransformed3 = pairStream1.transformWithToPair( + pairStream1.transformWithToPair( stream2, new Function3, JavaRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaPairDStream pairTransformed4 = pairStream1.transformWithToPair( + pairStream1.transformWithToPair( pairStream2, new Function3, JavaPairRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } @@ -682,13 +686,13 @@ public void testStreamingContextTransform(){ ); List>> pairStream1input = Arrays.asList( - Arrays.asList(new Tuple2(1, "x")), - Arrays.asList(new Tuple2(2, "y")) + Arrays.asList(new Tuple2<>(1, "x")), + Arrays.asList(new Tuple2<>(2, "y")) ); List>>> expected = Arrays.asList( - Arrays.asList(new Tuple2>(1, new Tuple2(1, "x"))), - Arrays.asList(new Tuple2>(2, new Tuple2(2, "y"))) + Arrays.asList(new Tuple2<>(1, new Tuple2<>(1, "x"))), + Arrays.asList(new Tuple2<>(2, new Tuple2<>(2, "y"))) ); JavaDStream stream1 = JavaTestUtils.attachTestInputStream(ssc, stream1input, 1); @@ -699,7 +703,7 @@ public void testStreamingContextTransform(){ List> listOfDStreams1 = Arrays.>asList(stream1, stream2); // This is just to test whether this transform to JavaStream compiles - JavaDStream transformed1 = ssc.transform( + ssc.transform( listOfDStreams1, new Function2>, Time, JavaRDD>() { @Override @@ -725,8 +729,8 @@ public JavaPairRDD> call(List> listO JavaPairRDD prdd3 = JavaPairRDD.fromJavaRDD(rdd3); PairFunction mapToTuple = new PairFunction() { @Override - public Tuple2 call(Integer i) throws Exception { - return new Tuple2(i, i); + public Tuple2 call(Integer i) { + return new Tuple2<>(i, i); } }; return rdd1.union(rdd2).mapToPair(mapToTuple).join(prdd3); @@ -755,7 +759,7 @@ public void testFlatMap() { JavaDStream flatMapped = stream.flatMap(new FlatMapFunction() { @Override public Iterable call(String x) { - return Lists.newArrayList(x.split("(?!^)")); + return Arrays.asList(x.split("(?!^)")); } }); JavaTestUtils.attachTestOutputStream(flatMapped); @@ -774,39 +778,39 @@ public void testPairFlatMap() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(6, "g"), - new Tuple2(6, "i"), - new Tuple2(6, "a"), - new Tuple2(6, "n"), - new Tuple2(6, "t"), - new Tuple2(6, "s")), + new Tuple2<>(6, "g"), + new Tuple2<>(6, "i"), + new Tuple2<>(6, "a"), + new Tuple2<>(6, "n"), + new Tuple2<>(6, "t"), + new Tuple2<>(6, "s")), Arrays.asList( - new Tuple2(7, "d"), - new Tuple2(7, "o"), - new Tuple2(7, "d"), - new Tuple2(7, "g"), - new Tuple2(7, "e"), - new Tuple2(7, "r"), - new Tuple2(7, "s")), + new Tuple2<>(7, "d"), + new Tuple2<>(7, "o"), + new Tuple2<>(7, "d"), + new Tuple2<>(7, "g"), + new Tuple2<>(7, "e"), + new Tuple2<>(7, "r"), + new Tuple2<>(7, "s")), Arrays.asList( - new Tuple2(9, "a"), - new Tuple2(9, "t"), - new Tuple2(9, "h"), - new Tuple2(9, "l"), - new Tuple2(9, "e"), - new Tuple2(9, "t"), - new Tuple2(9, "i"), - new Tuple2(9, "c"), - new Tuple2(9, "s"))); + new Tuple2<>(9, "a"), + new Tuple2<>(9, "t"), + new Tuple2<>(9, "h"), + new Tuple2<>(9, "l"), + new Tuple2<>(9, "e"), + new Tuple2<>(9, "t"), + new Tuple2<>(9, "i"), + new Tuple2<>(9, "c"), + new Tuple2<>(9, "s"))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream flatMapped = stream.flatMapToPair( new PairFlatMapFunction() { @Override - public Iterable> call(String in) throws Exception { - List> out = Lists.newArrayList(); + public Iterable> call(String in) { + List> out = new ArrayList<>(); for (String letter: in.split("(?!^)")) { - out.add(new Tuple2(in.length(), letter)); + out.add(new Tuple2<>(in.length(), letter)); } return out; } @@ -851,13 +855,13 @@ public void testUnion() { */ public static void assertOrderInvariantEquals( List> expected, List> actual) { - List> expectedSets = new ArrayList>(); + List> expectedSets = new ArrayList<>(); for (List list: expected) { - expectedSets.add(Collections.unmodifiableSet(new HashSet(list))); + expectedSets.add(Collections.unmodifiableSet(new HashSet<>(list))); } - List> actualSets = new ArrayList>(); + List> actualSets = new ArrayList<>(); for (List list: actual) { - actualSets.add(Collections.unmodifiableSet(new HashSet(list))); + actualSets.add(Collections.unmodifiableSet(new HashSet<>(list))); } Assert.assertEquals(expectedSets, actualSets); } @@ -869,25 +873,25 @@ public static void assertOrderInvariantEquals( public void testPairFilter() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("giants", 6)), - Arrays.asList(new Tuple2("yankees", 7))); + Arrays.asList(new Tuple2<>("giants", 6)), + Arrays.asList(new Tuple2<>("yankees", 7))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = stream.mapToPair( new PairFunction() { @Override - public Tuple2 call(String in) throws Exception { - return new Tuple2(in, in.length()); + public Tuple2 call(String in) { + return new Tuple2<>(in, in.length()); } }); JavaPairDStream filtered = pairStream.filter( new Function, Boolean>() { @Override - public Boolean call(Tuple2 in) throws Exception { + public Boolean call(Tuple2 in) { return in._1().contains("a"); } }); @@ -898,28 +902,28 @@ public Boolean call(Tuple2 in) throws Exception { } @SuppressWarnings("unchecked") - private List>> stringStringKVStream = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("california", "giants"), - new Tuple2("new york", "yankees"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("california", "ducks"), - new Tuple2("new york", "rangers"), - new Tuple2("new york", "islanders"))); + private final List>> stringStringKVStream = Arrays.asList( + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "yankees"), + new Tuple2<>("new york", "mets")), + Arrays.asList(new Tuple2<>("california", "sharks"), + new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "rangers"), + new Tuple2<>("new york", "islanders"))); @SuppressWarnings("unchecked") - private List>> stringIntKVStream = Arrays.asList( + private final List>> stringIntKVStream = Arrays.asList( Arrays.asList( - new Tuple2("california", 1), - new Tuple2("california", 3), - new Tuple2("new york", 4), - new Tuple2("new york", 1)), + new Tuple2<>("california", 1), + new Tuple2<>("california", 3), + new Tuple2<>("new york", 4), + new Tuple2<>("new york", 1)), Arrays.asList( - new Tuple2("california", 5), - new Tuple2("california", 5), - new Tuple2("new york", 3), - new Tuple2("new york", 1))); + new Tuple2<>("california", 5), + new Tuple2<>("california", 5), + new Tuple2<>("new york", 3), + new Tuple2<>("new york", 1))); @SuppressWarnings("unchecked") @Test @@ -928,22 +932,22 @@ public void testPairMap() { // Maps pair -> pair of different type List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, "california"), - new Tuple2(3, "california"), - new Tuple2(4, "new york"), - new Tuple2(1, "new york")), + new Tuple2<>(1, "california"), + new Tuple2<>(3, "california"), + new Tuple2<>(4, "new york"), + new Tuple2<>(1, "new york")), Arrays.asList( - new Tuple2(5, "california"), - new Tuple2(5, "california"), - new Tuple2(3, "new york"), - new Tuple2(1, "new york"))); + new Tuple2<>(5, "california"), + new Tuple2<>(5, "california"), + new Tuple2<>(3, "new york"), + new Tuple2<>(1, "new york"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reversed = pairStream.mapToPair( new PairFunction, Integer, String>() { @Override - public Tuple2 call(Tuple2 in) throws Exception { + public Tuple2 call(Tuple2 in) { return in.swap(); } }); @@ -961,23 +965,23 @@ public void testPairMapPartitions() { // Maps pair -> pair of different type List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, "california"), - new Tuple2(3, "california"), - new Tuple2(4, "new york"), - new Tuple2(1, "new york")), + new Tuple2<>(1, "california"), + new Tuple2<>(3, "california"), + new Tuple2<>(4, "new york"), + new Tuple2<>(1, "new york")), Arrays.asList( - new Tuple2(5, "california"), - new Tuple2(5, "california"), - new Tuple2(3, "new york"), - new Tuple2(1, "new york"))); + new Tuple2<>(5, "california"), + new Tuple2<>(5, "california"), + new Tuple2<>(3, "new york"), + new Tuple2<>(1, "new york"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reversed = pairStream.mapPartitionsToPair( new PairFlatMapFunction>, Integer, String>() { @Override - public Iterable> call(Iterator> in) throws Exception { - LinkedList> out = new LinkedList>(); + public Iterable> call(Iterator> in) { + List> out = new LinkedList<>(); while (in.hasNext()) { Tuple2 next = in.next(); out.add(next.swap()); @@ -1006,7 +1010,7 @@ public void testPairMap2() { // Maps pair -> single JavaDStream reversed = pairStream.map( new Function, Integer>() { @Override - public Integer call(Tuple2 in) throws Exception { + public Integer call(Tuple2 in) { return in._2(); } }); @@ -1022,23 +1026,23 @@ public Integer call(Tuple2 in) throws Exception { public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair List>> inputData = Arrays.asList( Arrays.asList( - new Tuple2("hi", 1), - new Tuple2("ho", 2)), + new Tuple2<>("hi", 1), + new Tuple2<>("ho", 2)), Arrays.asList( - new Tuple2("hi", 1), - new Tuple2("ho", 2))); + new Tuple2<>("hi", 1), + new Tuple2<>("ho", 2))); List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, "h"), - new Tuple2(1, "i"), - new Tuple2(2, "h"), - new Tuple2(2, "o")), + new Tuple2<>(1, "h"), + new Tuple2<>(1, "i"), + new Tuple2<>(2, "h"), + new Tuple2<>(2, "o")), Arrays.asList( - new Tuple2(1, "h"), - new Tuple2(1, "i"), - new Tuple2(2, "h"), - new Tuple2(2, "o"))); + new Tuple2<>(1, "h"), + new Tuple2<>(1, "i"), + new Tuple2<>(2, "h"), + new Tuple2<>(2, "o"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); @@ -1046,10 +1050,10 @@ public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair JavaPairDStream flatMapped = pairStream.flatMapToPair( new PairFlatMapFunction, Integer, String>() { @Override - public Iterable> call(Tuple2 in) throws Exception { - List> out = new LinkedList>(); + public Iterable> call(Tuple2 in) { + List> out = new LinkedList<>(); for (Character s : in._1().toCharArray()) { - out.add(new Tuple2(in._2(), s.toString())); + out.add(new Tuple2<>(in._2(), s.toString())); } return out; } @@ -1067,11 +1071,11 @@ public void testPairGroupByKey() { List>>> expected = Arrays.asList( Arrays.asList( - new Tuple2>("california", Arrays.asList("dodgers", "giants")), - new Tuple2>("new york", Arrays.asList("yankees", "mets"))), + new Tuple2<>("california", Arrays.asList("dodgers", "giants")), + new Tuple2<>("new york", Arrays.asList("yankees", "mets"))), Arrays.asList( - new Tuple2>("california", Arrays.asList("sharks", "ducks")), - new Tuple2>("new york", Arrays.asList("rangers", "islanders")))); + new Tuple2<>("california", Arrays.asList("sharks", "ducks")), + new Tuple2<>("new york", Arrays.asList("rangers", "islanders")))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1103,11 +1107,11 @@ public void testPairReduceByKey() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("california", 4), - new Tuple2("new york", 5)), + new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), Arrays.asList( - new Tuple2("california", 10), - new Tuple2("new york", 4))); + new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1128,20 +1132,20 @@ public void testCombineByKey() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("california", 4), - new Tuple2("new york", 5)), + new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), Arrays.asList( - new Tuple2("california", 10), - new Tuple2("new york", 4))); + new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream combined = pairStream.combineByKey( + JavaPairDStream combined = pairStream.combineByKey( new Function() { @Override - public Integer call(Integer i) throws Exception { + public Integer call(Integer i) { return i; } }, new IntegerSum(), new IntegerSum(), new HashPartitioner(2)); @@ -1162,13 +1166,13 @@ public void testCountByValue() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), + new Tuple2<>("hello", 1L), + new Tuple2<>("world", 1L)), Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("moon", 1L)), + new Tuple2<>("hello", 1L), + new Tuple2<>("moon", 1L)), Arrays.asList( - new Tuple2("hello", 1L))); + new Tuple2<>("hello", 1L))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream counted = stream.countByValue(); @@ -1185,16 +1189,16 @@ public void testGroupByKeyAndWindow() { List>>> expected = Arrays.asList( Arrays.asList( - new Tuple2>("california", Arrays.asList(1, 3)), - new Tuple2>("new york", Arrays.asList(1, 4)) + new Tuple2<>("california", Arrays.asList(1, 3)), + new Tuple2<>("new york", Arrays.asList(1, 4)) ), Arrays.asList( - new Tuple2>("california", Arrays.asList(1, 3, 5, 5)), - new Tuple2>("new york", Arrays.asList(1, 1, 3, 4)) + new Tuple2<>("california", Arrays.asList(1, 3, 5, 5)), + new Tuple2<>("new york", Arrays.asList(1, 1, 3, 4)) ), Arrays.asList( - new Tuple2>("california", Arrays.asList(5, 5)), - new Tuple2>("new york", Arrays.asList(1, 3)) + new Tuple2<>("california", Arrays.asList(5, 5)), + new Tuple2<>("new york", Arrays.asList(1, 3)) ) ); @@ -1212,16 +1216,16 @@ public void testGroupByKeyAndWindow() { } } - private HashSet>> convert(List>> listOfTuples) { - List>> newListOfTuples = new ArrayList>>(); + private static Set>> convert(List>> listOfTuples) { + List>> newListOfTuples = new ArrayList<>(); for (Tuple2> tuple: listOfTuples) { newListOfTuples.add(convert(tuple)); } - return new HashSet>>(newListOfTuples); + return new HashSet<>(newListOfTuples); } - private Tuple2> convert(Tuple2> tuple) { - return new Tuple2>(tuple._1(), new HashSet(tuple._2())); + private static Tuple2> convert(Tuple2> tuple) { + return new Tuple2<>(tuple._1(), new HashSet<>(tuple._2())); } @SuppressWarnings("unchecked") @@ -1230,12 +1234,12 @@ public void testReduceByKeyAndWindow() { List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 10), - new Tuple2("new york", 4))); + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1254,12 +1258,12 @@ public void testUpdateStateByKey() { List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9))); + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1270,10 +1274,10 @@ public void testUpdateStateByKey() { public Optional call(List values, Optional state) { int out = 0; if (state.isPresent()) { - out = out + state.get(); + out += state.get(); } for (Integer v : values) { - out = out + v; + out += v; } return Optional.of(out); } @@ -1290,19 +1294,19 @@ public void testUpdateStateByKeyWithInitial() { List>> inputData = stringIntKVStream; List> initial = Arrays.asList ( - new Tuple2 ("california", 1), - new Tuple2 ("new york", 2)); + new Tuple2<>("california", 1), + new Tuple2<>("new york", 2)); JavaRDD> tmpRDD = ssc.sparkContext().parallelize(initial); JavaPairRDD initialRDD = JavaPairRDD.fromJavaRDD (tmpRDD); List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 5), - new Tuple2("new york", 7)), - Arrays.asList(new Tuple2("california", 15), - new Tuple2("new york", 11)), - Arrays.asList(new Tuple2("california", 15), - new Tuple2("new york", 11))); + Arrays.asList(new Tuple2<>("california", 5), + new Tuple2<>("new york", 7)), + Arrays.asList(new Tuple2<>("california", 15), + new Tuple2<>("new york", 11)), + Arrays.asList(new Tuple2<>("california", 15), + new Tuple2<>("new york", 11))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1313,10 +1317,10 @@ public void testUpdateStateByKeyWithInitial() { public Optional call(List values, Optional state) { int out = 0; if (state.isPresent()) { - out = out + state.get(); + out += state.get(); } for (Integer v : values) { - out = out + v; + out += v; } return Optional.of(out); } @@ -1333,19 +1337,19 @@ public void testReduceByKeyAndWindowWithInverse() { List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 10), - new Tuple2("new york", 4))); + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reduceWindowed = pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), - new Duration(2000), new Duration(1000)); + new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reduceWindowed); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1362,15 +1366,15 @@ public void testCountByValueAndWindow() { List>> expected = Arrays.asList( Sets.newHashSet( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), + new Tuple2<>("hello", 1L), + new Tuple2<>("world", 1L)), Sets.newHashSet( - new Tuple2("hello", 2L), - new Tuple2("world", 1L), - new Tuple2("moon", 1L)), + new Tuple2<>("hello", 2L), + new Tuple2<>("world", 1L), + new Tuple2<>("moon", 1L)), Sets.newHashSet( - new Tuple2("hello", 2L), - new Tuple2("moon", 1L))); + new Tuple2<>("hello", 2L), + new Tuple2<>("moon", 1L))); JavaDStream stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1378,7 +1382,7 @@ public void testCountByValueAndWindow() { stream.countByValueAndWindow(new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(counted); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - List>> unorderedResult = Lists.newArrayList(); + List>> unorderedResult = new ArrayList<>(); for (List> res: result) { unorderedResult.add(Sets.newHashSet(res)); } @@ -1391,27 +1395,27 @@ public void testCountByValueAndWindow() { public void testPairTransform() { List>> inputData = Arrays.asList( Arrays.asList( - new Tuple2(3, 5), - new Tuple2(1, 5), - new Tuple2(4, 5), - new Tuple2(2, 5)), + new Tuple2<>(3, 5), + new Tuple2<>(1, 5), + new Tuple2<>(4, 5), + new Tuple2<>(2, 5)), Arrays.asList( - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5), - new Tuple2(1, 5))); + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5), + new Tuple2<>(1, 5))); List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, 5), - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5)), + new Tuple2<>(1, 5), + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5)), Arrays.asList( - new Tuple2(1, 5), - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5))); + new Tuple2<>(1, 5), + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1420,7 +1424,7 @@ public void testPairTransform() { JavaPairDStream sorted = pairStream.transformToPair( new Function, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD in) throws Exception { + public JavaPairRDD call(JavaPairRDD in) { return in.sortByKey(); } }); @@ -1436,15 +1440,15 @@ public JavaPairRDD call(JavaPairRDD in) thro public void testPairToNormalRDDTransform() { List>> inputData = Arrays.asList( Arrays.asList( - new Tuple2(3, 5), - new Tuple2(1, 5), - new Tuple2(4, 5), - new Tuple2(2, 5)), + new Tuple2<>(3, 5), + new Tuple2<>(1, 5), + new Tuple2<>(4, 5), + new Tuple2<>(2, 5)), Arrays.asList( - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5), - new Tuple2(1, 5))); + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5), + new Tuple2<>(1, 5))); List> expected = Arrays.asList( Arrays.asList(3,1,4,2), @@ -1457,11 +1461,11 @@ public void testPairToNormalRDDTransform() { JavaDStream firstParts = pairStream.transform( new Function, JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD in) throws Exception { + public JavaRDD call(JavaPairRDD in) { return in.map(new Function, Integer>() { @Override - public Integer call(Tuple2 in) { - return in._1(); + public Integer call(Tuple2 in2) { + return in2._1(); } }); } @@ -1479,14 +1483,14 @@ public void testMapValues() { List>> inputData = stringStringKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", "DODGERS"), - new Tuple2("california", "GIANTS"), - new Tuple2("new york", "YANKEES"), - new Tuple2("new york", "METS")), - Arrays.asList(new Tuple2("california", "SHARKS"), - new Tuple2("california", "DUCKS"), - new Tuple2("new york", "RANGERS"), - new Tuple2("new york", "ISLANDERS"))); + Arrays.asList(new Tuple2<>("california", "DODGERS"), + new Tuple2<>("california", "GIANTS"), + new Tuple2<>("new york", "YANKEES"), + new Tuple2<>("new york", "METS")), + Arrays.asList(new Tuple2<>("california", "SHARKS"), + new Tuple2<>("california", "DUCKS"), + new Tuple2<>("new york", "RANGERS"), + new Tuple2<>("new york", "ISLANDERS"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1494,8 +1498,8 @@ public void testMapValues() { JavaPairDStream mapped = pairStream.mapValues(new Function() { @Override - public String call(String s) throws Exception { - return s.toUpperCase(); + public String call(String s) { + return s.toUpperCase(Locale.ENGLISH); } }); @@ -1511,22 +1515,22 @@ public void testFlatMapValues() { List>> inputData = stringStringKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers1"), - new Tuple2("california", "dodgers2"), - new Tuple2("california", "giants1"), - new Tuple2("california", "giants2"), - new Tuple2("new york", "yankees1"), - new Tuple2("new york", "yankees2"), - new Tuple2("new york", "mets1"), - new Tuple2("new york", "mets2")), - Arrays.asList(new Tuple2("california", "sharks1"), - new Tuple2("california", "sharks2"), - new Tuple2("california", "ducks1"), - new Tuple2("california", "ducks2"), - new Tuple2("new york", "rangers1"), - new Tuple2("new york", "rangers2"), - new Tuple2("new york", "islanders1"), - new Tuple2("new york", "islanders2"))); + Arrays.asList(new Tuple2<>("california", "dodgers1"), + new Tuple2<>("california", "dodgers2"), + new Tuple2<>("california", "giants1"), + new Tuple2<>("california", "giants2"), + new Tuple2<>("new york", "yankees1"), + new Tuple2<>("new york", "yankees2"), + new Tuple2<>("new york", "mets1"), + new Tuple2<>("new york", "mets2")), + Arrays.asList(new Tuple2<>("california", "sharks1"), + new Tuple2<>("california", "sharks2"), + new Tuple2<>("california", "ducks1"), + new Tuple2<>("california", "ducks2"), + new Tuple2<>("new york", "rangers1"), + new Tuple2<>("new york", "rangers2"), + new Tuple2<>("new york", "islanders1"), + new Tuple2<>("new york", "islanders2"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1537,7 +1541,7 @@ public void testFlatMapValues() { new Function>() { @Override public Iterable call(String in) { - List out = new ArrayList(); + List out = new ArrayList<>(); out.add(in + "1"); out.add(in + "2"); return out; @@ -1554,29 +1558,29 @@ public Iterable call(String in) { @Test public void testCoGroup() { List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), + Arrays.asList(new Tuple2<>("california", "sharks"), + new Tuple2<>("new york", "rangers"))); List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); + Arrays.asList(new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "mets")), + Arrays.asList(new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "islanders"))); List, List>>>> expected = Arrays.asList( Arrays.asList( - new Tuple2, List>>("california", - new Tuple2, List>(Arrays.asList("dodgers"), Arrays.asList("giants"))), - new Tuple2, List>>("new york", - new Tuple2, List>(Arrays.asList("yankees"), Arrays.asList("mets")))), + new Tuple2<>("california", + new Tuple2<>(Arrays.asList("dodgers"), Arrays.asList("giants"))), + new Tuple2<>("new york", + new Tuple2<>(Arrays.asList("yankees"), Arrays.asList("mets")))), Arrays.asList( - new Tuple2, List>>("california", - new Tuple2, List>(Arrays.asList("sharks"), Arrays.asList("ducks"))), - new Tuple2, List>>("new york", - new Tuple2, List>(Arrays.asList("rangers"), Arrays.asList("islanders"))))); + new Tuple2<>("california", + new Tuple2<>(Arrays.asList("sharks"), Arrays.asList("ducks"))), + new Tuple2<>("new york", + new Tuple2<>(Arrays.asList("rangers"), Arrays.asList("islanders"))))); JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( @@ -1612,29 +1616,29 @@ public void testCoGroup() { @Test public void testJoin() { List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), + Arrays.asList(new Tuple2<>("california", "sharks"), + new Tuple2<>("new york", "rangers"))); List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); + Arrays.asList(new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "mets")), + Arrays.asList(new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "islanders"))); List>>> expected = Arrays.asList( Arrays.asList( - new Tuple2>("california", - new Tuple2("dodgers", "giants")), - new Tuple2>("new york", - new Tuple2("yankees", "mets"))), + new Tuple2<>("california", + new Tuple2<>("dodgers", "giants")), + new Tuple2<>("new york", + new Tuple2<>("yankees", "mets"))), Arrays.asList( - new Tuple2>("california", - new Tuple2("sharks", "ducks")), - new Tuple2>("new york", - new Tuple2("rangers", "islanders")))); + new Tuple2<>("california", + new Tuple2<>("sharks", "ducks")), + new Tuple2<>("new york", + new Tuple2<>("rangers", "islanders")))); JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( @@ -1656,13 +1660,13 @@ public void testJoin() { @Test public void testLeftOuterJoin() { List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks") )); + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), + Arrays.asList(new Tuple2<>("california", "sharks") )); List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants") ), - Arrays.asList(new Tuple2("new york", "islanders") ) + Arrays.asList(new Tuple2<>("california", "giants") ), + Arrays.asList(new Tuple2<>("new york", "islanders") ) ); @@ -1705,7 +1709,7 @@ public void testCheckpointMasterRecovery() throws InterruptedException { JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @Override - public Integer call(String s) throws Exception { + public Integer call(String s) { return s.length(); } }); @@ -1727,6 +1731,7 @@ public Integer call(String s) throws Exception { @SuppressWarnings("unchecked") @Test public void testContextGetOrCreate() throws InterruptedException { + ssc.stop(); final SparkConf conf = new SparkConf() .setMaster("local[2]") @@ -1743,6 +1748,7 @@ public void testContextGetOrCreate() throws InterruptedException { // (used to detect the new context) final AtomicBoolean newContextCreated = new AtomicBoolean(false); Function0 creatingFunc = new Function0() { + @Override public JavaStreamingContext call() { newContextCreated.set(true); return new JavaStreamingContext(conf, Seconds.apply(1)); @@ -1756,20 +1762,20 @@ public JavaStreamingContext call() { newContextCreated.set(false); ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration(), true); + new Configuration(), true); Assert.assertTrue("new context not created", newContextCreated.get()); ssc.stop(); newContextCreated.set(false); ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration()); + new Configuration()); Assert.assertTrue("old context not recovered", !newContextCreated.get()); ssc.stop(); newContextCreated.set(false); JavaSparkContext sc = new JavaSparkContext(conf); ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration()); + new Configuration()); Assert.assertTrue("old context not recovered", !newContextCreated.get()); ssc.stop(); } @@ -1791,7 +1797,7 @@ public void testCheckpointofIndividualStream() throws InterruptedException { JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @Override - public Integer call(String s) throws Exception { + public Integer call(String s) { return s.length(); } }); @@ -1809,29 +1815,26 @@ public Integer call(String s) throws Exception { // InputStream functionality is deferred to the existing Scala tests. @Test public void testSocketTextStream() { - JavaReceiverInputDStream test = ssc.socketTextStream("localhost", 12345); + ssc.socketTextStream("localhost", 12345); } @Test public void testSocketString() { - - class Converter implements Function> { - public Iterable call(InputStream in) throws IOException { - BufferedReader reader = new BufferedReader(new InputStreamReader(in)); - List out = new ArrayList(); - while (true) { - String line = reader.readLine(); - if (line == null) { break; } - out.add(line); - } - return out; - } - } - - JavaDStream test = ssc.socketStream( + ssc.socketStream( "localhost", 12345, - new Converter(), + new Function>() { + @Override + public Iterable call(InputStream in) throws IOException { + List out = new ArrayList<>(); + try (BufferedReader reader = new BufferedReader(new InputStreamReader(in))) { + for (String line; (line = reader.readLine()) != null;) { + out.add(line); + } + } + return out; + } + }, StorageLevel.MEMORY_ONLY()); } @@ -1861,7 +1864,7 @@ public void testFileStream() throws IOException { TextInputFormat.class, new Function() { @Override - public Boolean call(Path v1) throws Exception { + public Boolean call(Path v1) { return Boolean.TRUE; } }, @@ -1870,7 +1873,7 @@ public Boolean call(Path v1) throws Exception { JavaDStream test = inputStream.map( new Function, String>() { @Override - public String call(Tuple2 v1) throws Exception { + public String call(Tuple2 v1) { return v1._2().toString(); } }); @@ -1883,19 +1886,15 @@ public String call(Tuple2 v1) throws Exception { @Test public void testRawSocketStream() { - JavaReceiverInputDStream test = ssc.rawSocketStream("localhost", 12345); + ssc.rawSocketStream("localhost", 12345); } - private List> fileTestPrepare(File testDir) throws IOException { + private static List> fileTestPrepare(File testDir) throws IOException { File existingFile = new File(testDir, "0"); Files.write("0\n", existingFile, Charset.forName("UTF-8")); - assertTrue(existingFile.setLastModified(1000) && existingFile.lastModified() == 1000); - - List> expected = Arrays.asList( - Arrays.asList("0") - ); - - return expected; + Assert.assertTrue(existingFile.setLastModified(1000)); + Assert.assertEquals(1000, existingFile.lastModified()); + return Arrays.asList(Arrays.asList("0")); } @SuppressWarnings("unchecked") diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java index 1b0787fe69de..ec2bffd6a5b9 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java @@ -36,7 +36,6 @@ import java.io.Serializable; import java.net.ConnectException; import java.net.Socket; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; public class JavaReceiverAPISuite implements Serializable { @@ -64,16 +63,16 @@ public void testReceiver() throws InterruptedException { ssc.receiverStream(new JavaSocketReceiver("localhost", server.port())); JavaDStream mapped = input.map(new Function() { @Override - public String call(String v1) throws Exception { + public String call(String v1) { return v1 + "."; } }); mapped.foreachRDD(new Function, Void>() { @Override - public Void call(JavaRDD rdd) throws Exception { - long count = rdd.count(); - dataCounter.addAndGet(count); - return null; + public Void call(JavaRDD rdd) { + long count = rdd.count(); + dataCounter.addAndGet(count); + return null; } }); @@ -83,7 +82,7 @@ public Void call(JavaRDD rdd) throws Exception { Thread.sleep(200); for (int i = 0; i < 6; i++) { - server.send("" + i + "\n"); // \n to make sure these are separate lines + server.send(i + "\n"); // \n to make sure these are separate lines Thread.sleep(100); } while (dataCounter.get() == 0 && System.currentTimeMillis() - startTime < timeout) { @@ -95,50 +94,49 @@ public Void call(JavaRDD rdd) throws Exception { server.stop(); } } -} -class JavaSocketReceiver extends Receiver { + private static class JavaSocketReceiver extends Receiver { - String host = null; - int port = -1; + String host = null; + int port = -1; - public JavaSocketReceiver(String host_ , int port_) { - super(StorageLevel.MEMORY_AND_DISK()); - host = host_; - port = port_; - } + JavaSocketReceiver(String host_ , int port_) { + super(StorageLevel.MEMORY_AND_DISK()); + host = host_; + port = port_; + } - @Override - public void onStart() { - new Thread() { - @Override public void run() { - receive(); - } - }.start(); - } + @Override + public void onStart() { + new Thread() { + @Override public void run() { + receive(); + } + }.start(); + } - @Override - public void onStop() { - } + @Override + public void onStop() { + } - private void receive() { - Socket socket = null; - try { - socket = new Socket(host, port); - BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream())); - String userInput; - while ((userInput = in.readLine()) != null) { - store(userInput); + private void receive() { + try { + Socket socket = new Socket(host, port); + BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream())); + String userInput; + while ((userInput = in.readLine()) != null) { + store(userInput); + } + in.close(); + socket.close(); + } catch(ConnectException ce) { + ce.printStackTrace(); + restart("Could not connect", ce); + } catch(Throwable t) { + t.printStackTrace(); + restart("Error receiving data", t); } - in.close(); - socket.close(); - } catch(ConnectException ce) { - ce.printStackTrace(); - restart("Could not connect", ce); - } catch(Throwable t) { - t.printStackTrace(); - restart("Error receiving data", t); } } -} +} diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala index bb80bff6dc2e..57b50bdfd652 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala @@ -17,16 +17,13 @@ package org.apache.spark.streaming -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import java.util.{List => JList} + +import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import java.util.{List => JList} -import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStreamLike, JavaDStream, JavaStreamingContext} -import org.apache.spark.streaming._ -import java.util.ArrayList -import collection.JavaConversions._ import org.apache.spark.api.java.JavaRDDLike -import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.api.java.{JavaDStreamLike, JavaDStream, JavaStreamingContext} /** Exposes streaming test functionality in a Java-friendly way. */ trait JavaTestBase extends TestSuiteBase { @@ -39,7 +36,7 @@ trait JavaTestBase extends TestSuiteBase { ssc: JavaStreamingContext, data: JList[JList[T]], numPartitions: Int) = { - val seqData = data.map(Seq(_:_*)) + val seqData = data.asScala.map(_.asScala) implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] @@ -72,9 +69,7 @@ trait JavaTestBase extends TestSuiteBase { implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] ssc.getState() val res = runStreams[V](ssc.ssc, numBatches, numExpectedOutput) - val out = new ArrayList[JList[V]]() - res.map(entry => out.append(new ArrayList[V](entry))) - out + res.map(_.asJava).asJava } /** @@ -90,12 +85,7 @@ trait JavaTestBase extends TestSuiteBase { implicit val cm: ClassTag[V] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] val res = runStreamsWithPartitions[V](ssc.ssc, numBatches, numExpectedOutput) - val out = new ArrayList[JList[JList[V]]]() - res.map{entry => - val lists = entry.map(new ArrayList[V](_)) - out.append(new ArrayList[JList[V]](lists)) - } - out + res.map(entry => entry.map(_.asJava).asJava).asJava } } diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java index 50e8f9fc159c..175b8a496b4e 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java @@ -17,13 +17,15 @@ package org.apache.spark.streaming; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.nio.ByteBuffer; import java.util.Arrays; -import java.util.Collection; +import java.util.Iterator; +import java.util.List; -import org.apache.commons.collections.CollectionUtils; -import org.apache.commons.collections.Transformer; +import com.google.common.base.Function; +import com.google.common.collect.Iterators; import org.apache.spark.SparkConf; import org.apache.spark.streaming.util.WriteAheadLog; import org.apache.spark.streaming.util.WriteAheadLogRecordHandle; @@ -32,40 +34,40 @@ import org.junit.Test; import org.junit.Assert; -class JavaWriteAheadLogSuiteHandle extends WriteAheadLogRecordHandle { - int index = -1; - public JavaWriteAheadLogSuiteHandle(int idx) { - index = idx; - } -} - public class JavaWriteAheadLogSuite extends WriteAheadLog { - class Record { + static class JavaWriteAheadLogSuiteHandle extends WriteAheadLogRecordHandle { + int index = -1; + JavaWriteAheadLogSuiteHandle(int idx) { + index = idx; + } + } + + static class Record { long time; int index; ByteBuffer buffer; - public Record(long tym, int idx, ByteBuffer buf) { + Record(long tym, int idx, ByteBuffer buf) { index = idx; time = tym; buffer = buf; } } private int index = -1; - private ArrayList records = new ArrayList(); + private final List records = new ArrayList<>(); // Methods for WriteAheadLog @Override - public WriteAheadLogRecordHandle write(java.nio.ByteBuffer record, long time) { + public WriteAheadLogRecordHandle write(ByteBuffer record, long time) { index += 1; - records.add(new org.apache.spark.streaming.JavaWriteAheadLogSuite.Record(time, index, record)); + records.add(new Record(time, index, record)); return new JavaWriteAheadLogSuiteHandle(index); } @Override - public java.nio.ByteBuffer read(WriteAheadLogRecordHandle handle) { + public ByteBuffer read(WriteAheadLogRecordHandle handle) { if (handle instanceof JavaWriteAheadLogSuiteHandle) { int reqdIndex = ((JavaWriteAheadLogSuiteHandle) handle).index; for (Record record: records) { @@ -78,14 +80,13 @@ public java.nio.ByteBuffer read(WriteAheadLogRecordHandle handle) { } @Override - public java.util.Iterator readAll() { - Collection buffers = CollectionUtils.collect(records, new Transformer() { + public Iterator readAll() { + return Iterators.transform(records.iterator(), new Function() { @Override - public Object transform(Object input) { - return ((Record) input).buffer; + public ByteBuffer apply(Record input) { + return input.buffer; } }); - return buffers.iterator(); } @Override @@ -110,20 +111,21 @@ public void testCustomWAL() { WriteAheadLog wal = WriteAheadLogUtils.createLogForDriver(conf, null, null); String data1 = "data1"; - WriteAheadLogRecordHandle handle = wal.write(ByteBuffer.wrap(data1.getBytes()), 1234); + WriteAheadLogRecordHandle handle = + wal.write(ByteBuffer.wrap(data1.getBytes(StandardCharsets.UTF_8)), 1234); Assert.assertTrue(handle instanceof JavaWriteAheadLogSuiteHandle); - Assert.assertTrue(new String(wal.read(handle).array()).equals(data1)); + Assert.assertEquals(new String(wal.read(handle).array(), StandardCharsets.UTF_8), data1); - wal.write(ByteBuffer.wrap("data2".getBytes()), 1235); - wal.write(ByteBuffer.wrap("data3".getBytes()), 1236); - wal.write(ByteBuffer.wrap("data4".getBytes()), 1237); + wal.write(ByteBuffer.wrap("data2".getBytes(StandardCharsets.UTF_8)), 1235); + wal.write(ByteBuffer.wrap("data3".getBytes(StandardCharsets.UTF_8)), 1236); + wal.write(ByteBuffer.wrap("data4".getBytes(StandardCharsets.UTF_8)), 1237); wal.clean(1236, false); - java.util.Iterator dataIterator = wal.readAll(); - ArrayList readData = new ArrayList(); + Iterator dataIterator = wal.readAll(); + List readData = new ArrayList<>(); while (dataIterator.hasNext()) { - readData.add(new String(dataIterator.next().array())); + readData.add(new String(dataIterator.next().array(), StandardCharsets.UTF_8)); } - Assert.assertTrue(readData.equals(Arrays.asList("data3", "data4"))); + Assert.assertEquals(readData, Arrays.asList("data3", "data4")); } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 08faeaa58f41..255376807c95 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -81,39 +81,41 @@ class BasicOperationsSuite extends TestSuiteBase { test("repartition (more partitions)") { val input = Seq(1 to 100, 101 to 200, 201 to 300) val operation = (r: DStream[Int]) => r.repartition(5) - val ssc = setupStreams(input, operation, 2) - val output = runStreamsWithPartitions(ssc, 3, 3) - assert(output.size === 3) - val first = output(0) - val second = output(1) - val third = output(2) - - assert(first.size === 5) - assert(second.size === 5) - assert(third.size === 5) - - assert(first.flatten.toSet.equals((1 to 100).toSet) ) - assert(second.flatten.toSet.equals((101 to 200).toSet)) - assert(third.flatten.toSet.equals((201 to 300).toSet)) + withStreamingContext(setupStreams(input, operation, 2)) { ssc => + val output = runStreamsWithPartitions(ssc, 3, 3) + assert(output.size === 3) + val first = output(0) + val second = output(1) + val third = output(2) + + assert(first.size === 5) + assert(second.size === 5) + assert(third.size === 5) + + assert(first.flatten.toSet.equals((1 to 100).toSet)) + assert(second.flatten.toSet.equals((101 to 200).toSet)) + assert(third.flatten.toSet.equals((201 to 300).toSet)) + } } test("repartition (fewer partitions)") { val input = Seq(1 to 100, 101 to 200, 201 to 300) val operation = (r: DStream[Int]) => r.repartition(2) - val ssc = setupStreams(input, operation, 5) - val output = runStreamsWithPartitions(ssc, 3, 3) - assert(output.size === 3) - val first = output(0) - val second = output(1) - val third = output(2) - - assert(first.size === 2) - assert(second.size === 2) - assert(third.size === 2) - - assert(first.flatten.toSet.equals((1 to 100).toSet)) - assert(second.flatten.toSet.equals( (101 to 200).toSet)) - assert(third.flatten.toSet.equals((201 to 300).toSet)) + withStreamingContext(setupStreams(input, operation, 5)) { ssc => + val output = runStreamsWithPartitions(ssc, 3, 3) + assert(output.size === 3) + val first = output(0) + val second = output(1) + val third = output(2) + + assert(first.size === 2) + assert(second.size === 2) + assert(third.size === 2) + + assert(first.flatten.toSet.equals((1 to 100).toSet)) + assert(second.flatten.toSet.equals((101 to 200).toSet)) + assert(third.flatten.toSet.equals((201 to 300).toSet)) + } } test("groupByKey") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 6b0a3f91d4d0..1bba7a143edf 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming import java.io.File -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.reflect.ClassTag import com.google.common.base.Charsets @@ -30,8 +30,10 @@ import org.apache.hadoop.io.{IntWritable, Text} import org.apache.hadoop.mapred.TextOutputFormat import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} +import org.apache.spark.streaming.scheduler.{ConstantEstimator, RateTestInputDStream, RateTestReceiver} import org.apache.spark.util.{Clock, ManualClock, Utils} /** @@ -191,8 +193,51 @@ class CheckpointSuite extends TestSuiteBase { } } + // This tests if "spark.driver.host" and "spark.driver.port" is set by user, can be recovered + // with correct value. + test("get correct spark.driver.[host|port] from checkpoint") { + val conf = Map("spark.driver.host" -> "localhost", "spark.driver.port" -> "9999") + conf.foreach(kv => System.setProperty(kv._1, kv._2)) + ssc = new StreamingContext(master, framework, batchDuration) + val originalConf = ssc.conf + assert(originalConf.get("spark.driver.host") === "localhost") + assert(originalConf.get("spark.driver.port") === "9999") + + val cp = new Checkpoint(ssc, Time(1000)) + ssc.stop() - // This tests whether the systm can recover from a master failure with simple + // Serialize/deserialize to simulate write to storage and reading it back + val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp)) + + val newCpConf = newCp.createSparkConf() + assert(newCpConf.contains("spark.driver.host")) + assert(newCpConf.contains("spark.driver.port")) + assert(newCpConf.get("spark.driver.host") === "localhost") + assert(newCpConf.get("spark.driver.port") === "9999") + + // Check if all the parameters have been restored + ssc = new StreamingContext(null, newCp, null) + val restoredConf = ssc.conf + assert(restoredConf.get("spark.driver.host") === "localhost") + assert(restoredConf.get("spark.driver.port") === "9999") + ssc.stop() + + // If spark.driver.host and spark.driver.host is not set in system property, these two + // parameters should not be presented in the newly recovered conf. + conf.foreach(kv => System.clearProperty(kv._1)) + val newCpConf1 = newCp.createSparkConf() + assert(!newCpConf1.contains("spark.driver.host")) + assert(!newCpConf1.contains("spark.driver.port")) + + // Spark itself will dispatch a random, not-used port for spark.driver.port if it is not set + // explicitly. + ssc = new StreamingContext(null, newCp, null) + val restoredConf1 = ssc.conf + assert(restoredConf1.get("spark.driver.host") === "localhost") + assert(restoredConf1.get("spark.driver.port") !== "9999") + } + + // This tests whether the system can recover from a master failure with simple // non-stateful operations. This assumes as reliable, replayable input // source - TestInputDStream. test("recovery with map and reduceByKey operations") { @@ -348,6 +393,30 @@ class CheckpointSuite extends TestSuiteBase { testCheckpointedOperation(input, operation, output, 7) } + test("recovery maintains rate controller") { + ssc = new StreamingContext(conf, batchDuration) + ssc.checkpoint(checkpointDir) + + val dstream = new RateTestInputDStream(ssc) { + override val rateController = + Some(new ReceiverRateController(id, new ConstantEstimator(200))) + } + + val output = new TestOutputStreamWithPartitions(dstream.checkpoint(batchDuration * 2)) + output.register() + runStreams(ssc, 5, 5) + + ssc = new StreamingContext(checkpointDir) + ssc.start() + val outputNew = advanceTimeWithRealDelay(ssc, 2) + + eventually(timeout(10.seconds)) { + assert(RateTestReceiver.getActive().nonEmpty) + assert(RateTestReceiver.getActive().get.getDefaultBlockGeneratorRateLimit() === 200) + } + ssc.stop() + } + // This tests whether file input stream remembers what files were seen before // the master failure and uses them again to process a large window operation. // It also tests whether batches, whose processing was incomplete due to the @@ -424,11 +493,11 @@ class CheckpointSuite extends TestSuiteBase { } } } - clock.advance(batchDuration.milliseconds) eventually(eventuallyTimeout) { // Wait until all files have been recorded and all batches have started assert(recordedFiles(ssc) === Seq(1, 2, 3) && batchCounter.getNumStartedBatches === 3) } + clock.advance(batchDuration.milliseconds) // Wait for a checkpoint to be written eventually(eventuallyTimeout) { assert(Checkpoint.getCheckpointFiles(checkpointDir).size === 6) @@ -454,9 +523,12 @@ class CheckpointSuite extends TestSuiteBase { // recorded before failure were saved and successfully recovered logInfo("*********** RESTARTING ************") withStreamingContext(new StreamingContext(checkpointDir)) { ssc => - // So that the restarted StreamingContext's clock has gone forward in time since failure - ssc.conf.set("spark.streaming.manualClock.jump", (batchDuration * 3).milliseconds.toString) - val oldClockTime = clock.getTimeMillis() + // "batchDuration.milliseconds * 3" has gone before restarting StreamingContext. And because + // the recovery time is read from the checkpoint time but the original clock doesn't align + // with the batch time, we need to add the offset "batchDuration.milliseconds / 2". + ssc.conf.set("spark.streaming.manualClock.jump", + (batchDuration.milliseconds / 2 + batchDuration.milliseconds * 3).toString) + val oldClockTime = clock.getTimeMillis() // 15000ms clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val batchCounter = new BatchCounter(ssc) val outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[Int]] @@ -467,10 +539,10 @@ class CheckpointSuite extends TestSuiteBase { ssc.start() // Verify that the clock has traveled forward to the expected time eventually(eventuallyTimeout) { - clock.getTimeMillis() === oldClockTime + assert(clock.getTimeMillis() === oldClockTime) } - // Wait for pre-failure batch to be recomputed (3 while SSC was down plus last batch) - val numBatchesAfterRestart = 4 + // There are 5 batches between 6000ms and 15000ms (inclusive). + val numBatchesAfterRestart = 5 eventually(eventuallyTimeout) { assert(batchCounter.getNumCompletedBatches === numBatchesAfterRestart) } @@ -483,7 +555,6 @@ class CheckpointSuite extends TestSuiteBase { assert(batchCounter.getNumCompletedBatches === index + numBatchesAfterRestart + 1) } } - clock.advance(batchDuration.milliseconds) logInfo("Output after restart = " + outputStream.output.mkString("[", ", ", "]")) assert(outputStream.output.size > 0, "No files processed after restart") ssc.stop() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala index 0c4c06534a69..e82c2fa4e72a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -17,25 +17,32 @@ package org.apache.spark.streaming -import org.apache.spark.Logging +import java.io.File + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkFunSuite, Logging} import org.apache.spark.util.Utils /** * This testsuite tests master failures at random times while the stream is running using * the real clock. */ -class FailureSuite extends TestSuiteBase with Logging { +class FailureSuite extends SparkFunSuite with BeforeAndAfter with Logging { - val directory = Utils.createTempDir() - val numBatches = 30 + private val batchDuration: Duration = Milliseconds(1000) + private val numBatches = 30 + private var directory: File = null - override def batchDuration: Duration = Milliseconds(1000) - - override def useManualClock: Boolean = false + before { + directory = Utils.createTempDir() + } - override def afterFunction() { - Utils.deleteRecursively(directory) - super.afterFunction() + after { + if (directory != null) { + Utils.deleteRecursively(directory) + } + StreamingContext.getActive().foreach { _.stop() } } test("multiple failures with map") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index b74d67c63a78..047e38ef9099 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -76,6 +76,9 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { fail("Timeout: cannot finish all batches in 30 seconds") } + // Ensure progress listener has been notified of all events + ssc.scheduler.listenerBus.waitUntilEmpty(500) + // Verify all "InputInfo"s have been reported assert(ssc.progressListener.numTotalReceivedRecords === input.size) assert(ssc.progressListener.numTotalProcessedRecords === input.size) @@ -325,27 +328,31 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } test("test track the number of input stream") { - val ssc = new StreamingContext(conf, batchDuration) + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => - class TestInputDStream extends InputDStream[String](ssc) { - def start() { } - def stop() { } - def compute(validTime: Time): Option[RDD[String]] = None - } + class TestInputDStream extends InputDStream[String](ssc) { + def start() {} - class TestReceiverInputDStream extends ReceiverInputDStream[String](ssc) { - def getReceiver: Receiver[String] = null - } + def stop() {} - // Register input streams - val receiverInputStreams = Array(new TestReceiverInputDStream, new TestReceiverInputDStream) - val inputStreams = Array(new TestInputDStream, new TestInputDStream, new TestInputDStream) + def compute(validTime: Time): Option[RDD[String]] = None + } - assert(ssc.graph.getInputStreams().length == receiverInputStreams.length + inputStreams.length) - assert(ssc.graph.getReceiverInputStreams().length == receiverInputStreams.length) - assert(ssc.graph.getReceiverInputStreams() === receiverInputStreams) - assert(ssc.graph.getInputStreams().map(_.id) === Array.tabulate(5)(i => i)) - assert(receiverInputStreams.map(_.id) === Array(0, 1)) + class TestReceiverInputDStream extends ReceiverInputDStream[String](ssc) { + def getReceiver: Receiver[String] = null + } + + // Register input streams + val receiverInputStreams = Array(new TestReceiverInputDStream, new TestReceiverInputDStream) + val inputStreams = Array(new TestInputDStream, new TestInputDStream, new TestInputDStream) + + assert(ssc.graph.getInputStreams().length == + receiverInputStreams.length + inputStreams.length) + assert(ssc.graph.getReceiverInputStreams().length == receiverInputStreams.length) + assert(ssc.graph.getReceiverInputStreams() === receiverInputStreams) + assert(ssc.graph.getInputStreams().map(_.id) === Array.tabulate(5)(i => i)) + assert(receiverInputStreams.map(_.id) === Array(0, 1)) + } } def testFileStream(newFilesOnly: Boolean) { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index e0f14fd95428..0e64b57e0ffd 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -43,6 +43,7 @@ object MasterFailureTest extends Logging { @volatile var setupCalled = false def main(args: Array[String]) { + // scalastyle:off println if (args.size < 2) { println( "Usage: MasterFailureTest <# batches> " + @@ -60,6 +61,7 @@ object MasterFailureTest extends Logging { testUpdateStateByKey(directory, numBatches, batchDuration) println("\n\nSUCCESS\n\n") + // scalastyle:on println } def testMap(directory: String, numBatches: Int, batchDuration: Duration) { @@ -242,7 +244,13 @@ object MasterFailureTest extends Logging { } catch { case e: Exception => logError("Error running streaming context", e) } - if (killingThread.isAlive) killingThread.interrupt() + if (killingThread.isAlive) { + killingThread.interrupt() + // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is + // stopped before running the next test. Otherwise, it's possible that we set SparkEnv.env + // to null after the next test creates the new SparkContext and fail the test. + killingThread.join() + } ssc.stop() logInfo("Has been killed = " + killed) @@ -291,10 +299,12 @@ object MasterFailureTest extends Logging { } // Log the output + // scalastyle:off println println("Expected output, size = " + expectedOutput.size) println(expectedOutput.mkString("[", ",", "]")) println("Output, size = " + output.size) println(output.mkString("[", ",", "]")) + // scalastyle:on println // Match the output with the expected output output.foreach(o => diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 6c0c926755c2..13cfe29d7b30 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -29,7 +29,7 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ -import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer @@ -47,7 +47,9 @@ class ReceivedBlockHandlerSuite with Matchers with Logging { - val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") + val conf = new SparkConf() + .set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") + .set("spark.app.id", "streaming-test") val hadoopConf = new Configuration() val streamId = 1 val securityMgr = new SecurityManager(conf) @@ -184,7 +186,7 @@ class ReceivedBlockHandlerSuite } test("Test Block - isFullyConsumed") { - val sparkConf = new SparkConf() + val sparkConf = new SparkConf().set("spark.app.id", "streaming-test") sparkConf.set("spark.storage.unrollMemoryThreshold", "512") // spark.storage.unrollFraction set to 0.4 for BlockManager sparkConf.set("spark.storage.unrollFraction", "0.4") @@ -251,7 +253,7 @@ class ReceivedBlockHandlerSuite maxMem: Long, conf: SparkConf, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { - val transfer = new NioBlockTransferService(conf, securityMgr) + val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) val manager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) manager.initialize("app-id") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala new file mode 100644 index 000000000000..6d388d9624d9 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import scala.util.Random + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.rdd.BlockRDD +import org.apache.spark.storage.{StorageLevel, StreamBlockId} +import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD +import org.apache.spark.streaming.receiver.{BlockManagerBasedStoreResult, Receiver, WriteAheadLogBasedStoreResult} +import org.apache.spark.streaming.scheduler.ReceivedBlockInfo +import org.apache.spark.streaming.util.{WriteAheadLogRecordHandle, WriteAheadLogUtils} +import org.apache.spark.{SparkConf, SparkEnv} + +class ReceiverInputDStreamSuite extends TestSuiteBase with BeforeAndAfterAll { + + override def afterAll(): Unit = { + StreamingContext.getActive().map { _.stop() } + } + + testWithoutWAL("createBlockRDD creates empty BlockRDD when no block info") { receiverStream => + val rdd = receiverStream.createBlockRDD(Time(0), Seq.empty) + assert(rdd.isInstanceOf[BlockRDD[_]]) + assert(!rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + assert(rdd.isEmpty()) + } + + testWithoutWAL("createBlockRDD creates correct BlockRDD with block info") { receiverStream => + val blockInfos = Seq.fill(5) { createBlockInfo(withWALInfo = false) } + val blockIds = blockInfos.map(_.blockId) + + // Verify that there are some blocks that are present, and some that are not + require(blockIds.forall(blockId => SparkEnv.get.blockManager.master.contains(blockId))) + + val rdd = receiverStream.createBlockRDD(Time(0), blockInfos) + assert(rdd.isInstanceOf[BlockRDD[_]]) + assert(!rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + val blockRDD = rdd.asInstanceOf[BlockRDD[_]] + assert(blockRDD.blockIds.toSeq === blockIds) + } + + testWithoutWAL("createBlockRDD filters non-existent blocks before creating BlockRDD") { + receiverStream => + val presentBlockInfos = Seq.fill(2)(createBlockInfo(withWALInfo = false, createBlock = true)) + val absentBlockInfos = Seq.fill(3)(createBlockInfo(withWALInfo = false, createBlock = false)) + val blockInfos = presentBlockInfos ++ absentBlockInfos + val blockIds = blockInfos.map(_.blockId) + + // Verify that there are some blocks that are present, and some that are not + require(blockIds.exists(blockId => SparkEnv.get.blockManager.master.contains(blockId))) + require(blockIds.exists(blockId => !SparkEnv.get.blockManager.master.contains(blockId))) + + val rdd = receiverStream.createBlockRDD(Time(0), blockInfos) + assert(rdd.isInstanceOf[BlockRDD[_]]) + val blockRDD = rdd.asInstanceOf[BlockRDD[_]] + assert(blockRDD.blockIds.toSeq === presentBlockInfos.map { _.blockId}) + } + + testWithWAL("createBlockRDD creates empty WALBackedBlockRDD when no block info") { + receiverStream => + val rdd = receiverStream.createBlockRDD(Time(0), Seq.empty) + assert(rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + assert(rdd.isEmpty()) + } + + testWithWAL( + "createBlockRDD creates correct WALBackedBlockRDD with all block info having WAL info") { + receiverStream => + val blockInfos = Seq.fill(5) { createBlockInfo(withWALInfo = true) } + val blockIds = blockInfos.map(_.blockId) + val rdd = receiverStream.createBlockRDD(Time(0), blockInfos) + assert(rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + val blockRDD = rdd.asInstanceOf[WriteAheadLogBackedBlockRDD[_]] + assert(blockRDD.blockIds.toSeq === blockIds) + assert(blockRDD.walRecordHandles.toSeq === blockInfos.map { _.walRecordHandleOption.get }) + } + + testWithWAL("createBlockRDD creates BlockRDD when some block info dont have WAL info") { + receiverStream => + val blockInfos1 = Seq.fill(2) { createBlockInfo(withWALInfo = true) } + val blockInfos2 = Seq.fill(3) { createBlockInfo(withWALInfo = false) } + val blockInfos = blockInfos1 ++ blockInfos2 + val blockIds = blockInfos.map(_.blockId) + val rdd = receiverStream.createBlockRDD(Time(0), blockInfos) + assert(rdd.isInstanceOf[BlockRDD[_]]) + val blockRDD = rdd.asInstanceOf[BlockRDD[_]] + assert(blockRDD.blockIds.toSeq === blockIds) + } + + + private def testWithoutWAL(msg: String)(body: ReceiverInputDStream[_] => Unit): Unit = { + test(s"Without WAL enabled: $msg") { + runTest(enableWAL = false, body) + } + } + + private def testWithWAL(msg: String)(body: ReceiverInputDStream[_] => Unit): Unit = { + test(s"With WAL enabled: $msg") { + runTest(enableWAL = true, body) + } + } + + private def runTest(enableWAL: Boolean, body: ReceiverInputDStream[_] => Unit): Unit = { + val conf = new SparkConf() + conf.setMaster("local[4]").setAppName("ReceiverInputDStreamSuite") + conf.set(WriteAheadLogUtils.RECEIVER_WAL_ENABLE_CONF_KEY, enableWAL.toString) + require(WriteAheadLogUtils.enableReceiverLog(conf) === enableWAL) + val ssc = new StreamingContext(conf, Seconds(1)) + val receiverStream = new ReceiverInputDStream[Int](ssc) { + override def getReceiver(): Receiver[Int] = null + } + withStreamingContext(ssc) { ssc => + body(receiverStream) + } + } + + /** + * Create a block info for input to the ReceiverInputDStream.createBlockRDD + * @param withWALInfo Create block with WAL info in it + * @param createBlock Actually create the block in the BlockManager + * @return + */ + private def createBlockInfo( + withWALInfo: Boolean, + createBlock: Boolean = true): ReceivedBlockInfo = { + val blockId = new StreamBlockId(0, Random.nextLong()) + if (createBlock) { + SparkEnv.get.blockManager.putSingle(blockId, 1, StorageLevel.MEMORY_ONLY, tellMaster = true) + require(SparkEnv.get.blockManager.master.contains(blockId)) + } + val storeResult = if (withWALInfo) { + new WriteAheadLogBasedStoreResult(blockId, None, new WriteAheadLogRecordHandle { }) + } else { + new BlockManagerBasedStoreResult(blockId, None) + } + new ReceivedBlockInfo(0, None, None, storeResult) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index 5d7127627eea..01279b34f73d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -129,32 +129,6 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { } } - test("block generator") { - val blockGeneratorListener = new FakeBlockGeneratorListener - val blockIntervalMs = 200 - val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms") - val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf) - val expectedBlocks = 5 - val waitTime = expectedBlocks * blockIntervalMs + (blockIntervalMs / 2) - val generatedData = new ArrayBuffer[Int] - - // Generate blocks - val startTime = System.currentTimeMillis() - blockGenerator.start() - var count = 0 - while(System.currentTimeMillis - startTime < waitTime) { - blockGenerator.addData(count) - generatedData += count - count += 1 - Thread.sleep(10) - } - blockGenerator.stop() - - val recordedData = blockGeneratorListener.arrayBuffers.flatten - assert(blockGeneratorListener.arrayBuffers.size > 0) - assert(recordedData.toSet === generatedData.toSet) - } - ignore("block generator throttling") { val blockGeneratorListener = new FakeBlockGeneratorListener val blockIntervalMs = 100 @@ -346,6 +320,13 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { def reportError(message: String, throwable: Throwable) { errors += throwable } + + override protected def onReceiverStart(): Boolean = true + + override def createBlockGenerator( + blockGeneratorListener: BlockGeneratorListener): BlockGenerator = { + null + } } /** diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 819dd2ccfe91..d26894e88fc2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -20,18 +20,23 @@ package org.apache.spark.streaming import java.io.{File, NotSerializableException} import java.util.concurrent.atomic.AtomicInteger +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.Queue + import org.apache.commons.io.FileUtils +import org.scalatest.{Assertions, BeforeAndAfter, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ -import org.scalatest.{Assertions, BeforeAndAfter} +import org.apache.spark._ +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.metrics.source.Source import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, SparkFunSuite} class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeouts with Logging { @@ -110,6 +115,15 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) } + test("checkPoint from conf") { + val checkpointDirectory = Utils.createTempDir().getAbsolutePath() + + val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) + myConf.set("spark.streaming.checkpoint.directory", checkpointDirectory) + ssc = new StreamingContext(myConf, batchDuration) + assert(ssc.checkpointDir != null) + } + test("state matching") { import StreamingContextState._ assert(INITIALIZED === INITIALIZED) @@ -247,7 +261,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo for (i <- 1 to 4) { logInfo("==================================\n\n\n") ssc = new StreamingContext(sc, Milliseconds(100)) - var runningCount = 0 + @volatile var runningCount = 0 TestReceiver.counter.set(1) val input = ssc.receiverStream(new TestReceiver) input.count().foreachRDD { rdd => @@ -256,14 +270,14 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo logInfo("Count = " + count + ", Running count = " + runningCount) } ssc.start() - ssc.awaitTerminationOrTimeout(500) + eventually(timeout(10.seconds), interval(10.millis)) { + assert(runningCount > 0) + } ssc.stop(stopSparkContext = false, stopGracefully = true) logInfo("Running count = " + runningCount) logInfo("TestReceiver.counter = " + TestReceiver.counter.get()) - assert(runningCount > 0) assert( - (TestReceiver.counter.get() == runningCount + 1) || - (TestReceiver.counter.get() == runningCount + 2), + TestReceiver.counter.get() == runningCount + 1, "Received records = " + TestReceiver.counter.get() + ", " + "processed records = " + runningCount ) @@ -271,6 +285,21 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo } } + test("stop gracefully even if a receiver misses StopReceiver") { + // This is not a deterministic unit. But if this unit test is flaky, then there is definitely + // something wrong. See SPARK-5681 + val conf = new SparkConf().setMaster(master).setAppName(appName) + sc = new SparkContext(conf) + ssc = new StreamingContext(sc, Milliseconds(100)) + val input = ssc.receiverStream(new TestReceiver) + input.foreachRDD(_ => {}) + ssc.start() + // Call `ssc.stop` at once so that it's possible that the receiver will miss "StopReceiver" + failAfter(30000 millis) { + ssc.stop(stopSparkContext = true, stopGracefully = true) + } + } + test("stop slow receiver gracefully") { val conf = new SparkConf().setMaster(master).setAppName(appName) conf.set("spark.streaming.gracefulStopTimeout", "20000s") @@ -297,6 +326,25 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo Thread.sleep(100) } + test ("registering and de-registering of streamingSource") { + val conf = new SparkConf().setMaster(master).setAppName(appName) + ssc = new StreamingContext(conf, batchDuration) + assert(ssc.getState() === StreamingContextState.INITIALIZED) + addInputStream(ssc).register() + ssc.start() + + val sources = StreamingContextSuite.getSources(ssc.env.metricsSystem) + val streamingSource = StreamingContextSuite.getStreamingSource(ssc) + assert(sources.contains(streamingSource)) + assert(ssc.getState() === StreamingContextState.ACTIVE) + + ssc.stop() + val sourcesAfterStop = StreamingContextSuite.getSources(ssc.env.metricsSystem) + val streamingSourceAfterStop = StreamingContextSuite.getStreamingSource(ssc) + assert(ssc.getState() === StreamingContextState.STOPPED) + assert(!sourcesAfterStop.contains(streamingSourceAfterStop)) + } + test("awaitTermination") { ssc = new StreamingContext(master, appName, batchDuration) val inputStream = addInputStream(ssc) @@ -321,16 +369,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo } assert(exception.isInstanceOf[TestFailedDueToTimeoutException], "Did not wait for stop") + var t: Thread = null // test whether wait exits if context is stopped failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown - new Thread() { + t = new Thread() { override def run() { Thread.sleep(500) ssc.stop() } - }.start() + } + t.start() ssc.awaitTermination() } + // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is stopped + // before running the next test. Otherwise, it's possible that we set SparkEnv.env to null after + // the next test creates the new SparkContext and fail the test. + t.join() } test("awaitTermination after stop") { @@ -382,16 +436,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo assert(ssc.awaitTerminationOrTimeout(500) === false) } + var t: Thread = null // test whether awaitTerminationOrTimeout() return true if context is stopped failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown - new Thread() { + t = new Thread() { override def run() { Thread.sleep(500) ssc.stop() } - }.start() + } + t.start() assert(ssc.awaitTerminationOrTimeout(10000) === true) } + // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is stopped + // before running the next test. Otherwise, it's possible that we set SparkEnv.env to null after + // the next test creates the new SparkContext and fail the test. + t.join() } test("getOrCreate") { @@ -665,6 +725,29 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo transformed.foreachRDD { rdd => rdd.collect() } } } + test("queueStream doesn't support checkpointing") { + val checkpointDirectory = Utils.createTempDir().getAbsolutePath() + def creatingFunction(): StreamingContext = { + val _ssc = new StreamingContext(conf, batchDuration) + val rdd = _ssc.sparkContext.parallelize(1 to 10) + _ssc.checkpoint(checkpointDirectory) + _ssc.queueStream[Int](Queue(rdd)).register() + _ssc + } + ssc = StreamingContext.getOrCreate(checkpointDirectory, creatingFunction _) + ssc.start() + eventually(timeout(10000 millis)) { + assert(Checkpoint.getCheckpointFiles(checkpointDirectory).size > 1) + } + ssc.stop() + val e = intercept[SparkException] { + ssc = StreamingContext.getOrCreate(checkpointDirectory, creatingFunction _) + } + // StreamingContext.validate changes the message, so use "contains" here + assert(e.getCause.getMessage.contains("queueStream doesn't support checkpointing. " + + "Please don't use queueStream when checkpointing is enabled.")) + } + def addInputStream(s: StreamingContext): DStream[Int] = { val input = (1 to 100).map(i => 1 to i) val inputStream = new TestInputStream(s, input, 1) @@ -716,7 +799,8 @@ class TestReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging } def onStop() { - // no clean to be done, the receiving thread should stop on it own + // no clean to be done, the receiving thread should stop on it own, so just wait for it. + receivingThreadOption.foreach(_.join()) } } @@ -796,3 +880,18 @@ package object testPackage extends Assertions { } } } + +/** + * Helper methods for testing StreamingContextSuite + * This includes methods to access private methods and fields in StreamingContext and MetricsSystem + */ +private object StreamingContextSuite extends PrivateMethodTester { + private val _sources = PrivateMethod[ArrayBuffer[Source]]('sources) + private def getSources(metricsSystem: MetricsSystem): ArrayBuffer[Source] = { + metricsSystem.invokePrivate(_sources()) + } + private val _streamingSource = PrivateMethod[StreamingSource]('streamingSource) + private def getStreamingSource(streamingContext: StreamingContext): StreamingSource = { + streamingContext.invokePrivate(_streamingSource()) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 1dc8960d6052..d840c349bbbc 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -36,13 +36,22 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { val input = (1 to 4).map(Seq(_)).toSeq val operation = (d: DStream[Int]) => d.map(x => x) + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + // To make sure that the processing start and end times in collected // information are different for successive batches override def batchDuration: Duration = Milliseconds(100) override def actuallyWait: Boolean = true test("batch info reporting") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val collector = new BatchInfoCollector ssc.addStreamingListener(collector) runStreams(ssc, input.size, input.size) @@ -59,7 +68,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { batchInfosSubmitted.foreach { info => info.numRecords should be (1L) - info.streamIdToNumRecords should be (Map(0 -> 1L)) + info.streamIdToInputInfo should be (Map(0 -> StreamInputInfo(0, 1L))) } isInIncreasingOrder(batchInfosSubmitted.map(_.submissionTime)) should be (true) @@ -77,7 +86,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { batchInfosStarted.foreach { info => info.numRecords should be (1L) - info.streamIdToNumRecords should be (Map(0 -> 1L)) + info.streamIdToInputInfo should be (Map(0 -> StreamInputInfo(0, 1L))) } isInIncreasingOrder(batchInfosStarted.map(_.submissionTime)) should be (true) @@ -98,7 +107,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { batchInfosCompleted.foreach { info => info.numRecords should be (1L) - info.streamIdToNumRecords should be (Map(0 -> 1L)) + info.streamIdToInputInfo should be (Map(0 -> StreamInputInfo(0, 1L))) } isInIncreasingOrder(batchInfosCompleted.map(_.submissionTime)) should be (true) @@ -107,7 +116,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { } test("receiver info reporting") { - val ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) + ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) inputStream.foreachRDD(_.count) @@ -116,7 +125,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { ssc.start() try { - eventually(timeout(2000 millis), interval(20 millis)) { + eventually(timeout(30 seconds), interval(20 millis)) { collector.startedReceiverStreamIds.size should equal (1) collector.startedReceiverStreamIds(0) should equal (0) collector.stoppedReceiverStreamIds should have size 1 diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 31b1aebf6a8e..0d58a7b54412 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -76,7 +76,7 @@ class TestInputStream[T: ClassTag](ssc_ : StreamingContext, input: Seq[Seq[T]], } // Report the input data's information to InputInfoTracker for testing - val inputInfo = InputInfo(id, selectedInput.length.toLong) + val inputInfo = StreamInputInfo(id, selectedInput.length.toLong) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index a08578680cff..068a6cb0e8fa 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -100,8 +100,8 @@ class UISeleniumSuite // Check stat table val statTableHeaders = findAll(cssSelector("#stat-table th")).map(_.text).toSeq statTableHeaders.exists( - _.matches("Timelines \\(Last \\d+ batches, \\d+ active, \\d+ completed\\)")) should be - (true) + _.matches("Timelines \\(Last \\d+ batches, \\d+ active, \\d+ completed\\)") + ) should be (true) statTableHeaders should contain ("Histograms") val statTableCells = findAll(cssSelector("#stat-table td")).map(_.text).toSeq diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala new file mode 100644 index 000000000000..a38cc603f219 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.receiver + +import scala.collection.mutable + +import org.scalatest.BeforeAndAfter +import org.scalatest.Matchers._ +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.storage.StreamBlockId +import org.apache.spark.util.ManualClock +import org.apache.spark.{SparkException, SparkConf, SparkFunSuite} + +class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { + + private val blockIntervalMs = 10 + private val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms") + @volatile private var blockGenerator: BlockGenerator = null + + after { + if (blockGenerator != null) { + blockGenerator.stop() + } + } + + test("block generation and data callbacks") { + val listener = new TestBlockGeneratorListener + val clock = new ManualClock() + + require(blockIntervalMs > 5) + require(listener.onAddDataCalled === false) + require(listener.onGenerateBlockCalled === false) + require(listener.onPushBlockCalled === false) + + // Verify that creating the generator does not start it + blockGenerator = new BlockGenerator(listener, 0, conf, clock) + assert(blockGenerator.isActive() === false, "block generator active before start()") + assert(blockGenerator.isStopped() === false, "block generator stopped before start()") + assert(listener.onAddDataCalled === false) + assert(listener.onGenerateBlockCalled === false) + assert(listener.onPushBlockCalled === false) + + // Verify start marks the generator active, but does not call the callbacks + blockGenerator.start() + assert(blockGenerator.isActive() === true, "block generator active after start()") + assert(blockGenerator.isStopped() === false, "block generator stopped after start()") + withClue("callbacks called before adding data") { + assert(listener.onAddDataCalled === false) + assert(listener.onGenerateBlockCalled === false) + assert(listener.onPushBlockCalled === false) + } + + // Verify whether addData() adds data that is present in generated blocks + val data1 = 1 to 10 + data1.foreach { blockGenerator.addData _ } + withClue("callbacks called on adding data without metadata and without block generation") { + assert(listener.onAddDataCalled === false) // should be called only with addDataWithCallback() + assert(listener.onGenerateBlockCalled === false) + assert(listener.onPushBlockCalled === false) + } + clock.advance(blockIntervalMs) // advance clock to generate blocks + withClue("blocks not generated or pushed") { + eventually(timeout(1 second)) { + assert(listener.onGenerateBlockCalled === true) + assert(listener.onPushBlockCalled === true) + } + } + listener.pushedData should contain theSameElementsInOrderAs (data1) + assert(listener.onAddDataCalled === false) // should be called only with addDataWithCallback() + + // Verify addDataWithCallback() add data+metadata and and callbacks are called correctly + val data2 = 11 to 20 + val metadata2 = data2.map { _.toString } + data2.zip(metadata2).foreach { case (d, m) => blockGenerator.addDataWithCallback(d, m) } + assert(listener.onAddDataCalled === true) + listener.addedData should contain theSameElementsInOrderAs (data2) + listener.addedMetadata should contain theSameElementsInOrderAs (metadata2) + clock.advance(blockIntervalMs) // advance clock to generate blocks + eventually(timeout(1 second)) { + listener.pushedData should contain theSameElementsInOrderAs (data1 ++ data2) + } + + // Verify addMultipleDataWithCallback() add data+metadata and and callbacks are called correctly + val data3 = 21 to 30 + val metadata3 = "metadata" + blockGenerator.addMultipleDataWithCallback(data3.iterator, metadata3) + listener.addedMetadata should contain theSameElementsInOrderAs (metadata2 :+ metadata3) + clock.advance(blockIntervalMs) // advance clock to generate blocks + eventually(timeout(1 second)) { + listener.pushedData should contain theSameElementsInOrderAs (data1 ++ data2 ++ data3) + } + + // Stop the block generator by starting the stop on a different thread and + // then advancing the manual clock for the stopping to proceed. + val thread = stopBlockGenerator(blockGenerator) + eventually(timeout(1 second), interval(10 milliseconds)) { + clock.advance(blockIntervalMs) + assert(blockGenerator.isStopped() === true) + } + thread.join() + + // Verify that the generator cannot be used any more + intercept[SparkException] { + blockGenerator.addData(1) + } + intercept[SparkException] { + blockGenerator.addDataWithCallback(1, 1) + } + intercept[SparkException] { + blockGenerator.addMultipleDataWithCallback(Iterator(1), 1) + } + intercept[SparkException] { + blockGenerator.start() + } + blockGenerator.stop() // Calling stop again should be fine + } + + test("stop ensures correct shutdown") { + val listener = new TestBlockGeneratorListener + val clock = new ManualClock() + blockGenerator = new BlockGenerator(listener, 0, conf, clock) + require(listener.onGenerateBlockCalled === false) + blockGenerator.start() + assert(blockGenerator.isActive() === true, "block generator") + assert(blockGenerator.isStopped() === false) + + val data = 1 to 1000 + data.foreach { blockGenerator.addData _ } + + // Verify that stop() shutdowns everything in the right order + // - First, stop receiving new data + // - Second, wait for final block with all buffered data to be generated + // - Finally, wait for all blocks to be pushed + clock.advance(1) // to make sure that the timer for another interval to complete + val thread = stopBlockGenerator(blockGenerator) + eventually(timeout(1 second), interval(10 milliseconds)) { + assert(blockGenerator.isActive() === false) + } + assert(blockGenerator.isStopped() === false) + + // Verify that data cannot be added + intercept[SparkException] { + blockGenerator.addData(1) + } + intercept[SparkException] { + blockGenerator.addDataWithCallback(1, null) + } + intercept[SparkException] { + blockGenerator.addMultipleDataWithCallback(Iterator(1), null) + } + + // Verify that stop() stays blocked until another block containing all the data is generated + // This intercept always succeeds, as the body either will either throw a timeout exception + // (expected as stop() should never complete) or a SparkException (unexpected as stop() + // completed and thread terminated). + val exception = intercept[Exception] { + failAfter(200 milliseconds) { + thread.join() + throw new SparkException( + "BlockGenerator.stop() completed before generating timer was stopped") + } + } + exception should not be a [SparkException] + + + // Verify that the final data is present in the final generated block and + // pushed before complete stop + assert(blockGenerator.isStopped() === false) // generator has not stopped yet + clock.advance(blockIntervalMs) // force block generation + failAfter(1 second) { + thread.join() + } + assert(blockGenerator.isStopped() === true) // generator has finally been completely stopped + assert(listener.pushedData === data, "All data not pushed by stop()") + } + + test("block push errors are reported") { + val listener = new TestBlockGeneratorListener { + @volatile var errorReported = false + override def onPushBlock( + blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { + throw new SparkException("test") + } + override def onError(message: String, throwable: Throwable): Unit = { + errorReported = true + } + } + blockGenerator = new BlockGenerator(listener, 0, conf) + blockGenerator.start() + assert(listener.errorReported === false) + blockGenerator.addData(1) + eventually(timeout(1 second), interval(10 milliseconds)) { + assert(listener.errorReported === true) + } + blockGenerator.stop() + } + + /** + * Helper method to stop the block generator with manual clock in a different thread, + * so that the main thread can advance the clock that allows the stopping to proceed. + */ + private def stopBlockGenerator(blockGenerator: BlockGenerator): Thread = { + val thread = new Thread() { + override def run(): Unit = { + blockGenerator.stop() + } + } + thread.start() + thread + } + + /** A listener for BlockGenerator that records the data in the callbacks */ + private class TestBlockGeneratorListener extends BlockGeneratorListener { + val pushedData = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any] + val addedData = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any] + val addedMetadata = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any] + @volatile var onGenerateBlockCalled = false + @volatile var onAddDataCalled = false + @volatile var onPushBlockCalled = false + + override def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { + pushedData ++= arrayBuffer + onPushBlockCalled = true + } + override def onError(message: String, throwable: Throwable): Unit = {} + override def onGenerateBlock(blockId: StreamBlockId): Unit = { + onGenerateBlockCalled = true + } + override def onAddData(data: Any, metadata: Any): Unit = { + addedData += data + addedMetadata += metadata + onAddDataCalled = true + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/RateLimiterSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/RateLimiterSuite.scala new file mode 100644 index 000000000000..c6330eb3673f --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/RateLimiterSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.receiver + +import org.apache.spark.SparkConf +import org.apache.spark.SparkFunSuite + +/** Testsuite for testing the network receiver behavior */ +class RateLimiterSuite extends SparkFunSuite { + + test("rate limiter initializes even without a maxRate set") { + val conf = new SparkConf() + val rateLimiter = new RateLimiter(conf){} + rateLimiter.updateRate(105) + assert(rateLimiter.getCurrentLimit == 105) + } + + test("rate limiter updates when below maxRate") { + val conf = new SparkConf().set("spark.streaming.receiver.maxRate", "110") + val rateLimiter = new RateLimiter(conf){} + rateLimiter.updateRate(105) + assert(rateLimiter.getCurrentLimit == 105) + } + + test("rate limiter stays below maxRate despite large updates") { + val conf = new SparkConf().set("spark.streaming.receiver.maxRate", "100") + val rateLimiter = new RateLimiter(conf){} + rateLimiter.updateRate(105) + assert(rateLimiter.getCurrentLimit === 100) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala index 2e210397fe7c..f5248acf712b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala @@ -46,8 +46,8 @@ class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter { val streamId1 = 0 val streamId2 = 1 val time = Time(0L) - val inputInfo1 = InputInfo(streamId1, 100L) - val inputInfo2 = InputInfo(streamId2, 300L) + val inputInfo1 = StreamInputInfo(streamId1, 100L) + val inputInfo2 = StreamInputInfo(streamId2, 300L) inputInfoTracker.reportInfo(time, inputInfo1) inputInfoTracker.reportInfo(time, inputInfo2) @@ -63,8 +63,8 @@ class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter { val inputInfoTracker = new InputInfoTracker(ssc) val streamId1 = 0 - val inputInfo1 = InputInfo(streamId1, 100L) - val inputInfo2 = InputInfo(streamId1, 300L) + val inputInfo1 = StreamInputInfo(streamId1, 100L) + val inputInfo2 = StreamInputInfo(streamId1, 300L) inputInfoTracker.reportInfo(Time(0), inputInfo1) inputInfoTracker.reportInfo(Time(1), inputInfo2) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala index 7865b06c2e3c..9b6cd4bc4e31 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala @@ -56,7 +56,8 @@ class JobGeneratorSuite extends TestSuiteBase { // 4. allow subsequent batches to be generated (to allow premature deletion of 3rd batch metadata) // 5. verify whether 3rd batch's block metadata still exists // - test("SPARK-6222: Do not clear received block data too soon") { + // TODO: SPARK-7420 enable this test + ignore("SPARK-6222: Do not clear received block data too soon") { import JobGeneratorSuite._ val checkpointDir = Utils.createTempDir() val testConf = conf @@ -76,7 +77,6 @@ class JobGeneratorSuite extends TestSuiteBase { if (time.milliseconds == longBatchTime) { while (waitLatch.getCount() > 0) { waitLatch.await() - println("Await over") } } }) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala new file mode 100644 index 000000000000..1eb52b7029a2 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import scala.collection.mutable + +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.streaming._ +import org.apache.spark.streaming.scheduler.rate.RateEstimator + +class RateControllerSuite extends TestSuiteBase { + + override def useManualClock: Boolean = false + + override def batchDuration: Duration = Milliseconds(50) + + test("RateController - rate controller publishes updates after batches complete") { + val ssc = new StreamingContext(conf, batchDuration) + withStreamingContext(ssc) { ssc => + val dstream = new RateTestInputDStream(ssc) + dstream.register() + ssc.start() + + eventually(timeout(10.seconds)) { + assert(dstream.publishedRates > 0) + } + } + } + + test("ReceiverRateController - published rates reach receivers") { + val ssc = new StreamingContext(conf, batchDuration) + withStreamingContext(ssc) { ssc => + val estimator = new ConstantEstimator(100) + val dstream = new RateTestInputDStream(ssc) { + override val rateController = + Some(new ReceiverRateController(id, estimator)) + } + dstream.register() + ssc.start() + + // Wait for receiver to start + eventually(timeout(5.seconds)) { + RateTestReceiver.getActive().nonEmpty + } + + // Update rate in the estimator and verify whether the rate was published to the receiver + def updateRateAndVerify(rate: Long): Unit = { + estimator.updateRate(rate) + eventually(timeout(5.seconds)) { + assert(RateTestReceiver.getActive().get.getDefaultBlockGeneratorRateLimit() === rate) + } + } + + // Verify multiple rate update + Seq(100, 200, 300).foreach { rate => + updateRateAndVerify(rate) + } + } + } +} + +private[streaming] class ConstantEstimator(@volatile private var rate: Long) + extends RateEstimator { + + def updateRate(newRate: Long): Unit = { + rate = newRate + } + + def compute( + time: Long, + elements: Long, + processingDelay: Long, + schedulingDelay: Long): Option[Double] = Some(rate) +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala new file mode 100644 index 000000000000..b2a51d72bac2 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite + +class ReceiverSchedulingPolicySuite extends SparkFunSuite { + + val receiverSchedulingPolicy = new ReceiverSchedulingPolicy + + test("rescheduleReceiver: empty executors") { + val scheduledExecutors = + receiverSchedulingPolicy.rescheduleReceiver(0, None, Map.empty, executors = Seq.empty) + assert(scheduledExecutors === Seq.empty) + } + + test("rescheduleReceiver: receiver preferredLocation") { + val receiverTrackingInfoMap = Map( + 0 -> ReceiverTrackingInfo(0, ReceiverState.INACTIVE, None, None)) + val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver( + 0, Some("host1"), receiverTrackingInfoMap, executors = Seq("host2")) + assert(scheduledExecutors.toSet === Set("host1", "host2")) + } + + test("rescheduleReceiver: return all idle executors if there are any idle executors") { + val executors = Seq("host1", "host2", "host3", "host4", "host5") + // host3 is idle + val receiverTrackingInfoMap = Map( + 0 -> ReceiverTrackingInfo(0, ReceiverState.ACTIVE, None, Some("host1"))) + val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver( + 1, None, receiverTrackingInfoMap, executors) + assert(scheduledExecutors.toSet === Set("host2", "host3", "host4", "host5")) + } + + test("rescheduleReceiver: return all executors that have minimum weight if no idle executors") { + val executors = Seq("host1", "host2", "host3", "host4", "host5") + // Weights: host1 = 1.5, host2 = 0.5, host3 = 1.0, host4 = 0.5, host5 = 0.5 + val receiverTrackingInfoMap = Map( + 0 -> ReceiverTrackingInfo(0, ReceiverState.ACTIVE, None, Some("host1")), + 1 -> ReceiverTrackingInfo(1, ReceiverState.SCHEDULED, Some(Seq("host2", "host3")), None), + 2 -> ReceiverTrackingInfo(2, ReceiverState.SCHEDULED, Some(Seq("host1", "host3")), None), + 3 -> ReceiverTrackingInfo(4, ReceiverState.SCHEDULED, Some(Seq("host4", "host5")), None)) + val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver( + 4, None, receiverTrackingInfoMap, executors) + assert(scheduledExecutors.toSet === Set("host2", "host4", "host5")) + } + + test("scheduleReceivers: " + + "schedule receivers evenly when there are more receivers than executors") { + val receivers = (0 until 6).map(new RateTestReceiver(_)) + val executors = (10000 until 10003).map(port => s"localhost:${port}") + val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) + val numReceiversOnExecutor = mutable.HashMap[String, Int]() + // There should be 2 receivers running on each executor and each receiver has one executor + scheduledExecutors.foreach { case (receiverId, executors) => + assert(executors.size == 1) + numReceiversOnExecutor(executors(0)) = numReceiversOnExecutor.getOrElse(executors(0), 0) + 1 + } + assert(numReceiversOnExecutor === executors.map(_ -> 2).toMap) + } + + + test("scheduleReceivers: " + + "schedule receivers evenly when there are more executors than receivers") { + val receivers = (0 until 3).map(new RateTestReceiver(_)) + val executors = (10000 until 10006).map(port => s"localhost:${port}") + val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) + val numReceiversOnExecutor = mutable.HashMap[String, Int]() + // There should be 1 receiver running on each executor and each receiver has two executors + scheduledExecutors.foreach { case (receiverId, executors) => + assert(executors.size == 2) + executors.foreach { l => + numReceiversOnExecutor(l) = numReceiversOnExecutor.getOrElse(l, 0) + 1 + } + } + assert(numReceiversOnExecutor === executors.map(_ -> 1).toMap) + } + + test("scheduleReceivers: schedule receivers evenly when the preferredLocations are even") { + val receivers = (0 until 3).map(new RateTestReceiver(_)) ++ + (3 until 6).map(new RateTestReceiver(_, Some("localhost"))) + val executors = (10000 until 10003).map(port => s"localhost:${port}") ++ + (10003 until 10006).map(port => s"localhost2:${port}") + val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) + val numReceiversOnExecutor = mutable.HashMap[String, Int]() + // There should be 1 receiver running on each executor and each receiver has 1 executor + scheduledExecutors.foreach { case (receiverId, executors) => + assert(executors.size == 1) + executors.foreach { l => + numReceiversOnExecutor(l) = numReceiversOnExecutor.getOrElse(l, 0) + 1 + } + } + assert(numReceiversOnExecutor === executors.map(_ -> 1).toMap) + // Make sure we schedule the receivers to their preferredLocations + val executorsForReceiversWithPreferredLocation = + scheduledExecutors.filter { case (receiverId, executors) => receiverId >= 3 }.flatMap(_._2) + // We can simply check the executor set because we only know each receiver only has 1 executor + assert(executorsForReceiversWithPreferredLocation.toSet === + (10000 until 10003).map(port => s"localhost:${port}").toSet) + } + + test("scheduleReceivers: return empty if no receiver") { + assert(receiverSchedulingPolicy.scheduleReceivers(Seq.empty, Seq("localhost:10000")).isEmpty) + } + + test("scheduleReceivers: return empty scheduled executors if no executors") { + val receivers = (0 until 3).map(new RateTestReceiver(_)) + val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, Seq.empty) + scheduledExecutors.foreach { case (receiverId, executors) => + assert(executors.isEmpty) + } + } + +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala new file mode 100644 index 000000000000..45138b748eca --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.storage.{StorageLevel, StreamBlockId} +import org.apache.spark.streaming._ +import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.receiver._ + +/** Testsuite for receiver scheduling */ +class ReceiverTrackerSuite extends TestSuiteBase { + + test("send rate update to receivers") { + withStreamingContext(new StreamingContext(conf, Milliseconds(100))) { ssc => + ssc.scheduler.listenerBus.start(ssc.sc) + + val newRateLimit = 100L + val inputDStream = new RateTestInputDStream(ssc) + val tracker = new ReceiverTracker(ssc) + tracker.start() + try { + // we wait until the Receiver has registered with the tracker, + // otherwise our rate update is lost + eventually(timeout(5 seconds)) { + assert(RateTestReceiver.getActive().nonEmpty) + } + + + // Verify that the rate of the block generator in the receiver get updated + val activeReceiver = RateTestReceiver.getActive().get + tracker.sendRateUpdate(inputDStream.id, newRateLimit) + eventually(timeout(5 seconds)) { + assert(activeReceiver.getDefaultBlockGeneratorRateLimit() === newRateLimit, + "default block generator did not receive rate update") + assert(activeReceiver.getCustomBlockGeneratorRateLimit() === newRateLimit, + "other block generator did not receive rate update") + } + } finally { + tracker.stop(false) + } + } + } + + test("should restart receiver after stopping it") { + withStreamingContext(new StreamingContext(conf, Milliseconds(100))) { ssc => + @volatile var startTimes = 0 + ssc.addStreamingListener(new StreamingListener { + override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { + startTimes += 1 + } + }) + val input = ssc.receiverStream(new StoppableReceiver) + val output = new TestOutputStream(input) + output.register() + ssc.start() + StoppableReceiver.shouldStop = true + eventually(timeout(10 seconds), interval(10 millis)) { + // The receiver is stopped once, so if it's restarted, it should be started twice. + assert(startTimes === 2) + } + } + } +} + +/** An input DStream with for testing rate controlling */ +private[streaming] class RateTestInputDStream(@transient ssc_ : StreamingContext) + extends ReceiverInputDStream[Int](ssc_) { + + override def getReceiver(): Receiver[Int] = new RateTestReceiver(id) + + @volatile + var publishedRates = 0 + + override val rateController: Option[RateController] = { + Some(new RateController(id, new ConstantEstimator(100)) { + override def publish(rate: Long): Unit = { + publishedRates += 1 + } + }) + } +} + +/** A receiver implementation for testing rate controlling */ +private[streaming] class RateTestReceiver(receiverId: Int, host: Option[String] = None) + extends Receiver[Int](StorageLevel.MEMORY_ONLY) { + + private lazy val customBlockGenerator = supervisor.createBlockGenerator( + new BlockGeneratorListener { + override def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]): Unit = {} + override def onError(message: String, throwable: Throwable): Unit = {} + override def onGenerateBlock(blockId: StreamBlockId): Unit = {} + override def onAddData(data: Any, metadata: Any): Unit = {} + } + ) + + setReceiverId(receiverId) + + override def onStart(): Unit = { + customBlockGenerator + RateTestReceiver.registerReceiver(this) + } + + override def onStop(): Unit = { + RateTestReceiver.deregisterReceiver() + } + + override def preferredLocation: Option[String] = host + + def getDefaultBlockGeneratorRateLimit(): Long = { + supervisor.getCurrentRateLimit + } + + def getCustomBlockGeneratorRateLimit(): Long = { + customBlockGenerator.getCurrentLimit + } +} + +/** + * A helper object to RateTestReceiver that give access to the currently active RateTestReceiver + * instance. + */ +private[streaming] object RateTestReceiver { + @volatile private var activeReceiver: RateTestReceiver = null + + def registerReceiver(receiver: RateTestReceiver): Unit = { + activeReceiver = receiver + } + + def deregisterReceiver(): Unit = { + activeReceiver = null + } + + def getActive(): Option[RateTestReceiver] = Option(activeReceiver) +} + +/** + * A custom receiver that could be stopped via StoppableReceiver.shouldStop + */ +class StoppableReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) { + + var receivingThreadOption: Option[Thread] = None + + def onStart() { + val thread = new Thread() { + override def run() { + while (!StoppableReceiver.shouldStop) { + Thread.sleep(10) + } + StoppableReceiver.this.stop("stop") + } + } + thread.start() + } + + def onStop() { + StoppableReceiver.shouldStop = true + receivingThreadOption.foreach(_.join()) + // Reset it so as to restart it + StoppableReceiver.shouldStop = false + } +} + +object StoppableReceiver { + @volatile var shouldStop = false +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala new file mode 100644 index 000000000000..a1af95be81c8 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler.rate + +import scala.util.Random + +import org.scalatest.Inspectors.forAll +import org.scalatest.Matchers + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.streaming.Seconds + +class PIDRateEstimatorSuite extends SparkFunSuite with Matchers { + + test("the right estimator is created") { + val conf = new SparkConf + conf.set("spark.streaming.backpressure.rateEstimator", "pid") + val pid = RateEstimator.create(conf, Seconds(1)) + pid.getClass should equal(classOf[PIDRateEstimator]) + } + + test("estimator checks ranges") { + intercept[IllegalArgumentException] { + new PIDRateEstimator(batchIntervalMillis = 0, 1, 2, 3, 10) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, proportional = -1, 2, 3, 10) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, 0, integral = -1, 3, 10) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, 0, 0, derivative = -1, 10) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, 0, 0, 0, minRate = 0) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, 0, 0, 0, minRate = -10) + } + } + + test("first estimate is None") { + val p = createDefaultEstimator() + p.compute(0, 10, 10, 0) should equal(None) + } + + test("second estimate is not None") { + val p = createDefaultEstimator() + p.compute(0, 10, 10, 0) + // 1000 elements / s + p.compute(10, 10, 10, 0) should equal(Some(1000)) + } + + test("no estimate when no time difference between successive calls") { + val p = createDefaultEstimator() + p.compute(0, 10, 10, 0) + p.compute(time = 10, 10, 10, 0) shouldNot equal(None) + p.compute(time = 10, 10, 10, 0) should equal(None) + } + + test("no estimate when no records in previous batch") { + val p = createDefaultEstimator() + p.compute(0, 10, 10, 0) + p.compute(10, numElements = 0, 10, 0) should equal(None) + p.compute(20, numElements = -10, 10, 0) should equal(None) + } + + test("no estimate when there is no processing delay") { + val p = createDefaultEstimator() + p.compute(0, 10, 10, 0) + p.compute(10, 10, processingDelay = 0, 0) should equal(None) + p.compute(20, 10, processingDelay = -10, 0) should equal(None) + } + + test("estimate is never less than min rate") { + val minRate = 5D + val p = new PIDRateEstimator(20, 1D, 1D, 0D, minRate) + // prepare a series of batch updates, one every 20ms, 0 processed elements, 2ms of processing + // this might point the estimator to try and decrease the bound, but we test it never + // goes below the min rate, which would be nonsensical. + val times = List.tabulate(50)(x => x * 20) // every 20ms + val elements = List.fill(50)(1) // no processing + val proc = List.fill(50)(20) // 20ms of processing + val sched = List.fill(50)(100) // strictly positive accumulation + val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) + res.head should equal(None) + res.tail should equal(List.fill(49)(Some(minRate))) + } + + test("with no accumulated or positive error, |I| > 0, follow the processing speed") { + val p = new PIDRateEstimator(20, 1D, 1D, 0D, 10) + // prepare a series of batch updates, one every 20ms with an increasing number of processed + // elements in each batch, but constant processing time, and no accumulated error. Even though + // the integral part is non-zero, the estimated rate should follow only the proportional term + val times = List.tabulate(50)(x => x * 20) // every 20ms + val elements = List.tabulate(50)(x => (x + 1) * 20) // increasing + val proc = List.fill(50)(20) // 20ms of processing + val sched = List.fill(50)(0) + val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) + res.head should equal(None) + res.tail should equal(List.tabulate(50)(x => Some((x + 1) * 1000D)).tail) + } + + test("with no accumulated but some positive error, |I| > 0, follow the processing speed") { + val p = new PIDRateEstimator(20, 1D, 1D, 0D, 10) + // prepare a series of batch updates, one every 20ms with an decreasing number of processed + // elements in each batch, but constant processing time, and no accumulated error. Even though + // the integral part is non-zero, the estimated rate should follow only the proportional term, + // asking for less and less elements + val times = List.tabulate(50)(x => x * 20) // every 20ms + val elements = List.tabulate(50)(x => (50 - x) * 20) // decreasing + val proc = List.fill(50)(20) // 20ms of processing + val sched = List.fill(50)(0) + val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) + res.head should equal(None) + res.tail should equal(List.tabulate(50)(x => Some((50 - x) * 1000D)).tail) + } + + test("with some accumulated and some positive error, |I| > 0, stay below the processing speed") { + val minRate = 10D + val p = new PIDRateEstimator(20, 1D, .01D, 0D, minRate) + val times = List.tabulate(50)(x => x * 20) // every 20ms + val rng = new Random() + val elements = List.tabulate(50)(x => rng.nextInt(1000) + 1000) + val procDelayMs = 20 + val proc = List.fill(50)(procDelayMs) // 20ms of processing + val sched = List.tabulate(50)(x => rng.nextInt(19) + 1) // random wait + val speeds = elements map ((x) => x.toDouble / procDelayMs * 1000) + + val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) + res.head should equal(None) + forAll(List.range(1, 50)) { (n) => + res(n) should not be None + if (res(n).get > 0 && sched(n) > 0) { + res(n).get should be < speeds(n) + res(n).get should be >= minRate + } + } + } + + private def createDefaultEstimator(): PIDRateEstimator = { + new PIDRateEstimator(20, 1D, 0D, 0D, 10) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index c9175d61b1f4..995f1197ccdf 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -22,15 +22,24 @@ import java.util.Properties import org.scalatest.Matchers import org.apache.spark.scheduler.SparkListenerJobStart +import org.apache.spark.streaming._ import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.scheduler._ -import org.apache.spark.streaming.{Duration, Time, Milliseconds, TestSuiteBase} class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val input = (1 to 4).map(Seq(_)).toSeq val operation = (d: DStream[Int]) => d.map(x => x) + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + private def createJobStart( batchTime: Time, outputOpId: Int, jobId: Int): SparkListenerJobStart = { val properties = new Properties() @@ -46,13 +55,15 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { test("onBatchSubmitted, onBatchStarted, onBatchCompleted, " + "onReceiverStarted, onReceiverError, onReceiverStopped") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val listener = new StreamingJobProgressListener(ssc) - val streamIdToNumRecords = Map(0 -> 300L, 1 -> 300L) + val streamIdToInputInfo = Map( + 0 -> StreamInputInfo(0, 300L), + 1 -> StreamInputInfo(1, 300L, Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "test"))) // onBatchSubmitted - val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, None, None) + val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None) listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) listener.waitingBatches should be (List(BatchUIData(batchInfoSubmitted))) listener.runningBatches should be (Nil) @@ -64,7 +75,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.numTotalReceivedRecords should be (0) // onBatchStarted - val batchInfoStarted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoStarted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (List(BatchUIData(batchInfoStarted))) @@ -94,7 +105,9 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { batchUIData.get.schedulingDelay should be (batchInfoStarted.schedulingDelay) batchUIData.get.processingDelay should be (batchInfoStarted.processingDelay) batchUIData.get.totalDelay should be (batchInfoStarted.totalDelay) - batchUIData.get.streamIdToNumRecords should be (Map(0 -> 300L, 1 -> 300L)) + batchUIData.get.streamIdToInputInfo should be (Map( + 0 -> StreamInputInfo(0, 300L), + 1 -> StreamInputInfo(1, 300L, Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "test")))) batchUIData.get.numRecords should be(600) batchUIData.get.outputOpIdSparkJobIdPairs should be Seq(OutputOpIdAndSparkJobId(0, 0), @@ -103,7 +116,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { OutputOpIdAndSparkJobId(1, 1)) // onBatchCompleted - val batchInfoCompleted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (Nil) @@ -115,20 +128,20 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.numTotalReceivedRecords should be (600) // onReceiverStarted - val receiverInfoStarted = ReceiverInfo(0, "test", null, true, "localhost") + val receiverInfoStarted = ReceiverInfo(0, "test", true, "localhost") listener.onReceiverStarted(StreamingListenerReceiverStarted(receiverInfoStarted)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (None) // onReceiverError - val receiverInfoError = ReceiverInfo(1, "test", null, true, "localhost") + val receiverInfoError = ReceiverInfo(1, "test", true, "localhost") listener.onReceiverError(StreamingListenerReceiverError(receiverInfoError)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (Some(receiverInfoError)) listener.receiverInfo(2) should be (None) // onReceiverStopped - val receiverInfoStopped = ReceiverInfo(2, "test", null, true, "localhost") + val receiverInfoStopped = ReceiverInfo(2, "test", true, "localhost") listener.onReceiverStopped(StreamingListenerReceiverStopped(receiverInfoStopped)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (Some(receiverInfoError)) @@ -137,13 +150,13 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { } test("Remove the old completed batches when exceeding the limit") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) val listener = new StreamingJobProgressListener(ssc) - val streamIdToNumRecords = Map(0 -> 300L, 1 -> 300L) + val streamIdToInputInfo = Map(0 -> StreamInputInfo(0, 300L), 1 -> StreamInputInfo(1, 300L)) - val batchInfoCompleted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) for(_ <- 0 until (limit + 10)) { listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) @@ -154,7 +167,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { } test("out-of-order onJobStart and onBatchXXX") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) val listener = new StreamingJobProgressListener(ssc) @@ -182,7 +195,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { batchUIData.get.schedulingDelay should be (batchInfoSubmitted.schedulingDelay) batchUIData.get.processingDelay should be (batchInfoSubmitted.processingDelay) batchUIData.get.totalDelay should be (batchInfoSubmitted.totalDelay) - batchUIData.get.streamIdToNumRecords should be (Map.empty) + batchUIData.get.streamIdToInputInfo should be (Map.empty) batchUIData.get.numRecords should be (0) batchUIData.get.outputOpIdSparkJobIdPairs should be (Seq(OutputOpIdAndSparkJobId(0, 0))) @@ -205,20 +218,20 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { } test("detect memory leak") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val listener = new StreamingJobProgressListener(ssc) val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) for (_ <- 0 until 2 * limit) { - val streamIdToNumRecords = Map(0 -> 300L, 1 -> 300L) + val streamIdToInputInfo = Map(0 -> StreamInputInfo(0, 300L), 1 -> StreamInputInfo(1, 300L)) // onBatchSubmitted - val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, None, None) + val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None) listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) // onBatchStarted - val batchInfoStarted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoStarted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) // onJobStart @@ -235,7 +248,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.onJobStart(jobStart4) // onBatchCompleted - val batchInfoCompleted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 325ff7c74c39..5e49fd00769a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -20,6 +20,7 @@ import java.io._ import java.nio.ByteBuffer import java.util +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.{implicitConversions, postfixOps} @@ -417,9 +418,8 @@ object WriteAheadLogSuite { /** Read all the data in the log file in a directory using the WriteAheadLog class. */ def readDataUsingWriteAheadLog(logDirectory: String): Seq[String] = { - import scala.collection.JavaConversions._ val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1) - val data = wal.readAll().map(byteBufferToString).toSeq + val data = wal.readAll().asScala.map(byteBufferToString).toSeq wal.close() data } diff --git a/tools/pom.xml b/tools/pom.xml index feffde4c857e..1e64f280e5be 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml @@ -76,10 +76,6 @@ org.apache.maven.plugins maven-source-plugin - - org.codehaus.mojo - build-helper-maven-plugin - diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 595ded6ae67f..a0524cabff2d 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -15,13 +15,14 @@ * limitations under the License. */ +// scalastyle:off classforname package org.apache.spark.tools import java.io.File import java.util.jar.JarFile import scala.collection.mutable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.runtimeMirror import scala.reflect.runtime.{universe => unv} import scala.util.Try @@ -92,7 +93,9 @@ object GenerateMIMAIgnore { ignoredMembers ++= getAnnotatedOrPackagePrivateMembers(classSymbol) } catch { + // scalastyle:off println case _: Throwable => println("Error instrumenting class:" + className) + // scalastyle:on println } } (ignoredClasses.flatMap(c => Seq(c, c.replace("$", "#"))).toSet, ignoredMembers.toSet) @@ -108,7 +111,9 @@ object GenerateMIMAIgnore { .filter(_.contains("$$")).map(classSymbol.fullName + "." + _) } catch { case t: Throwable => + // scalastyle:off println println("[WARN] Unable to detect inner functions for class:" + classSymbol.fullName) + // scalastyle:on println Seq.empty[String] } } @@ -128,12 +133,14 @@ object GenerateMIMAIgnore { getOrElse(Iterator.empty).mkString("\n") File(".generated-mima-class-excludes") .writeAll(previousContents + privateClasses.mkString("\n")) + // scalastyle:off println println("Created : .generated-mima-class-excludes in current directory.") val previousMembersContents = Try(File(".generated-mima-member-excludes").lines) .getOrElse(Iterator.empty).mkString("\n") File(".generated-mima-member-excludes").writeAll(previousMembersContents + privateMembers.mkString("\n")) println("Created : .generated-mima-member-excludes in current directory.") + // scalastyle:on println } @@ -154,7 +161,7 @@ object GenerateMIMAIgnore { val path = packageName.replace('.', '/') val resources = classLoader.getResources(path) - val jars = resources.filter(x => x.getProtocol == "jar") + val jars = resources.asScala.filter(_.getProtocol == "jar") .map(_.getFile.split(":")(1).split("!")(0)).toSeq jars.flatMap(getClassesFromJar(_, path)) @@ -168,15 +175,18 @@ object GenerateMIMAIgnore { private def getClassesFromJar(jarPath: String, packageName: String) = { import scala.collection.mutable val jar = new JarFile(new File(jarPath)) - val enums = jar.entries().map(_.getName).filter(_.startsWith(packageName)) + val enums = jar.entries().asScala.map(_.getName).filter(_.startsWith(packageName)) val classes = mutable.HashSet[Class[_]]() for (entry <- enums if entry.endsWith(".class")) { try { classes += Class.forName(entry.replace('/', '.').stripSuffix(".class"), false, classLoader) } catch { + // scalastyle:off println case _: Throwable => println("Unable to load:" + entry) + // scalastyle:on println } } classes } } +// scalastyle:on classforname diff --git a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala index 583823c90c5c..856ea177a9a1 100644 --- a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala +++ b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala @@ -323,11 +323,14 @@ object JavaAPICompletenessChecker { val missingMethods = javaEquivalents -- javaMethods for (method <- missingMethods) { + // scalastyle:off println println(method) + // scalastyle:on println } } def main(args: Array[String]) { + // scalastyle:off println println("Missing RDD methods") printMissingMethods(classOf[RDD[_]], classOf[JavaRDD[_]]) println() @@ -359,5 +362,6 @@ object JavaAPICompletenessChecker { println("Missing PairDStream methods") printMissingMethods(classOf[PairDStreamFunctions[_, _]], classOf[JavaPairDStream[_, _]]) println() + // scalastyle:on println } } diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index baa97616eaff..0dc2861253f1 100644 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -85,7 +85,9 @@ object StoragePerfTester { latch.countDown() } catch { case e: Exception => + // scalastyle:off println println("Exception in child thread: " + e + " " + e.getMessage) + // scalastyle:on println System.exit(1) } } @@ -97,9 +99,11 @@ object StoragePerfTester { val bytesPerSecond = totalBytes.get() / time val bytesPerFile = (totalBytes.get() / (numOutputSplits * numMaps.toDouble)).toLong + // scalastyle:off println System.err.println("files_total\t\t%s".format(numMaps * numOutputSplits)) System.err.println("bytes_per_file\t\t%s".format(Utils.bytesToString(bytesPerFile))) System.err.println("agg_throughput\t\t%s/s".format(Utils.bytesToString(bytesPerSecond.toLong))) + // scalastyle:on println executor.shutdown() sc.stop() diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 62c6354f1e20..066abe92e51c 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml @@ -67,7 +67,17 @@ org.mockito - mockito-all + mockito-core + test + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.apache.commons + commons-lang3 test @@ -80,7 +90,7 @@ net.alchim31.maven scala-maven-plugin - + -XDignore.symbol.file diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java b/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java new file mode 100644 index 000000000000..5c9d5d9a3831 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe; + +import java.io.IOException; + +public abstract class KVIterator { + + public abstract boolean next() throws IOException; + + public abstract K getKey(); + + public abstract V getValue(); + + public abstract void close(); +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java similarity index 52% rename from unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java rename to unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 192c6714b240..1c16da982923 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -21,91 +21,107 @@ import sun.misc.Unsafe; -public final class PlatformDependent { +public final class Platform { - /** - * Facade in front of {@link sun.misc.Unsafe}, used to avoid directly exposing Unsafe outside of - * this package. This also lets us avoid accidental use of deprecated methods. - */ - public static final class UNSAFE { + private static final Unsafe _UNSAFE; - private UNSAFE() { } + public static final int BYTE_ARRAY_OFFSET; - public static int getInt(Object object, long offset) { - return _UNSAFE.getInt(object, offset); - } + public static final int INT_ARRAY_OFFSET; - public static void putInt(Object object, long offset, int value) { - _UNSAFE.putInt(object, offset, value); - } + public static final int LONG_ARRAY_OFFSET; - public static boolean getBoolean(Object object, long offset) { - return _UNSAFE.getBoolean(object, offset); - } + public static final int DOUBLE_ARRAY_OFFSET; - public static void putBoolean(Object object, long offset, boolean value) { - _UNSAFE.putBoolean(object, offset, value); - } + public static int getInt(Object object, long offset) { + return _UNSAFE.getInt(object, offset); + } - public static byte getByte(Object object, long offset) { - return _UNSAFE.getByte(object, offset); - } + public static void putInt(Object object, long offset, int value) { + _UNSAFE.putInt(object, offset, value); + } - public static void putByte(Object object, long offset, byte value) { - _UNSAFE.putByte(object, offset, value); - } + public static boolean getBoolean(Object object, long offset) { + return _UNSAFE.getBoolean(object, offset); + } - public static short getShort(Object object, long offset) { - return _UNSAFE.getShort(object, offset); - } + public static void putBoolean(Object object, long offset, boolean value) { + _UNSAFE.putBoolean(object, offset, value); + } - public static void putShort(Object object, long offset, short value) { - _UNSAFE.putShort(object, offset, value); - } + public static byte getByte(Object object, long offset) { + return _UNSAFE.getByte(object, offset); + } - public static long getLong(Object object, long offset) { - return _UNSAFE.getLong(object, offset); - } + public static void putByte(Object object, long offset, byte value) { + _UNSAFE.putByte(object, offset, value); + } - public static void putLong(Object object, long offset, long value) { - _UNSAFE.putLong(object, offset, value); - } + public static short getShort(Object object, long offset) { + return _UNSAFE.getShort(object, offset); + } - public static float getFloat(Object object, long offset) { - return _UNSAFE.getFloat(object, offset); - } + public static void putShort(Object object, long offset, short value) { + _UNSAFE.putShort(object, offset, value); + } - public static void putFloat(Object object, long offset, float value) { - _UNSAFE.putFloat(object, offset, value); - } + public static long getLong(Object object, long offset) { + return _UNSAFE.getLong(object, offset); + } - public static double getDouble(Object object, long offset) { - return _UNSAFE.getDouble(object, offset); - } + public static void putLong(Object object, long offset, long value) { + _UNSAFE.putLong(object, offset, value); + } - public static void putDouble(Object object, long offset, double value) { - _UNSAFE.putDouble(object, offset, value); - } + public static float getFloat(Object object, long offset) { + return _UNSAFE.getFloat(object, offset); + } - public static long allocateMemory(long size) { - return _UNSAFE.allocateMemory(size); - } + public static void putFloat(Object object, long offset, float value) { + _UNSAFE.putFloat(object, offset, value); + } - public static void freeMemory(long address) { - _UNSAFE.freeMemory(address); - } + public static double getDouble(Object object, long offset) { + return _UNSAFE.getDouble(object, offset); + } + public static void putDouble(Object object, long offset, double value) { + _UNSAFE.putDouble(object, offset, value); } - private static final Unsafe _UNSAFE; + public static Object getObjectVolatile(Object object, long offset) { + return _UNSAFE.getObjectVolatile(object, offset); + } - public static final int BYTE_ARRAY_OFFSET; + public static void putObjectVolatile(Object object, long offset, Object value) { + _UNSAFE.putObjectVolatile(object, offset, value); + } - public static final int INT_ARRAY_OFFSET; + public static long allocateMemory(long size) { + return _UNSAFE.allocateMemory(size); + } - public static final int LONG_ARRAY_OFFSET; + public static void freeMemory(long address) { + _UNSAFE.freeMemory(address); + } - public static final int DOUBLE_ARRAY_OFFSET; + public static void copyMemory( + Object src, long srcOffset, Object dst, long dstOffset, long length) { + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + srcOffset += size; + dstOffset += size; + } + } + + /** + * Raises an exception bypassing compiler checks for checked exceptions. + */ + public static void throwException(Throwable t) { + _UNSAFE.throwException(t); + } /** * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to @@ -136,26 +152,4 @@ public static void freeMemory(long address) { DOUBLE_ARRAY_OFFSET = 0; } } - - static public void copyMemory( - Object src, - long srcOffset, - Object dst, - long dstOffset, - long length) { - while (length > 0) { - long size = Math.min(length, UNSAFE_COPY_THRESHOLD); - _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); - length -= size; - srcOffset += size; - dstOffset += size; - } - } - - /** - * Raises an exception bypassing compiler checks for checked exceptions. - */ - public static void throwException(Throwable t) { - _UNSAFE.throwException(t); - } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index 53eadf96a6b5..cf42877bf9fd 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.array; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; public class ByteArrayMethods { @@ -25,6 +25,12 @@ private ByteArrayMethods() { // Private constructor, since this class only contains static methods. } + /** Returns the next number greater or equal num that is power of 2. */ + public static long nextPowerOf2(long num) { + final long highBit = Long.highestOneBit(num); + return (highBit == num) ? num : highBit << 1; + } + public static int roundNumberOfBytesToNearestWord(int numBytes) { int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` if (remainder == 0) { @@ -35,21 +41,25 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { } /** - * Optimized byte array equality check for 8-byte-word-aligned byte arrays. + * Optimized byte array equality check for byte arrays. * @return true if the arrays are equal, false otherwise */ - public static boolean wordAlignedArrayEquals( - Object leftBaseObject, - long leftBaseOffset, - Object rightBaseObject, - long rightBaseOffset, - long arrayLengthInBytes) { - for (int i = 0; i < arrayLengthInBytes; i += 8) { - final long left = - PlatformDependent.UNSAFE.getLong(leftBaseObject, leftBaseOffset + i); - final long right = - PlatformDependent.UNSAFE.getLong(rightBaseObject, rightBaseOffset + i); - if (left != right) return false; + public static boolean arrayEquals( + Object leftBase, long leftOffset, Object rightBase, long rightOffset, final long length) { + int i = 0; + while (i <= length - 8) { + if (Platform.getLong(leftBase, leftOffset + i) != + Platform.getLong(rightBase, rightOffset + i)) { + return false; + } + i += 8; + } + while (i < length) { + if (Platform.getByte(leftBase, leftOffset + i) != + Platform.getByte(rightBase, rightOffset + i)) { + return false; + } + i += 1; } return true; } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java index 18d1f0d2d7eb..74105050e419 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.array; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; /** @@ -64,7 +64,7 @@ public long size() { public void set(int index, long value) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - PlatformDependent.UNSAFE.putLong(baseObj, baseOffset + index * WIDTH, value); + Platform.putLong(baseObj, baseOffset + index * WIDTH, value); } /** @@ -73,6 +73,6 @@ public void set(int index, long value) { public long get(int index) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - return PlatformDependent.UNSAFE.getLong(baseObj, baseOffset + index * WIDTH); + return Platform.getLong(baseObj, baseOffset + index * WIDTH); } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java index 28e23da108eb..7c124173b0bb 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java @@ -90,7 +90,7 @@ public boolean isSet(int index) { * To iterate over the true bits in a BitSet, use the following loop: *
        * 
    -   *  for (long i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i + 1)) {
    +   *  for (long i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i + 1)) {
        *    // operate on index i here
        *  }
        * 
    diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java
    index 0987191c1c63..7857bf66a72a 100644
    --- a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java
    +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java
    @@ -17,7 +17,7 @@
     
     package org.apache.spark.unsafe.bitset;
     
    -import org.apache.spark.unsafe.PlatformDependent;
    +import org.apache.spark.unsafe.Platform;
     
     /**
      * Methods for working with fixed-size uncompressed bitsets.
    @@ -41,8 +41,8 @@ public static void set(Object baseObject, long baseOffset, int index) {
         assert index >= 0 : "index (" + index + ") should >= 0";
         final long mask = 1L << (index & 0x3f);  // mod 64 and shift
         final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE;
    -    final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset);
    -    PlatformDependent.UNSAFE.putLong(baseObject, wordOffset, word | mask);
    +    final long word = Platform.getLong(baseObject, wordOffset);
    +    Platform.putLong(baseObject, wordOffset, word | mask);
       }
     
       /**
    @@ -52,8 +52,8 @@ public static void unset(Object baseObject, long baseOffset, int index) {
         assert index >= 0 : "index (" + index + ") should >= 0";
         final long mask = 1L << (index & 0x3f);  // mod 64 and shift
         final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE;
    -    final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset);
    -    PlatformDependent.UNSAFE.putLong(baseObject, wordOffset, word & ~mask);
    +    final long word = Platform.getLong(baseObject, wordOffset);
    +    Platform.putLong(baseObject, wordOffset, word & ~mask);
       }
     
       /**
    @@ -63,7 +63,7 @@ public static boolean isSet(Object baseObject, long baseOffset, int index) {
         assert index >= 0 : "index (" + index + ") should >= 0";
         final long mask = 1L << (index & 0x3f);  // mod 64 and shift
         final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE;
    -    final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset);
    +    final long word = Platform.getLong(baseObject, wordOffset);
         return (word & mask) != 0;
       }
     
    @@ -73,7 +73,7 @@ public static boolean isSet(Object baseObject, long baseOffset, int index) {
       public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidthInWords) {
         long addr = baseOffset;
         for (int i = 0; i < bitSetWidthInWords; i++, addr += WORD_SIZE) {
    -      if (PlatformDependent.UNSAFE.getLong(baseObject, addr) != 0) {
    +      if (Platform.getLong(baseObject, addr) != 0) {
             return true;
           }
         }
    @@ -87,7 +87,7 @@ public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidt
        * To iterate over the true bits in a BitSet, use the following loop:
        * 
        * 
    -   *  for (long i = bs.nextSetBit(0, sizeInWords); i >= 0; i = bs.nextSetBit(i + 1, sizeInWords)) {
    +   *  for (long i = bs.nextSetBit(0, sizeInWords); i >= 0; i = bs.nextSetBit(i + 1, sizeInWords)) {
        *    // operate on index i here
        *  }
        * 
    @@ -109,8 +109,7 @@ public static int nextSetBit(
     
         // Try to find the next set bit in the current word
         final int subIndex = fromIndex & 0x3f;
    -    long word =
    -      PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + wi * WORD_SIZE) >> subIndex;
    +    long word = Platform.getLong(baseObject, baseOffset + wi * WORD_SIZE) >> subIndex;
         if (word != 0) {
           return (wi << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word);
         }
    @@ -118,7 +117,7 @@ public static int nextSetBit(
         // Find the next set bit in the rest of the words
         wi += 1;
         while (wi < bitsetSizeInWords) {
    -      word = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + wi * WORD_SIZE);
    +      word = Platform.getLong(baseObject, baseOffset + wi * WORD_SIZE);
           if (word != 0) {
             return (wi << 6) + java.lang.Long.numberOfTrailingZeros(word);
           }
    diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
    index 85cd02469adb..4276f25c2165 100644
    --- a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
    +++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
    @@ -17,7 +17,7 @@
     
     package org.apache.spark.unsafe.hash;
     
    -import org.apache.spark.unsafe.PlatformDependent;
    +import org.apache.spark.unsafe.Platform;
     
     /**
      * 32-bit Murmur3 hasher.  This is based on Guava's Murmur3_32HashFunction.
    @@ -44,12 +44,16 @@ public int hashInt(int input) {
         return fmix(h1, 4);
       }
     
    -  public int hashUnsafeWords(Object baseObject, long baseOffset, int lengthInBytes) {
    +  public int hashUnsafeWords(Object base, long offset, int lengthInBytes) {
    +    return hashUnsafeWords(base, offset, lengthInBytes, seed);
    +  }
    +
    +  public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) {
         // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method.
         assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)";
         int h1 = seed;
    -    for (int offset = 0; offset < lengthInBytes; offset += 4) {
    -      int halfWord = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset);
    +    for (int i = 0; i < lengthInBytes; i += 4) {
    +      int halfWord = Platform.getInt(base, offset + i);
           int k1 = mixK1(halfWord);
           h1 = mixH1(h1, k1);
         }
    diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
    index bbe83d36cf36..6722301df19d 100644
    --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
    +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
    @@ -24,6 +24,9 @@ public class HeapMemoryAllocator implements MemoryAllocator {
     
       @Override
       public MemoryBlock allocate(long size) throws OutOfMemoryError {
    +    if (size % 8 != 0) {
    +      throw new IllegalArgumentException("Size " + size + " was not a multiple of 8");
    +    }
         long[] array = new long[(int) (size / 8)];
         return MemoryBlock.fromLongArray(array);
       }
    diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
    index 3dc82d8c2eb3..dd7582083437 100644
    --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
    +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
    @@ -19,7 +19,7 @@
     
     import javax.annotation.Nullable;
     
    -import org.apache.spark.unsafe.PlatformDependent;
    +import org.apache.spark.unsafe.Platform;
     
     /**
      * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size.
    @@ -34,7 +34,7 @@ public class MemoryBlock extends MemoryLocation {
        */
       int pageNumber = -1;
     
    -  MemoryBlock(@Nullable Object obj, long offset, long length) {
    +  public MemoryBlock(@Nullable Object obj, long offset, long length) {
         super(obj, offset);
         this.length = length;
       }
    @@ -50,6 +50,6 @@ public long size() {
        * Creates a memory block pointing to the memory used by the long array.
        */
       public static MemoryBlock fromLongArray(final long[] array) {
    -    return new MemoryBlock(array, PlatformDependent.LONG_ARRAY_OFFSET, array.length * 8);
    +    return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8);
       }
     }
    diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
    index 10881969dbc7..97b2c93f0dc3 100644
    --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
    +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
    @@ -58,8 +58,13 @@ public class TaskMemoryManager {
       /** The number of entries in the page table. */
       private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS;
     
    -  /** Maximum supported data page size */
    -  private static final long MAXIMUM_PAGE_SIZE = (1L << OFFSET_BITS);
    +  /**
    +   * Maximum supported data page size (in bytes). In principle, the maximum addressable page size is
    +   * (1L << OFFSET_BITS) bytes, which is 2+ petabytes. However, the on-heap allocator's maximum page
    +   * size is limited by the maximum amount of data that can be stored in a  long[] array, which is
    +   * (2^32 - 1) * 8 bytes (or 16 gigabytes). Therefore, we cap this at 16 gigabytes.
    +   */
    +  public static final long MAXIMUM_PAGE_SIZE_BYTES = ((1L << 31) - 1) * 8L;
     
       /** Bit mask for the lower 51 bits of a long. */
       private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL;
    @@ -110,9 +115,9 @@ public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) {
        * intended for allocating large blocks of memory that will be shared between operators.
        */
       public MemoryBlock allocatePage(long size) {
    -    if (size > MAXIMUM_PAGE_SIZE) {
    +    if (size > MAXIMUM_PAGE_SIZE_BYTES) {
           throw new IllegalArgumentException(
    -        "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE + " bytes");
    +        "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes");
         }
     
         final int pageNumber;
    @@ -139,14 +144,16 @@ public MemoryBlock allocatePage(long size) {
       public void freePage(MemoryBlock page) {
         assert (page.pageNumber != -1) :
           "Called freePage() on memory that wasn't allocated with allocatePage()";
    -    executorMemoryManager.free(page);
    +    assert(allocatedPages.get(page.pageNumber));
    +    pageTable[page.pageNumber] = null;
         synchronized (this) {
           allocatedPages.clear(page.pageNumber);
         }
    -    pageTable[page.pageNumber] = null;
         if (logger.isTraceEnabled()) {
           logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size());
         }
    +    // Cannot access a page once it's freed.
    +    executorMemoryManager.free(page);
       }
     
       /**
    @@ -159,8 +166,11 @@ public void freePage(MemoryBlock page) {
        * top-level Javadoc for more details).
        */
       public MemoryBlock allocate(long size) throws OutOfMemoryError {
    +    assert(size > 0) : "Size must be positive, but got " + size;
         final MemoryBlock memory = executorMemoryManager.allocate(size);
    -    allocatedNonPageMemory.add(memory);
    +    synchronized(allocatedNonPageMemory) {
    +      allocatedNonPageMemory.add(memory);
    +    }
         return memory;
       }
     
    @@ -170,8 +180,10 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError {
       public void free(MemoryBlock memory) {
         assert (memory.pageNumber == -1) : "Should call freePage() for pages, not free()";
         executorMemoryManager.free(memory);
    -    final boolean wasAlreadyRemoved = !allocatedNonPageMemory.remove(memory);
    -    assert (!wasAlreadyRemoved) : "Called free() on memory that was already freed!";
    +    synchronized(allocatedNonPageMemory) {
    +      final boolean wasAlreadyRemoved = !allocatedNonPageMemory.remove(memory);
    +      assert (!wasAlreadyRemoved) : "Called free() on memory that was already freed!";
    +    }
       }
     
       /**
    @@ -217,9 +229,10 @@ public Object getPage(long pagePlusOffsetAddress) {
         if (inHeap) {
           final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
           assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
    -      final Object page = pageTable[pageNumber].getBaseObject();
    +      final MemoryBlock page = pageTable[pageNumber];
           assert (page != null);
    -      return page;
    +      assert (page.getBaseObject() != null);
    +      return page.getBaseObject();
         } else {
           return null;
         }
    @@ -238,7 +251,9 @@ public long getOffsetInPage(long pagePlusOffsetAddress) {
           // converted the absolute address into a relative address. Here, we invert that operation:
           final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
           assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
    -      return pageTable[pageNumber].getBaseOffset() + offsetInPage;
    +      final MemoryBlock page = pageTable[pageNumber];
    +      assert (page != null);
    +      return page.getBaseOffset() + offsetInPage;
         }
       }
     
    @@ -254,14 +269,17 @@ public long cleanUpAllAllocatedMemory() {
             freePage(page);
           }
         }
    -    final Iterator iter = allocatedNonPageMemory.iterator();
    -    while (iter.hasNext()) {
    -      final MemoryBlock memory = iter.next();
    -      freedBytes += memory.size();
    -      // We don't call free() here because that calls Set.remove, which would lead to a
    -      // ConcurrentModificationException here.
    -      executorMemoryManager.free(memory);
    -      iter.remove();
    +
    +    synchronized (allocatedNonPageMemory) {
    +      final Iterator iter = allocatedNonPageMemory.iterator();
    +      while (iter.hasNext()) {
    +        final MemoryBlock memory = iter.next();
    +        freedBytes += memory.size();
    +        // We don't call free() here because that calls Set.remove, which would lead to a
    +        // ConcurrentModificationException here.
    +        executorMemoryManager.free(memory);
    +        iter.remove();
    +      }
         }
         return freedBytes;
       }
    diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
    index 15898771fef2..cda7826c8c99 100644
    --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
    +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
    @@ -17,7 +17,7 @@
     
     package org.apache.spark.unsafe.memory;
     
    -import org.apache.spark.unsafe.PlatformDependent;
    +import org.apache.spark.unsafe.Platform;
     
     /**
      * A simple {@link MemoryAllocator} that uses {@code Unsafe} to allocate off-heap memory.
    @@ -26,7 +26,10 @@ public class UnsafeMemoryAllocator implements MemoryAllocator {
     
       @Override
       public MemoryBlock allocate(long size) throws OutOfMemoryError {
    -    long address = PlatformDependent.UNSAFE.allocateMemory(size);
    +    if (size % 8 != 0) {
    +      throw new IllegalArgumentException("Size " + size + " was not a multiple of 8");
    +    }
    +    long address = Platform.allocateMemory(size);
         return new MemoryBlock(null, address, size);
       }
     
    @@ -34,6 +37,6 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError {
       public void free(MemoryBlock memory) {
         assert (memory.obj == null) :
           "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?";
    -    PlatformDependent.UNSAFE.freeMemory(memory.offset);
    +    Platform.freeMemory(memory.offset);
       }
     }
    diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java
    new file mode 100644
    index 000000000000..c08c9c73d239
    --- /dev/null
    +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java
    @@ -0,0 +1,32 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.unsafe.types;
    +
    +import org.apache.spark.unsafe.Platform;
    +
    +public class ByteArray {
    +
    +  /**
    +   * Writes the content of a byte array into a memory address, identified by an object and an
    +   * offset. The target memory address must already been allocated, and have enough space to
    +   * hold all the bytes in this string.
    +   */
    +  public static void writeToMemory(byte[] src, Object target, long targetOffset) {
    +    Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET, target, targetOffset, src.length);
    +  }
    +}
    diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
    new file mode 100644
    index 000000000000..30e175807636
    --- /dev/null
    +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
    @@ -0,0 +1,310 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.unsafe.types;
    +
    +import java.io.Serializable;
    +import java.util.regex.Matcher;
    +import java.util.regex.Pattern;
    +
    +/**
    + * The internal representation of interval type.
    + */
    +public final class CalendarInterval implements Serializable {
    +  public static final long MICROS_PER_MILLI = 1000L;
    +  public static final long MICROS_PER_SECOND = MICROS_PER_MILLI * 1000;
    +  public static final long MICROS_PER_MINUTE = MICROS_PER_SECOND * 60;
    +  public static final long MICROS_PER_HOUR = MICROS_PER_MINUTE * 60;
    +  public static final long MICROS_PER_DAY = MICROS_PER_HOUR * 24;
    +  public static final long MICROS_PER_WEEK = MICROS_PER_DAY * 7;
    +
    +  /**
    +   * A function to generate regex which matches interval string's unit part like "3 years".
    +   *
    +   * First, we can leave out some units in interval string, and we only care about the value of
    +   * unit, so here we use non-capturing group to wrap the actual regex.
    +   * At the beginning of the actual regex, we should match spaces before the unit part.
    +   * Next is the number part, starts with an optional "-" to represent negative value. We use
    +   * capturing group to wrap this part as we need the value later.
    +   * Finally is the unit name, ends with an optional "s".
    +   */
    +  private static String unitRegex(String unit) {
    +    return "(?:\\s+(-?\\d+)\\s+" + unit + "s?)?";
    +  }
    +
    +  private static Pattern p = Pattern.compile("interval" + unitRegex("year") + unitRegex("month") +
    +    unitRegex("week") + unitRegex("day") + unitRegex("hour") + unitRegex("minute") +
    +    unitRegex("second") + unitRegex("millisecond") + unitRegex("microsecond"));
    +
    +  private static Pattern yearMonthPattern =
    +    Pattern.compile("^(?:['|\"])?([+|-])?(\\d+)-(\\d+)(?:['|\"])?$");
    +
    +  private static Pattern dayTimePattern =
    +    Pattern.compile("^(?:['|\"])?([+|-])?(\\d+) (\\d+):(\\d+):(\\d+)(\\.(\\d+))?(?:['|\"])?$");
    +
    +  private static Pattern quoteTrimPattern = Pattern.compile("^(?:['|\"])?(.*?)(?:['|\"])?$");
    +
    +  private static long toLong(String s) {
    +    if (s == null) {
    +      return 0;
    +    } else {
    +      return Long.valueOf(s);
    +    }
    +  }
    +
    +  public static CalendarInterval fromString(String s) {
    +    if (s == null) {
    +      return null;
    +    }
    +    s = s.trim();
    +    Matcher m = p.matcher(s);
    +    if (!m.matches() || s.equals("interval")) {
    +      return null;
    +    } else {
    +      long months = toLong(m.group(1)) * 12 + toLong(m.group(2));
    +      long microseconds = toLong(m.group(3)) * MICROS_PER_WEEK;
    +      microseconds += toLong(m.group(4)) * MICROS_PER_DAY;
    +      microseconds += toLong(m.group(5)) * MICROS_PER_HOUR;
    +      microseconds += toLong(m.group(6)) * MICROS_PER_MINUTE;
    +      microseconds += toLong(m.group(7)) * MICROS_PER_SECOND;
    +      microseconds += toLong(m.group(8)) * MICROS_PER_MILLI;
    +      microseconds += toLong(m.group(9));
    +      return new CalendarInterval((int) months, microseconds);
    +    }
    +  }
    +
    +  public static long toLongWithRange(String fieldName,
    +      String s, long minValue, long maxValue) throws IllegalArgumentException {
    +    long result = 0;
    +    if (s != null) {
    +      result = Long.valueOf(s);
    +      if (result < minValue || result > maxValue) {
    +        throw new IllegalArgumentException(String.format("%s %d outside range [%d, %d]",
    +          fieldName, result, minValue, maxValue));
    +      }
    +    }
    +    return result;
    +  }
    +
    +  /**
    +   * Parse YearMonth string in form: [-]YYYY-MM
    +   *
    +   * adapted from HiveIntervalYearMonth.valueOf
    +   */
    +  public static CalendarInterval fromYearMonthString(String s) throws IllegalArgumentException {
    +    CalendarInterval result = null;
    +    if (s == null) {
    +      throw new IllegalArgumentException("Interval year-month string was null");
    +    }
    +    s = s.trim();
    +    Matcher m = yearMonthPattern.matcher(s);
    +    if (!m.matches()) {
    +      throw new IllegalArgumentException(
    +        "Interval string does not match year-month format of 'y-m': " + s);
    +    } else {
    +      try {
    +        int sign = m.group(1) != null && m.group(1).equals("-") ? -1 : 1;
    +        int years = (int) toLongWithRange("year", m.group(2), 0, Integer.MAX_VALUE);
    +        int months = (int) toLongWithRange("month", m.group(3), 0, 11);
    +        result = new CalendarInterval(sign * (years * 12 + months), 0);
    +      } catch (Exception e) {
    +        throw new IllegalArgumentException(
    +          "Error parsing interval year-month string: " + e.getMessage(), e);
    +      }
    +    }
    +    return result;
    +  }
    +
    +  /**
    +   * Parse dayTime string in form: [-]d HH:mm:ss.nnnnnnnnn
    +   *
    +   * adapted from HiveIntervalDayTime.valueOf
    +   */
    +  public static CalendarInterval fromDayTimeString(String s) throws IllegalArgumentException {
    +    CalendarInterval result = null;
    +    if (s == null) {
    +      throw new IllegalArgumentException("Interval day-time string was null");
    +    }
    +    s = s.trim();
    +    Matcher m = dayTimePattern.matcher(s);
    +    if (!m.matches()) {
    +      throw new IllegalArgumentException(
    +        "Interval string does not match day-time format of 'd h:m:s.n': " + s);
    +    } else {
    +      try {
    +        int sign = m.group(1) != null && m.group(1).equals("-") ? -1 : 1;
    +        long days = toLongWithRange("day", m.group(2), 0, Integer.MAX_VALUE);
    +        long hours = toLongWithRange("hour", m.group(3), 0, 23);
    +        long minutes = toLongWithRange("minute", m.group(4), 0, 59);
    +        long seconds = toLongWithRange("second", m.group(5), 0, 59);
    +        // Hive allow nanosecond precision interval
    +        long nanos = toLongWithRange("nanosecond", m.group(7), 0L, 999999999L);
    +        result = new CalendarInterval(0, sign * (
    +          days * MICROS_PER_DAY + hours * MICROS_PER_HOUR + minutes * MICROS_PER_MINUTE +
    +          seconds * MICROS_PER_SECOND + nanos / 1000L));
    +      } catch (Exception e) {
    +        throw new IllegalArgumentException(
    +          "Error parsing interval day-time string: " + e.getMessage(), e);
    +      }
    +    }
    +    return result;
    +  }
    +
    +  public static CalendarInterval fromSingleUnitString(String unit, String s)
    +      throws IllegalArgumentException {
    +
    +    CalendarInterval result = null;
    +    if (s == null) {
    +      throw new IllegalArgumentException(String.format("Interval %s string was null", unit));
    +    }
    +    s = s.trim();
    +    Matcher m = quoteTrimPattern.matcher(s);
    +    if (!m.matches()) {
    +      throw new IllegalArgumentException(
    +        "Interval string does not match day-time format of 'd h:m:s.n': " + s);
    +    } else {
    +      try {
    +        if (unit.equals("year")) {
    +          int year = (int) toLongWithRange("year", m.group(1),
    +            Integer.MIN_VALUE / 12, Integer.MAX_VALUE / 12);
    +          result = new CalendarInterval(year * 12, 0L);
    +
    +        } else if (unit.equals("month")) {
    +          int month = (int) toLongWithRange("month", m.group(1),
    +            Integer.MIN_VALUE, Integer.MAX_VALUE);
    +          result = new CalendarInterval(month, 0L);
    +
    +        } else if (unit.equals("day")) {
    +          long day = toLongWithRange("day", m.group(1),
    +            Long.MIN_VALUE / MICROS_PER_DAY, Long.MAX_VALUE / MICROS_PER_DAY);
    +          result = new CalendarInterval(0, day * MICROS_PER_DAY);
    +
    +        } else if (unit.equals("hour")) {
    +          long hour = toLongWithRange("hour", m.group(1),
    +            Long.MIN_VALUE / MICROS_PER_HOUR, Long.MAX_VALUE / MICROS_PER_HOUR);
    +          result = new CalendarInterval(0, hour * MICROS_PER_HOUR);
    +
    +        } else if (unit.equals("minute")) {
    +          long minute = toLongWithRange("minute", m.group(1),
    +            Long.MIN_VALUE / MICROS_PER_MINUTE, Long.MAX_VALUE / MICROS_PER_MINUTE);
    +          result = new CalendarInterval(0, minute * MICROS_PER_MINUTE);
    +
    +        } else if (unit.equals("second")) {
    +          long micros = parseSecondNano(m.group(1));
    +          result = new CalendarInterval(0, micros);
    +        }
    +      } catch (Exception e) {
    +        throw new IllegalArgumentException("Error parsing interval string: " + e.getMessage(), e);
    +      }
    +    }
    +    return result;
    +  }
    +
    +  /**
    +   * Parse second_nano string in ss.nnnnnnnnn format to microseconds
    +   */
    +  public static long parseSecondNano(String secondNano) throws IllegalArgumentException {
    +    String[] parts = secondNano.split("\\.");
    +    if (parts.length == 1) {
    +      return toLongWithRange("second", parts[0], Long.MIN_VALUE / MICROS_PER_SECOND,
    +        Long.MAX_VALUE / MICROS_PER_SECOND) * MICROS_PER_SECOND;
    +
    +    } else if (parts.length == 2) {
    +      long seconds = parts[0].equals("") ? 0L : toLongWithRange("second", parts[0],
    +        Long.MIN_VALUE / MICROS_PER_SECOND, Long.MAX_VALUE / MICROS_PER_SECOND);
    +      long nanos = toLongWithRange("nanosecond", parts[1], 0L, 999999999L);
    +      return seconds * MICROS_PER_SECOND + nanos / 1000L;
    +
    +    } else {
    +      throw new IllegalArgumentException(
    +        "Interval string does not match second-nano format of ss.nnnnnnnnn");
    +    }
    +  }
    +
    +  public final int months;
    +  public final long microseconds;
    +
    +  public CalendarInterval(int months, long microseconds) {
    +    this.months = months;
    +    this.microseconds = microseconds;
    +  }
    +
    +  public CalendarInterval add(CalendarInterval that) {
    +    int months = this.months + that.months;
    +    long microseconds = this.microseconds + that.microseconds;
    +    return new CalendarInterval(months, microseconds);
    +  }
    +
    +  public CalendarInterval subtract(CalendarInterval that) {
    +    int months = this.months - that.months;
    +    long microseconds = this.microseconds - that.microseconds;
    +    return new CalendarInterval(months, microseconds);
    +  }
    +
    +  public CalendarInterval negate() {
    +    return new CalendarInterval(-this.months, -this.microseconds);
    +  }
    +
    +  @Override
    +  public boolean equals(Object other) {
    +    if (this == other) return true;
    +    if (other == null || !(other instanceof CalendarInterval)) return false;
    +
    +    CalendarInterval o = (CalendarInterval) other;
    +    return this.months == o.months && this.microseconds == o.microseconds;
    +  }
    +
    +  @Override
    +  public int hashCode() {
    +    return 31 * months + (int) microseconds;
    +  }
    +
    +  @Override
    +  public String toString() {
    +    StringBuilder sb = new StringBuilder("interval");
    +
    +    if (months != 0) {
    +      appendUnit(sb, months / 12, "year");
    +      appendUnit(sb, months % 12, "month");
    +    }
    +
    +    if (microseconds != 0) {
    +      long rest = microseconds;
    +      appendUnit(sb, rest / MICROS_PER_WEEK, "week");
    +      rest %= MICROS_PER_WEEK;
    +      appendUnit(sb, rest / MICROS_PER_DAY, "day");
    +      rest %= MICROS_PER_DAY;
    +      appendUnit(sb, rest / MICROS_PER_HOUR, "hour");
    +      rest %= MICROS_PER_HOUR;
    +      appendUnit(sb, rest / MICROS_PER_MINUTE, "minute");
    +      rest %= MICROS_PER_MINUTE;
    +      appendUnit(sb, rest / MICROS_PER_SECOND, "second");
    +      rest %= MICROS_PER_SECOND;
    +      appendUnit(sb, rest / MICROS_PER_MILLI, "millisecond");
    +      rest %= MICROS_PER_MILLI;
    +      appendUnit(sb, rest, "microsecond");
    +    }
    +
    +    return sb.toString();
    +  }
    +
    +  private void appendUnit(StringBuilder sb, long value, String unit) {
    +    if (value != 0) {
    +      sb.append(" " + value + " " + unit + "s");
    +    }
    +  }
    +}
    diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
    index 9871a70a40e6..216aeea60d1c 100644
    --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
    +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
    @@ -17,12 +17,17 @@
     
     package org.apache.spark.unsafe.types;
     
    -import java.io.Serializable;
    -import java.io.UnsupportedEncodingException;
    -import java.util.Arrays;
     import javax.annotation.Nonnull;
    +import java.io.*;
    +import java.nio.ByteOrder;
    +import java.util.Arrays;
    +import java.util.Map;
    +
    +import org.apache.spark.unsafe.Platform;
    +import org.apache.spark.unsafe.array.ByteArrayMethods;
    +
    +import static org.apache.spark.unsafe.Platform.*;
     
    -import org.apache.spark.unsafe.PlatformDependent;
     
     /**
      * A UTF-8 String for internal Spark use.
    @@ -32,72 +37,189 @@
      * 

    * Note: This is not designed for general use cases, should not be used outside SQL. */ -public final class UTF8String implements Comparable, Serializable { +public final class UTF8String implements Comparable, Externalizable { + // These are only updated by readExternal() @Nonnull - private byte[] bytes; + private Object base; + private long offset; + private int numBytes; + + public Object getBaseObject() { return base; } + public long getBaseOffset() { return offset; } private static int[] bytesOfCodePointInUTF8 = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, - 6, 6, 6, 6}; + 6, 6}; + private static boolean isLittleEndian = ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN; + + private static final UTF8String COMMA_UTF8 = UTF8String.fromString(","); + public static final UTF8String EMPTY_UTF8 = UTF8String.fromString(""); + + /** + * Creates an UTF8String from byte array, which should be encoded in UTF-8. + * + * Note: `bytes` will be hold by returned UTF8String. + */ public static UTF8String fromBytes(byte[] bytes) { - return (bytes != null) ? new UTF8String().set(bytes) : null; + if (bytes != null) { + return new UTF8String(bytes, BYTE_ARRAY_OFFSET, bytes.length); + } else { + return null; + } } - public static UTF8String fromString(String str) { - return (str != null) ? new UTF8String().set(str) : null; + /** + * Creates an UTF8String from byte array, which should be encoded in UTF-8. + * + * Note: `bytes` will be hold by returned UTF8String. + */ + public static UTF8String fromBytes(byte[] bytes, int offset, int numBytes) { + if (bytes != null) { + return new UTF8String(bytes, BYTE_ARRAY_OFFSET + offset, numBytes); + } else { + return null; + } + } + + /** + * Creates an UTF8String from given address (base and offset) and length. + */ + public static UTF8String fromAddress(Object base, long offset, int numBytes) { + return new UTF8String(base, offset, numBytes); } /** - * Updates the UTF8String with String. + * Creates an UTF8String from String. */ - protected UTF8String set(final String str) { + public static UTF8String fromString(String str) { + if (str == null) return null; try { - bytes = str.getBytes("utf-8"); + return fromBytes(str.getBytes("utf-8")); } catch (UnsupportedEncodingException e) { // Turn the exception into unchecked so we can find out about it at runtime, but // don't need to add lots of boilerplate code everywhere. - PlatformDependent.throwException(e); + throwException(e); + return null; } - return this; } /** - * Updates the UTF8String with byte[], which should be encoded in UTF-8. + * Creates an UTF8String that contains `length` spaces. + */ + public static UTF8String blankString(int length) { + byte[] spaces = new byte[length]; + Arrays.fill(spaces, (byte) ' '); + return fromBytes(spaces); + } + + protected UTF8String(Object base, long offset, int numBytes) { + this.base = base; + this.offset = offset; + this.numBytes = numBytes; + } + + // for serialization + public UTF8String() { + this(null, 0, 0); + } + + /** + * Writes the content of this string into a memory address, identified by an object and an offset. + * The target memory address must already been allocated, and have enough space to hold all the + * bytes in this string. */ - protected UTF8String set(final byte[] bytes) { - this.bytes = bytes; - return this; + public void writeToMemory(Object target, long targetOffset) { + Platform.copyMemory(base, offset, target, targetOffset, numBytes); } /** * Returns the number of bytes for a code point with the first byte as `b` * @param b The first byte of a code point */ - public int numBytes(final byte b) { + private static int numBytesForFirstByte(final byte b) { final int offset = (b & 0xFF) - 192; return (offset >= 0) ? bytesOfCodePointInUTF8[offset] : 1; } + /** + * Returns the number of bytes + */ + public int numBytes() { + return numBytes; + } + /** * Returns the number of code points in it. - * - * This is only used by Substring() when `start` is negative. */ - public int length() { + public int numChars() { int len = 0; - for (int i = 0; i < bytes.length; i+= numBytes(bytes[i])) { + for (int i = 0; i < numBytes; i += numBytesForFirstByte(getByte(i))) { len += 1; } return len; } + /** + * Returns a 64-bit integer that can be used as the prefix used in sorting. + */ + public long getPrefix() { + // Since JVMs are either 4-byte aligned or 8-byte aligned, we check the size of the string. + // If size is 0, just return 0. + // If size is between 0 and 4 (inclusive), assume data is 4-byte aligned under the hood and + // use a getInt to fetch the prefix. + // If size is greater than 4, assume we have at least 8 bytes of data to fetch. + // After getting the data, we use a mask to mask out data that is not part of the string. + long p; + long mask = 0; + if (isLittleEndian) { + if (numBytes >= 8) { + p = Platform.getLong(base, offset); + } else if (numBytes > 4) { + p = Platform.getLong(base, offset); + mask = (1L << (8 - numBytes) * 8) - 1; + } else if (numBytes > 0) { + p = (long) Platform.getInt(base, offset); + mask = (1L << (8 - numBytes) * 8) - 1; + } else { + p = 0; + } + p = java.lang.Long.reverseBytes(p); + } else { + // byteOrder == ByteOrder.BIG_ENDIAN + if (numBytes >= 8) { + p = Platform.getLong(base, offset); + } else if (numBytes > 4) { + p = Platform.getLong(base, offset); + mask = (1L << (8 - numBytes) * 8) - 1; + } else if (numBytes > 0) { + p = ((long) Platform.getInt(base, offset)) << 32; + mask = (1L << (8 - numBytes) * 8) - 1; + } else { + p = 0; + } + } + p &= ~mask; + return p; + } + + /** + * Returns the underline bytes, will be a copy of it if it's part of another array. + */ public byte[] getBytes() { - return bytes; + // avoid copy if `base` is `byte[]` + if (offset == BYTE_ARRAY_OFFSET && base instanceof byte[] + && ((byte[]) base).length == numBytes) { + return (byte[]) base; + } else { + byte[] bytes = new byte[numBytes]; + copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); + return bytes; + } } /** @@ -106,92 +228,613 @@ public byte[] getBytes() { * @param until the position after last code point, exclusive. */ public UTF8String substring(final int start, final int until) { - if (until <= start || start >= bytes.length) { - return UTF8String.fromBytes(new byte[0]); + if (until <= start || start >= numBytes) { + return EMPTY_UTF8; } int i = 0; int c = 0; - for (; i < bytes.length && c < start; i += numBytes(bytes[i])) { + while (i < numBytes && c < start) { + i += numBytesForFirstByte(getByte(i)); c += 1; } int j = i; - for (; j < bytes.length && c < until; j += numBytes(bytes[i])) { + while (i < numBytes && c < until) { + i += numBytesForFirstByte(getByte(i)); c += 1; } - return UTF8String.fromBytes(Arrays.copyOfRange(bytes, i, j)); + if (i > j) { + byte[] bytes = new byte[i - j]; + copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j); + return fromBytes(bytes); + } else { + return EMPTY_UTF8; + } + } + + public UTF8String substringSQL(int pos, int length) { + // Information regarding the pos calculation: + // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and + // negative indices for start positions. If a start index i is greater than 0, it + // refers to element i-1 in the sequence. If a start index i is less than 0, it refers + // to the -ith element before the end of the sequence. If a start index i is 0, it + // refers to the first element. + int len = numChars(); + int start = (pos > 0) ? pos -1 : ((pos < 0) ? len + pos : 0); + int end = (length == Integer.MAX_VALUE) ? len : start + length; + return substring(start, end); } + /** + * Returns whether this contains `substring` or not. + */ public boolean contains(final UTF8String substring) { - final byte[] b = substring.getBytes(); - if (b.length == 0) { + if (substring.numBytes == 0) { return true; } - for (int i = 0; i <= bytes.length - b.length; i++) { - if (bytes[i] == b[0] && startsWith(b, i)) { + byte first = substring.getByte(0); + for (int i = 0; i <= numBytes - substring.numBytes; i++) { + if (getByte(i) == first && matchAt(substring, i)) { return true; } } return false; } - private boolean startsWith(final byte[] prefix, int offsetInBytes) { - if (prefix.length + offsetInBytes > bytes.length || offsetInBytes < 0) { + /** + * Returns the byte at position `i`. + */ + private byte getByte(int i) { + return Platform.getByte(base, offset + i); + } + + private boolean matchAt(final UTF8String s, int pos) { + if (s.numBytes + pos > numBytes || pos < 0) { return false; } - int i = 0; - while (i < prefix.length && prefix[i] == bytes[i + offsetInBytes]) { - i++; - } - return i == prefix.length; + return ByteArrayMethods.arrayEquals(base, offset + pos, s.base, s.offset, s.numBytes); } public boolean startsWith(final UTF8String prefix) { - return startsWith(prefix.getBytes(), 0); + return matchAt(prefix, 0); } public boolean endsWith(final UTF8String suffix) { - return startsWith(suffix.getBytes(), bytes.length - suffix.getBytes().length); + return matchAt(suffix, numBytes - suffix.numBytes); } + /** + * Returns the upper case of this string + */ public UTF8String toUpperCase() { - return UTF8String.fromString(toString().toUpperCase()); + if (numBytes == 0) { + return EMPTY_UTF8; + } + + byte[] bytes = new byte[numBytes]; + bytes[0] = (byte) Character.toTitleCase(getByte(0)); + for (int i = 0; i < numBytes; i++) { + byte b = getByte(i); + if (numBytesForFirstByte(b) != 1) { + // fallback + return toUpperCaseSlow(); + } + int upper = Character.toUpperCase((int) b); + if (upper > 127) { + // fallback + return toUpperCaseSlow(); + } + bytes[i] = (byte) upper; + } + return fromBytes(bytes); } + private UTF8String toUpperCaseSlow() { + return fromString(toString().toUpperCase()); + } + + /** + * Returns the lower case of this string + */ public UTF8String toLowerCase() { - return UTF8String.fromString(toString().toLowerCase()); + if (numBytes == 0) { + return EMPTY_UTF8; + } + + byte[] bytes = new byte[numBytes]; + bytes[0] = (byte) Character.toTitleCase(getByte(0)); + for (int i = 0; i < numBytes; i++) { + byte b = getByte(i); + if (numBytesForFirstByte(b) != 1) { + // fallback + return toLowerCaseSlow(); + } + int lower = Character.toLowerCase((int) b); + if (lower > 127) { + // fallback + return toLowerCaseSlow(); + } + bytes[i] = (byte) lower; + } + return fromBytes(bytes); + } + + private UTF8String toLowerCaseSlow() { + return fromString(toString().toLowerCase()); + } + + /** + * Returns the title case of this string, that could be used as title. + */ + public UTF8String toTitleCase() { + if (numBytes == 0) { + return EMPTY_UTF8; + } + + byte[] bytes = new byte[numBytes]; + for (int i = 0; i < numBytes; i++) { + byte b = getByte(i); + if (i == 0 || getByte(i - 1) == ' ') { + if (numBytesForFirstByte(b) != 1) { + // fallback + return toTitleCaseSlow(); + } + int upper = Character.toTitleCase(b); + if (upper > 127) { + // fallback + return toTitleCaseSlow(); + } + bytes[i] = (byte) upper; + } else { + bytes[i] = b; + } + } + return fromBytes(bytes); + } + + private UTF8String toTitleCaseSlow() { + StringBuffer sb = new StringBuffer(); + String s = toString(); + sb.append(s); + sb.setCharAt(0, Character.toTitleCase(sb.charAt(0))); + for (int i = 1; i < s.length(); i++) { + if (sb.charAt(i - 1) == ' ') { + sb.setCharAt(i, Character.toTitleCase(sb.charAt(i))); + } + } + return fromString(sb.toString()); + } + + /* + * Returns the index of the string `match` in this String. This string has to be a comma separated + * list. If `match` contains a comma 0 will be returned. If the `match` isn't part of this String, + * 0 will be returned, else the index of match (1-based index) + */ + public int findInSet(UTF8String match) { + if (match.contains(COMMA_UTF8)) { + return 0; + } + + int n = 1, lastComma = -1; + for (int i = 0; i < numBytes; i++) { + if (getByte(i) == (byte) ',') { + if (i - (lastComma + 1) == match.numBytes && + ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset, + match.numBytes)) { + return n; + } + lastComma = i; + n++; + } + } + if (numBytes - (lastComma + 1) == match.numBytes && + ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset, + match.numBytes)) { + return n; + } + return 0; + } + + /** + * Copy the bytes from the current UTF8String, and make a new UTF8String. + * @param start the start position of the current UTF8String in bytes. + * @param end the end position of the current UTF8String in bytes. + * @return a new UTF8String in the position of [start, end] of current UTF8String bytes. + */ + private UTF8String copyUTF8String(int start, int end) { + int len = end - start + 1; + byte[] newBytes = new byte[len]; + copyMemory(base, offset + start, newBytes, BYTE_ARRAY_OFFSET, len); + return UTF8String.fromBytes(newBytes); + } + + public UTF8String trim() { + int s = 0; + int e = this.numBytes - 1; + // skip all of the space (0x20) in the left side + while (s < this.numBytes && getByte(s) <= 0x20 && getByte(s) >= 0x00) s++; + // skip all of the space (0x20) in the right side + while (e >= 0 && getByte(e) <= 0x20 && getByte(e) >= 0x00) e--; + if (s > e) { + // empty string + return UTF8String.fromBytes(new byte[0]); + } else { + return copyUTF8String(s, e); + } + } + + public UTF8String trimLeft() { + int s = 0; + // skip all of the space (0x20) in the left side + while (s < this.numBytes && getByte(s) <= 0x20 && getByte(s) >= 0x00) s++; + if (s == this.numBytes) { + // empty string + return UTF8String.fromBytes(new byte[0]); + } else { + return copyUTF8String(s, this.numBytes - 1); + } + } + + public UTF8String trimRight() { + int e = numBytes - 1; + // skip all of the space (0x20) in the right side + while (e >= 0 && getByte(e) <= 0x20 && getByte(e) >= 0x00) e--; + + if (e < 0) { + // empty string + return UTF8String.fromBytes(new byte[0]); + } else { + return copyUTF8String(0, e); + } + } + + public UTF8String reverse() { + byte[] result = new byte[this.numBytes]; + + int i = 0; // position in byte + while (i < numBytes) { + int len = numBytesForFirstByte(getByte(i)); + copyMemory(this.base, this.offset + i, result, + BYTE_ARRAY_OFFSET + result.length - i - len, len); + + i += len; + } + + return UTF8String.fromBytes(result); + } + + public UTF8String repeat(int times) { + if (times <= 0) { + return EMPTY_UTF8; + } + + byte[] newBytes = new byte[numBytes * times]; + copyMemory(this.base, this.offset, newBytes, BYTE_ARRAY_OFFSET, numBytes); + + int copied = 1; + while (copied < times) { + int toCopy = Math.min(copied, times - copied); + System.arraycopy(newBytes, 0, newBytes, copied * numBytes, numBytes * toCopy); + copied += toCopy; + } + + return UTF8String.fromBytes(newBytes); + } + + /** + * Returns the position of the first occurrence of substr in + * current string from the specified position (0-based index). + * + * @param v the string to be searched + * @param start the start position of the current string for searching + * @return the position of the first occurrence of substr, if not found, -1 returned. + */ + public int indexOf(UTF8String v, int start) { + if (v.numBytes() == 0) { + return 0; + } + + // locate to the start position. + int i = 0; // position in byte + int c = 0; // position in character + while (i < numBytes && c < start) { + i += numBytesForFirstByte(getByte(i)); + c += 1; + } + + do { + if (i + v.numBytes > numBytes) { + return -1; + } + if (ByteArrayMethods.arrayEquals(base, offset + i, v.base, v.offset, v.numBytes)) { + return c; + } + i += numBytesForFirstByte(getByte(i)); + c += 1; + } while (i < numBytes); + + return -1; + } + + /** + * Find the `str` from left to right. + */ + private int find(UTF8String str, int start) { + assert (str.numBytes > 0); + while (start <= numBytes - str.numBytes) { + if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) { + return start; + } + start += 1; + } + return -1; + } + + /** + * Find the `str` from right to left. + */ + private int rfind(UTF8String str, int start) { + assert (str.numBytes > 0); + while (start >= 0) { + if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) { + return start; + } + start -= 1; + } + return -1; + } + + /** + * Returns the substring from string str before count occurrences of the delimiter delim. + * If count is positive, everything the left of the final delimiter (counting from left) is + * returned. If count is negative, every to the right of the final delimiter (counting from the + * right) is returned. subStringIndex performs a case-sensitive match when searching for delim. + */ + public UTF8String subStringIndex(UTF8String delim, int count) { + if (delim.numBytes == 0 || count == 0) { + return EMPTY_UTF8; + } + if (count > 0) { + int idx = -1; + while (count > 0) { + idx = find(delim, idx + 1); + if (idx >= 0) { + count --; + } else { + // can not find enough delim + return this; + } + } + if (idx == 0) { + return EMPTY_UTF8; + } + byte[] bytes = new byte[idx]; + copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, idx); + return fromBytes(bytes); + + } else { + int idx = numBytes - delim.numBytes + 1; + count = -count; + while (count > 0) { + idx = rfind(delim, idx - 1); + if (idx >= 0) { + count --; + } else { + // can not find enough delim + return this; + } + } + if (idx + delim.numBytes == numBytes) { + return EMPTY_UTF8; + } + int size = numBytes - delim.numBytes - idx; + byte[] bytes = new byte[size]; + copyMemory(base, offset + idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size); + return fromBytes(bytes); + } + } + + /** + * Returns str, right-padded with pad to a length of len + * For example: + * ('hi', 5, '??') => 'hi???' + * ('hi', 1, '??') => 'h' + */ + public UTF8String rpad(int len, UTF8String pad) { + int spaces = len - this.numChars(); // number of char need to pad + if (spaces <= 0 || pad.numBytes() == 0) { + // no padding at all, return the substring of the current string + return substring(0, len); + } else { + int padChars = pad.numChars(); + int count = spaces / padChars; // how many padding string needed + // the partial string of the padding + UTF8String remain = pad.substring(0, spaces - padChars * count); + + byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; + copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET, this.numBytes); + int offset = this.numBytes; + int idx = 0; + while (idx < count) { + copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); + ++ idx; + offset += pad.numBytes; + } + copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); + + return UTF8String.fromBytes(data); + } + } + + /** + * Returns str, left-padded with pad to a length of len. + * For example: + * ('hi', 5, '??') => '???hi' + * ('hi', 1, '??') => 'h' + */ + public UTF8String lpad(int len, UTF8String pad) { + int spaces = len - this.numChars(); // number of char need to pad + if (spaces <= 0 || pad.numBytes() == 0) { + // no padding at all, return the substring of the current string + return substring(0, len); + } else { + int padChars = pad.numChars(); + int count = spaces / padChars; // how many padding string needed + // the partial string of the padding + UTF8String remain = pad.substring(0, spaces - padChars * count); + + byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; + + int offset = 0; + int idx = 0; + while (idx < count) { + copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); + ++ idx; + offset += pad.numBytes; + } + copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); + offset += remain.numBytes; + copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET + offset, numBytes()); + + return UTF8String.fromBytes(data); + } + } + + /** + * Concatenates input strings together into a single string. Returns null if any input is null. + */ + public static UTF8String concat(UTF8String... inputs) { + // Compute the total length of the result. + int totalLength = 0; + for (int i = 0; i < inputs.length; i++) { + if (inputs[i] != null) { + totalLength += inputs[i].numBytes; + } else { + return null; + } + } + + // Allocate a new byte array, and copy the inputs one by one into it. + final byte[] result = new byte[totalLength]; + int offset = 0; + for (int i = 0; i < inputs.length; i++) { + int len = inputs[i].numBytes; + copyMemory( + inputs[i].base, inputs[i].offset, + result, BYTE_ARRAY_OFFSET + offset, + len); + offset += len; + } + return fromBytes(result); + } + + /** + * Concatenates input strings together into a single string using the separator. + * A null input is skipped. For example, concat(",", "a", null, "c") would yield "a,c". + */ + public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { + if (separator == null) { + return null; + } + + int numInputBytes = 0; // total number of bytes from the inputs + int numInputs = 0; // number of non-null inputs + for (int i = 0; i < inputs.length; i++) { + if (inputs[i] != null) { + numInputBytes += inputs[i].numBytes; + numInputs++; + } + } + + if (numInputs == 0) { + // Return an empty string if there is no input, or all the inputs are null. + return fromBytes(new byte[0]); + } + + // Allocate a new byte array, and copy the inputs one by one into it. + // The size of the new array is the size of all inputs, plus the separators. + final byte[] result = new byte[numInputBytes + (numInputs - 1) * separator.numBytes]; + int offset = 0; + + for (int i = 0, j = 0; i < inputs.length; i++) { + if (inputs[i] != null) { + int len = inputs[i].numBytes; + copyMemory( + inputs[i].base, inputs[i].offset, + result, BYTE_ARRAY_OFFSET + offset, + len); + offset += len; + + j++; + // Add separator if this is not the last input. + if (j < numInputs) { + copyMemory( + separator.base, separator.offset, + result, BYTE_ARRAY_OFFSET + offset, + separator.numBytes); + offset += separator.numBytes; + } + } + } + return fromBytes(result); + } + + public UTF8String[] split(UTF8String pattern, int limit) { + String[] splits = toString().split(pattern.toString(), limit); + UTF8String[] res = new UTF8String[splits.length]; + for (int i = 0; i < res.length; i++) { + res[i] = fromString(splits[i]); + } + return res; + } + + // TODO: Need to use `Code Point` here instead of Char in case the character longer than 2 bytes + public UTF8String translate(Map dict) { + String srcStr = this.toString(); + + StringBuilder sb = new StringBuilder(); + for(int k = 0; k< srcStr.length(); k++) { + if (null == dict.get(srcStr.charAt(k))) { + sb.append(srcStr.charAt(k)); + } else if ('\0' != dict.get(srcStr.charAt(k))){ + sb.append(dict.get(srcStr.charAt(k))); + } + } + return fromString(sb.toString()); } @Override public String toString() { try { - return new String(bytes, "utf-8"); + return new String(getBytes(), "utf-8"); } catch (UnsupportedEncodingException e) { // Turn the exception into unchecked so we can find out about it at runtime, but // don't need to add lots of boilerplate code everywhere. - PlatformDependent.throwException(e); + throwException(e); return "unknown"; // we will never reach here. } } @Override public UTF8String clone() { - return new UTF8String().set(bytes); + return fromBytes(getBytes()); } @Override - public int compareTo(final UTF8String other) { - final byte[] b = other.getBytes(); - for (int i = 0; i < bytes.length && i < b.length; i++) { - int res = bytes[i] - b[i]; + public int compareTo(@Nonnull final UTF8String other) { + int len = Math.min(numBytes, other.numBytes); + // TODO: compare 8 bytes as unsigned long + for (int i = 0; i < len; i ++) { + // In UTF-8, the byte should be unsigned, so we should compare them as unsigned int. + int res = (getByte(i) & 0xFF) - (other.getByte(i) & 0xFF); if (res != 0) { return res; } } - return bytes.length - b.length; + return numBytes - other.numBytes; } public int compare(final UTF8String other) { @@ -201,18 +844,153 @@ public int compare(final UTF8String other) { @Override public boolean equals(final Object other) { if (other instanceof UTF8String) { - return Arrays.equals(bytes, ((UTF8String) other).getBytes()); - } else if (other instanceof String) { - // Used only in unit tests. - String s = (String) other; - return bytes.length >= s.length() && length() == s.length() && toString().equals(s); + UTF8String o = (UTF8String) other; + if (numBytes != o.numBytes) { + return false; + } + return ByteArrayMethods.arrayEquals(base, offset, o.base, o.offset, numBytes); } else { return false; } } + /** + * Levenshtein distance is a metric for measuring the distance of two strings. The distance is + * defined by the minimum number of single-character edits (i.e. insertions, deletions or + * substitutions) that are required to change one of the strings into the other. + */ + public int levenshteinDistance(UTF8String other) { + // Implementation adopted from org.apache.common.lang3.StringUtils.getLevenshteinDistance + + int n = numChars(); + int m = other.numChars(); + + if (n == 0) { + return m; + } else if (m == 0) { + return n; + } + + UTF8String s, t; + + if (n <= m) { + s = this; + t = other; + } else { + s = other; + t = this; + int swap; + swap = n; + n = m; + m = swap; + } + + int p[] = new int[n + 1]; + int d[] = new int[n + 1]; + int swap[]; + + int i, i_bytes, j, j_bytes, num_bytes_j, cost; + + for (i = 0; i <= n; i++) { + p[i] = i; + } + + for (j = 0, j_bytes = 0; j < m; j_bytes += num_bytes_j, j++) { + num_bytes_j = numBytesForFirstByte(t.getByte(j_bytes)); + d[0] = j + 1; + + for (i = 0, i_bytes = 0; i < n; i_bytes += numBytesForFirstByte(s.getByte(i_bytes)), i++) { + if (s.getByte(i_bytes) != t.getByte(j_bytes) || + num_bytes_j != numBytesForFirstByte(s.getByte(i_bytes))) { + cost = 1; + } else { + cost = (ByteArrayMethods.arrayEquals(t.base, t.offset + j_bytes, s.base, + s.offset + i_bytes, num_bytes_j)) ? 0 : 1; + } + d[i + 1] = Math.min(Math.min(d[i] + 1, p[i + 1] + 1), p[i] + cost); + } + + swap = p; + p = d; + d = swap; + } + + return p[n]; + } + @Override public int hashCode() { - return Arrays.hashCode(bytes); + int result = 1; + for (int i = 0; i < numBytes; i ++) { + result = 31 * result + getByte(i); + } + return result; + } + + /** + * Soundex mapping table + */ + private static final byte[] US_ENGLISH_MAPPING = {'0', '1', '2', '3', '0', '1', '2', '7', + '0', '2', '2', '4', '5', '5', '0', '1', '2', '6', '2', '3', '0', '1', '7', '2', '0', '2'}; + + /** + * Encodes a string into a Soundex value. Soundex is an encoding used to relate similar names, + * but can also be used as a general purpose scheme to find word with similar phonemes. + * https://en.wikipedia.org/wiki/Soundex + */ + public UTF8String soundex() { + if (numBytes == 0) { + return EMPTY_UTF8; + } + + byte b = getByte(0); + if ('a' <= b && b <= 'z') { + b -= 32; + } else if (b < 'A' || 'Z' < b) { + // first character must be a letter + return this; + } + byte sx[] = {'0', '0', '0', '0'}; + sx[0] = b; + int sxi = 1; + int idx = b - 'A'; + byte lastCode = US_ENGLISH_MAPPING[idx]; + + for (int i = 1; i < numBytes; i++) { + b = getByte(i); + if ('a' <= b && b <= 'z') { + b -= 32; + } else if (b < 'A' || 'Z' < b) { + // not a letter, skip it + lastCode = '0'; + continue; + } + idx = b - 'A'; + byte code = US_ENGLISH_MAPPING[idx]; + if (code == '7') { + // ignore it + } else { + if (code != '0' && code != lastCode) { + sx[sxi++] = code; + if (sxi > 3) break; + } + lastCode = code; + } + } + return UTF8String.fromBytes(sx); + } + + public void writeExternal(ObjectOutput out) throws IOException { + byte[] bytes = getBytes(); + out.writeInt(bytes.length); + out.write(bytes); + } + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + offset = BYTE_ARRAY_OFFSET; + numBytes = in.readInt(); + base = new byte[numBytes]; + in.readFully((byte[]) base); } + } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java index 3b9175835229..2f8cb132ac8b 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -22,7 +22,7 @@ import java.util.Set; import junit.framework.Assert; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.junit.Test; /** @@ -83,11 +83,11 @@ public void randomizedStressTestBytes() { rand.nextBytes(bytes); Assert.assertEquals( - hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize), - hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); hashcodes.add(hasher.hashUnsafeWords( - bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); } // A very loose bound. @@ -106,11 +106,11 @@ public void randomizedStressTestPaddedStrings() { System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); Assert.assertEquals( - hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize), - hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); hashcodes.add(hasher.hashUnsafeWords( - paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); } // A very loose bound. diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java deleted file mode 100644 index 81315f7c9464..000000000000 --- a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ /dev/null @@ -1,383 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.map; - -import java.lang.Exception; -import java.nio.ByteBuffer; -import java.util.*; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; -import static org.mockito.AdditionalMatchers.geq; -import static org.mockito.Mockito.*; - -import org.apache.spark.unsafe.array.ByteArrayMethods; -import org.apache.spark.unsafe.memory.*; -import org.apache.spark.unsafe.PlatformDependent; -import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET; -import static org.apache.spark.unsafe.PlatformDependent.LONG_ARRAY_OFFSET; - - -public abstract class AbstractBytesToBytesMapSuite { - - private final Random rand = new Random(42); - - private TaskMemoryManager memoryManager; - private TaskMemoryManager sizeLimitedMemoryManager; - - @Before - public void setup() { - memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator())); - // Mocked memory manager for tests that check the maximum array size, since actually allocating - // such large arrays will cause us to run out of memory in our tests. - sizeLimitedMemoryManager = spy(memoryManager); - when(sizeLimitedMemoryManager.allocate(geq(1L << 20))).thenAnswer(new Answer() { - @Override - public MemoryBlock answer(InvocationOnMock invocation) throws Throwable { - if (((Long) invocation.getArguments()[0] / 8) > Integer.MAX_VALUE) { - throw new OutOfMemoryError("Requested array size exceeds VM limit"); - } - return memoryManager.allocate(1L << 20); - } - }); - } - - @After - public void tearDown() { - if (memoryManager != null) { - memoryManager.cleanUpAllAllocatedMemory(); - memoryManager = null; - } - } - - protected abstract MemoryAllocator getMemoryAllocator(); - - private static byte[] getByteArray(MemoryLocation loc, int size) { - final byte[] arr = new byte[size]; - PlatformDependent.copyMemory( - loc.getBaseObject(), - loc.getBaseOffset(), - arr, - BYTE_ARRAY_OFFSET, - size - ); - return arr; - } - - private byte[] getRandomByteArray(int numWords) { - Assert.assertTrue(numWords > 0); - final int lengthInBytes = numWords * 8; - final byte[] bytes = new byte[lengthInBytes]; - rand.nextBytes(bytes); - return bytes; - } - - /** - * Fast equality checking for byte arrays, since these comparisons are a bottleneck - * in our stress tests. - */ - private static boolean arrayEquals( - byte[] expected, - MemoryLocation actualAddr, - long actualLengthBytes) { - return (actualLengthBytes == expected.length) && ByteArrayMethods.wordAlignedArrayEquals( - expected, - BYTE_ARRAY_OFFSET, - actualAddr.getBaseObject(), - actualAddr.getBaseOffset(), - expected.length - ); - } - - @Test - public void emptyMap() { - BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64); - try { - Assert.assertEquals(0, map.size()); - final int keyLengthInWords = 10; - final int keyLengthInBytes = keyLengthInWords * 8; - final byte[] key = getRandomByteArray(keyLengthInWords); - Assert.assertFalse(map.lookup(key, BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined()); - Assert.assertFalse(map.iterator().hasNext()); - } finally { - map.free(); - } - } - - @Test - public void setAndRetrieveAKey() { - BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64); - final int recordLengthWords = 10; - final int recordLengthBytes = recordLengthWords * 8; - final byte[] keyData = getRandomByteArray(recordLengthWords); - final byte[] valueData = getRandomByteArray(recordLengthWords); - try { - final BytesToBytesMap.Location loc = - map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes); - Assert.assertFalse(loc.isDefined()); - loc.putNewKey( - keyData, - BYTE_ARRAY_OFFSET, - recordLengthBytes, - valueData, - BYTE_ARRAY_OFFSET, - recordLengthBytes - ); - // After storing the key and value, the other location methods should return results that - // reflect the result of this store without us having to call lookup() again on the same key. - Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); - Assert.assertEquals(recordLengthBytes, loc.getValueLength()); - Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); - Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); - - // After calling lookup() the location should still point to the correct data. - Assert.assertTrue(map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined()); - Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); - Assert.assertEquals(recordLengthBytes, loc.getValueLength()); - Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); - Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); - - try { - loc.putNewKey( - keyData, - BYTE_ARRAY_OFFSET, - recordLengthBytes, - valueData, - BYTE_ARRAY_OFFSET, - recordLengthBytes - ); - Assert.fail("Should not be able to set a new value for a key"); - } catch (AssertionError e) { - // Expected exception; do nothing. - } - } finally { - map.free(); - } - } - - @Test - public void iteratorTest() throws Exception { - final int size = 4096; - BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2); - try { - for (long i = 0; i < size; i++) { - final long[] value = new long[] { i }; - final BytesToBytesMap.Location loc = - map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8); - Assert.assertFalse(loc.isDefined()); - // Ensure that we store some zero-length keys - if (i % 5 == 0) { - loc.putNewKey( - null, - PlatformDependent.LONG_ARRAY_OFFSET, - 0, - value, - PlatformDependent.LONG_ARRAY_OFFSET, - 8 - ); - } else { - loc.putNewKey( - value, - PlatformDependent.LONG_ARRAY_OFFSET, - 8, - value, - PlatformDependent.LONG_ARRAY_OFFSET, - 8 - ); - } - } - final java.util.BitSet valuesSeen = new java.util.BitSet(size); - final Iterator iter = map.iterator(); - while (iter.hasNext()) { - final BytesToBytesMap.Location loc = iter.next(); - Assert.assertTrue(loc.isDefined()); - final MemoryLocation keyAddress = loc.getKeyAddress(); - final MemoryLocation valueAddress = loc.getValueAddress(); - final long value = PlatformDependent.UNSAFE.getLong( - valueAddress.getBaseObject(), valueAddress.getBaseOffset()); - final long keyLength = loc.getKeyLength(); - if (keyLength == 0) { - Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0); - } else { - final long key = PlatformDependent.UNSAFE.getLong( - keyAddress.getBaseObject(), keyAddress.getBaseOffset()); - Assert.assertEquals(value, key); - } - valuesSeen.set((int) value); - } - Assert.assertEquals(size, valuesSeen.cardinality()); - } finally { - map.free(); - } - } - - @Test - public void iteratingOverDataPagesWithWastedSpace() throws Exception { - final int NUM_ENTRIES = 1000 * 1000; - final int KEY_LENGTH = 16; - final int VALUE_LENGTH = 40; - final BytesToBytesMap map = new BytesToBytesMap(memoryManager, NUM_ENTRIES); - // Each record will take 8 + 8 + 16 + 40 = 72 bytes of space in the data page. Our 64-megabyte - // pages won't be evenly-divisible by records of this size, which will cause us to waste some - // space at the end of the page. This is necessary in order for us to take the end-of-record - // handling branch in iterator(). - try { - for (int i = 0; i < NUM_ENTRIES; i++) { - final long[] key = new long[] { i, i }; // 2 * 8 = 16 bytes - final long[] value = new long[] { i, i, i, i, i }; // 5 * 8 = 40 bytes - final BytesToBytesMap.Location loc = map.lookup( - key, - LONG_ARRAY_OFFSET, - KEY_LENGTH - ); - Assert.assertFalse(loc.isDefined()); - loc.putNewKey( - key, - LONG_ARRAY_OFFSET, - KEY_LENGTH, - value, - LONG_ARRAY_OFFSET, - VALUE_LENGTH - ); - } - Assert.assertEquals(2, map.getNumDataPages()); - - final java.util.BitSet valuesSeen = new java.util.BitSet(NUM_ENTRIES); - final Iterator iter = map.iterator(); - final long key[] = new long[KEY_LENGTH / 8]; - final long value[] = new long[VALUE_LENGTH / 8]; - while (iter.hasNext()) { - final BytesToBytesMap.Location loc = iter.next(); - Assert.assertTrue(loc.isDefined()); - Assert.assertEquals(KEY_LENGTH, loc.getKeyLength()); - Assert.assertEquals(VALUE_LENGTH, loc.getValueLength()); - PlatformDependent.copyMemory( - loc.getKeyAddress().getBaseObject(), - loc.getKeyAddress().getBaseOffset(), - key, - LONG_ARRAY_OFFSET, - KEY_LENGTH - ); - PlatformDependent.copyMemory( - loc.getValueAddress().getBaseObject(), - loc.getValueAddress().getBaseOffset(), - value, - LONG_ARRAY_OFFSET, - VALUE_LENGTH - ); - for (long j : key) { - Assert.assertEquals(key[0], j); - } - for (long j : value) { - Assert.assertEquals(key[0], j); - } - valuesSeen.set((int) key[0]); - } - Assert.assertEquals(NUM_ENTRIES, valuesSeen.cardinality()); - } finally { - map.free(); - } - } - - @Test - public void randomizedStressTest() { - final int size = 65536; - // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays - // into ByteBuffers in order to use them as keys here. - final Map expected = new HashMap(); - final BytesToBytesMap map = new BytesToBytesMap(memoryManager, size); - - try { - // Fill the map to 90% full so that we can trigger probing - for (int i = 0; i < size * 0.9; i++) { - final byte[] key = getRandomByteArray(rand.nextInt(256) + 1); - final byte[] value = getRandomByteArray(rand.nextInt(512) + 1); - if (!expected.containsKey(ByteBuffer.wrap(key))) { - expected.put(ByteBuffer.wrap(key), value); - final BytesToBytesMap.Location loc = map.lookup( - key, - BYTE_ARRAY_OFFSET, - key.length - ); - Assert.assertFalse(loc.isDefined()); - loc.putNewKey( - key, - BYTE_ARRAY_OFFSET, - key.length, - value, - BYTE_ARRAY_OFFSET, - value.length - ); - // After calling putNewKey, the following should be true, even before calling - // lookup(): - Assert.assertTrue(loc.isDefined()); - Assert.assertEquals(key.length, loc.getKeyLength()); - Assert.assertEquals(value.length, loc.getValueLength()); - Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length)); - Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length)); - } - } - - for (Map.Entry entry : expected.entrySet()) { - final byte[] key = entry.getKey().array(); - final byte[] value = entry.getValue(); - final BytesToBytesMap.Location loc = map.lookup(key, BYTE_ARRAY_OFFSET, key.length); - Assert.assertTrue(loc.isDefined()); - Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); - Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); - } - } finally { - map.free(); - } - } - - @Test - public void initialCapacityBoundsChecking() { - try { - new BytesToBytesMap(sizeLimitedMemoryManager, 0); - Assert.fail("Expected IllegalArgumentException to be thrown"); - } catch (IllegalArgumentException e) { - // expected exception - } - - try { - new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY + 1); - Assert.fail("Expected IllegalArgumentException to be thrown"); - } catch (IllegalArgumentException e) { - // expected exception - } - - // Can allocate _at_ the max capacity - BytesToBytesMap map = - new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY); - map.free(); - } - - @Test - public void resizingLargeMap() { - // As long as a map's capacity is below the max, we should be able to resize up to the max - BytesToBytesMap map = - new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY - 64); - map.growAndRehash(); - map.free(); - } -} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java new file mode 100644 index 000000000000..80d4982c4b57 --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java @@ -0,0 +1,240 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.unsafe.types; + +import org.junit.Test; + +import static junit.framework.Assert.*; +import static org.apache.spark.unsafe.types.CalendarInterval.*; + +public class CalendarIntervalSuite { + + @Test + public void equalsTest() { + CalendarInterval i1 = new CalendarInterval(3, 123); + CalendarInterval i2 = new CalendarInterval(3, 321); + CalendarInterval i3 = new CalendarInterval(1, 123); + CalendarInterval i4 = new CalendarInterval(3, 123); + + assertNotSame(i1, i2); + assertNotSame(i1, i3); + assertNotSame(i2, i3); + assertEquals(i1, i4); + } + + @Test + public void toStringTest() { + CalendarInterval i; + + i = new CalendarInterval(34, 0); + assertEquals(i.toString(), "interval 2 years 10 months"); + + i = new CalendarInterval(-34, 0); + assertEquals(i.toString(), "interval -2 years -10 months"); + + i = new CalendarInterval(0, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); + assertEquals(i.toString(), "interval 3 weeks 13 hours 123 microseconds"); + + i = new CalendarInterval(0, -3 * MICROS_PER_WEEK - 13 * MICROS_PER_HOUR - 123); + assertEquals(i.toString(), "interval -3 weeks -13 hours -123 microseconds"); + + i = new CalendarInterval(34, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); + assertEquals(i.toString(), "interval 2 years 10 months 3 weeks 13 hours 123 microseconds"); + } + + @Test + public void fromStringTest() { + testSingleUnit("year", 3, 36, 0); + testSingleUnit("month", 3, 3, 0); + testSingleUnit("week", 3, 0, 3 * MICROS_PER_WEEK); + testSingleUnit("day", 3, 0, 3 * MICROS_PER_DAY); + testSingleUnit("hour", 3, 0, 3 * MICROS_PER_HOUR); + testSingleUnit("minute", 3, 0, 3 * MICROS_PER_MINUTE); + testSingleUnit("second", 3, 0, 3 * MICROS_PER_SECOND); + testSingleUnit("millisecond", 3, 0, 3 * MICROS_PER_MILLI); + testSingleUnit("microsecond", 3, 0, 3); + + String input; + + input = "interval -5 years 23 month"; + CalendarInterval result = new CalendarInterval(-5 * 12 + 23, 0); + assertEquals(CalendarInterval.fromString(input), result); + + input = "interval -5 years 23 month "; + assertEquals(CalendarInterval.fromString(input), result); + + input = " interval -5 years 23 month "; + assertEquals(CalendarInterval.fromString(input), result); + + // Error cases + input = "interval 3month 1 hour"; + assertEquals(CalendarInterval.fromString(input), null); + + input = "interval 3 moth 1 hour"; + assertEquals(CalendarInterval.fromString(input), null); + + input = "interval"; + assertEquals(CalendarInterval.fromString(input), null); + + input = "int"; + assertEquals(CalendarInterval.fromString(input), null); + + input = ""; + assertEquals(CalendarInterval.fromString(input), null); + + input = null; + assertEquals(CalendarInterval.fromString(input), null); + } + + @Test + public void fromYearMonthStringTest() { + String input; + CalendarInterval i; + + input = "99-10"; + i = new CalendarInterval(99 * 12 + 10, 0L); + assertEquals(CalendarInterval.fromYearMonthString(input), i); + + input = "-8-10"; + i = new CalendarInterval(-8 * 12 - 10, 0L); + assertEquals(CalendarInterval.fromYearMonthString(input), i); + + try { + input = "99-15"; + CalendarInterval.fromYearMonthString(input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("month 15 outside range")); + } + } + + @Test + public void fromDayTimeStringTest() { + String input; + CalendarInterval i; + + input = "5 12:40:30.999999999"; + i = new CalendarInterval(0, 5 * MICROS_PER_DAY + 12 * MICROS_PER_HOUR + + 40 * MICROS_PER_MINUTE + 30 * MICROS_PER_SECOND + 999999L); + assertEquals(CalendarInterval.fromDayTimeString(input), i); + + input = "10 0:12:0.888"; + i = new CalendarInterval(0, 10 * MICROS_PER_DAY + 12 * MICROS_PER_MINUTE); + assertEquals(CalendarInterval.fromDayTimeString(input), i); + + input = "-3 0:0:0"; + i = new CalendarInterval(0, -3 * MICROS_PER_DAY); + assertEquals(CalendarInterval.fromDayTimeString(input), i); + + try { + input = "5 30:12:20"; + CalendarInterval.fromDayTimeString(input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("hour 30 outside range")); + } + + try { + input = "5 30-12"; + CalendarInterval.fromDayTimeString(input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("not match day-time format")); + } + } + + @Test + public void fromSingleUnitStringTest() { + String input; + CalendarInterval i; + + input = "12"; + i = new CalendarInterval(12 * 12, 0L); + assertEquals(CalendarInterval.fromSingleUnitString("year", input), i); + + input = "100"; + i = new CalendarInterval(0, 100 * MICROS_PER_DAY); + assertEquals(CalendarInterval.fromSingleUnitString("day", input), i); + + input = "1999.38888"; + i = new CalendarInterval(0, 1999 * MICROS_PER_SECOND + 38); + assertEquals(CalendarInterval.fromSingleUnitString("second", input), i); + + try { + input = String.valueOf(Integer.MAX_VALUE); + CalendarInterval.fromSingleUnitString("year", input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("outside range")); + } + + try { + input = String.valueOf(Long.MAX_VALUE / MICROS_PER_HOUR + 1); + CalendarInterval.fromSingleUnitString("hour", input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("outside range")); + } + } + + @Test + public void addTest() { + String input = "interval 3 month 1 hour"; + String input2 = "interval 2 month 100 hour"; + + CalendarInterval interval = CalendarInterval.fromString(input); + CalendarInterval interval2 = CalendarInterval.fromString(input2); + + assertEquals(interval.add(interval2), new CalendarInterval(5, 101 * MICROS_PER_HOUR)); + + input = "interval -10 month -81 hour"; + input2 = "interval 75 month 200 hour"; + + interval = CalendarInterval.fromString(input); + interval2 = CalendarInterval.fromString(input2); + + assertEquals(interval.add(interval2), new CalendarInterval(65, 119 * MICROS_PER_HOUR)); + } + + @Test + public void subtractTest() { + String input = "interval 3 month 1 hour"; + String input2 = "interval 2 month 100 hour"; + + CalendarInterval interval = CalendarInterval.fromString(input); + CalendarInterval interval2 = CalendarInterval.fromString(input2); + + assertEquals(interval.subtract(interval2), new CalendarInterval(1, -99 * MICROS_PER_HOUR)); + + input = "interval -10 month -81 hour"; + input2 = "interval 75 month 200 hour"; + + interval = CalendarInterval.fromString(input); + interval2 = CalendarInterval.fromString(input2); + + assertEquals(interval.subtract(interval2), new CalendarInterval(-85, -281 * MICROS_PER_HOUR)); + } + + private void testSingleUnit(String unit, int number, int months, long microseconds) { + String input1 = "interval " + number + " " + unit; + String input2 = "interval " + number + " " + unit + "s"; + CalendarInterval result = new CalendarInterval(months, microseconds); + assertEquals(CalendarInterval.fromString(input1), result); + assertEquals(CalendarInterval.fromString(input2), result); + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 80c179a1b5e7..98aa8a2469a7 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -18,76 +18,475 @@ package org.apache.spark.unsafe.types; import java.io.UnsupportedEncodingException; +import java.util.Arrays; +import java.util.HashMap; -import junit.framework.Assert; +import com.google.common.collect.ImmutableMap; import org.junit.Test; +import static junit.framework.Assert.*; + +import static org.apache.spark.unsafe.types.UTF8String.*; + public class UTF8StringSuite { private void checkBasic(String str, int len) throws UnsupportedEncodingException { - Assert.assertEquals(UTF8String.fromString(str).length(), len); - Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")).length(), len); + UTF8String s1 = fromString(str); + UTF8String s2 = fromBytes(str.getBytes("utf8")); + assertEquals(s1.numChars(), len); + assertEquals(s2.numChars(), len); + + assertEquals(s1.toString(), str); + assertEquals(s2.toString(), str); + assertEquals(s1, s2); - Assert.assertEquals(UTF8String.fromString(str), str); - Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")), str); - Assert.assertEquals(UTF8String.fromString(str).toString(), str); - Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")).toString(), str); - Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")), UTF8String.fromString(str)); + assertEquals(s1.hashCode(), s2.hashCode()); - Assert.assertEquals(UTF8String.fromString(str).hashCode(), - UTF8String.fromBytes(str.getBytes("utf8")).hashCode()); + assertEquals(s1.compareTo(s2), 0); + + assertEquals(s1.contains(s2), true); + assertEquals(s2.contains(s1), true); + assertEquals(s1.startsWith(s1), true); + assertEquals(s1.endsWith(s1), true); } @Test public void basicTest() throws UnsupportedEncodingException { + checkBasic("", 0); checkBasic("hello", 5); - checkBasic("世 界", 3); + checkBasic("大 千 世 界", 7); + } + + @Test + public void emptyStringTest() { + assertEquals(fromString(""), EMPTY_UTF8); + assertEquals(fromBytes(new byte[0]), EMPTY_UTF8); + assertEquals(0, EMPTY_UTF8.numChars()); + assertEquals(0, EMPTY_UTF8.numBytes()); + } + + @Test + public void prefix() { + assertTrue(fromString("a").getPrefix() - fromString("b").getPrefix() < 0); + assertTrue(fromString("ab").getPrefix() - fromString("b").getPrefix() < 0); + assertTrue( + fromString("abbbbbbbbbbbasdf").getPrefix() - fromString("bbbbbbbbbbbbasdf").getPrefix() < 0); + assertTrue(fromString("").getPrefix() - fromString("a").getPrefix() < 0); + assertTrue(fromString("你好").getPrefix() - fromString("世界").getPrefix() > 0); + + byte[] buf1 = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + byte[] buf2 = {1, 2, 3}; + UTF8String str1 = UTF8String.fromBytes(buf1, 0, 3); + UTF8String str2 = UTF8String.fromBytes(buf1, 0, 8); + UTF8String str3 = UTF8String.fromBytes(buf2); + assertTrue(str1.getPrefix() - str2.getPrefix() < 0); + assertEquals(str1.getPrefix(), str3.getPrefix()); + } + + @Test + public void compareTo() { + assertTrue(fromString("").compareTo(fromString("a")) < 0); + assertTrue(fromString("abc").compareTo(fromString("ABC")) > 0); + assertTrue(fromString("abc0").compareTo(fromString("abc")) > 0); + assertTrue(fromString("abcabcabc").compareTo(fromString("abcabcabc")) == 0); + assertTrue(fromString("aBcabcabc").compareTo(fromString("Abcabcabc")) > 0); + assertTrue(fromString("Abcabcabc").compareTo(fromString("abcabcabC")) < 0); + assertTrue(fromString("abcabcabc").compareTo(fromString("abcabcabC")) > 0); + + assertTrue(fromString("abc").compareTo(fromString("世界")) < 0); + assertTrue(fromString("你好").compareTo(fromString("世界")) > 0); + assertTrue(fromString("你好123").compareTo(fromString("你好122")) > 0); + } + + protected void testUpperandLower(String upper, String lower) { + UTF8String us = fromString(upper); + UTF8String ls = fromString(lower); + assertEquals(ls, us.toLowerCase()); + assertEquals(us, ls.toUpperCase()); + assertEquals(us, us.toUpperCase()); + assertEquals(ls, ls.toLowerCase()); + } + + @Test + public void upperAndLower() { + testUpperandLower("", ""); + testUpperandLower("0123456", "0123456"); + testUpperandLower("ABCXYZ", "abcxyz"); + testUpperandLower("ЀЁЂѺΏỀ", "ѐёђѻώề"); + testUpperandLower("大千世界 数据砖头", "大千世界 数据砖头"); + } + + @Test + public void titleCase() { + assertEquals(fromString(""), fromString("").toTitleCase()); + assertEquals(fromString("Ab Bc Cd"), fromString("ab bc cd").toTitleCase()); + assertEquals(fromString("Ѐ Ё Ђ Ѻ Ώ Ề"), fromString("ѐ ё ђ ѻ ώ ề").toTitleCase()); + assertEquals(fromString("大千世界 数据砖头"), fromString("大千世界 数据砖头").toTitleCase()); + } + + @Test + public void concatTest() { + assertEquals(EMPTY_UTF8, concat()); + assertEquals(null, concat((UTF8String) null)); + assertEquals(EMPTY_UTF8, concat(EMPTY_UTF8)); + assertEquals(fromString("ab"), concat(fromString("ab"))); + assertEquals(fromString("ab"), concat(fromString("a"), fromString("b"))); + assertEquals(fromString("abc"), concat(fromString("a"), fromString("b"), fromString("c"))); + assertEquals(null, concat(fromString("a"), null, fromString("c"))); + assertEquals(null, concat(fromString("a"), null, null)); + assertEquals(null, concat(null, null, null)); + assertEquals(fromString("数据砖头"), concat(fromString("数据"), fromString("砖头"))); + } + + @Test + public void concatWsTest() { + // Returns null if the separator is null + assertEquals(null, concatWs(null, (UTF8String)null)); + assertEquals(null, concatWs(null, fromString("a"))); + + // If separator is null, concatWs should skip all null inputs and never return null. + UTF8String sep = fromString("哈哈"); + assertEquals( + EMPTY_UTF8, + concatWs(sep, EMPTY_UTF8)); + assertEquals( + fromString("ab"), + concatWs(sep, fromString("ab"))); + assertEquals( + fromString("a哈哈b"), + concatWs(sep, fromString("a"), fromString("b"))); + assertEquals( + fromString("a哈哈b哈哈c"), + concatWs(sep, fromString("a"), fromString("b"), fromString("c"))); + assertEquals( + fromString("a哈哈c"), + concatWs(sep, fromString("a"), null, fromString("c"))); + assertEquals( + fromString("a"), + concatWs(sep, fromString("a"), null, null)); + assertEquals( + EMPTY_UTF8, + concatWs(sep, null, null, null)); + assertEquals( + fromString("数据哈哈砖头"), + concatWs(sep, fromString("数据"), fromString("砖头"))); } @Test public void contains() { - Assert.assertTrue(UTF8String.fromString("hello").contains(UTF8String.fromString("ello"))); - Assert.assertFalse(UTF8String.fromString("hello").contains(UTF8String.fromString("vello"))); - Assert.assertFalse(UTF8String.fromString("hello").contains(UTF8String.fromString("hellooo"))); - Assert.assertTrue(UTF8String.fromString("大千世界").contains(UTF8String.fromString("千世"))); - Assert.assertFalse(UTF8String.fromString("大千世界").contains(UTF8String.fromString("世千"))); - Assert.assertFalse( - UTF8String.fromString("大千世界").contains(UTF8String.fromString("大千世界好"))); + assertTrue(EMPTY_UTF8.contains(EMPTY_UTF8)); + assertTrue(fromString("hello").contains(fromString("ello"))); + assertFalse(fromString("hello").contains(fromString("vello"))); + assertFalse(fromString("hello").contains(fromString("hellooo"))); + assertTrue(fromString("大千世界").contains(fromString("千世界"))); + assertFalse(fromString("大千世界").contains(fromString("世千"))); + assertFalse(fromString("大千世界").contains(fromString("大千世界好"))); } @Test public void startsWith() { - Assert.assertTrue(UTF8String.fromString("hello").startsWith(UTF8String.fromString("hell"))); - Assert.assertFalse(UTF8String.fromString("hello").startsWith(UTF8String.fromString("ell"))); - Assert.assertFalse(UTF8String.fromString("hello").startsWith(UTF8String.fromString("hellooo"))); - Assert.assertTrue(UTF8String.fromString("数据砖头").startsWith(UTF8String.fromString("数据"))); - Assert.assertFalse(UTF8String.fromString("大千世界").startsWith(UTF8String.fromString("千"))); - Assert.assertFalse( - UTF8String.fromString("大千世界").startsWith(UTF8String.fromString("大千世界好"))); + assertTrue(EMPTY_UTF8.startsWith(EMPTY_UTF8)); + assertTrue(fromString("hello").startsWith(fromString("hell"))); + assertFalse(fromString("hello").startsWith(fromString("ell"))); + assertFalse(fromString("hello").startsWith(fromString("hellooo"))); + assertTrue(fromString("数据砖头").startsWith(fromString("数据"))); + assertFalse(fromString("大千世界").startsWith(fromString("千"))); + assertFalse(fromString("大千世界").startsWith(fromString("大千世界好"))); } @Test public void endsWith() { - Assert.assertTrue(UTF8String.fromString("hello").endsWith(UTF8String.fromString("ello"))); - Assert.assertFalse(UTF8String.fromString("hello").endsWith(UTF8String.fromString("ellov"))); - Assert.assertFalse(UTF8String.fromString("hello").endsWith(UTF8String.fromString("hhhello"))); - Assert.assertTrue(UTF8String.fromString("大千世界").endsWith(UTF8String.fromString("世界"))); - Assert.assertFalse(UTF8String.fromString("大千世界").endsWith(UTF8String.fromString("世"))); - Assert.assertFalse( - UTF8String.fromString("数据砖头").endsWith(UTF8String.fromString("我的数据砖头"))); + assertTrue(EMPTY_UTF8.endsWith(EMPTY_UTF8)); + assertTrue(fromString("hello").endsWith(fromString("ello"))); + assertFalse(fromString("hello").endsWith(fromString("ellov"))); + assertFalse(fromString("hello").endsWith(fromString("hhhello"))); + assertTrue(fromString("大千世界").endsWith(fromString("世界"))); + assertFalse(fromString("大千世界").endsWith(fromString("世"))); + assertFalse(fromString("数据砖头").endsWith(fromString("我的数据砖头"))); } @Test public void substring() { - Assert.assertEquals( - UTF8String.fromString("hello").substring(0, 0), UTF8String.fromString("")); - Assert.assertEquals( - UTF8String.fromString("hello").substring(1, 3), UTF8String.fromString("el")); - Assert.assertEquals( - UTF8String.fromString("数据砖头").substring(0, 1), UTF8String.fromString("数")); - Assert.assertEquals( - UTF8String.fromString("数据砖头").substring(1, 3), UTF8String.fromString("据砖")); - Assert.assertEquals( - UTF8String.fromString("数据砖头").substring(3, 5), UTF8String.fromString("头")); + assertEquals(EMPTY_UTF8, fromString("hello").substring(0, 0)); + assertEquals(fromString("el"), fromString("hello").substring(1, 3)); + assertEquals(fromString("数"), fromString("数据砖头").substring(0, 1)); + assertEquals(fromString("据砖"), fromString("数据砖头").substring(1, 3)); + assertEquals(fromString("头"), fromString("数据砖头").substring(3, 5)); + assertEquals(fromString("ߵ梷"), fromString("ߵ梷").substring(0, 2)); + } + + @Test + public void trims() { + assertEquals(fromString("hello"), fromString(" hello ").trim()); + assertEquals(fromString("hello "), fromString(" hello ").trimLeft()); + assertEquals(fromString(" hello"), fromString(" hello ").trimRight()); + + assertEquals(EMPTY_UTF8, fromString(" ").trim()); + assertEquals(EMPTY_UTF8, fromString(" ").trimLeft()); + assertEquals(EMPTY_UTF8, fromString(" ").trimRight()); + + assertEquals(fromString("数据砖头"), fromString(" 数据砖头 ").trim()); + assertEquals(fromString("数据砖头 "), fromString(" 数据砖头 ").trimLeft()); + assertEquals(fromString(" 数据砖头"), fromString(" 数据砖头 ").trimRight()); + + assertEquals(fromString("数据砖头"), fromString("数据砖头").trim()); + assertEquals(fromString("数据砖头"), fromString("数据砖头").trimLeft()); + assertEquals(fromString("数据砖头"), fromString("数据砖头").trimRight()); + } + + @Test + public void indexOf() { + assertEquals(0, EMPTY_UTF8.indexOf(EMPTY_UTF8, 0)); + assertEquals(-1, EMPTY_UTF8.indexOf(fromString("l"), 0)); + assertEquals(0, fromString("hello").indexOf(EMPTY_UTF8, 0)); + assertEquals(2, fromString("hello").indexOf(fromString("l"), 0)); + assertEquals(3, fromString("hello").indexOf(fromString("l"), 3)); + assertEquals(-1, fromString("hello").indexOf(fromString("a"), 0)); + assertEquals(2, fromString("hello").indexOf(fromString("ll"), 0)); + assertEquals(-1, fromString("hello").indexOf(fromString("ll"), 4)); + assertEquals(1, fromString("数据砖头").indexOf(fromString("据砖"), 0)); + assertEquals(-1, fromString("数据砖头").indexOf(fromString("数"), 3)); + assertEquals(0, fromString("数据砖头").indexOf(fromString("数"), 0)); + assertEquals(3, fromString("数据砖头").indexOf(fromString("头"), 0)); + } + + @Test + public void substring_index() { + assertEquals(fromString("www.apache.org"), + fromString("www.apache.org").subStringIndex(fromString("."), 3)); + assertEquals(fromString("www.apache"), + fromString("www.apache.org").subStringIndex(fromString("."), 2)); + assertEquals(fromString("www"), + fromString("www.apache.org").subStringIndex(fromString("."), 1)); + assertEquals(fromString(""), + fromString("www.apache.org").subStringIndex(fromString("."), 0)); + assertEquals(fromString("org"), + fromString("www.apache.org").subStringIndex(fromString("."), -1)); + assertEquals(fromString("apache.org"), + fromString("www.apache.org").subStringIndex(fromString("."), -2)); + assertEquals(fromString("www.apache.org"), + fromString("www.apache.org").subStringIndex(fromString("."), -3)); + // str is empty string + assertEquals(fromString(""), + fromString("").subStringIndex(fromString("."), 1)); + // empty string delim + assertEquals(fromString(""), + fromString("www.apache.org").subStringIndex(fromString(""), 1)); + // delim does not exist in str + assertEquals(fromString("www.apache.org"), + fromString("www.apache.org").subStringIndex(fromString("#"), 2)); + // delim is 2 chars + assertEquals(fromString("www||apache"), + fromString("www||apache||org").subStringIndex(fromString("||"), 2)); + assertEquals(fromString("apache||org"), + fromString("www||apache||org").subStringIndex(fromString("||"), -2)); + // non ascii chars + assertEquals(fromString("大千世界大"), + fromString("大千世界大千世界").subStringIndex(fromString("千"), 2)); + // overlapped delim + assertEquals(fromString("||"), fromString("||||||").subStringIndex(fromString("|||"), 3)); + assertEquals(fromString("|||"), fromString("||||||").subStringIndex(fromString("|||"), -4)); + } + + @Test + public void reverse() { + assertEquals(fromString("olleh"), fromString("hello").reverse()); + assertEquals(EMPTY_UTF8, EMPTY_UTF8.reverse()); + assertEquals(fromString("者行孙"), fromString("孙行者").reverse()); + assertEquals(fromString("者行孙 olleh"), fromString("hello 孙行者").reverse()); + } + + @Test + public void repeat() { + assertEquals(fromString("数d数d数d数d数d"), fromString("数d").repeat(5)); + assertEquals(fromString("数d"), fromString("数d").repeat(1)); + assertEquals(EMPTY_UTF8, fromString("数d").repeat(-1)); + } + + @Test + public void pad() { + assertEquals(fromString("hel"), fromString("hello").lpad(3, fromString("????"))); + assertEquals(fromString("hello"), fromString("hello").lpad(5, fromString("????"))); + assertEquals(fromString("?hello"), fromString("hello").lpad(6, fromString("????"))); + assertEquals(fromString("???????hello"), fromString("hello").lpad(12, fromString("????"))); + assertEquals(fromString("?????hello"), fromString("hello").lpad(10, fromString("?????"))); + assertEquals(fromString("???????"), EMPTY_UTF8.lpad(7, fromString("?????"))); + + assertEquals(fromString("hel"), fromString("hello").rpad(3, fromString("????"))); + assertEquals(fromString("hello"), fromString("hello").rpad(5, fromString("????"))); + assertEquals(fromString("hello?"), fromString("hello").rpad(6, fromString("????"))); + assertEquals(fromString("hello???????"), fromString("hello").rpad(12, fromString("????"))); + assertEquals(fromString("hello?????"), fromString("hello").rpad(10, fromString("?????"))); + assertEquals(fromString("???????"), EMPTY_UTF8.rpad(7, fromString("?????"))); + + assertEquals(fromString("数据砖"), fromString("数据砖头").lpad(3, fromString("????"))); + assertEquals(fromString("?数据砖头"), fromString("数据砖头").lpad(5, fromString("????"))); + assertEquals(fromString("??数据砖头"), fromString("数据砖头").lpad(6, fromString("????"))); + assertEquals(fromString("孙行数据砖头"), fromString("数据砖头").lpad(6, fromString("孙行者"))); + assertEquals(fromString("孙行者数据砖头"), fromString("数据砖头").lpad(7, fromString("孙行者"))); + assertEquals( + fromString("孙行者孙行者孙行数据砖头"), + fromString("数据砖头").lpad(12, fromString("孙行者"))); + + assertEquals(fromString("数据砖"), fromString("数据砖头").rpad(3, fromString("????"))); + assertEquals(fromString("数据砖头?"), fromString("数据砖头").rpad(5, fromString("????"))); + assertEquals(fromString("数据砖头??"), fromString("数据砖头").rpad(6, fromString("????"))); + assertEquals(fromString("数据砖头孙行"), fromString("数据砖头").rpad(6, fromString("孙行者"))); + assertEquals(fromString("数据砖头孙行者"), fromString("数据砖头").rpad(7, fromString("孙行者"))); + assertEquals( + fromString("数据砖头孙行者孙行者孙行"), + fromString("数据砖头").rpad(12, fromString("孙行者"))); + + assertEquals(EMPTY_UTF8, fromString("数据砖头").lpad(-10, fromString("孙行者"))); + assertEquals(EMPTY_UTF8, fromString("数据砖头").lpad(-10, EMPTY_UTF8)); + assertEquals(fromString("数据砖头"), fromString("数据砖头").lpad(5, EMPTY_UTF8)); + assertEquals(fromString("数据砖"), fromString("数据砖头").lpad(3, EMPTY_UTF8)); + assertEquals(EMPTY_UTF8, EMPTY_UTF8.lpad(3, EMPTY_UTF8)); + + assertEquals(EMPTY_UTF8, fromString("数据砖头").rpad(-10, fromString("孙行者"))); + assertEquals(EMPTY_UTF8, fromString("数据砖头").rpad(-10, EMPTY_UTF8)); + assertEquals(fromString("数据砖头"), fromString("数据砖头").rpad(5, EMPTY_UTF8)); + assertEquals(fromString("数据砖"), fromString("数据砖头").rpad(3, EMPTY_UTF8)); + assertEquals(EMPTY_UTF8, EMPTY_UTF8.rpad(3, EMPTY_UTF8)); + } + + @Test + public void substringSQL() { + UTF8String e = fromString("example"); + assertEquals(e.substringSQL(0, 2), fromString("ex")); + assertEquals(e.substringSQL(1, 2), fromString("ex")); + assertEquals(e.substringSQL(0, 7), fromString("example")); + assertEquals(e.substringSQL(1, 2), fromString("ex")); + assertEquals(e.substringSQL(0, 100), fromString("example")); + assertEquals(e.substringSQL(1, 100), fromString("example")); + assertEquals(e.substringSQL(2, 2), fromString("xa")); + assertEquals(e.substringSQL(1, 6), fromString("exampl")); + assertEquals(e.substringSQL(2, 100), fromString("xample")); + assertEquals(e.substringSQL(0, 0), fromString("")); + assertEquals(e.substringSQL(100, 4), EMPTY_UTF8); + assertEquals(e.substringSQL(0, Integer.MAX_VALUE), fromString("example")); + assertEquals(e.substringSQL(1, Integer.MAX_VALUE), fromString("example")); + assertEquals(e.substringSQL(2, Integer.MAX_VALUE), fromString("xample")); + } + + @Test + public void split() { + assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), -1), + new UTF8String[]{fromString("ab"), fromString("def"), fromString("ghi")})); + assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), 2), + new UTF8String[]{fromString("ab"), fromString("def,ghi")})); + assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), 2), + new UTF8String[]{fromString("ab"), fromString("def,ghi")})); + } + + @Test + public void levenshteinDistance() { + assertEquals(EMPTY_UTF8.levenshteinDistance(EMPTY_UTF8), 0); + assertEquals(EMPTY_UTF8.levenshteinDistance(fromString("a")), 1); + assertEquals(fromString("aaapppp").levenshteinDistance(EMPTY_UTF8), 7); + assertEquals(fromString("frog").levenshteinDistance(fromString("fog")), 1); + assertEquals(fromString("fly").levenshteinDistance(fromString("ant")),3); + assertEquals(fromString("elephant").levenshteinDistance(fromString("hippo")), 7); + assertEquals(fromString("hippo").levenshteinDistance(fromString("elephant")), 7); + assertEquals(fromString("hippo").levenshteinDistance(fromString("zzzzzzzz")), 8); + assertEquals(fromString("hello").levenshteinDistance(fromString("hallo")),1); + assertEquals(fromString("世界千世").levenshteinDistance(fromString("千a世b")),4); + } + + @Test + public void translate() { + assertEquals( + fromString("1a2s3ae"), + fromString("translate").translate(ImmutableMap.of( + 'r', '1', + 'n', '2', + 'l', '3', + 't', '\0' + ))); + assertEquals( + fromString("translate"), + fromString("translate").translate(new HashMap())); + assertEquals( + fromString("asae"), + fromString("translate").translate(ImmutableMap.of( + 'r', '\0', + 'n', '\0', + 'l', '\0', + 't', '\0' + ))); + assertEquals( + fromString("aa世b"), + fromString("花花世界").translate(ImmutableMap.of( + '花', 'a', + '界', 'b' + ))); + } + + @Test + public void createBlankString() { + assertEquals(fromString(" "), blankString(1)); + assertEquals(fromString(" "), blankString(2)); + assertEquals(fromString(" "), blankString(3)); + assertEquals(fromString(""), blankString(0)); + } + + @Test + public void findInSet() { + assertEquals(fromString("ab").findInSet(fromString("ab")), 1); + assertEquals(fromString("a,b").findInSet(fromString("b")), 2); + assertEquals(fromString("abc,b,ab,c,def").findInSet(fromString("ab")), 3); + assertEquals(fromString("ab,abc,b,ab,c,def").findInSet(fromString("ab")), 1); + assertEquals(fromString(",,,ab,abc,b,ab,c,def").findInSet(fromString("ab")), 4); + assertEquals(fromString(",ab,abc,b,ab,c,def").findInSet(fromString("")), 1); + assertEquals(fromString("数据砖头,abc,b,ab,c,def").findInSet(fromString("ab")), 4); + assertEquals(fromString("数据砖头,abc,b,ab,c,def").findInSet(fromString("def")), 6); + } + + @Test + public void soundex() { + assertEquals(fromString("Robert").soundex(), fromString("R163")); + assertEquals(fromString("Rupert").soundex(), fromString("R163")); + assertEquals(fromString("Rubin").soundex(), fromString("R150")); + assertEquals(fromString("Ashcraft").soundex(), fromString("A261")); + assertEquals(fromString("Ashcroft").soundex(), fromString("A261")); + assertEquals(fromString("Burroughs").soundex(), fromString("B620")); + assertEquals(fromString("Burrows").soundex(), fromString("B620")); + assertEquals(fromString("Ekzampul").soundex(), fromString("E251")); + assertEquals(fromString("Example").soundex(), fromString("E251")); + assertEquals(fromString("Ellery").soundex(), fromString("E460")); + assertEquals(fromString("Euler").soundex(), fromString("E460")); + assertEquals(fromString("Ghosh").soundex(), fromString("G200")); + assertEquals(fromString("Gauss").soundex(), fromString("G200")); + assertEquals(fromString("Gutierrez").soundex(), fromString("G362")); + assertEquals(fromString("Heilbronn").soundex(), fromString("H416")); + assertEquals(fromString("Hilbert").soundex(), fromString("H416")); + assertEquals(fromString("Jackson").soundex(), fromString("J250")); + assertEquals(fromString("Kant").soundex(), fromString("K530")); + assertEquals(fromString("Knuth").soundex(), fromString("K530")); + assertEquals(fromString("Lee").soundex(), fromString("L000")); + assertEquals(fromString("Lukasiewicz").soundex(), fromString("L222")); + assertEquals(fromString("Lissajous").soundex(), fromString("L222")); + assertEquals(fromString("Ladd").soundex(), fromString("L300")); + assertEquals(fromString("Lloyd").soundex(), fromString("L300")); + assertEquals(fromString("Moses").soundex(), fromString("M220")); + assertEquals(fromString("O'Hara").soundex(), fromString("O600")); + assertEquals(fromString("Pfister").soundex(), fromString("P236")); + assertEquals(fromString("Rubin").soundex(), fromString("R150")); + assertEquals(fromString("Robert").soundex(), fromString("R163")); + assertEquals(fromString("Rupert").soundex(), fromString("R163")); + assertEquals(fromString("Soundex").soundex(), fromString("S532")); + assertEquals(fromString("Sownteks").soundex(), fromString("S532")); + assertEquals(fromString("Tymczak").soundex(), fromString("T522")); + assertEquals(fromString("VanDeusen").soundex(), fromString("V532")); + assertEquals(fromString("Washington").soundex(), fromString("W252")); + assertEquals(fromString("Wheaton").soundex(), fromString("W350")); + + assertEquals(fromString("a").soundex(), fromString("A000")); + assertEquals(fromString("ab").soundex(), fromString("A100")); + assertEquals(fromString("abc").soundex(), fromString("A120")); + assertEquals(fromString("abcd").soundex(), fromString("A123")); + assertEquals(fromString("").soundex(), fromString("")); + assertEquals(fromString("123").soundex(), fromString("123")); + assertEquals(fromString("世界千世").soundex(), fromString("世界千世")); } } diff --git a/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala new file mode 100644 index 000000000000..12a002befa0a --- /dev/null +++ b/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.types + +import org.apache.commons.lang3.StringUtils + +import org.scalacheck.{Arbitrary, Gen} +import org.scalatest.prop.GeneratorDrivenPropertyChecks +// scalastyle:off +import org.scalatest.{FunSuite, Matchers} + +import org.apache.spark.unsafe.types.UTF8String.{fromString => toUTF8} + +/** + * This TestSuite utilize ScalaCheck to generate randomized inputs for UTF8String testing. + */ +class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenPropertyChecks with Matchers { +// scalastyle:on + + test("toString") { + forAll { (s: String) => + assert(toUTF8(s).toString() === s) + } + } + + test("numChars") { + forAll { (s: String) => + assert(toUTF8(s).numChars() === s.length) + } + } + + test("startsWith") { + forAll { (s: String) => + val utf8 = toUTF8(s) + assert(utf8.startsWith(utf8)) + for (i <- 1 to s.length) { + assert(utf8.startsWith(toUTF8(s.dropRight(i)))) + } + } + } + + test("endsWith") { + forAll { (s: String) => + val utf8 = toUTF8(s) + assert(utf8.endsWith(utf8)) + for (i <- 1 to s.length) { + assert(utf8.endsWith(toUTF8(s.drop(i)))) + } + } + } + + test("toUpperCase") { + forAll { (s: String) => + assert(toUTF8(s).toUpperCase === toUTF8(s.toUpperCase)) + } + } + + test("toLowerCase") { + forAll { (s: String) => + assert(toUTF8(s).toLowerCase === toUTF8(s.toLowerCase)) + } + } + + test("compare") { + forAll { (s1: String, s2: String) => + assert(Math.signum(toUTF8(s1).compareTo(toUTF8(s2))) === Math.signum(s1.compareTo(s2))) + } + } + + test("substring") { + forAll { (s: String) => + for (start <- 0 to s.length; end <- 0 to s.length; if start <= end) { + assert(toUTF8(s).substring(start, end).toString === s.substring(start, end)) + } + } + } + + test("contains") { + forAll { (s: String) => + for (start <- 0 to s.length; end <- 0 to s.length; if start <= end) { + val substring = s.substring(start, end) + assert(toUTF8(s).contains(toUTF8(substring)) === s.contains(substring)) + } + } + } + + val whitespaceChar: Gen[Char] = Gen.choose(0x00, 0x20).map(_.toChar) + val whitespaceString: Gen[String] = Gen.listOf(whitespaceChar).map(_.mkString) + val randomString: Gen[String] = Arbitrary.arbString.arbitrary + + test("trim, trimLeft, trimRight") { + // lTrim and rTrim are both modified from java.lang.String.trim + def lTrim(s: String): String = { + var st = 0 + val array: Array[Char] = s.toCharArray + while ((st < s.length) && (array(st) <= ' ')) { + st += 1 + } + if (st > 0) s.substring(st, s.length) else s + } + def rTrim(s: String): String = { + var len = s.length + val array: Array[Char] = s.toCharArray + while ((len > 0) && (array(len - 1) <= ' ')) { + len -= 1 + } + if (len < s.length) s.substring(0, len) else s + } + + forAll( + whitespaceString, + randomString, + whitespaceString + ) { (start: String, middle: String, end: String) => + val s = start + middle + end + assert(toUTF8(s).trim() === toUTF8(s.trim())) + assert(toUTF8(s).trimLeft() === toUTF8(lTrim(s))) + assert(toUTF8(s).trimRight() === toUTF8(rTrim(s))) + } + } + + test("reverse") { + forAll { (s: String) => + assert(toUTF8(s).reverse === toUTF8(s.reverse)) + } + } + + test("indexOf") { + forAll { (s: String) => + for (start <- 0 to s.length; end <- 0 to s.length; if start <= end) { + val substring = s.substring(start, end) + assert(toUTF8(s).indexOf(toUTF8(substring), 0) === s.indexOf(substring)) + } + } + } + + val randomInt = Gen.choose(-100, 100) + + test("repeat") { + def repeat(str: String, times: Int): String = { + if (times > 0) str * times else "" + } + // ScalaCheck always generating too large repeat times which might hang the test forever. + forAll(randomString, randomInt) { (s: String, times: Int) => + assert(toUTF8(s).repeat(times) === toUTF8(repeat(s, times))) + } + } + + test("lpad, rpad") { + def padding(origin: String, pad: String, length: Int, isLPad: Boolean): String = { + if (length <= 0) return "" + if (length <= origin.length) { + if (length <= 0) "" else origin.substring(0, length) + } else { + if (pad.length == 0) return origin + val toPad = length - origin.length + val partPad = if (toPad % pad.length == 0) "" else pad.substring(0, toPad % pad.length) + if (isLPad) { + pad * (toPad / pad.length) + partPad + origin + } else { + origin + pad * (toPad / pad.length) + partPad + } + } + } + + forAll ( + randomString, + randomString, + randomInt + ) { (s: String, pad: String, length: Int) => + assert(toUTF8(s).lpad(length, toUTF8(pad)) === + toUTF8(padding(s, pad, length, true))) + assert(toUTF8(s).rpad(length, toUTF8(pad)) === + toUTF8(padding(s, pad, length, false))) + } + } + + val nullalbeSeq = Gen.listOf(Gen.oneOf[String](null: String, randomString)) + + test("concat") { + def concat(orgin: Seq[String]): String = + if (orgin.exists(_ == null)) null else orgin.mkString + + forAll { (inputs: Seq[String]) => + assert(UTF8String.concat(inputs.map(toUTF8): _*) === toUTF8(inputs.mkString)) + } + forAll (nullalbeSeq) { (inputs: Seq[String]) => + assert(UTF8String.concat(inputs.map(toUTF8): _*) === toUTF8(concat(inputs))) + } + } + + test("concatWs") { + def concatWs(sep: String, inputs: Seq[String]): String = { + if (sep == null) return null + inputs.filter(_ != null).mkString(sep) + } + + forAll { (sep: String, inputs: Seq[String]) => + assert(UTF8String.concatWs(toUTF8(sep), inputs.map(toUTF8): _*) === + toUTF8(inputs.mkString(sep))) + } + forAll(randomString, nullalbeSeq) {(sep: String, inputs: Seq[String]) => + assert(UTF8String.concatWs(toUTF8(sep), inputs.map(toUTF8): _*) === + toUTF8(concatWs(sep, inputs))) + } + } + + // TODO: enable this when we find a proper way to generate valid patterns + ignore("split") { + forAll { (s: String, pattern: String, limit: Int) => + assert(toUTF8(s).split(toUTF8(pattern), limit) === + s.split(pattern, limit).map(toUTF8(_))) + } + } + + test("levenshteinDistance") { + forAll { (one: String, another: String) => + assert(toUTF8(one).levenshteinDistance(toUTF8(another)) === + StringUtils.getLevenshteinDistance(one, another)) + } + } + + test("hashCode") { + forAll { (s: String) => + assert(toUTF8(s).hashCode() === toUTF8(s).hashCode()) + } + } + + test("equals") { + forAll { (one: String, another: String) => + assert(toUTF8(one).equals(toUTF8(another)) === one.equals(another)) + } + } +} diff --git a/yarn/pom.xml b/yarn/pom.xml index 644def7501dc..d8e4a4bbead8 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml @@ -30,7 +30,6 @@ Spark Project YARN yarn - 1.9 @@ -39,6 +38,12 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-network-yarn_${scala.binary.version} + ${project.version} + test + org.apache.spark spark-core_${scala.binary.version} @@ -93,12 +98,28 @@ jetty-servlet - + + + + org.eclipse.jetty.orbit + javax.servlet.jsp + 2.2.0.v201112011158 + test + + + org.eclipse.jetty.orbit + javax.servlet.jsp.jstl + 1.2.0.v201105211821 + test + + - + org.apache.hadoop hadoop-yarn-server-tests @@ -107,7 +128,7 @@ org.mockito - mockito-all + mockito-core test @@ -125,29 +146,20 @@ com.sun.jersey jersey-core - ${jersey.version} test com.sun.jersey jersey-json - ${jersey.version} test - - - stax - stax-api - - com.sun.jersey jersey-server - ${jersey.version} test - + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala index 77af46c192cc..56e4741b9387 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala @@ -65,6 +65,8 @@ private[yarn] class AMDelegationTokenRenewer( sparkConf.getInt("spark.yarn.credentials.file.retention.days", 5) private val numFilesToKeep = sparkConf.getInt("spark.yarn.credentials.file.retention.count", 5) + private val freshHadoopConf = + hadoopUtil.getConfBypassingFSCache(hadoopConf, new Path(credentialsFile).toUri.getScheme) /** * Schedule a login from the keytab and principal set using the --principal and --keytab @@ -123,7 +125,7 @@ private[yarn] class AMDelegationTokenRenewer( private def cleanupOldFiles(): Unit = { import scala.concurrent.duration._ try { - val remoteFs = FileSystem.get(hadoopConf) + val remoteFs = FileSystem.get(freshHadoopConf) val credentialsPath = new Path(credentialsFile) val thresholdTime = System.currentTimeMillis() - (daysToKeepFiles days).toMillis hadoopUtil.listFilesSorted( @@ -169,13 +171,13 @@ private[yarn] class AMDelegationTokenRenewer( // Get a copy of the credentials override def run(): Void = { val nns = YarnSparkHadoopUtil.get.getNameNodesToAccess(sparkConf) + dst - hadoopUtil.obtainTokensForNamenodes(nns, hadoopConf, tempCreds) + hadoopUtil.obtainTokensForNamenodes(nns, freshHadoopConf, tempCreds) null } }) // Add the temp credentials back to the original ones. UserGroupInformation.getCurrentUser.addCredentials(tempCreds) - val remoteFs = FileSystem.get(hadoopConf) + val remoteFs = FileSystem.get(freshHadoopConf) // If lastCredentialsFileSuffix is 0, then the AM is either started or restarted. If the AM // was restarted, then the lastCredentialsFileSuffix might be > 0, so find the newest file // and update the lastCredentialsFileSuffix. @@ -194,7 +196,7 @@ private[yarn] class AMDelegationTokenRenewer( val tempTokenPath = new Path(tokenPathStr + SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) logInfo("Writing out delegation tokens to " + tempTokenPath.toString) val credentials = UserGroupInformation.getCurrentUser.getCredentials - credentials.writeTokenStorageFile(tempTokenPath, hadoopConf) + credentials.writeTokenStorageFile(tempTokenPath, freshHadoopConf) logInfo(s"Delegation Tokens written out successfully. Renaming file to $tokenPathStr") remoteFs.rename(tempTokenPath, tokenPath) logInfo("Delegation token file rename complete.") diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 83dafa4a125d..93621b44c918 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -30,8 +30,8 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.spark.rpc._ -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv} -import org.apache.spark.SparkException +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv, + SparkException, SparkUserAppException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend} @@ -64,7 +64,8 @@ private[spark] class ApplicationMaster( // Default to numExecutors * 2, with minimum of 3 private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", - sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3))) + sparkConf.getInt("spark.yarn.max.worker.failures", + math.max(sparkConf.getInt("spark.executor.instances", 0) * 2, 3))) @volatile private var exitCode = 0 @volatile private var unregistered = false @@ -111,7 +112,8 @@ private[spark] class ApplicationMaster( val fs = FileSystem.get(yarnConf) // This shutdown hook should run *after* the SparkContext is shut down. - Utils.addShutdownHook(Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY - 1) { () => + val priority = ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY - 1 + ShutdownHookManager.addShutdownHook(priority) { () => val maxAppAttempts = client.getMaxRegAttempts(sparkConf, yarnConf) val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts @@ -198,7 +200,7 @@ private[spark] class ApplicationMaster( final def finish(status: FinalApplicationStatus, code: Int, msg: String = null): Unit = { synchronized { if (!finished) { - val inShutdown = Utils.inShutdown() + val inShutdown = ShutdownHookManager.inShutdown() logInfo(s"Final app status: $status, exitCode: $code" + Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) exitCode = code @@ -229,7 +231,11 @@ private[spark] class ApplicationMaster( sparkContextRef.compareAndSet(sc, null) } - private def registerAM(_rpcEnv: RpcEnv, uiAddress: String, securityMgr: SecurityManager) = { + private def registerAM( + _rpcEnv: RpcEnv, + driverRef: RpcEndpointRef, + uiAddress: String, + securityMgr: SecurityManager) = { val sc = sparkContextRef.get() val appId = client.getAttemptId().getApplicationId().toString() @@ -246,6 +252,7 @@ private[spark] class ApplicationMaster( RpcAddress(_sparkConf.get("spark.driver.host"), _sparkConf.get("spark.driver.port").toInt), CoarseGrainedSchedulerBackend.ENDPOINT_NAME) allocator = client.register(driverUrl, + driverRef, yarnConf, _sparkConf, if (sc != null) sc.preferredNodeLocationData else Map(), @@ -262,17 +269,20 @@ private[spark] class ApplicationMaster( * * In cluster mode, the AM and the driver belong to same process * so the AMEndpoint need not monitor lifecycle of the driver. + * + * @return A reference to the driver's RPC endpoint. */ private def runAMEndpoint( host: String, port: String, - isClusterMode: Boolean): Unit = { + isClusterMode: Boolean): RpcEndpointRef = { val driverEndpoint = rpcEnv.setupEndpointRef( SparkEnv.driverActorSystemName, RpcAddress(host, port.toInt), YarnSchedulerBackend.ENDPOINT_NAME) amEndpoint = rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverEndpoint, isClusterMode)) + driverEndpoint } private def runDriver(securityMgr: SecurityManager): Unit = { @@ -290,11 +300,11 @@ private[spark] class ApplicationMaster( "Timed out waiting for SparkContext.") } else { rpcEnv = sc.env.rpcEnv - runAMEndpoint( + val driverRef = runAMEndpoint( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port"), isClusterMode = true) - registerAM(rpcEnv, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) + registerAM(rpcEnv, driverRef, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) userClassThread.join() } } @@ -302,9 +312,9 @@ private[spark] class ApplicationMaster( private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { val port = sparkConf.getInt("spark.yarn.am.port", 0) rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr) - waitForSparkDriver() + val driverRef = waitForSparkDriver() addAmIpFilter() - registerAM(rpcEnv, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) + registerAM(rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) // In client mode the actor will stop the reporter thread. reporterThread.join() @@ -428,7 +438,7 @@ private[spark] class ApplicationMaster( } } - private def waitForSparkDriver(): Unit = { + private def waitForSparkDriver(): RpcEndpointRef = { logInfo("Waiting for Spark driver to be reachable.") var driverUp = false val hostport = args.userArgs(0) @@ -485,7 +495,6 @@ private[spark] class ApplicationMaster( */ private def startUserApplication(): Thread = { logInfo("Starting the user application in a separate Thread") - System.setProperty("spark.executor.instances", args.numExecutors.toString) val classpath = Client.getUserClasspath(sparkConf) val urls = classpath.map { entry => @@ -521,6 +530,10 @@ private[spark] class ApplicationMaster( e.getCause match { case _: InterruptedException => // Reporter thread can interrupt to stop user class + case SparkUserAppException(exitCode) => + val msg = s"User application exited with status $exitCode" + logError(msg) + finish(FinalApplicationStatus.FAILED, exitCode, msg) case cause: Throwable => logError("User class threw exception: " + cause, cause) finish(FinalApplicationStatus.FAILED, @@ -555,11 +568,12 @@ private[spark] class ApplicationMaster( } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RequestExecutors(requestedTotal) => + case RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount) => Option(allocator) match { case Some(a) => allocatorLock.synchronized { - if (a.requestTotalExecutors(requestedTotal)) { + if (a.requestTotalExecutorsWithPreferredLocalities(requestedTotal, + localityAwareTasks, hostToLocalTaskCount)) { allocatorLock.notifyAll() } } @@ -576,6 +590,13 @@ private[spark] class ApplicationMaster( case None => logWarning("Container allocator is not ready to kill executors yet.") } context.reply(true) + + case GetExecutorLossReason(eid) => + Option(allocator) match { + case Some(a) => a.enqueueGetLossReasonRequest(eid, context) + case None => logWarning(s"Container allocator is not ready to find" + + s" executor loss reasons yet.") + } } override def onDisconnected(remoteAddress: RpcAddress): Unit = { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index 68e9f6b4db7f..17d9943c795e 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -29,7 +29,6 @@ class ApplicationMasterArguments(val args: Array[String]) { var userArgs: Seq[String] = Nil var executorMemory = 1024 var executorCores = 1 - var numExecutors = DEFAULT_NUMBER_EXECUTORS var propertiesFile: String = null parseArgs(args.toList) @@ -63,10 +62,6 @@ class ApplicationMasterArguments(val args: Array[String]) { userArgsBuffer += value args = tail - case ("--num-workers" | "--num-executors") :: IntParam(value) :: tail => - numExecutors = value - args = tail - case ("--worker-memory" | "--executor-memory") :: MemoryParam(value) :: tail => executorMemory = value args = tail @@ -85,7 +80,9 @@ class ApplicationMasterArguments(val args: Array[String]) { } if (primaryPyFile != null && primaryRFile != null) { + // scalastyle:off println System.err.println("Cannot have primary-py-file and primary-r-file at the same time") + // scalastyle:on println System.exit(-1) } @@ -93,6 +90,7 @@ class ApplicationMasterArguments(val args: Array[String]) { } def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { + // scalastyle:off println if (unknownParam != null) { System.err.println("Unknown/unsupported param " + unknownParam) } @@ -107,10 +105,11 @@ class ApplicationMasterArguments(val args: Array[String]) { | place on the PYTHONPATH for Python apps. | --args ARGS Arguments to be passed to your application's main class. | Multiple invocations are possible, each will be passed in order. - | --num-executors NUM Number of executors to start (Default: 2) | --executor-cores NUM Number of cores for the executors (Default: 1) | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) + | --properties-file FILE Path to a custom Spark properties file. """.stripMargin) + // scalastyle:on println System.exit(exitCode) } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index da1ec2a0fe2e..a2c4bc2f5480 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -25,7 +25,7 @@ import java.security.PrivilegedExceptionAction import java.util.{Properties, UUID} import java.util.zip.{ZipEntry, ZipOutputStream} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map} import scala.reflect.runtime.universe import scala.util.{Try, Success, Failure} @@ -80,10 +80,12 @@ private[spark] class Client( private val isClusterMode = args.isClusterMode private var loginFromKeytab = false + private var principal: String = null + private var keytab: String = null + private val fireAndForget = isClusterMode && !sparkConf.getBoolean("spark.yarn.submit.waitAppCompletion", true) - def stop(): Unit = yarnClient.stop() /** @@ -161,6 +163,23 @@ private[spark] class Client( appContext.setQueue(args.amQueue) appContext.setAMContainerSpec(containerContext) appContext.setApplicationType("SPARK") + sparkConf.getOption(CONF_SPARK_YARN_APPLICATION_TAGS) + .map(StringUtils.getTrimmedStringCollection(_)) + .filter(!_.isEmpty()) + .foreach { tagCollection => + try { + // The setApplicationTags method was only introduced in Hadoop 2.4+, so we need to use + // reflection to set it, printing a warning if a tag was specified but the YARN version + // doesn't support it. + val method = appContext.getClass().getMethod( + "setApplicationTags", classOf[java.util.Set[String]]) + method.invoke(appContext, new java.util.HashSet[String](tagCollection)) + } catch { + case e: NoSuchMethodException => + logWarning(s"Ignoring $CONF_SPARK_YARN_APPLICATION_TAGS because this version of " + + "YARN does not support it") + } + } sparkConf.getOption("spark.yarn.maxAppAttempts").map(_.toInt) match { case Some(v) => appContext.setMaxAppAttempts(v) case None => logDebug("spark.yarn.maxAppAttempts is not set. " + @@ -201,12 +220,14 @@ private[spark] class Client( val executorMem = args.executorMemory + executorMemoryOverhead if (executorMem > maxMem) { throw new IllegalArgumentException(s"Required executor memory (${args.executorMemory}" + - s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster!") + s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " + + "Please increase the value of 'yarn.scheduler.maximum-allocation-mb'.") } val amMem = args.amMemory + amMemoryOverhead if (amMem > maxMem) { throw new IllegalArgumentException(s"Required AM memory (${args.amMemory}" + - s"+$amMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster!") + s"+$amMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " + + "Please increase the value of 'yarn.scheduler.maximum-allocation-mb'.") } logInfo("Will allocate AM container, with %d MB memory including %d MB overhead".format( amMem, @@ -264,8 +285,8 @@ private[spark] class Client( // multiple times, YARN will fail to launch containers for the app with an internal // error. val distributedUris = new HashSet[String] - obtainTokenForHiveMetastore(hadoopConf, credentials) - obtainTokenForHBase(hadoopConf, credentials) + obtainTokenForHiveMetastore(sparkConf, hadoopConf, credentials) + obtainTokenForHBase(sparkConf, hadoopConf, credentials) val replication = sparkConf.getInt("spark.yarn.submit.file.replication", fs.getDefaultReplication(dst)).toShort @@ -321,8 +342,9 @@ private[spark] class Client( val linkname = targetDir.map(_ + "/").getOrElse("") + destName.orElse(Option(localURI.getFragment())).getOrElse(localPath.getName()) val destPath = copyFileToRemote(dst, localPath, replication) + val destFs = FileSystem.get(destPath.toUri(), hadoopConf) distCacheMgr.addResource( - fs, hadoopConf, destPath, localResources, resType, linkname, statCache, + destFs, hadoopConf, destPath, localResources, resType, linkname, statCache, appMasterOnly = appMasterOnly) (false, linkname) } else { @@ -338,7 +360,7 @@ private[spark] class Client( if (loginFromKeytab) { logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" + " via the YARN Secure Distributed Cache.") - val (_, localizedPath) = distribute(args.keytab, + val (_, localizedPath) = distribute(keytab, destName = Some(sparkConf.get("spark.yarn.keytab")), appMasterOnly = true) require(localizedPath != null, "Keytab file already distributed.") @@ -409,7 +431,7 @@ private[spark] class Client( } // Distribute an archive with Hadoop and Spark configuration for the AM. - val (_, confLocalizedPath) = distribute(createConfArchive().getAbsolutePath(), + val (_, confLocalizedPath) = distribute(createConfArchive().toURI().getPath(), resType = LocalResourceType.ARCHIVE, destName = Some(LOCALIZED_CONF_DIR), appMasterOnly = true) @@ -489,7 +511,7 @@ private[spark] class Client( val nns = YarnSparkHadoopUtil.get.getNameNodesToAccess(sparkConf) + stagingDirPath YarnSparkHadoopUtil.get.obtainTokensForNamenodes( nns, hadoopConf, creds, Some(sparkConf.get("spark.yarn.principal"))) - val t = creds.getAllTokens + val t = creds.getAllTokens.asScala .filter(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND) .head val newExpiration = t.renew(hadoopConf) @@ -615,7 +637,7 @@ private[spark] class Client( val appId = newAppResponse.getApplicationId val appStagingDir = getAppStagingDir(appId) val pySparkArchives = - if (sys.props.getOrElse("spark.yarn.isPython", "false").toBoolean) { + if (sparkConf.getBoolean("spark.yarn.isPython", false)) { findPySparkArchives() } else { Nil @@ -628,8 +650,8 @@ private[spark] class Client( distCacheMgr.setDistArchivesEnv(launchEnv) val amContainer = Records.newRecord(classOf[ContainerLaunchContext]) - amContainer.setLocalResources(localResources) - amContainer.setEnvironment(launchEnv) + amContainer.setLocalResources(localResources.asJava) + amContainer.setEnvironment(launchEnv.asJava) val javaOpts = ListBuffer[String]() @@ -676,7 +698,7 @@ private[spark] class Client( val libraryPaths = Seq(sys.props.get("spark.driver.extraLibraryPath"), sys.props.get("spark.driver.libraryPath")).flatten if (libraryPaths.nonEmpty) { - prefixEnv = Some(Utils.libraryPathEnvPrefix(libraryPaths)) + prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(libraryPaths))) } if (sparkConf.getOption("spark.yarn.am.extraJavaOptions").isDefined) { logWarning("spark.yarn.am.extraJavaOptions will not take effect in cluster mode") @@ -698,7 +720,7 @@ private[spark] class Client( } sparkConf.getOption("spark.yarn.am.extraLibraryPath").foreach { paths => - prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(paths))) + prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(paths)))) } } @@ -731,9 +753,9 @@ private[spark] class Client( } val amClass = if (isClusterMode) { - Class.forName("org.apache.spark.deploy.yarn.ApplicationMaster").getName + Utils.classForName("org.apache.spark.deploy.yarn.ApplicationMaster").getName } else { - Class.forName("org.apache.spark.deploy.yarn.ExecutorLauncher").getName + Utils.classForName("org.apache.spark.deploy.yarn.ExecutorLauncher").getName } if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { args.userArgs = ArrayBuffer(args.primaryRFile) ++ args.userArgs @@ -746,7 +768,6 @@ private[spark] class Client( userArgs ++ Seq( "--executor-memory", args.executorMemory.toString + "m", "--executor-cores", args.executorCores.toString, - "--num-executors ", args.numExecutors.toString, "--properties-file", buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), LOCALIZED_CONF_DIR, SPARK_CONF_FILE)) @@ -761,10 +782,10 @@ private[spark] class Client( // TODO: it would be nicer to just make sure there are no null commands here val printableCommands = commands.map(s => if (s == null) "null" else s).toList - amContainer.setCommands(printableCommands) + amContainer.setCommands(printableCommands.asJava) logDebug("===============================================================================") - logDebug("Yarn AM launch context:") + logDebug("YARN AM launch context:") logDebug(s" user class: ${Option(args.userClass).getOrElse("N/A")}") logDebug(" env:") launchEnv.foreach { case (k, v) => logDebug(s" $k -> $v") } @@ -776,7 +797,8 @@ private[spark] class Client( // send the acl settings into YARN to control who has access via YARN interfaces val securityManager = new SecurityManager(sparkConf) - amContainer.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager)) + amContainer.setApplicationACLs( + YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager).asJava) setupSecurityToken(amContainer) UserGroupInformation.getCurrentUser().addCredentials(credentials) @@ -784,19 +806,27 @@ private[spark] class Client( } def setupCredentials(): Unit = { - if (args.principal != null) { - require(args.keytab != null, "Keytab must be specified when principal is specified.") + loginFromKeytab = args.principal != null || sparkConf.contains("spark.yarn.principal") + if (loginFromKeytab) { + principal = + if (args.principal != null) args.principal else sparkConf.get("spark.yarn.principal") + keytab = { + if (args.keytab != null) { + args.keytab + } else { + sparkConf.getOption("spark.yarn.keytab").orNull + } + } + + require(keytab != null, "Keytab must be specified when principal is specified.") logInfo("Attempting to login to the Kerberos" + - s" using principal: ${args.principal} and keytab: ${args.keytab}") - val f = new File(args.keytab) + s" using principal: $principal and keytab: $keytab") + val f = new File(keytab) // Generate a file name that can be used for the keytab file, that does not conflict // with any user file. val keytabFileName = f.getName + "-" + UUID.randomUUID().toString - UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) - loginFromKeytab = true sparkConf.set("spark.yarn.keytab", keytabFileName) - sparkConf.set("spark.yarn.principal", args.principal) - logInfo("Successfully logged into the KDC.") + sparkConf.set("spark.yarn.principal", principal) } credentials = UserGroupInformation.getCurrentUser.getCredentials } @@ -937,7 +967,7 @@ private[spark] class Client( object Client extends Logging { def main(argStrings: Array[String]) { if (!sys.props.contains("SPARK_SUBMIT")) { - println("WARNING: This client is deprecated and will be removed in a " + + logWarning("WARNING: This client is deprecated and will be removed in a " + "future version of Spark. Use ./bin/spark-submit with \"--master yarn\"") } @@ -947,6 +977,10 @@ object Client extends Logging { val sparkConf = new SparkConf val args = new ClientArguments(argStrings, sparkConf) + // to maintain backwards-compatibility + if (!Utils.isDynamicAllocationEnabled(sparkConf)) { + sparkConf.setIfMissing("spark.executor.instances", args.numExecutors.toString) + } new Client(args, sparkConf).run() } @@ -971,6 +1005,10 @@ object Client extends Logging { // of the executors val CONF_SPARK_YARN_SECONDARY_JARS = "spark.yarn.secondary.jars" + // Comma-separated list of strings to pass through as YARN application tags appearing + // in YARN ApplicationReports, which can be used for filtering when querying YARN. + val CONF_SPARK_YARN_APPLICATION_TAGS = "spark.yarn.tags" + // Staging directory is private! -> rwx-------- val STAGING_DIR_PERMISSION: FsPermission = FsPermission.createImmutable(Integer.parseInt("700", 8).toShort) @@ -1007,7 +1045,10 @@ object Client extends Logging { s"in favor of the $CONF_SPARK_JAR configuration variable.") System.getenv(ENV_SPARK_JAR) } else { - SparkContext.jarOfClass(this.getClass).head + SparkContext.jarOfClass(this.getClass).getOrElse(throw new SparkException("Could not " + + "find jar containing Spark classes. The jar can be defined using the " + + "spark.yarn.jar configuration option. If testing Spark, either set that option or " + + "make sure SPARK_PREPEND_CLASSES is not set.")) } } @@ -1061,20 +1102,10 @@ object Client extends Logging { triedDefault.toOption } - /** - * In Hadoop 0.23, the MR application classpath comes with the YARN application - * classpath. In Hadoop 2.0, it's an array of Strings, and in 2.2+ it's a String. - * So we need to use reflection to retrieve it. - */ private[yarn] def getDefaultMRApplicationClasspath: Option[Seq[String]] = { val triedDefault = Try[Seq[String]] { val field = classOf[MRJobConfig].getField("DEFAULT_MAPREDUCE_APPLICATION_CLASSPATH") - val value = if (field.getType == classOf[String]) { - StringUtils.getStrings(field.get(null).asInstanceOf[String]).toArray - } else { - field.get(null).asInstanceOf[Array[String]] - } - value.toSeq + StringUtils.getStrings(field.get(null).asInstanceOf[String]).toSeq } recoverWith { case e: NoSuchFieldException => Success(Seq.empty[String]) } @@ -1106,10 +1137,10 @@ object Client extends Logging { env: HashMap[String, String], isAM: Boolean, extraClassPath: Option[String] = None): Unit = { - extraClassPath.foreach(addClasspathEntry(_, env)) - addClasspathEntry( - YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env - ) + extraClassPath.foreach { cp => + addClasspathEntry(getClusterPath(sparkConf, cp), env) + } + addClasspathEntry(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env) if (isAM) { addClasspathEntry( @@ -1125,12 +1156,14 @@ object Client extends Logging { getUserClasspath(sparkConf) } userClassPath.foreach { x => - addFileToClasspath(x, null, env) + addFileToClasspath(sparkConf, x, null, env) } } - addFileToClasspath(new URI(sparkJar(sparkConf)), SPARK_JAR, env) + addFileToClasspath(sparkConf, new URI(sparkJar(sparkConf)), SPARK_JAR, env) populateHadoopClasspath(conf, env) - sys.env.get(ENV_DIST_CLASSPATH).foreach(addClasspathEntry(_, env)) + sys.env.get(ENV_DIST_CLASSPATH).foreach { cp => + addClasspathEntry(getClusterPath(sparkConf, cp), env) + } } /** @@ -1159,16 +1192,18 @@ object Client extends Logging { * * If not a "local:" file and no alternate name, the environment is not modified. * + * @param conf Spark configuration. * @param uri URI to add to classpath (optional). * @param fileName Alternate name for the file (optional). * @param env Map holding the environment variables. */ private def addFileToClasspath( + conf: SparkConf, uri: URI, fileName: String, env: HashMap[String, String]): Unit = { if (uri != null && uri.getScheme == LOCAL_SCHEME) { - addClasspathEntry(uri.getPath, env) + addClasspathEntry(getClusterPath(conf, uri.getPath), env) } else if (fileName != null) { addClasspathEntry(buildPath( YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), fileName), env) @@ -1182,11 +1217,37 @@ object Client extends Logging { private def addClasspathEntry(path: String, env: HashMap[String, String]): Unit = YarnSparkHadoopUtil.addPathToEnvironment(env, Environment.CLASSPATH.name, path) + /** + * Returns the path to be sent to the NM for a path that is valid on the gateway. + * + * This method uses two configuration values: + * + * - spark.yarn.config.gatewayPath: a string that identifies a portion of the input path that may + * only be valid in the gateway node. + * - spark.yarn.config.replacementPath: a string with which to replace the gateway path. This may + * contain, for example, env variable references, which will be expanded by the NMs when + * starting containers. + * + * If either config is not available, the input path is returned. + */ + def getClusterPath(conf: SparkConf, path: String): String = { + val localPath = conf.get("spark.yarn.config.gatewayPath", null) + val clusterPath = conf.get("spark.yarn.config.replacementPath", null) + if (localPath != null && clusterPath != null) { + path.replace(localPath, clusterPath) + } else { + path + } + } + /** * Obtains token for the Hive metastore and adds them to the credentials. */ - private def obtainTokenForHiveMetastore(conf: Configuration, credentials: Credentials) { - if (UserGroupInformation.isSecurityEnabled) { + private def obtainTokenForHiveMetastore( + sparkConf: SparkConf, + conf: Configuration, + credentials: Credentials) { + if (shouldGetTokens(sparkConf, "hive") && UserGroupInformation.isSecurityEnabled) { val mirror = universe.runtimeMirror(getClass.getClassLoader) try { @@ -1243,8 +1304,11 @@ object Client extends Logging { /** * Obtain security token for HBase. */ - def obtainTokenForHBase(conf: Configuration, credentials: Credentials): Unit = { - if (UserGroupInformation.isSecurityEnabled) { + def obtainTokenForHBase( + sparkConf: SparkConf, + conf: Configuration, + credentials: Credentials): Unit = { + if (shouldGetTokens(sparkConf, "hbase") && UserGroupInformation.isSecurityEnabled) { val mirror = universe.runtimeMirror(getClass.getClassLoader) try { @@ -1257,11 +1321,12 @@ object Client extends Logging { logDebug("Attempting to fetch HBase security token.") - val hbaseConf = confCreate.invoke(null, conf) - val token = obtainToken.invoke(null, hbaseConf).asInstanceOf[Token[TokenIdentifier]] - credentials.addToken(token.getService, token) - - logInfo("Added HBase security token to credentials.") + val hbaseConf = confCreate.invoke(null, conf).asInstanceOf[Configuration] + if ("kerberos" == hbaseConf.get("hbase.security.authentication")) { + val token = obtainToken.invoke(null, hbaseConf).asInstanceOf[Token[TokenIdentifier]] + credentials.addToken(token.getService, token) + logInfo("Added HBase security token to credentials.") + } } catch { case e: java.lang.NoSuchMethodException => logInfo("HBase Method not found: " + e) @@ -1339,4 +1404,13 @@ object Client extends Logging { components.mkString(Path.SEPARATOR) } + /** + * Return whether delegation tokens should be retrieved for the given service when security is + * enabled. By default, tokens are retrieved, but that behavior can be changed by setting + * a service-specific configuration. + */ + def shouldGetTokens(conf: SparkConf, service: String): Boolean = { + conf.getBoolean(s"spark.yarn.security.tokens.${service}.enabled", true) + } + } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 35e990602a6c..54f62e6b723a 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -46,15 +46,14 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var keytab: String = null def isClusterMode: Boolean = userClass != null - private var driverMemory: Int = 512 // MB + private var driverMemory: Int = Utils.DEFAULT_DRIVER_MEM_MB // MB private var driverCores: Int = 1 private val driverMemOverheadKey = "spark.yarn.driver.memoryOverhead" private val amMemKey = "spark.yarn.am.memory" private val amMemOverheadKey = "spark.yarn.am.memoryOverhead" private val driverCoresKey = "spark.driver.cores" private val amCoresKey = "spark.yarn.am.cores" - private val isDynamicAllocationEnabled = - sparkConf.getBoolean("spark.dynamicAllocation.enabled", false) + private val isDynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(sparkConf) parseArgs(args.toList) loadEnvironmentArgs() @@ -97,6 +96,9 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) } numExecutors = initialNumExecutors + } else { + val numExecutorsConf = "spark.executor.instances" + numExecutors = sparkConf.getInt(numExecutorsConf, numExecutors) } principal = Option(principal) .orElse(sparkConf.getOption("spark.yarn.principal")) @@ -123,6 +125,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) throw new SparkException("Executor cores must not be less than " + "spark.task.cpus.") } + // scalastyle:off println if (isClusterMode) { for (key <- Seq(amMemKey, amMemOverheadKey, amCoresKey)) { if (sparkConf.contains(key)) { @@ -144,11 +147,13 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) .map(_.toInt) .foreach { cores => amCores = cores } } + // scalastyle:on println } private def parseArgs(inputArgs: List[String]): Unit = { var args = inputArgs + // scalastyle:off println while (!args.isEmpty) { args match { case ("--jar") :: value :: tail => @@ -193,11 +198,6 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) if (args(0) == "--num-workers") { println("--num-workers is deprecated. Use --num-executors instead.") } - // Dynamic allocation is not compatible with this option - if (isDynamicAllocationEnabled) { - throw new IllegalArgumentException("Explicitly setting the number " + - "of executors is not compatible with spark.dynamicAllocation.enabled!") - } numExecutors = value args = tail @@ -253,6 +253,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) throw new IllegalArgumentException(getUsageMessage(args)) } } + // scalastyle:on println if (primaryPyFile != null && primaryRFile != null) { throw new IllegalArgumentException("Cannot have primary-py-file and primary-r-file" + @@ -262,8 +263,9 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) private def getUsageMessage(unknownParam: List[String] = null): String = { val message = if (unknownParam != null) s"Unknown/unsupported param $unknownParam\n" else "" + val mem_mb = Utils.DEFAULT_DRIVER_MEM_MB message + - """ + s""" |Usage: org.apache.spark.deploy.yarn.Client [options] |Options: | --jar JAR_PATH Path to your application's JAR file (required in yarn-cluster @@ -275,7 +277,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) | Multiple invocations are possible, each will be passed in order. | --num-executors NUM Number of executors to start (Default: 2) | --executor-cores NUM Number of cores per executor (Default: 1). - | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: 512 Mb) + | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: $mem_mb Mb) | --driver-cores NUM Number of cores used by the driver (Default: 1). | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) | --name NAME The name of your application (Default: Spark) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala index 229c2c4d5eb3..94feb6393fd6 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala @@ -35,6 +35,9 @@ private[spark] class ExecutorDelegationTokenUpdater( @volatile private var lastCredentialsFileSuffix = 0 private val credentialsFile = sparkConf.get("spark.yarn.credentials.file") + private val freshHadoopConf = + SparkHadoopUtil.get.getConfBypassingFSCache( + hadoopConf, new Path(credentialsFile).toUri.getScheme) private val delegationTokenRenewer = Executors.newSingleThreadScheduledExecutor( @@ -49,7 +52,7 @@ private[spark] class ExecutorDelegationTokenUpdater( def updateCredentialsIfRequired(): Unit = { try { val credentialsFilePath = new Path(credentialsFile) - val remoteFs = FileSystem.get(hadoopConf) + val remoteFs = FileSystem.get(freshHadoopConf) SparkHadoopUtil.get.listFilesSorted( remoteFs, credentialsFilePath.getParent, credentialsFilePath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index b0937083bc53..9abd09b3cc7a 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -20,14 +20,13 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.URI import java.nio.ByteBuffer +import java.util.Collections -import org.apache.hadoop.fs.Path -import org.apache.hadoop.yarn.api.ApplicationConstants.Environment -import org.apache.spark.util.Utils - -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, ListBuffer} +import org.apache.hadoop.fs.Path +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.security.UserGroupInformation @@ -40,6 +39,7 @@ import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.network.util.JavaUtils +import org.apache.spark.util.Utils class ExecutorRunnable( container: Container, @@ -74,9 +74,9 @@ class ExecutorRunnable( .asInstanceOf[ContainerLaunchContext] val localResources = prepareLocalResources - ctx.setLocalResources(localResources) + ctx.setLocalResources(localResources.asJava) - ctx.setEnvironment(env) + ctx.setEnvironment(env.asJava) val credentials = UserGroupInformation.getCurrentUser().getCredentials() val dob = new DataOutputBuffer() @@ -86,11 +86,19 @@ class ExecutorRunnable( val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores, appId, localResources) - logInfo(s"Setting up executor with environment: $env") - logInfo("Setting up executor with commands: " + commands) - ctx.setCommands(commands) + logInfo(s""" + |=============================================================================== + |YARN executor launch context: + | env: + |${env.map { case (k, v) => s" $k -> $v\n" }.mkString} + | command: + | ${commands.mkString(" ")} + |=============================================================================== + """.stripMargin) - ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)) + ctx.setCommands(commands.asJava) + ctx.setApplicationACLs( + YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr).asJava) // If external shuffle service is enabled, register with the Yarn shuffle service already // started on the NodeManager and, if authentication is enabled, provide it with our secret @@ -105,7 +113,7 @@ class ExecutorRunnable( // Authentication is not enabled, so just provide dummy metadata ByteBuffer.allocate(0) } - ctx.setServiceData(Map[String, ByteBuffer]("spark_shuffle" -> secretBytes)) + ctx.setServiceData(Collections.singletonMap("spark_shuffle", secretBytes)) } // Send the start request to the ContainerManager @@ -146,7 +154,7 @@ class ExecutorRunnable( javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } sys.props.get("spark.executor.extraLibraryPath").foreach { p => - prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(p))) + prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p)))) } javaOpts += "-Djava.io.tmpdir=" + @@ -195,7 +203,7 @@ class ExecutorRunnable( val userClassPath = Client.getUserClasspath(sparkConf).flatMap { uri => val absPath = if (new File(uri.getPath()).isAbsolute()) { - uri.getPath() + Client.getClusterPath(sparkConf, uri.getPath()) } else { Client.buildPath(Environment.PWD.$(), uri.getPath()) } @@ -210,7 +218,7 @@ class ExecutorRunnable( // an inconsistent state. // TODO: If the OOM is not recoverable by rescheduling it on different node, then do // 'something' to fail job ... akin to blacklisting trackers in mapred ? - "-XX:OnOutOfMemoryError='kill %p'") ++ + YarnSparkHadoopUtil.getOutOfMemoryErrorArgument) ++ javaOpts ++ Seq("org.apache.spark.executor.CoarseGrainedExecutorBackend", "--driver-url", masterAddress.toString, @@ -307,7 +315,8 @@ class ExecutorRunnable( env("SPARK_LOG_URL_STDOUT") = s"$baseUrl/stdout?start=-4096" } - System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k, v) => env(k) = v } + System.getenv().asScala.filterKeys(_.startsWith("SPARK")) + .foreach { case (k, v) => env(k) = v } env } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala new file mode 100644 index 000000000000..081780204e42 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import scala.collection.mutable.{ArrayBuffer, HashMap, Set} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.yarn.api.records.{ContainerId, Resource} +import org.apache.hadoop.yarn.util.RackResolver + +import org.apache.spark.SparkConf + +private[yarn] case class ContainerLocalityPreferences(nodes: Array[String], racks: Array[String]) + +/** + * This strategy is calculating the optimal locality preferences of YARN containers by considering + * the node ratio of pending tasks, number of required cores/containers and and locality of current + * existing containers. The target of this algorithm is to maximize the number of tasks that + * would run locally. + * + * Consider a situation in which we have 20 tasks that require (host1, host2, host3) + * and 10 tasks that require (host1, host2, host4), besides each container has 2 cores + * and cpus per task is 1, so the required container number is 15, + * and host ratio is (host1: 30, host2: 30, host3: 20, host4: 10). + * + * 1. If requested container number (18) is more than the required container number (15): + * + * requests for 5 containers with nodes: (host1, host2, host3, host4) + * requests for 5 containers with nodes: (host1, host2, host3) + * requests for 5 containers with nodes: (host1, host2) + * requests for 3 containers with no locality preferences. + * + * The placement ratio is 3 : 3 : 2 : 1, and set the additional containers with no locality + * preferences. + * + * 2. If requested container number (10) is less than or equal to the required container number + * (15): + * + * requests for 4 containers with nodes: (host1, host2, host3, host4) + * requests for 3 containers with nodes: (host1, host2, host3) + * requests for 3 containers with nodes: (host1, host2) + * + * The placement ratio is 10 : 10 : 7 : 4, close to expected ratio (3 : 3 : 2 : 1) + * + * 3. If containers exist but none of them can match the requested localities, + * follow the method of 1 and 2. + * + * 4. If containers exist and some of them can match the requested localities. + * For example if we have 1 containers on each node (host1: 1, host2: 1: host3: 1, host4: 1), + * and the expected containers on each node would be (host1: 5, host2: 5, host3: 4, host4: 2), + * so the newly requested containers on each node would be updated to (host1: 4, host2: 4, + * host3: 3, host4: 1), 12 containers by total. + * + * 4.1 If requested container number (18) is more than newly required containers (12). Follow + * method 1 with updated ratio 4 : 4 : 3 : 1. + * + * 4.2 If request container number (10) is more than newly required containers (12). Follow + * method 2 with updated ratio 4 : 4 : 3 : 1. + * + * 5. If containers exist and existing localities can fully cover the requested localities. + * For example if we have 5 containers on each node (host1: 5, host2: 5, host3: 5, host4: 5), + * which could cover the current requested localities. This algorithm will allocate all the + * requested containers with no localities. + */ +private[yarn] class LocalityPreferredContainerPlacementStrategy( + val sparkConf: SparkConf, + val yarnConf: Configuration, + val resource: Resource) { + + // Number of CPUs per task + private val CPUS_PER_TASK = sparkConf.getInt("spark.task.cpus", 1) + + /** + * Calculate each container's node locality and rack locality + * @param numContainer number of containers to calculate + * @param numLocalityAwareTasks number of locality required tasks + * @param hostToLocalTaskCount a map to store the preferred hostname and possible task + * numbers running on it, used as hints for container allocation + * @return node localities and rack localities, each locality is an array of string, + * the length of localities is the same as number of containers + */ + def localityOfRequestedContainers( + numContainer: Int, + numLocalityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int], + allocatedHostToContainersMap: HashMap[String, Set[ContainerId]] + ): Array[ContainerLocalityPreferences] = { + val updatedHostToContainerCount = expectedHostToContainerCount( + numLocalityAwareTasks, hostToLocalTaskCount, allocatedHostToContainersMap) + val updatedLocalityAwareContainerNum = updatedHostToContainerCount.values.sum + + // The number of containers to allocate, divided into two groups, one with preferred locality, + // and the other without locality preference. + val requiredLocalityFreeContainerNum = + math.max(0, numContainer - updatedLocalityAwareContainerNum) + val requiredLocalityAwareContainerNum = numContainer - requiredLocalityFreeContainerNum + + val containerLocalityPreferences = ArrayBuffer[ContainerLocalityPreferences]() + if (requiredLocalityFreeContainerNum > 0) { + for (i <- 0 until requiredLocalityFreeContainerNum) { + containerLocalityPreferences += ContainerLocalityPreferences( + null.asInstanceOf[Array[String]], null.asInstanceOf[Array[String]]) + } + } + + if (requiredLocalityAwareContainerNum > 0) { + val largestRatio = updatedHostToContainerCount.values.max + // Round the ratio of preferred locality to the number of locality required container + // number, which is used for locality preferred host calculating. + var preferredLocalityRatio = updatedHostToContainerCount.mapValues { ratio => + val adjustedRatio = ratio.toDouble * requiredLocalityAwareContainerNum / largestRatio + adjustedRatio.ceil.toInt + } + + for (i <- 0 until requiredLocalityAwareContainerNum) { + // Only filter out the ratio which is larger than 0, which means the current host can + // still be allocated with new container request. + val hosts = preferredLocalityRatio.filter(_._2 > 0).keys.toArray + val racks = hosts.map { h => + RackResolver.resolve(yarnConf, h).getNetworkLocation + }.toSet + containerLocalityPreferences += ContainerLocalityPreferences(hosts, racks.toArray) + + // Minus 1 each time when the host is used. When the current ratio is 0, + // which means all the required ratio is satisfied, this host will not be allocated again. + preferredLocalityRatio = preferredLocalityRatio.mapValues(_ - 1) + } + } + + containerLocalityPreferences.toArray + } + + /** + * Calculate the number of executors need to satisfy the given number of pending tasks. + */ + private def numExecutorsPending(numTasksPending: Int): Int = { + val coresPerExecutor = resource.getVirtualCores + (numTasksPending * CPUS_PER_TASK + coresPerExecutor - 1) / coresPerExecutor + } + + /** + * Calculate the expected host to number of containers by considering with allocated containers. + * @param localityAwareTasks number of locality aware tasks + * @param hostToLocalTaskCount a map to store the preferred hostname and possible task + * numbers running on it, used as hints for container allocation + * @return a map with hostname as key and required number of containers on this host as value + */ + private def expectedHostToContainerCount( + localityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int], + allocatedHostToContainersMap: HashMap[String, Set[ContainerId]] + ): Map[String, Int] = { + val totalLocalTaskNum = hostToLocalTaskCount.values.sum + hostToLocalTaskCount.map { case (host, count) => + val expectedCount = + count.toDouble * numExecutorsPending(localityAwareTasks) / totalLocalTaskNum + val existedCount = allocatedHostToContainersMap.get(host) + .map(_.size) + .getOrElse(0) + + // If existing container can not fully satisfy the expected number of container, + // the required container number is expected count minus existed count. Otherwise the + // required container number is 0. + (host, math.max(0, (expectedCount - existedCount).ceil.toInt)) + } + } +} diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 940873fbd046..fd88b8b7fe3b 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -21,8 +21,9 @@ import java.util.Collections import java.util.concurrent._ import java.util.regex.Pattern -import scala.collection.JavaConversions._ +import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.JavaConverters._ import com.google.common.util.concurrent.ThreadFactoryBuilder @@ -36,6 +37,10 @@ import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef} +import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor +import org.apache.spark.util.Utils /** * YarnAllocator is charged with requesting containers from the YARN ResourceManager and deciding @@ -52,6 +57,7 @@ import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ */ private[yarn] class YarnAllocator( driverUrl: String, + driverRef: RpcEndpointRef, conf: Configuration, sparkConf: SparkConf, amClient: AMRMClient[ContainerRequest], @@ -82,12 +88,25 @@ private[yarn] class YarnAllocator( private var executorIdCounter = 0 @volatile private var numExecutorsFailed = 0 - @volatile private var targetNumExecutors = args.numExecutors + @volatile private var targetNumExecutors = + if (Utils.isDynamicAllocationEnabled(sparkConf)) { + sparkConf.getInt("spark.dynamicAllocation.initialExecutors", 0) + } else { + sparkConf.getInt("spark.executor.instances", YarnSparkHadoopUtil.DEFAULT_NUMBER_EXECUTORS) + } + + // Executor loss reason requests that are pending - maps from executor ID for inquiry to a + // list of requesters that should be responded to once we find out why the given executor + // was lost. + private val pendingLossReasonRequests = new HashMap[String, mutable.Buffer[RpcCallContext]] // Keep track of which container is running which executor to remove the executors later // Visible for testing. private[yarn] val executorIdToContainer = new HashMap[String, Container] + private var numUnexpectedContainerRelease = 0L + private val containerIdToExecutorId = new HashMap[ContainerId, String] + // Executor memory in MB. protected val executorMemory = args.executorMemory // Additional memory overhead. @@ -96,7 +115,7 @@ private[yarn] class YarnAllocator( // Number of cores per executor. protected val executorCores = args.executorCores // Resource capability requested for each executors - private val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores) + private[yarn] val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores) private val launcherPool = new ThreadPoolExecutor( // max pool size of Integer.MAX_VALUE is ignored because we use an unbounded queue @@ -127,6 +146,16 @@ private[yarn] class YarnAllocator( } } + // A map to store preferred hostname and possible task numbers running on it. + private var hostToLocalTaskCounts: Map[String, Int] = Map.empty + + // Number of tasks that have locality preferences in active stages + private var numLocalityAwareTasks: Int = 0 + + // A container placement strategy based on pending tasks' locality preference + private[yarn] val containerPlacementStrategy = + new LocalityPreferredContainerPlacementStrategy(sparkConf, conf, resource) + def getNumExecutorsRunning: Int = numExecutorsRunning def getNumExecutorsFailed: Int = numExecutorsFailed @@ -140,16 +169,25 @@ private[yarn] class YarnAllocator( * Number of container requests at the given location that have not yet been fulfilled. */ private def getNumPendingAtLocation(location: String): Int = - amClient.getMatchingRequests(RM_REQUEST_PRIORITY, location, resource).map(_.size).sum + amClient.getMatchingRequests(RM_REQUEST_PRIORITY, location, resource).asScala.map(_.size).sum /** * Request as many executors from the ResourceManager as needed to reach the desired total. If * the requested total is smaller than the current number of running executors, no executors will * be killed. - * + * @param requestedTotal total number of containers requested + * @param localityAwareTasks number of locality aware tasks to be used as container placement hint + * @param hostToLocalTaskCount a map of preferred hostname to possible task counts to be used as + * container placement hint. * @return Whether the new requested total is different than the old value. */ - def requestTotalExecutors(requestedTotal: Int): Boolean = synchronized { + def requestTotalExecutorsWithPreferredLocalities( + requestedTotal: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int]): Boolean = synchronized { + this.numLocalityAwareTasks = localityAwareTasks + this.hostToLocalTaskCounts = hostToLocalTaskCount + if (requestedTotal != targetNumExecutors) { logInfo(s"Driver requested a total number of $requestedTotal executor(s).") targetNumExecutors = requestedTotal @@ -165,6 +203,7 @@ private[yarn] class YarnAllocator( def killExecutor(executorId: String): Unit = synchronized { if (executorIdToContainer.contains(executorId)) { val container = executorIdToContainer.remove(executorId).get + containerIdToExecutorId.remove(container.getId) internalReleaseContainer(container) numExecutorsRunning -= 1 } else { @@ -197,15 +236,13 @@ private[yarn] class YarnAllocator( numExecutorsRunning, allocateResponse.getAvailableResources)) - handleAllocatedContainers(allocatedContainers) + handleAllocatedContainers(allocatedContainers.asScala) } val completedContainers = allocateResponse.getCompletedContainersStatuses() if (completedContainers.size > 0) { logDebug("Completed %d containers".format(completedContainers.size)) - - processCompletedContainers(completedContainers) - + processCompletedContainers(completedContainers.asScala) logDebug("Finished processing %d completed containers. Current running executor count: %d." .format(completedContainers.size, numExecutorsRunning)) } @@ -221,15 +258,23 @@ private[yarn] class YarnAllocator( val numPendingAllocate = getNumPendingAllocate val missing = targetNumExecutors - numPendingAllocate - numExecutorsRunning + // TODO. Consider locality preferences of pending container requests. + // Since the last time we made container requests, stages have completed and been submitted, + // and that the localities at which we requested our pending executors + // no longer apply to our current needs. We should consider to remove all outstanding + // container requests and add requests anew each time to avoid this. if (missing > 0) { logInfo(s"Will request $missing executor containers, each with ${resource.getVirtualCores} " + s"cores and ${resource.getMemory} MB memory including $memoryOverhead MB overhead") - for (i <- 0 until missing) { - val request = createContainerRequest(resource) + val containerLocalityPreferences = containerPlacementStrategy.localityOfRequestedContainers( + missing, numLocalityAwareTasks, hostToLocalTaskCounts, allocatedHostToContainersMap) + + for (locality <- containerLocalityPreferences) { + val request = createContainerRequest(resource, locality.nodes, locality.racks) amClient.addContainerRequest(request) val nodes = request.getNodes - val hostStr = if (nodes == null || nodes.isEmpty) "Any" else nodes.last + val hostStr = if (nodes == null || nodes.isEmpty) "Any" else nodes.asScala.last logInfo(s"Container request (host: $hostStr, capability: $resource)") } } else if (missing < 0) { @@ -238,7 +283,8 @@ private[yarn] class YarnAllocator( val matchingRequests = amClient.getMatchingRequests(RM_REQUEST_PRIORITY, ANY_HOST, resource) if (!matchingRequests.isEmpty) { - matchingRequests.head.take(numToCancel).foreach(amClient.removeContainerRequest) + matchingRequests.iterator().next().asScala + .take(numToCancel).foreach(amClient.removeContainerRequest) } else { logWarning("Expected to find pending requests, but found none.") } @@ -249,11 +295,14 @@ private[yarn] class YarnAllocator( * Creates a container request, handling the reflection required to use YARN features that were * added in recent versions. */ - private def createContainerRequest(resource: Resource): ContainerRequest = { + protected def createContainerRequest( + resource: Resource, + nodes: Array[String], + racks: Array[String]): ContainerRequest = { nodeLabelConstructor.map { constructor => - constructor.newInstance(resource, null, null, RM_REQUEST_PRIORITY, true: java.lang.Boolean, + constructor.newInstance(resource, nodes, racks, RM_REQUEST_PRIORITY, true: java.lang.Boolean, labelExpression.orNull) - }.getOrElse(new ContainerRequest(resource, null, null, RM_REQUEST_PRIORITY)) + }.getOrElse(new ContainerRequest(resource, nodes, racks, RM_REQUEST_PRIORITY)) } /** @@ -353,6 +402,7 @@ private[yarn] class YarnAllocator( logInfo("Launching container %s for on host %s".format(containerId, executorHostname)) executorIdToContainer(executorId) = container + containerIdToExecutorId(container.getId) = executorId val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname, new HashSet[ContainerId]) @@ -383,12 +433,8 @@ private[yarn] class YarnAllocator( private[yarn] def processCompletedContainers(completedContainers: Seq[ContainerStatus]): Unit = { for (completedContainer <- completedContainers) { val containerId = completedContainer.getContainerId - - if (releasedContainers.contains(containerId)) { - // Already marked the container for release, so remove it from - // `releasedContainers`. - releasedContainers.remove(containerId) - } else { + val alreadyReleased = releasedContainers.remove(containerId) + val exitReason = if (!alreadyReleased) { // Decrement the number of executors running. The next iteration of // the ApplicationMaster's reporting thread will take care of allocating. numExecutorsRunning -= 1 @@ -399,25 +445,45 @@ private[yarn] class YarnAllocator( // Hadoop 2.2.X added a ContainerExitStatus we should switch to use // there are some exit status' we shouldn't necessarily count against us, but for // now I think its ok as none of the containers are expected to exit - if (completedContainer.getExitStatus == ContainerExitStatus.PREEMPTED) { - logInfo("Container preempted: " + containerId) - } else if (completedContainer.getExitStatus == -103) { // vmem limit exceeded - logWarning(memLimitExceededLogMessage( - completedContainer.getDiagnostics, - VMEM_EXCEEDED_PATTERN)) - } else if (completedContainer.getExitStatus == -104) { // pmem limit exceeded - logWarning(memLimitExceededLogMessage( - completedContainer.getDiagnostics, - PMEM_EXCEEDED_PATTERN)) - } else if (completedContainer.getExitStatus != 0) { - logInfo("Container marked as failed: " + containerId + - ". Exit status: " + completedContainer.getExitStatus + - ". Diagnostics: " + completedContainer.getDiagnostics) - numExecutorsFailed += 1 + val exitStatus = completedContainer.getExitStatus + val (isNormalExit, containerExitReason) = exitStatus match { + case ContainerExitStatus.SUCCESS => + (true, s"Executor for container $containerId exited normally.") + case ContainerExitStatus.PREEMPTED => + // Preemption should count as a normal exit, since YARN preempts containers merely + // to do resource sharing, and tasks that fail due to preempted executors could + // just as easily finish on any other executor. See SPARK-8167. + (true, s"Container $containerId was preempted.") + // Should probably still count memory exceeded exit codes towards task failures + case VMEM_EXCEEDED_EXIT_CODE => + (false, memLimitExceededLogMessage( + completedContainer.getDiagnostics, + VMEM_EXCEEDED_PATTERN)) + case PMEM_EXCEEDED_EXIT_CODE => + (false, memLimitExceededLogMessage( + completedContainer.getDiagnostics, + PMEM_EXCEEDED_PATTERN)) + case unknown => + numExecutorsFailed += 1 + (false, "Container marked as failed: " + containerId + + ". Exit status: " + completedContainer.getExitStatus + + ". Diagnostics: " + completedContainer.getDiagnostics) + } + if (isNormalExit) { + logInfo(containerExitReason) + } else { + logWarning(containerExitReason) + } + ExecutorExited(0, isNormalExit, containerExitReason) + } else { + // If we have already released this container, then it must mean + // that the driver has explicitly requested it to be killed + ExecutorExited(completedContainer.getExitStatus, isNormalExit = true, + s"Container $containerId exited from explicit termination request.") } - if (allocatedContainerToHostMap.containsKey(containerId)) { + if (allocatedContainerToHostMap.contains(containerId)) { val host = allocatedContainerToHostMap.get(containerId).get val containerSet = allocatedHostToContainersMap.get(host).get @@ -430,6 +496,35 @@ private[yarn] class YarnAllocator( allocatedContainerToHostMap.remove(containerId) } + + containerIdToExecutorId.remove(containerId).foreach { eid => + executorIdToContainer.remove(eid) + pendingLossReasonRequests.remove(eid).foreach { pendingRequests => + // Notify application of executor loss reasons so it can decide whether it should abort + pendingRequests.foreach(_.reply(exitReason)) + } + if (!alreadyReleased) { + // The executor could have gone away (like no route to host, node failure, etc) + // Notify backend about the failure of the executor + numUnexpectedContainerRelease += 1 + driverRef.send(RemoveExecutor(eid, exitReason)) + } + } + } + } + + /** + * Register that some RpcCallContext has asked the AM why the executor was lost. Note that + * we can only find the loss reason to send back in the next call to allocateResources(). + */ + private[yarn] def enqueueGetLossReasonRequest( + eid: String, + context: RpcCallContext): Unit = synchronized { + if (executorIdToContainer.contains(eid)) { + pendingLossReasonRequests + .getOrElseUpdate(eid, new ArrayBuffer[RpcCallContext]) += context + } else { + logWarning(s"Tried to get the loss reason for non-existent executor $eid") } } @@ -438,6 +533,8 @@ private[yarn] class YarnAllocator( amClient.releaseAssignedContainer(container.getId()) } + private[yarn] def getNumUnexpectedContainerRelease = numUnexpectedContainerRelease + } private object YarnAllocator { @@ -446,6 +543,8 @@ private object YarnAllocator { Pattern.compile(s"$MEM_REGEX of $MEM_REGEX physical memory used") val VMEM_EXCEEDED_PATTERN = Pattern.compile(s"$MEM_REGEX of $MEM_REGEX virtual memory used") + val VMEM_EXCEEDED_EXIT_CODE = -103 + val PMEM_EXCEEDED_EXIT_CODE = -104 def memLimitExceededLogMessage(diagnostics: String, pattern: Pattern): String = { val matcher = pattern.matcher(diagnostics) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 7f533ee55e8b..df042bf291de 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -19,20 +19,19 @@ package org.apache.spark.deploy.yarn import java.util.{List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.{Map, Set} import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.hadoop.yarn.webapp.util.WebAppUtils import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.SplitInfo import org.apache.spark.util.Utils @@ -56,6 +55,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg */ def register( driverUrl: String, + driverRef: RpcEndpointRef, conf: YarnConfiguration, sparkConf: SparkConf, preferredNodeLocations: Map[String, Set[SplitInfo]], @@ -73,7 +73,8 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) registered = true } - new YarnAllocator(driverUrl, conf, sparkConf, amClient, getAttemptId(), args, securityMgr) + new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), args, + securityMgr) } /** @@ -105,8 +106,8 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg val method = classOf[WebAppUtils].getMethod("getProxyHostsAndPortsForAmFilter", classOf[Configuration]) val proxies = method.invoke(null, conf).asInstanceOf[JList[String]] - val hosts = proxies.map { proxy => proxy.split(":")(0) } - val uriBases = proxies.map { proxy => prefix + proxy + proxyBase } + val hosts = proxies.asScala.map { proxy => proxy.split(":")(0) } + val uriBases = proxies.asScala.map { proxy => prefix + proxy + proxyBase } Map("PROXY_HOSTS" -> hosts.mkString(","), "PROXY_URI_BASES" -> uriBases.mkString(",")) } catch { case e: NoSuchMethodException => diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 68d01c17ef72..445d3dcd266d 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -37,6 +37,7 @@ import org.apache.hadoop.yarn.api.records.{ApplicationAccessType, ContainerId, P import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.launcher.YarnCommandBuilderUtils import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.util.Utils @@ -219,26 +220,61 @@ object YarnSparkHadoopUtil { } } + /** + * The handler if an OOM Exception is thrown by the JVM must be configured on Windows + * differently: the 'taskkill' command should be used, whereas Unix-based systems use 'kill'. + * + * As the JVM interprets both %p and %%p as the same, we can use either of them. However, + * some tests on Windows computers suggest, that the JVM only accepts '%%p'. + * + * Furthermore, the behavior of the character '%' on the Windows command line differs from + * the behavior of '%' in a .cmd file: it gets interpreted as an incomplete environment + * variable. Windows .cmd files escape a '%' by '%%'. Thus, the correct way of writing + * '%%p' in an escaped way is '%%%%p'. + * + * @return The correct OOM Error handler JVM option, platform dependent. + */ + def getOutOfMemoryErrorArgument : String = { + if (Utils.isWindows) { + escapeForShell("-XX:OnOutOfMemoryError=taskkill /F /PID %%%%p") + } else { + "-XX:OnOutOfMemoryError='kill %p'" + } + } + /** * Escapes a string for inclusion in a command line executed by Yarn. Yarn executes commands - * using `bash -c "command arg1 arg2"` and that means plain quoting doesn't really work. The - * argument is enclosed in single quotes and some key characters are escaped. + * using either + * + * (Unix-based) `bash -c "command arg1 arg2"` and that means plain quoting doesn't really work. + * The argument is enclosed in single quotes and some key characters are escaped. + * + * (Windows-based) part of a .cmd file in which case windows escaping for each argument must be + * applied. Windows is quite lenient, however it is usually Java that causes trouble, needing to + * distinguish between arguments starting with '-' and class names. If arguments are surrounded + * by ' java takes the following string as is, hence an argument is mistakenly taken as a class + * name which happens to start with a '-'. The way to avoid this, is to surround nothing with + * a ', but instead with a ". * * @param arg A single argument. * @return Argument quoted for execution via Yarn's generated shell script. */ def escapeForShell(arg: String): String = { if (arg != null) { - val escaped = new StringBuilder("'") - for (i <- 0 to arg.length() - 1) { - arg.charAt(i) match { - case '$' => escaped.append("\\$") - case '"' => escaped.append("\\\"") - case '\'' => escaped.append("'\\''") - case c => escaped.append(c) + if (Utils.isWindows) { + YarnCommandBuilderUtils.quoteForBatchScript(arg) + } else { + val escaped = new StringBuilder("'") + for (i <- 0 to arg.length() - 1) { + arg.charAt(i) match { + case '$' => escaped.append("\\$") + case '"' => escaped.append("\\\"") + case '\'' => escaped.append("'\\''") + case c => escaped.append(c) + } } + escaped.append("'").toString() } - escaped.append("'").toString() } else { arg } diff --git a/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala b/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala new file mode 100644 index 000000000000..3ac36ef0a1c3 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher + +/** + * Exposes needed methods + */ +private[spark] object YarnCommandBuilderUtils { + def quoteForBatchScript(arg: String) : String = { + CommandBuilderUtils.quoteForBatchScript(arg) + } +} diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 1c8d7ec57635..d06d95140438 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -20,10 +20,9 @@ package org.apache.spark.scheduler.cluster import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} -import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException import org.apache.spark.{SparkException, Logging, SparkContext} -import org.apache.spark.deploy.yarn.{Client, ClientArguments} +import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil} import org.apache.spark.scheduler.TaskSchedulerImpl private[spark] class YarnClientSchedulerBackend( @@ -34,14 +33,13 @@ private[spark] class YarnClientSchedulerBackend( private var client: Client = null private var appId: ApplicationId = null - private var monitorThread: Thread = null + private var monitorThread: MonitorThread = null /** * Create a Yarn client to submit an application to the ResourceManager. * This waits until the application is running. */ override def start() { - super.start() val driverHost = conf.get("spark.driver.host") val driverPort = conf.get("spark.driver.port") val hostport = driverHost + ":" + driverPort @@ -56,7 +54,20 @@ private[spark] class YarnClientSchedulerBackend( totalExpectedExecutors = args.numExecutors client = new Client(args, conf) appId = client.submitApplication() + + // SPARK-8687: Ensure all necessary properties have already been set before + // we initialize our driver scheduler backend, which serves these properties + // to the executors + super.start() + waitForApplication() + + // SPARK-8851: In yarn-client mode, the AM still does the credentials refresh. The driver + // reads the credentials from HDFS, just like the executors and updates its own credentials + // cache. + if (conf.contains("spark.yarn.credentials.file")) { + YarnSparkHadoopUtil.get.startExecutorDelegationTokenRenewer(conf) + } monitorThread = asyncMonitorApplication() monitorThread.start() } @@ -70,8 +81,6 @@ private[spark] class YarnClientSchedulerBackend( // List of (target Client argument, environment variable, Spark property) val optionTuples = List( - ("--num-executors", "SPARK_WORKER_INSTANCES", "spark.executor.instances"), - ("--num-executors", "SPARK_EXECUTOR_INSTANCES", "spark.executor.instances"), ("--executor-memory", "SPARK_WORKER_MEMORY", "spark.executor.memory"), ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"), ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"), @@ -81,7 +90,6 @@ private[spark] class YarnClientSchedulerBackend( ) // Warn against the following deprecated environment variables: env var -> suggestion val deprecatedEnvVars = Map( - "SPARK_WORKER_INSTANCES" -> "SPARK_WORKER_INSTANCES or --num-executors through spark-submit", "SPARK_WORKER_MEMORY" -> "SPARK_EXECUTOR_MEMORY or --executor-memory through spark-submit", "SPARK_WORKER_CORES" -> "SPARK_EXECUTOR_CORES or --executor-cores through spark-submit") optionTuples.foreach { case (optionName, envVar, sparkProp) => @@ -120,24 +128,42 @@ private[spark] class YarnClientSchedulerBackend( } } + /** + * We create this class for SPARK-9519. Basically when we interrupt the monitor thread it's + * because the SparkContext is being shut down(sc.stop() called by user code), but if + * monitorApplication return, it means the Yarn application finished before sc.stop() was called, + * which means we should call sc.stop() here, and we don't allow the monitor to be interrupted + * before SparkContext stops successfully. + */ + private class MonitorThread extends Thread { + private var allowInterrupt = true + + override def run() { + try { + val (state, _) = client.monitorApplication(appId, logApplicationReport = false) + logError(s"Yarn application has already exited with state $state!") + allowInterrupt = false + sc.stop() + } catch { + case e: InterruptedException => logInfo("Interrupting monitor thread") + } + } + + def stopMonitor(): Unit = { + if (allowInterrupt) { + this.interrupt() + } + } + } + /** * Monitor the application state in a separate thread. * If the application has exited for any reason, stop the SparkContext. * This assumes both `client` and `appId` have already been set. */ - private def asyncMonitorApplication(): Thread = { + private def asyncMonitorApplication(): MonitorThread = { assert(client != null && appId != null, "Application has not been submitted yet!") - val t = new Thread { - override def run() { - try { - val (state, _) = client.monitorApplication(appId, logApplicationReport = false) - logError(s"Yarn application has already exited with state $state!") - sc.stop() - } catch { - case e: InterruptedException => logInfo("Interrupting monitor thread") - } - } - } + val t = new MonitorThread t.setName("Yarn application state monitor") t.setDaemon(true) t @@ -148,9 +174,12 @@ private[spark] class YarnClientSchedulerBackend( */ override def stop() { assert(client != null, "Attempted to stop this scheduler before starting it!") - monitorThread.interrupt() + if (monitorThread != null) { + monitorThread.stopMonitor() + } super.stop() client.stop() + YarnSparkHadoopUtil.get.stopExecutorDelegationTokenRenewer() logInfo("Stopped") } @@ -160,5 +189,4 @@ private[spark] class YarnClientSchedulerBackend( super.applicationId } } - } diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index 33f580aaebdc..1aed5a167507 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -19,6 +19,8 @@ package org.apache.spark.scheduler.cluster import java.net.NetworkInterface +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment + import scala.collection.JavaConverters._ import org.apache.hadoop.yarn.api.records.NodeState @@ -64,68 +66,29 @@ private[spark] class YarnClusterSchedulerBackend( } override def getDriverLogUrls: Option[Map[String, String]] = { - var yarnClientOpt: Option[YarnClient] = None var driverLogs: Option[Map[String, String]] = None try { val yarnConf = new YarnConfiguration(sc.hadoopConfiguration) val containerId = YarnSparkHadoopUtil.get.getContainerId - yarnClientOpt = Some(YarnClient.createYarnClient()) - yarnClientOpt.foreach { yarnClient => - yarnClient.init(yarnConf) - yarnClient.start() - - // For newer versions of YARN, we can find the HTTP address for a given node by getting a - // container report for a given container. But container reports came only in Hadoop 2.4, - // so we basically have to get the node reports for all nodes and find the one which runs - // this container. For that we have to compare the node's host against the current host. - // Since the host can have multiple addresses, we need to compare against all of them to - // find out if one matches. - - // Get all the addresses of this node. - val addresses = - NetworkInterface.getNetworkInterfaces.asScala - .flatMap(_.getInetAddresses.asScala) - .toSeq - - // Find a node report that matches one of the addresses - val nodeReport = - yarnClient.getNodeReports(NodeState.RUNNING).asScala.find { x => - val host = x.getNodeId.getHost - addresses.exists { address => - address.getHostAddress == host || - address.getHostName == host || - address.getCanonicalHostName == host - } - } - // Now that we have found the report for the Node Manager that the AM is running on, we - // can get the base HTTP address for the Node manager from the report. - // The format used for the logs for each container is well-known and can be constructed - // using the NM's HTTP address and the container ID. - // The NM may be running several containers, but we can build the URL for the AM using - // the AM's container ID, which we already know. - nodeReport.foreach { report => - val httpAddress = report.getHttpAddress - // lookup appropriate http scheme for container log urls - val yarnHttpPolicy = yarnConf.get( - YarnConfiguration.YARN_HTTP_POLICY_KEY, - YarnConfiguration.YARN_HTTP_POLICY_DEFAULT - ) - val user = Utils.getCurrentUserName() - val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://" - val baseUrl = s"$httpScheme$httpAddress/node/containerlogs/$containerId/$user" - logDebug(s"Base URL for logs: $baseUrl") - driverLogs = Some(Map( - "stderr" -> s"$baseUrl/stderr?start=-4096", - "stdout" -> s"$baseUrl/stdout?start=-4096")) - } - } + val httpAddress = System.getenv(Environment.NM_HOST.name()) + + ":" + System.getenv(Environment.NM_HTTP_PORT.name()) + // lookup appropriate http scheme for container log urls + val yarnHttpPolicy = yarnConf.get( + YarnConfiguration.YARN_HTTP_POLICY_KEY, + YarnConfiguration.YARN_HTTP_POLICY_DEFAULT + ) + val user = Utils.getCurrentUserName() + val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://" + val baseUrl = s"$httpScheme$httpAddress/node/containerlogs/$containerId/$user" + logDebug(s"Base URL for logs: $baseUrl") + driverLogs = Some(Map( + "stderr" -> s"$baseUrl/stderr?start=-4096", + "stdout" -> s"$baseUrl/stdout?start=-4096")) } catch { case e: Exception => - logInfo("Node Report API is not available in the version of YARN being used, so AM" + + logInfo("Error while building AM log links, so AM" + " logs link will not appear in application UI", e) - } finally { - yarnClientOpt.foreach(_.close()) } driverLogs } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala new file mode 100644 index 000000000000..17c59ff06e0c --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.io.{File, FileOutputStream, OutputStreamWriter} +import java.util.Properties +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ + +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.Files +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.server.MiniYARNCluster +import org.scalatest.{BeforeAndAfterAll, Matchers} + +import org.apache.spark._ +import org.apache.spark.launcher.TestClasspathBuilder +import org.apache.spark.util.Utils + +abstract class BaseYarnClusterSuite + extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging { + + // log4j configuration for the YARN containers, so that their output is collected + // by YARN instead of trying to overwrite unit-tests.log. + protected val LOG4J_CONF = """ + |log4j.rootCategory=DEBUG, console + |log4j.appender.console=org.apache.log4j.ConsoleAppender + |log4j.appender.console.target=System.err + |log4j.appender.console.layout=org.apache.log4j.PatternLayout + |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + |log4j.logger.org.apache.hadoop=WARN + |log4j.logger.org.eclipse.jetty=WARN + |log4j.logger.org.spark-project.jetty=WARN + """.stripMargin + + private var yarnCluster: MiniYARNCluster = _ + protected var tempDir: File = _ + private var fakeSparkJar: File = _ + private var hadoopConfDir: File = _ + private var logConfDir: File = _ + + def newYarnConfig(): YarnConfiguration + + override def beforeAll() { + super.beforeAll() + + tempDir = Utils.createTempDir() + logConfDir = new File(tempDir, "log4j") + logConfDir.mkdir() + System.setProperty("SPARK_YARN_MODE", "true") + + val logConfFile = new File(logConfDir, "log4j.properties") + Files.write(LOG4J_CONF, logConfFile, UTF_8) + + // Disable the disk utilization check to avoid the test hanging when people's disks are + // getting full. + val yarnConf = newYarnConfig() + yarnConf.set("yarn.nodemanager.disk-health-checker.max-disk-utilization-per-disk-percentage", + "100.0") + + yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1) + yarnCluster.init(yarnConf) + yarnCluster.start() + + // There's a race in MiniYARNCluster in which start() may return before the RM has updated + // its address in the configuration. You can see this in the logs by noticing that when + // MiniYARNCluster prints the address, it still has port "0" assigned, although later the + // test works sometimes: + // + // INFO MiniYARNCluster: MiniYARN ResourceManager address: blah:0 + // + // That log message prints the contents of the RM_ADDRESS config variable. If you check it + // later on, it looks something like this: + // + // INFO YarnClusterSuite: RM address in configuration is blah:42631 + // + // This hack loops for a bit waiting for the port to change, and fails the test if it hasn't + // done so in a timely manner (defined to be 10 seconds). + val config = yarnCluster.getConfig() + val deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(10) + while (config.get(YarnConfiguration.RM_ADDRESS).split(":")(1) == "0") { + if (System.currentTimeMillis() > deadline) { + throw new IllegalStateException("Timed out waiting for RM to come up.") + } + logDebug("RM address still not set in configuration, waiting...") + TimeUnit.MILLISECONDS.sleep(100) + } + + logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}") + + fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) + hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR) + assert(hadoopConfDir.mkdir()) + File.createTempFile("token", ".txt", hadoopConfDir) + } + + override def afterAll() { + yarnCluster.stop() + System.clearProperty("SPARK_YARN_MODE") + super.afterAll() + } + + protected def runSpark( + clientMode: Boolean, + klass: String, + appArgs: Seq[String] = Nil, + sparkArgs: Seq[String] = Nil, + extraClassPath: Seq[String] = Nil, + extraJars: Seq[String] = Nil, + extraConf: Map[String, String] = Map(), + extraEnv: Map[String, String] = Map()): Unit = { + val master = if (clientMode) "yarn-client" else "yarn-cluster" + val props = new Properties() + + props.setProperty("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath()) + + val testClasspath = new TestClasspathBuilder() + .buildClassPath( + logConfDir.getAbsolutePath() + + File.pathSeparator + + extraClassPath.mkString(File.pathSeparator)) + .asScala + .mkString(File.pathSeparator) + + props.setProperty("spark.driver.extraClassPath", testClasspath) + props.setProperty("spark.executor.extraClassPath", testClasspath) + + // SPARK-4267: make sure java options are propagated correctly. + props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") + props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"") + + yarnCluster.getConfig.asScala.foreach { e => + props.setProperty("spark.hadoop." + e.getKey(), e.getValue()) + } + + sys.props.foreach { case (k, v) => + if (k.startsWith("spark.")) { + props.setProperty(k, v) + } + } + + extraConf.foreach { case (k, v) => props.setProperty(k, v) } + + val propsFile = File.createTempFile("spark", ".properties", tempDir) + val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8) + props.store(writer, "Spark properties.") + writer.close() + + val extraJarArgs = if (extraJars.nonEmpty) Seq("--jars", extraJars.mkString(",")) else Nil + val mainArgs = + if (klass.endsWith(".py")) { + Seq(klass) + } else { + Seq("--class", klass, fakeSparkJar.getAbsolutePath()) + } + val argv = + Seq( + new File(sys.props("spark.test.home"), "bin/spark-submit").getAbsolutePath(), + "--master", master, + "--num-executors", "1", + "--properties-file", propsFile.getAbsolutePath()) ++ + extraJarArgs ++ + sparkArgs ++ + mainArgs ++ + appArgs + + Utils.executeAndGetOutput(argv, + extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()) ++ extraEnv) + } + + /** + * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide + * any sort of error when the job process finishes successfully, but the job itself fails. So + * the tests enforce that something is written to a file after everything is ok to indicate + * that the job succeeded. + */ + protected def checkResult(result: File): Unit = { + checkResult(result, "success") + } + + protected def checkResult(result: File, expected: String): Unit = { + val resultString = Files.toString(result, UTF_8) + resultString should be (expected) + } + + protected def mainClassName(klass: Class[_]): String = { + klass.getName().stripSuffix("$") + } + +} diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 4ec976aa3138..e7f2501e7899 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.URI -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ HashMap => MutableHashMap } +import scala.collection.JavaConverters._ +import scala.collection.mutable.{HashMap => MutableHashMap} import scala.reflect.ClassTag import scala.util.Try @@ -29,13 +29,16 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.MRJobConfig import org.apache.hadoop.yarn.api.ApplicationConstants.Environment +import org.apache.hadoop.yarn.api.protocolrecords.GetNewApplicationResponse import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.client.api.YarnClientApplication import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.util.Records import org.mockito.Matchers._ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfterAll, Matchers} -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.Utils class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { @@ -151,6 +154,58 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { } } + test("Cluster path translation") { + val conf = new Configuration() + val sparkConf = new SparkConf() + .set(Client.CONF_SPARK_JAR, "local:/localPath/spark.jar") + .set("spark.yarn.config.gatewayPath", "/localPath") + .set("spark.yarn.config.replacementPath", "/remotePath") + + Client.getClusterPath(sparkConf, "/localPath") should be ("/remotePath") + Client.getClusterPath(sparkConf, "/localPath/1:/localPath/2") should be ( + "/remotePath/1:/remotePath/2") + + val env = new MutableHashMap[String, String]() + Client.populateClasspath(null, conf, sparkConf, env, false, + extraClassPath = Some("/localPath/my1.jar")) + val cp = classpath(env) + cp should contain ("/remotePath/spark.jar") + cp should contain ("/remotePath/my1.jar") + } + + test("configuration and args propagate through createApplicationSubmissionContext") { + val conf = new Configuration() + // When parsing tags, duplicates and leading/trailing whitespace should be removed. + // Spaces between non-comma strings should be preserved as single tags. Empty strings may or + // may not be removed depending on the version of Hadoop being used. + val sparkConf = new SparkConf() + .set(Client.CONF_SPARK_YARN_APPLICATION_TAGS, ",tag1, dup,tag2 , ,multi word , dup") + .set("spark.yarn.maxAppAttempts", "42") + val args = new ClientArguments(Array( + "--name", "foo-test-app", + "--queue", "staging-queue"), sparkConf) + + val appContext = Records.newRecord(classOf[ApplicationSubmissionContext]) + val getNewApplicationResponse = Records.newRecord(classOf[GetNewApplicationResponse]) + val containerLaunchContext = Records.newRecord(classOf[ContainerLaunchContext]) + + val client = new Client(args, conf, sparkConf) + client.createApplicationSubmissionContext( + new YarnClientApplication(getNewApplicationResponse, appContext), + containerLaunchContext) + + appContext.getApplicationName should be ("foo-test-app") + appContext.getQueue should be ("staging-queue") + appContext.getAMContainerSpec should be (containerLaunchContext) + appContext.getApplicationType should be ("SPARK") + appContext.getClass.getMethods.filter(_.getName.equals("getApplicationTags")).foreach{ method => + val tags = method.invoke(appContext).asInstanceOf[java.util.Set[String]] + tags should contain allOf ("tag1", "dup", "tag2", "multi word") + tags.asScala.filter(_.nonEmpty).size should be (4) + } + appContext.getMaxAppAttempts should be (42) + } + object Fixtures { val knownDefYarnAppCP: Seq[String] = diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala new file mode 100644 index 000000000000..b7fe4ccc67a3 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import org.scalatest.{BeforeAndAfterEach, Matchers} + +import org.apache.spark.SparkFunSuite + +class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers with BeforeAndAfterEach { + + private val yarnAllocatorSuite = new YarnAllocatorSuite + import yarnAllocatorSuite._ + + override def beforeEach() { + yarnAllocatorSuite.beforeEach() + } + + override def afterEach() { + yarnAllocatorSuite.afterEach() + } + + test("allocate locality preferred containers with enough resource and no matched existed " + + "containers") { + // 1. All the locations of current containers cannot satisfy the new requirements + // 2. Current requested container number can fully satisfy the pending tasks. + + val handler = createAllocator(2) + handler.updateResourceRequests() + handler.handleAllocatedContainers(Array(createContainer("host1"), createContainer("host2"))) + + val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( + 3, 15, Map("host3" -> 15, "host4" -> 15, "host5" -> 10), handler.allocatedHostToContainersMap) + + assert(localities.map(_.nodes) === Array( + Array("host3", "host4", "host5"), + Array("host3", "host4", "host5"), + Array("host3", "host4"))) + } + + test("allocate locality preferred containers with enough resource and partially matched " + + "containers") { + // 1. Parts of current containers' locations can satisfy the new requirements + // 2. Current requested container number can fully satisfy the pending tasks. + + val handler = createAllocator(3) + handler.updateResourceRequests() + handler.handleAllocatedContainers(Array( + createContainer("host1"), + createContainer("host1"), + createContainer("host2") + )) + + val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( + 3, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), handler.allocatedHostToContainersMap) + + assert(localities.map(_.nodes) === + Array(null, Array("host2", "host3"), Array("host2", "host3"))) + } + + test("allocate locality preferred containers with limited resource and partially matched " + + "containers") { + // 1. Parts of current containers' locations can satisfy the new requirements + // 2. Current requested container number cannot fully satisfy the pending tasks. + + val handler = createAllocator(3) + handler.updateResourceRequests() + handler.handleAllocatedContainers(Array( + createContainer("host1"), + createContainer("host1"), + createContainer("host2") + )) + + val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( + 1, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), handler.allocatedHostToContainersMap) + + assert(localities.map(_.nodes) === Array(Array("host2", "host3"))) + } + + test("allocate locality preferred containers with fully matched containers") { + // Current containers' locations can fully satisfy the new requirements + + val handler = createAllocator(5) + handler.updateResourceRequests() + handler.handleAllocatedContainers(Array( + createContainer("host1"), + createContainer("host1"), + createContainer("host2"), + createContainer("host2"), + createContainer("host3") + )) + + val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( + 3, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), handler.allocatedHostToContainersMap) + + assert(localities.map(_.nodes) === Array(null, null, null)) + } + + test("allocate containers with no locality preference") { + // Request new container without locality preference + + val handler = createAllocator(2) + handler.updateResourceRequests() + handler.handleAllocatedContainers(Array(createContainer("host1"), createContainer("host2"))) + + val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( + 1, 0, Map.empty, handler.allocatedHostToContainersMap) + + assert(localities.map(_.nodes) === Array(null)) + } +} diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 7509000771d9..5d05f514adde 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -25,15 +25,18 @@ import org.apache.hadoop.net.DNSToSwitchMapping import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest +import org.scalatest.{BeforeAndAfterEach, Matchers} + +import org.scalatest.{BeforeAndAfterEach, Matchers} +import org.mockito.Mockito._ import org.apache.spark.{SecurityManager, SparkFunSuite} import org.apache.spark.SparkConf import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.YarnAllocator._ +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.SplitInfo -import org.scalatest.{BeforeAndAfterEach, Matchers} - class MockResolver extends DNSToSwitchMapping { override def resolve(names: JList[String]): JList[String] = { @@ -84,15 +87,17 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter def createAllocator(maxExecutors: Int = 5): YarnAllocator = { val args = Array( - "--num-executors", s"$maxExecutors", "--executor-cores", "5", "--executor-memory", "2048", "--jar", "somejar.jar", "--class", "SomeClass") + val sparkConfClone = sparkConf.clone() + sparkConfClone.set("spark.executor.instances", maxExecutors.toString) new YarnAllocator( "not used", + mock(classOf[RpcEndpointRef]), conf, - sparkConf, + sparkConfClone, rmClient, appAttemptId, new ApplicationMasterArguments(args), @@ -171,7 +176,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.getNumExecutorsRunning should be (0) handler.getNumPendingAllocate should be (4) - handler.requestTotalExecutors(3) + handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map.empty) handler.updateResourceRequests() handler.getNumPendingAllocate should be (3) @@ -182,7 +187,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1") handler.allocatedHostToContainersMap.get("host1").get should contain (container.getId) - handler.requestTotalExecutors(2) + handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map.empty) handler.updateResourceRequests() handler.getNumPendingAllocate should be (1) } @@ -193,7 +198,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.getNumExecutorsRunning should be (0) handler.getNumPendingAllocate should be (4) - handler.requestTotalExecutors(3) + handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map.empty) handler.updateResourceRequests() handler.getNumPendingAllocate should be (3) @@ -203,7 +208,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.getNumExecutorsRunning should be (2) - handler.requestTotalExecutors(1) + handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map.empty) handler.updateResourceRequests() handler.getNumPendingAllocate should be (0) handler.getNumExecutorsRunning should be (2) @@ -219,7 +224,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter val container2 = createContainer("host2") handler.handleAllocatedContainers(Array(container1, container2)) - handler.requestTotalExecutors(1) + handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map.empty) handler.executorIdToContainer.keys.foreach { id => handler.killExecutor(id ) } val statuses = Seq(container1, container2).map { c => @@ -231,6 +236,30 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.getNumPendingAllocate should be (1) } + test("lost executor removed from backend") { + val handler = createAllocator(4) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (4) + + val container1 = createContainer("host1") + val container2 = createContainer("host2") + handler.handleAllocatedContainers(Array(container1, container2)) + + handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map()) + + val statuses = Seq(container1, container2).map { c => + ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Failed", -1) + } + handler.updateResourceRequests() + handler.processCompletedContainers(statuses.toSeq) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (2) + handler.getNumExecutorsFailed should be (2) + handler.getNumUnexpectedContainerRelease should be (2) + } + test("memory exceeded diagnostic regexes") { val diagnostics = "Container [pid=12465,containerID=container_1412887393566_0003_01_000002] is running " + @@ -241,5 +270,4 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter assert(vmemMsg.contains("5.8 GB of 4.2 GB virtual memory used.")) assert(pmemMsg.contains("2.1 MB of 2 GB physical memory used.")) } - } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 335e966519c7..b5a42fd6afd9 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -17,25 +17,21 @@ package org.apache.spark.deploy.yarn -import java.io.{File, FileOutputStream, OutputStreamWriter} +import java.io.File import java.net.URL -import java.util.Properties -import java.util.concurrent.TimeUnit -import scala.collection.JavaConversions._ import scala.collection.mutable import com.google.common.base.Charsets.UTF_8 -import com.google.common.io.ByteStreams -import com.google.common.io.Files +import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.server.MiniYARNCluster -import org.scalatest.{BeforeAndAfterAll, Matchers} +import org.scalatest.Matchers import org.apache.spark._ -import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.launcher.TestClasspathBuilder import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, SparkListenerExecutorAdded} +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.Utils /** @@ -43,17 +39,9 @@ import org.apache.spark.util.Utils * applications, and require the Spark assembly to be built before they can be successfully * run. */ -class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging { - - // log4j configuration for the YARN containers, so that their output is collected - // by YARN instead of trying to overwrite unit-tests.log. - private val LOG4J_CONF = """ - |log4j.rootCategory=DEBUG, console - |log4j.appender.console=org.apache.log4j.ConsoleAppender - |log4j.appender.console.target=System.err - |log4j.appender.console.layout=org.apache.log4j.PatternLayout - |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n - """.stripMargin +class YarnClusterSuite extends BaseYarnClusterSuite { + + override def newYarnConfig(): YarnConfiguration = new YarnConfiguration() private val TEST_PYFILE = """ |import mod1, mod2 @@ -82,65 +70,6 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher | return 42 """.stripMargin - private var yarnCluster: MiniYARNCluster = _ - private var tempDir: File = _ - private var fakeSparkJar: File = _ - private var hadoopConfDir: File = _ - private var logConfDir: File = _ - - override def beforeAll() { - super.beforeAll() - - tempDir = Utils.createTempDir() - logConfDir = new File(tempDir, "log4j") - logConfDir.mkdir() - System.setProperty("SPARK_YARN_MODE", "true") - - val logConfFile = new File(logConfDir, "log4j.properties") - Files.write(LOG4J_CONF, logConfFile, UTF_8) - - yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1) - yarnCluster.init(new YarnConfiguration()) - yarnCluster.start() - - // There's a race in MiniYARNCluster in which start() may return before the RM has updated - // its address in the configuration. You can see this in the logs by noticing that when - // MiniYARNCluster prints the address, it still has port "0" assigned, although later the - // test works sometimes: - // - // INFO MiniYARNCluster: MiniYARN ResourceManager address: blah:0 - // - // That log message prints the contents of the RM_ADDRESS config variable. If you check it - // later on, it looks something like this: - // - // INFO YarnClusterSuite: RM address in configuration is blah:42631 - // - // This hack loops for a bit waiting for the port to change, and fails the test if it hasn't - // done so in a timely manner (defined to be 10 seconds). - val config = yarnCluster.getConfig() - val deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(10) - while (config.get(YarnConfiguration.RM_ADDRESS).split(":")(1) == "0") { - if (System.currentTimeMillis() > deadline) { - throw new IllegalStateException("Timed out waiting for RM to come up.") - } - logDebug("RM address still not set in configuration, waiting...") - TimeUnit.MILLISECONDS.sleep(100) - } - - logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}") - - fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) - hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR) - assert(hadoopConfDir.mkdir()) - File.createTempFile("token", ".txt", hadoopConfDir) - } - - override def afterAll() { - yarnCluster.stop() - System.clearProperty("SPARK_YARN_MODE") - super.afterAll() - } - test("run Spark in yarn-client mode") { testBasicYarnApp(true) } @@ -174,7 +103,7 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher } private def testBasicYarnApp(clientMode: Boolean): Unit = { - var result = File.createTempFile("result", null, tempDir) + val result = File.createTempFile("result", null, tempDir) runSpark(clientMode, mainClassName(YarnClusterDriver.getClass), appArgs = Seq(result.getAbsolutePath())) checkResult(result) @@ -184,6 +113,17 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher val primaryPyFile = new File(tempDir, "test.py") Files.write(TEST_PYFILE, primaryPyFile, UTF_8) + // When running tests, let's not assume the user has built the assembly module, which also + // creates the pyspark archive. Instead, let's use PYSPARK_ARCHIVES_PATH to point at the + // needed locations. + val sparkHome = sys.props("spark.test.home"); + val pythonPath = Seq( + s"$sparkHome/python/lib/py4j-0.8.2.1-src.zip", + s"$sparkHome/python") + val extraEnv = Map( + "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator), + "PYTHONPATH" -> pythonPath.mkString(File.pathSeparator)) + val moduleDir = if (clientMode) { // In client-mode, .py files added with --py-files are not visible in the driver. @@ -203,7 +143,8 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher runSpark(clientMode, primaryPyFile.getAbsolutePath(), sparkArgs = Seq("--py-files", pyFiles), - appArgs = Seq(result.getAbsolutePath())) + appArgs = Seq(result.getAbsolutePath()), + extraEnv = extraEnv) checkResult(result) } @@ -224,89 +165,6 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher checkResult(executorResult, "OVERRIDDEN") } - private def runSpark( - clientMode: Boolean, - klass: String, - appArgs: Seq[String] = Nil, - sparkArgs: Seq[String] = Nil, - extraClassPath: Seq[String] = Nil, - extraJars: Seq[String] = Nil, - extraConf: Map[String, String] = Map()): Unit = { - val master = if (clientMode) "yarn-client" else "yarn-cluster" - val props = new Properties() - - props.setProperty("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath()) - - val childClasspath = logConfDir.getAbsolutePath() + - File.pathSeparator + - sys.props("java.class.path") + - File.pathSeparator + - extraClassPath.mkString(File.pathSeparator) - props.setProperty("spark.driver.extraClassPath", childClasspath) - props.setProperty("spark.executor.extraClassPath", childClasspath) - - // SPARK-4267: make sure java options are propagated correctly. - props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") - props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"") - - yarnCluster.getConfig().foreach { e => - props.setProperty("spark.hadoop." + e.getKey(), e.getValue()) - } - - sys.props.foreach { case (k, v) => - if (k.startsWith("spark.")) { - props.setProperty(k, v) - } - } - - extraConf.foreach { case (k, v) => props.setProperty(k, v) } - - val propsFile = File.createTempFile("spark", ".properties", tempDir) - val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8) - props.store(writer, "Spark properties.") - writer.close() - - val extraJarArgs = if (!extraJars.isEmpty()) Seq("--jars", extraJars.mkString(",")) else Nil - val mainArgs = - if (klass.endsWith(".py")) { - Seq(klass) - } else { - Seq("--class", klass, fakeSparkJar.getAbsolutePath()) - } - val argv = - Seq( - new File(sys.props("spark.test.home"), "bin/spark-submit").getAbsolutePath(), - "--master", master, - "--num-executors", "1", - "--properties-file", propsFile.getAbsolutePath()) ++ - extraJarArgs ++ - sparkArgs ++ - mainArgs ++ - appArgs - - Utils.executeAndGetOutput(argv, - extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath())) - } - - /** - * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide - * any sort of error when the job process finishes successfully, but the job itself fails. So - * the tests enforce that something is written to a file after everything is ok to indicate - * that the job succeeded. - */ - private def checkResult(result: File): Unit = { - checkResult(result, "success") - } - - private def checkResult(result: File, expected: String): Unit = { - var resultString = Files.toString(result, UTF_8) - resultString should be (expected) - } - - private def mainClassName(klass: Class[_]): String = { - klass.getName().stripSuffix("$") - } - } private[spark] class SaveExecutorInfo extends SparkListener { @@ -328,12 +186,14 @@ private object YarnClusterDriver extends Logging with Matchers { def main(args: Array[String]): Unit = { if (args.length != 1) { + // scalastyle:off println System.err.println( s""" |Invalid command line: ${args.mkString(" ")} | |Usage: YarnClusterDriver [result file] """.stripMargin) + // scalastyle:on println System.exit(1) } @@ -369,8 +229,8 @@ private object YarnClusterDriver extends Logging with Matchers { assert(listener.driverLogs.nonEmpty) val driverLogs = listener.driverLogs.get assert(driverLogs.size === 2) - assert(driverLogs.containsKey("stderr")) - assert(driverLogs.containsKey("stdout")) + assert(driverLogs.contains("stderr")) + assert(driverLogs.contains("stdout")) val urlStr = driverLogs("stderr") // Ensure that this is a valid URL, else this will throw an exception new URL(urlStr) @@ -382,17 +242,29 @@ private object YarnClusterDriver extends Logging with Matchers { } -private object YarnClasspathTest { +private object YarnClasspathTest extends Logging { + + var exitCode = 0 + + def error(m: String, ex: Throwable = null): Unit = { + logError(m, ex) + // scalastyle:off println + System.out.println(m) + if (ex != null) { + ex.printStackTrace(System.out) + } + // scalastyle:on println + } def main(args: Array[String]): Unit = { if (args.length != 2) { - System.err.println( + error( s""" |Invalid command line: ${args.mkString(" ")} | |Usage: YarnClasspathTest [driver result file] [executor result file] """.stripMargin) - System.exit(1) + // scalastyle:on println } readResource(args(0)) @@ -402,6 +274,7 @@ private object YarnClasspathTest { } finally { sc.stop() } + System.exit(exitCode) } private def readResource(resultPath: String): Unit = { @@ -411,6 +284,11 @@ private object YarnClasspathTest { val resource = ccl.getResourceAsStream("test.resource") val bytes = ByteStreams.toByteArray(resource) result = new String(bytes, 0, bytes.length, UTF_8) + } catch { + case t: Throwable => + error(s"loading test.resource to $resultPath", t) + // set the exit code if not yet set + exitCode = 2 } finally { Files.write(result, new File(resultPath), UTF_8) } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala new file mode 100644 index 000000000000..8d9c9b3004ed --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -0,0 +1,109 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.deploy.yarn + +import java.io.File + +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.Files +import org.apache.commons.io.FileUtils +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.scalatest.Matchers + +import org.apache.spark._ +import org.apache.spark.network.shuffle.ShuffleTestAccessor +import org.apache.spark.network.yarn.{YarnShuffleService, YarnTestAccessor} + +/** + * Integration test for the external shuffle service with a yarn mini-cluster + */ +class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { + + override def newYarnConfig(): YarnConfiguration = { + val yarnConfig = new YarnConfiguration() + yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle") + yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"), + classOf[YarnShuffleService].getCanonicalName) + yarnConfig.set("spark.shuffle.service.port", "0") + yarnConfig + } + + test("external shuffle service") { + val shuffleServicePort = YarnTestAccessor.getShuffleServicePort + val shuffleService = YarnTestAccessor.getShuffleServiceInstance + + val registeredExecFile = YarnTestAccessor.getRegisteredExecutorFile(shuffleService) + + logInfo("Shuffle service port = " + shuffleServicePort) + val result = File.createTempFile("result", null, tempDir) + runSpark( + false, + mainClassName(YarnExternalShuffleDriver.getClass), + appArgs = Seq(result.getAbsolutePath(), registeredExecFile.getAbsolutePath), + extraConf = Map( + "spark.shuffle.service.enabled" -> "true", + "spark.shuffle.service.port" -> shuffleServicePort.toString + ) + ) + checkResult(result) + assert(YarnTestAccessor.getRegisteredExecutorFile(shuffleService).exists()) + } +} + +private object YarnExternalShuffleDriver extends Logging with Matchers { + + val WAIT_TIMEOUT_MILLIS = 10000 + + def main(args: Array[String]): Unit = { + if (args.length != 2) { + // scalastyle:off println + System.err.println( + s""" + |Invalid command line: ${args.mkString(" ")} + | + |Usage: ExternalShuffleDriver [result file] [registed exec file] + """.stripMargin) + // scalastyle:on println + System.exit(1) + } + + val sc = new SparkContext(new SparkConf() + .setAppName("External Shuffle Test")) + val conf = sc.getConf + val status = new File(args(0)) + val registeredExecFile = new File(args(1)) + logInfo("shuffle service executor file = " + registeredExecFile) + var result = "failure" + val execStateCopy = new File(registeredExecFile.getAbsolutePath + "_dup") + try { + val data = sc.parallelize(0 until 100, 10).map { x => (x % 10) -> x }.reduceByKey{ _ + _ }. + collect().toSet + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + data should be ((0 until 10).map{x => x -> (x * 10 + 450)}.toSet) + result = "success" + // only one process can open a leveldb file at a time, so we copy the files + FileUtils.copyDirectory(registeredExecFile, execStateCopy) + assert(!ShuffleTestAccessor.reloadRegisteredExecutors(execStateCopy).isEmpty) + } finally { + sc.stop() + FileUtils.deleteDirectory(execStateCopy) + Files.write(result, status, UTF_8) + } + } + +} diff --git a/yarn/src/test/scala/org/apache/spark/launcher/TestClasspathBuilder.scala b/yarn/src/test/scala/org/apache/spark/launcher/TestClasspathBuilder.scala new file mode 100644 index 000000000000..da9e8e21a26a --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/launcher/TestClasspathBuilder.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher + +import java.util.{List => JList, Map => JMap} + +/** + * Exposes AbstractCommandBuilder to the YARN tests, so that they can build classpaths the same + * way other cluster managers do. + */ +private[spark] class TestClasspathBuilder extends AbstractCommandBuilder { + + childEnv.put(CommandBuilderUtils.ENV_SPARK_HOME, sys.props("spark.test.home")) + + override def buildClassPath(extraCp: String): JList[String] = super.buildClassPath(extraCp) + + /** Not used by the YARN tests. */ + override def buildCommand(env: JMap[String, String]): JList[String] = + throw new UnsupportedOperationException() + +} diff --git a/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala b/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala new file mode 100644 index 000000000000..aa46ec5100f0 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.shuffle + +import java.io.{IOException, File} +import java.util.concurrent.ConcurrentMap + +import com.google.common.annotations.VisibleForTesting +import org.apache.hadoop.yarn.api.records.ApplicationId +import org.fusesource.leveldbjni.JniDBFactory +import org.iq80.leveldb.{DB, Options} + +import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo + +/** + * just a cheat to get package-visible members in tests + */ +object ShuffleTestAccessor { + + def getBlockResolver(handler: ExternalShuffleBlockHandler): ExternalShuffleBlockResolver = { + handler.blockManager + } + + def getExecutorInfo( + appId: ApplicationId, + execId: String, + resolver: ExternalShuffleBlockResolver + ): Option[ExecutorShuffleInfo] = { + val id = new AppExecId(appId.toString, execId) + Option(resolver.executors.get(id)) + } + + def registeredExecutorFile(resolver: ExternalShuffleBlockResolver): File = { + resolver.registeredExecutorFile + } + + def shuffleServiceLevelDB(resolver: ExternalShuffleBlockResolver): DB = { + resolver.db + } + + def reloadRegisteredExecutors( + file: File): ConcurrentMap[ExternalShuffleBlockResolver.AppExecId, ExecutorShuffleInfo] = { + val options: Options = new Options + options.createIfMissing(true) + val factory = new JniDBFactory + val db = factory.open(file, options) + val result = ExternalShuffleBlockResolver.reloadRegisteredExecutors(db) + db.close() + result + } + + def reloadRegisteredExecutors( + db: DB): ConcurrentMap[ExternalShuffleBlockResolver.AppExecId, ExecutorShuffleInfo] = { + ExternalShuffleBlockResolver.reloadRegisteredExecutors(db) + } +} diff --git a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala new file mode 100644 index 000000000000..6aa8c814cd4f --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.yarn + +import java.io.{DataOutputStream, File, FileOutputStream} + +import scala.annotation.tailrec + +import org.apache.commons.io.FileUtils +import org.apache.hadoop.yarn.api.records.ApplicationId +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.server.api.{ApplicationInitializationContext, ApplicationTerminationContext} +import org.scalatest.{BeforeAndAfterEach, Matchers} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.network.shuffle.ShuffleTestAccessor +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo + +class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach { + private[yarn] var yarnConfig: YarnConfiguration = new YarnConfiguration + + override def beforeEach(): Unit = { + yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle") + yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"), + classOf[YarnShuffleService].getCanonicalName) + yarnConfig.setInt("spark.shuffle.service.port", 0) + + yarnConfig.get("yarn.nodemanager.local-dirs").split(",").foreach { dir => + val d = new File(dir) + if (d.exists()) { + FileUtils.deleteDirectory(d) + } + FileUtils.forceMkdir(d) + logInfo(s"creating yarn.nodemanager.local-dirs: $d") + } + } + + var s1: YarnShuffleService = null + var s2: YarnShuffleService = null + var s3: YarnShuffleService = null + + override def afterEach(): Unit = { + if (s1 != null) { + s1.stop() + s1 = null + } + if (s2 != null) { + s2.stop() + s2 = null + } + if (s3 != null) { + s3.stop() + s3 = null + } + } + + test("executor state kept across NM restart") { + s1 = new YarnShuffleService + s1.init(yarnConfig) + val app1Id = ApplicationId.newInstance(0, 1) + val app1Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app1Id, null) + s1.initializeApplication(app1Data) + val app2Id = ApplicationId.newInstance(0, 2) + val app2Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app2Id, null) + s1.initializeApplication(app2Data) + + val execStateFile = s1.registeredExecutorFile + execStateFile should not be (null) + val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, "sort") + val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, "hash") + + val blockHandler = s1.blockHandler + val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler) + ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile) + + blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1) + blockResolver.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2) + ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", blockResolver) should + be (Some(shuffleInfo1)) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", blockResolver) should + be (Some(shuffleInfo2)) + + if (!execStateFile.exists()) { + @tailrec def findExistingParent(file: File): File = { + if (file == null) file + else if (file.exists()) file + else findExistingParent(file.getParentFile()) + } + val existingParent = findExistingParent(execStateFile) + assert(false, s"$execStateFile does not exist -- closest existing parent is $existingParent") + } + assert(execStateFile.exists(), s"$execStateFile did not exist") + + // now we pretend the shuffle service goes down, and comes back up + s1.stop() + s2 = new YarnShuffleService + s2.init(yarnConfig) + s2.registeredExecutorFile should be (execStateFile) + + val handler2 = s2.blockHandler + val resolver2 = ShuffleTestAccessor.getBlockResolver(handler2) + + // now we reinitialize only one of the apps, and expect yarn to tell us that app2 was stopped + // during the restart + s2.initializeApplication(app1Data) + s2.stopApplication(new ApplicationTerminationContext(app2Id)) + ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", resolver2) should be (Some(shuffleInfo1)) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver2) should be (None) + + // Act like the NM restarts one more time + s2.stop() + s3 = new YarnShuffleService + s3.init(yarnConfig) + s3.registeredExecutorFile should be (execStateFile) + + val handler3 = s3.blockHandler + val resolver3 = ShuffleTestAccessor.getBlockResolver(handler3) + + // app1 is still running + s3.initializeApplication(app1Data) + ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", resolver3) should be (Some(shuffleInfo1)) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver3) should be (None) + s3.stop() + } + + test("removed applications should not be in registered executor file") { + s1 = new YarnShuffleService + s1.init(yarnConfig) + val app1Id = ApplicationId.newInstance(0, 1) + val app1Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app1Id, null) + s1.initializeApplication(app1Data) + val app2Id = ApplicationId.newInstance(0, 2) + val app2Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app2Id, null) + s1.initializeApplication(app2Data) + + val execStateFile = s1.registeredExecutorFile + execStateFile should not be (null) + val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, "sort") + val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, "hash") + + val blockHandler = s1.blockHandler + val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler) + ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile) + + blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1) + blockResolver.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2) + + val db = ShuffleTestAccessor.shuffleServiceLevelDB(blockResolver) + ShuffleTestAccessor.reloadRegisteredExecutors(db) should not be empty + + s1.stopApplication(new ApplicationTerminationContext(app1Id)) + ShuffleTestAccessor.reloadRegisteredExecutors(db) should not be empty + s1.stopApplication(new ApplicationTerminationContext(app2Id)) + ShuffleTestAccessor.reloadRegisteredExecutors(db) shouldBe empty + } + + test("shuffle service should be robust to corrupt registered executor file") { + s1 = new YarnShuffleService + s1.init(yarnConfig) + val app1Id = ApplicationId.newInstance(0, 1) + val app1Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app1Id, null) + s1.initializeApplication(app1Data) + + val execStateFile = s1.registeredExecutorFile + val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, "sort") + + val blockHandler = s1.blockHandler + val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler) + ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile) + + blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1) + + // now we pretend the shuffle service goes down, and comes back up. But we'll also + // make a corrupt registeredExecutor File + s1.stop() + + execStateFile.listFiles().foreach{_.delete()} + + val out = new DataOutputStream(new FileOutputStream(execStateFile + "/CURRENT")) + out.writeInt(42) + out.close() + + s2 = new YarnShuffleService + s2.init(yarnConfig) + s2.registeredExecutorFile should be (execStateFile) + + val handler2 = s2.blockHandler + val resolver2 = ShuffleTestAccessor.getBlockResolver(handler2) + + // we re-initialize app1, but since the file was corrupt there is nothing we can do about it ... + s2.initializeApplication(app1Data) + // however, when we initialize a totally new app2, everything is still happy + val app2Id = ApplicationId.newInstance(0, 2) + val app2Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app2Id, null) + s2.initializeApplication(app2Data) + val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, "hash") + resolver2.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver2) should be (Some(shuffleInfo2)) + s2.stop() + + // another stop & restart should be fine though (eg., we recover from previous corruption) + s3 = new YarnShuffleService + s3.init(yarnConfig) + s3.registeredExecutorFile should be (execStateFile) + val handler3 = s3.blockHandler + val resolver3 = ShuffleTestAccessor.getBlockResolver(handler3) + + s3.initializeApplication(app2Data) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver3) should be (Some(shuffleInfo2)) + s3.stop() + + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala similarity index 60% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala rename to yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala index 12c2eed0d6b7..db322cd18e15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala @@ -14,22 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package org.apache.spark.network.yarn -package org.apache.spark.sql.execution.expressions - -import org.apache.spark.TaskContext -import org.apache.spark.sql.catalyst.expressions.{LeafExpression, InternalRow} -import org.apache.spark.sql.types.{IntegerType, DataType} - +import java.io.File /** - * Expression that returns the current partition id of the Spark task. + * just a cheat to get package-visible members in tests */ -private[sql] case object SparkPartitionID extends LeafExpression { +object YarnTestAccessor { + def getShuffleServicePort: Int = { + YarnShuffleService.boundPort + } - override def nullable: Boolean = false + def getShuffleServiceInstance: YarnShuffleService = { + YarnShuffleService.instance + } - override def dataType: DataType = IntegerType + def getRegisteredExecutorFile(service: YarnShuffleService): File = { + service.registeredExecutorFile + } - override def eval(input: InternalRow): Int = TaskContext.get().partitionId() }