From 3d567a357c40836dc0e8da67308719a48cc51193 Mon Sep 17 00:00:00 2001 From: wuyi Date: Mon, 21 Oct 2019 21:10:21 -0500 Subject: [PATCH 01/58] [MINOR][SQL] Avoid unnecessary invocation on checkAndGlobPathIfNecessary ### What changes were proposed in this pull request? Only invoke `checkAndGlobPathIfNecessary()` when we have to use `InMemoryFileIndex`. ### Why are the changes needed? Avoid unnecessary function invocation. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Pass Jenkins. Closes #26196 from Ngone51/dev-avoid-unnecessary-invocation-on-globpath. Authored-by: wuyi Signed-off-by: Sean Owen --- .../apache/spark/sql/execution/datasources/DataSource.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 0f5f1591623a..e9b8fae7cd73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -378,8 +378,6 @@ case class DataSource( // This is a non-streaming file based datasource. case (format: FileFormat, _) => - val globbedPaths = - checkAndGlobPathIfNecessary(checkEmptyGlobPath = true, checkFilesExist = checkFilesExist) val useCatalogFileIndex = sparkSession.sqlContext.conf.manageFilesourcePartitions && catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog && catalogTable.get.partitionColumnNames.nonEmpty @@ -391,6 +389,8 @@ case class DataSource( catalogTable.get.stats.map(_.sizeInBytes.toLong).getOrElse(defaultTableSize)) (index, catalogTable.get.dataSchema, catalogTable.get.partitionSchema) } else { + val globbedPaths = checkAndGlobPathIfNecessary( + checkEmptyGlobPath = true, checkFilesExist = checkFilesExist) val index = createInMemoryFileIndex(globbedPaths) val (resultDataSchema, resultPartitionSchema) = getOrInferFileFormatSchema(format, () => index) From 484f93e25506f84d1548504783be9ce940149bb7 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Tue, 22 Oct 2019 10:38:06 +0800 Subject: [PATCH 02/58] [SPARK-29530][SQL] Make SQLConf in SQL parse process thread safe ### What changes were proposed in this pull request? As I have comment in [SPARK-29516](https://github.com/apache/spark/pull/26172#issuecomment-544364977) SparkSession.sql() method parse process not under current sparksession's conf, so some configuration about parser is not valid in multi-thread situation. In this pr, we add a SQLConf parameter to AbstractSqlParser and initial it with SessionState's conf. Then for each SparkSession's parser process. It will use's it's own SessionState's SQLConf and to be thread safe ### Why are the changes needed? Fix bug ### Does this PR introduce any user-facing change? NO ### How was this patch tested? NO Closes #26187 from AngersZhuuuu/SPARK-29530. Authored-by: angerszhu Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/parser/ParseDriver.scala | 14 +++++++------- .../spark/sql/execution/SparkSqlParser.scala | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 85998e33140d..a84d29b71ac4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DataType, StructType} /** * Base SQL parsing infrastructure. */ -abstract class AbstractSqlParser extends ParserInterface with Logging { +abstract class AbstractSqlParser(conf: SQLConf) extends ParserInterface with Logging { /** Creates/Resolves DataType for a given SQL string. */ override def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => @@ -91,16 +91,16 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { val lexer = new SqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command))) lexer.removeErrorListeners() lexer.addErrorListener(ParseErrorListener) - lexer.legacy_setops_precedence_enbled = SQLConf.get.setOpsPrecedenceEnforced - lexer.ansi = SQLConf.get.ansiEnabled + lexer.legacy_setops_precedence_enbled = conf.setOpsPrecedenceEnforced + lexer.ansi = conf.ansiEnabled val tokenStream = new CommonTokenStream(lexer) val parser = new SqlBaseParser(tokenStream) parser.addParseListener(PostProcessor) parser.removeErrorListeners() parser.addErrorListener(ParseErrorListener) - parser.legacy_setops_precedence_enbled = SQLConf.get.setOpsPrecedenceEnforced - parser.ansi = SQLConf.get.ansiEnabled + parser.legacy_setops_precedence_enbled = conf.setOpsPrecedenceEnforced + parser.ansi = conf.ansiEnabled try { try { @@ -134,12 +134,12 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { /** * Concrete SQL parser for Catalyst-only SQL statements. */ -class CatalystSqlParser(conf: SQLConf) extends AbstractSqlParser { +class CatalystSqlParser(conf: SQLConf) extends AbstractSqlParser(conf) { val astBuilder = new AstBuilder(conf) } /** For test-only. */ -object CatalystSqlParser extends AbstractSqlParser { +object CatalystSqlParser extends AbstractSqlParser(SQLConf.get) { val astBuilder = new AstBuilder(SQLConf.get) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 3e7a54877cae..cdee11781324 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.types.StructType /** * Concrete parser for Spark SQL statements. */ -class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser { +class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser(conf) { val astBuilder = new SparkSqlAstBuilder(conf) private val substitutor = new VariableSubstitution(conf) From 467c3f610f4f83f16405a3110870d57e59059435 Mon Sep 17 00:00:00 2001 From: denglingang Date: Tue, 22 Oct 2019 14:49:23 +0900 Subject: [PATCH 03/58] [SPARK-29529][DOCS] Remove unnecessary orc version and hive version in doc ### What changes were proposed in this pull request? This PR remove unnecessary orc version and hive version in doc. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? N/A. Closes #26146 from denglingang/SPARK-24576. Lead-authored-by: denglingang Co-authored-by: HyukjinKwon Signed-off-by: HyukjinKwon --- docs/sql-data-sources-orc.md | 2 +- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/sql-data-sources-orc.md b/docs/sql-data-sources-orc.md index 45bff17c6cf2..bddffe02602e 100644 --- a/docs/sql-data-sources-orc.md +++ b/docs/sql-data-sources-orc.md @@ -31,7 +31,7 @@ the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is also spark.sql.orc.impl native - The name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4. `hive` means the ORC library in Hive 1.2.1. + The name of ORC implementation. It can be one of native and hive. native means the native ORC support. hive means the ORC library in Hive. spark.sql.orc.enableVectorizedReader diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4944099fcc0d..75db52e334b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -626,8 +626,8 @@ object SQLConf { .createWithDefault("snappy") val ORC_IMPLEMENTATION = buildConf("spark.sql.orc.impl") - .doc("When native, use the native version of ORC support instead of the ORC library in Hive " + - "1.2.1. It is 'hive' by default prior to Spark 2.4.") + .doc("When native, use the native version of ORC support instead of the ORC library in Hive." + + "It is 'hive' by default prior to Spark 2.4.") .internal() .stringConf .checkValues(Set("hive", "native")) From 811d563fbf60203377e8462e4fad271c1140b4fa Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Tue, 22 Oct 2019 16:18:34 +0900 Subject: [PATCH 04/58] [SPARK-29536][PYTHON] Upgrade cloudpickle to 1.1.1 to support Python 3.8 ### What changes were proposed in this pull request? Inline cloudpickle in PySpark to cloudpickle 1.1.1. See https://github.com/cloudpipe/cloudpickle/blob/v1.1.1/cloudpickle/cloudpickle.py https://github.com/cloudpipe/cloudpickle/pull/269 was added for Python 3.8 support (fixed from 1.1.0). Using 1.2.2 seems breaking PyPy 2 due to cloudpipe/cloudpickle#278 so this PR currently uses 1.1.1. Once we drop Python 2, we can switch to the highest version. ### Why are the changes needed? positional-only arguments was newly introduced from Python 3.8 (see https://docs.python.org/3/whatsnew/3.8.html#positional-only-parameters) Particularly the newly added argument to `types.CodeType` was the problem (https://docs.python.org/3/whatsnew/3.8.html#changes-in-the-python-api): > `types.CodeType` has a new parameter in the second position of the constructor (posonlyargcount) to support positional-only arguments defined in **PEP 570**. The first argument (argcount) now represents the total number of positional arguments (including positional-only arguments). The new `replace()` method of `types.CodeType` can be used to make the code future-proof. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Manually tested. Note that the optional dependency PyArrow looks not yet supporting Python 3.8; therefore, it was not tested. See "Details" below.

```bash cd python ./run-tests --python-executables=python3.8 ``` ``` Running PySpark tests. Output is in /Users/hyukjin.kwon/workspace/forked/spark/python/unit-tests.log Will test against the following Python executables: ['python3.8'] Will test the following Python modules: ['pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming'] Starting test(python3.8): pyspark.ml.tests.test_algorithms Starting test(python3.8): pyspark.ml.tests.test_feature Starting test(python3.8): pyspark.ml.tests.test_base Starting test(python3.8): pyspark.ml.tests.test_evaluation Finished test(python3.8): pyspark.ml.tests.test_base (12s) Starting test(python3.8): pyspark.ml.tests.test_image Finished test(python3.8): pyspark.ml.tests.test_evaluation (14s) Starting test(python3.8): pyspark.ml.tests.test_linalg Finished test(python3.8): pyspark.ml.tests.test_feature (23s) Starting test(python3.8): pyspark.ml.tests.test_param Finished test(python3.8): pyspark.ml.tests.test_image (22s) Starting test(python3.8): pyspark.ml.tests.test_persistence Finished test(python3.8): pyspark.ml.tests.test_param (25s) Starting test(python3.8): pyspark.ml.tests.test_pipeline Finished test(python3.8): pyspark.ml.tests.test_linalg (37s) Starting test(python3.8): pyspark.ml.tests.test_stat Finished test(python3.8): pyspark.ml.tests.test_pipeline (7s) Starting test(python3.8): pyspark.ml.tests.test_training_summary Finished test(python3.8): pyspark.ml.tests.test_stat (21s) Starting test(python3.8): pyspark.ml.tests.test_tuning Finished test(python3.8): pyspark.ml.tests.test_persistence (45s) Starting test(python3.8): pyspark.ml.tests.test_wrapper Finished test(python3.8): pyspark.ml.tests.test_algorithms (83s) Starting test(python3.8): pyspark.mllib.tests.test_algorithms Finished test(python3.8): pyspark.ml.tests.test_training_summary (32s) Starting test(python3.8): pyspark.mllib.tests.test_feature Finished test(python3.8): pyspark.ml.tests.test_wrapper (20s) Starting test(python3.8): pyspark.mllib.tests.test_linalg Finished test(python3.8): pyspark.mllib.tests.test_feature (32s) Starting test(python3.8): pyspark.mllib.tests.test_stat Finished test(python3.8): pyspark.mllib.tests.test_algorithms (70s) Starting test(python3.8): pyspark.mllib.tests.test_streaming_algorithms Finished test(python3.8): pyspark.mllib.tests.test_stat (37s) Starting test(python3.8): pyspark.mllib.tests.test_util Finished test(python3.8): pyspark.mllib.tests.test_linalg (70s) Starting test(python3.8): pyspark.sql.tests.test_arrow Finished test(python3.8): pyspark.sql.tests.test_arrow (1s) ... 53 tests were skipped Starting test(python3.8): pyspark.sql.tests.test_catalog Finished test(python3.8): pyspark.mllib.tests.test_util (15s) Starting test(python3.8): pyspark.sql.tests.test_column Finished test(python3.8): pyspark.sql.tests.test_catalog (24s) Starting test(python3.8): pyspark.sql.tests.test_conf Finished test(python3.8): pyspark.sql.tests.test_column (21s) Starting test(python3.8): pyspark.sql.tests.test_context Finished test(python3.8): pyspark.ml.tests.test_tuning (125s) Starting test(python3.8): pyspark.sql.tests.test_dataframe Finished test(python3.8): pyspark.sql.tests.test_conf (9s) Starting test(python3.8): pyspark.sql.tests.test_datasources Finished test(python3.8): pyspark.sql.tests.test_context (29s) Starting test(python3.8): pyspark.sql.tests.test_functions Finished test(python3.8): pyspark.sql.tests.test_datasources (32s) Starting test(python3.8): pyspark.sql.tests.test_group Finished test(python3.8): pyspark.sql.tests.test_dataframe (39s) ... 3 tests were skipped Starting test(python3.8): pyspark.sql.tests.test_pandas_udf Finished test(python3.8): pyspark.sql.tests.test_pandas_udf (1s) ... 6 tests were skipped Starting test(python3.8): pyspark.sql.tests.test_pandas_udf_cogrouped_map Finished test(python3.8): pyspark.sql.tests.test_pandas_udf_cogrouped_map (0s) ... 14 tests were skipped Starting test(python3.8): pyspark.sql.tests.test_pandas_udf_grouped_agg Finished test(python3.8): pyspark.sql.tests.test_pandas_udf_grouped_agg (1s) ... 15 tests were skipped Starting test(python3.8): pyspark.sql.tests.test_pandas_udf_grouped_map Finished test(python3.8): pyspark.sql.tests.test_pandas_udf_grouped_map (1s) ... 20 tests were skipped Starting test(python3.8): pyspark.sql.tests.test_pandas_udf_scalar Finished test(python3.8): pyspark.sql.tests.test_pandas_udf_scalar (1s) ... 49 tests were skipped Starting test(python3.8): pyspark.sql.tests.test_pandas_udf_window Finished test(python3.8): pyspark.sql.tests.test_pandas_udf_window (1s) ... 14 tests were skipped Starting test(python3.8): pyspark.sql.tests.test_readwriter Finished test(python3.8): pyspark.sql.tests.test_functions (29s) Starting test(python3.8): pyspark.sql.tests.test_serde Finished test(python3.8): pyspark.sql.tests.test_group (20s) Starting test(python3.8): pyspark.sql.tests.test_session Finished test(python3.8): pyspark.mllib.tests.test_streaming_algorithms (126s) Starting test(python3.8): pyspark.sql.tests.test_streaming Finished test(python3.8): pyspark.sql.tests.test_serde (25s) Starting test(python3.8): pyspark.sql.tests.test_types Finished test(python3.8): pyspark.sql.tests.test_readwriter (38s) Starting test(python3.8): pyspark.sql.tests.test_udf Finished test(python3.8): pyspark.sql.tests.test_session (32s) Starting test(python3.8): pyspark.sql.tests.test_utils Finished test(python3.8): pyspark.sql.tests.test_utils (17s) Starting test(python3.8): pyspark.streaming.tests.test_context Finished test(python3.8): pyspark.sql.tests.test_types (45s) Starting test(python3.8): pyspark.streaming.tests.test_dstream Finished test(python3.8): pyspark.sql.tests.test_udf (44s) Starting test(python3.8): pyspark.streaming.tests.test_kinesis Finished test(python3.8): pyspark.streaming.tests.test_kinesis (0s) ... 2 tests were skipped Starting test(python3.8): pyspark.streaming.tests.test_listener Finished test(python3.8): pyspark.streaming.tests.test_context (28s) Starting test(python3.8): pyspark.tests.test_appsubmit Finished test(python3.8): pyspark.sql.tests.test_streaming (60s) Starting test(python3.8): pyspark.tests.test_broadcast Finished test(python3.8): pyspark.streaming.tests.test_listener (11s) Starting test(python3.8): pyspark.tests.test_conf Finished test(python3.8): pyspark.tests.test_conf (17s) Starting test(python3.8): pyspark.tests.test_context Finished test(python3.8): pyspark.tests.test_broadcast (39s) Starting test(python3.8): pyspark.tests.test_daemon Finished test(python3.8): pyspark.tests.test_daemon (5s) Starting test(python3.8): pyspark.tests.test_join Finished test(python3.8): pyspark.tests.test_context (31s) Starting test(python3.8): pyspark.tests.test_profiler Finished test(python3.8): pyspark.tests.test_join (9s) Starting test(python3.8): pyspark.tests.test_rdd Finished test(python3.8): pyspark.tests.test_profiler (12s) Starting test(python3.8): pyspark.tests.test_readwrite Finished test(python3.8): pyspark.tests.test_readwrite (23s) ... 3 tests were skipped Starting test(python3.8): pyspark.tests.test_serializers Finished test(python3.8): pyspark.tests.test_appsubmit (94s) Starting test(python3.8): pyspark.tests.test_shuffle Finished test(python3.8): pyspark.streaming.tests.test_dstream (110s) Starting test(python3.8): pyspark.tests.test_taskcontext Finished test(python3.8): pyspark.tests.test_rdd (42s) Starting test(python3.8): pyspark.tests.test_util Finished test(python3.8): pyspark.tests.test_serializers (11s) Starting test(python3.8): pyspark.tests.test_worker Finished test(python3.8): pyspark.tests.test_shuffle (12s) Starting test(python3.8): pyspark.accumulators Finished test(python3.8): pyspark.tests.test_util (7s) Starting test(python3.8): pyspark.broadcast Finished test(python3.8): pyspark.accumulators (8s) Starting test(python3.8): pyspark.conf Finished test(python3.8): pyspark.broadcast (8s) Starting test(python3.8): pyspark.context Finished test(python3.8): pyspark.tests.test_worker (19s) Starting test(python3.8): pyspark.ml.classification Finished test(python3.8): pyspark.conf (4s) Starting test(python3.8): pyspark.ml.clustering Finished test(python3.8): pyspark.context (22s) Starting test(python3.8): pyspark.ml.evaluation Finished test(python3.8): pyspark.tests.test_taskcontext (49s) Starting test(python3.8): pyspark.ml.feature Finished test(python3.8): pyspark.ml.clustering (43s) Starting test(python3.8): pyspark.ml.fpm Finished test(python3.8): pyspark.ml.evaluation (27s) Starting test(python3.8): pyspark.ml.image Finished test(python3.8): pyspark.ml.image (8s) Starting test(python3.8): pyspark.ml.linalg.__init__ Finished test(python3.8): pyspark.ml.linalg.__init__ (0s) Starting test(python3.8): pyspark.ml.recommendation Finished test(python3.8): pyspark.ml.classification (63s) Starting test(python3.8): pyspark.ml.regression Finished test(python3.8): pyspark.ml.fpm (23s) Starting test(python3.8): pyspark.ml.stat Finished test(python3.8): pyspark.ml.stat (30s) Starting test(python3.8): pyspark.ml.tuning Finished test(python3.8): pyspark.ml.regression (51s) Starting test(python3.8): pyspark.mllib.classification Finished test(python3.8): pyspark.ml.feature (93s) Starting test(python3.8): pyspark.mllib.clustering Finished test(python3.8): pyspark.ml.tuning (39s) Starting test(python3.8): pyspark.mllib.evaluation Finished test(python3.8): pyspark.mllib.classification (38s) Starting test(python3.8): pyspark.mllib.feature Finished test(python3.8): pyspark.mllib.evaluation (25s) Starting test(python3.8): pyspark.mllib.fpm Finished test(python3.8): pyspark.mllib.clustering (64s) Starting test(python3.8): pyspark.mllib.linalg.__init__ Finished test(python3.8): pyspark.ml.recommendation (131s) Starting test(python3.8): pyspark.mllib.linalg.distributed Finished test(python3.8): pyspark.mllib.linalg.__init__ (0s) Starting test(python3.8): pyspark.mllib.random Finished test(python3.8): pyspark.mllib.feature (36s) Starting test(python3.8): pyspark.mllib.recommendation Finished test(python3.8): pyspark.mllib.fpm (31s) Starting test(python3.8): pyspark.mllib.regression Finished test(python3.8): pyspark.mllib.random (16s) Starting test(python3.8): pyspark.mllib.stat.KernelDensity Finished test(python3.8): pyspark.mllib.stat.KernelDensity (1s) Starting test(python3.8): pyspark.mllib.stat._statistics Finished test(python3.8): pyspark.mllib.stat._statistics (25s) Starting test(python3.8): pyspark.mllib.tree Finished test(python3.8): pyspark.mllib.regression (44s) Starting test(python3.8): pyspark.mllib.util Finished test(python3.8): pyspark.mllib.recommendation (49s) Starting test(python3.8): pyspark.profiler Finished test(python3.8): pyspark.mllib.linalg.distributed (53s) Starting test(python3.8): pyspark.rdd Finished test(python3.8): pyspark.profiler (14s) Starting test(python3.8): pyspark.serializers Finished test(python3.8): pyspark.mllib.tree (30s) Starting test(python3.8): pyspark.shuffle Finished test(python3.8): pyspark.shuffle (2s) Starting test(python3.8): pyspark.sql.avro.functions Finished test(python3.8): pyspark.mllib.util (30s) Starting test(python3.8): pyspark.sql.catalog Finished test(python3.8): pyspark.serializers (17s) Starting test(python3.8): pyspark.sql.column Finished test(python3.8): pyspark.rdd (31s) Starting test(python3.8): pyspark.sql.conf Finished test(python3.8): pyspark.sql.conf (7s) Starting test(python3.8): pyspark.sql.context Finished test(python3.8): pyspark.sql.avro.functions (19s) Starting test(python3.8): pyspark.sql.dataframe Finished test(python3.8): pyspark.sql.catalog (16s) Starting test(python3.8): pyspark.sql.functions Finished test(python3.8): pyspark.sql.column (27s) Starting test(python3.8): pyspark.sql.group Finished test(python3.8): pyspark.sql.context (26s) Starting test(python3.8): pyspark.sql.readwriter Finished test(python3.8): pyspark.sql.group (52s) Starting test(python3.8): pyspark.sql.session Finished test(python3.8): pyspark.sql.dataframe (73s) Starting test(python3.8): pyspark.sql.streaming Finished test(python3.8): pyspark.sql.functions (75s) Starting test(python3.8): pyspark.sql.types Finished test(python3.8): pyspark.sql.readwriter (57s) Starting test(python3.8): pyspark.sql.udf Finished test(python3.8): pyspark.sql.types (13s) Starting test(python3.8): pyspark.sql.window Finished test(python3.8): pyspark.sql.session (32s) Starting test(python3.8): pyspark.streaming.util Finished test(python3.8): pyspark.streaming.util (1s) Starting test(python3.8): pyspark.util Finished test(python3.8): pyspark.util (0s) Finished test(python3.8): pyspark.sql.streaming (30s) Finished test(python3.8): pyspark.sql.udf (27s) Finished test(python3.8): pyspark.sql.window (22s) Tests passed in 855 seconds ```

Closes #26194 from HyukjinKwon/SPARK-29536. Authored-by: HyukjinKwon Signed-off-by: HyukjinKwon --- python/pyspark/cloudpickle.py | 257 +++++++++++++++++++++++++++++----- python/setup.py | 1 + 2 files changed, 221 insertions(+), 37 deletions(-) diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 7df5f6c748ad..09d3a5e7cfb6 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -44,7 +44,6 @@ import dis from functools import partial -import importlib import io import itertools import logging @@ -56,12 +55,26 @@ import traceback import types import weakref +import uuid +import threading + + +try: + from enum import Enum +except ImportError: + Enum = None # cloudpickle is meant for inter process communication: we expect all # communicating processes to run the same Python version hence we favor # communication speed over compatibility: DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL +# Track the provenance of reconstructed dynamic classes to make it possible to +# recontruct instances from the matching singleton class definition when +# appropriate and preserve the usual "isinstance" semantics of Python objects. +_DYNAMIC_CLASS_TRACKER_BY_CLASS = weakref.WeakKeyDictionary() +_DYNAMIC_CLASS_TRACKER_BY_ID = weakref.WeakValueDictionary() +_DYNAMIC_CLASS_TRACKER_LOCK = threading.Lock() if sys.version_info[0] < 3: # pragma: no branch from pickle import Pickler @@ -71,12 +84,37 @@ from StringIO import StringIO string_types = (basestring,) # noqa PY3 = False + PY2 = True + PY2_WRAPPER_DESCRIPTOR_TYPE = type(object.__init__) + PY2_METHOD_WRAPPER_TYPE = type(object.__eq__) + PY2_CLASS_DICT_BLACKLIST = (PY2_METHOD_WRAPPER_TYPE, + PY2_WRAPPER_DESCRIPTOR_TYPE) else: types.ClassType = type from pickle import _Pickler as Pickler from io import BytesIO as StringIO string_types = (str,) PY3 = True + PY2 = False + + +def _ensure_tracking(class_def): + with _DYNAMIC_CLASS_TRACKER_LOCK: + class_tracker_id = _DYNAMIC_CLASS_TRACKER_BY_CLASS.get(class_def) + if class_tracker_id is None: + class_tracker_id = uuid.uuid4().hex + _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id + _DYNAMIC_CLASS_TRACKER_BY_ID[class_tracker_id] = class_def + return class_tracker_id + + +def _lookup_class_or_track(class_tracker_id, class_def): + if class_tracker_id is not None: + with _DYNAMIC_CLASS_TRACKER_LOCK: + class_def = _DYNAMIC_CLASS_TRACKER_BY_ID.setdefault( + class_tracker_id, class_def) + _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id + return class_def def _make_cell_set_template_code(): @@ -112,7 +150,7 @@ def inner(value): # NOTE: we are marking the cell variable as a free variable intentionally # so that we simulate an inner function instead of the outer function. This # is what gives us the ``nonlocal`` behavior in a Python 2 compatible way. - if not PY3: # pragma: no branch + if PY2: # pragma: no branch return types.CodeType( co.co_argcount, co.co_nlocals, @@ -130,24 +168,43 @@ def inner(value): (), ) else: - return types.CodeType( - co.co_argcount, - co.co_kwonlyargcount, - co.co_nlocals, - co.co_stacksize, - co.co_flags, - co.co_code, - co.co_consts, - co.co_names, - co.co_varnames, - co.co_filename, - co.co_name, - co.co_firstlineno, - co.co_lnotab, - co.co_cellvars, # this is the trickery - (), - ) - + if hasattr(types.CodeType, "co_posonlyargcount"): # pragma: no branch + return types.CodeType( + co.co_argcount, + co.co_posonlyargcount, # Python3.8 with PEP570 + co.co_kwonlyargcount, + co.co_nlocals, + co.co_stacksize, + co.co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_cellvars, # this is the trickery + (), + ) + else: + return types.CodeType( + co.co_argcount, + co.co_kwonlyargcount, + co.co_nlocals, + co.co_stacksize, + co.co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_cellvars, # this is the trickery + (), + ) _cell_set_template_code = _make_cell_set_template_code() @@ -220,7 +277,7 @@ def _walk_global_ops(code): global-referencing instructions in *code*. """ code = getattr(code, 'co_code', b'') - if not PY3: # pragma: no branch + if PY2: # pragma: no branch code = map(ord, code) n = len(code) @@ -250,6 +307,39 @@ def _walk_global_ops(code): yield op, instr.arg +def _extract_class_dict(cls): + """Retrieve a copy of the dict of a class without the inherited methods""" + clsdict = dict(cls.__dict__) # copy dict proxy to a dict + if len(cls.__bases__) == 1: + inherited_dict = cls.__bases__[0].__dict__ + else: + inherited_dict = {} + for base in reversed(cls.__bases__): + inherited_dict.update(base.__dict__) + to_remove = [] + for name, value in clsdict.items(): + try: + base_value = inherited_dict[name] + if value is base_value: + to_remove.append(name) + elif PY2: + # backward compat for Python 2 + if hasattr(value, "im_func"): + if value.im_func is getattr(base_value, "im_func", None): + to_remove.append(name) + elif isinstance(value, PY2_CLASS_DICT_BLACKLIST): + # On Python 2 we have no way to pickle those specific + # methods types nor to check that they are actually + # inherited. So we assume that they are always inherited + # from builtin types. + to_remove.append(name) + except KeyError: + pass + for name in to_remove: + clsdict.pop(name) + return clsdict + + class CloudPickler(Pickler): dispatch = Pickler.dispatch.copy() @@ -277,7 +367,7 @@ def save_memoryview(self, obj): dispatch[memoryview] = save_memoryview - if not PY3: # pragma: no branch + if PY2: # pragma: no branch def save_buffer(self, obj): self.save(str(obj)) @@ -300,12 +390,23 @@ def save_codeobject(self, obj): Save a code object """ if PY3: # pragma: no branch - args = ( - obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize, - obj.co_flags, obj.co_code, obj.co_consts, obj.co_names, obj.co_varnames, - obj.co_filename, obj.co_name, obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, - obj.co_cellvars - ) + if hasattr(obj, "co_posonlyargcount"): # pragma: no branch + args = ( + obj.co_argcount, obj.co_posonlyargcount, + obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize, + obj.co_flags, obj.co_code, obj.co_consts, obj.co_names, + obj.co_varnames, obj.co_filename, obj.co_name, + obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, + obj.co_cellvars + ) + else: + args = ( + obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals, + obj.co_stacksize, obj.co_flags, obj.co_code, obj.co_consts, + obj.co_names, obj.co_varnames, obj.co_filename, + obj.co_name, obj.co_firstlineno, obj.co_lnotab, + obj.co_freevars, obj.co_cellvars + ) else: args = ( obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code, @@ -460,15 +561,40 @@ def func(): # then discards the reference to it self.write(pickle.POP) - def save_dynamic_class(self, obj): + def _save_dynamic_enum(self, obj, clsdict): + """Special handling for dynamic Enum subclasses + + Use a dedicated Enum constructor (inspired by EnumMeta.__call__) as the + EnumMeta metaclass has complex initialization that makes the Enum + subclasses hold references to their own instances. """ - Save a class that can't be stored as module global. + members = dict((e.name, e.value) for e in obj) + + # Python 2.7 with enum34 can have no qualname: + qualname = getattr(obj, "__qualname__", None) + + self.save_reduce(_make_skeleton_enum, + (obj.__bases__, obj.__name__, qualname, members, + obj.__module__, _ensure_tracking(obj), None), + obj=obj) + + # Cleanup the clsdict that will be passed to _rehydrate_skeleton_class: + # Those attributes are already handled by the metaclass. + for attrname in ["_generate_next_value_", "_member_names_", + "_member_map_", "_member_type_", + "_value2member_map_"]: + clsdict.pop(attrname, None) + for member in members: + clsdict.pop(member) + + def save_dynamic_class(self, obj): + """Save a class that can't be stored as module global. This method is used to serialize classes that are defined inside functions, or that otherwise can't be serialized as attribute lookups from global modules. """ - clsdict = dict(obj.__dict__) # copy dict proxy to a dict + clsdict = _extract_class_dict(obj) clsdict.pop('__weakref__', None) # For ABCMeta in python3.7+, remove _abc_impl as it is not picklable. @@ -496,8 +622,8 @@ def save_dynamic_class(self, obj): for k in obj.__slots__: clsdict.pop(k, None) - # If type overrides __dict__ as a property, include it in the type kwargs. - # In Python 2, we can't set this attribute after construction. + # If type overrides __dict__ as a property, include it in the type + # kwargs. In Python 2, we can't set this attribute after construction. __dict__ = clsdict.pop('__dict__', None) if isinstance(__dict__, property): type_kwargs['__dict__'] = __dict__ @@ -524,8 +650,16 @@ def save_dynamic_class(self, obj): write(pickle.MARK) # Create and memoize an skeleton class with obj's name and bases. - tp = type(obj) - self.save_reduce(tp, (obj.__name__, obj.__bases__, type_kwargs), obj=obj) + if Enum is not None and issubclass(obj, Enum): + # Special handling of Enum subclasses + self._save_dynamic_enum(obj, clsdict) + else: + # "Regular" class definition: + tp = type(obj) + self.save_reduce(_make_skeleton_class, + (tp, obj.__name__, obj.__bases__, type_kwargs, + _ensure_tracking(obj), None), + obj=obj) # Now save the rest of obj's __dict__. Any references to obj # encountered while saving will point to the skeleton class. @@ -778,7 +912,7 @@ def save_inst(self, obj): save(stuff) write(pickle.BUILD) - if not PY3: # pragma: no branch + if PY2: # pragma: no branch dispatch[types.InstanceType] = save_inst def save_property(self, obj): @@ -1119,6 +1253,22 @@ def _make_skel_func(code, cell_count, base_globals=None): return types.FunctionType(code, base_globals, None, None, closure) +def _make_skeleton_class(type_constructor, name, bases, type_kwargs, + class_tracker_id, extra): + """Build dynamic class with an empty __dict__ to be filled once memoized + + If class_tracker_id is not None, try to lookup an existing class definition + matching that id. If none is found, track a newly reconstructed class + definition under that id so that other instances stemming from the same + class id will also reuse this class definition. + + The "extra" variable is meant to be a dict (or None) that can be used for + forward compatibility shall the need arise. + """ + skeleton_class = type_constructor(name, bases, type_kwargs) + return _lookup_class_or_track(class_tracker_id, skeleton_class) + + def _rehydrate_skeleton_class(skeleton_class, class_dict): """Put attributes from `class_dict` back on `skeleton_class`. @@ -1137,6 +1287,39 @@ def _rehydrate_skeleton_class(skeleton_class, class_dict): return skeleton_class +def _make_skeleton_enum(bases, name, qualname, members, module, + class_tracker_id, extra): + """Build dynamic enum with an empty __dict__ to be filled once memoized + + The creation of the enum class is inspired by the code of + EnumMeta._create_. + + If class_tracker_id is not None, try to lookup an existing enum definition + matching that id. If none is found, track a newly reconstructed enum + definition under that id so that other instances stemming from the same + class id will also reuse this enum definition. + + The "extra" variable is meant to be a dict (or None) that can be used for + forward compatibility shall the need arise. + """ + # enums always inherit from their base Enum class at the last position in + # the list of base classes: + enum_base = bases[-1] + metacls = enum_base.__class__ + classdict = metacls.__prepare__(name, bases) + + for member_name, member_value in members.items(): + classdict[member_name] = member_value + enum_class = metacls.__new__(metacls, name, bases, classdict) + enum_class.__module__ = module + + # Python 2.7 compat + if qualname is not None: + enum_class.__qualname__ = qualname + + return _lookup_class_or_track(class_tracker_id, enum_class) + + def _is_dynamic(module): """ Return True if the module is special module that cannot be imported by its @@ -1176,4 +1359,4 @@ def _reduce_method_descriptor(obj): import copy_reg as copyreg except ImportError: import copyreg - copyreg.pickle(method_descriptor, _reduce_method_descriptor) + copyreg.pickle(method_descriptor, _reduce_method_descriptor) \ No newline at end of file diff --git a/python/setup.py b/python/setup.py index ee5c32683efa..ea672309703b 100644 --- a/python/setup.py +++ b/python/setup.py @@ -230,6 +230,7 @@ def _supports_symlinks(): 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: Implementation :: CPython', 'Programming Language :: Python :: Implementation :: PyPy'] ) From 868d851dac6016a2fc5665fb2a3ea01ab184402a Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 22 Oct 2019 17:49:44 +0800 Subject: [PATCH 05/58] [SPARK-29232][ML] Update the parameter maps of the DecisionTreeRegression/Classification Models ### What changes were proposed in this pull request? The trees (Array[```DecisionTreeRegressionModel```]) in ```RandomForestRegressionModel``` only contains the default parameter value. Need to update the parameter maps for these trees. Same issues in ```RandomForestClassifier```, ```GBTClassifier``` and ```GBTRegressor``` ### Why are the changes needed? User wants to access each individual tree and build the trees back up for the random forest estimator. This doesn't work because trees don't have the correct parameter values ### Does this PR introduce any user-facing change? Yes. Now the trees in ```RandomForestRegressionModel```, ```RandomForestClassifier```, ```GBTClassifier``` and ```GBTRegressor``` have the correct parameter values. ### How was this patch tested? Add tests Closes #26154 from huaxingao/spark-29232. Authored-by: Huaxin Gao Signed-off-by: zhengruifeng --- .../ml/classification/GBTClassifier.scala | 1 + .../RandomForestClassifier.scala | 1 + .../spark/ml/regression/GBTRegressor.scala | 1 + .../ml/regression/RandomForestRegressor.scala | 1 + .../classification/GBTClassifierSuite.scala | 16 +++++++++++++++ .../RandomForestClassifierSuite.scala | 20 +++++++++++++++++++ .../ml/regression/GBTRegressorSuite.scala | 16 ++++++++++++++- .../RandomForestRegressorSuite.scala | 19 ++++++++++++++++++ 8 files changed, 74 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 09f81b0dcbda..74624be360c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -203,6 +203,7 @@ class GBTClassifier @Since("1.4.0") ( } else { GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) } + baseLearners.foreach(copyValues(_)) val numFeatures = baseLearners.head.numFeatures instr.logNumFeatures(numFeatures) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 731b43b67813..245cda35d8ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -143,6 +143,7 @@ class RandomForestClassifier @Since("1.4.0") ( val trees = RandomForest .run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) .map(_.asInstanceOf[DecisionTreeClassificationModel]) + trees.foreach(copyValues(_)) val numFeatures = trees.head.numFeatures instr.logNumClasses(numClasses) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 00c0bc9f5e28..0cc06d82bf3f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -181,6 +181,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) } + baseLearners.foreach(copyValues(_)) val numFeatures = baseLearners.head.numFeatures instr.logNumFeatures(numFeatures) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 938aa5acac08..8f78fc1da18c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -130,6 +130,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S val trees = RandomForest .run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) .map(_.asInstanceOf[DecisionTreeRegressionModel]) + trees.foreach(copyValues(_)) val numFeatures = trees.head.numFeatures instr.logNamedValue(Instrumentation.loggerTags.numFeatures, numFeatures) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index af3dd201d3b5..530ca20d0eb0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -456,6 +456,22 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { } } + test("tree params") { + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) + val gbt = new GBTClassifier() + .setMaxDepth(2) + .setCheckpointInterval(5) + .setSeed(123) + val model = gbt.fit(df) + + model.trees.foreach (i => { + assert(i.getMaxDepth === model.getMaxDepth) + assert(i.getCheckpointInterval === model.getCheckpointInterval) + assert(i.getSeed === model.getSeed) + }) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index f03ed0b76eb8..5958bfcf5ea6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -230,6 +230,26 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { } } + test("tree params") { + val rdd = orderedLabeledPoints5_20 + val rf = new RandomForestClassifier() + .setImpurity("entropy") + .setMaxDepth(3) + .setNumTrees(3) + .setSeed(123) + val categoricalFeatures = Map.empty[Int, Int] + val numClasses = 2 + + val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val model = rf.fit(df) + + model.trees.foreach (i => { + assert(i.getMaxDepth === model.getMaxDepth) + assert(i.getSeed === model.getSeed) + assert(i.getImpurity === model.getImpurity) + }) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 60007975c3b5..e2462af2ac1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -296,7 +296,21 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { } } - ///////////////////////////////////////////////////////////////////////////// + test("tree params") { + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setCheckpointInterval(5) + .setSeed(123) + val model = gbt.fit(trainData.toDF) + + model.trees.foreach (i => { + assert(i.getMaxDepth === model.getMaxDepth) + assert(i.getCheckpointInterval === model.getCheckpointInterval) + assert(i.getSeed === model.getSeed) + }) + } + + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 0243e8d2335e..f3b0f0470e57 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -139,6 +139,25 @@ class RandomForestRegressorSuite extends MLTest with DefaultReadWriteTest{ } } + test("tree params") { + val rf = new RandomForestRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(10) + .setNumTrees(3) + .setSeed(123) + + val df = orderedLabeledPoints50_1000.toDF() + val model = rf.fit(df) + + model.trees.foreach (i => { + assert(i.getMaxDepth === model.getMaxDepth) + assert(i.getSeed === model.getSeed) + assert(i.getImpurity === model.getImpurity) + assert(i.getMaxBins === model.getMaxBins) + }) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// From 3163b6b43b99ca02642cf935d885ed2d0f98d633 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 22 Oct 2019 03:20:49 -0700 Subject: [PATCH 06/58] [SPARK-29516][SQL][TEST] Test ThriftServerQueryTestSuite asynchronously ### What changes were proposed in this pull request? This PR test `ThriftServerQueryTestSuite` in an asynchronous way. ### Why are the changes needed? The default value of `spark.sql.hive.thriftServer.async` is `true`. ### Does this PR introduce any user-facing change? No ### How was this patch tested? ``` build/sbt "hive-thriftserver/test-only *.ThriftServerQueryTestSuite" -Phive-thriftserver build/mvn -Dtest=none -DwildcardSuites=org.apache.spark.sql.hive.thriftserver.ThriftServerQueryTestSuite test -Phive-thriftserver ``` Closes #26172 from wangyum/SPARK-29516. Authored-by: Yuming Wang Signed-off-by: Yuming Wang --- .../thriftserver/ThriftServerQueryTestSuite.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala index abe91a280a14..04b1de00ccbf 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.thriftserver import java.io.File -import java.sql.{DriverManager, Statement, Timestamp} +import java.sql.{DriverManager, SQLException, Statement, Timestamp} import java.util.{Locale, MissingFormatArgumentException} import scala.util.{Random, Try} @@ -75,11 +75,6 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite { } } - override def sparkConf: SparkConf = super.sparkConf - // Hive Thrift server should not executes SQL queries in an asynchronous way - // because we may set session configuration. - .set(HiveUtils.HIVE_THRIFT_SERVER_ASYNC, false) - override val isTestWithConfigSets = false /** List of test cases to ignore, in lower cases. */ @@ -208,6 +203,12 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite { s"Exception did not match for query #$i\n${expected.sql}, " + s"expected: ${expected.output}, but got: ${output.output}") + // SQLException should not exactly match. We only assert the result contains Exception. + case _ if output.output.startsWith(classOf[SQLException].getName) => + assert(expected.output.contains("Exception"), + s"Exception did not match for query #$i\n${expected.sql}, " + + s"expected: ${expected.output}, but got: ${output.output}") + case _ => assertResult(expected.output, s"Result did not match for query #$i\n${expected.sql}") { output.output From bb49c80c890452dc047a1975b16dcd876705ad23 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Tue, 22 Oct 2019 19:08:09 +0800 Subject: [PATCH 07/58] [SPARK-21492][SQL] Fix memory leak in SortMergeJoin ### What changes were proposed in this pull request? We shall have a new mechanism that the downstream operators may notify its parents that they may release the output data stream. In this PR, we implement the mechanism as below: - Add function named `cleanupResources` in SparkPlan, which default call children's `cleanupResources` function, the operator which need a resource cleanup should rewrite this with the self cleanup and also call `super.cleanupResources`, like SortExec in this PR. - Add logic support on the trigger side, in this PR is SortMergeJoinExec, which make sure and call the `cleanupResources` to do the cleanup job for all its upstream(children) operator. ### Why are the changes needed? Bugfix for SortMergeJoin memory leak, and implement a general framework for SparkPlan resource cleanup. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? UT: Add new test suite JoinWithResourceCleanSuite to check both standard and code generation scenario. Integrate Test: Test with driver/executor default memory set 1g, local mode 10 thread. The below test(thanks taosaildrone for providing this test [here](https://github.com/apache/spark/pull/23762#issuecomment-463303175)) will pass with this PR. ``` from pyspark.sql.functions import rand, col spark.conf.set("spark.sql.join.preferSortMergeJoin", "true") spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1) # spark.conf.set("spark.sql.sortMergeJoinExec.eagerCleanupResources", "true") r1 = spark.range(1, 1001).select(col("id").alias("timestamp1")) r1 = r1.withColumn('value', rand()) r2 = spark.range(1000, 1001).select(col("id").alias("timestamp2")) r2 = r2.withColumn('value2', rand()) joined = r1.join(r2, r1.timestamp1 == r2.timestamp2, "inner") joined = joined.coalesce(1) joined.explain() joined.show() ``` Closes #26164 from xuanyuanking/SPARK-21492. Authored-by: Yuanjian Li Signed-off-by: Wenchen Fan --- .../execution/UnsafeExternalRowSorter.java | 11 +++++- .../apache/spark/sql/execution/SortExec.scala | 27 +++++++++++-- .../spark/sql/execution/SparkPlan.scala | 9 +++++ .../execution/joins/SortMergeJoinExec.scala | 39 +++++++++++++------ .../org/apache/spark/sql/JoinSuite.scala | 33 +++++++++++++++- 5 files changed, 102 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 863d80b5cb9c..3123f2187da8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -52,6 +52,12 @@ public final class UnsafeExternalRowSorter { private final UnsafeExternalRowSorter.PrefixComputer prefixComputer; private final UnsafeExternalSorter sorter; + // This flag makes sure the cleanupResource() has been called. After the cleanup work, + // iterator.next should always return false. Downstream operator triggers the resource + // cleanup while they found there's no need to keep the iterator any more. + // See more details in SPARK-21492. + private boolean isReleased = false; + public abstract static class PrefixComputer { public static class Prefix { @@ -157,7 +163,8 @@ public long getSortTimeNanos() { return sorter.getSortTimeNanos(); } - private void cleanupResources() { + public void cleanupResources() { + isReleased = true; sorter.cleanupResources(); } @@ -176,7 +183,7 @@ public Iterator sort() throws IOException { @Override public boolean hasNext() { - return sortedIterator.hasNext(); + return !isReleased && sortedIterator.hasNext(); } @Override diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index 0a955d6a7523..32d21d05e5f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -62,6 +62,14 @@ case class SortExec( "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) + private[sql] var rowSorter: UnsafeExternalRowSorter = _ + + /** + * This method gets invoked only once for each SortExec instance to initialize an + * UnsafeExternalRowSorter, both `plan.execute` and code generation are using it. + * In the code generation code path, we need to call this function outside the class so we + * should make it public. + */ def createSorter(): UnsafeExternalRowSorter = { val ordering = newOrdering(sortOrder, output) @@ -87,13 +95,13 @@ case class SortExec( } val pageSize = SparkEnv.get.memoryManager.pageSizeBytes - val sorter = UnsafeExternalRowSorter.create( + rowSorter = UnsafeExternalRowSorter.create( schema, ordering, prefixComparator, prefixComputer, pageSize, canUseRadixSort) if (testSpillFrequency > 0) { - sorter.setTestSpillFrequency(testSpillFrequency) + rowSorter.setTestSpillFrequency(testSpillFrequency) } - sorter + rowSorter } protected override def doExecute(): RDD[InternalRow] = { @@ -181,4 +189,17 @@ case class SortExec( |$sorterVariable.insertRow((UnsafeRow)${row.value}); """.stripMargin } + + /** + * In SortExec, we overwrites cleanupResources to close UnsafeExternalRowSorter. + */ + override protected[sql] def cleanupResources(): Unit = { + if (rowSorter != null) { + // There's possible for rowSorter is null here, for example, in the scenario of empty + // iterator in the current task, the downstream physical node(like SortMergeJoinExec) will + // trigger cleanupResources before rowSorter initialized in createSorter. + rowSorter.cleanupResources() + } + super.cleanupResources() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index b4cdf9e16b7e..125f76282e3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -507,6 +507,15 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } newOrdering(order, Seq.empty) } + + /** + * Cleans up the resources used by the physical operator (if any). In general, all the resources + * should be cleaned up when the task finishes but operators like SortMergeJoinExec and LimitExec + * may want eager cleanup to free up tight resources (e.g., memory). + */ + protected[sql] def cleanupResources(): Unit = { + children.foreach(_.cleanupResources()) + } } trait LeafExecNode extends SparkPlan { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 189727a9bc88..26fb0e5ffb1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -191,7 +191,8 @@ case class SortMergeJoinExec( RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), inMemoryThreshold, - spillThreshold + spillThreshold, + cleanupResources ) private[this] val joinRow = new JoinedRow @@ -235,7 +236,8 @@ case class SortMergeJoinExec( streamedIter = RowIterator.fromScala(leftIter), bufferedIter = RowIterator.fromScala(rightIter), inMemoryThreshold, - spillThreshold + spillThreshold, + cleanupResources ) val rightNullRow = new GenericInternalRow(right.output.length) new LeftOuterIterator( @@ -249,7 +251,8 @@ case class SortMergeJoinExec( streamedIter = RowIterator.fromScala(rightIter), bufferedIter = RowIterator.fromScala(leftIter), inMemoryThreshold, - spillThreshold + spillThreshold, + cleanupResources ) val leftNullRow = new GenericInternalRow(left.output.length) new RightOuterIterator( @@ -283,7 +286,8 @@ case class SortMergeJoinExec( RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), inMemoryThreshold, - spillThreshold + spillThreshold, + cleanupResources ) private[this] val joinRow = new JoinedRow @@ -318,7 +322,8 @@ case class SortMergeJoinExec( RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), inMemoryThreshold, - spillThreshold + spillThreshold, + cleanupResources ) private[this] val joinRow = new JoinedRow @@ -360,7 +365,8 @@ case class SortMergeJoinExec( RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), inMemoryThreshold, - spillThreshold + spillThreshold, + cleanupResources ) private[this] val joinRow = new JoinedRow @@ -640,6 +646,9 @@ case class SortMergeJoinExec( (evaluateVariables(leftVars), "") } + val thisPlan = ctx.addReferenceObj("plan", this) + val eagerCleanup = s"$thisPlan.cleanupResources();" + s""" |while (findNextInnerJoinRows($leftInput, $rightInput)) { | ${leftVarDecl.mkString("\n")} @@ -653,6 +662,7 @@ case class SortMergeJoinExec( | } | if (shouldStop()) return; |} + |$eagerCleanup """.stripMargin } } @@ -678,6 +688,7 @@ case class SortMergeJoinExec( * @param inMemoryThreshold Threshold for number of rows guaranteed to be held in memory by * internal buffer * @param spillThreshold Threshold for number of rows to be spilled by internal buffer + * @param eagerCleanupResources the eager cleanup function to be invoked when no join row found */ private[joins] class SortMergeJoinScanner( streamedKeyGenerator: Projection, @@ -686,7 +697,8 @@ private[joins] class SortMergeJoinScanner( streamedIter: RowIterator, bufferedIter: RowIterator, inMemoryThreshold: Int, - spillThreshold: Int) { + spillThreshold: Int, + eagerCleanupResources: () => Unit) { private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ private[this] var bufferedRow: InternalRow = _ @@ -710,7 +722,8 @@ private[joins] class SortMergeJoinScanner( def getBufferedMatches: ExternalAppendOnlyUnsafeRowArray = bufferedMatches /** - * Advances both input iterators, stopping when we have found rows with matching join keys. + * Advances both input iterators, stopping when we have found rows with matching join keys. If no + * join rows found, try to do the eager resources cleanup. * @return true if matching rows have been found and false otherwise. If this returns true, then * [[getStreamedRow]] and [[getBufferedMatches]] can be called to construct the join * results. @@ -720,7 +733,7 @@ private[joins] class SortMergeJoinScanner( // Advance the streamed side of the join until we find the next row whose join key contains // no nulls or we hit the end of the streamed iterator. } - if (streamedRow == null) { + val found = if (streamedRow == null) { // We have consumed the entire streamed iterator, so there can be no more matches. matchJoinKey = null bufferedMatches.clear() @@ -760,17 +773,19 @@ private[joins] class SortMergeJoinScanner( true } } + if (!found) eagerCleanupResources() + found } /** * Advances the streamed input iterator and buffers all rows from the buffered input that - * have matching keys. + * have matching keys. If no join rows found, try to do the eager resources cleanup. * @return true if the streamed iterator returned a row, false otherwise. If this returns true, * then [[getStreamedRow]] and [[getBufferedMatches]] can be called to produce the outer * join results. */ final def findNextOuterJoinRows(): Boolean = { - if (!advancedStreamed()) { + val found = if (!advancedStreamed()) { // We have consumed the entire streamed iterator, so there can be no more matches. matchJoinKey = null bufferedMatches.clear() @@ -800,6 +815,8 @@ private[joins] class SortMergeJoinScanner( // If there is a streamed input then we always return true true } + if (!found) eagerCleanupResources() + found } // --- Private methods -------------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 72742644ff34..62f2d21e5270 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -22,12 +22,14 @@ import java.util.Locale import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer +import org.mockito.Mockito._ + import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.Filter -import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, SortExec} +import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, SortExec, SparkPlan} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.python.BatchEvalPythonExec import org.apache.spark.sql.internal.SQLConf @@ -37,6 +39,23 @@ import org.apache.spark.sql.types.StructType class JoinSuite extends QueryTest with SharedSparkSession { import testImplicits._ + private def attachCleanupResourceChecker(plan: SparkPlan): Unit = { + // SPARK-21492: Check cleanupResources are finally triggered in SortExec node for every + // test case + plan.foreachUp { + case s: SortExec => + val sortExec = spy(s) + verify(sortExec, atLeastOnce).cleanupResources() + verify(sortExec.rowSorter, atLeastOnce).cleanupResources() + case _ => + } + } + + override protected def checkAnswer(df: => DataFrame, rows: Seq[Row]): Unit = { + attachCleanupResourceChecker(df.queryExecution.sparkPlan) + super.checkAnswer(df, rows) + } + setupTestData() def statisticSizeInByte(df: DataFrame): BigInt = { @@ -1039,4 +1058,16 @@ class JoinSuite extends QueryTest with SharedSparkSession { checkAnswer(df, Row(1, 2, 1, 2) :: Nil) } + + test("SPARK-21492: cleanupResource without code generation") { + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.range(0, 10, 1, 2) + val df2 = spark.range(10).select($"id".as("b1"), (- $"id").as("b2")) + val res = df1.join(df2, $"id" === $"b1" && $"id" === $"b2").select($"b1", $"b2", $"id") + checkAnswer(res, Row(0, 0, 0)) + } + } } From b4844eea1fc0cfb82cfe7e13f22655b9729c3ad4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 22 Oct 2019 19:17:28 +0800 Subject: [PATCH 08/58] [SPARK-29517][SQL] TRUNCATE TABLE should look up catalog/table like v2 commands ### What changes were proposed in this pull request? Add TruncateTableStatement and make TRUNCATE TABLE go through the same catalog/table resolution framework of v2 commands. ### Why are the changes needed? It's important to make all the commands have the same table resolution behavior, to avoid confusing end-users. e.g. ``` USE my_catalog DESC t // success and describe the table t from my_catalog TRUNCATE TABLE t // report table not found as there is no table t in the session catalog ``` ### Does this PR introduce any user-facing change? yes. When running TRUNCATE TABLE, Spark fails the command if the current catalog is set to a v2 catalog, or the table name specified a v2 catalog. ### How was this patch tested? Unit tests. Closes #26174 from viirya/SPARK-29517. Authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 14 ++++++++++++ .../catalyst/plans/logical/statements.scala | 7 ++++++ .../sql/catalyst/parser/DDLParserSuite.scala | 10 +++++++++ .../analysis/ResolveSessionCatalog.scala | 8 ++++++- .../spark/sql/execution/SparkSqlParser.scala | 14 ------------ .../sql/connector/DataSourceV2SQLSuite.scala | 22 +++++++++++++++++++ 7 files changed, 61 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 1839203e3b23..4c93f1fe1197 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -211,7 +211,7 @@ statement | CLEAR CACHE #clearCache | LOAD DATA LOCAL? INPATH path=STRING OVERWRITE? INTO TABLE tableIdentifier partitionSpec? #loadData - | TRUNCATE TABLE tableIdentifier partitionSpec? #truncateTable + | TRUNCATE TABLE multipartIdentifier partitionSpec? #truncateTable | MSCK REPAIR TABLE multipartIdentifier #repairTable | op=(ADD | LIST) identifier .*? #manageResource | SET ROLE .*? #failNativeCommand diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 8af7cf9ad800..862903246ed3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2728,4 +2728,18 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging override def visitRepairTable(ctx: RepairTableContext): LogicalPlan = withOrigin(ctx) { RepairTableStatement(visitMultipartIdentifier(ctx.multipartIdentifier())) } + + /** + * Create a [[TruncateTableStatement]] command. + * + * For example: + * {{{ + * TRUNCATE TABLE multi_part_name [PARTITION (partcol1=val1, partcol2=val2 ...)] + * }}} + */ + override def visitTruncateTable(ctx: TruncateTableContext): LogicalPlan = withOrigin(ctx) { + TruncateTableStatement( + visitMultipartIdentifier(ctx.multipartIdentifier), + Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec)) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index 72d5cbb7d904..1a69a6ab3380 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -316,3 +316,10 @@ case class AnalyzeColumnStatement( * A REPAIR TABLE statement, as parsed from SQL */ case class RepairTableStatement(tableName: Seq[String]) extends ParsedStatement + +/** + * A TRUNCATE TABLE statement, as parsed from SQL + */ +case class TruncateTableStatement( + tableName: Seq[String], + partitionSpec: Option[TablePartitionSpec]) extends ParsedStatement diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 0eaf74f65506..0d87d0ce9b0f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -961,6 +961,16 @@ class DDLParserSuite extends AnalysisTest { RepairTableStatement(Seq("a", "b", "c"))) } + test("TRUNCATE table") { + comparePlans( + parsePlan("TRUNCATE TABLE a.b.c"), + TruncateTableStatement(Seq("a", "b", "c"), None)) + + comparePlans( + parsePlan("TRUNCATE TABLE a.b.c PARTITION(ds='2017-06-10')"), + TruncateTableStatement(Seq("a", "b", "c"), Some(Map("ds" -> "2017-06-10")))) + } + private case class TableSpec( name: Seq[String], schema: Option[StructType], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 72f539f72008..978214778a4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog, TableChange, V1Table} import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableRecoverPartitionsCommand, AlterTableSetLocationCommand, AlterTableSetPropertiesCommand, AlterTableUnsetPropertiesCommand, AnalyzeColumnCommand, AnalyzePartitionCommand, AnalyzeTableCommand, DescribeColumnCommand, DescribeTableCommand, DropTableCommand, ShowTablesCommand} +import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableRecoverPartitionsCommand, AlterTableSetLocationCommand, AlterTableSetPropertiesCommand, AlterTableUnsetPropertiesCommand, AnalyzeColumnCommand, AnalyzePartitionCommand, AnalyzeTableCommand, DescribeColumnCommand, DescribeTableCommand, DropTableCommand, ShowTablesCommand, TruncateTableCommand} import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource} import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.internal.SQLConf @@ -282,6 +282,12 @@ class ResolveSessionCatalog( AlterTableRecoverPartitionsCommand( v1TableName.asTableIdentifier, "MSCK REPAIR TABLE") + + case TruncateTableStatement(tableName, partitionSpec) => + val v1TableName = parseV1Table(tableName, "TRUNCATE TABLE") + TruncateTableCommand( + v1TableName.asTableIdentifier, + partitionSpec) } private def parseV1Table(tableName: Seq[String], sql: String): Seq[String] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index cdee11781324..a51d29431dec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -346,20 +346,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { ) } - /** - * Create a [[TruncateTableCommand]] command. - * - * For example: - * {{{ - * TRUNCATE TABLE tablename [PARTITION (partcol1=val1, partcol2=val2 ...)] - * }}} - */ - override def visitTruncateTable(ctx: TruncateTableContext): LogicalPlan = withOrigin(ctx) { - TruncateTableCommand( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec)) - } - /** * Create a [[CreateDatabaseCommand]] command. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index d253e6078ddc..01c051f15635 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1210,6 +1210,28 @@ class DataSourceV2SQLSuite } } + test("TRUNCATE TABLE") { + val t = "testcat.ns1.ns2.tbl" + withTable(t) { + sql( + s""" + |CREATE TABLE $t (id bigint, data string) + |USING foo + |PARTITIONED BY (id) + """.stripMargin) + + val e1 = intercept[AnalysisException] { + sql(s"TRUNCATE TABLE $t") + } + assert(e1.message.contains("TRUNCATE TABLE is only supported with v1 tables")) + + val e2 = intercept[AnalysisException] { + sql(s"TRUNCATE TABLE $t PARTITION(id='1')") + } + assert(e2.message.contains("TRUNCATE TABLE is only supported with v1 tables")) + } + } + private def assertAnalysisError(sqlStatement: String, expectedError: String): Unit = { val errMsg = intercept[AnalysisException] { sql(sqlStatement) From 877993847c0baa016003639e16708373e57ca64b Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 22 Oct 2019 08:55:37 -0500 Subject: [PATCH 09/58] [SPARK-28787][DOC][SQL] Document LOAD DATA statement in SQL Reference ### What changes were proposed in this pull request? Document LOAD DATA statement in SQL Reference ### Why are the changes needed? To complete the SQL Reference ### Does this PR introduce any user-facing change? Yes ### How was this patch tested? Tested using jykyll build --serve Here are the screen shots: ![image](https://user-images.githubusercontent.com/13592258/64073167-e7cd0800-cc4e-11e9-9fcc-92fe4cb5a942.png) ![image](https://user-images.githubusercontent.com/13592258/64073169-ee5b7f80-cc4e-11e9-9a36-cc023bcd32b1.png) ![image](https://user-images.githubusercontent.com/13592258/64073170-f4516080-cc4e-11e9-9101-2609a01fe6fe.png) Closes #25522 from huaxingao/spark-28787. Authored-by: Huaxin Gao Signed-off-by: Sean Owen --- docs/sql-ref-syntax-dml-load.md | 103 +++++++++++++++++++++++++++++++- 1 file changed, 100 insertions(+), 3 deletions(-) diff --git a/docs/sql-ref-syntax-dml-load.md b/docs/sql-ref-syntax-dml-load.md index fd25ba314e0b..c2a6102db4aa 100644 --- a/docs/sql-ref-syntax-dml-load.md +++ b/docs/sql-ref-syntax-dml-load.md @@ -1,7 +1,7 @@ --- layout: global -title: LOAD -displayTitle: LOAD +title: LOAD DATA +displayTitle: LOAD DATA license: | Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with @@ -19,4 +19,101 @@ license: | limitations under the License. --- -**This page is under construction** +### Description +`LOAD DATA` statement loads the data into a table from the user specified directory or file. If a directory is specified then all the files from the directory are loaded. If a file is specified then only the single file is loaded. Additionally the `LOAD DATA` statement takes an optional partition specification. When a partition is specified, the data files (when input source is a directory) or the single file (when input source is a file) are loaded into the partition of the target table. + +### Syntax +{% highlight sql %} +LOAD DATA [ LOCAL ] INPATH path [ OVERWRITE ] INTO TABLE table_name + [ PARTITION ( partition_col_name = partition_col_val [ , ... ] ) ] +{% endhighlight %} + +### Parameters +
+
path
+
Path of the file system. It can be either an absolute or a relative path.
+
+ +
+
table_name
+
The name of an existing table.
+
+ +
+
PARTITION ( partition_col_name = partition_col_val [ , ... ] )
+
Specifies one or more partition column and value pairs.
+
+ +
+
LOCAL
+
If specified, it causes the INPATH to be resolved against the local file system, instead of the default file system, which is typically a distributed storage.
+
+ +
+
OVERWRITE
+
By default, new data is appended to the table. If OVERWRITE is used, the table is instead overwritten with new data.
+
+ +### Examples +{% highlight sql %} + -- Example without partition specification. + -- Assuming the students table has already been created and populated. + SELECT * FROM students; + + + -------------- + ------------------------------ + -------------- + + | name | address | student_id | + + -------------- + ------------------------------ + -------------- + + | Amy Smith | 123 Park Ave, San Jose | 111111 | + + -------------- + ------------------------------ + -------------- + + + CREATE TABLE test_load (name VARCHAR(64), address VARCHAR(64), student_id INT); + + -- Assuming the students table is in '/user/hive/warehouse/' + LOAD DATA LOCAL INPATH '/user/hive/warehouse/students' OVERWRITE INTO TABLE test_load; + + SELECT * FROM test_load; + + + -------------- + ------------------------------ + -------------- + + | name | address | student_id | + + -------------- + ------------------------------ + -------------- + + | Amy Smith | 123 Park Ave, San Jose | 111111 | + + -------------- + ------------------------------ + -------------- + + + -- Example with partition specification. + CREATE TABLE test_partition (c1 INT, c2 INT, c3 INT) USING HIVE PARTITIONED BY (c2, c3); + + INSERT INTO test_partition PARTITION (c2 = 2, c3 = 3) VALUES (1); + + INSERT INTO test_partition PARTITION (c2 = 5, c3 = 6) VALUES (4); + + INSERT INTO test_partition PARTITION (c2 = 8, c3 = 9) VALUES (7); + + SELECT * FROM test_partition; + + + ------- + ------- + ----- + + | c1 | c2 | c3 | + + ------- + --------------- + + | 1 | 2 | 3 | + + ------- + ------- + ----- + + | 4 | 5 | 6 | + + ------- + ------- + ----- + + | 7 | 8 | 9 | + + ------- + ------- + ----- + + + CREATE TABLE test_load_partition (c1 INT, c2 INT, c3 INT) USING HIVE PARTITIONED BY (c2, c3); + + -- Assuming the test_partition table is in '/user/hive/warehouse/' + LOAD DATA LOCAL INPATH '/user/hive/warehouse/test_partition/c2=2/c3=3' + OVERWRITE INTO TABLE test_load_partition PARTITION (c2=2, c3=3); + + SELECT * FROM test_load_partition; + + + ------- + ------- + ----- + + | c1 | c2 | c3 | + + ------- + --------------- + + | 1 | 2 | 3 | + + ------- + ------- + ----- + + + +{% endhighlight %} + From c1c64851ed4b8dac3ca4becaea9e6721eb25c589 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 22 Oct 2019 08:56:44 -0500 Subject: [PATCH 10/58] [SPARK-28793][DOC][SQL] Document CREATE FUNCTION in SQL Reference ### What changes were proposed in this pull request? Document CREATE FUNCTION statement in SQL Reference Guide. ### Why are the changes needed? Currently Spark lacks documentation on the supported SQL constructs causing confusion among users who sometimes have to look at the code to understand the usage. This is aimed at addressing this issue. ### Does this PR introduce any user-facing change? Yes. **Before:** There was no documentation for this. **After.** Screen Shot 2019-09-22 at 3 01 52 PM Screen Shot 2019-09-22 at 3 02 11 PM Screen Shot 2019-09-22 at 3 02 39 PM Screen Shot 2019-09-22 at 3 04 04 PM ### How was this patch tested? Tested using jykyll build --serve Closes #25894 from dilipbiswal/sql-ref-create-function. Authored-by: Dilip Biswal Signed-off-by: Sean Owen --- docs/sql-getting-started.md | 3 + docs/sql-ref-syntax-ddl-create-function.md | 151 ++++++++++++++++++++- 2 files changed, 153 insertions(+), 1 deletion(-) diff --git a/docs/sql-getting-started.md b/docs/sql-getting-started.md index 5d18c48879f9..0ded2654719c 100644 --- a/docs/sql-getting-started.md +++ b/docs/sql-getting-started.md @@ -346,6 +346,9 @@ For example: +## Scalar Functions +(to be filled soon) + ## Aggregations The [built-in DataFrames functions](api/scala/index.html#org.apache.spark.sql.functions$) provide common diff --git a/docs/sql-ref-syntax-ddl-create-function.md b/docs/sql-ref-syntax-ddl-create-function.md index f95a9eba42c2..4c09ebafb1f5 100644 --- a/docs/sql-ref-syntax-ddl-create-function.md +++ b/docs/sql-ref-syntax-ddl-create-function.md @@ -19,4 +19,153 @@ license: | limitations under the License. --- -**This page is under construction** +### Description +The `CREATE FUNCTION` statement is used to create a temporary or permanent function +in Spark. Temporary functions are scoped at a session level where as permanent +functions are created in the persistent catalog and are made available to +all sessions. The resources specified in the `USING` clause are made available +to all executors when they are executed for the first time. In addition to the +SQL interface, spark allows users to create custom user defined scalar and +aggregate functions using Scala, Python and Java APIs. Please refer to +[scalar_functions](sql-getting-started.html#scalar-functions) and +[aggregate functions](sql-getting-started#aggregations) for more information. + +### Syntax +{% highlight sql %} +CREATE [ OR REPLACE ] [ TEMPORARY ] FUNCTION [ IF NOT EXISTS ] + function_name AS class_name [ resource_locations ] +{% endhighlight %} + +### Parameters +
+
OR REPLACE
+
+ If specified, the resources for the function are reloaded. This is mainly useful + to pick up any changes made to the implementation of the function. This + parameter is mutually exclusive to IF NOT EXISTS and can not + be specified together. +
+
TEMPORARY
+
+ Indicates the scope of function being created. When TEMPORARY is specified, the + created function is valid and visible in the current session. No persistent + entry is made in the catalog for these kind of functions. +
+
IF NOT EXISTS
+
+ If specified, creates the function only when it does not exist. The creation + of function succeeds (no error is thrown) if the specified function already + exists in the system. This parameter is mutually exclusive to OR REPLACE + and can not be specified together. +
+
function_name
+
+ Specifies a name of funnction to be created. The function name may be + optionally qualified with a database name.

+ Syntax: + + [database_name.]function_name + +
+
class_name
+
+ Specifies the name of the class that provides the implementation for function to be created. + The implementing class should extend one of the base classes as follows: +
    +
  • Should extend UDF or UDAF in org.apache.hadoop.hive.ql.exec package.
  • +
  • Should extend AbstractGenericUDAFResolver, GenericUDF, or + GenericUDTF in org.apache.hadoop.hive.ql.udf.generic package.
  • +
  • Should extend UserDefinedAggregateFunction in org.apache.spark.sql.expressions package.
  • +
+
+
resource_locations
+
+ Specifies the list of resources that contain the implementation of the function + along with its dependencies.

+ Syntax: + + USING { { (JAR | FILE ) resource_uri} , ...} + +
+
+ +### Examples +{% highlight sql %} +-- 1. Create a simple UDF `SimpleUdf` that increments the supplied integral value by 10. +-- import org.apache.hadoop.hive.ql.exec.UDF; +-- public class SimpleUdf extends UDF { +-- public int evaluate(int value) { +-- return value + 10; +-- } +-- } +-- 2. Compile and place it in a JAR file called `SimpleUdf.jar` in /tmp. + +-- Create a table called `test` and insert two rows. +CREATE TABLE test(c1 INT); +INSERT INTO test VALUES (1), (2); + +-- Create a permanent function called `simple_udf`. +CREATE FUNCTION simple_udf AS 'SimpleUdf' + USING JAR '/tmp/SimpleUdf.jar'; + +-- Verify that the function is in the registry. +SHOW USER FUNCTIONS; + +------------------+ + | function| + +------------------+ + |default.simple_udf| + +------------------+ + +-- Invoke the function. Every selected value should be incremented by 10. +SELECT simple_udf(c1) AS function_return_value FROM t1; + +---------------------+ + |function_return_value| + +---------------------+ + | 11| + | 12| + +---------------------+ + +-- Created a temporary function. +CREATE TEMPORARY FUNCTION simple_temp_udf AS 'SimpleUdf' + USING JAR '/tmp/SimpleUdf.jar'; + +-- Verify that the newly created temporary function is in the registry. +-- Please note that the temporary function does not have a qualified +-- database associated with it. +SHOW USER FUNCTIONS; + +------------------+ + | function| + +------------------+ + |default.simple_udf| + | simple_temp_udf| + +------------------+ + +-- 1. Modify `SimpleUdf`'s implementation to add supplied integral value by 20. +-- import org.apache.hadoop.hive.ql.exec.UDF; + +-- public class SimpleUdfR extends UDF { +-- public int evaluate(int value) { +-- return value + 20; +-- } +-- } +-- 2. Compile and place it in a jar file called `SimpleUdfR.jar` in /tmp. + +-- Replace the implementation of `simple_udf` +CREATE OR REPLACE FUNCTION simple_udf AS 'SimpleUdfR' + USING JAR '/tmp/SimpleUdfR.jar'; + +-- Invoke the function. Every selected value should be incremented by 20. +SELECT simple_udf(c1) AS function_return_value FROM t1; ++---------------------+ +|function_return_value| ++---------------------+ +| 21| +| 22| ++---------------------+ + +{% endhighlight %} + +### Related statements +- [SHOW FUNCTIONS](sql-ref-syntax-aux-show-functions.html) +- [DESCRIBE FUNCTION](sql-ref-syntax-aux-describe-function.html) +- [DROP FUNCTION](sql-ref-syntax-ddl-drop-function.html) From 2036a8cca7a428672310ae11e71d0f1f51074cac Mon Sep 17 00:00:00 2001 From: chenjuanni Date: Tue, 22 Oct 2019 08:58:12 -0500 Subject: [PATCH 11/58] [SPARK-29488][WEBUI] In Web UI, stage page has js error when sort table MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? In Web UI, stage page has js error when sort table. https://issues.apache.org/jira/browse/SPARK-29488 ### Why are the changes needed? In Web UI, follow the steps below, get js error "Uncaught TypeError: Failed to execute 'removeChild' on 'Node': parameter 1 is not of type 'Node'.". 1) Click "Summary Metrics..." 's tablehead "Min" 2) Click "Aggregated Metrics by Executor" 's tablehead "Task Time" 3) Click "Summary Metrics..." 's tablehead "Min"(the same as step 1.) ### Does this PR introduce any user-facing change? No. ### How was this patch tested? In Web UI, follow the steps below, no error occur. 1) Click "Summary Metrics..." 's tablehead "Min" 2) Click "Aggregated Metrics by Executor" 's tablehead "Task Time" 3) Click "Summary Metrics..." 's tablehead "Min"(the same as step 1.) ![image](https://user-images.githubusercontent.com/7802338/66899878-464b1b80-f02e-11e9-9660-6cdaab283491.png) Closes #26136 from cjn082030/SPARK-1. Authored-by: chenjuanni Signed-off-by: Sean Owen --- .../org/apache/spark/ui/static/sorttable.js | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js index 9960d5c34d1f..ecd580e5c64a 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js +++ b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js @@ -97,9 +97,14 @@ sorttable = { sorttable.reverse(this.sorttable_tbody); this.className = this.className.replace('sorttable_sorted', 'sorttable_sorted_reverse'); - this.removeChild(document.getElementById('sorttable_sortfwdind')); + rowlists = this.parentNode.getElementsByTagName("span"); + for (var j=0; j < rowlists.length; j++) { + if (rowlists[j].className.search(/\bsorttable_sortfwdind\b/)) { + rowlists[j].parentNode.removeChild(rowlists[j]); + } + } sortrevind = document.createElement('span'); - sortrevind.id = "sorttable_sortrevind"; + sortrevind.class = "sorttable_sortrevind"; sortrevind.innerHTML = stIsIE ? ' 5' : ' ▾'; this.appendChild(sortrevind); return; @@ -110,9 +115,14 @@ sorttable = { sorttable.reverse(this.sorttable_tbody); this.className = this.className.replace('sorttable_sorted_reverse', 'sorttable_sorted'); - this.removeChild(document.getElementById('sorttable_sortrevind')); + rowlists = this.parentNode.getElementsByTagName("span"); + for (var j=0; j < rowlists.length; j++) { + if (rowlists[j].className.search(/\sorttable_sortrevind\b/)) { + rowlists[j].parentNode.removeChild(rowlists[j]); + } + } sortfwdind = document.createElement('span'); - sortfwdind.id = "sorttable_sortfwdind"; + sortfwdind.class = "sorttable_sortfwdind"; sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▴'; this.appendChild(sortfwdind); return; @@ -126,14 +136,17 @@ sorttable = { cell.className = cell.className.replace('sorttable_sorted',''); } }); - sortfwdind = document.getElementById('sorttable_sortfwdind'); - if (sortfwdind) { sortfwdind.parentNode.removeChild(sortfwdind); } - sortrevind = document.getElementById('sorttable_sortrevind'); - if (sortrevind) { sortrevind.parentNode.removeChild(sortrevind); } + rowlists = this.parentNode.getElementsByTagName("span"); + for (var j=0; j < rowlists.length; j++) { + if (rowlists[j].className.search(/\bsorttable_sortfwdind\b/) + || rowlists[j].className.search(/\sorttable_sortrevind\b/) ) { + rowlists[j].parentNode.removeChild(rowlists[j]); + } + } this.className += ' sorttable_sorted'; sortfwdind = document.createElement('span'); - sortfwdind.id = "sorttable_sortfwdind"; + sortfwdind.class = "sorttable_sortfwdind"; sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▴'; this.appendChild(sortfwdind); From 80094688fdbf5cc0d10c295fadb92965b460de5d Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 22 Oct 2019 14:14:59 -0700 Subject: [PATCH 12/58] [SPARK-29556][CORE] Avoid putting request path in error response in ErrorServlet ### What changes were proposed in this pull request? Don't include `$path` from user query in the error response. ### Why are the changes needed? The path could contain input that is then rendered as HTML in the error response. It's not clear whether it's exploitable, but better safe than sorry as the path info really isn't that important in this context. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Existing tests. Closes #26211 from srowen/SPARK-29556. Authored-by: Sean Owen Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/deploy/rest/RestSubmissionServer.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala index e59bf3f0eaf4..f60d940b8c82 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -317,8 +317,7 @@ private class ErrorServlet extends RestServlet { versionMismatch = true s"Unknown protocol version '$unknownVersion'." case _ => - // never reached - s"Malformed path $path." + "Malformed path." } msg += s" Please submit requests through http://[host]:[port]/$serverVersion/submissions/..." val error = handleError(msg) From 3bf5355e24094153db5cac4d34bf5ccead31772a Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 22 Oct 2019 14:47:17 -0700 Subject: [PATCH 13/58] [SPARK-29539][SQL] SHOW PARTITIONS should look up catalog/table like v2 commands ### What changes were proposed in this pull request? Add ShowPartitionsStatement and make SHOW PARTITIONS go through the same catalog/table resolution framework of v2 commands. ### Why are the changes needed? It's important to make all the commands have the same table resolution behavior, to avoid confusing end-users. ### Does this PR introduce any user-facing change? Yes. When running SHOW PARTITIONS, Spark fails the command if the current catalog is set to a v2 catalog, or the table name specified a v2 catalog. ### How was this patch tested? Unit tests. Closes #26198 from huaxingao/spark-29539. Authored-by: Huaxin Gao Signed-off-by: Liang-Chi Hsieh --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 17 ++++++++++++ .../catalyst/plans/logical/statements.scala | 6 +++++ .../sql/catalyst/parser/DDLParserSuite.scala | 26 +++++++++++++++++++ .../analysis/ResolveSessionCatalog.scala | 8 +++++- .../spark/sql/execution/SparkSqlParser.scala | 17 ------------ .../sql/connector/DataSourceV2SQLSuite.scala | 22 ++++++++++++++++ .../execution/command/DDLParserSuite.scala | 21 --------------- 8 files changed, 79 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 4c93f1fe1197..963077c35df9 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -194,7 +194,7 @@ statement ('(' key=tablePropertyKey ')')? #showTblProperties | SHOW COLUMNS (FROM | IN) tableIdentifier ((FROM | IN) db=errorCapturingIdentifier)? #showColumns - | SHOW PARTITIONS tableIdentifier partitionSpec? #showPartitions + | SHOW PARTITIONS multipartIdentifier partitionSpec? #showPartitions | SHOW identifier? FUNCTIONS (LIKE? (qualifiedName | pattern=STRING))? #showFunctions | SHOW CREATE TABLE tableIdentifier #showCreateTable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 862903246ed3..548042bc9767 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2742,4 +2742,21 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging visitMultipartIdentifier(ctx.multipartIdentifier), Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec)) } + + /** + * A command for users to list the partition names of a table. If partition spec is specified, + * partitions that match the spec are returned. Otherwise an empty result set is returned. + * + * This function creates a [[ShowPartitionsStatement]] logical plan + * + * The syntax of using this command in SQL is: + * {{{ + * SHOW PARTITIONS multi_part_name [partition_spec]; + * }}} + */ + override def visitShowPartitions(ctx: ShowPartitionsContext): LogicalPlan = withOrigin(ctx) { + val table = visitMultipartIdentifier(ctx.multipartIdentifier) + val partitionKeys = Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec) + ShowPartitionsStatement(table, partitionKeys) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index 1a69a6ab3380..a73a2975aa9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -323,3 +323,9 @@ case class RepairTableStatement(tableName: Seq[String]) extends ParsedStatement case class TruncateTableStatement( tableName: Seq[String], partitionSpec: Option[TablePartitionSpec]) extends ParsedStatement + +/** + * A SHOW PARTITIONS statement, as parsed from SQL + */ +case class ShowPartitionsStatement(tableName: Seq[String], + partitionSpec: Option[TablePartitionSpec]) extends ParsedStatement diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 0d87d0ce9b0f..1dacb2384ac1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -971,6 +971,32 @@ class DDLParserSuite extends AnalysisTest { TruncateTableStatement(Seq("a", "b", "c"), Some(Map("ds" -> "2017-06-10")))) } + test("SHOW PARTITIONS") { + val sql1 = "SHOW PARTITIONS t1" + val sql2 = "SHOW PARTITIONS db1.t1" + val sql3 = "SHOW PARTITIONS t1 PARTITION(partcol1='partvalue', partcol2='partvalue')" + val sql4 = "SHOW PARTITIONS a.b.c" + val sql5 = "SHOW PARTITIONS a.b.c PARTITION(ds='2017-06-10')" + + val parsed1 = parsePlan(sql1) + val expected1 = ShowPartitionsStatement(Seq("t1"), None) + val parsed2 = parsePlan(sql2) + val expected2 = ShowPartitionsStatement(Seq("db1", "t1"), None) + val parsed3 = parsePlan(sql3) + val expected3 = ShowPartitionsStatement(Seq("t1"), + Some(Map("partcol1" -> "partvalue", "partcol2" -> "partvalue"))) + val parsed4 = parsePlan(sql4) + val expected4 = ShowPartitionsStatement(Seq("a", "b", "c"), None) + val parsed5 = parsePlan(sql5) + val expected5 = ShowPartitionsStatement(Seq("a", "b", "c"), Some(Map("ds" -> "2017-06-10"))) + + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + comparePlans(parsed3, expected3) + comparePlans(parsed4, expected4) + comparePlans(parsed5, expected5) + } + private case class TableSpec( name: Seq[String], schema: Option[StructType], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 978214778a4a..4a2e6731d9d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog, TableChange, V1Table} import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableRecoverPartitionsCommand, AlterTableSetLocationCommand, AlterTableSetPropertiesCommand, AlterTableUnsetPropertiesCommand, AnalyzeColumnCommand, AnalyzePartitionCommand, AnalyzeTableCommand, DescribeColumnCommand, DescribeTableCommand, DropTableCommand, ShowTablesCommand, TruncateTableCommand} +import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableRecoverPartitionsCommand, AlterTableSetLocationCommand, AlterTableSetPropertiesCommand, AlterTableUnsetPropertiesCommand, AnalyzeColumnCommand, AnalyzePartitionCommand, AnalyzeTableCommand, DescribeColumnCommand, DescribeTableCommand, DropTableCommand, ShowPartitionsCommand, ShowTablesCommand, TruncateTableCommand} import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource} import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.internal.SQLConf @@ -288,6 +288,12 @@ class ResolveSessionCatalog( TruncateTableCommand( v1TableName.asTableIdentifier, partitionSpec) + + case ShowPartitionsStatement(tableName, partitionSpec) => + val v1TableName = parseV1Table(tableName, "SHOW PARTITIONS") + ShowPartitionsCommand( + v1TableName.asTableIdentifier, + partitionSpec) } private def parseV1Table(tableName: Seq[String], sql: String): Seq[String] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index a51d29431dec..3f3f6b373eb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -135,23 +135,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { ShowColumnsCommand(Option(ctx.db).map(_.getText), visitTableIdentifier(ctx.tableIdentifier)) } - /** - * A command for users to list the partition names of a table. If partition spec is specified, - * partitions that match the spec are returned. Otherwise an empty result set is returned. - * - * This function creates a [[ShowPartitionsCommand]] logical plan - * - * The syntax of using this command in SQL is: - * {{{ - * SHOW PARTITIONS table_identifier [partition_spec]; - * }}} - */ - override def visitShowPartitions(ctx: ShowPartitionsContext): LogicalPlan = withOrigin(ctx) { - val table = visitTableIdentifier(ctx.tableIdentifier) - val partitionKeys = Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec) - ShowPartitionsCommand(table, partitionKeys) - } - /** * Creates a [[ShowCreateTableCommand]] */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 01c051f15635..39709ab426a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1232,6 +1232,28 @@ class DataSourceV2SQLSuite } } + test("SHOW PARTITIONS") { + val t = "testcat.ns1.ns2.tbl" + withTable(t) { + sql( + s""" + |CREATE TABLE $t (id bigint, data string) + |USING foo + |PARTITIONED BY (id) + """.stripMargin) + + val e1 = intercept[AnalysisException] { + val partition = sql(s"SHOW PARTITIONS $t") + } + assert(e1.message.contains("SHOW PARTITIONS is only supported with v1 tables")) + + val e2 = intercept[AnalysisException] { + val partition2 = sql(s"SHOW PARTITIONS $t PARTITION(id='1')") + } + assert(e2.message.contains("SHOW PARTITIONS is only supported with v1 tables")) + } + } + private def assertAnalysisError(sqlStatement: String, expectedError: String): Unit = { val errMsg = intercept[AnalysisException] { sql(sqlStatement) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index 5a5899cbccc5..0640d0540baa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -870,27 +870,6 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { comparePlans(parsed4, expected4) } - - test("show partitions") { - val sql1 = "SHOW PARTITIONS t1" - val sql2 = "SHOW PARTITIONS db1.t1" - val sql3 = "SHOW PARTITIONS t1 PARTITION(partcol1='partvalue', partcol2='partvalue')" - - val parsed1 = parser.parsePlan(sql1) - val expected1 = - ShowPartitionsCommand(TableIdentifier("t1", None), None) - val parsed2 = parser.parsePlan(sql2) - val expected2 = - ShowPartitionsCommand(TableIdentifier("t1", Some("db1")), None) - val expected3 = - ShowPartitionsCommand(TableIdentifier("t1", None), - Some(Map("partcol1" -> "partvalue", "partcol2" -> "partvalue"))) - val parsed3 = parser.parsePlan(sql3) - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) - comparePlans(parsed3, expected3) - } - test("support for other types in DBPROPERTIES") { val sql = """ From f23c5d7f6705348ddeac0b714b29374cba3a4efe Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 22 Oct 2019 16:30:29 -0700 Subject: [PATCH 14/58] [SPARK-29560][BUILD] Add typesafe bintray repo for sbt-mima-plugin ### What changes were proposed in this pull request? This add `typesafe` bintray repo for `sbt-mima-plugin`. ### Why are the changes needed? Since Oct 21, the following plugin causes [Jenkins failures](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-branch-2.4-test-sbt-hadoop-2.6/611/console ) due to the missing jar. - `branch-2.4`: `sbt-mima-plugin:0.1.17` is missing. - `master`: `sbt-mima-plugin:0.3.0` is missing. These versions of `sbt-mima-plugin` seems to be removed from the old repo. ``` $ rm -rf ~/.ivy2/ $ build/sbt scalastyle test:scalastyle ... [warn] :::::::::::::::::::::::::::::::::::::::::::::: [warn] :: UNRESOLVED DEPENDENCIES :: [warn] :::::::::::::::::::::::::::::::::::::::::::::: [warn] :: com.typesafe#sbt-mima-plugin;0.1.17: not found [warn] :::::::::::::::::::::::::::::::::::::::::::::: ``` ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Check `GitHub Action` linter result. This PR should pass. Or, manual check. (Note that Jenkins PR builder didn't fail until now due to the local cache.) Closes #26217 from dongjoon-hyun/SPARK-29560. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- project/plugins.sbt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/project/plugins.sbt b/project/plugins.sbt index d1fe59a47217..02525c27b6aa 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -32,6 +32,9 @@ addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.2") addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "1.0.0") +// SPARK-29560 Only sbt-mima-plugin needs this repo +resolvers += Resolver.url("bintray", + new java.net.URL("https://dl.bintray.com/typesafe/sbt-plugins"))(Resolver.defaultIvyPatterns) addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.3.0") // sbt 1.0.0 support: https://github.com/AlpineNow/junit_xml_listener/issues/6 From e6749092f7a2cc1943899fde8d830ec2b8fa2186 Mon Sep 17 00:00:00 2001 From: DylanGuedes Date: Wed, 23 Oct 2019 10:24:38 +0900 Subject: [PATCH 15/58] [SPARK-29107][SQL][TESTS] Port window.sql (Part 1) ### What changes were proposed in this pull request? This PR ports window.sql from PostgreSQL regression tests https://github.com/postgres/postgres/blob/REL_12_STABLE/src/test/regress/sql/window.sql from lines 1~319 The expected results can be found in the link: https://github.com/postgres/postgres/blob/REL_12_STABLE/src/test/regress/expected/window.out ### Why are the changes needed? To ensure compatibility with PostgreSQL. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Pass the Jenkins. And, Comparison with PgSQL results. Closes #26119 from DylanGuedes/spark-29107. Authored-by: DylanGuedes Signed-off-by: HyukjinKwon --- .../inputs/postgreSQL/window_part1.sql | 352 +++++++++ .../results/postgreSQL/window_part1.sql.out | 725 ++++++++++++++++++ 2 files changed, 1077 insertions(+) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part1.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part1.sql.out diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part1.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part1.sql new file mode 100644 index 000000000000..ae2a015ada24 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part1.sql @@ -0,0 +1,352 @@ +-- Portions Copyright (c) 1996-2019, PostgreSQL Global Development Group +-- +-- Window Functions Testing +-- https://github.com/postgres/postgres/blob/REL_12_STABLE/src/test/regress/sql/window.sql#L1-L319 + +CREATE TEMPORARY VIEW tenk2 AS SELECT * FROM tenk1; + +-- [SPARK-29540] Thrift in some cases can't parse string to date +-- CREATE TABLE empsalary ( +-- depname string, +-- empno integer, +-- salary int, +-- enroll_date date +-- ) USING parquet; + +-- [SPARK-29540] Thrift in some cases can't parse string to date +-- INSERT INTO empsalary VALUES ('develop', 10, 5200, '2007-08-01'); +-- INSERT INTO empsalary VALUES ('sales', 1, 5000, '2006-10-01'); +-- INSERT INTO empsalary VALUES ('personnel', 5, 3500, '2007-12-10'); +-- INSERT INTO empsalary VALUES ('sales', 4, 4800, '2007-08-08'); +-- INSERT INTO empsalary VALUES ('personnel', 2, 3900, '2006-12-23'); +-- INSERT INTO empsalary VALUES ('develop', 7, 4200, '2008-01-01'); +-- INSERT INTO empsalary VALUES ('develop', 9, 4500, '2008-01-01'); +-- INSERT INTO empsalary VALUES ('sales', 3, 4800, '2007-08-01'); +-- INSERT INTO empsalary VALUES ('develop', 8, 6000, '2006-10-01'); +-- INSERT INTO empsalary VALUES ('develop', 11, 5200, '2007-08-15'); + +-- [SPARK-29540] Thrift in some cases can't parse string to date +-- SELECT depname, empno, salary, sum(salary) OVER (PARTITION BY depname) FROM empsalary ORDER BY depname, salary; + +-- [SPARK-29540] Thrift in some cases can't parse string to date +-- SELECT depname, empno, salary, rank() OVER (PARTITION BY depname ORDER BY salary) FROM empsalary; + +-- with GROUP BY +SELECT four, ten, SUM(SUM(four)) OVER (PARTITION BY four), AVG(ten) FROM tenk1 +GROUP BY four, ten ORDER BY four, ten; + +-- [SPARK-29540] Thrift in some cases can't parse string to date +-- SELECT depname, empno, salary, sum(salary) OVER w FROM empsalary WINDOW w AS (PARTITION BY depname); + +-- [SPARK-28064] Order by does not accept a call to rank() +-- SELECT depname, empno, salary, rank() OVER w FROM empsalary WINDOW w AS (PARTITION BY depname ORDER BY salary) ORDER BY rank() OVER w; + +-- empty window specification +SELECT COUNT(*) OVER () FROM tenk1 WHERE unique2 < 10; + +SELECT COUNT(*) OVER w FROM tenk1 WHERE unique2 < 10 WINDOW w AS (); + +-- no window operation +SELECT four FROM tenk1 WHERE FALSE WINDOW w AS (PARTITION BY ten); + +-- cumulative aggregate +SELECT sum(four) OVER (PARTITION BY ten ORDER BY unique2) AS sum_1, ten, four FROM tenk1 WHERE unique2 < 10; + +SELECT row_number() OVER (ORDER BY unique2) FROM tenk1 WHERE unique2 < 10; + +SELECT rank() OVER (PARTITION BY four ORDER BY ten) AS rank_1, ten, four FROM tenk1 WHERE unique2 < 10; + +SELECT dense_rank() OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10; + +SELECT percent_rank() OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10; + +SELECT cume_dist() OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10; + +SELECT ntile(3) OVER (ORDER BY ten, four), ten, four FROM tenk1 WHERE unique2 < 10; + +-- [SPARK-28065] ntile does not accept NULL as input +-- SELECT ntile(NULL) OVER (ORDER BY ten, four), ten, four FROM tenk1 LIMIT 2; + +SELECT lag(ten) OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10; + +-- [SPARK-28068] `lag` second argument must be a literal in Spark +-- SELECT lag(ten, four) OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10; + +-- [SPARK-28068] `lag` second argument must be a literal in Spark +-- SELECT lag(ten, four, 0) OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10; + +SELECT lead(ten) OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10; + +SELECT lead(ten * 2, 1) OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10; + +SELECT lead(ten * 2, 1, -1) OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10; + +SELECT first(ten) OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10; + +-- last returns the last row of the frame, which is CURRENT ROW in ORDER BY window. +SELECT last(four) OVER (ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10; + +SELECT last(ten) OVER (PARTITION BY four), ten, four FROM +(SELECT * FROM tenk1 WHERE unique2 < 10 ORDER BY four, ten)s +ORDER BY four, ten; + +-- [SPARK-27951] ANSI SQL: NTH_VALUE function +-- SELECT nth_value(ten, four + 1) OVER (PARTITION BY four), ten, four +-- FROM (SELECT * FROM tenk1 WHERE unique2 < 10 ORDER BY four, ten)s; + +SELECT ten, two, sum(hundred) AS gsum, sum(sum(hundred)) OVER (PARTITION BY two ORDER BY ten) AS wsum +FROM tenk1 GROUP BY ten, two; + +SELECT count(*) OVER (PARTITION BY four), four FROM (SELECT * FROM tenk1 WHERE two = 1)s WHERE unique2 < 10; + +SELECT (count(*) OVER (PARTITION BY four ORDER BY ten) + + sum(hundred) OVER (PARTITION BY four ORDER BY ten)) AS cntsum + FROM tenk1 WHERE unique2 < 10; + +-- opexpr with different windows evaluation. +SELECT * FROM( + SELECT count(*) OVER (PARTITION BY four ORDER BY ten) + + sum(hundred) OVER (PARTITION BY two ORDER BY ten) AS total, + count(*) OVER (PARTITION BY four ORDER BY ten) AS fourcount, + sum(hundred) OVER (PARTITION BY two ORDER BY ten) AS twosum + FROM tenk1 +)sub WHERE total <> fourcount + twosum; + +SELECT avg(four) OVER (PARTITION BY four ORDER BY thousand / 100) FROM tenk1 WHERE unique2 < 10; + +SELECT ten, two, sum(hundred) AS gsum, sum(sum(hundred)) OVER win AS wsum +FROM tenk1 GROUP BY ten, two WINDOW win AS (PARTITION BY two ORDER BY ten); + +-- [SPARK-29540] Thrift in some cases can't parse string to date +-- more than one window with GROUP BY +-- SELECT sum(salary), +-- row_number() OVER (ORDER BY depname), +-- sum(sum(salary)) OVER (ORDER BY depname DESC) +-- FROM empsalary GROUP BY depname; + +-- [SPARK-29540] Thrift in some cases can't parse string to date +-- identical windows with different names +-- SELECT sum(salary) OVER w1, count(*) OVER w2 +-- FROM empsalary WINDOW w1 AS (ORDER BY salary), w2 AS (ORDER BY salary); + +-- subplan +-- [SPARK-28379] Correlated scalar subqueries must be aggregated +-- SELECT lead(ten, (SELECT two FROM tenk1 WHERE s.unique2 = unique2)) OVER (PARTITION BY four ORDER BY ten) +-- FROM tenk1 s WHERE unique2 < 10; + +-- empty table +SELECT count(*) OVER (PARTITION BY four) FROM (SELECT * FROM tenk1 WHERE FALSE)s; + +-- [SPARK-29540] Thrift in some cases can't parse string to date +-- mixture of agg/wfunc in the same window +-- SELECT sum(salary) OVER w, rank() OVER w FROM empsalary WINDOW w AS (PARTITION BY depname ORDER BY salary DESC); + +-- Cannot safely cast 'enroll_date': StringType to DateType; +-- SELECT empno, depname, salary, bonus, depadj, MIN(bonus) OVER (ORDER BY empno), MAX(depadj) OVER () FROM( +-- SELECT *, +-- CASE WHEN enroll_date < '2008-01-01' THEN 2008 - extract(year FROM enroll_date) END * 500 AS bonus, +-- CASE WHEN +-- AVG(salary) OVER (PARTITION BY depname) < salary +-- THEN 200 END AS depadj FROM empsalary +-- )s; + +create temporary view int4_tbl as select * from values + (0), + (123456), + (-123456), + (2147483647), + (-2147483647) + as int4_tbl(f1); + +-- window function over ungrouped agg over empty row set (bug before 9.1) +SELECT SUM(COUNT(f1)) OVER () FROM int4_tbl WHERE f1=42; + +-- window function with ORDER BY an expression involving aggregates (9.1 bug) +select ten, + sum(unique1) + sum(unique2) as res, + rank() over (order by sum(unique1) + sum(unique2)) as rank +from tenk1 +group by ten order by ten; + +-- window and aggregate with GROUP BY expression (9.2 bug) +-- explain +-- select first(max(x)) over (), y +-- from (select unique1 as x, ten+four as y from tenk1) ss +-- group by y; + +-- test non-default frame specifications +SELECT four, ten, +sum(ten) over (partition by four order by ten), +last(ten) over (partition by four order by ten) +FROM (select distinct ten, four from tenk1) ss; + +SELECT four, ten, +sum(ten) over (partition by four order by ten range between unbounded preceding and current row), +last(ten) over (partition by four order by ten range between unbounded preceding and current row) +FROM (select distinct ten, four from tenk1) ss; + +SELECT four, ten, +sum(ten) over (partition by four order by ten range between unbounded preceding and unbounded following), +last(ten) over (partition by four order by ten range between unbounded preceding and unbounded following) +FROM (select distinct ten, four from tenk1) ss; + +-- [SPARK-29451] Some queries with divisions in SQL windows are failling in Thrift +-- SELECT four, ten/4 as two, +-- sum(ten/4) over (partition by four order by ten/4 range between unbounded preceding and current row), +-- last(ten/4) over (partition by four order by ten/4 range between unbounded preceding and current row) +-- FROM (select distinct ten, four from tenk1) ss; + +-- [SPARK-29451] Some queries with divisions in SQL windows are failling in Thrift +-- SELECT four, ten/4 as two, +-- sum(ten/4) over (partition by four order by ten/4 rows between unbounded preceding and current row), +-- last(ten/4) over (partition by four order by ten/4 rows between unbounded preceding and current row) +-- FROM (select distinct ten, four from tenk1) ss; + +SELECT sum(unique1) over (order by four range between current row and unbounded following), +unique1, four +FROM tenk1 WHERE unique1 < 10; + +SELECT sum(unique1) over (rows between current row and unbounded following), +unique1, four +FROM tenk1 WHERE unique1 < 10; + +SELECT sum(unique1) over (rows between 2 preceding and 2 following), +unique1, four +FROM tenk1 WHERE unique1 < 10; + +-- [SPARK-28428] Spark `exclude` always expecting `()` +-- SELECT sum(unique1) over (rows between 2 preceding and 2 following exclude no others), +-- unique1, four +-- FROM tenk1 WHERE unique1 < 10; + +-- [SPARK-28428] Spark `exclude` always expecting `()` +-- SELECT sum(unique1) over (rows between 2 preceding and 2 following exclude current row), +-- unique1, four +-- FROM tenk1 WHERE unique1 < 10; + +-- [SPARK-28428] Spark `exclude` always expecting `()` +-- SELECT sum(unique1) over (rows between 2 preceding and 2 following exclude group), +-- unique1, four +-- FROM tenk1 WHERE unique1 < 10; + +-- [SPARK-28428] Spark `exclude` always expecting `()` +-- SELECT sum(unique1) over (rows between 2 preceding and 2 following exclude ties), +-- unique1, four +-- FROM tenk1 WHERE unique1 < 10; + +-- [SPARK-28428] Spark `exclude` always expecting `()` +-- SELECT first(unique1) over (ORDER BY four rows between current row and 2 following exclude current row), +-- unique1, four +-- FROM tenk1 WHERE unique1 < 10; + +-- [SPARK-28428] Spark `exclude` always expecting `()` +-- SELECT first(unique1) over (ORDER BY four rows between current row and 2 following exclude group), +-- unique1, four +-- FROM tenk1 WHERE unique1 < 10; + +-- [SPARK-28428] Spark `exclude` always expecting `()` +-- SELECT first(unique1) over (ORDER BY four rows between current row and 2 following exclude ties), +-- unique1, four +-- FROM tenk1 WHERE unique1 < 10; + +-- [SPARK-28428] Spark `exclude` always expecting `()` +-- SELECT last(unique1) over (ORDER BY four rows between current row and 2 following exclude current row), +-- unique1, four +-- FROM tenk1 WHERE unique1 < 10; + +-- [SPARK-28428] Spark `exclude` always expecting `()` +-- SELECT last(unique1) over (ORDER BY four rows between current row and 2 following exclude group), +-- unique1, four +-- FROM tenk1 WHERE unique1 < 10; + +-- [SPARK-28428] Spark `exclude` always expecting `()` +-- SELECT last(unique1) over (ORDER BY four rows between current row and 2 following exclude ties), +-- unique1, four +-- FROM tenk1 WHERE unique1 < 10; + +SELECT sum(unique1) over (rows between 2 preceding and 1 preceding), +unique1, four +FROM tenk1 WHERE unique1 < 10; + +SELECT sum(unique1) over (rows between 1 following and 3 following), +unique1, four +FROM tenk1 WHERE unique1 < 10; + +SELECT sum(unique1) over (rows between unbounded preceding and 1 following), +unique1, four +FROM tenk1 WHERE unique1 < 10; + +-- [SPARK-28428] Spark `exclude` always expecting `()` +-- SELECT sum(unique1) over (w range between current row and unbounded following), +-- unique1, four +-- FROM tenk1 WHERE unique1 < 10 WINDOW w AS (order by four); + +-- [SPARK-28428] Spark `exclude` always expecting `()` +-- SELECT sum(unique1) over (w range between unbounded preceding and current row exclude current row), +-- unique1, four +-- FROM tenk1 WHERE unique1 < 10 WINDOW w AS (order by four); + +-- [SPARK-28428] Spark `exclude` always expecting `()` +-- SELECT sum(unique1) over (w range between unbounded preceding and current row exclude group), +-- unique1, four +-- FROM tenk1 WHERE unique1 < 10 WINDOW w AS (order by four); + +-- [SPARK-28428] Spark `exclude` always expecting `()` +-- SELECT sum(unique1) over (w range between unbounded preceding and current row exclude ties), +-- unique1, four +-- FROM tenk1 WHERE unique1 < 10 WINDOW w AS (order by four); + +-- [SPARK-27951] ANSI SQL: NTH_VALUE function +-- SELECT first_value(unique1) over w, +-- nth_value(unique1, 2) over w AS nth_2, +-- last_value(unique1) over w, unique1, four +-- FROM tenk1 WHERE unique1 < 10 +-- WINDOW w AS (order by four range between current row and unbounded following); + +-- [SPARK-28501] Frame bound value must be a literal. +-- SELECT sum(unique1) over +-- (order by unique1 +-- rows (SELECT unique1 FROM tenk1 ORDER BY unique1 LIMIT 1) + 1 PRECEDING), +-- unique1 +-- FROM tenk1 WHERE unique1 < 10; + +CREATE TEMP VIEW v_window AS +SELECT i.id, sum(i.id) over (order by i.id rows between 1 preceding and 1 following) as sum_rows +FROM range(1, 11) i; + +SELECT * FROM v_window; + +-- [SPARK-28428] Spark `exclude` always expecting `()` +-- CREATE OR REPLACE TEMP VIEW v_window AS +-- SELECT i, sum(i) over (order by i rows between 1 preceding and 1 following +-- exclude current row) as sum_rows FROM range(1, 10) i; + +-- SELECT * FROM v_window; + +-- [SPARK-28428] Spark `exclude` always expecting `()` +-- CREATE OR REPLACE TEMP VIEW v_window AS +-- SELECT i, sum(i) over (order by i rows between 1 preceding and 1 following +-- exclude group) as sum_rows FROM range(1, 10) i; +-- SELECT * FROM v_window; + +-- [SPARK-28428] Spark `exclude` always expecting `()` +-- CREATE OR REPLACE TEMP VIEW v_window AS +-- SELECT i, sum(i) over (order by i rows between 1 preceding and 1 following +-- exclude ties) as sum_rows FROM generate_series(1, 10) i; + +-- [SPARK-28428] Spark `exclude` always expecting `()` +-- CREATE OR REPLACE TEMP VIEW v_window AS +-- SELECT i, sum(i) over (order by i rows between 1 preceding and 1 following +-- exclude no others) as sum_rows FROM generate_series(1, 10) i; +-- SELECT * FROM v_window; + +-- [SPARK-28648] Adds support to `groups` unit type in window clauses +-- CREATE OR REPLACE TEMP VIEW v_window AS +-- SELECT i.id, sum(i.id) over (order by i.id groups between 1 preceding and 1 following) as sum_rows FROM range(1, 11) i; +-- SELECT * FROM v_window; + +DROP VIEW v_window; +-- [SPARK-29540] Thrift in some cases can't parse string to date +-- DROP TABLE empsalary; +DROP VIEW tenk2; +DROP VIEW int4_tbl; diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part1.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part1.sql.out new file mode 100644 index 000000000000..45bc98ae9764 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part1.sql.out @@ -0,0 +1,725 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 43 + + +-- !query 0 +CREATE TEMPORARY VIEW tenk2 AS SELECT * FROM tenk1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT four, ten, SUM(SUM(four)) OVER (PARTITION BY four), AVG(ten) FROM tenk1 +GROUP BY four, ten ORDER BY four, ten +-- !query 1 schema +struct +-- !query 1 output +0 0 0 0.0 +0 2 0 2.0 +0 4 0 4.0 +0 6 0 6.0 +0 8 0 8.0 +1 1 2500 1.0 +1 3 2500 3.0 +1 5 2500 5.0 +1 7 2500 7.0 +1 9 2500 9.0 +2 0 5000 0.0 +2 2 5000 2.0 +2 4 5000 4.0 +2 6 5000 6.0 +2 8 5000 8.0 +3 1 7500 1.0 +3 3 7500 3.0 +3 5 7500 5.0 +3 7 7500 7.0 +3 9 7500 9.0 + + +-- !query 2 +SELECT COUNT(*) OVER () FROM tenk1 WHERE unique2 < 10 +-- !query 2 schema +struct +-- !query 2 output +10 +10 +10 +10 +10 +10 +10 +10 +10 +10 + + +-- !query 3 +SELECT COUNT(*) OVER w FROM tenk1 WHERE unique2 < 10 WINDOW w AS () +-- !query 3 schema +struct +-- !query 3 output +10 +10 +10 +10 +10 +10 +10 +10 +10 +10 + + +-- !query 4 +SELECT four FROM tenk1 WHERE FALSE WINDOW w AS (PARTITION BY ten) +-- !query 4 schema +struct +-- !query 4 output + + + +-- !query 5 +SELECT sum(four) OVER (PARTITION BY ten ORDER BY unique2) AS sum_1, ten, four FROM tenk1 WHERE unique2 < 10 +-- !query 5 schema +struct +-- !query 5 output +0 0 0 +0 0 0 +0 4 0 +1 7 1 +1 9 1 +2 0 2 +3 1 3 +3 3 3 +4 1 1 +5 1 1 + + +-- !query 6 +SELECT row_number() OVER (ORDER BY unique2) FROM tenk1 WHERE unique2 < 10 +-- !query 6 schema +struct +-- !query 6 output +1 +10 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 7 +SELECT rank() OVER (PARTITION BY four ORDER BY ten) AS rank_1, ten, four FROM tenk1 WHERE unique2 < 10 +-- !query 7 schema +struct +-- !query 7 output +1 0 0 +1 0 0 +1 0 2 +1 1 1 +1 1 1 +1 1 3 +2 3 3 +3 4 0 +3 7 1 +4 9 1 + + +-- !query 8 +SELECT dense_rank() OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10 +-- !query 8 schema +struct +-- !query 8 output +1 0 0 +1 0 0 +1 0 2 +1 1 1 +1 1 1 +1 1 3 +2 3 3 +2 4 0 +2 7 1 +3 9 1 + + +-- !query 9 +SELECT percent_rank() OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10 +-- !query 9 schema +struct +-- !query 9 output +0.0 0 0 +0.0 0 0 +0.0 0 2 +0.0 1 1 +0.0 1 1 +0.0 1 3 +0.6666666666666666 7 1 +1.0 3 3 +1.0 4 0 +1.0 9 1 + + +-- !query 10 +SELECT cume_dist() OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10 +-- !query 10 schema +struct +-- !query 10 output +0.5 1 1 +0.5 1 1 +0.5 1 3 +0.6666666666666666 0 0 +0.6666666666666666 0 0 +0.75 7 1 +1.0 0 2 +1.0 3 3 +1.0 4 0 +1.0 9 1 + + +-- !query 11 +SELECT ntile(3) OVER (ORDER BY ten, four), ten, four FROM tenk1 WHERE unique2 < 10 +-- !query 11 schema +struct +-- !query 11 output +1 0 0 +1 0 0 +1 0 2 +1 1 1 +2 1 1 +2 1 3 +2 3 3 +3 4 0 +3 7 1 +3 9 1 + + +-- !query 12 +SELECT lag(ten) OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10 +-- !query 12 schema +struct +-- !query 12 output +0 0 0 +0 4 0 +1 1 1 +1 3 3 +1 7 1 +7 9 1 +NULL 0 0 +NULL 0 2 +NULL 1 1 +NULL 1 3 + + +-- !query 13 +SELECT lead(ten) OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10 +-- !query 13 schema +struct +-- !query 13 output +0 0 0 +1 1 1 +3 1 3 +4 0 0 +7 1 1 +9 7 1 +NULL 0 2 +NULL 3 3 +NULL 4 0 +NULL 9 1 + + +-- !query 14 +SELECT lead(ten * 2, 1) OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10 +-- !query 14 schema +struct +-- !query 14 output +0 0 0 +14 1 1 +18 7 1 +2 1 1 +6 1 3 +8 0 0 +NULL 0 2 +NULL 3 3 +NULL 4 0 +NULL 9 1 + + +-- !query 15 +SELECT lead(ten * 2, 1, -1) OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10 +-- !query 15 schema +struct +-- !query 15 output +-1 0 2 +-1 3 3 +-1 4 0 +-1 9 1 +0 0 0 +14 1 1 +18 7 1 +2 1 1 +6 1 3 +8 0 0 + + +-- !query 16 +SELECT first(ten) OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10 +-- !query 16 schema +struct +-- !query 16 output +0 0 0 +0 0 0 +0 0 2 +0 4 0 +1 1 1 +1 1 1 +1 1 3 +1 3 3 +1 7 1 +1 9 1 + + +-- !query 17 +SELECT last(four) OVER (ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10 +-- !query 17 schema +struct +-- !query 17 output +0 4 0 +1 1 1 +1 1 1 +1 1 3 +1 7 1 +1 9 1 +2 0 0 +2 0 0 +2 0 2 +3 3 3 + + +-- !query 18 +SELECT last(ten) OVER (PARTITION BY four), ten, four FROM +(SELECT * FROM tenk1 WHERE unique2 < 10 ORDER BY four, ten)s +ORDER BY four, ten +-- !query 18 schema +struct +-- !query 18 output +4 0 0 +4 0 0 +4 4 0 +9 1 1 +9 1 1 +9 7 1 +9 9 1 +0 0 2 +3 1 3 +3 3 3 + + +-- !query 19 +SELECT ten, two, sum(hundred) AS gsum, sum(sum(hundred)) OVER (PARTITION BY two ORDER BY ten) AS wsum +FROM tenk1 GROUP BY ten, two +-- !query 19 schema +struct +-- !query 19 output +0 0 45000 45000 +1 1 46000 46000 +2 0 47000 92000 +3 1 48000 94000 +4 0 49000 141000 +5 1 50000 144000 +6 0 51000 192000 +7 1 52000 196000 +8 0 53000 245000 +9 1 54000 250000 + + +-- !query 20 +SELECT count(*) OVER (PARTITION BY four), four FROM (SELECT * FROM tenk1 WHERE two = 1)s WHERE unique2 < 10 +-- !query 20 schema +struct +-- !query 20 output +2 3 +2 3 +4 1 +4 1 +4 1 +4 1 + + +-- !query 21 +SELECT (count(*) OVER (PARTITION BY four ORDER BY ten) + + sum(hundred) OVER (PARTITION BY four ORDER BY ten)) AS cntsum + FROM tenk1 WHERE unique2 < 10 +-- !query 21 schema +struct +-- !query 21 output +136 +22 +22 +24 +24 +51 +82 +87 +92 +92 + + +-- !query 22 +SELECT * FROM( + SELECT count(*) OVER (PARTITION BY four ORDER BY ten) + + sum(hundred) OVER (PARTITION BY two ORDER BY ten) AS total, + count(*) OVER (PARTITION BY four ORDER BY ten) AS fourcount, + sum(hundred) OVER (PARTITION BY two ORDER BY ten) AS twosum + FROM tenk1 +)sub WHERE total <> fourcount + twosum +-- !query 22 schema +struct +-- !query 22 output + + + +-- !query 23 +SELECT avg(four) OVER (PARTITION BY four ORDER BY thousand / 100) FROM tenk1 WHERE unique2 < 10 +-- !query 23 schema +struct +-- !query 23 output +0.0 +0.0 +0.0 +1.0 +1.0 +1.0 +1.0 +2.0 +3.0 +3.0 + + +-- !query 24 +SELECT ten, two, sum(hundred) AS gsum, sum(sum(hundred)) OVER win AS wsum +FROM tenk1 GROUP BY ten, two WINDOW win AS (PARTITION BY two ORDER BY ten) +-- !query 24 schema +struct +-- !query 24 output +0 0 45000 45000 +1 1 46000 46000 +2 0 47000 92000 +3 1 48000 94000 +4 0 49000 141000 +5 1 50000 144000 +6 0 51000 192000 +7 1 52000 196000 +8 0 53000 245000 +9 1 54000 250000 + + +-- !query 25 +SELECT count(*) OVER (PARTITION BY four) FROM (SELECT * FROM tenk1 WHERE FALSE)s +-- !query 25 schema +struct +-- !query 25 output + + + +-- !query 26 +create temporary view int4_tbl as select * from values + (0), + (123456), + (-123456), + (2147483647), + (-2147483647) + as int4_tbl(f1) +-- !query 26 schema +struct<> +-- !query 26 output + + + +-- !query 27 +SELECT SUM(COUNT(f1)) OVER () FROM int4_tbl WHERE f1=42 +-- !query 27 schema +struct +-- !query 27 output +0 + + +-- !query 28 +select ten, + sum(unique1) + sum(unique2) as res, + rank() over (order by sum(unique1) + sum(unique2)) as rank +from tenk1 +group by ten order by ten +-- !query 28 schema +struct +-- !query 28 output +0 9976146 4 +1 10114187 9 +2 10059554 8 +3 9878541 1 +4 9881005 2 +5 9981670 5 +6 9947099 3 +7 10120309 10 +8 9991305 6 +9 10040184 7 + + +-- !query 29 +SELECT four, ten, +sum(ten) over (partition by four order by ten), +last(ten) over (partition by four order by ten) +FROM (select distinct ten, four from tenk1) ss +-- !query 29 schema +struct +-- !query 29 output +0 0 0 0 +0 2 2 2 +0 4 6 4 +0 6 12 6 +0 8 20 8 +1 1 1 1 +1 3 4 3 +1 5 9 5 +1 7 16 7 +1 9 25 9 +2 0 0 0 +2 2 2 2 +2 4 6 4 +2 6 12 6 +2 8 20 8 +3 1 1 1 +3 3 4 3 +3 5 9 5 +3 7 16 7 +3 9 25 9 + + +-- !query 30 +SELECT four, ten, +sum(ten) over (partition by four order by ten range between unbounded preceding and current row), +last(ten) over (partition by four order by ten range between unbounded preceding and current row) +FROM (select distinct ten, four from tenk1) ss +-- !query 30 schema +struct +-- !query 30 output +0 0 0 0 +0 2 2 2 +0 4 6 4 +0 6 12 6 +0 8 20 8 +1 1 1 1 +1 3 4 3 +1 5 9 5 +1 7 16 7 +1 9 25 9 +2 0 0 0 +2 2 2 2 +2 4 6 4 +2 6 12 6 +2 8 20 8 +3 1 1 1 +3 3 4 3 +3 5 9 5 +3 7 16 7 +3 9 25 9 + + +-- !query 31 +SELECT four, ten, +sum(ten) over (partition by four order by ten range between unbounded preceding and unbounded following), +last(ten) over (partition by four order by ten range between unbounded preceding and unbounded following) +FROM (select distinct ten, four from tenk1) ss +-- !query 31 schema +struct +-- !query 31 output +0 0 20 8 +0 2 20 8 +0 4 20 8 +0 6 20 8 +0 8 20 8 +1 1 25 9 +1 3 25 9 +1 5 25 9 +1 7 25 9 +1 9 25 9 +2 0 20 8 +2 2 20 8 +2 4 20 8 +2 6 20 8 +2 8 20 8 +3 1 25 9 +3 3 25 9 +3 5 25 9 +3 7 25 9 +3 9 25 9 + + +-- !query 32 +SELECT sum(unique1) over (order by four range between current row and unbounded following), +unique1, four +FROM tenk1 WHERE unique1 < 10 +-- !query 32 schema +struct +-- !query 32 output +10 3 3 +10 7 3 +18 2 2 +18 6 2 +33 1 1 +33 5 1 +33 9 1 +45 0 0 +45 4 0 +45 8 0 + + +-- !query 33 +SELECT sum(unique1) over (rows between current row and unbounded following), +unique1, four +FROM tenk1 WHERE unique1 < 10 +-- !query 33 schema +struct +-- !query 33 output +0 0 0 +10 3 3 +15 5 1 +23 8 0 +32 9 1 +38 6 2 +39 1 1 +41 2 2 +45 4 0 +7 7 3 + + +-- !query 34 +SELECT sum(unique1) over (rows between 2 preceding and 2 following), +unique1, four +FROM tenk1 WHERE unique1 < 10 +-- !query 34 schema +struct +-- !query 34 output +10 0 0 +13 2 2 +15 7 3 +22 1 1 +23 3 3 +26 6 2 +29 9 1 +31 8 0 +32 5 1 +7 4 0 + + +-- !query 35 +SELECT sum(unique1) over (rows between 2 preceding and 1 preceding), +unique1, four +FROM tenk1 WHERE unique1 < 10 +-- !query 35 schema +struct +-- !query 35 output +10 0 0 +13 3 3 +15 8 0 +17 5 1 +3 6 2 +4 2 2 +6 1 1 +7 9 1 +8 7 3 +NULL 4 0 + + +-- !query 36 +SELECT sum(unique1) over (rows between 1 following and 3 following), +unique1, four +FROM tenk1 WHERE unique1 < 10 +-- !query 36 schema +struct +-- !query 36 output +0 7 3 +10 5 1 +15 8 0 +16 2 2 +16 9 1 +22 6 2 +23 1 1 +7 3 3 +9 4 0 +NULL 0 0 + + +-- !query 37 +SELECT sum(unique1) over (rows between unbounded preceding and 1 following), +unique1, four +FROM tenk1 WHERE unique1 < 10 +-- !query 37 schema +struct +-- !query 37 output +13 1 1 +22 6 2 +30 9 1 +35 8 0 +38 5 1 +45 0 0 +45 3 3 +45 7 3 +6 4 0 +7 2 2 + + +-- !query 38 +CREATE TEMP VIEW v_window AS +SELECT i.id, sum(i.id) over (order by i.id rows between 1 preceding and 1 following) as sum_rows +FROM range(1, 11) i +-- !query 38 schema +struct<> +-- !query 38 output + + + +-- !query 39 +SELECT * FROM v_window +-- !query 39 schema +struct +-- !query 39 output +1 3 +10 19 +2 6 +3 9 +4 12 +5 15 +6 18 +7 21 +8 24 +9 27 + + +-- !query 40 +DROP VIEW v_window +-- !query 40 schema +struct<> +-- !query 40 output + + + +-- !query 41 +DROP VIEW tenk2 +-- !query 41 schema +struct<> +-- !query 41 output + + + +-- !query 42 +DROP VIEW int4_tbl +-- !query 42 schema +struct<> +-- !query 42 output + From c128ac564d198effe6bb9754489ea32133dfeb89 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Wed, 23 Oct 2019 12:17:20 +0800 Subject: [PATCH 16/58] [SPARK-29511][SQL] DataSourceV2: Support CREATE NAMESPACE ### What changes were proposed in this pull request? This PR adds `CREATE NAMESPACE` support for V2 catalogs. ### Why are the changes needed? Currently, you cannot explicitly create namespaces for v2 catalogs. ### Does this PR introduce any user-facing change? The user can now perform the following: ```SQL CREATE NAMESPACE mycatalog.ns ``` to create a namespace `ns` inside `mycatalog` V2 catalog. ### How was this patch tested? Added unit tests. Closes #26166 from imback82/create_namespace. Authored-by: Terry Kim Signed-off-by: Wenchen Fan --- docs/sql-keywords.md | 1 + .../spark/sql/catalyst/parser/SqlBase.g4 | 7 +- .../catalyst/analysis/ResolveCatalogs.scala | 7 ++ .../sql/catalyst/parser/AstBuilder.scala | 40 +++++++++ .../catalyst/plans/logical/statements.scala | 13 +++ .../catalyst/plans/logical/v2Commands.scala | 8 ++ .../sql/catalyst/parser/DDLParserSuite.scala | 84 +++++++++++++++++++ .../catalyst/parser/ParserUtilsSuite.scala | 2 +- .../analysis/ResolveSessionCatalog.scala | 15 +++- .../spark/sql/execution/SparkSqlParser.scala | 27 ------ .../datasources/v2/CreateNamespaceExec.scala | 55 ++++++++++++ .../datasources/v2/DataSourceV2Strategy.scala | 5 +- .../sql/connector/DataSourceV2SQLSuite.scala | 28 ++++++- .../execution/command/DDLParserSuite.scala | 58 ------------- 14 files changed, 259 insertions(+), 91 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateNamespaceExec.scala diff --git a/docs/sql-keywords.md b/docs/sql-keywords.md index 7a0e3efee8ff..b4f8d8be11c4 100644 --- a/docs/sql-keywords.md +++ b/docs/sql-keywords.md @@ -210,6 +210,7 @@ Below is a list of all the keywords in Spark SQL. PRECEDINGnon-reservednon-reservednon-reserved PRIMARYreservednon-reservedreserved PRINCIPALSnon-reservednon-reservednon-reserved + PROPERTIESnon-reservednon-reservednon-reserved PURGEnon-reservednon-reservednon-reserved QUERYnon-reservednon-reservednon-reserved RANGEnon-reservednon-reservedreserved diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 963077c35df9..7e5e16b8e32b 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -83,10 +83,10 @@ statement : query #statementDefault | ctes? dmlStatementNoWith #dmlStatement | USE NAMESPACE? multipartIdentifier #use - | CREATE database (IF NOT EXISTS)? db=errorCapturingIdentifier + | CREATE (database | NAMESPACE) (IF NOT EXISTS)? multipartIdentifier ((COMMENT comment=STRING) | locationSpec | - (WITH DBPROPERTIES tablePropertyList))* #createDatabase + (WITH (DBPROPERTIES | PROPERTIES) tablePropertyList))* #createNamespace | ALTER database db=errorCapturingIdentifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties | ALTER database db=errorCapturingIdentifier @@ -1039,6 +1039,7 @@ ansiNonReserved | POSITION | PRECEDING | PRINCIPALS + | PROPERTIES | PURGE | QUERY | RANGE @@ -1299,6 +1300,7 @@ nonReserved | PRECEDING | PRIMARY | PRINCIPALS + | PROPERTIES | PURGE | QUERY | RANGE @@ -1564,6 +1566,7 @@ POSITION: 'POSITION'; PRECEDING: 'PRECEDING'; PRIMARY: 'PRIMARY'; PRINCIPALS: 'PRINCIPALS'; +PROPERTIES: 'PROPERTIES'; PURGE: 'PURGE'; QUERY: 'QUERY'; RANGE: 'RANGE'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index 13a79a82a385..6553b3d57d7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -168,6 +168,13 @@ class ResolveCatalogs(val catalogManager: CatalogManager) s"Can not specify catalog `${catalog.name}` for view ${viewName.quoted} " + s"because view support in catalog has not been implemented yet") + case c @ CreateNamespaceStatement(NonSessionCatalog(catalog, nameParts), _, _) => + CreateNamespace( + catalog.asNamespaceCatalog, + nameParts, + c.ifNotExists, + c.properties) + case ShowNamespacesStatement(Some(CatalogAndNamespace(catalog, namespace)), pattern) => ShowNamespaces(catalog.asNamespaceCatalog, namespace, pattern) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 548042bc9767..7c67952aba40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2307,6 +2307,46 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } } + /** + * Create a [[CreateNamespaceStatement]] command. + * + * For example: + * {{{ + * CREATE NAMESPACE [IF NOT EXISTS] ns1.ns2.ns3 + * create_namespace_clauses; + * + * create_namespace_clauses (order insensitive): + * [COMMENT namespace_comment] + * [LOCATION path] + * [WITH PROPERTIES (key1=val1, key2=val2, ...)] + * }}} + */ + override def visitCreateNamespace(ctx: CreateNamespaceContext): LogicalPlan = withOrigin(ctx) { + checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) + checkDuplicateClauses(ctx.PROPERTIES, "WITH PROPERTIES", ctx) + checkDuplicateClauses(ctx.DBPROPERTIES, "WITH DBPROPERTIES", ctx) + + if (!ctx.PROPERTIES.isEmpty && !ctx.DBPROPERTIES.isEmpty) { + throw new ParseException(s"Either PROPERTIES or DBPROPERTIES is allowed.", ctx) + } + + var properties = ctx.tablePropertyList.asScala.headOption + .map(visitPropertyKeyValues) + .getOrElse(Map.empty) + Option(ctx.comment).map(string).map { + properties += CreateNamespaceStatement.COMMENT_PROPERTY_KEY -> _ + } + ctx.locationSpec.asScala.headOption.map(visitLocationSpec).map { + properties += CreateNamespaceStatement.LOCATION_PROPERTY_KEY -> _ + } + + CreateNamespaceStatement( + visitMultipartIdentifier(ctx.multipartIdentifier), + ctx.EXISTS != null, + properties) + } + /** * Create a [[ShowNamespacesStatement]] command. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index a73a2975aa9c..3bd16187320f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -282,6 +282,19 @@ case class InsertIntoStatement( case class ShowTablesStatement(namespace: Option[Seq[String]], pattern: Option[String]) extends ParsedStatement +/** + * A CREATE NAMESPACE statement, as parsed from SQL. + */ +case class CreateNamespaceStatement( + namespace: Seq[String], + ifNotExists: Boolean, + properties: Map[String, String]) extends ParsedStatement + +object CreateNamespaceStatement { + val COMMENT_PROPERTY_KEY: String = "comment" + val LOCATION_PROPERTY_KEY: String = "location" +} + /** * A SHOW NAMESPACES statement, as parsed from SQL. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index f89dfb1ec47d..8f5731a4a7a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -237,6 +237,14 @@ case class ReplaceTableAsSelect( } } +/** + * The logical plan of the CREATE NAMESPACE command that works for v2 catalogs. + */ +case class CreateNamespace( + catalog: SupportsNamespaces, + namespace: Seq[String], + ifNotExists: Boolean, + properties: Map[String, String]) extends Command /** * The logical plan of the SHOW NAMESPACES command that works for v2 catalogs. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 1dacb2384ac1..38ef357036a0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -845,6 +845,90 @@ class DDLParserSuite extends AnalysisTest { ShowTablesStatement(Some(Seq("tbl")), Some("*dog*"))) } + test("create namespace -- backward compatibility with DATABASE/DBPROPERTIES") { + val expected = CreateNamespaceStatement( + Seq("a", "b", "c"), + ifNotExists = true, + Map( + "a" -> "a", + "b" -> "b", + "c" -> "c", + "comment" -> "namespace_comment", + "location" -> "/home/user/db")) + + comparePlans( + parsePlan( + """ + |CREATE NAMESPACE IF NOT EXISTS a.b.c + |WITH PROPERTIES ('a'='a', 'b'='b', 'c'='c') + |COMMENT 'namespace_comment' LOCATION '/home/user/db' + """.stripMargin), + expected) + + comparePlans( + parsePlan( + """ + |CREATE DATABASE IF NOT EXISTS a.b.c + |WITH DBPROPERTIES ('a'='a', 'b'='b', 'c'='c') + |COMMENT 'namespace_comment' LOCATION '/home/user/db' + """.stripMargin), + expected) + } + + test("create namespace -- check duplicates") { + def createDatabase(duplicateClause: String): String = { + s""" + |CREATE NAMESPACE IF NOT EXISTS a.b.c + |$duplicateClause + |$duplicateClause + """.stripMargin + } + val sql1 = createDatabase("COMMENT 'namespace_comment'") + val sql2 = createDatabase("LOCATION '/home/user/db'") + val sql3 = createDatabase("WITH PROPERTIES ('a'='a', 'b'='b', 'c'='c')") + val sql4 = createDatabase("WITH DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')") + + intercept(sql1, "Found duplicate clauses: COMMENT") + intercept(sql2, "Found duplicate clauses: LOCATION") + intercept(sql3, "Found duplicate clauses: WITH PROPERTIES") + intercept(sql4, "Found duplicate clauses: WITH DBPROPERTIES") + } + + test("create namespace - property values must be set") { + assertUnsupported( + sql = "CREATE NAMESPACE a.b.c WITH PROPERTIES('key_without_value', 'key_with_value'='x')", + containsThesePhrases = Seq("key_without_value")) + } + + test("create namespace -- either PROPERTIES or DBPROPERTIES is allowed") { + val sql = + s""" + |CREATE NAMESPACE IF NOT EXISTS a.b.c + |WITH PROPERTIES ('a'='a', 'b'='b', 'c'='c') + |WITH DBPROPERTIES ('a'='a', 'b'='b', 'c'='c') + """.stripMargin + intercept(sql, "Either PROPERTIES or DBPROPERTIES is allowed") + } + + test("create namespace - support for other types in PROPERTIES") { + val sql = + """ + |CREATE NAMESPACE a.b.c + |LOCATION '/home/user/db' + |WITH PROPERTIES ('a'=1, 'b'=0.1, 'c'=TRUE) + """.stripMargin + comparePlans( + parsePlan(sql), + CreateNamespaceStatement( + Seq("a", "b", "c"), + ifNotExists = false, + Map( + "a" -> "1", + "b" -> "0.1", + "c" -> "true", + "location" -> "/home/user/db"))) + } + test("show databases: basic") { comparePlans( parsePlan("SHOW DATABASES"), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala index 07f77ea889db..c6434f2bdd3e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala @@ -50,7 +50,7 @@ class ParserUtilsSuite extends SparkFunSuite { |WITH DBPROPERTIES ('a'='a', 'b'='b', 'c'='c') """.stripMargin ) { parser => - parser.statement().asInstanceOf[CreateDatabaseContext] + parser.statement().asInstanceOf[CreateNamespaceContext] } val emptyContext = buildContext("") { parser => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 4a2e6731d9d8..4cca9846e996 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog, TableChange, V1Table} import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableRecoverPartitionsCommand, AlterTableSetLocationCommand, AlterTableSetPropertiesCommand, AlterTableUnsetPropertiesCommand, AnalyzeColumnCommand, AnalyzePartitionCommand, AnalyzeTableCommand, DescribeColumnCommand, DescribeTableCommand, DropTableCommand, ShowPartitionsCommand, ShowTablesCommand, TruncateTableCommand} +import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableRecoverPartitionsCommand, AlterTableSetLocationCommand, AlterTableSetPropertiesCommand, AlterTableUnsetPropertiesCommand, AnalyzeColumnCommand, AnalyzePartitionCommand, AnalyzeTableCommand, CreateDatabaseCommand, DescribeColumnCommand, DescribeTableCommand, DropTableCommand, ShowPartitionsCommand, ShowTablesCommand, TruncateTableCommand} import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource} import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.internal.SQLConf @@ -255,6 +255,19 @@ class ResolveSessionCatalog( case DropViewStatement(SessionCatalog(catalog, viewName), ifExists) => DropTableCommand(viewName.asTableIdentifier, ifExists, isView = true, purge = false) + case c @ CreateNamespaceStatement(SessionCatalog(catalog, nameParts), _, _) => + if (nameParts.length != 1) { + throw new AnalysisException( + s"The database name is not valid: ${nameParts.quoted}") + } + + val comment = c.properties.get(CreateNamespaceStatement.COMMENT_PROPERTY_KEY) + val location = c.properties.get(CreateNamespaceStatement.LOCATION_PROPERTY_KEY) + val newProperties = c.properties - + CreateNamespaceStatement.COMMENT_PROPERTY_KEY - + CreateNamespaceStatement.LOCATION_PROPERTY_KEY + CreateDatabaseCommand(nameParts.head, c.ifNotExists, location, comment, newProperties) + case ShowTablesStatement(Some(SessionCatalog(catalog, nameParts)), pattern) => if (nameParts.length != 1) { throw new AnalysisException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 3f3f6b373eb0..38f3c6e1b750 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -329,33 +329,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { ) } - /** - * Create a [[CreateDatabaseCommand]] command. - * - * For example: - * {{{ - * CREATE DATABASE [IF NOT EXISTS] database_name - * create_database_clauses; - * - * create_database_clauses (order insensitive): - * [COMMENT database_comment] - * [LOCATION path] - * [WITH DBPROPERTIES (key1=val1, key2=val2, ...)] - * }}} - */ - override def visitCreateDatabase(ctx: CreateDatabaseContext): LogicalPlan = withOrigin(ctx) { - checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx) - checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) - checkDuplicateClauses(ctx.DBPROPERTIES, "WITH DBPROPERTIES", ctx) - - CreateDatabaseCommand( - ctx.db.getText, - ctx.EXISTS != null, - ctx.locationSpec.asScala.headOption.map(visitLocationSpec), - Option(ctx.comment).map(string), - ctx.tablePropertyList.asScala.headOption.map(visitPropertyKeyValues).getOrElse(Map.empty)) - } - /** * Create an [[AlterDatabasePropertiesCommand]] command. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateNamespaceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateNamespaceExec.scala new file mode 100644 index 000000000000..0f69f85dd837 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateNamespaceExec.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import scala.collection.JavaConverters.mapAsJavaMapConverter + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.NamespaceAlreadyExistsException +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.SupportsNamespaces + +/** + * Physical plan node for creating a namespace. + */ +case class CreateNamespaceExec( + catalog: SupportsNamespaces, + namespace: Seq[String], + ifNotExists: Boolean, + private var properties: Map[String, String]) + extends V2CommandExec { + override protected def run(): Seq[InternalRow] = { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + val ns = namespace.toArray + if (!catalog.namespaceExists(ns)) { + try { + catalog.createNamespace(ns, properties.asJava) + } catch { + case _: NamespaceAlreadyExistsException if ifNotExists => + logWarning(s"Namespace ${namespace.quoted} was created concurrently. Ignoring.") + } + } else if (!ifNotExists) { + throw new NamespaceAlreadyExistsException(ns) + } + + Seq.empty + } + + override def output: Seq[Attribute] = Seq.empty +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index c8d29520bcfc..49035c3cc3da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import org.apache.spark.sql.{AnalysisException, Strategy} import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, AppendData, CreateTableAsSelect, CreateV2Table, DeleteFromTable, DescribeTable, DropTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Repartition, ReplaceTable, ReplaceTableAsSelect, SetCatalogAndNamespace, ShowNamespaces, ShowTables} +import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, AppendData, CreateNamespace, CreateTableAsSelect, CreateV2Table, DeleteFromTable, DescribeTable, DropTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Repartition, ReplaceTable, ReplaceTableAsSelect, SetCatalogAndNamespace, ShowNamespaces, ShowTables} import org.apache.spark.sql.connector.catalog.{StagingTableCatalog, TableCapability} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} @@ -289,6 +289,9 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { case AlterTable(catalog, ident, _, changes) => AlterTableExec(catalog, ident, changes) :: Nil + case CreateNamespace(catalog, namespace, ifNotExists, properties) => + CreateNamespaceExec(catalog, namespace, ifNotExists, properties) :: Nil + case r: ShowNamespaces => ShowNamespacesExec(r.output, r.catalog, r.namespace, r.pattern) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 39709ab426a0..2ea26787dbb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.connector import scala.collection.JavaConverters._ import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchDatabaseException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NamespaceAlreadyExistsException, NoSuchDatabaseException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.internal.SQLConf @@ -764,6 +764,32 @@ class DataSourceV2SQLSuite assert(expected === df.collect()) } + test("CreateNameSpace: basic tests") { + // Session catalog is used. + sql("CREATE NAMESPACE ns") + testShowNamespaces("SHOW NAMESPACES", Seq("default", "ns")) + + // V2 non-session catalog is used. + sql("CREATE NAMESPACE testcat.ns1.ns2") + testShowNamespaces("SHOW NAMESPACES IN testcat", Seq("ns1")) + testShowNamespaces("SHOW NAMESPACES IN testcat.ns1", Seq("ns1.ns2")) + + // TODO: Add tests for validating namespace metadata when DESCRIBE NAMESPACE is available. + } + + test("CreateNameSpace: test handling of 'IF NOT EXIST'") { + sql("CREATE NAMESPACE IF NOT EXISTS testcat.ns1") + + // The 'ns1' namespace already exists, so this should fail. + val exception = intercept[NamespaceAlreadyExistsException] { + sql("CREATE NAMESPACE testcat.ns1") + } + assert(exception.getMessage.contains("Namespace 'ns1' already exists")) + + // The following will be no-op since the namespace already exists. + sql("CREATE NAMESPACE IF NOT EXISTS testcat.ns1") + } + test("ShowNamespaces: show root namespaces with default v2 catalog") { spark.conf.set("spark.sql.default.catalog", "testcat") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index 0640d0540baa..a9b94bea9517 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -74,46 +74,6 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { }.head } - test("create database") { - val sql = - """ - |CREATE DATABASE IF NOT EXISTS database_name - |WITH DBPROPERTIES ('a'='a', 'b'='b', 'c'='c') - |COMMENT 'database_comment' LOCATION '/home/user/db' - """.stripMargin - val parsed = parser.parsePlan(sql) - val expected = CreateDatabaseCommand( - "database_name", - ifNotExists = true, - Some("/home/user/db"), - Some("database_comment"), - Map("a" -> "a", "b" -> "b", "c" -> "c")) - comparePlans(parsed, expected) - } - - test("create database -- check duplicates") { - def createDatabase(duplicateClause: String): String = { - s""" - |CREATE DATABASE IF NOT EXISTS database_name - |$duplicateClause - |$duplicateClause - """.stripMargin - } - val sql1 = createDatabase("COMMENT 'database_comment'") - val sql2 = createDatabase("LOCATION '/home/user/db'") - val sql3 = createDatabase("WITH DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')") - - intercept(sql1, "Found duplicate clauses: COMMENT") - intercept(sql2, "Found duplicate clauses: LOCATION") - intercept(sql3, "Found duplicate clauses: WITH DBPROPERTIES") - } - - test("create database - property values must be set") { - assertUnsupported( - sql = "CREATE DATABASE my_db WITH DBPROPERTIES('key_without_value', 'key_with_value'='x')", - containsThesePhrases = Seq("key_without_value")) - } - test("drop database") { val sql1 = "DROP DATABASE IF EXISTS database_name RESTRICT" val sql2 = "DROP DATABASE IF EXISTS database_name CASCADE" @@ -870,24 +830,6 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { comparePlans(parsed4, expected4) } - test("support for other types in DBPROPERTIES") { - val sql = - """ - |CREATE DATABASE database_name - |LOCATION '/home/user/db' - |WITH DBPROPERTIES ('a'=1, 'b'=0.1, 'c'=TRUE) - """.stripMargin - val parsed = parser.parsePlan(sql) - val expected = CreateDatabaseCommand( - "database_name", - ifNotExists = false, - Some("/home/user/db"), - None, - Map("a" -> "1", "b" -> "0.1", "c" -> "true")) - - comparePlans(parsed, expected) - } - test("Test CTAS #1") { val s1 = """ From 8c3469009cf84c95a81bb684244aac29b650d225 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 23 Oct 2019 00:14:43 -0700 Subject: [PATCH 17/58] [SPARK-29546][TESTS] Recover jersey-guava test dependency in docker-integration-tests ### What changes were proposed in this pull request? Currently, `docker-integration-tests` is broken in both JDK8/11. This PR aims to recover JDBC integration test for JDK8/11. ### Why are the changes needed? While SPARK-28737 upgraded `Jersey` to 2.29 for JDK11, `docker-integration-tests` is broken because `com.spotify.docker-client` still depends on `jersey-guava`. The latest `com.spotify.docker-client` also has this problem. - https://mvnrepository.com/artifact/com.spotify/docker-client/5.0.2 -> https://mvnrepository.com/artifact/org.glassfish.jersey.core/jersey-client/2.19 -> https://mvnrepository.com/artifact/org.glassfish.jersey.core/jersey-common/2.19 -> https://mvnrepository.com/artifact/org.glassfish.jersey.bundles.repackaged/jersey-guava/2.19 ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Manual because this is an integration test suite. ``` $ java -version openjdk version "1.8.0_222" OpenJDK Runtime Environment (AdoptOpenJDK)(build 1.8.0_222-b10) OpenJDK 64-Bit Server VM (AdoptOpenJDK)(build 25.222-b10, mixed mode) $ build/mvn install -DskipTests $ build/mvn -Pdocker-integration-tests -pl :spark-docker-integration-tests_2.12 test ``` ``` $ java -version openjdk version "11.0.5" 2019-10-15 OpenJDK Runtime Environment AdoptOpenJDK (build 11.0.5+10) OpenJDK 64-Bit Server VM AdoptOpenJDK (build 11.0.5+10, mixed mode) $ build/mvn install -DskipTests $ build/mvn -Pdocker-integration-tests -pl :spark-docker-integration-tests_2.12 test ``` **BEFORE** ``` *** RUN ABORTED *** com.spotify.docker.client.exceptions.DockerException: java.util.concurrent.ExecutionException: javax.ws.rs.ProcessingException: java.lang.NoClassDefFoundError: jersey/repackaged/com/google/common/util/concurrent/MoreExecutors at com.spotify.docker.client.DefaultDockerClient.propagate(DefaultDockerClient.java:1607) at com.spotify.docker.client.DefaultDockerClient.request(DefaultDockerClient.java:1538) at com.spotify.docker.client.DefaultDockerClient.ping(DefaultDockerClient.java:387) at org.apache.spark.sql.jdbc.DockerJDBCIntegrationSuite.beforeAll(DockerJDBCIntegrationSuite.scala:81) ``` **AFTER** ``` Run completed in 47 seconds, 999 milliseconds. Total number of tests run: 30 Suites: completed 6, aborted 0 Tests: succeeded 30, failed 0, canceled 0, ignored 6, pending 0 All tests passed. ``` Closes #26203 from dongjoon-hyun/SPARK-29546. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- external/docker-integration-tests/pom.xml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index a4956ff5ee9c..aff79b8b8e64 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -106,6 +106,14 @@ test-jar test + + + org.glassfish.jersey.bundles.repackaged + jersey-guava + 2.25.1 + test + mysql mysql-connector-java From cbe6eadc0c1d0384c1ee03f3a5b28cc583a60717 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 23 Oct 2019 10:56:19 +0200 Subject: [PATCH 18/58] [SPARK-29352][SQL][SS] Track active streaming queries in the SparkSession.sharedState ### What changes were proposed in this pull request? This moves the tracking of active queries from a per SparkSession state, to the shared SparkSession for better safety in isolated Spark Session environments. ### Why are the changes needed? We have checks to prevent the restarting of the same stream on the same spark session, but we can actually make that better in multi-tenant environments by actually putting that state in the SharedState instead of SessionState. This would allow a more comprehensive check for multi-tenant clusters. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Added tests to StreamingQueryManagerSuite Closes #26018 from brkyvz/sharedStreamingQueryManager. Lead-authored-by: Burak Yavuz Co-authored-by: Burak Yavuz Signed-off-by: Burak Yavuz --- .../spark/sql/internal/SharedState.scala | 10 ++- .../sql/streaming/StreamingQueryManager.scala | 22 +++-- .../StreamingQueryManagerSuite.scala | 80 ++++++++++++++++++- 3 files changed, 102 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index f1a648176c3b..d097f9f18f89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.internal import java.net.URL -import java.util.Locale +import java.util.{Locale, UUID} +import java.util.concurrent.ConcurrentHashMap import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -33,6 +34,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.execution.CacheManager import org.apache.spark.sql.execution.ui.{SQLAppStatusListener, SQLAppStatusStore, SQLTab} import org.apache.spark.sql.internal.StaticSQLConf._ +import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.status.ElementTrackingStore import org.apache.spark.util.Utils @@ -110,6 +112,12 @@ private[sql] class SharedState( */ val cacheManager: CacheManager = new CacheManager + /** + * A map of active streaming queries to the session specific StreamingQueryManager that manages + * the lifecycle of that stream. + */ + private[sql] val activeStreamingQueries = new ConcurrentHashMap[UUID, StreamingQueryManager]() + /** * A status store to query SQL status/metrics of this Spark application, based on SQL-specific * [[org.apache.spark.scheduler.SparkListenerEvent]]s. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 9abe38dfda0b..9b43a83e7b94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -352,8 +352,10 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } } - // Make sure no other query with same id is active - if (activeQueries.values.exists(_.id == query.id)) { + // Make sure no other query with same id is active across all sessions + val activeOption = + Option(sparkSession.sharedState.activeStreamingQueries.putIfAbsent(query.id, this)) + if (activeOption.isDefined || activeQueries.values.exists(_.id == query.id)) { throw new IllegalStateException( s"Cannot start query with id ${query.id} as another query with same id is " + s"already active. Perhaps you are attempting to restart a query from checkpoint " + @@ -370,9 +372,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo query.streamingQuery.start() } catch { case e: Throwable => - activeQueriesLock.synchronized { - activeQueries -= query.id - } + unregisterTerminatedStream(query.id) throw e } query @@ -380,9 +380,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo /** Notify (by the StreamingQuery) that the query has been terminated */ private[sql] def notifyQueryTermination(terminatedQuery: StreamingQuery): Unit = { - activeQueriesLock.synchronized { - activeQueries -= terminatedQuery.id - } + unregisterTerminatedStream(terminatedQuery.id) awaitTerminationLock.synchronized { if (lastTerminatedQuery == null || terminatedQuery.exception.nonEmpty) { lastTerminatedQuery = terminatedQuery @@ -391,4 +389,12 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } stateStoreCoordinator.deactivateInstances(terminatedQuery.runId) } + + private def unregisterTerminatedStream(terminatedQueryId: UUID): Unit = { + activeQueriesLock.synchronized { + // remove from shared state only if the streaming query manager also matches + sparkSession.sharedState.activeStreamingQueries.remove(terminatedQueryId, this) + activeQueries -= terminatedQueryId + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala index b26d2556b2e3..09580b94056b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.streaming +import java.io.File import java.util.concurrent.CountDownLatch import scala.concurrent.Future @@ -28,7 +29,7 @@ import org.scalatest.time.Span import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkException -import org.apache.spark.sql.Dataset +import org.apache.spark.sql.{Dataset, Encoders} import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.streaming.util.BlockingSource @@ -242,6 +243,83 @@ class StreamingQueryManagerSuite extends StreamTest { } } + testQuietly("can't start a streaming query with the same name in the same session") { + val ds1 = makeDataset._2 + val ds2 = makeDataset._2 + val queryName = "abc" + + val query1 = ds1.writeStream.format("noop").queryName(queryName).start() + try { + val e = intercept[IllegalArgumentException] { + ds2.writeStream.format("noop").queryName(queryName).start() + } + assert(e.getMessage.contains("query with that name is already active")) + } finally { + query1.stop() + } + } + + testQuietly("can start a streaming query with the same name in a different session") { + val session2 = spark.cloneSession() + + val ds1 = MemoryStream(Encoders.INT, spark.sqlContext).toDS() + val ds2 = MemoryStream(Encoders.INT, session2.sqlContext).toDS() + val queryName = "abc" + + val query1 = ds1.writeStream.format("noop").queryName(queryName).start() + val query2 = ds2.writeStream.format("noop").queryName(queryName).start() + + query1.stop() + query2.stop() + } + + testQuietly("can't start multiple instances of the same streaming query in the same session") { + withTempDir { dir => + val (ms1, ds1) = makeDataset + val (ms2, ds2) = makeDataset + val chkLocation = new File(dir, "_checkpoint").getCanonicalPath + val dataLocation = new File(dir, "data").getCanonicalPath + + val query1 = ds1.writeStream.format("parquet") + .option("checkpointLocation", chkLocation).start(dataLocation) + ms1.addData(1, 2, 3) + try { + val e = intercept[IllegalStateException] { + ds2.writeStream.format("parquet") + .option("checkpointLocation", chkLocation).start(dataLocation) + } + assert(e.getMessage.contains("same id")) + } finally { + query1.stop() + } + } + } + + testQuietly( + "can't start multiple instances of the same streaming query in the different sessions") { + withTempDir { dir => + val session2 = spark.cloneSession() + + val ms1 = MemoryStream(Encoders.INT, spark.sqlContext) + val ds2 = MemoryStream(Encoders.INT, session2.sqlContext).toDS() + val chkLocation = new File(dir, "_checkpoint").getCanonicalPath + val dataLocation = new File(dir, "data").getCanonicalPath + + val query1 = ms1.toDS().writeStream.format("parquet") + .option("checkpointLocation", chkLocation).start(dataLocation) + ms1.addData(1, 2, 3) + try { + val e = intercept[IllegalStateException] { + ds2.writeStream.format("parquet") + .option("checkpointLocation", chkLocation).start(dataLocation) + } + assert(e.getMessage.contains("same id")) + } finally { + query1.stop() + } + } + } + /** Run a body of code by defining a query on each dataset */ private def withQueriesOn(datasets: Dataset[_]*)(body: Seq[StreamingQuery] => Unit): Unit = { failAfter(streamingTimeout) { From 70dd9c0cabb52fac3ab20fbde7eeda41b19bad61 Mon Sep 17 00:00:00 2001 From: turbofei Date: Wed, 23 Oct 2019 20:31:06 +0900 Subject: [PATCH 19/58] [SPARK-29542][SQL][DOC] Make the descriptions of spark.sql.files.* be clearly ### What changes were proposed in this pull request? As described in [SPARK-29542](https://issues.apache.org/jira/browse/SPARK-29542) , the descriptions of `spark.sql.files.*` are confused. In this PR, I make their descriptions be clearly. ### Why are the changes needed? It makes the descriptions of `spark.sql.files.*` be clearly. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Existing UT. Closes #26200 from turboFei/SPARK-29542-partition-maxSize. Authored-by: turbofei Signed-off-by: HyukjinKwon --- .../org/apache/spark/sql/internal/SQLConf.scala | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 75db52e334b8..7f75bf84d65a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -980,7 +980,9 @@ object SQLConf { .createWithDefault(true) val FILES_MAX_PARTITION_BYTES = buildConf("spark.sql.files.maxPartitionBytes") - .doc("The maximum number of bytes to pack into a single partition when reading files.") + .doc("The maximum number of bytes to pack into a single partition when reading files. " + + "This configuration is effective only when using file-based sources such as Parquet, JSON " + + "and ORC.") .bytesConf(ByteUnit.BYTE) .createWithDefault(128 * 1024 * 1024) // parquet.block.size @@ -989,19 +991,24 @@ object SQLConf { .doc("The estimated cost to open a file, measured by the number of bytes could be scanned in" + " the same time. This is used when putting multiple files into a partition. It's better to" + " over estimated, then the partitions with small files will be faster than partitions with" + - " bigger files (which is scheduled first).") + " bigger files (which is scheduled first). This configuration is effective only when using" + + " file-based sources such as Parquet, JSON and ORC.") .longConf .createWithDefault(4 * 1024 * 1024) val IGNORE_CORRUPT_FILES = buildConf("spark.sql.files.ignoreCorruptFiles") .doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " + - "encountering corrupted files and the contents that have been read will still be returned.") + "encountering corrupted files and the contents that have been read will still be returned. " + + "This configuration is effective only when using file-based sources such as Parquet, JSON " + + "and ORC.") .booleanConf .createWithDefault(false) val IGNORE_MISSING_FILES = buildConf("spark.sql.files.ignoreMissingFiles") .doc("Whether to ignore missing files. If true, the Spark jobs will continue to run when " + - "encountering missing files and the contents that have been read will still be returned.") + "encountering missing files and the contents that have been read will still be returned. " + + "This configuration is effective only when using file-based sources such as Parquet, JSON " + + "and ORC.") .booleanConf .createWithDefault(false) From 0a7095156bdb565133f7dcc74546c51a5e5d2414 Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Wed, 23 Oct 2019 13:46:09 +0200 Subject: [PATCH 20/58] [SPARK-29499][CORE][PYSPARK] Add mapPartitionsWithIndex for RDDBarrier ### What changes were proposed in this pull request? Add mapPartitionsWithIndex for RDDBarrier. ### Why are the changes needed? There is only one method in `RDDBarrier`. We often use the partition index as a label for the current partition. We need to get the index from `TaskContext` index in the method of `mapPartitions` which is not convenient. ### Does this PR introduce any user-facing change? No ### How was this patch tested? New UT. Closes #26148 from ConeyLiu/barrier-index. Authored-by: Xianyang Liu Signed-off-by: Xingbo Jiang --- .../org/apache/spark/rdd/RDDBarrier.scala | 22 ++++++++ .../apache/spark/rdd/RDDBarrierSuite.scala | 9 ++++ dev/sparktestsupport/modules.py | 1 + python/pyspark/rdd.py | 14 ++++++ python/pyspark/tests/test_rddbarrier.py | 50 +++++++++++++++++++ 5 files changed, 96 insertions(+) create mode 100644 python/pyspark/tests/test_rddbarrier.py diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala index 42802f7113a1..b70ea0073c9a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala @@ -54,5 +54,27 @@ class RDDBarrier[T: ClassTag] private[spark] (rdd: RDD[T]) { ) } + /** + * :: Experimental :: + * Returns a new RDD by applying a function to each partition of the wrapped RDD, while tracking + * the index of the original partition. And all tasks are launched together in a barrier stage. + * The interface is the same as [[org.apache.spark.rdd.RDD#mapPartitionsWithIndex]]. + * Please see the API doc there. + * @see [[org.apache.spark.BarrierTaskContext]] + */ + @Experimental + @Since("3.0.0") + def mapPartitionsWithIndex[S: ClassTag]( + f: (Int, Iterator[T]) => Iterator[S], + preservesPartitioning: Boolean = false): RDD[S] = rdd.withScope { + val cleanedF = rdd.sparkContext.clean(f) + new MapPartitionsRDD( + rdd, + (_: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(index, iter), + preservesPartitioning, + isFromBarrier = true + ) + } + // TODO: [SPARK-25247] add extra conf to RDDBarrier, e.g., timeout. } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala index 2f6c4d6a42ea..f048f9543013 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala @@ -29,6 +29,15 @@ class RDDBarrierSuite extends SparkFunSuite with SharedSparkContext { assert(rdd2.isBarrier()) } + test("RDDBarrier mapPartitionsWithIndex") { + val rdd = sc.parallelize(1 to 12, 4) + assert(rdd.isBarrier() === false) + + val rdd2 = rdd.barrier().mapPartitionsWithIndex((index, iter) => Iterator(index)) + assert(rdd2.isBarrier()) + assert(rdd2.collect().toList === List(0, 1, 2, 3)) + } + test("create an RDDBarrier in the middle of a chain of RDDs") { val rdd = sc.parallelize(1 to 10, 4).map(x => x * 2) val rdd2 = rdd.barrier().mapPartitions(iter => iter).map(x => (x, x + 1)) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index c7ea065b28ed..1443584ccbcb 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -329,6 +329,7 @@ def __hash__(self): "pyspark.tests.test_join", "pyspark.tests.test_profiler", "pyspark.tests.test_rdd", + "pyspark.tests.test_rddbarrier", "pyspark.tests.test_readwrite", "pyspark.tests.test_serializers", "pyspark.tests.test_shuffle", diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 1edffaa4ca16..52ab86c0d88e 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2535,6 +2535,20 @@ def func(s, iterator): return f(iterator) return PipelinedRDD(self.rdd, func, preservesPartitioning, isFromBarrier=True) + def mapPartitionsWithIndex(self, f, preservesPartitioning=False): + """ + .. note:: Experimental + + Returns a new RDD by applying a function to each partition of the wrapped RDD, while + tracking the index of the original partition. And all tasks are launched together + in a barrier stage. + The interface is the same as :func:`RDD.mapPartitionsWithIndex`. + Please see the API doc there. + + .. versionadded:: 3.0.0 + """ + return PipelinedRDD(self.rdd, f, preservesPartitioning, isFromBarrier=True) + class PipelinedRDD(RDD): diff --git a/python/pyspark/tests/test_rddbarrier.py b/python/pyspark/tests/test_rddbarrier.py new file mode 100644 index 000000000000..8534fb4abb87 --- /dev/null +++ b/python/pyspark/tests/test_rddbarrier.py @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from pyspark.testing.utils import ReusedPySparkTestCase + + +class RDDBarrierTests(ReusedPySparkTestCase): + def test_map_partitions(self): + """Test RDDBarrier.mapPartitions""" + rdd = self.sc.parallelize(range(12), 4) + self.assertFalse(rdd._is_barrier()) + + rdd1 = rdd.barrier().mapPartitions(lambda it: it) + self.assertTrue(rdd1._is_barrier()) + + def test_map_partitions_with_index(self): + """Test RDDBarrier.mapPartitionsWithIndex""" + rdd = self.sc.parallelize(range(12), 4) + self.assertFalse(rdd._is_barrier()) + + def f(index, iterator): + yield index + rdd1 = rdd.barrier().mapPartitionsWithIndex(f) + self.assertTrue(rdd1._is_barrier()) + self.assertEqual(rdd1.collect(), [0, 1, 2, 3]) + + +if __name__ == "__main__": + import unittest + from pyspark.tests.test_rddbarrier import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) From df00b5c17d1770078f25f66504043bb3d6514ef7 Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Wed, 23 Oct 2019 15:23:25 +0200 Subject: [PATCH 21/58] [SPARK-29569][BUILD][DOCS] Copy and paste minified jquery instead when post-processing badges in JavaDoc ### What changes were proposed in this pull request? This PR fixes our documentation build to copy minified jquery file instead. The original file `jquery.js` seems missing as of Scala 2.12 upgrade. Scala 2.12 seems started to use minified `jquery.min.js` instead. Since we dropped Scala 2.11, we won't have to take care about legacy `jquery.js` anymore. Note that, there seem multiple weird stuff in the current ScalaDoc (e.g., some pages are weird, it starts from `scala.collection.*` or some pages are missing, or some docs are truncated, some badges look missing). It needs a separate double check and investigation. This PR targets to make the documentation generation pass in order to unblock Spark 3.0 preview. ### Why are the changes needed? To fix and make our official documentation build able to run. ### Does this PR introduce any user-facing change? It will enable to build the documentation in our official way. **Before:** ``` Making directory api/scala cp -r ../target/scala-2.12/unidoc/. api/scala Making directory api/java cp -r ../target/javaunidoc/. api/java Updating JavaDoc files for badge post-processing Copying jquery.js from Scala API to Java API for page post-processing of badges jekyll 3.8.6 | Error: No such file or directory rb_sysopen - ./api/scala/lib/jquery.js ``` **After:** ``` Making directory api/scala cp -r ../target/scala-2.12/unidoc/. api/scala Making directory api/java cp -r ../target/javaunidoc/. api/java Updating JavaDoc files for badge post-processing Copying jquery.min.js from Scala API to Java API for page post-processing of badges Copying api_javadocs.js to Java API for page post-processing of badges Appending content of api-javadocs.css to JavaDoc stylesheet.css for badge styles ... ``` ### How was this patch tested? Manually tested via: ``` SKIP_PYTHONDOC=1 SKIP_RDOC=1 SKIP_SQLDOC=1 jekyll build ``` Closes #26228 from HyukjinKwon/SPARK-29569. Authored-by: HyukjinKwon Signed-off-by: Xingbo Jiang --- docs/_plugins/copy_api_dirs.rb | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 2d1a9547e373..f95e4e2f9779 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -96,9 +96,9 @@ end # End updating JavaDoc files for badge post-processing - puts "Copying jquery.js from Scala API to Java API for page post-processing of badges" - jquery_src_file = "./api/scala/lib/jquery.js" - jquery_dest_file = "./api/java/lib/jquery.js" + puts "Copying jquery.min.js from Scala API to Java API for page post-processing of badges" + jquery_src_file = "./api/scala/lib/jquery.min.js" + jquery_dest_file = "./api/java/lib/jquery.min.js" mkdir_p("./api/java/lib") cp(jquery_src_file, jquery_dest_file) From 53a5f17803851dc232ec3b39242e85b881ade6ef Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Wed, 23 Oct 2019 08:26:47 -0700 Subject: [PATCH 22/58] [SPARK-29513][SQL] REFRESH TABLE should look up catalog/table like v2 commands ### What changes were proposed in this pull request? Add RefreshTableStatement and make REFRESH TABLE go through the same catalog/table resolution framework of v2 commands. ### Why are the changes needed? It's important to make all the commands have the same table resolution behavior, to avoid confusing end-users. e.g. ``` USE my_catalog DESC t // success and describe the table t from my_catalog REFRESH TABLE t // report table not found as there is no table t in the session catalog ``` ### Does this PR introduce any user-facing change? yes. When running REFRESH TABLE, Spark fails the command if the current catalog is set to a v2 catalog, or the table name specified a v2 catalog. ### How was this patch tested? New unit tests Closes #26183 from imback82/refresh_table. Lead-authored-by: Terry Kim Co-authored-by: Terry Kim Signed-off-by: Liang-Chi Hsieh --- .../spark/sql/catalyst/parser/SqlBase.g4 | 4 +- .../catalyst/analysis/ResolveCatalogs.scala | 3 + .../sql/catalyst/parser/AstBuilder.scala | 12 ++++ .../catalyst/plans/logical/statements.scala | 8 ++- .../catalyst/plans/logical/v2Commands.scala | 7 +++ .../sql/catalyst/parser/DDLParserSuite.scala | 6 ++ .../sql/connector/InMemoryTableCatalog.scala | 10 ++++ .../analysis/ResolveSessionCatalog.scala | 5 +- .../spark/sql/execution/SparkSqlParser.scala | 7 --- .../datasources/v2/DataSourceV2Strategy.scala | 5 +- .../datasources/v2/RefreshTableExec.scala | 33 +++++++++++ .../sql/connector/DataSourceV2SQLSuite.scala | 59 +++++++++---------- 12 files changed, 115 insertions(+), 44 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RefreshTableExec.scala diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 7e5e16b8e32b..970d244071e0 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -201,9 +201,9 @@ statement | (DESC | DESCRIBE) FUNCTION EXTENDED? describeFuncName #describeFunction | (DESC | DESCRIBE) database EXTENDED? db=errorCapturingIdentifier #describeDatabase | (DESC | DESCRIBE) TABLE? option=(EXTENDED | FORMATTED)? - multipartIdentifier partitionSpec? describeColName? #describeTable + multipartIdentifier partitionSpec? describeColName? #describeTable | (DESC | DESCRIBE) QUERY? query #describeQuery - | REFRESH TABLE tableIdentifier #refreshTable + | REFRESH TABLE multipartIdentifier #refreshTable | REFRESH (STRING | .*?) #refreshResource | CACHE LAZY? TABLE tableIdentifier (OPTIONS options=tablePropertyList)? (AS? query)? #cacheTable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index 6553b3d57d7f..9803fda0678f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -137,6 +137,9 @@ class ResolveCatalogs(val catalogManager: CatalogManager) writeOptions = c.options.filterKeys(_ != "path"), ignoreIfExists = c.ifNotExists) + case RefreshTableStatement(NonSessionCatalog(catalog, tableName)) => + RefreshTable(catalog.asTableCatalog, tableName.asIdentifier) + case c @ ReplaceTableStatement( NonSessionCatalog(catalog, tableName), _, _, _, _, _, _, _, _, _) => ReplaceTable( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 7c67952aba40..940dfd0fc333 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2799,4 +2799,16 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val partitionKeys = Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec) ShowPartitionsStatement(table, partitionKeys) } + + /** + * Create a [[RefreshTableStatement]]. + * + * For example: + * {{{ + * REFRESH TABLE multi_part_name + * }}} + */ + override def visitRefreshTable(ctx: RefreshTableContext): LogicalPlan = withOrigin(ctx) { + RefreshTableStatement(visitMultipartIdentifier(ctx.multipartIdentifier())) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index 3bd16187320f..127d9026f802 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -340,5 +340,11 @@ case class TruncateTableStatement( /** * A SHOW PARTITIONS statement, as parsed from SQL */ -case class ShowPartitionsStatement(tableName: Seq[String], +case class ShowPartitionsStatement( + tableName: Seq[String], partitionSpec: Option[TablePartitionSpec]) extends ParsedStatement + +/** + * A REFRESH TABLE statement, as parsed from SQL + */ +case class RefreshTableStatement(tableName: Seq[String]) extends ParsedStatement diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 8f5731a4a7a7..d80c1c034a86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -348,3 +348,10 @@ case class SetCatalogAndNamespace( catalogManager: CatalogManager, catalogName: Option[String], namespace: Option[Seq[String]]) extends Command + +/** + * The logical plan of the REFRESH TABLE command that works for v2 catalogs. + */ +case class RefreshTable( + catalog: TableCatalog, + ident: Identifier) extends Command diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 38ef357036a0..8e605bd15f69 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -1081,6 +1081,12 @@ class DDLParserSuite extends AnalysisTest { comparePlans(parsed5, expected5) } + test("REFRESH TABLE table") { + comparePlans( + parsePlan("REFRESH TABLE a.b.c"), + RefreshTableStatement(Seq("a", "b", "c"))) + } + private case class TableSpec( name: Seq[String], schema: Option[StructType], diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTableCatalog.scala index 8724a38d08d1..ece903a4c283 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTableCatalog.scala @@ -34,6 +34,8 @@ class BasicInMemoryTableCatalog extends TableCatalog { protected val tables: util.Map[Identifier, InMemoryTable] = new ConcurrentHashMap[Identifier, InMemoryTable]() + private val invalidatedTables: util.Set[Identifier] = ConcurrentHashMap.newKeySet() + private var _name: Option[String] = None override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = { @@ -55,6 +57,10 @@ class BasicInMemoryTableCatalog extends TableCatalog { } } + override def invalidateTable(ident: Identifier): Unit = { + invalidatedTables.add(ident) + } + override def createTable( ident: Identifier, schema: StructType, @@ -104,6 +110,10 @@ class BasicInMemoryTableCatalog extends TableCatalog { } } + def isTableInvalidated(ident: Identifier): Boolean = { + invalidatedTables.contains(ident) + } + def clearTables(): Unit = { tables.clear() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 4cca9846e996..230b8f3906bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog, TableChange, V1Table} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableRecoverPartitionsCommand, AlterTableSetLocationCommand, AlterTableSetPropertiesCommand, AlterTableUnsetPropertiesCommand, AnalyzeColumnCommand, AnalyzePartitionCommand, AnalyzeTableCommand, CreateDatabaseCommand, DescribeColumnCommand, DescribeTableCommand, DropTableCommand, ShowPartitionsCommand, ShowTablesCommand, TruncateTableCommand} -import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource} +import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, RefreshTable} import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{HIVE_TYPE_STRING, HiveStringType, MetadataBuilder, StructField, StructType} @@ -216,6 +216,9 @@ class ResolveSessionCatalog( ignoreIfExists = c.ifNotExists) } + case RefreshTableStatement(SessionCatalog(_, tableName)) => + RefreshTable(tableName.asTableIdentifier) + // For REPLACE TABLE [AS SELECT], we should fail if the catalog is resolved to the // session catalog and the table provider is not v2. case c @ ReplaceTableStatement( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 38f3c6e1b750..2439621f7725 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -143,13 +143,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { ShowCreateTableCommand(table) } - /** - * Create a [[RefreshTable]] logical plan. - */ - override def visitRefreshTable(ctx: RefreshTableContext): LogicalPlan = withOrigin(ctx) { - RefreshTable(visitTableIdentifier(ctx.tableIdentifier)) - } - /** * Create a [[RefreshResource]] logical plan. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 49035c3cc3da..4a7cb7db45de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import org.apache.spark.sql.{AnalysisException, Strategy} import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, AppendData, CreateNamespace, CreateTableAsSelect, CreateV2Table, DeleteFromTable, DescribeTable, DropTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Repartition, ReplaceTable, ReplaceTableAsSelect, SetCatalogAndNamespace, ShowNamespaces, ShowTables} +import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, AppendData, CreateNamespace, CreateTableAsSelect, CreateV2Table, DeleteFromTable, DescribeTable, DropTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, RefreshTable, Repartition, ReplaceTable, ReplaceTableAsSelect, SetCatalogAndNamespace, ShowNamespaces, ShowTables} import org.apache.spark.sql.connector.catalog.{StagingTableCatalog, TableCapability} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} @@ -193,6 +193,9 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { catalog, ident, parts, query, planLater(query), props, writeOptions, ifNotExists) :: Nil } + case RefreshTable(catalog, ident) => + RefreshTableExec(catalog, ident) :: Nil + case ReplaceTable(catalog, ident, schema, parts, props, orCreate) => catalog match { case staging: StagingTableCatalog => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RefreshTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RefreshTableExec.scala new file mode 100644 index 000000000000..2a19ff304a9e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RefreshTableExec.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} + +case class RefreshTableExec( + catalog: TableCatalog, + ident: Identifier) extends V2CommandExec { + override protected def run(): Seq[InternalRow] = { + catalog.invalidateTable(ident) + Seq.empty + } + + override def output: Seq[Attribute] = Seq.empty +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 2ea26787dbb1..463147903c92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1112,6 +1112,20 @@ class DataSourceV2SQLSuite } } + test("REFRESH TABLE: v2 table") { + val t = "testcat.ns1.ns2.tbl" + withTable(t) { + sql(s"CREATE TABLE $t (id bigint, data string) USING foo") + + val testCatalog = catalog("testcat").asTableCatalog.asInstanceOf[InMemoryTableCatalog] + val identifier = Identifier.of(Array("ns1", "ns2"), "tbl") + + assert(!testCatalog.isTableInvalidated(identifier)) + sql(s"REFRESH TABLE $t") + assert(testCatalog.isTableInvalidated(identifier)) + } + } + test("REPLACE TABLE: v1 table") { val e = intercept[AnalysisException] { sql(s"CREATE OR REPLACE TABLE tbl (a int) USING ${classOf[SimpleScanSource].getName}") @@ -1211,16 +1225,8 @@ class DataSourceV2SQLSuite val t = "testcat.ns1.ns2.tbl" withTable(t) { spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo") - - val e = intercept[AnalysisException] { - sql(s"ANALYZE TABLE $t COMPUTE STATISTICS") - } - assert(e.message.contains("ANALYZE TABLE is only supported with v1 tables")) - - val e2 = intercept[AnalysisException] { - sql(s"ANALYZE TABLE $t COMPUTE STATISTICS FOR ALL COLUMNS") - } - assert(e2.message.contains("ANALYZE TABLE is only supported with v1 tables")) + testV1Command("ANALYZE TABLE", s"$t COMPUTE STATISTICS") + testV1Command("ANALYZE TABLE", s"$t COMPUTE STATISTICS FOR ALL COLUMNS") } } @@ -1228,11 +1234,7 @@ class DataSourceV2SQLSuite val t = "testcat.ns1.ns2.tbl" withTable(t) { spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo") - - val e = intercept[AnalysisException] { - sql(s"MSCK REPAIR TABLE $t") - } - assert(e.message.contains("MSCK REPAIR TABLE is only supported with v1 tables")) + testV1Command("MSCK REPAIR TABLE", t) } } @@ -1246,15 +1248,8 @@ class DataSourceV2SQLSuite |PARTITIONED BY (id) """.stripMargin) - val e1 = intercept[AnalysisException] { - sql(s"TRUNCATE TABLE $t") - } - assert(e1.message.contains("TRUNCATE TABLE is only supported with v1 tables")) - - val e2 = intercept[AnalysisException] { - sql(s"TRUNCATE TABLE $t PARTITION(id='1')") - } - assert(e2.message.contains("TRUNCATE TABLE is only supported with v1 tables")) + testV1Command("TRUNCATE TABLE", t) + testV1Command("TRUNCATE TABLE", s"$t PARTITION(id='1')") } } @@ -1268,16 +1263,16 @@ class DataSourceV2SQLSuite |PARTITIONED BY (id) """.stripMargin) - val e1 = intercept[AnalysisException] { - val partition = sql(s"SHOW PARTITIONS $t") - } - assert(e1.message.contains("SHOW PARTITIONS is only supported with v1 tables")) + testV1Command("SHOW PARTITIONS", t) + testV1Command("SHOW PARTITIONS", s"$t PARTITION(id='1')") + } + } - val e2 = intercept[AnalysisException] { - val partition2 = sql(s"SHOW PARTITIONS $t PARTITION(id='1')") - } - assert(e2.message.contains("SHOW PARTITIONS is only supported with v1 tables")) + private def testV1Command(sqlCommand: String, sqlParams: String): Unit = { + val e = intercept[AnalysisException] { + sql(s"$sqlCommand $sqlParams") } + assert(e.message.contains(s"$sqlCommand is only supported with v1 tables")) } private def assertAnalysisError(sqlStatement: String, expectedError: String): Unit = { From bfbf2821f34afba1c3a8a720084b5421a9de77eb Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Thu, 24 Oct 2019 00:41:48 +0800 Subject: [PATCH 23/58] [SPARK-29503][SQL] Remove conversion CreateNamedStruct to CreateNamedStructUnsafe ### What changes were proposed in this pull request? There's a case where MapObjects has a lambda function which creates nested struct - unsafe data in safe data struct. In this case, MapObjects doesn't copy the row returned from lambda function (as outmost data type is safe data struct), which misses copying nested unsafe data. The culprit is that `UnsafeProjection.toUnsafeExprs` converts `CreateNamedStruct` to `CreateNamedStructUnsafe` (this is the only place where `CreateNamedStructUnsafe` is used) which incurs safe and unsafe being mixed up temporarily, which may not be needed at all at least logically, as it will finally assembly these evaluations to `UnsafeRow`. > Before the patch ``` /* 105 */ private ArrayData MapObjects_0(InternalRow i) { /* 106 */ boolean isNull_1 = i.isNullAt(0); /* 107 */ ArrayData value_1 = isNull_1 ? /* 108 */ null : (i.getArray(0)); /* 109 */ ArrayData value_0 = null; /* 110 */ /* 111 */ if (!isNull_1) { /* 112 */ /* 113 */ int dataLength_0 = value_1.numElements(); /* 114 */ /* 115 */ ArrayData[] convertedArray_0 = null; /* 116 */ convertedArray_0 = new ArrayData[dataLength_0]; /* 117 */ /* 118 */ /* 119 */ int loopIndex_0 = 0; /* 120 */ /* 121 */ while (loopIndex_0 < dataLength_0) { /* 122 */ value_MapObject_lambda_variable_1 = (int) (value_1.getInt(loopIndex_0)); /* 123 */ isNull_MapObject_lambda_variable_1 = value_1.isNullAt(loopIndex_0); /* 124 */ /* 125 */ ArrayData arrayData_0 = ArrayData.allocateArrayData( /* 126 */ -1, 1L, " createArray failed."); /* 127 */ /* 128 */ mutableStateArray_0[0].reset(); /* 129 */ /* 130 */ /* 131 */ mutableStateArray_0[0].zeroOutNullBytes(); /* 132 */ /* 133 */ /* 134 */ if (isNull_MapObject_lambda_variable_1) { /* 135 */ mutableStateArray_0[0].setNullAt(0); /* 136 */ } else { /* 137 */ mutableStateArray_0[0].write(0, value_MapObject_lambda_variable_1); /* 138 */ } /* 139 */ arrayData_0.update(0, (mutableStateArray_0[0].getRow())); /* 140 */ if (false) { /* 141 */ convertedArray_0[loopIndex_0] = null; /* 142 */ } else { /* 143 */ convertedArray_0[loopIndex_0] = arrayData_0 instanceof UnsafeArrayData? arrayData_0.copy() : arrayData_0; /* 144 */ } /* 145 */ /* 146 */ loopIndex_0 += 1; /* 147 */ } /* 148 */ /* 149 */ value_0 = new org.apache.spark.sql.catalyst.util.GenericArrayData(convertedArray_0); /* 150 */ } /* 151 */ globalIsNull_0 = isNull_1; /* 152 */ return value_0; /* 153 */ } ``` > After the patch ``` /* 104 */ private ArrayData MapObjects_0(InternalRow i) { /* 105 */ boolean isNull_1 = i.isNullAt(0); /* 106 */ ArrayData value_1 = isNull_1 ? /* 107 */ null : (i.getArray(0)); /* 108 */ ArrayData value_0 = null; /* 109 */ /* 110 */ if (!isNull_1) { /* 111 */ /* 112 */ int dataLength_0 = value_1.numElements(); /* 113 */ /* 114 */ ArrayData[] convertedArray_0 = null; /* 115 */ convertedArray_0 = new ArrayData[dataLength_0]; /* 116 */ /* 117 */ /* 118 */ int loopIndex_0 = 0; /* 119 */ /* 120 */ while (loopIndex_0 < dataLength_0) { /* 121 */ value_MapObject_lambda_variable_1 = (int) (value_1.getInt(loopIndex_0)); /* 122 */ isNull_MapObject_lambda_variable_1 = value_1.isNullAt(loopIndex_0); /* 123 */ /* 124 */ ArrayData arrayData_0 = ArrayData.allocateArrayData( /* 125 */ -1, 1L, " createArray failed."); /* 126 */ /* 127 */ Object[] values_0 = new Object[1]; /* 128 */ /* 129 */ /* 130 */ if (isNull_MapObject_lambda_variable_1) { /* 131 */ values_0[0] = null; /* 132 */ } else { /* 133 */ values_0[0] = value_MapObject_lambda_variable_1; /* 134 */ } /* 135 */ /* 136 */ final InternalRow value_3 = new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(values_0); /* 137 */ values_0 = null; /* 138 */ arrayData_0.update(0, value_3); /* 139 */ if (false) { /* 140 */ convertedArray_0[loopIndex_0] = null; /* 141 */ } else { /* 142 */ convertedArray_0[loopIndex_0] = arrayData_0 instanceof UnsafeArrayData? arrayData_0.copy() : arrayData_0; /* 143 */ } /* 144 */ /* 145 */ loopIndex_0 += 1; /* 146 */ } /* 147 */ /* 148 */ value_0 = new org.apache.spark.sql.catalyst.util.GenericArrayData(convertedArray_0); /* 149 */ } /* 150 */ globalIsNull_0 = isNull_1; /* 151 */ return value_0; /* 152 */ } ``` ### Why are the changes needed? This patch fixes the bug described above. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? UT added which fails on master branch and passes on PR. Closes #26173 from HeartSaVioR/SPARK-29503. Authored-by: Jungtaek Lim (HeartSaVioR) Signed-off-by: Wenchen Fan --- .../sql/catalyst/expressions/Projection.scala | 8 +--- .../expressions/complexTypeCreator.scala | 48 +++++-------------- .../sql/catalyst/optimizer/ComplexTypes.scala | 6 +-- .../optimizer/NormalizeFloatingNumbers.scala | 5 +- .../sql/catalyst/optimizer/expressions.scala | 4 +- .../expressions/ComplexTypeSuite.scala | 1 - .../scala/org/apache/spark/sql/Column.scala | 2 +- .../spark/sql/DataFrameComplexTypeSuite.scala | 22 +++++++++ 8 files changed, 43 insertions(+), 53 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index eaaf94baac21..300f075d3276 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -127,12 +127,6 @@ object UnsafeProjection InterpretedUnsafeProjection.createProjection(in) } - protected def toUnsafeExprs(exprs: Seq[Expression]): Seq[Expression] = { - exprs.map(_ transform { - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) - }) - } - /** * Returns an UnsafeProjection for given StructType. * @@ -153,7 +147,7 @@ object UnsafeProjection * Returns an UnsafeProjection for given sequence of bound Expressions. */ def create(exprs: Seq[Expression]): UnsafeProjection = { - createObject(toUnsafeExprs(exprs)) + createObject(exprs) } def create(expr: Expression): UnsafeProjection = create(Seq(expr)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index cae3c0528e13..3f722e8537c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -295,9 +295,20 @@ object CreateStruct extends FunctionBuilder { } /** - * Common base class for both [[CreateNamedStruct]] and [[CreateNamedStructUnsafe]]. + * Creates a struct with the given field names and values + * + * @param children Seq(name1, val1, name2, val2, ...) */ -trait CreateNamedStructLike extends Expression { +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.", + examples = """ + Examples: + > SELECT _FUNC_("a", 1, "b", 2, "c", 3); + {"a":1,"b":2,"c":3} + """) +// scalastyle:on line.size.limit +case class CreateNamedStruct(children: Seq[Expression]) extends Expression { lazy val (nameExprs, valExprs) = children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip @@ -348,23 +359,6 @@ trait CreateNamedStructLike extends Expression { override def eval(input: InternalRow): Any = { InternalRow(valExprs.map(_.eval(input)): _*) } -} - -/** - * Creates a struct with the given field names and values - * - * @param children Seq(name1, val1, name2, val2, ...) - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.", - examples = """ - Examples: - > SELECT _FUNC_("a", 1, "b", 2, "c", 3); - {"a":1,"b":2,"c":3} - """) -// scalastyle:on line.size.limit -case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStructLike { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName @@ -397,22 +391,6 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc override def prettyName: String = "named_struct" } -/** - * Creates a struct with the given field names and values. This is a variant that returns - * UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with - * this expression automatically at runtime. - * - * @param children Seq(name1, val1, name2, val2, ...) - */ -case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike { - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) - ExprCode(code = eval.code, isNull = FalseLiteral, value = eval.value) - } - - override def prettyName: String = "named_struct_unsafe" -} - /** * Creates a map after splitting the input text into key/value pairs using delimiters */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index db7d6d3254bd..1743565ccb6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule /** - * Simplify redundant [[CreateNamedStructLike]], [[CreateArray]] and [[CreateMap]] expressions. + * Simplify redundant [[CreateNamedStruct]], [[CreateArray]] and [[CreateMap]] expressions. */ object SimplifyExtractValueOps extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -37,8 +37,8 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { case a: Aggregate => a case p => p.transformExpressionsUp { // Remove redundant field extraction. - case GetStructField(createNamedStructLike: CreateNamedStructLike, ordinal, _) => - createNamedStructLike.valExprs(ordinal) + case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) => + createNamedStruct.valExprs(ordinal) // Remove redundant array indexing. case GetArrayStructFields(CreateArray(elems), field, ordinal, _, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index b036092cf1fc..ea01d9e63eef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateNamedStructUnsafe, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, KnownFloatingPointNormalized, LambdaFunction, NamedLambdaVariable, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, KnownFloatingPointNormalized, LambdaFunction, NamedLambdaVariable, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery, Window} @@ -114,9 +114,6 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { case CreateNamedStruct(children) => CreateNamedStruct(children.map(normalize)) - case CreateNamedStructUnsafe(children) => - CreateNamedStructUnsafe(children.map(normalize)) - case CreateArray(children) => CreateArray(children.map(normalize)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 0a6737ba4211..36ad796c08a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -227,8 +227,8 @@ object OptimizeIn extends Rule[LogicalPlan] { if (newList.length == 1 // TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed, // TODO: we exclude them in this rule. - && !v.isInstanceOf[CreateNamedStructLike] - && !newList.head.isInstanceOf[CreateNamedStructLike]) { + && !v.isInstanceOf[CreateNamedStruct] + && !newList.head.isInstanceOf[CreateNamedStruct]) { EqualTo(v, newList.head) } else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) { val hSet = newList.map(e => e.eval(EmptyRow)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 0c4438987cd2..9039cd645159 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -369,7 +369,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val b = AttributeReference("b", IntegerType)() checkMetadata(CreateStruct(Seq(a, b))) checkMetadata(CreateNamedStruct(Seq("a", a, "b", b))) - checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b))) } test("StringToMap") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 7b903a3f7f14..ed10843b0859 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -200,7 +200,7 @@ class Column(val expr: Expression) extends Logging { UnresolvedAlias(a, Some(Column.generateAlias)) // Wait until the struct is resolved. This will generate a nicer looking alias. - case struct: CreateNamedStructLike => UnresolvedAlias(struct) + case struct: CreateNamedStruct => UnresolvedAlias(struct) case expr: Expression => Alias(expr, toPrettySQL(expr))() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index e9179a39d3b6..4f2564290662 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -18,8 +18,12 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.DefinedByConstructorParams +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.objects.MapObjects import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.ArrayType /** * A test suite to test DataFrame/SQL functionalities with complex types (i.e. array, struct, map). @@ -64,6 +68,24 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSparkSession { val ds100_5 = Seq(S100_5()).toDS() ds100_5.rdd.count } + + test("SPARK-29503 nest unsafe struct inside safe array") { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + val df = spark.sparkContext.parallelize(Seq(Seq(1, 2, 3))).toDF("items") + + // items: Seq[Int] => items.map { item => Seq(Struct(item)) } + val result = df.select( + new Column(MapObjects( + (item: Expression) => array(struct(new Column(item))).expr, + $"items".expr, + df.schema("items").dataType.asInstanceOf[ArrayType].elementType + )) as "items" + ).collect() + + assert(result.size === 1) + assert(result === Row(Seq(Seq(Row(1)), Seq(Row(2)), Seq(Row(3)))) :: Nil) + } + } } class S100( From 7e8e4c0a146ef071808a611e256ab049b396212a Mon Sep 17 00:00:00 2001 From: jiake Date: Thu, 24 Oct 2019 01:18:07 +0800 Subject: [PATCH 24/58] [SPARK-29552][SQL] Execute the "OptimizeLocalShuffleReader" rule when creating new query stage and then can optimize the shuffle reader to local shuffle reader as much as possible ### What changes were proposed in this pull request? `OptimizeLocalShuffleReader` rule is very conservative and gives up optimization as long as there are extra shuffles introduced. It's very likely that most of the added local shuffle readers are fine and only one introduces extra shuffle. However, it's very hard to make `OptimizeLocalShuffleReader` optimal, a simple workaround is to run this rule again right before executing a query stage. ### Why are the changes needed? Optimize more shuffle reader to local shuffle reader. ### Does this PR introduce any user-facing change? No ### How was this patch tested? existing ut Closes #26207 from JkSelf/resolve-multi-joins-issue. Authored-by: jiake Signed-off-by: Wenchen Fan --- .../execution/adaptive/AdaptiveSparkPlanExec.scala | 9 +++++++++ .../execution/adaptive/AdaptiveQueryExecSuite.scala | 11 +++++++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index f45e3560b2cf..f01947d8f5ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -92,6 +92,15 @@ case class AdaptiveSparkPlanExec( // optimizations should be stage-independent. @transient private val queryStageOptimizerRules: Seq[Rule[SparkPlan]] = Seq( ReuseAdaptiveSubquery(conf, subqueryCache), + + // When adding local shuffle readers in 'OptimizeLocalShuffleReader`, we revert all the local + // readers if additional shuffles are introduced. This may be too conservative: maybe there is + // only one local reader that introduces shuffle, and we can still keep other local readers. + // Here we re-execute this rule with the sub-plan-tree of a query stage, to make sure necessary + // local readers are added before executing the query stage. + // This rule must be executed before `ReduceNumShufflePartitions`, as local shuffle readers + // can't change number of partitions. + OptimizeLocalShuffleReader(conf), ReduceNumShufflePartitions(conf), ApplyColumnarRulesAndInsertTransitions(session.sessionState.conf, session.sessionState.columnarRules), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 43802968c469..649467a27d93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -163,8 +163,9 @@ class AdaptiveQueryExecSuite assert(smj.size == 3) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 3) - // additional shuffle exchange introduced, only one shuffle reader to local shuffle reader. - checkNumLocalShuffleReaders(adaptivePlan, 1) + // The child of remaining one BroadcastHashJoin is not ShuffleQueryStage. + // So only two LocalShuffleReader. + checkNumLocalShuffleReaders(adaptivePlan, 2) } } @@ -188,7 +189,8 @@ class AdaptiveQueryExecSuite assert(smj.size == 3) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 3) - // additional shuffle exchange introduced, only one shuffle reader to local shuffle reader. + // The child of remaining two BroadcastHashJoin is not ShuffleQueryStage. + // So only two LocalShuffleReader. checkNumLocalShuffleReaders(adaptivePlan, 1) } } @@ -213,7 +215,8 @@ class AdaptiveQueryExecSuite assert(smj.size == 3) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 3) - // additional shuffle exchange introduced, only one shuffle reader to local shuffle reader. + // The child of remaining two BroadcastHashJoin is not ShuffleQueryStage. + // So only two LocalShuffleReader. checkNumLocalShuffleReaders(adaptivePlan, 1) } } From 5867707835a442827c9df17f79c28c86eb9f2c68 Mon Sep 17 00:00:00 2001 From: Luca Canali Date: Wed, 23 Oct 2019 10:45:11 -0700 Subject: [PATCH 25/58] [SPARK-29557][BUILD] Update dropwizard/codahale metrics library to 3.2.6 ### What changes were proposed in this pull request? This proposes to update the dropwizard/codahale metrics library version used by Spark to `3.2.6` which is the last version supporting Ganglia. ### Why are the changes needed? Spark is currently using Dropwizard metrics version 3.1.5, a version that is no more actively developed nor maintained, according to the project's Github repo README. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Existing tests + manual tests on a YARN cluster. Closes #26212 from LucaCanali/updateDropwizardVersion. Authored-by: Luca Canali Signed-off-by: Dongjoon Hyun --- dev/deps/spark-deps-hadoop-2.7 | 8 ++++---- dev/deps/spark-deps-hadoop-3.2 | 8 ++++---- pom.xml | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 26dc6e7bd8bf..f21e76bf4331 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -149,10 +149,10 @@ lz4-java-1.6.0.jar machinist_2.12-0.6.8.jar macro-compat_2.12-1.1.1.jar mesos-1.4.0-shaded-protobuf.jar -metrics-core-3.1.5.jar -metrics-graphite-3.1.5.jar -metrics-json-3.1.5.jar -metrics-jvm-3.1.5.jar +metrics-core-3.2.6.jar +metrics-graphite-3.2.6.jar +metrics-json-3.2.6.jar +metrics-jvm-3.2.6.jar minlog-1.3.0.jar netty-all-4.1.42.Final.jar objenesis-2.5.1.jar diff --git a/dev/deps/spark-deps-hadoop-3.2 b/dev/deps/spark-deps-hadoop-3.2 index a92b7124cb4a..3ecc3c2b0d35 100644 --- a/dev/deps/spark-deps-hadoop-3.2 +++ b/dev/deps/spark-deps-hadoop-3.2 @@ -179,10 +179,10 @@ lz4-java-1.6.0.jar machinist_2.12-0.6.8.jar macro-compat_2.12-1.1.1.jar mesos-1.4.0-shaded-protobuf.jar -metrics-core-3.1.5.jar -metrics-graphite-3.1.5.jar -metrics-json-3.1.5.jar -metrics-jvm-3.1.5.jar +metrics-core-3.2.6.jar +metrics-graphite-3.2.6.jar +metrics-json-3.2.6.jar +metrics-jvm-3.2.6.jar minlog-1.3.0.jar mssql-jdbc-6.2.1.jre7.jar netty-all-4.1.42.Final.jar diff --git a/pom.xml b/pom.xml index 69b5b79b7b07..c42ef5c6626d 100644 --- a/pom.xml +++ b/pom.xml @@ -148,7 +148,7 @@ 0.9.3 2.4.0 2.0.8 - 3.1.5 + 3.2.6 1.8.2 hadoop2 1.8.10 From b91356e4c2a5c4a2e77c78a05a93a9d3979f1fce Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 23 Oct 2019 21:41:05 +0000 Subject: [PATCH 26/58] [SPARK-29533][SQL][TESTS][FOLLOWUP] Regenerate the result on EC2 ### What changes were proposed in this pull request? This is a follow-up of https://github.com/apache/spark/pull/26189 to regenerate the result on EC2. ### Why are the changes needed? This will be used for the other PR reviews. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? N/A. Closes #26233 from dongjoon-hyun/SPARK-29533. Authored-by: Dongjoon Hyun Signed-off-by: DB Tsai --- .../IntervalBenchmark-jdk11-results.txt | 44 +++++++++---------- .../benchmarks/IntervalBenchmark-results.txt | 44 +++++++++---------- 2 files changed, 44 insertions(+), 44 deletions(-) diff --git a/sql/core/benchmarks/IntervalBenchmark-jdk11-results.txt b/sql/core/benchmarks/IntervalBenchmark-jdk11-results.txt index 2a3903200a8a..6605bd2c13b4 100644 --- a/sql/core/benchmarks/IntervalBenchmark-jdk11-results.txt +++ b/sql/core/benchmarks/IntervalBenchmark-jdk11-results.txt @@ -1,25 +1,25 @@ -OpenJDK 64-Bit Server VM 11.0.2+9 on Mac OS X 10.15 -Intel(R) Core(TM) i7-4850HQ CPU @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.4+11-post-Ubuntu-1ubuntu218.04.3 on Linux 4.15.0-1044-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz cast strings to intervals: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -string w/ interval 471 513 57 2.1 470.7 1.0X -string w/o interval 437 444 8 2.3 436.9 1.1X -1 units w/ interval 726 758 45 1.4 726.3 0.6X -1 units w/o interval 712 717 5 1.4 711.7 0.7X -2 units w/ interval 926 935 12 1.1 925.9 0.5X -2 units w/o interval 943 947 3 1.1 943.4 0.5X -3 units w/ interval 1089 1116 31 0.9 1089.0 0.4X -3 units w/o interval 1105 1108 3 0.9 1105.1 0.4X -4 units w/ interval 1260 1261 1 0.8 1260.4 0.4X -4 units w/o interval 1276 1277 1 0.8 1275.9 0.4X -5 units w/ interval 1436 1445 11 0.7 1435.6 0.3X -5 units w/o interval 1455 1463 6 0.7 1455.5 0.3X -6 units w/ interval 1634 1639 4 0.6 1634.4 0.3X -6 units w/o interval 1642 1644 3 0.6 1641.7 0.3X -7 units w/ interval 1829 1838 8 0.5 1828.6 0.3X -7 units w/o interval 1850 1853 4 0.5 1849.5 0.3X -8 units w/ interval 2065 2070 5 0.5 2065.4 0.2X -8 units w/o interval 2070 2090 21 0.5 2070.0 0.2X -9 units w/ interval 2279 2290 10 0.4 2278.7 0.2X -9 units w/o interval 2276 2285 8 0.4 2275.7 0.2X +string w/ interval 663 758 99 1.5 663.2 1.0X +string w/o interval 563 582 19 1.8 563.2 1.2X +1 units w/ interval 891 953 97 1.1 891.2 0.7X +1 units w/o interval 894 905 15 1.1 893.6 0.7X +2 units w/ interval 1142 1169 24 0.9 1141.8 0.6X +2 units w/o interval 1195 1200 7 0.8 1194.7 0.6X +3 units w/ interval 1449 1452 3 0.7 1449.5 0.5X +3 units w/o interval 1489 1491 1 0.7 1489.3 0.4X +4 units w/ interval 1688 1690 1 0.6 1688.4 0.4X +4 units w/o interval 1711 1722 11 0.6 1710.9 0.4X +5 units w/ interval 1961 1983 23 0.5 1961.3 0.3X +5 units w/o interval 1983 1991 10 0.5 1983.4 0.3X +6 units w/ interval 2217 2228 11 0.5 2216.9 0.3X +6 units w/o interval 2240 2244 5 0.4 2239.6 0.3X +7 units w/ interval 2454 2468 16 0.4 2454.1 0.3X +7 units w/o interval 2480 2491 15 0.4 2479.5 0.3X +8 units w/ interval 2762 2792 26 0.4 2761.9 0.2X +8 units w/o interval 2763 2778 14 0.4 2762.9 0.2X +9 units w/ interval 3036 3060 21 0.3 3036.4 0.2X +9 units w/o interval 3095 3111 15 0.3 3094.8 0.2X diff --git a/sql/core/benchmarks/IntervalBenchmark-results.txt b/sql/core/benchmarks/IntervalBenchmark-results.txt index 9010b980c07b..40169826cc62 100644 --- a/sql/core/benchmarks/IntervalBenchmark-results.txt +++ b/sql/core/benchmarks/IntervalBenchmark-results.txt @@ -1,25 +1,25 @@ -Java HotSpot(TM) 64-Bit Server VM 1.8.0_202-b08 on Mac OS X 10.15 -Intel(R) Core(TM) i7-4850HQ CPU @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_222-8u222-b10-1ubuntu1~18.04.1-b10 on Linux 4.15.0-1044-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz cast strings to intervals: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -string w/ interval 420 435 18 2.4 419.8 1.0X -string w/o interval 359 365 10 2.8 358.7 1.2X -1 units w/ interval 752 759 8 1.3 752.0 0.6X -1 units w/o interval 762 766 4 1.3 762.0 0.6X -2 units w/ interval 961 970 8 1.0 960.7 0.4X -2 units w/o interval 970 976 9 1.0 970.2 0.4X -3 units w/ interval 1130 1136 7 0.9 1130.4 0.4X -3 units w/o interval 1150 1158 9 0.9 1150.3 0.4X -4 units w/ interval 1333 1336 3 0.7 1333.5 0.3X -4 units w/o interval 1354 1359 4 0.7 1354.5 0.3X -5 units w/ interval 1523 1525 2 0.7 1523.3 0.3X -5 units w/o interval 1549 1551 3 0.6 1549.4 0.3X -6 units w/ interval 1661 1663 2 0.6 1660.8 0.3X -6 units w/o interval 1691 1704 13 0.6 1691.2 0.2X -7 units w/ interval 1811 1817 8 0.6 1810.6 0.2X -7 units w/o interval 1853 1854 1 0.5 1853.2 0.2X -8 units w/ interval 2029 2037 8 0.5 2028.7 0.2X -8 units w/o interval 2075 2075 1 0.5 2074.5 0.2X -9 units w/ interval 2170 2175 5 0.5 2170.0 0.2X -9 units w/o interval 2204 2212 8 0.5 2203.6 0.2X +string w/ interval 600 641 37 1.7 600.3 1.0X +string w/o interval 536 544 12 1.9 536.4 1.1X +1 units w/ interval 1017 1027 9 1.0 1016.8 0.6X +1 units w/o interval 1055 1073 16 0.9 1054.7 0.6X +2 units w/ interval 1272 1292 29 0.8 1272.2 0.5X +2 units w/o interval 1309 1314 9 0.8 1309.0 0.5X +3 units w/ interval 1545 1566 20 0.6 1544.8 0.4X +3 units w/o interval 1606 1610 5 0.6 1605.8 0.4X +4 units w/ interval 1820 1826 6 0.5 1819.7 0.3X +4 units w/o interval 1882 1885 3 0.5 1881.5 0.3X +5 units w/ interval 2039 2043 7 0.5 2038.9 0.3X +5 units w/o interval 2131 2133 3 0.5 2130.6 0.3X +6 units w/ interval 2269 2272 4 0.4 2269.5 0.3X +6 units w/o interval 2327 2333 6 0.4 2327.2 0.3X +7 units w/ interval 2477 2485 10 0.4 2476.8 0.2X +7 units w/o interval 2536 2538 3 0.4 2536.0 0.2X +8 units w/ interval 2764 2781 27 0.4 2763.8 0.2X +8 units w/o interval 2843 2847 5 0.4 2842.9 0.2X +9 units w/ interval 2983 2997 12 0.3 2982.5 0.2X +9 units w/o interval 3071 3072 1 0.3 3071.1 0.2X From 7ecf968527a63bc5bd2397ed04f1149dd07821ca Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 23 Oct 2019 16:44:10 -0700 Subject: [PATCH 27/58] [SPARK-29567][TESTS] Update JDBC Integration Test Docker Images ### What changes were proposed in this pull request? This PR updates JDBC Integration Test DBMS Docker Images. | DBMS | Docker Image Tag | Release | | ------ | ------------------ | ------ | | MySQL | mysql:5.7.28 | Oct 13, 2019 | | PostgreSQL | postgres:12.0-alpine | Oct 3, 2019 | * For `MySQL`, `SET GLOBAL sql_mode = ''` is added to disable all strict modes because `test("Basic write test")` creates a table like the following. The latest MySQL rejects `0000-00-00 00:00:00` as TIMESTAMP and causes the test case failure. ``` mysql> desc datescopy; +-------+-----------+------+-----+---------------------+-----------------------------+ | Field | Type | Null | Key | Default | Extra | +-------+-----------+------+-----+---------------------+-----------------------------+ | d | date | YES | | NULL | | | t | timestamp | NO | | CURRENT_TIMESTAMP | on update CURRENT_TIMESTAMP | | dt | timestamp | NO | | 0000-00-00 00:00:00 | | | ts | timestamp | NO | | 0000-00-00 00:00:00 | | | yr | date | YES | | NULL | | +-------+-----------+------+-----+---------------------+-----------------------------+ ``` * For `PostgreSQL`, I chose the smallest image in `12` releases. It reduces the image size a lot, `312MB` -> `72.8MB`. This is good for CI/CI testing environment. ``` $ docker images | grep postgres postgres 12.0-alpine 5b681acb1cfc 2 days ago 72.8MB postgres 11.4 53912975086f 3 months ago 312MB ``` Note that - For `MsSqlServer`, we are using `2017-GA-ubuntu` and the next version `2019-CTP3.2-ubuntu` is still `Community Technology Preview` status. - For `DB2` and `Oracle`, the official images are not available. ### Why are the changes needed? This is to make it sure we are testing with the latest DBMS images during preparing `3.0.0`. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Since this is the integration test, we need to run this manually. ``` build/mvn install -DskipTests build/mvn -Pdocker-integration-tests -pl :spark-docker-integration-tests_2.12 test ``` Closes #26224 from dongjoon-hyun/SPARK-29567. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala | 4 +++- .../org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index 9cd5c4ec41a5..bba1b5275269 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.tags.DockerTest @DockerTest class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { override val db = new DatabaseOnDocker { - override val imageName = "mysql:5.7.9" + override val imageName = "mysql:5.7.28" override val env = Map( "MYSQL_ROOT_PASSWORD" -> "rootpass" ) @@ -39,6 +39,8 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { } override def dataPreparation(conn: Connection): Unit = { + // Since MySQL 5.7.14+, we need to disable strict mode + conn.prepareStatement("SET GLOBAL sql_mode = ''").executeUpdate() conn.prepareStatement("CREATE DATABASE foo").executeUpdate() conn.prepareStatement("CREATE TABLE tbl (x INTEGER, y TEXT(8))").executeUpdate() conn.prepareStatement("INSERT INTO tbl VALUES (42,'fred')").executeUpdate() diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 89da9a1de6f7..599f00def075 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.tags.DockerTest @DockerTest class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { override val db = new DatabaseOnDocker { - override val imageName = "postgres:11.4" + override val imageName = "postgres:12.0-alpine" override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) From fd899d6331f4777a36a3f2a79a6b2fa123dccc1a Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Wed, 23 Oct 2019 18:17:37 -0700 Subject: [PATCH 28/58] [SPARK-29576][CORE] Use Spark's CompressionCodec for Ser/Deser of MapOutputStatus ### What changes were proposed in this pull request? Instead of using ZStd codec directly, we use Spark's CompressionCodec which wraps ZStd codec in a buffered stream to avoid overhead excessive of JNI call while trying to compress/decompress small amount of data. Also, by using Spark's CompressionCodec, we can easily to make it configurable in the future if it's needed. ### Why are the changes needed? Faster performance. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Existing tests. Closes #26235 from dbtsai/optimizeDeser. Lead-authored-by: DB Tsai Co-authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- ...tatusesSerDeserBenchmark-jdk11-results.txt | 38 +++++----- .../MapStatusesSerDeserBenchmark-results.txt | 38 +++++----- .../org/apache/spark/MapOutputTracker.scala | 70 ++++++++----------- .../spark/MapStatusesSerDeserBenchmark.scala | 9 +-- 4 files changed, 72 insertions(+), 83 deletions(-) diff --git a/core/benchmarks/MapStatusesSerDeserBenchmark-jdk11-results.txt b/core/benchmarks/MapStatusesSerDeserBenchmark-jdk11-results.txt index 7a6cfb7b23b9..db23cf5c12ea 100644 --- a/core/benchmarks/MapStatusesSerDeserBenchmark-jdk11-results.txt +++ b/core/benchmarks/MapStatusesSerDeserBenchmark-jdk11-results.txt @@ -2,10 +2,10 @@ OpenJDK 64-Bit Server VM 11.0.4+11-post-Ubuntu-1ubuntu218.04.3 on Linux 4.15.0-1 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz 200000 MapOutputs, 10 blocks w/ broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Serialization 205 213 13 1.0 1023.6 1.0X -Deserialization 908 939 27 0.2 4540.2 0.2X +Serialization 170 178 9 1.2 849.7 1.0X +Deserialization 530 535 9 0.4 2651.1 0.3X -Compressed Serialized MapStatus sizes: 400 bytes +Compressed Serialized MapStatus sizes: 411 bytes Compressed Serialized Broadcast MapStatus sizes: 2 MB @@ -13,8 +13,8 @@ OpenJDK 64-Bit Server VM 11.0.4+11-post-Ubuntu-1ubuntu218.04.3 on Linux 4.15.0-1 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz 200000 MapOutputs, 10 blocks w/o broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Serialization 195 204 24 1.0 976.9 1.0X -Deserialization 913 940 33 0.2 4566.7 0.2X +Serialization 157 165 7 1.3 785.4 1.0X +Deserialization 495 588 79 0.4 2476.7 0.3X Compressed Serialized MapStatus sizes: 2 MB Compressed Serialized Broadcast MapStatus sizes: 0 bytes @@ -24,21 +24,21 @@ OpenJDK 64-Bit Server VM 11.0.4+11-post-Ubuntu-1ubuntu218.04.3 on Linux 4.15.0-1 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz 200000 MapOutputs, 100 blocks w/ broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Serialization 616 619 3 0.3 3079.1 1.0X -Deserialization 936 954 22 0.2 4680.5 0.7X +Serialization 344 351 4 0.6 1720.4 1.0X +Deserialization 527 579 99 0.4 2635.9 0.7X -Compressed Serialized MapStatus sizes: 418 bytes -Compressed Serialized Broadcast MapStatus sizes: 14 MB +Compressed Serialized MapStatus sizes: 427 bytes +Compressed Serialized Broadcast MapStatus sizes: 13 MB OpenJDK 64-Bit Server VM 11.0.4+11-post-Ubuntu-1ubuntu218.04.3 on Linux 4.15.0-1044-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz 200000 MapOutputs, 100 blocks w/o broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Serialization 586 588 3 0.3 2928.8 1.0X -Deserialization 929 933 4 0.2 4647.0 0.6X +Serialization 317 321 4 0.6 1583.8 1.0X +Deserialization 530 540 15 0.4 2648.3 0.6X -Compressed Serialized MapStatus sizes: 14 MB +Compressed Serialized MapStatus sizes: 13 MB Compressed Serialized Broadcast MapStatus sizes: 0 bytes @@ -46,21 +46,21 @@ OpenJDK 64-Bit Server VM 11.0.4+11-post-Ubuntu-1ubuntu218.04.3 on Linux 4.15.0-1 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz 200000 MapOutputs, 1000 blocks w/ broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Serialization 4740 4916 249 0.0 23698.5 1.0X -Deserialization 1578 1597 27 0.1 7890.6 3.0X +Serialization 1738 1849 156 0.1 8692.0 1.0X +Deserialization 946 977 33 0.2 4730.2 1.8X -Compressed Serialized MapStatus sizes: 546 bytes -Compressed Serialized Broadcast MapStatus sizes: 123 MB +Compressed Serialized MapStatus sizes: 556 bytes +Compressed Serialized Broadcast MapStatus sizes: 121 MB OpenJDK 64-Bit Server VM 11.0.4+11-post-Ubuntu-1ubuntu218.04.3 on Linux 4.15.0-1044-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz 200000 MapOutputs, 1000 blocks w/o broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Serialization 4492 4573 115 0.0 22458.3 1.0X -Deserialization 1533 1547 20 0.1 7664.8 2.9X +Serialization 1379 1432 76 0.1 6892.6 1.0X +Deserialization 929 941 19 0.2 4645.5 1.5X -Compressed Serialized MapStatus sizes: 123 MB +Compressed Serialized MapStatus sizes: 121 MB Compressed Serialized Broadcast MapStatus sizes: 0 bytes diff --git a/core/benchmarks/MapStatusesSerDeserBenchmark-results.txt b/core/benchmarks/MapStatusesSerDeserBenchmark-results.txt index 0c649694f6b6..053f4bf77192 100644 --- a/core/benchmarks/MapStatusesSerDeserBenchmark-results.txt +++ b/core/benchmarks/MapStatusesSerDeserBenchmark-results.txt @@ -2,10 +2,10 @@ OpenJDK 64-Bit Server VM 1.8.0_222-8u222-b10-1ubuntu1~18.04.1-b10 on Linux 4.15. Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz 200000 MapOutputs, 10 blocks w/ broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Serialization 236 245 18 0.8 1179.1 1.0X -Deserialization 842 885 37 0.2 4211.4 0.3X +Serialization 178 187 15 1.1 887.5 1.0X +Deserialization 530 558 32 0.4 2647.5 0.3X -Compressed Serialized MapStatus sizes: 400 bytes +Compressed Serialized MapStatus sizes: 411 bytes Compressed Serialized Broadcast MapStatus sizes: 2 MB @@ -13,8 +13,8 @@ OpenJDK 64-Bit Server VM 1.8.0_222-8u222-b10-1ubuntu1~18.04.1-b10 on Linux 4.15. Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz 200000 MapOutputs, 10 blocks w/o broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Serialization 213 219 8 0.9 1065.1 1.0X -Deserialization 846 870 33 0.2 4228.6 0.3X +Serialization 167 175 7 1.2 835.7 1.0X +Deserialization 523 537 22 0.4 2616.2 0.3X Compressed Serialized MapStatus sizes: 2 MB Compressed Serialized Broadcast MapStatus sizes: 0 bytes @@ -24,21 +24,21 @@ OpenJDK 64-Bit Server VM 1.8.0_222-8u222-b10-1ubuntu1~18.04.1-b10 on Linux 4.15. Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz 200000 MapOutputs, 100 blocks w/ broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Serialization 624 709 167 0.3 3121.1 1.0X -Deserialization 885 908 22 0.2 4427.0 0.7X +Serialization 351 416 147 0.6 1754.4 1.0X +Deserialization 546 551 8 0.4 2727.6 0.6X -Compressed Serialized MapStatus sizes: 418 bytes -Compressed Serialized Broadcast MapStatus sizes: 14 MB +Compressed Serialized MapStatus sizes: 427 bytes +Compressed Serialized Broadcast MapStatus sizes: 13 MB OpenJDK 64-Bit Server VM 1.8.0_222-8u222-b10-1ubuntu1~18.04.1-b10 on Linux 4.15.0-1044-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz 200000 MapOutputs, 100 blocks w/o broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Serialization 603 604 2 0.3 3014.9 1.0X -Deserialization 892 895 5 0.2 4458.7 0.7X +Serialization 320 321 1 0.6 1598.0 1.0X +Deserialization 542 549 7 0.4 2709.0 0.6X -Compressed Serialized MapStatus sizes: 14 MB +Compressed Serialized MapStatus sizes: 13 MB Compressed Serialized Broadcast MapStatus sizes: 0 bytes @@ -46,21 +46,21 @@ OpenJDK 64-Bit Server VM 1.8.0_222-8u222-b10-1ubuntu1~18.04.1-b10 on Linux 4.15. Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz 200000 MapOutputs, 1000 blocks w/ broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Serialization 4612 4945 471 0.0 23061.0 1.0X -Deserialization 1493 1495 2 0.1 7466.3 3.1X +Serialization 1671 1877 290 0.1 8357.3 1.0X +Deserialization 943 970 32 0.2 4715.8 1.8X -Compressed Serialized MapStatus sizes: 546 bytes -Compressed Serialized Broadcast MapStatus sizes: 123 MB +Compressed Serialized MapStatus sizes: 556 bytes +Compressed Serialized Broadcast MapStatus sizes: 121 MB OpenJDK 64-Bit Server VM 1.8.0_222-8u222-b10-1ubuntu1~18.04.1-b10 on Linux 4.15.0-1044-aws Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz 200000 MapOutputs, 1000 blocks w/o broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Serialization 4452 4595 202 0.0 22261.4 1.0X -Deserialization 1464 1477 18 0.1 7321.4 3.0X +Serialization 1373 1436 89 0.1 6865.0 1.0X +Deserialization 940 970 37 0.2 4699.1 1.5X -Compressed Serialized MapStatus sizes: 123 MB +Compressed Serialized MapStatus sizes: 121 MB Compressed Serialized Broadcast MapStatus sizes: 0 bytes diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 6f4a6239a09e..873efa76468e 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -28,13 +28,12 @@ import scala.concurrent.duration.Duration import scala.reflect.ClassTag import scala.util.control.NonFatal -import com.github.luben.zstd.ZstdInputStream -import com.github.luben.zstd.ZstdOutputStream import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOutputStream} import org.apache.spark.broadcast.{Broadcast, BroadcastManager} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.io.CompressionCodec import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, MapStatus} import org.apache.spark.shuffle.MetadataFetchFailedException @@ -195,7 +194,8 @@ private class ShuffleStatus(numPartitions: Int) { def serializedMapStatus( broadcastManager: BroadcastManager, isLocal: Boolean, - minBroadcastSize: Int): Array[Byte] = { + minBroadcastSize: Int, + conf: SparkConf): Array[Byte] = { var result: Array[Byte] = null withReadLock { @@ -207,7 +207,7 @@ private class ShuffleStatus(numPartitions: Int) { if (result == null) withWriteLock { if (cachedSerializedMapStatus == null) { val serResult = MapOutputTracker.serializeMapStatuses( - mapStatuses, broadcastManager, isLocal, minBroadcastSize) + mapStatuses, broadcastManager, isLocal, minBroadcastSize, conf) cachedSerializedMapStatus = serResult._1 cachedSerializedBroadcast = serResult._2 } @@ -450,7 +450,8 @@ private[spark] class MapOutputTrackerMaster( " to " + hostPort) val shuffleStatus = shuffleStatuses.get(shuffleId).head context.reply( - shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast)) + shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast, + conf)) } catch { case NonFatal(e) => logError(e.getMessage, e) } @@ -799,7 +800,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr endPartition: Int) : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") - val statuses = getStatuses(shuffleId) + val statuses = getStatuses(shuffleId, conf) try { MapOutputTracker.convertMapStatuses( shuffleId, startPartition, endPartition, statuses) @@ -818,7 +819,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, mapIndex $mapIndex" + s"partitions $startPartition-$endPartition") - val statuses = getStatuses(shuffleId) + val statuses = getStatuses(shuffleId, conf) try { MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses, Some(mapIndex)) @@ -836,7 +837,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr * * (It would be nice to remove this restriction in the future.) */ - private def getStatuses(shuffleId: Int): Array[MapStatus] = { + private def getStatuses(shuffleId: Int, conf: SparkConf): Array[MapStatus] = { val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") @@ -846,7 +847,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr if (fetchedStatuses == null) { logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) - fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) + fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes, conf) logInfo("Got the output locations") mapStatuses.put(shuffleId, fetchedStatuses) } @@ -890,16 +891,20 @@ private[spark] object MapOutputTracker extends Logging { // Serialize an array of map output locations into an efficient byte format so that we can send // it to reduce tasks. We do this by compressing the serialized bytes using Zstd. They will // generally be pretty compressible because many map outputs will be on the same hostname. - def serializeMapStatuses(statuses: Array[MapStatus], broadcastManager: BroadcastManager, - isLocal: Boolean, minBroadcastSize: Int): (Array[Byte], Broadcast[Array[Byte]]) = { + def serializeMapStatuses( + statuses: Array[MapStatus], + broadcastManager: BroadcastManager, + isLocal: Boolean, + minBroadcastSize: Int, + conf: SparkConf): (Array[Byte], Broadcast[Array[Byte]]) = { // Using `org.apache.commons.io.output.ByteArrayOutputStream` instead of the standard one // This implementation doesn't reallocate the whole memory block but allocates // additional buffers. This way no buffers need to be garbage collected and // the contents don't have to be copied to the new buffer. val out = new ApacheByteArrayOutputStream() - val compressedOut = new ApacheByteArrayOutputStream() - - val objOut = new ObjectOutputStream(out) + out.write(DIRECT) + val codec = CompressionCodec.createCodec(conf, "zstd") + val objOut = new ObjectOutputStream(codec.compressedOutputStream(out)) Utils.tryWithSafeFinally { // Since statuses can be modified in parallel, sync on it statuses.synchronized { @@ -908,42 +913,21 @@ private[spark] object MapOutputTracker extends Logging { } { objOut.close() } - - val arr: Array[Byte] = { - val zos = new ZstdOutputStream(compressedOut) - Utils.tryWithSafeFinally { - compressedOut.write(DIRECT) - // `out.writeTo(zos)` will write the uncompressed data from `out` to `zos` - // without copying to avoid unnecessary allocation and copy of byte[]. - out.writeTo(zos) - } { - zos.close() - } - compressedOut.toByteArray - } + val arr = out.toByteArray if (arr.length >= minBroadcastSize) { // Use broadcast instead. // Important arr(0) is the tag == DIRECT, ignore that while deserializing ! val bcast = broadcastManager.newBroadcast(arr, isLocal) // toByteArray creates copy, so we can reuse out out.reset() - val oos = new ObjectOutputStream(out) + out.write(BROADCAST) + val oos = new ObjectOutputStream(codec.compressedOutputStream(out)) Utils.tryWithSafeFinally { oos.writeObject(bcast) } { oos.close() } - val outArr = { - compressedOut.reset() - val zos = new ZstdOutputStream(compressedOut) - Utils.tryWithSafeFinally { - compressedOut.write(BROADCAST) - out.writeTo(zos) - } { - zos.close() - } - compressedOut.toByteArray - } + val outArr = out.toByteArray logInfo("Broadcast mapstatuses size = " + outArr.length + ", actual size = " + arr.length) (outArr, bcast) } else { @@ -952,11 +936,15 @@ private[spark] object MapOutputTracker extends Logging { } // Opposite of serializeMapStatuses. - def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = { + def deserializeMapStatuses(bytes: Array[Byte], conf: SparkConf): Array[MapStatus] = { assert (bytes.length > 0) def deserializeObject(arr: Array[Byte], off: Int, len: Int): AnyRef = { - val objIn = new ObjectInputStream(new ZstdInputStream( + val codec = CompressionCodec.createCodec(conf, "zstd") + // The ZStd codec is wrapped in a `BufferedInputStream` which avoids overhead excessive + // of JNI call while trying to decompress small amount of data for each element + // of `MapStatuses` + val objIn = new ObjectInputStream(codec.compressedInputStream( new ByteArrayInputStream(arr, off, len))) Utils.tryWithSafeFinally { objIn.readObject() diff --git a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala index 53afe141981f..5dbef88e73a9 100644 --- a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala @@ -67,19 +67,20 @@ object MapStatusesSerDeserBenchmark extends BenchmarkBase { var serializedBroadcastSizes = 0 val (serializedMapStatus, serializedBroadcast) = MapOutputTracker.serializeMapStatuses( - shuffleStatus.mapStatuses, tracker.broadcastManager, tracker.isLocal, minBroadcastSize) + shuffleStatus.mapStatuses, tracker.broadcastManager, tracker.isLocal, minBroadcastSize, + sc.getConf) serializedMapStatusSizes = serializedMapStatus.length if (serializedBroadcast != null) { serializedBroadcastSizes = serializedBroadcast.value.length } benchmark.addCase("Serialization") { _ => - MapOutputTracker.serializeMapStatuses( - shuffleStatus.mapStatuses, tracker.broadcastManager, tracker.isLocal, minBroadcastSize) + MapOutputTracker.serializeMapStatuses(shuffleStatus.mapStatuses, tracker.broadcastManager, + tracker.isLocal, minBroadcastSize, sc.getConf) } benchmark.addCase("Deserialization") { _ => - val result = MapOutputTracker.deserializeMapStatuses(serializedMapStatus) + val result = MapOutputTracker.deserializeMapStatuses(serializedMapStatus, sc.getConf) assert(result.length == numMaps) } From 55ced9c148a0c47134f3218ebd2f7fb5bea944eb Mon Sep 17 00:00:00 2001 From: 07ARB Date: Thu, 24 Oct 2019 15:57:16 +0900 Subject: [PATCH 29/58] [SPARK-29571][SQL][TESTS][FOLLOWUP] Fix UT in AllExecutionsPageSuite ### What changes were proposed in this pull request? This is a follow-up of #24052 to correct assert condition. ### Why are the changes needed? To test IllegalArgumentException condition.. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Manual Test (during fixing of SPARK-29453 find this issue) Closes #26234 from 07ARB/SPARK-29571. Authored-by: 07ARB Signed-off-by: HyukjinKwon --- .../apache/spark/sql/execution/ui/AllExecutionsPageSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/AllExecutionsPageSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/AllExecutionsPageSuite.scala index 9e42056c19a0..298afa880c93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/AllExecutionsPageSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/AllExecutionsPageSuite.scala @@ -73,7 +73,7 @@ class AllExecutionsPageSuite extends SharedSparkSession with BeforeAndAfter { map.put("failed.sort", Array("duration")) when(request.getParameterMap()).thenReturn(map) val html = renderSQLPage(request, tab, statusStore).toString().toLowerCase(Locale.ROOT) - assert(!html.contains("IllegalArgumentException")) + assert(!html.contains("illegalargumentexception")) assert(html.contains("duration")) } From 177bf672e47977cbb6ccfd88f3ec77687c1fdebe Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 24 Oct 2019 15:00:21 +0800 Subject: [PATCH 30/58] [SPARK-29522][SQL] CACHE TABLE should look up catalog/table like v2 commands ### What changes were proposed in this pull request? Add CacheTableStatement and make CACHE TABLE go through the same catalog/table resolution framework of v2 commands. ### Why are the changes needed? It's important to make all the commands have the same table resolution behavior, to avoid confusing end-users. e.g. ``` USE my_catalog DESC t // success and describe the table t from my_catalog CACHE TABLE t // report table not found as there is no table t in the session catalog ``` ### Does this PR introduce any user-facing change? yes. When running CACHE TABLE, Spark fails the command if the current catalog is set to a v2 catalog, or the table name specified a v2 catalog. ### How was this patch tested? Unit tests. Closes #26179 from viirya/SPARK-29522. Lead-authored-by: Liang-Chi Hsieh Co-authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 24 +++++++++++++++++++ .../catalyst/plans/logical/statements.scala | 9 +++++++ .../sql/catalyst/parser/DDLParserSuite.scala | 17 +++++++++++++ .../analysis/ResolveSessionCatalog.scala | 6 ++++- .../spark/sql/execution/SparkSqlParser.scala | 15 ------------ .../sql/connector/DataSourceV2SQLSuite.scala | 14 +++++++++++ .../spark/sql/hive/CachedTableSuite.scala | 2 +- 8 files changed, 71 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 970d244071e0..01cd181010f9 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -205,7 +205,7 @@ statement | (DESC | DESCRIBE) QUERY? query #describeQuery | REFRESH TABLE multipartIdentifier #refreshTable | REFRESH (STRING | .*?) #refreshResource - | CACHE LAZY? TABLE tableIdentifier + | CACHE LAZY? TABLE multipartIdentifier (OPTIONS options=tablePropertyList)? (AS? query)? #cacheTable | UNCACHE TABLE (IF EXISTS)? tableIdentifier #uncacheTable | CLEAR CACHE #clearCache diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 940dfd0fc333..99e5c9feb8fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2769,6 +2769,30 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging RepairTableStatement(visitMultipartIdentifier(ctx.multipartIdentifier())) } + /** + * Create a [[CacheTableStatement]]. + * + * For example: + * {{{ + * CACHE [LAZY] TABLE multi_part_name + * [OPTIONS tablePropertyList] [[AS] query] + * }}} + */ + override def visitCacheTable(ctx: CacheTableContext): LogicalPlan = withOrigin(ctx) { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + val query = Option(ctx.query).map(plan) + val tableName = visitMultipartIdentifier(ctx.multipartIdentifier) + if (query.isDefined && tableName.length > 1) { + val catalogAndNamespace = tableName.init + throw new ParseException("It is not allowed to add catalog/namespace " + + s"prefix ${catalogAndNamespace.quoted} to " + + "the table name in CACHE TABLE AS SELECT", ctx) + } + val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) + CacheTableStatement(tableName, query, ctx.LAZY != null, options) + } + /** * Create a [[TruncateTableStatement]] command. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index 127d9026f802..4a91ee6d52d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -330,6 +330,15 @@ case class AnalyzeColumnStatement( */ case class RepairTableStatement(tableName: Seq[String]) extends ParsedStatement +/** + * A CACHE TABLE statement, as parsed from SQL + */ +case class CacheTableStatement( + tableName: Seq[String], + plan: Option[LogicalPlan], + isLazy: Boolean, + options: Map[String, String]) extends ParsedStatement + /** * A TRUNCATE TABLE statement, as parsed from SQL */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 8e605bd15f69..37349f7a3342 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -1045,6 +1045,23 @@ class DDLParserSuite extends AnalysisTest { RepairTableStatement(Seq("a", "b", "c"))) } + test("CACHE table") { + comparePlans( + parsePlan("CACHE TABLE a.b.c"), + CacheTableStatement(Seq("a", "b", "c"), None, false, Map.empty)) + + comparePlans( + parsePlan("CACHE LAZY TABLE a.b.c"), + CacheTableStatement(Seq("a", "b", "c"), None, true, Map.empty)) + + comparePlans( + parsePlan("CACHE LAZY TABLE a.b.c OPTIONS('storageLevel' 'DISK_ONLY')"), + CacheTableStatement(Seq("a", "b", "c"), None, true, Map("storageLevel" -> "DISK_ONLY"))) + + intercept("CACHE TABLE a.b.c AS SELECT * FROM testData", + "It is not allowed to add catalog/namespace prefix a.b") + } + test("TRUNCATE table") { comparePlans( parsePlan("TRUNCATE TABLE a.b.c"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 230b8f3906bd..65d95b600eaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog, TableChange, V1Table} import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableRecoverPartitionsCommand, AlterTableSetLocationCommand, AlterTableSetPropertiesCommand, AlterTableUnsetPropertiesCommand, AnalyzeColumnCommand, AnalyzePartitionCommand, AnalyzeTableCommand, CreateDatabaseCommand, DescribeColumnCommand, DescribeTableCommand, DropTableCommand, ShowPartitionsCommand, ShowTablesCommand, TruncateTableCommand} +import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableRecoverPartitionsCommand, AlterTableSetLocationCommand, AlterTableSetPropertiesCommand, AlterTableUnsetPropertiesCommand, AnalyzeColumnCommand, AnalyzePartitionCommand, AnalyzeTableCommand, CacheTableCommand, CreateDatabaseCommand, DescribeColumnCommand, DescribeTableCommand, DropTableCommand, ShowPartitionsCommand, ShowTablesCommand, TruncateTableCommand} import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, RefreshTable} import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.internal.SQLConf @@ -299,6 +299,10 @@ class ResolveSessionCatalog( v1TableName.asTableIdentifier, "MSCK REPAIR TABLE") + case CacheTableStatement(tableName, plan, isLazy, options) => + val v1TableName = parseV1Table(tableName, "CACHE TABLE") + CacheTableCommand(v1TableName.asTableIdentifier, plan, isLazy, options) + case TruncateTableStatement(tableName, partitionSpec) => val v1TableName = parseV1Table(tableName, "TRUNCATE TABLE") TruncateTableCommand( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 2439621f7725..fb13d01bd91d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -165,21 +165,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { unquotedPath } - /** - * Create a [[CacheTableCommand]] logical plan. - */ - override def visitCacheTable(ctx: CacheTableContext): LogicalPlan = withOrigin(ctx) { - val query = Option(ctx.query).map(plan) - val tableIdent = visitTableIdentifier(ctx.tableIdentifier) - if (query.isDefined && tableIdent.database.isDefined) { - val database = tableIdent.database.get - throw new ParseException(s"It is not allowed to add database prefix `$database` to " + - s"the table name in CACHE TABLE AS SELECT", ctx) - } - val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) - CacheTableCommand(tableIdent, query, ctx.LAZY != null, options) - } - /** * Create an [[UncacheTableCommand]] logical plan. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 463147903c92..4d1e70f68ba0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1268,6 +1268,20 @@ class DataSourceV2SQLSuite } } + test("CACHE TABLE") { + val t = "testcat.ns1.ns2.tbl" + withTable(t) { + spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo") + + testV1Command("CACHE TABLE", t) + + val e = intercept[AnalysisException] { + sql(s"CACHE LAZY TABLE $t") + } + assert(e.message.contains("CACHE TABLE is only supported with v1 tables")) + } + } + private def testV1Command(sqlCommand: String, sqlParams: String): Unit = { val e = intercept[AnalysisException] { sql(s"$sqlCommand $sqlParams") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 4cbc03d05c9e..7b3fb6817423 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -300,7 +300,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto val e = intercept[ParseException] { sql(s"CACHE TABLE $db.cachedTable AS SELECT 1") }.getMessage - assert(e.contains("It is not allowed to add database prefix ") && + assert(e.contains("It is not allowed to add catalog/namespace prefix ") && e.contains("to the table name in CACHE TABLE AS SELECT")) } } From 9e77d483158a6e9edc0b5b4c642e1231773be9ee Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Thu, 24 Oct 2019 15:43:13 +0800 Subject: [PATCH 31/58] [SPARK-21492][SQL][FOLLOW UP] Reimplement UnsafeExternalRowSorter in database style iterator ### What changes were proposed in this pull request? Reimplement the iterator in UnsafeExternalRowSorter in database style. This can be done by reusing the `RowIterator` in our code base. ### Why are the changes needed? During the job in #26164, after involving a var `isReleased` in `hasNext`, there's possible that `isReleased` is false when calling `hasNext`, but it becomes true before calling `next`. A safer way is using database-style iterator: `advanceNext` and `getRow`. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Existing UT. Closes #26229 from xuanyuanking/SPARK-21492-follow-up. Authored-by: Yuanjian Li Signed-off-by: Wenchen Fan --- .../spark/sql/execution/RowIterator.scala | 0 .../execution/UnsafeExternalRowSorter.java | 48 ++++++++++--------- 2 files changed, 26 insertions(+), 22 deletions(-) rename sql/{core => catalyst}/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala (100%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala similarity index 100% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 3123f2187da8..90b55a8586de 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -20,7 +20,6 @@ import java.io.IOException; import java.util.function.Supplier; -import scala.collection.AbstractIterator; import scala.collection.Iterator; import scala.math.Ordering; @@ -168,7 +167,7 @@ public void cleanupResources() { sorter.cleanupResources(); } - public Iterator sort() throws IOException { + public Iterator sort() throws IOException { try { final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); if (!sortedIterator.hasNext()) { @@ -176,31 +175,32 @@ public Iterator sort() throws IOException { // here in order to prevent memory leaks. cleanupResources(); } - return new AbstractIterator() { + return new RowIterator() { private final int numFields = schema.length(); private UnsafeRow row = new UnsafeRow(numFields); @Override - public boolean hasNext() { - return !isReleased && sortedIterator.hasNext(); - } - - @Override - public UnsafeRow next() { + public boolean advanceNext() { try { - sortedIterator.loadNext(); - row.pointTo( - sortedIterator.getBaseObject(), - sortedIterator.getBaseOffset(), - sortedIterator.getRecordLength()); - if (!hasNext()) { - UnsafeRow copy = row.copy(); // so that we don't have dangling pointers to freed page - row = null; // so that we don't keep references to the base object - cleanupResources(); - return copy; + if (!isReleased && sortedIterator.hasNext()) { + sortedIterator.loadNext(); + row.pointTo( + sortedIterator.getBaseObject(), + sortedIterator.getBaseOffset(), + sortedIterator.getRecordLength()); + // Here is the initial bug fix in SPARK-9364: the bug fix of use-after-free bug + // when returning the last row from an iterator. For example, in + // [[GroupedIterator]], we still use the last row after traversing the iterator + // in `fetchNextGroupIterator` + if (!sortedIterator.hasNext()) { + row = row.copy(); // so that we don't have dangling pointers to freed page + cleanupResources(); + } + return true; } else { - return row; + row = null; // so that we don't keep references to the base object + return false; } } catch (IOException e) { cleanupResources(); @@ -210,14 +210,18 @@ public UnsafeRow next() { } throw new RuntimeException("Exception should have been re-thrown in next()"); } - }; + + @Override + public UnsafeRow getRow() { return row; } + + }.toScala(); } catch (IOException e) { cleanupResources(); throw e; } } - public Iterator sort(Iterator inputIterator) throws IOException { + public Iterator sort(Iterator inputIterator) throws IOException { while (inputIterator.hasNext()) { insertRow(inputIterator.next()); } From 1296bbb8ac7f582f8689f3e2f36614cf541b80d4 Mon Sep 17 00:00:00 2001 From: Pavithra Ramachandran Date: Thu, 24 Oct 2019 11:14:31 +0200 Subject: [PATCH 32/58] [SPARK-29504][WEBUI] Toggle full job description on click ### What changes were proposed in this pull request? On clicking job description in jobs page, the description was not shown fully. Add the function for the click event on description. ### Why are the changes needed? when there is a long description of a job, it cannot be seen fully in the UI. The feature was added in https://github.com/apache/spark/pull/24145 But it is missed after https://github.com/apache/spark/pull/25374 Before change: ![Screenshot from 2019-10-23 11-23-00](https://user-images.githubusercontent.com/51401130/67361914-827b0080-f587-11e9-9181-d49a6a836046.png) After change: on Double click over decription ![Screenshot from 2019-10-23 11-20-02](https://user-images.githubusercontent.com/51401130/67361936-932b7680-f587-11e9-9e59-d290abed4b70.png) ### Does this PR introduce any user-facing change? No ### How was this patch tested? Manually test Closes #26222 from PavithraRamachandran/jobs_description_tooltip. Authored-by: Pavithra Ramachandran Signed-off-by: Gengliang Wang --- .../main/resources/org/apache/spark/ui/static/webui.js | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js index 89622106ff1f..cf04db28804c 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.js +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js @@ -87,4 +87,11 @@ $(function() { collapseTablePageLoad('collapse-aggregated-runningExecutions','aggregated-runningExecutions'); collapseTablePageLoad('collapse-aggregated-completedExecutions','aggregated-completedExecutions'); collapseTablePageLoad('collapse-aggregated-failedExecutions','aggregated-failedExecutions'); -}); \ No newline at end of file +}); + +$(function() { + // Show/hide full job description on click event. + $(".description-input").click(function() { + $(this).toggleClass("description-input-full"); + }); +}); From 67cf0433ee4e4a7c33e0092e887bba53ad35627e Mon Sep 17 00:00:00 2001 From: angerszhu Date: Thu, 24 Oct 2019 21:55:03 +0900 Subject: [PATCH 33/58] [SPARK-29145][SQL] Support sub-queries in join conditions ### What changes were proposed in this pull request? Support SparkSQL use iN/EXISTS with subquery in JOIN condition. ### Why are the changes needed? Support SQL use iN/EXISTS with subquery in JOIN condition. ### Does this PR introduce any user-facing change? This PR is for enable user use subquery in `JOIN`'s ON condition. such as we have create three table ``` CREATE TABLE A(id String); CREATE TABLE B(id String); CREATE TABLE C(id String); ``` we can do query like : ``` SELECT A.id from A JOIN B ON A.id = B.id and A.id IN (select C.id from C) ``` ### How was this patch tested? ADDED UT Closes #25854 from AngersZhuuuu/SPARK-29145. Lead-authored-by: angerszhu Co-authored-by: AngersZhuuuu Signed-off-by: Takeshi Yamamuro --- .../sql/catalyst/analysis/Analyzer.scala | 2 + .../sql/catalyst/analysis/CheckAnalysis.scala | 4 +- .../analysis/ResolveSubquerySuite.scala | 14 +- .../org/apache/spark/sql/SubquerySuite.scala | 148 ++++++++++++++++++ 4 files changed, 165 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b913a9618d6e..21bf926af50d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1697,6 +1697,8 @@ class Analyzer( // Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries. case q: UnaryNode if q.childrenResolved => resolveSubQueries(q, q.children) + case j: Join if j.childrenResolved => + resolveSubQueries(j, Seq(j, j.left, j.right)) case s: SupportsSubquery if s.childrenResolved => resolveSubQueries(s, s.children) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 6a5d938f0fdc..d9dc9ebbcaf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -601,10 +601,10 @@ trait CheckAnalysis extends PredicateHelper { case inSubqueryOrExistsSubquery => plan match { - case _: Filter | _: SupportsSubquery => // Ok + case _: Filter | _: SupportsSubquery | _: Join => // Ok case _ => failAnalysis(s"IN/EXISTS predicate sub-queries can only be used in" + - s" Filter and a few commands: $plan") + s" Filter/Join and a few commands: $plan") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 74a8590b5eef..5aa80e1a9bd7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{InSubquery, ListQuery} -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, Project} +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical._ /** * Unit tests for [[ResolveSubquery]]. @@ -29,8 +30,10 @@ class ResolveSubquerySuite extends AnalysisTest { val a = 'a.int val b = 'b.int + val c = 'c.int val t1 = LocalRelation(a) val t2 = LocalRelation(b) + val t3 = LocalRelation(c) test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") { val expr = Filter( @@ -41,4 +44,13 @@ class ResolveSubquerySuite extends AnalysisTest { assert(m.contains( "Expressions referencing the outer query are not supported outside of WHERE/HAVING clauses")) } + + test("SPARK-29145 Support subquery in join condition") { + val expr = Join(t1, + t2, + Inner, + Some(InSubquery(Seq(a), ListQuery(Project(Seq(UnresolvedAttribute("c")), t3)))), + JoinHint.NONE) + assertAnalysisSuccess(expr) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index a1d7792941ed..266f8e23712d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -204,6 +204,154 @@ class SubquerySuite extends QueryTest with SharedSparkSession { } } + test("SPARK-29145: JOIN Condition use QueryList") { + withTempView("s1", "s2", "s3") { + Seq(1, 3, 5, 7, 9).toDF("id").createOrReplaceTempView("s1") + Seq(1, 3, 4, 6, 9).toDF("id").createOrReplaceTempView("s2") + Seq(3, 4, 6, 9).toDF("id").createOrReplaceTempView("s3") + + checkAnswer( + sql( + """ + | SELECT s1.id FROM s1 + | JOIN s2 ON s1.id = s2.id + | AND s1.id IN (SELECT 9) + """.stripMargin), + Row(9) :: Nil) + + checkAnswer( + sql( + """ + | SELECT s1.id FROM s1 + | JOIN s2 ON s1.id = s2.id + | AND s1.id NOT IN (SELECT 9) + """.stripMargin), + Row(1) :: Row(3) :: Nil) + + // case `IN` + checkAnswer( + sql( + """ + | SELECT s1.id FROM s1 + | JOIN s2 ON s1.id = s2.id + | AND s1.id IN (SELECT id FROM s3) + """.stripMargin), + Row(3) :: Row(9) :: Nil) + + checkAnswer( + sql( + """ + | SELECT s1.id AS id2 FROM s1 + | LEFT SEMI JOIN s2 + | ON s1.id = s2.id + | AND s1.id IN (SELECT id FROM s3) + """.stripMargin), + Row(3) :: Row(9) :: Nil) + + checkAnswer( + sql( + """ + | SELECT s1.id as id2 FROM s1 + | LEFT ANTI JOIN s2 + | ON s1.id = s2.id + | AND s1.id IN (SELECT id FROM s3) + """.stripMargin), + Row(1) :: Row(5) :: Row(7) :: Nil) + + checkAnswer( + sql( + """ + | SELECT s1.id, s2.id as id2 FROM s1 + | LEFT OUTER JOIN s2 + | ON s1.id = s2.id + | AND s1.id IN (SELECT id FROM s3) + """.stripMargin), + Row(1, null) :: Row(3, 3) :: Row(5, null) :: Row(7, null) :: Row(9, 9) :: Nil) + + checkAnswer( + sql( + """ + | SELECT s1.id, s2.id as id2 FROM s1 + | RIGHT OUTER JOIN s2 + | ON s1.id = s2.id + | AND s1.id IN (SELECT id FROM s3) + """.stripMargin), + Row(null, 1) :: Row(3, 3) :: Row(null, 4) :: Row(null, 6) :: Row(9, 9) :: Nil) + + checkAnswer( + sql( + """ + | SELECT s1.id, s2.id AS id2 FROM s1 + | FULL OUTER JOIN s2 + | ON s1.id = s2.id + | AND s1.id IN (SELECT id FROM s3) + """.stripMargin), + Row(1, null) :: Row(3, 3) :: Row(5, null) :: Row(7, null) :: Row(9, 9) :: + Row(null, 1) :: Row(null, 4) :: Row(null, 6) :: Nil) + + // case `NOT IN` + checkAnswer( + sql( + """ + | SELECT s1.id FROM s1 + | JOIN s2 ON s1.id = s2.id + | AND s1.id NOT IN (SELECT id FROM s3) + """.stripMargin), + Row(1) :: Nil) + + checkAnswer( + sql( + """ + | SELECT s1.id AS id2 FROM s1 + | LEFT SEMI JOIN s2 + | ON s1.id = s2.id + | AND s1.id NOT IN (SELECT id FROM s3) + """.stripMargin), + Row(1) :: Nil) + + checkAnswer( + sql( + """ + | SELECT s1.id AS id2 FROM s1 + | LEFT ANTI JOIN s2 + | ON s1.id = s2.id + | AND s1.id NOT IN (SELECT id FROM s3) + """.stripMargin), + Row(3) :: Row(5) :: Row(7) :: Row(9) :: Nil) + + checkAnswer( + sql( + """ + | SELECT s1.id, s2.id AS id2 FROM s1 + | LEFT OUTER JOIN s2 + | ON s1.id = s2.id + | AND s1.id NOT IN (SELECT id FROM s3) + """.stripMargin), + Row(1, 1) :: Row(3, null) :: Row(5, null) :: Row(7, null) :: Row(9, null) :: Nil) + + checkAnswer( + sql( + """ + | SELECT s1.id, s2.id AS id2 FROM s1 + | RIGHT OUTER JOIN s2 + | ON s1.id = s2.id + | AND s1.id NOT IN (SELECT id FROM s3) + """.stripMargin), + Row(1, 1) :: Row(null, 3) :: Row(null, 4) :: Row(null, 6) :: Row(null, 9) :: Nil) + + checkAnswer( + sql( + """ + | SELECT s1.id, s2.id AS id2 FROM s1 + | FULL OUTER JOIN s2 + | ON s1.id = s2.id + | AND s1.id NOT IN (SELECT id FROM s3) + """.stripMargin), + Row(1, 1) :: Row(3, null) :: Row(5, null) :: Row(7, null) :: Row(9, null) :: + Row(null, 3) :: Row(null, 4) :: Row(null, 6) :: Row(null, 9) :: Nil) + } + } + test("SPARK-14791: scalar subquery inside broadcast join") { val df = sql("select a, sum(b) as s from l group by a having a > (select avg(a) from l)") val expected = Row(3, 2.0, 3, 3.0) :: Row(6, null, 6, null) :: Nil From 1ec1b2bd17ac5f6040336680c79e2aa9765b353d Mon Sep 17 00:00:00 2001 From: Pavithra Ramachandran Date: Thu, 24 Oct 2019 08:19:03 -0500 Subject: [PATCH 34/58] [SPARK-28791][DOC] Documentation for Alter table Command What changes were proposed in this pull request? Document ALTER TABLE statement in SQL Reference Guide. Why are the changes needed? Adding documentation for SQL reference. Does this PR introduce any user-facing change? yes Before: There was no documentation for this. After. ![1](https://user-images.githubusercontent.com/51401130/65674372-1087c800-e06a-11e9-9155-ac70b419b069.png) ![2](https://user-images.githubusercontent.com/51401130/65674384-14b3e580-e06a-11e9-9c57-bca566dfdbc2.png) ![3](https://user-images.githubusercontent.com/51401130/65674391-18e00300-e06a-11e9-950a-6cc948dedd7d.png) ![4](https://user-images.githubusercontent.com/51401130/65674397-1bdaf380-e06a-11e9-87b0-b1523a745f83.png) ![5](https://user-images.githubusercontent.com/51401130/65674406-209fa780-e06a-11e9-8440-7e8105a77117.png) ![6](https://user-images.githubusercontent.com/51401130/65674417-23020180-e06a-11e9-8fff-30511836bb08.png) How was this patch tested? Used jekyll build and serve to verify. Closes #25590 from PavithraRamachandran/alter_doc. Authored-by: Pavithra Ramachandran Signed-off-by: Sean Owen --- docs/sql-ref-syntax-ddl-alter-table.md | 238 ++++++++++++++++++++++++- 1 file changed, 237 insertions(+), 1 deletion(-) diff --git a/docs/sql-ref-syntax-ddl-alter-table.md b/docs/sql-ref-syntax-ddl-alter-table.md index 7fcd39791582..e311691c6b80 100644 --- a/docs/sql-ref-syntax-ddl-alter-table.md +++ b/docs/sql-ref-syntax-ddl-alter-table.md @@ -19,4 +19,240 @@ license: | limitations under the License. --- -**This page is under construction** +### Description +`ALTER TABLE` statement changes the schema or properties of a table. + +### RENAME +`ALTER TABLE RENAME` statement changes the table name of an existing table in the database. + +#### Syntax +{% highlight sql %} +ALTER TABLE [db_name.]old_table_name RENAME TO [db_name.]new_table_name + +ALTER TABLE table_name PARTITION partition_spec RENAME TO PARTITION partition_spec; + +{% endhighlight %} + +#### Parameters +
+
old_table_name
+
Name of an existing table.
+
+
+
db_name
+
Name of the existing database.
+
+ +
+
new_table_name
+
New name using which the table has to be renamed.
+
+ +
+
partition_spec
+
Partition to be renamed.
+
+ + +### ADD COLUMNS +`ALTER TABLE ADD COLUMNS` statement adds mentioned columns to an existing table. + +#### Syntax +{% highlight sql %} +ALTER TABLE table_name ADD COLUMNS (col_spec[, col_spec ...]) +{% endhighlight %} + +#### Parameters +
+
table_name
+
The name of an existing table.
+
+ + +
+
COLUMNS (col_spec)
+
Specifies the columns to be added to be renamed.
+
+ + +### SET AND UNSET + +#### SET TABLE PROPERTIES +`ALTER TABLE SET` command is used for setting the table properties. If a particular property was already set, +this overrides the old value with the new one. + +`ALTER TABLE UNSET` is used to drop the table property. + +##### Syntax +{% highlight sql %} + +--Set Table Properties +ALTER TABLE table_name SET TBLPROPERTIES (key1=val1, key2=val2, ...) + +--Unset Table Properties +ALTER TABLE table_name UNSET TBLPROPERTIES [IF EXISTS] (key1, key2, ...) + +{% endhighlight %} + +#### SET SERDE +`ALTER TABLE SET` command is used for setting the SERDE or SERDE properties in Hive tables. If a particular property was already set, +this overrides the old value with the new one. + +##### Syntax +{% highlight sql %} + +--Set SERDE Propeties +ALTER TABLE table_name [PARTITION part_spec] + SET SERDEPROPERTIES (key1=val1, key2=val2, ...) + +ALTER TABLE table_name [PARTITION part_spec] SET SERDE serde_class_name + [WITH SERDEPROPERTIES (key1=val1, key2=val2, ...)] + +{% endhighlight %} + +#### SET LOCATION And SET FILE FORMAT +`ALTER TABLE SET` command can also be used for changing the file location and file format for +exsisting tables. + +##### Syntax +{% highlight sql %} + +--Changing File Format +ALTER TABLE table_name [PARTITION partition_spec] SET FILEFORMAT file_format; + +--Changing File Location +ALTER TABLE table_name [PARTITION partition_spec] SET LOCATION 'new_location'; + +{% endhighlight %} + +#### Parameters +
+
table_name
+
The name of an existing table.
+
+ +
+
PARTITION (part_spec)
+
Specifies the partition on which the property has to be set.
+
+ +
+
SERDEPROPERTIES (key1=val1, key2=val2, ...)
+
Specifies the SERDE properties to be set.
+
+ + +### Examples +{% highlight sql %} + +--RENAME table +DESC student; ++--------------------------+------------+----------+--+ +| col_name | data_type | comment | ++--------------------------+------------+----------+--+ +| name | string | NULL | +| rollno | int | NULL | +| age | int | NULL | +| # Partition Information | | | +| # col_name | data_type | comment | +| age | int | NULL | ++--------------------------+------------+----------+--+ + +ALTER TABLE Student RENAME TO StudentInfo; + +--After Renaming the table + +DESC StudentInfo; ++--------------------------+------------+----------+--+ +| col_name | data_type | comment | ++--------------------------+------------+----------+--+ +| name | string | NULL | +| rollno | int | NULL | +| age | int | NULL | +| # Partition Information | | | +| # col_name | data_type | comment | +| age | int | NULL | ++--------------------------+------------+----------+--+ + +--RENAME partition + +SHOW PARTITIONS StudentInfo; ++------------+--+ +| partition | ++------------+--+ +| age=10 | +| age=11 | +| age=12 | ++------------+--+ + +ALTER TABLE default.StudentInfo PARTITION (age='10') RENAME TO PARTITION (age='15'); + +--After renaming Partition +SHOW PARTITIONS StudentInfo; ++------------+--+ +| partition | ++------------+--+ +| age=11 | +| age=12 | +| age=15 | ++------------+--+ + +-- Add new column to a table + +DESC StudentInfo; ++--------------------------+------------+----------+--+ +| col_name | data_type | comment | ++--------------------------+------------+----------+--+ +| name | string | NULL | +| rollno | int | NULL | +| age | int | NULL | +| # Partition Information | | | +| # col_name | data_type | comment | +| age | int | NULL | ++--------------------------+------------+----------+ + +ALTER TABLE StudentInfo ADD columns (LastName string, DOB timestamp); + +--After Adding New columns to the table +DESC StudentInfo; ++--------------------------+------------+----------+--+ +| col_name | data_type | comment | ++--------------------------+------------+----------+--+ +| name | string | NULL | +| rollno | int | NULL | +| LastName | string | NULL | +| DOB | timestamp | NULL | +| age | int | NULL | +| # Partition Information | | | +| # col_name | data_type | comment | +| age | int | NULL | ++--------------------------+------------+----------+--+ + + +--Change the fileformat +ALTER TABLE loc_orc SET fileformat orc; + +ALTER TABLE p1 partition (month=2, day=2) SET fileformat parquet; + +--Change the file Location +ALTER TABLE dbx.tab1 PARTITION (a='1', b='2') SET LOCATION '/path/to/part/ways' + +-- SET SERDE/ SERDE Properties +ALTER TABLE test_tab SET SERDE 'org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe'; + +ALTER TABLE dbx.tab1 SET SERDE 'org.apache.madoop' WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee') + +--SET TABLE PROPERTIES +ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('winner' = 'loser') + +--DROP TABLE PROPERTIES +ALTER TABLE dbx.tab1 UNSET TBLPROPERTIES ('winner') + +{% endhighlight %} + + +### Related Statements +- [CREATE TABLE](sql-ref-syntax-ddl-create-table.html) +- [DROP TABLE](sql-ref-syntax-ddl-drop-table.html) + + From 76d4bebb54d7b3960d48fcf473d2d3db0f5fbcde Mon Sep 17 00:00:00 2001 From: shahid Date: Thu, 24 Oct 2019 08:29:05 -0500 Subject: [PATCH 35/58] [SPARK-29559][WEBUI] Support pagination for JDBC/ODBC Server page ### What changes were proposed in this pull request? Supports pagination for SQL Statisitcs table in the JDBC/ODBC tab using existing Spark pagination framework. ### Why are the changes needed? It will easier for user to analyse the table and it may fix the potential issues like oom while loading the page, that may occur similar to the SQL page (refer https://github.com/apache/spark/pull/22645) ### Does this PR introduce any user-facing change? There will be no change in the `SQLStatistics` table in JDBC/ODBC server page execpt pagination support. ### How was this patch tested? Manually verified. Before PR: ![Screenshot 2019-10-22 at 11 37 29 PM](https://user-images.githubusercontent.com/23054875/67316080-73636680-f525-11e9-91bc-ff7e06e3736d.png) After PR: ![Screenshot 2019-10-22 at 10 33 00 PM](https://user-images.githubusercontent.com/23054875/67316092-778f8400-f525-11e9-93f8-1e2815abd66f.png) Closes #26215 from shahidki31/jdbcPagination. Authored-by: shahid Signed-off-by: Sean Owen --- .../thriftserver/ui/ThriftServerPage.scala | 364 +++++++++++++++--- 1 file changed, 302 insertions(+), 62 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 4056be4769d2..e472aaad5bdc 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.hive.thriftserver.ui +import java.net.URLEncoder +import java.nio.charset.StandardCharsets.UTF_8 import java.util.Calendar import javax.servlet.http.HttpServletRequest -import scala.xml.Node +import scala.collection.JavaConverters._ +import scala.xml.{Node, Unparsed} import org.apache.commons.text.StringEscapeUtils @@ -29,7 +32,7 @@ import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{ExecutionInfo, import org.apache.spark.sql.hive.thriftserver.ui.ToolTips._ import org.apache.spark.ui._ import org.apache.spark.ui.UIUtils._ - +import org.apache.spark.util.Utils /** Page for Spark Web UI that shows statistics of the thrift server */ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("") with Logging { @@ -69,45 +72,56 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" /** Generate stats of batch statements of the thrift server program */ private def generateSQLStatsTable(request: HttpServletRequest): Seq[Node] = { + val numStatement = listener.getExecutionList.size + val table = if (numStatement > 0) { - val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Close Time", - "Execution Time", "Duration", "Statement", "State", "Detail") - val tooltips = Seq(None, None, None, None, Some(THRIFT_SERVER_FINISH_TIME), - Some(THRIFT_SERVER_CLOSE_TIME), Some(THRIFT_SERVER_EXECUTION), - Some(THRIFT_SERVER_DURATION), None, None, None) - assert(headerRow.length == tooltips.length) - val dataRows = listener.getExecutionList.sortBy(_.startTimestamp).reverse - - def generateDataRow(info: ExecutionInfo): Seq[Node] = { - val jobLink = info.jobId.map { id: String => - - [{id}] - + + val sqlTableTag = "sqlstat" + + val parameterOtherTable = request.getParameterMap().asScala + .filterNot(_._1.startsWith(sqlTableTag)) + .map { case (name, vals) => + name + "=" + vals(0) } - val detail = Option(info.detail).filter(!_.isEmpty).getOrElse(info.executePlan) - - {info.userName} - - {jobLink} - - {info.groupId} - {formatDate(info.startTimestamp)} - {if (info.finishTimestamp > 0) formatDate(info.finishTimestamp)} - {if (info.closeTimestamp > 0) formatDate(info.closeTimestamp)} - - {formatDurationOption(Some(info.totalTime(info.finishTimestamp)))} - - {formatDurationOption(Some(info.totalTime(info.closeTimestamp)))} - {info.statement} - {info.state} - {errorMessageCell(detail)} - - } - Some(UIUtils.listingTable(headerRow, generateDataRow, - dataRows, false, None, Seq(null), false, tooltipHeaders = tooltips)) + val parameterSqlTablePage = request.getParameter(s"$sqlTableTag.page") + val parameterSqlTableSortColumn = request.getParameter(s"$sqlTableTag.sort") + val parameterSqlTableSortDesc = request.getParameter(s"$sqlTableTag.desc") + val parameterSqlPageSize = request.getParameter(s"$sqlTableTag.pageSize") + + val sqlTablePage = Option(parameterSqlTablePage).map(_.toInt).getOrElse(1) + val sqlTableSortColumn = Option(parameterSqlTableSortColumn).map { sortColumn => + UIUtils.decodeURLParameter(sortColumn) + }.getOrElse("Start Time") + val sqlTableSortDesc = Option(parameterSqlTableSortDesc).map(_.toBoolean).getOrElse( + // New executions should be shown above old executions by default. + sqlTableSortColumn == "Start Time" + ) + val sqlTablePageSize = Option(parameterSqlPageSize).map(_.toInt).getOrElse(100) + + try { + Some(new SqlStatsPagedTable( + request, + parent, + listener.getExecutionList, + "sqlserver", + UIUtils.prependBaseUri(request, parent.basePath), + parameterOtherTable, + sqlTableTag, + pageSize = sqlTablePageSize, + sortColumn = sqlTableSortColumn, + desc = sqlTableSortDesc + ).table(sqlTablePage)) + } catch { + case e@(_: IllegalArgumentException | _: IndexOutOfBoundsException) => + Some(
+

Error while rendering job table:

+
+              {Utils.exceptionString(e)}
+            
+
) + } } else { None } @@ -123,30 +137,6 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" content } - private def errorMessageCell(errorMessage: String): Seq[Node] = { - val isMultiline = errorMessage.indexOf('\n') >= 0 - val errorSummary = StringEscapeUtils.escapeHtml4( - if (isMultiline) { - errorMessage.substring(0, errorMessage.indexOf('\n')) - } else { - errorMessage - }) - val details = if (isMultiline) { - // scalastyle:off - - + details - ++ - - // scalastyle:on - } else { - "" - } - {errorSummary}{details} - } - /** Generate stats of batch sessions of the thrift server program */ private def generateSessionStatsTable(request: HttpServletRequest): Seq[Node] = { val sessionList = listener.getSessionList @@ -185,7 +175,6 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" content } - /** * Returns a human-readable string representing a duration such as "5 second 35 ms" */ @@ -202,3 +191,254 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" } } +private[ui] class SqlStatsPagedTable( + request: HttpServletRequest, + parent: ThriftServerTab, + data: Seq[ExecutionInfo], + subPath: String, + basePath: String, + parameterOtherTable: Iterable[String], + sqlStatsTableTag: String, + pageSize: Int, + sortColumn: String, + desc: Boolean) extends PagedTable[SqlStatsTableRow] { + + override val dataSource = new SqlStatsTableDataSource(data, pageSize, sortColumn, desc) + + private val parameterPath = s"$basePath/$subPath/?${parameterOtherTable.mkString("&")}" + + override def tableId: String = sqlStatsTableTag + + override def tableCssClass: String = + "table table-bordered table-condensed table-striped " + + "table-head-clickable table-cell-width-limited" + + override def pageLink(page: Int): String = { + val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + parameterPath + + s"&$pageNumberFormField=$page" + + s"&$sqlStatsTableTag.sort=$encodedSortColumn" + + s"&$sqlStatsTableTag.desc=$desc" + + s"&$pageSizeFormField=$pageSize" + } + + override def pageSizeFormField: String = s"$sqlStatsTableTag.pageSize" + + override def pageNumberFormField: String = s"$sqlStatsTableTag.page" + + override def goButtonFormPath: String = { + val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + s"$parameterPath&$sqlStatsTableTag.sort=$encodedSortColumn&$sqlStatsTableTag.desc=$desc" + } + + override def headers: Seq[Node] = { + val sqlTableHeaders = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", + "Close Time", "Execution Time", "Duration", "Statement", "State", "Detail") + + val tooltips = Seq(None, None, None, None, Some(THRIFT_SERVER_FINISH_TIME), + Some(THRIFT_SERVER_CLOSE_TIME), Some(THRIFT_SERVER_EXECUTION), + Some(THRIFT_SERVER_DURATION), None, None, None) + + assert(sqlTableHeaders.length == tooltips.length) + + val headerRow: Seq[Node] = { + sqlTableHeaders.zip(tooltips).map { case (header, tooltip) => + if (header == sortColumn) { + val headerLink = Unparsed( + parameterPath + + s"&$sqlStatsTableTag.sort=${URLEncoder.encode(header, UTF_8.name())}" + + s"&$sqlStatsTableTag.desc=${!desc}" + + s"&$sqlStatsTableTag.pageSize=$pageSize" + + s"#$sqlStatsTableTag") + val arrow = if (desc) "▾" else "▴" // UP or DOWN + + if (tooltip.nonEmpty) { + + + + {header} {Unparsed(arrow)} + + + + } else { + + + {header} {Unparsed(arrow)} + + + } + } else { + val headerLink = Unparsed( + parameterPath + + s"&$sqlStatsTableTag.sort=${URLEncoder.encode(header, UTF_8.name())}" + + s"&$sqlStatsTableTag.pageSize=$pageSize" + + s"#$sqlStatsTableTag") + + if(tooltip.nonEmpty) { + + + + {header} + + + + } else { + + + {header} + + + } + } + } + } + + {headerRow} + + } + + override def row(sqlStatsTableRow: SqlStatsTableRow): Seq[Node] = { + val info = sqlStatsTableRow.executionInfo + val startTime = info.startTimestamp + val executionTime = sqlStatsTableRow.executionTime + val duration = sqlStatsTableRow.duration + + def jobLinks(jobData: Seq[String]): Seq[Node] = { + jobData.map { jobId => + [{jobId.toString}] + } + } + + + + {info.userName} + + + {jobLinks(sqlStatsTableRow.jobId)} + + + {info.groupId} + + + {UIUtils.formatDate(startTime)} + + + {if (info.finishTimestamp > 0) formatDate(info.finishTimestamp)} + + + {if (info.closeTimestamp > 0) formatDate(info.closeTimestamp)} + + + {UIUtils.formatDuration(executionTime)} + + + {UIUtils.formatDuration(duration)} + + + {info.statement} + + + {info.state} + + {errorMessageCell(sqlStatsTableRow.detail)} + + } + + + private def errorMessageCell(errorMessage: String): Seq[Node] = { + val isMultiline = errorMessage.indexOf('\n') >= 0 + val errorSummary = StringEscapeUtils.escapeHtml4( + if (isMultiline) { + errorMessage.substring(0, errorMessage.indexOf('\n')) + } else { + errorMessage + }) + val details = if (isMultiline) { + // scalastyle:off + + + details + ++ + + // scalastyle:on + } else { + "" + } + + {errorSummary}{details} + + } + + private def jobURL(request: HttpServletRequest, jobId: String): String = + "%s/jobs/job/?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) +} + + private[ui] class SqlStatsTableRow( + val jobId: Seq[String], + val duration: Long, + val executionTime: Long, + val executionInfo: ExecutionInfo, + val detail: String) + + private[ui] class SqlStatsTableDataSource( + info: Seq[ExecutionInfo], + pageSize: Int, + sortColumn: String, + desc: Boolean) extends PagedDataSource[SqlStatsTableRow](pageSize) { + + // Convert ExecutionInfo to SqlStatsTableRow which contains the final contents to show in + // the table so that we can avoid creating duplicate contents during sorting the data + private val data = info.map(sqlStatsTableRow).sorted(ordering(sortColumn, desc)) + + private var _slicedStartTime: Set[Long] = null + + override def dataSize: Int = data.size + + override def sliceData(from: Int, to: Int): Seq[SqlStatsTableRow] = { + val r = data.slice(from, to) + r.map(x => x) + _slicedStartTime = r.map(_.executionInfo.startTimestamp).toSet + r + } + + private def sqlStatsTableRow(executionInfo: ExecutionInfo): SqlStatsTableRow = { + val duration = executionInfo.totalTime(executionInfo.closeTimestamp) + val executionTime = executionInfo.totalTime(executionInfo.finishTimestamp) + val detail = Option(executionInfo.detail).filter(!_.isEmpty) + .getOrElse(executionInfo.executePlan) + val jobId = executionInfo.jobId.toSeq.sorted + + new SqlStatsTableRow(jobId, duration, executionTime, executionInfo, detail) + + } + + /** + * Return Ordering according to sortColumn and desc. + */ + private def ordering(sortColumn: String, desc: Boolean): Ordering[SqlStatsTableRow] = { + val ordering: Ordering[SqlStatsTableRow] = sortColumn match { + case "User" => Ordering.by(_.executionInfo.userName) + case "JobID" => Ordering by (_.jobId.headOption) + case "GroupID" => Ordering.by(_.executionInfo.groupId) + case "Start Time" => Ordering.by(_.executionInfo.startTimestamp) + case "Finish Time" => Ordering.by(_.executionInfo.finishTimestamp) + case "Close Time" => Ordering.by(_.executionInfo.closeTimestamp) + case "Execution Time" => Ordering.by(_.executionTime) + case "Duration" => Ordering.by(_.duration) + case "Statement" => Ordering.by(_.executionInfo.statement) + case "State" => Ordering.by(_.executionInfo.state) + case "Detail" => Ordering.by(_.detail) + case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") + } + if (desc) { + ordering.reverse + } else { + ordering + } + } + + } From a35fb4fd504b9ac14384eb63fc4c993fd53cd667 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 24 Oct 2019 08:30:27 -0500 Subject: [PATCH 36/58] [SPARK-29578][TESTS] Add "8634" as another skipped day for Kwajalein timzeone due to more recent timezone updates in later JDK 8 ### What changes were proposed in this pull request? Recent timezone definition changes in very new JDK 8 (and beyond) releases cause test failures. The below was observed on JDK 1.8.0_232. As before, the easy fix is to allow for these inconsequential variations in test results due to differing definition of timezones. ### Why are the changes needed? Keeps test passing on the latest JDK releases. ### Does this PR introduce any user-facing change? None ### How was this patch tested? Existing tests Closes #26236 from srowen/SPARK-29578. Authored-by: Sean Owen Signed-off-by: Sean Owen --- .../apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 1abbca6c8cd2..10642b3ca8a4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -555,12 +555,12 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers { // There are some days are skipped entirely in some timezone, skip them here. val skipped_days = Map[String, Set[Int]]( - "Kwajalein" -> Set(8632, 8633), + "Kwajalein" -> Set(8632, 8633, 8634), "Pacific/Apia" -> Set(15338), "Pacific/Enderbury" -> Set(9130, 9131), "Pacific/Fakaofo" -> Set(15338), "Pacific/Kiritimati" -> Set(9130, 9131), - "Pacific/Kwajalein" -> Set(8632, 8633), + "Pacific/Kwajalein" -> Set(8632, 8633, 8634), "MIT" -> Set(15338)) for (tz <- ALL_TIMEZONES) { val skipped = skipped_days.getOrElse(tz.getID, Set.empty) From cdea520ff8954cf415fd98d034d9b674d6ca4f67 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 24 Oct 2019 09:15:59 -0700 Subject: [PATCH 37/58] [SPARK-29532][SQL] Simplify interval string parsing ### What changes were proposed in this pull request? Only use antlr4 to parse the interval string, and remove the duplicated parsing logic from `CalendarInterval`. ### Why are the changes needed? Simplify the code and fix inconsistent behaviors. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Pass the Jenkins with the updated test cases. Closes #26190 from cloud-fan/parser. Lead-authored-by: Wenchen Fan Co-authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../spark/unsafe/types/CalendarInterval.java | 156 +++------------- .../unsafe/types/CalendarIntervalSuite.java | 174 ++---------------- .../spark/sql/catalyst/parser/SqlBase.g4 | 4 + .../spark/sql/catalyst/expressions/Cast.scala | 9 +- .../sql/catalyst/expressions/TimeWindow.scala | 4 +- .../sql/catalyst/parser/AstBuilder.scala | 40 +++- .../sql/catalyst/parser/ParseDriver.scala | 9 + .../sql/catalyst/util/IntervalUtils.scala | 51 ++++- .../CollectionExpressionsSuite.scala | 36 ++-- .../expressions/DateExpressionsSuite.scala | 18 +- .../expressions/HashExpressionsSuite.scala | 6 +- .../IntervalExpressionsSuite.scala | 4 +- .../expressions/MutableProjectionSuite.scala | 4 +- .../expressions/ObjectExpressionsSuite.scala | 5 +- .../expressions/UnsafeRowConverterSuite.scala | 2 +- .../parser/ExpressionParserSuite.scala | 6 +- .../catalyst/util/IntervalUtilsSuite.scala | 90 +++++++++ .../IntervalBenchmark-jdk11-results.txt | 40 ++-- .../benchmarks/IntervalBenchmark-results.txt | 41 +++-- .../scala/org/apache/spark/sql/Dataset.scala | 3 +- .../execution/streaming/GroupStateImpl.scala | 3 +- .../sql/execution/streaming/Triggers.scala | 4 +- .../resources/sql-tests/inputs/literals.sql | 13 ++ .../sql-tests/results/literals.sql.out | 164 ++++++++++++++--- .../benchmark/IntervalBenchmark.scala | 5 +- .../sql/hive/execution/SQLQuerySuite.scala | 46 ----- 26 files changed, 464 insertions(+), 473 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java index 28fb64f7cd0e..184ddac9a71a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java @@ -32,94 +32,11 @@ public final class CalendarInterval implements Serializable { public static final long MICROS_PER_DAY = MICROS_PER_HOUR * 24; public static final long MICROS_PER_WEEK = MICROS_PER_DAY * 7; - /** - * A function to generate regex which matches interval string's unit part like "3 years". - * - * First, we can leave out some units in interval string, and we only care about the value of - * unit, so here we use non-capturing group to wrap the actual regex. - * At the beginning of the actual regex, we should match spaces before the unit part. - * Next is the number part, starts with an optional "-" to represent negative value. We use - * capturing group to wrap this part as we need the value later. - * Finally is the unit name, ends with an optional "s". - */ - private static String unitRegex(String unit) { - return "(?:\\s+(-?\\d+)\\s+" + unit + "s?)?"; - } - - private static Pattern p = Pattern.compile("interval" + unitRegex("year") + unitRegex("month") + - unitRegex("week") + unitRegex("day") + unitRegex("hour") + unitRegex("minute") + - unitRegex("second") + unitRegex("millisecond") + unitRegex("microsecond"), - Pattern.CASE_INSENSITIVE); - - private static Pattern yearMonthPattern = - Pattern.compile("^(?:['|\"])?([+|-])?(\\d+)-(\\d+)(?:['|\"])?$"); + private static Pattern yearMonthPattern = Pattern.compile( + "^([+|-])?(\\d+)-(\\d+)$"); private static Pattern dayTimePattern = Pattern.compile( - "^(?:['|\"])?([+|-])?((\\d+) )?((\\d+):)?(\\d+):(\\d+)(\\.(\\d+))?(?:['|\"])?$"); - - private static Pattern quoteTrimPattern = Pattern.compile("^(?:['|\"])?(.*?)(?:['|\"])?$"); - - private static long toLong(String s) { - if (s == null) { - return 0; - } else { - return Long.parseLong(s); - } - } - - /** - * Convert a string to CalendarInterval. Return null if the input string is not a valid interval. - * This method is case-insensitive. - */ - public static CalendarInterval fromString(String s) { - try { - return fromCaseInsensitiveString(s); - } catch (IllegalArgumentException e) { - return null; - } - } - - /** - * Convert a string to CalendarInterval. This method can handle - * strings without the `interval` prefix and throws IllegalArgumentException - * when the input string is not a valid interval. - * - * @throws IllegalArgumentException if the string is not a valid internal. - */ - public static CalendarInterval fromCaseInsensitiveString(String s) { - if (s == null) { - throw new IllegalArgumentException("Interval cannot be null"); - } - String trimmed = s.trim(); - if (trimmed.isEmpty()) { - throw new IllegalArgumentException("Interval cannot be blank"); - } - String prefix = "interval"; - String intervalStr = trimmed; - // Checks the given interval string does not start with the `interval` prefix - if (!intervalStr.regionMatches(true, 0, prefix, 0, prefix.length())) { - // Prepend `interval` if it does not present because - // the regular expression strictly require it. - intervalStr = prefix + " " + trimmed; - } else if (intervalStr.length() == prefix.length()) { - throw new IllegalArgumentException("Interval string must have time units"); - } - - Matcher m = p.matcher(intervalStr); - if (!m.matches()) { - throw new IllegalArgumentException("Invalid interval: " + s); - } - - long months = toLong(m.group(1)) * 12 + toLong(m.group(2)); - long microseconds = toLong(m.group(3)) * MICROS_PER_WEEK; - microseconds += toLong(m.group(4)) * MICROS_PER_DAY; - microseconds += toLong(m.group(5)) * MICROS_PER_HOUR; - microseconds += toLong(m.group(6)) * MICROS_PER_MINUTE; - microseconds += toLong(m.group(7)) * MICROS_PER_SECOND; - microseconds += toLong(m.group(8)) * MICROS_PER_MILLI; - microseconds += toLong(m.group(9)); - return new CalendarInterval((int) months, microseconds); - } + "^([+|-])?((\\d+) )?((\\d+):)?(\\d+):(\\d+)(\\.(\\d+))?$"); public static long toLongWithRange(String fieldName, String s, long minValue, long maxValue) throws IllegalArgumentException { @@ -242,72 +159,59 @@ public static CalendarInterval fromDayTimeString(String s, String from, String t return result; } - public static CalendarInterval fromSingleUnitString(String unit, String s) + public static CalendarInterval fromUnitStrings(String[] units, String[] values) throws IllegalArgumentException { + assert units.length == values.length; + int months = 0; + long microseconds = 0; - CalendarInterval result = null; - if (s == null) { - throw new IllegalArgumentException(String.format("Interval %s string was null", unit)); - } - s = s.trim(); - Matcher m = quoteTrimPattern.matcher(s); - if (!m.matches()) { - throw new IllegalArgumentException( - "Interval string does not match day-time format of 'd h:m:s.n': " + s); - } else { + for (int i = 0; i < units.length; i++) { try { - switch (unit) { + switch (units[i]) { case "year": - int year = (int) toLongWithRange("year", m.group(1), - Integer.MIN_VALUE / 12, Integer.MAX_VALUE / 12); - result = new CalendarInterval(year * 12, 0L); + months = Math.addExact(months, Math.multiplyExact(Integer.parseInt(values[i]), 12)); break; case "month": - int month = (int) toLongWithRange("month", m.group(1), - Integer.MIN_VALUE, Integer.MAX_VALUE); - result = new CalendarInterval(month, 0L); + months = Math.addExact(months, Integer.parseInt(values[i])); break; case "week": - long week = toLongWithRange("week", m.group(1), - Long.MIN_VALUE / MICROS_PER_WEEK, Long.MAX_VALUE / MICROS_PER_WEEK); - result = new CalendarInterval(0, week * MICROS_PER_WEEK); + microseconds = Math.addExact( + microseconds, + Math.multiplyExact(Long.parseLong(values[i]), MICROS_PER_WEEK)); break; case "day": - long day = toLongWithRange("day", m.group(1), - Long.MIN_VALUE / MICROS_PER_DAY, Long.MAX_VALUE / MICROS_PER_DAY); - result = new CalendarInterval(0, day * MICROS_PER_DAY); + microseconds = Math.addExact( + microseconds, + Math.multiplyExact(Long.parseLong(values[i]), MICROS_PER_DAY)); break; case "hour": - long hour = toLongWithRange("hour", m.group(1), - Long.MIN_VALUE / MICROS_PER_HOUR, Long.MAX_VALUE / MICROS_PER_HOUR); - result = new CalendarInterval(0, hour * MICROS_PER_HOUR); + microseconds = Math.addExact( + microseconds, + Math.multiplyExact(Long.parseLong(values[i]), MICROS_PER_HOUR)); break; case "minute": - long minute = toLongWithRange("minute", m.group(1), - Long.MIN_VALUE / MICROS_PER_MINUTE, Long.MAX_VALUE / MICROS_PER_MINUTE); - result = new CalendarInterval(0, minute * MICROS_PER_MINUTE); + microseconds = Math.addExact( + microseconds, + Math.multiplyExact(Long.parseLong(values[i]), MICROS_PER_MINUTE)); break; case "second": { - long micros = parseSecondNano(m.group(1)); - result = new CalendarInterval(0, micros); + microseconds = Math.addExact(microseconds, parseSecondNano(values[i])); break; } case "millisecond": - long millisecond = toLongWithRange("millisecond", m.group(1), - Long.MIN_VALUE / MICROS_PER_MILLI, Long.MAX_VALUE / MICROS_PER_MILLI); - result = new CalendarInterval(0, millisecond * MICROS_PER_MILLI); + microseconds = Math.addExact( + microseconds, + Math.multiplyExact(Long.parseLong(values[i]), MICROS_PER_MILLI)); break; - case "microsecond": { - long micros = Long.parseLong(m.group(1)); - result = new CalendarInterval(0, micros); + case "microsecond": + microseconds = Math.addExact(microseconds, Long.parseLong(values[i])); break; - } } } catch (Exception e) { throw new IllegalArgumentException("Error parsing interval string: " + e.getMessage(), e); } } - return result; + return new CalendarInterval(months, microseconds); } /** diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java index 587071332ce4..9f3262bf2aaa 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java @@ -19,8 +19,6 @@ import org.junit.Test; -import java.util.Arrays; - import static org.junit.Assert.*; import static org.apache.spark.unsafe.types.CalendarInterval.*; @@ -62,72 +60,6 @@ public void toStringTest() { assertEquals("interval 2 years 10 months 3 weeks 13 hours 123 microseconds", i.toString()); } - @Test - public void fromStringTest() { - testSingleUnit("year", 3, 36, 0); - testSingleUnit("month", 3, 3, 0); - testSingleUnit("week", 3, 0, 3 * MICROS_PER_WEEK); - testSingleUnit("day", 3, 0, 3 * MICROS_PER_DAY); - testSingleUnit("hour", 3, 0, 3 * MICROS_PER_HOUR); - testSingleUnit("minute", 3, 0, 3 * MICROS_PER_MINUTE); - testSingleUnit("second", 3, 0, 3 * MICROS_PER_SECOND); - testSingleUnit("millisecond", 3, 0, 3 * MICROS_PER_MILLI); - testSingleUnit("microsecond", 3, 0, 3); - - CalendarInterval result = new CalendarInterval(-5 * 12 + 23, 0); - Arrays.asList( - "interval -5 years 23 month", - " -5 years 23 month", - "interval -5 years 23 month ", - " -5 years 23 month ", - " interval -5 years 23 month ").forEach(input -> - assertEquals(fromString(input), result) - ); - - // Error cases - Arrays.asList( - "interval 3month 1 hour", - "3month 1 hour", - "interval 3 moth 1 hour", - "3 moth 1 hour", - "interval", - "int", - "", - null).forEach(input -> assertNull(fromString(input))); - } - - @Test - public void fromCaseInsensitiveStringTest() { - for (String input : new String[]{"5 MINUTES", "5 minutes", "5 Minutes"}) { - assertEquals(fromCaseInsensitiveString(input), new CalendarInterval(0, 5L * 60 * 1_000_000)); - } - - for (String input : new String[]{null, "", " "}) { - try { - fromCaseInsensitiveString(input); - fail("Expected to throw an exception for the invalid input"); - } catch (IllegalArgumentException e) { - String msg = e.getMessage(); - if (input == null) assertTrue(msg.contains("cannot be null")); - else assertTrue(msg.contains("cannot be blank")); - } - } - - for (String input : new String[]{"interval", "interval1 day", "foo", "foo 1 day"}) { - try { - fromCaseInsensitiveString(input); - fail("Expected to throw an exception for the invalid input"); - } catch (IllegalArgumentException e) { - String msg = e.getMessage(); - if (input.trim().equalsIgnoreCase("interval")) { - assertTrue(msg.contains("Interval string must have time units")); - } else { - assertTrue(msg.contains("Invalid interval:")); - } - } - } - } - @Test public void fromYearMonthStringTest() { String input; @@ -194,107 +126,25 @@ public void fromDayTimeStringTest() { } } - @Test - public void fromSingleUnitStringTest() { - String input; - CalendarInterval i; - - input = "12"; - i = new CalendarInterval(12 * 12, 0L); - assertEquals(fromSingleUnitString("year", input), i); - - input = "100"; - i = new CalendarInterval(0, 100 * MICROS_PER_DAY); - assertEquals(fromSingleUnitString("day", input), i); - - input = "1999.38888"; - i = new CalendarInterval(0, 1999 * MICROS_PER_SECOND + 38); - assertEquals(fromSingleUnitString("second", input), i); - - try { - input = String.valueOf(Integer.MAX_VALUE); - fromSingleUnitString("year", input); - fail("Expected to throw an exception for the invalid input"); - } catch (IllegalArgumentException e) { - assertTrue(e.getMessage().contains("outside range")); - } - - try { - input = String.valueOf(Long.MAX_VALUE / MICROS_PER_HOUR + 1); - fromSingleUnitString("hour", input); - fail("Expected to throw an exception for the invalid input"); - } catch (IllegalArgumentException e) { - assertTrue(e.getMessage().contains("outside range")); - } - } - @Test public void addTest() { - String input = "interval 3 month 1 hour"; - String input2 = "interval 2 month 100 hour"; - - CalendarInterval interval = fromString(input); - CalendarInterval interval2 = fromString(input2); - - assertEquals(interval.add(interval2), new CalendarInterval(5, 101 * MICROS_PER_HOUR)); + CalendarInterval input1 = new CalendarInterval(3, 1 * MICROS_PER_HOUR); + CalendarInterval input2 = new CalendarInterval(2, 100 * MICROS_PER_HOUR); + assertEquals(input1.add(input2), new CalendarInterval(5, 101 * MICROS_PER_HOUR)); - input = "interval -10 month -81 hour"; - input2 = "interval 75 month 200 hour"; - - interval = fromString(input); - interval2 = fromString(input2); - - assertEquals(interval.add(interval2), new CalendarInterval(65, 119 * MICROS_PER_HOUR)); + input1 = new CalendarInterval(-10, -81 * MICROS_PER_HOUR); + input2 = new CalendarInterval(75, 200 * MICROS_PER_HOUR); + assertEquals(input1.add(input2), new CalendarInterval(65, 119 * MICROS_PER_HOUR)); } @Test public void subtractTest() { - String input = "interval 3 month 1 hour"; - String input2 = "interval 2 month 100 hour"; - - CalendarInterval interval = fromString(input); - CalendarInterval interval2 = fromString(input2); - - assertEquals(interval.subtract(interval2), new CalendarInterval(1, -99 * MICROS_PER_HOUR)); - - input = "interval -10 month -81 hour"; - input2 = "interval 75 month 200 hour"; - - interval = fromString(input); - interval2 = fromString(input2); - - assertEquals(interval.subtract(interval2), new CalendarInterval(-85, -281 * MICROS_PER_HOUR)); - } - - private static void testSingleUnit(String unit, int number, int months, long microseconds) { - Arrays.asList("interval ", "").forEach(prefix -> { - String input1 = prefix + number + " " + unit; - String input2 = prefix + number + " " + unit + "s"; - CalendarInterval result = new CalendarInterval(months, microseconds); - assertEquals(fromString(input1), result); - assertEquals(fromString(input2), result); - }); - } - - @Test - public void fromStringCaseSensitivityTest() { - testSingleUnit("YEAR", 3, 36, 0); - testSingleUnit("Month", 3, 3, 0); - testSingleUnit("Week", 3, 0, 3 * MICROS_PER_WEEK); - testSingleUnit("DAY", 3, 0, 3 * MICROS_PER_DAY); - testSingleUnit("HouR", 3, 0, 3 * MICROS_PER_HOUR); - testSingleUnit("MiNuTe", 3, 0, 3 * MICROS_PER_MINUTE); - testSingleUnit("Second", 3, 0, 3 * MICROS_PER_SECOND); - testSingleUnit("MilliSecond", 3, 0, 3 * MICROS_PER_MILLI); - testSingleUnit("MicroSecond", 3, 0, 3); - - String input; - - input = "INTERVAL -5 YEARS 23 MONTHS"; - CalendarInterval result = new CalendarInterval(-5 * 12 + 23, 0); - assertEquals(fromString(input), result); + CalendarInterval input1 = new CalendarInterval(3, 1 * MICROS_PER_HOUR); + CalendarInterval input2 = new CalendarInterval(2, 100 * MICROS_PER_HOUR); + assertEquals(input1.subtract(input2), new CalendarInterval(1, -99 * MICROS_PER_HOUR)); - assertNull(fromString("INTERVAL")); - assertNull(fromString(" Interval ")); + input1 = new CalendarInterval(-10, -81 * MICROS_PER_HOUR); + input2 = new CalendarInterval(75, 200 * MICROS_PER_HOUR); + assertEquals(input1.subtract(input2), new CalendarInterval(-85, -281 * MICROS_PER_HOUR)); } } diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 01cd181010f9..82401f91e31d 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -79,6 +79,10 @@ singleTableSchema : colTypeList EOF ; +singleInterval + : INTERVAL? (intervalValue intervalUnit)+ EOF + ; + statement : query #statementDefault | ctes? dmlStatementNoWith #dmlStatement diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d71f300dd26d..862b2bb515a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -23,7 +23,7 @@ import java.util.Locale import java.util.concurrent.TimeUnit._ import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.{InternalRow, WalkedTypePath} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.UTF8StringBuilder -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper} object Cast { @@ -466,7 +466,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit // IntervalConverter private[this] def castToInterval(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => CalendarInterval.fromString(s.toString)) + buildCast[UTF8String](_, s => IntervalUtils.safeFromString(s.toString)) } // LongConverter @@ -1213,8 +1213,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def castToIntervalCode(from: DataType): CastFunction = from match { case StringType => + val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") (c, evPrim, evNull) => - code"""$evPrim = CalendarInterval.fromString($c.toString()); + code"""$evPrim = $util.safeFromString($c.toString()); if(${evPrim} == null) { ${evNull} = true; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 9aae678deb4b..b9ec933f3149 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -22,8 +22,8 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval case class TimeWindow( timeColumn: Expression, @@ -102,7 +102,7 @@ object TimeWindow { * precision. */ private def getIntervalInMicroSeconds(interval: String): Long = { - val cal = CalendarInterval.fromCaseInsensitiveString(interval) + val cal = IntervalUtils.fromString(interval) if (cal.months > 0) { throw new IllegalArgumentException( s"Intervals greater than a month is not supported ($interval).") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 99e5c9feb8fa..d8e1a0cdcb10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} +import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -100,6 +101,23 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging withOrigin(ctx)(StructType(visitColTypeList(ctx.colTypeList))) } + override def visitSingleInterval(ctx: SingleIntervalContext): CalendarInterval = { + withOrigin(ctx) { + val units = ctx.intervalUnit().asScala.map { + u => normalizeInternalUnit(u.getText.toLowerCase(Locale.ROOT)) + }.toArray + val values = ctx.intervalValue().asScala.map(getIntervalValue).toArray + try { + CalendarInterval.fromUnitStrings(units, values) + } catch { + case i: IllegalArgumentException => + val e = new ParseException(i.getMessage, ctx) + e.setStackTrace(i.getStackTrace) + throw e + } + } + } + /* ******************************************************************************************** * Plan parsing * ******************************************************************************************** */ @@ -1770,7 +1788,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging toLiteral(stringToTimestamp(_, zoneId), TimestampType) case "INTERVAL" => val interval = try { - CalendarInterval.fromCaseInsensitiveString(value) + IntervalUtils.fromString(value) } catch { case e: IllegalArgumentException => val ex = new ParseException("Cannot parse the INTERVAL value: " + value, ctx) @@ -1930,15 +1948,12 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging */ override def visitIntervalField(ctx: IntervalFieldContext): CalendarInterval = withOrigin(ctx) { import ctx._ - val s = value.getText + val s = getIntervalValue(value) try { val unitText = unit.getText.toLowerCase(Locale.ROOT) val interval = (unitText, Option(to).map(_.getText.toLowerCase(Locale.ROOT))) match { - case (u, None) if u.endsWith("s") => - // Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/... - CalendarInterval.fromSingleUnitString(u.substring(0, u.length - 1), s) case (u, None) => - CalendarInterval.fromSingleUnitString(u, s) + CalendarInterval.fromUnitStrings(Array(normalizeInternalUnit(u)), Array(s)) case ("year", Some("month")) => CalendarInterval.fromYearMonthString(s) case ("day", Some("hour")) => @@ -1967,6 +1982,19 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } } + private def getIntervalValue(value: IntervalValueContext): String = { + if (value.STRING() != null) { + string(value.STRING()) + } else { + value.getText + } + } + + // Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/... + private def normalizeInternalUnit(s: String): String = { + if (s.endsWith("s")) s.substring(0, s.length - 1) else s + } + /* ******************************************************************************************** * DataType parsing * ******************************************************************************************** */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index a84d29b71ac4..b66cae797941 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -29,12 +29,21 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.unsafe.types.CalendarInterval /** * Base SQL parsing infrastructure. */ abstract class AbstractSqlParser(conf: SQLConf) extends ParserInterface with Logging { + /** + * Creates [[CalendarInterval]] for a given SQL String. Throws [[ParseException]] if the SQL + * string is not a valid interval format. + */ + def parseInterval(sqlText: String): CalendarInterval = parse(sqlText) { parser => + astBuilder.visitSingleInterval(parser.singleInterval()) + } + /** Creates/Resolves DataType for a given SQL string. */ override def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => astBuilder.visitSingleDataType(parser.singleDataType()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index 78d188f81f62..14fd153e15f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -17,21 +17,24 @@ package org.apache.spark.sql.catalyst.util +import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} import org.apache.spark.sql.types.Decimal import org.apache.spark.unsafe.types.CalendarInterval object IntervalUtils { - val MONTHS_PER_YEAR: Int = 12 - val MONTHS_PER_QUARTER: Byte = 3 - val YEARS_PER_MILLENNIUM: Int = 1000 - val YEARS_PER_CENTURY: Int = 100 - val YEARS_PER_DECADE: Int = 10 - val MICROS_PER_HOUR: Long = DateTimeUtils.MILLIS_PER_HOUR * DateTimeUtils.MICROS_PER_MILLIS - val MICROS_PER_MINUTE: Long = DateTimeUtils.MILLIS_PER_MINUTE * DateTimeUtils.MICROS_PER_MILLIS - val DAYS_PER_MONTH: Byte = 30 - val MICROS_PER_MONTH: Long = DAYS_PER_MONTH * DateTimeUtils.SECONDS_PER_DAY + final val MONTHS_PER_YEAR: Int = 12 + final val MONTHS_PER_QUARTER: Byte = 3 + final val YEARS_PER_MILLENNIUM: Int = 1000 + final val YEARS_PER_CENTURY: Int = 100 + final val YEARS_PER_DECADE: Int = 10 + final val MICROS_PER_HOUR: Long = + DateTimeUtils.MILLIS_PER_HOUR * DateTimeUtils.MICROS_PER_MILLIS + final val MICROS_PER_MINUTE: Long = + DateTimeUtils.MILLIS_PER_MINUTE * DateTimeUtils.MICROS_PER_MILLIS + final val DAYS_PER_MONTH: Byte = 30 + final val MICROS_PER_MONTH: Long = DAYS_PER_MONTH * DateTimeUtils.SECONDS_PER_DAY /* 365.25 days per year assumes leap year every four years */ - val MICROS_PER_YEAR: Long = (36525L * DateTimeUtils.MICROS_PER_DAY) / 100 + final val MICROS_PER_YEAR: Long = (36525L * DateTimeUtils.MICROS_PER_DAY) / 100 def getYears(interval: CalendarInterval): Int = { interval.months / MONTHS_PER_YEAR @@ -88,4 +91,32 @@ object IntervalUtils { result += MICROS_PER_MONTH * (interval.months % MONTHS_PER_YEAR) Decimal(result, 18, 6) } + + /** + * Converts a string to [[CalendarInterval]] case-insensitively. + * + * @throws IllegalArgumentException if the input string is not in valid interval format. + */ + def fromString(str: String): CalendarInterval = { + if (str == null) throw new IllegalArgumentException("Interval string cannot be null") + try { + CatalystSqlParser.parseInterval(str) + } catch { + case e: ParseException => + val ex = new IllegalArgumentException(s"Invalid interval string: $str\n" + e.message) + ex.setStackTrace(e.getStackTrace) + throw ex + } + } + + /** + * A safe version of `fromString`. It returns null for invalid input string. + */ + def safeFromString(str: String): CalendarInterval = { + try { + fromString(str) + } catch { + case _: IllegalArgumentException => null + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 603073b40d7a..e10aa60d52cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.util.DateTimeTestUtils +import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, IntervalUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH @@ -720,7 +720,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Timestamp.valueOf("2018-01-01 00:00:00")), Literal(Timestamp.valueOf("2018-01-02 00:00:00")), - Literal(CalendarInterval.fromString("interval 12 hours"))), + Literal(IntervalUtils.fromString("interval 12 hours"))), Seq( Timestamp.valueOf("2018-01-01 00:00:00"), Timestamp.valueOf("2018-01-01 12:00:00"), @@ -729,7 +729,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Timestamp.valueOf("2018-01-01 00:00:00")), Literal(Timestamp.valueOf("2018-01-02 00:00:01")), - Literal(CalendarInterval.fromString("interval 12 hours"))), + Literal(IntervalUtils.fromString("interval 12 hours"))), Seq( Timestamp.valueOf("2018-01-01 00:00:00"), Timestamp.valueOf("2018-01-01 12:00:00"), @@ -738,7 +738,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Timestamp.valueOf("2018-01-02 00:00:00")), Literal(Timestamp.valueOf("2018-01-01 00:00:00")), - Literal(CalendarInterval.fromString("interval 12 hours").negate())), + Literal(IntervalUtils.fromString("interval 12 hours").negate())), Seq( Timestamp.valueOf("2018-01-02 00:00:00"), Timestamp.valueOf("2018-01-01 12:00:00"), @@ -747,7 +747,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Timestamp.valueOf("2018-01-02 00:00:00")), Literal(Timestamp.valueOf("2017-12-31 23:59:59")), - Literal(CalendarInterval.fromString("interval 12 hours").negate())), + Literal(IntervalUtils.fromString("interval 12 hours").negate())), Seq( Timestamp.valueOf("2018-01-02 00:00:00"), Timestamp.valueOf("2018-01-01 12:00:00"), @@ -756,7 +756,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Timestamp.valueOf("2018-01-01 00:00:00")), Literal(Timestamp.valueOf("2018-03-01 00:00:00")), - Literal(CalendarInterval.fromString("interval 1 month"))), + Literal(IntervalUtils.fromString("interval 1 month"))), Seq( Timestamp.valueOf("2018-01-01 00:00:00"), Timestamp.valueOf("2018-02-01 00:00:00"), @@ -765,7 +765,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Timestamp.valueOf("2018-03-01 00:00:00")), Literal(Timestamp.valueOf("2018-01-01 00:00:00")), - Literal(CalendarInterval.fromString("interval 1 month").negate())), + Literal(IntervalUtils.fromString("interval 1 month").negate())), Seq( Timestamp.valueOf("2018-03-01 00:00:00"), Timestamp.valueOf("2018-02-01 00:00:00"), @@ -774,7 +774,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Timestamp.valueOf("2018-03-03 00:00:00")), Literal(Timestamp.valueOf("2018-01-01 00:00:00")), - Literal(CalendarInterval.fromString("interval 1 month 1 day").negate())), + Literal(IntervalUtils.fromString("interval 1 month 1 day").negate())), Seq( Timestamp.valueOf("2018-03-03 00:00:00"), Timestamp.valueOf("2018-02-02 00:00:00"), @@ -783,7 +783,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Timestamp.valueOf("2018-01-31 00:00:00")), Literal(Timestamp.valueOf("2018-04-30 00:00:00")), - Literal(CalendarInterval.fromString("interval 1 month"))), + Literal(IntervalUtils.fromString("interval 1 month"))), Seq( Timestamp.valueOf("2018-01-31 00:00:00"), Timestamp.valueOf("2018-02-28 00:00:00"), @@ -793,7 +793,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Timestamp.valueOf("2018-01-01 00:00:00")), Literal(Timestamp.valueOf("2018-03-01 00:00:00")), - Literal(CalendarInterval.fromString("interval 1 month 1 second"))), + Literal(IntervalUtils.fromString("interval 1 month 1 second"))), Seq( Timestamp.valueOf("2018-01-01 00:00:00"), Timestamp.valueOf("2018-02-01 00:00:01"))) @@ -801,7 +801,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Timestamp.valueOf("2018-01-01 00:00:00")), Literal(Timestamp.valueOf("2018-03-01 00:04:06")), - Literal(CalendarInterval.fromString("interval 1 month 2 minutes 3 seconds"))), + Literal(IntervalUtils.fromString("interval 1 month 2 minutes 3 seconds"))), Seq( Timestamp.valueOf("2018-01-01 00:00:00"), Timestamp.valueOf("2018-02-01 00:02:03"), @@ -839,7 +839,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Timestamp.valueOf("2018-03-25 01:30:00")), Literal(Timestamp.valueOf("2018-03-25 03:30:00")), - Literal(CalendarInterval.fromString("interval 30 minutes"))), + Literal(IntervalUtils.fromString("interval 30 minutes"))), Seq( Timestamp.valueOf("2018-03-25 01:30:00"), Timestamp.valueOf("2018-03-25 03:00:00"), @@ -849,7 +849,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Timestamp.valueOf("2018-10-28 01:30:00")), Literal(Timestamp.valueOf("2018-10-28 03:30:00")), - Literal(CalendarInterval.fromString("interval 30 minutes"))), + Literal(IntervalUtils.fromString("interval 30 minutes"))), Seq( Timestamp.valueOf("2018-10-28 01:30:00"), noDST(Timestamp.valueOf("2018-10-28 02:00:00")), @@ -866,7 +866,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Date.valueOf("2018-01-01")), Literal(Date.valueOf("2018-01-05")), - Literal(CalendarInterval.fromString("interval 2 days"))), + Literal(IntervalUtils.fromString("interval 2 days"))), Seq( Date.valueOf("2018-01-01"), Date.valueOf("2018-01-03"), @@ -875,7 +875,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Date.valueOf("2018-01-01")), Literal(Date.valueOf("2018-03-01")), - Literal(CalendarInterval.fromString("interval 1 month"))), + Literal(IntervalUtils.fromString("interval 1 month"))), Seq( Date.valueOf("2018-01-01"), Date.valueOf("2018-02-01"), @@ -884,7 +884,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new Sequence( Literal(Date.valueOf("2018-01-31")), Literal(Date.valueOf("2018-04-30")), - Literal(CalendarInterval.fromString("interval 1 month"))), + Literal(IntervalUtils.fromString("interval 1 month"))), Seq( Date.valueOf("2018-01-31"), Date.valueOf("2018-02-28"), @@ -905,14 +905,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper new Sequence( Literal(Date.valueOf("1970-01-02")), Literal(Date.valueOf("1970-01-01")), - Literal(CalendarInterval.fromString("interval 1 day"))), + Literal(IntervalUtils.fromString("interval 1 day"))), EmptyRow, "sequence boundaries: 1 to 0 by 1") checkExceptionInExpression[IllegalArgumentException]( new Sequence( Literal(Date.valueOf("1970-01-01")), Literal(Date.valueOf("1970-02-01")), - Literal(CalendarInterval.fromString("interval 1 month").negate())), + Literal(IntervalUtils.fromString("interval 1 month").negate())), EmptyRow, s"sequence boundaries: 0 to 2678400000000 by -${28 * CalendarInterval.MICROS_PER_DAY}") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index e893e863b367..6abadd77bd41 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT import org.apache.spark.sql.internal.SQLConf @@ -1075,16 +1075,16 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(SubtractTimestamps(Literal(end), Literal(end)), new CalendarInterval(0, 0)) checkEvaluation(SubtractTimestamps(Literal(end), Literal(Instant.EPOCH)), - CalendarInterval.fromString("interval 18173 days " + + IntervalUtils.fromString("interval 18173 days " + "11 hours 4 minutes 1 seconds 123 milliseconds 456 microseconds")) checkEvaluation(SubtractTimestamps(Literal(Instant.EPOCH), Literal(end)), - CalendarInterval.fromString("interval -18173 days " + + IntervalUtils.fromString("interval -18173 days " + "-11 hours -4 minutes -1 seconds -123 milliseconds -456 microseconds")) checkEvaluation( SubtractTimestamps( Literal(Instant.parse("9999-12-31T23:59:59.999999Z")), Literal(Instant.parse("0001-01-01T00:00:00Z"))), - CalendarInterval.fromString("interval 521722 weeks 4 days " + + IntervalUtils.fromString("interval 521722 weeks 4 days " + "23 hours 59 minutes 59 seconds 999 milliseconds 999 microseconds")) } @@ -1093,18 +1093,18 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(SubtractDates(Literal(end), Literal(end)), new CalendarInterval(0, 0)) checkEvaluation(SubtractDates(Literal(end.plusDays(1)), Literal(end)), - CalendarInterval.fromString("interval 1 days")) + IntervalUtils.fromString("interval 1 days")) checkEvaluation(SubtractDates(Literal(end.minusDays(1)), Literal(end)), - CalendarInterval.fromString("interval -1 days")) + IntervalUtils.fromString("interval -1 days")) val epochDate = Literal(LocalDate.ofEpochDay(0)) checkEvaluation(SubtractDates(Literal(end), epochDate), - CalendarInterval.fromString("interval 49 years 9 months 4 days")) + IntervalUtils.fromString("interval 49 years 9 months 4 days")) checkEvaluation(SubtractDates(epochDate, Literal(end)), - CalendarInterval.fromString("interval -49 years -9 months -4 days")) + IntervalUtils.fromString("interval -49 years -9 months -4 days")) checkEvaluation( SubtractDates( Literal(LocalDate.of(10000, 1, 1)), Literal(LocalDate.of(1, 1, 1))), - CalendarInterval.fromString("interval 9999 years")) + IntervalUtils.fromString("interval 9999 years")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index f90c98be0b3f..4b2da73abe56 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -30,9 +30,9 @@ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData, IntervalUtils} import org.apache.spark.sql.types.{ArrayType, StructType, _} -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.types.UTF8String class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val random = new scala.util.Random @@ -252,7 +252,7 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("hive-hash for CalendarInterval type") { def checkHiveHashForIntervalType(interval: String, expected: Long): Unit = { - checkHiveHash(CalendarInterval.fromString(interval), CalendarIntervalType, expected) + checkHiveHash(IntervalUtils.fromString(interval), CalendarIntervalType, expected) } // ----- MICROSEC ----- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala index 078ec8880021..818ee239dbbf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.catalyst.expressions import scala.language.implicitConversions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.types.Decimal -import org.apache.spark.unsafe.types.CalendarInterval class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { implicit def interval(s: String): Literal = { - Literal(CalendarInterval.fromString("interval " + s)) + Literal(IntervalUtils.fromString("interval " + s)) } test("millenniums") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala index 0d594eb10962..23ba9c6ec738 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -56,7 +56,7 @@ class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { testBothCodegenAndInterpreted("variable-length types") { val proj = createMutableProjection(variableLengthTypes) - val scalaValues = Seq("abc", BigDecimal(10), CalendarInterval.fromString("interval 1 day"), + val scalaValues = Seq("abc", BigDecimal(10), IntervalUtils.fromString("interval 1 day"), Array[Byte](1, 2), Array("123", "456"), Map(1 -> "a", 2 -> "b"), Row(1, "a"), new java.lang.Integer(5)) val inputRow = InternalRow.fromSeq(scalaValues.zip(variableLengthTypes).map { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index a171885471a3..4ccd4f7ce798 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -35,8 +35,7 @@ import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData, IntervalUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -486,7 +485,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ("abcd".getBytes, BinaryType), ("abcd", StringType), (BigDecimal.valueOf(10), DecimalType.IntDecimal), - (CalendarInterval.fromString("interval 3 day"), CalendarIntervalType), + (IntervalUtils.fromString("interval 3 day"), CalendarIntervalType), (java.math.BigDecimal.valueOf(10), DecimalType.BigIntDecimal), (Array(3, 2, 1), ArrayType(IntegerType)) ).foreach { case (input, dt) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 323a3a901689..20e77254ecda 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -531,7 +531,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB // Simple tests val inputRow = InternalRow.fromSeq(Seq( false, 3.toByte, 15.toShort, -83, 129L, 1.0f, 8.0, UTF8String.fromString("test"), - Decimal(255), CalendarInterval.fromString("interval 1 day"), Array[Byte](1, 2) + Decimal(255), IntervalUtils.fromString("interval 1 day"), Array[Byte](1, 2) )) val fields1 = Array( BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index e6eabcc1f302..86b3aa8190b4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} -import org.apache.spark.sql.catalyst.util.DateTimeTestUtils +import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, IntervalUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -432,7 +432,7 @@ class ExpressionParserSuite extends AnalysisTest { intercept("timestamP '2016-33-11 20:54:00.000'", "Cannot parse the TIMESTAMP value") // Interval. - val intervalLiteral = Literal(CalendarInterval.fromString("interval 3 month 1 hour")) + val intervalLiteral = Literal(IntervalUtils.fromString("interval 3 month 1 hour")) assertEqual("InterVal 'interval 3 month 1 hour'", intervalLiteral) assertEqual("INTERVAL '3 month 1 hour'", intervalLiteral) intercept("Interval 'interval 3 monthsss 1 hoursss'", "Cannot parse the INTERVAL value") @@ -597,7 +597,7 @@ class ExpressionParserSuite extends AnalysisTest { "microsecond") def intervalLiteral(u: String, s: String): Literal = { - Literal(CalendarInterval.fromSingleUnitString(u, s)) + Literal(CalendarInterval.fromUnitStrings(Array(u), Array(s))) } test("intervals") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala new file mode 100644 index 000000000000..e48779af3c9a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.IntervalUtils.fromString +import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.unsafe.types.CalendarInterval._ + +class IntervalUtilsSuite extends SparkFunSuite { + + test("fromString: basic") { + testSingleUnit("YEAR", 3, 36, 0) + testSingleUnit("Month", 3, 3, 0) + testSingleUnit("Week", 3, 0, 3 * MICROS_PER_WEEK) + testSingleUnit("DAY", 3, 0, 3 * MICROS_PER_DAY) + testSingleUnit("HouR", 3, 0, 3 * MICROS_PER_HOUR) + testSingleUnit("MiNuTe", 3, 0, 3 * MICROS_PER_MINUTE) + testSingleUnit("Second", 3, 0, 3 * MICROS_PER_SECOND) + testSingleUnit("MilliSecond", 3, 0, 3 * MICROS_PER_MILLI) + testSingleUnit("MicroSecond", 3, 0, 3) + + for (input <- Seq(null, "", " ")) { + try { + fromString(input) + fail("Expected to throw an exception for the invalid input") + } catch { + case e: IllegalArgumentException => + val msg = e.getMessage + if (input == null) { + assert(msg.contains("cannot be null")) + } + } + } + + for (input <- Seq("interval", "interval1 day", "foo", "foo 1 day")) { + try { + fromString(input) + fail("Expected to throw an exception for the invalid input") + } catch { + case e: IllegalArgumentException => + val msg = e.getMessage + assert(msg.contains("Invalid interval string")) + } + } + } + + test("fromString: random order field") { + val input = "1 day 1 year" + val result = new CalendarInterval(12, MICROS_PER_DAY) + assert(fromString(input) == result) + } + + test("fromString: duplicated fields") { + val input = "1 day 1 day" + val result = new CalendarInterval(0, 2 * MICROS_PER_DAY) + assert(fromString(input) == result) + } + + test("fromString: value with +/-") { + val input = "+1 year -1 day" + val result = new CalendarInterval(12, -MICROS_PER_DAY) + assert(fromString(input) == result) + } + + private def testSingleUnit(unit: String, number: Int, months: Int, microseconds: Long): Unit = { + for (prefix <- Seq("interval ", "")) { + val input1 = prefix + number + " " + unit + val input2 = prefix + number + " " + unit + "s" + val result = new CalendarInterval(months, microseconds) + assert(fromString(input1) == result) + assert(fromString(input2) == result) + } + } +} diff --git a/sql/core/benchmarks/IntervalBenchmark-jdk11-results.txt b/sql/core/benchmarks/IntervalBenchmark-jdk11-results.txt index 6605bd2c13b4..221ac42022a1 100644 --- a/sql/core/benchmarks/IntervalBenchmark-jdk11-results.txt +++ b/sql/core/benchmarks/IntervalBenchmark-jdk11-results.txt @@ -2,24 +2,24 @@ OpenJDK 64-Bit Server VM 11.0.4+11-post-Ubuntu-1ubuntu218.04.3 on Linux 4.15.0-1 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz cast strings to intervals: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -string w/ interval 663 758 99 1.5 663.2 1.0X -string w/o interval 563 582 19 1.8 563.2 1.2X -1 units w/ interval 891 953 97 1.1 891.2 0.7X -1 units w/o interval 894 905 15 1.1 893.6 0.7X -2 units w/ interval 1142 1169 24 0.9 1141.8 0.6X -2 units w/o interval 1195 1200 7 0.8 1194.7 0.6X -3 units w/ interval 1449 1452 3 0.7 1449.5 0.5X -3 units w/o interval 1489 1491 1 0.7 1489.3 0.4X -4 units w/ interval 1688 1690 1 0.6 1688.4 0.4X -4 units w/o interval 1711 1722 11 0.6 1710.9 0.4X -5 units w/ interval 1961 1983 23 0.5 1961.3 0.3X -5 units w/o interval 1983 1991 10 0.5 1983.4 0.3X -6 units w/ interval 2217 2228 11 0.5 2216.9 0.3X -6 units w/o interval 2240 2244 5 0.4 2239.6 0.3X -7 units w/ interval 2454 2468 16 0.4 2454.1 0.3X -7 units w/o interval 2480 2491 15 0.4 2479.5 0.3X -8 units w/ interval 2762 2792 26 0.4 2761.9 0.2X -8 units w/o interval 2763 2778 14 0.4 2762.9 0.2X -9 units w/ interval 3036 3060 21 0.3 3036.4 0.2X -9 units w/o interval 3095 3111 15 0.3 3094.8 0.2X +prepare string w/ interval 672 728 64 1.5 672.1 1.0X +prepare string w/o interval 580 602 19 1.7 580.4 1.2X +1 units w/ interval 9450 9575 138 0.1 9449.6 0.1X +1 units w/o interval 8948 8968 19 0.1 8948.3 0.1X +2 units w/ interval 10947 10966 19 0.1 10947.1 0.1X +2 units w/o interval 10470 10489 26 0.1 10469.5 0.1X +3 units w/ interval 12265 12333 72 0.1 12264.5 0.1X +3 units w/o interval 12001 12004 3 0.1 12000.6 0.1X +4 units w/ interval 13749 13828 69 0.1 13748.5 0.0X +4 units w/o interval 13467 13479 15 0.1 13467.3 0.0X +5 units w/ interval 15392 15446 51 0.1 15392.1 0.0X +5 units w/o interval 15090 15107 29 0.1 15089.7 0.0X +6 units w/ interval 16696 16714 20 0.1 16695.9 0.0X +6 units w/o interval 16361 16366 5 0.1 16361.4 0.0X +7 units w/ interval 18190 18270 71 0.1 18190.2 0.0X +7 units w/o interval 17757 17767 9 0.1 17756.7 0.0X +8 units w/ interval 19821 19870 43 0.1 19820.7 0.0X +8 units w/o interval 19479 19555 97 0.1 19479.5 0.0X +9 units w/ interval 21417 21481 56 0.0 21417.1 0.0X +9 units w/o interval 21058 21131 86 0.0 21058.2 0.0X diff --git a/sql/core/benchmarks/IntervalBenchmark-results.txt b/sql/core/benchmarks/IntervalBenchmark-results.txt index 40169826cc62..60e8e5198353 100644 --- a/sql/core/benchmarks/IntervalBenchmark-results.txt +++ b/sql/core/benchmarks/IntervalBenchmark-results.txt @@ -2,24 +2,25 @@ OpenJDK 64-Bit Server VM 1.8.0_222-8u222-b10-1ubuntu1~18.04.1-b10 on Linux 4.15. Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz cast strings to intervals: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -string w/ interval 600 641 37 1.7 600.3 1.0X -string w/o interval 536 544 12 1.9 536.4 1.1X -1 units w/ interval 1017 1027 9 1.0 1016.8 0.6X -1 units w/o interval 1055 1073 16 0.9 1054.7 0.6X -2 units w/ interval 1272 1292 29 0.8 1272.2 0.5X -2 units w/o interval 1309 1314 9 0.8 1309.0 0.5X -3 units w/ interval 1545 1566 20 0.6 1544.8 0.4X -3 units w/o interval 1606 1610 5 0.6 1605.8 0.4X -4 units w/ interval 1820 1826 6 0.5 1819.7 0.3X -4 units w/o interval 1882 1885 3 0.5 1881.5 0.3X -5 units w/ interval 2039 2043 7 0.5 2038.9 0.3X -5 units w/o interval 2131 2133 3 0.5 2130.6 0.3X -6 units w/ interval 2269 2272 4 0.4 2269.5 0.3X -6 units w/o interval 2327 2333 6 0.4 2327.2 0.3X -7 units w/ interval 2477 2485 10 0.4 2476.8 0.2X -7 units w/o interval 2536 2538 3 0.4 2536.0 0.2X -8 units w/ interval 2764 2781 27 0.4 2763.8 0.2X -8 units w/o interval 2843 2847 5 0.4 2842.9 0.2X -9 units w/ interval 2983 2997 12 0.3 2982.5 0.2X -9 units w/o interval 3071 3072 1 0.3 3071.1 0.2X +prepare string w/ interval 596 647 61 1.7 596.0 1.0X +prepare string w/o interval 530 554 22 1.9 530.2 1.1X +1 units w/ interval 9168 9243 66 0.1 9167.8 0.1X +1 units w/o interval 8740 8744 5 0.1 8740.2 0.1X +2 units w/ interval 10815 10874 52 0.1 10815.0 0.1X +2 units w/o interval 10413 10419 11 0.1 10412.8 0.1X +3 units w/ interval 12490 12530 37 0.1 12490.3 0.0X +3 units w/o interval 12173 12180 9 0.1 12172.8 0.0X +4 units w/ interval 13788 13834 43 0.1 13788.0 0.0X +4 units w/o interval 13445 13456 10 0.1 13445.5 0.0X +5 units w/ interval 15313 15330 15 0.1 15312.7 0.0X +5 units w/o interval 14928 14942 16 0.1 14928.0 0.0X +6 units w/ interval 16959 17003 42 0.1 16959.1 0.0X +6 units w/o interval 16623 16627 5 0.1 16623.3 0.0X +7 units w/ interval 18955 18972 21 0.1 18955.4 0.0X +7 units w/o interval 18454 18462 7 0.1 18454.1 0.0X +8 units w/ interval 20835 20843 8 0.0 20835.4 0.0X +8 units w/o interval 20446 20463 19 0.0 20445.7 0.0X +9 units w/ interval 22981 23031 43 0.0 22981.4 0.0X +9 units w/o interval 22581 22603 25 0.0 22581.1 0.0X + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 076270a9f1c6..5f6e0a82be4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -46,6 +46,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.arrow.{ArrowBatchStreamWriter, ArrowConverters} import org.apache.spark.sql.execution.command._ @@ -724,7 +725,7 @@ class Dataset[T] private[sql]( def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = withTypedPlan { val parsedDelay = try { - CalendarInterval.fromCaseInsensitiveString(delayThreshold) + IntervalUtils.fromString(delayThreshold) } catch { case e: IllegalArgumentException => throw new AnalysisException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala index dda9d41f630e..d191a79187f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala @@ -21,6 +21,7 @@ import java.sql.Date import java.util.concurrent.TimeUnit import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout} +import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.execution.streaming.GroupStateImpl._ import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout} import org.apache.spark.unsafe.types.CalendarInterval @@ -159,7 +160,7 @@ private[sql] class GroupStateImpl[S] private( def getTimeoutTimestamp: Long = timeoutTimestamp private def parseDuration(duration: String): Long = { - val cal = CalendarInterval.fromCaseInsensitiveString(duration) + val cal = IntervalUtils.fromString(duration) if (cal.milliseconds < 0 || cal.months < 0) { throw new IllegalArgumentException(s"Provided duration ($duration) is not positive") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala index 2bdb3402c14b..daa70a12ba0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala @@ -21,8 +21,8 @@ import java.util.concurrent.TimeUnit import scala.concurrent.duration.Duration +import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.streaming.Trigger -import org.apache.spark.unsafe.types.CalendarInterval private object Triggers { def validate(intervalMs: Long): Unit = { @@ -30,7 +30,7 @@ private object Triggers { } def convert(interval: String): Long = { - val cal = CalendarInterval.fromCaseInsensitiveString(interval) + val cal = IntervalUtils.fromString(interval) if (cal.months > 0) { throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval") } diff --git a/sql/core/src/test/resources/sql-tests/inputs/literals.sql b/sql/core/src/test/resources/sql-tests/inputs/literals.sql index 816386c48320..0f95f8523782 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/literals.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/literals.sql @@ -85,6 +85,19 @@ select timestamp '2016-33-11 20:54:00.000'; -- interval select interval 13.123456789 seconds, interval -13.123456789 second; select interval 1 year 2 month 3 week 4 day 5 hour 6 minute 7 seconds 8 millisecond, 9 microsecond; +select interval '30' year '25' month '-100' day '40' hour '80' minute '299.889987299' second; +select interval '0 0:0:0.1' day to second; +select interval '10-9' year to month; +select interval '20 15:40:32.99899999' day to hour; +select interval '20 15:40:32.99899999' day to minute; +select interval '20 15:40:32.99899999' day to second; +select interval '15:40:32.99899999' hour to minute; +select interval '15:40.99899999' hour to second; +select interval '15:40' hour to second; +select interval '15:40:32.99899999' hour to second; +select interval '20 40:32.99899999' minute to second; +select interval '40:32.99899999' minute to second; +select interval '40:32' minute to second; -- ns is not supported select interval 10 nanoseconds; diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out index aef23963da37..fd6e51b2385d 100644 --- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 46 +-- Number of queries: 59 -- !query 0 @@ -337,10 +337,114 @@ interval 1 years 2 months 3 weeks 4 days 5 hours 6 minutes 7 seconds 8 milliseco -- !query 36 -select interval 10 nanoseconds +select interval '30' year '25' month '-100' day '40' hour '80' minute '299.889987299' second -- !query 36 schema -struct<> +struct -- !query 36 output +interval 32 years 1 months -14 weeks -6 hours -35 minutes -110 milliseconds -13 microseconds + + +-- !query 37 +select interval '0 0:0:0.1' day to second +-- !query 37 schema +struct +-- !query 37 output +interval 100 milliseconds + + +-- !query 38 +select interval '10-9' year to month +-- !query 38 schema +struct +-- !query 38 output +interval 10 years 9 months + + +-- !query 39 +select interval '20 15:40:32.99899999' day to hour +-- !query 39 schema +struct +-- !query 39 output +interval 2 weeks 6 days 15 hours + + +-- !query 40 +select interval '20 15:40:32.99899999' day to minute +-- !query 40 schema +struct +-- !query 40 output +interval 2 weeks 6 days 15 hours 40 minutes + + +-- !query 41 +select interval '20 15:40:32.99899999' day to second +-- !query 41 schema +struct +-- !query 41 output +interval 2 weeks 6 days 15 hours 40 minutes 32 seconds 998 milliseconds 999 microseconds + + +-- !query 42 +select interval '15:40:32.99899999' hour to minute +-- !query 42 schema +struct +-- !query 42 output +interval 15 hours 40 minutes + + +-- !query 43 +select interval '15:40.99899999' hour to second +-- !query 43 schema +struct +-- !query 43 output +interval 15 minutes 40 seconds 998 milliseconds 999 microseconds + + +-- !query 44 +select interval '15:40' hour to second +-- !query 44 schema +struct +-- !query 44 output +interval 15 hours 40 minutes + + +-- !query 45 +select interval '15:40:32.99899999' hour to second +-- !query 45 schema +struct +-- !query 45 output +interval 15 hours 40 minutes 32 seconds 998 milliseconds 999 microseconds + + +-- !query 46 +select interval '20 40:32.99899999' minute to second +-- !query 46 schema +struct +-- !query 46 output +interval 2 weeks 6 days 40 minutes 32 seconds 998 milliseconds 999 microseconds + + +-- !query 47 +select interval '40:32.99899999' minute to second +-- !query 47 schema +struct +-- !query 47 output +interval 40 minutes 32 seconds 998 milliseconds 999 microseconds + + +-- !query 48 +select interval '40:32' minute to second +-- !query 48 schema +struct +-- !query 48 output +interval 40 minutes 32 seconds + + +-- !query 49 +select interval 10 nanoseconds +-- !query 49 schema +struct<> +-- !query 49 output org.apache.spark.sql.catalyst.parser.ParseException no viable alternative at input 'interval 10 nanoseconds'(line 1, pos 19) @@ -350,11 +454,11 @@ select interval 10 nanoseconds -------------------^^^ --- !query 37 +-- !query 50 select GEO '(10,-6)' --- !query 37 schema +-- !query 50 schema struct<> --- !query 37 output +-- !query 50 output org.apache.spark.sql.catalyst.parser.ParseException Literals of type 'GEO' are currently not supported.(line 1, pos 7) @@ -364,19 +468,19 @@ select GEO '(10,-6)' -------^^^ --- !query 38 +-- !query 51 select 90912830918230182310293801923652346786BD, 123.0E-28BD, 123.08BD --- !query 38 schema +-- !query 51 schema struct<90912830918230182310293801923652346786:decimal(38,0),1.230E-26:decimal(29,29),123.08:decimal(5,2)> --- !query 38 output +-- !query 51 output 90912830918230182310293801923652346786 0.0000000000000000000000000123 123.08 --- !query 39 +-- !query 52 select 1.20E-38BD --- !query 39 schema +-- !query 52 schema struct<> --- !query 39 output +-- !query 52 output org.apache.spark.sql.catalyst.parser.ParseException decimal can only support precision up to 38(line 1, pos 7) @@ -386,19 +490,19 @@ select 1.20E-38BD -------^^^ --- !query 40 +-- !query 53 select x'2379ACFe' --- !query 40 schema +-- !query 53 schema struct --- !query 40 output +-- !query 53 output #y�� --- !query 41 +-- !query 54 select X'XuZ' --- !query 41 schema +-- !query 54 schema struct<> --- !query 41 output +-- !query 54 output org.apache.spark.sql.catalyst.parser.ParseException contains illegal character for hexBinary: 0XuZ(line 1, pos 7) @@ -408,33 +512,33 @@ select X'XuZ' -------^^^ --- !query 42 +-- !query 55 SELECT 3.14, -3.14, 3.14e8, 3.14e-8, -3.14e8, -3.14e-8, 3.14e+8, 3.14E8, 3.14E-8 --- !query 42 schema +-- !query 55 schema struct<3.14:decimal(3,2),-3.14:decimal(3,2),3.14E+8:decimal(3,-6),3.14E-8:decimal(10,10),-3.14E+8:decimal(3,-6),-3.14E-8:decimal(10,10),3.14E+8:decimal(3,-6),3.14E+8:decimal(3,-6),3.14E-8:decimal(10,10)> --- !query 42 output +-- !query 55 output 3.14 -3.14 314000000 0.0000000314 -314000000 -0.0000000314 314000000 314000000 0.0000000314 --- !query 43 +-- !query 56 select map(1, interval 1 day, 2, interval 3 week) --- !query 43 schema +-- !query 56 schema struct> --- !query 43 output +-- !query 56 output {1:interval 1 days,2:interval 3 weeks} --- !query 44 +-- !query 57 select interval 'interval 3 year 1 hour' --- !query 44 schema +-- !query 57 schema struct --- !query 44 output +-- !query 57 output interval 3 years 1 hours --- !query 45 +-- !query 58 select interval '3 year 1 hour' --- !query 45 schema +-- !query 58 schema struct --- !query 45 output +-- !query 58 output interval 3 years 1 hours diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/IntervalBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/IntervalBenchmark.scala index 4c1c75b815a0..d75cb1040f31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/IntervalBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/IntervalBenchmark.scala @@ -88,8 +88,9 @@ object IntervalBenchmark extends SqlBasedBenchmark { val intervalToTest = ListBuffer[String]() val benchmark = new Benchmark("cast strings to intervals", N, output = output) - addCase(benchmark, N, "string w/ interval", buildString(true, timeUnits)) - addCase(benchmark, N, "string w/o interval", buildString(false, timeUnits)) + // The first 2 cases are used to show the overhead of preparing the interval string. + addCase(benchmark, N, "prepare string w/ interval", buildString(true, timeUnits)) + addCase(benchmark, N, "prepare string w/o interval", buildString(false, timeUnits)) addCase(benchmark, N, intervalToTest) // Only years for (unit <- timeUnits) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 116a62b2f6dc..3e5834f33ea5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -42,7 +42,6 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.GLOBAL_TEMP_DATABASE import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval case class Nested1(f1: Nested2) case class Nested2(f2: Nested3) @@ -1183,51 +1182,6 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) } - test("Convert hive interval term into Literal of CalendarIntervalType") { - checkAnswer(sql("select interval '0 0:0:0.1' day to second"), - Row(CalendarInterval.fromString("interval 100 milliseconds"))) - checkAnswer(sql("select interval '10-9' year to month"), - Row(CalendarInterval.fromString("interval 10 years 9 months"))) - checkAnswer(sql("select interval '20 15:40:32.99899999' day to hour"), - Row(CalendarInterval.fromString("interval 2 weeks 6 days 15 hours"))) - checkAnswer(sql("select interval '20 15:40:32.99899999' day to minute"), - Row(CalendarInterval.fromString("interval 2 weeks 6 days 15 hours 40 minutes"))) - checkAnswer(sql("select interval '20 15:40:32.99899999' day to second"), - Row(CalendarInterval.fromString("interval 2 weeks 6 days 15 hours 40 minutes " + - "32 seconds 998 milliseconds 999 microseconds"))) - checkAnswer(sql("select interval '15:40:32.99899999' hour to minute"), - Row(CalendarInterval.fromString("interval 15 hours 40 minutes"))) - checkAnswer(sql("select interval '15:40.99899999' hour to second"), - Row(CalendarInterval.fromString("interval 15 minutes 40 seconds 998 milliseconds " + - "999 microseconds"))) - checkAnswer(sql("select interval '15:40' hour to second"), - Row(CalendarInterval.fromString("interval 15 hours 40 minutes"))) - checkAnswer(sql("select interval '15:40:32.99899999' hour to second"), - Row(CalendarInterval.fromString("interval 15 hours 40 minutes 32 seconds 998 milliseconds " + - "999 microseconds"))) - checkAnswer(sql("select interval '20 40:32.99899999' minute to second"), - Row(CalendarInterval.fromString("interval 2 weeks 6 days 40 minutes 32 seconds " + - "998 milliseconds 999 microseconds"))) - checkAnswer(sql("select interval '40:32.99899999' minute to second"), - Row(CalendarInterval.fromString("interval 40 minutes 32 seconds 998 milliseconds " + - "999 microseconds"))) - checkAnswer(sql("select interval '40:32' minute to second"), - Row(CalendarInterval.fromString("interval 40 minutes 32 seconds"))) - checkAnswer(sql("select interval '30' year"), - Row(CalendarInterval.fromString("interval 30 years"))) - checkAnswer(sql("select interval '25' month"), - Row(CalendarInterval.fromString("interval 25 months"))) - checkAnswer(sql("select interval '-100' day"), - Row(CalendarInterval.fromString("interval -14 weeks -2 days"))) - checkAnswer(sql("select interval '40' hour"), - Row(CalendarInterval.fromString("interval 1 days 16 hours"))) - checkAnswer(sql("select interval '80' minute"), - Row(CalendarInterval.fromString("interval 1 hour 20 minutes"))) - checkAnswer(sql("select interval '299.889987299' second"), - Row(CalendarInterval.fromString( - "interval 4 minutes 59 seconds 889 milliseconds 987 microseconds"))) - } - test("specifying database name for a temporary view is not allowed") { withTempPath { dir => withTempView("db.t") { From dcf5eaf1a6c0330a9460e168c1c3fee21998ba65 Mon Sep 17 00:00:00 2001 From: stczwd Date: Thu, 24 Oct 2019 10:25:04 -0700 Subject: [PATCH 38/58] [SPARK-29444][FOLLOWUP] add doc and python parameter for ignoreNullFields in json generating # What changes were proposed in this pull request? Add description for ignoreNullFields, which is commited in #26098 , in DataFrameWriter and readwriter.py. Enable user to use ignoreNullFields in pyspark. ### Does this PR introduce any user-facing change? No ### How was this patch tested? run unit tests Closes #26227 from stczwd/json-generator-doc. Authored-by: stczwd Signed-off-by: Dongjoon Hyun --- python/pyspark/sql/readwriter.py | 6 ++++-- .../apache/spark/sql/catalyst/json/JSONOptions.scala | 4 ++-- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 10 ++++++---- .../scala/org/apache/spark/sql/DataFrameWriter.scala | 2 ++ 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index f92face2d057..18fd7de7ee54 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -788,7 +788,7 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) @since(1.4) def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None, - lineSep=None, encoding=None): + lineSep=None, encoding=None, ignoreNullFields=None): """Saves the content of the :class:`DataFrame` in JSON format (`JSON Lines text format or newline-delimited JSON `_) at the specified path. @@ -817,13 +817,15 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm the default UTF-8 charset will be used. :param lineSep: defines the line separator that should be used for writing. If None is set, it uses the default value, ``\\n``. + :param ignoreNullFields: Whether to ignore null fields when generating JSON objects. + If None is set, it uses the default value, ``true``. >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) self._set_opts( compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat, - lineSep=lineSep, encoding=encoding) + lineSep=lineSep, encoding=encoding, ignoreNullFields=ignoreNullFields) self._jwrite.json(path) @since(1.4) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index e7bfb77e46c2..4952540f1132 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -78,8 +78,8 @@ private[sql] class JSONOptions( val dropFieldIfAllNull = parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false) // Whether to ignore null fields during json generating - val ignoreNullFields = parameters.getOrElse("ignoreNullFields", - SQLConf.get.jsonGeneratorIgnoreNullFields).toBoolean + val ignoreNullFields = parameters.get("ignoreNullFields").map(_.toBoolean) + .getOrElse(SQLConf.get.jsonGeneratorIgnoreNullFields) // A language tag in IETF BCP 47 format val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7f75bf84d65a..a228d9f064a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1196,9 +1196,11 @@ object SQLConf { val JSON_GENERATOR_IGNORE_NULL_FIELDS = buildConf("spark.sql.jsonGenerator.ignoreNullFields") - .doc("If false, JacksonGenerator will generate null for null fields in Struct.") - .stringConf - .createWithDefault("true") + .doc("Whether to ignore null fields when generating JSON objects in JSON data source and " + + "JSON functions such as to_json. " + + "If false, it generates null for null fields in JSON objects.") + .booleanConf + .createWithDefault(true) val FILE_SINK_LOG_DELETION = buildConf("spark.sql.streaming.fileSink.log.deletion") .internal() @@ -2392,7 +2394,7 @@ class SQLConf extends Serializable with Logging { def sessionLocalTimeZone: String = getConf(SQLConf.SESSION_LOCAL_TIMEZONE) - def jsonGeneratorIgnoreNullFields: String = getConf(SQLConf.JSON_GENERATOR_IGNORE_NULL_FIELDS) + def jsonGeneratorIgnoreNullFields: Boolean = getConf(SQLConf.JSON_GENERATOR_IGNORE_NULL_FIELDS) def parallelFileListingInStatsComputation: Boolean = getConf(SQLConf.PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 4f88cc6daa33..68127c27a8cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -687,6 +687,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `encoding` (by default it is not set): specifies encoding (charset) of saved json * files. If it is not set, the UTF-8 charset will be used.
  • *
  • `lineSep` (default `\n`): defines the line separator that should be used for writing.
  • + *
  • `ignoreNullFields` (default `true`): Whether to ignore null fields + * when generating JSON objects.
  • * * * @since 1.4.0 From 92b25295ca0dc5b80aaddb1c8f8d5ef0a250d11e Mon Sep 17 00:00:00 2001 From: fuwhu Date: Thu, 24 Oct 2019 12:35:32 -0700 Subject: [PATCH 39/58] [SPARK-21287][SQL] Remove requirement of fetch_size>=0 from JDBCOptions ### What changes were proposed in this pull request? Remove the requirement of fetch_size>=0 from JDBCOptions to allow negative fetch size. ### Why are the changes needed? Namely, to allow data fetch in stream manner (row-by-row fetch) against MySQL database. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Unit test (JDBCSuite) This closes #26230 . Closes #26244 from fuwhu/SPARK-21287-FIX. Authored-by: fuwhu Signed-off-by: Dongjoon Hyun --- .../sql/execution/datasources/jdbc/JDBCOptions.scala | 9 +-------- .../test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 9 --------- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index d184f3cb71b1..5d1feaed81a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -147,14 +147,7 @@ class JDBCOptions( """.stripMargin ) - val fetchSize = { - val size = parameters.getOrElse(JDBC_BATCH_FETCH_SIZE, "0").toInt - require(size >= 0, - s"Invalid value `${size.toString}` for parameter " + - s"`$JDBC_BATCH_FETCH_SIZE`. The minimum value is 0. When the value is 0, " + - "the JDBC driver ignores the value and does the estimates.") - size - } + val fetchSize = parameters.getOrElse(JDBC_BATCH_FETCH_SIZE, "0").toInt // ------------------------------------------------------------ // Optional parameters only for writing diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 3c8ce0a3fc3e..715534b0458d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -450,15 +450,6 @@ class JDBCSuite extends QueryTest urlWithUserAndPass, "TEST.PEOPLE", new Properties()).collect().length === 3) } - test("Basic API with illegal fetchsize") { - val properties = new Properties() - properties.setProperty(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "-1") - val e = intercept[IllegalArgumentException] { - spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", properties).collect() - }.getMessage - assert(e.contains("Invalid value `-1` for parameter `fetchsize`")) - } - test("Missing partition columns") { withView("tempPeople") { val e = intercept[IllegalArgumentException] { From dec99d8ac5aeda045e611fe2f9e27facd4cecef4 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Thu, 24 Oct 2019 14:51:23 -0700 Subject: [PATCH 40/58] [SPARK-29526][SQL] UNCACHE TABLE should look up catalog/table like v2 commands ### What changes were proposed in this pull request? Add UncacheTableStatement and make UNCACHE TABLE go through the same catalog/table resolution framework of v2 commands. ### Why are the changes needed? It's important to make all the commands have the same table resolution behavior, to avoid confusing end-users. e.g. ``` USE my_catalog DESC t // success and describe the table t from my_catalog UNCACHE TABLE t // report table not found as there is no table t in the session catalog ``` ### Does this PR introduce any user-facing change? yes. When running UNCACHE TABLE, Spark fails the command if the current catalog is set to a v2 catalog, or the table name specified a v2 catalog. ### How was this patch tested? New unit tests Closes #26237 from imback82/uncache_table. Authored-by: Terry Kim Signed-off-by: Dongjoon Hyun --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../spark/sql/catalyst/parser/AstBuilder.scala | 7 +++++++ .../sql/catalyst/plans/logical/statements.scala | 7 +++++++ .../sql/catalyst/parser/DDLParserSuite.scala | 16 +++++++++++++--- .../analysis/ResolveSessionCatalog.scala | 6 +++++- .../spark/sql/execution/SparkSqlParser.scala | 7 ------- .../sql/connector/DataSourceV2SQLSuite.scala | 10 ++++++++++ 7 files changed, 43 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 82401f91e31d..1e89507411ad 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -211,7 +211,7 @@ statement | REFRESH (STRING | .*?) #refreshResource | CACHE LAZY? TABLE multipartIdentifier (OPTIONS options=tablePropertyList)? (AS? query)? #cacheTable - | UNCACHE TABLE (IF EXISTS)? tableIdentifier #uncacheTable + | UNCACHE TABLE (IF EXISTS)? multipartIdentifier #uncacheTable | CLEAR CACHE #clearCache | LOAD DATA LOCAL? INPATH path=STRING OVERWRITE? INTO TABLE tableIdentifier partitionSpec? #loadData diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d8e1a0cdcb10..b030227b4881 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2821,6 +2821,13 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging CacheTableStatement(tableName, query, ctx.LAZY != null, options) } + /** + * Create an [[UncacheTableStatement]] logical plan. + */ + override def visitUncacheTable(ctx: UncacheTableContext): LogicalPlan = withOrigin(ctx) { + UncacheTableStatement(visitMultipartIdentifier(ctx.multipartIdentifier), ctx.EXISTS != null) + } + /** * Create a [[TruncateTableStatement]] command. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index 4a91ee6d52d9..ef8c92269434 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -339,6 +339,13 @@ case class CacheTableStatement( isLazy: Boolean, options: Map[String, String]) extends ParsedStatement +/** + * An UNCACHE TABLE statement, as parsed from SQL + */ +case class UncacheTableStatement( + tableName: Seq[String], + ifExists: Boolean) extends ParsedStatement + /** * A TRUNCATE TABLE statement, as parsed from SQL */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 37349f7a3342..f4375956f0af 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -1039,13 +1039,13 @@ class DDLParserSuite extends AnalysisTest { "missing 'COLUMNS' at ''") } - test("MSCK REPAIR table") { + test("MSCK REPAIR TABLE") { comparePlans( parsePlan("MSCK REPAIR TABLE a.b.c"), RepairTableStatement(Seq("a", "b", "c"))) } - test("CACHE table") { + test("CACHE TABLE") { comparePlans( parsePlan("CACHE TABLE a.b.c"), CacheTableStatement(Seq("a", "b", "c"), None, false, Map.empty)) @@ -1062,6 +1062,16 @@ class DDLParserSuite extends AnalysisTest { "It is not allowed to add catalog/namespace prefix a.b") } + test("UNCACHE TABLE") { + comparePlans( + parsePlan("UNCACHE TABLE a.b.c"), + UncacheTableStatement(Seq("a", "b", "c"), ifExists = false)) + + comparePlans( + parsePlan("UNCACHE TABLE IF EXISTS a.b.c"), + UncacheTableStatement(Seq("a", "b", "c"), ifExists = true)) + } + test("TRUNCATE table") { comparePlans( parsePlan("TRUNCATE TABLE a.b.c"), @@ -1098,7 +1108,7 @@ class DDLParserSuite extends AnalysisTest { comparePlans(parsed5, expected5) } - test("REFRESH TABLE table") { + test("REFRESH TABLE") { comparePlans( parsePlan("REFRESH TABLE a.b.c"), RefreshTableStatement(Seq("a", "b", "c"))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 65d95b600eaa..f91686cb544c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog, TableChange, V1Table} import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableRecoverPartitionsCommand, AlterTableSetLocationCommand, AlterTableSetPropertiesCommand, AlterTableUnsetPropertiesCommand, AnalyzeColumnCommand, AnalyzePartitionCommand, AnalyzeTableCommand, CacheTableCommand, CreateDatabaseCommand, DescribeColumnCommand, DescribeTableCommand, DropTableCommand, ShowPartitionsCommand, ShowTablesCommand, TruncateTableCommand} +import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableRecoverPartitionsCommand, AlterTableSetLocationCommand, AlterTableSetPropertiesCommand, AlterTableUnsetPropertiesCommand, AnalyzeColumnCommand, AnalyzePartitionCommand, AnalyzeTableCommand, CacheTableCommand, CreateDatabaseCommand, DescribeColumnCommand, DescribeTableCommand, DropTableCommand, ShowPartitionsCommand, ShowTablesCommand, TruncateTableCommand, UncacheTableCommand} import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, RefreshTable} import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.internal.SQLConf @@ -303,6 +303,10 @@ class ResolveSessionCatalog( val v1TableName = parseV1Table(tableName, "CACHE TABLE") CacheTableCommand(v1TableName.asTableIdentifier, plan, isLazy, options) + case UncacheTableStatement(tableName, ifExists) => + val v1TableName = parseV1Table(tableName, "UNCACHE TABLE") + UncacheTableCommand(v1TableName.asTableIdentifier, ifExists) + case TruncateTableStatement(tableName, partitionSpec) => val v1TableName = parseV1Table(tableName, "TRUNCATE TABLE") TruncateTableCommand( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index fb13d01bd91d..aef0a2d2e595 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -165,13 +165,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { unquotedPath } - /** - * Create an [[UncacheTableCommand]] logical plan. - */ - override def visitUncacheTable(ctx: UncacheTableContext): LogicalPlan = withOrigin(ctx) { - UncacheTableCommand(visitTableIdentifier(ctx.tableIdentifier), ctx.EXISTS != null) - } - /** * Create a [[ClearCacheCommand]] logical plan. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 4d1e70f68ba0..4f2c1af8f7b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1282,6 +1282,16 @@ class DataSourceV2SQLSuite } } + test("UNCACHE TABLE") { + val t = "testcat.ns1.ns2.tbl" + withTable(t) { + sql(s"CREATE TABLE $t (id bigint, data string) USING foo") + + testV1Command("UNCACHE TABLE", t) + testV1Command("UNCACHE TABLE", s"IF EXISTS $t") + } + } + private def testV1Command(sqlCommand: String, sqlParams: String): Unit = { val e = intercept[AnalysisException] { sql(s"$sqlCommand $sqlParams") From 40df9d246e4049d7f858d4ff98801935fa9ee861 Mon Sep 17 00:00:00 2001 From: wenxuanguan Date: Fri, 25 Oct 2019 10:02:54 +0900 Subject: [PATCH 41/58] [SPARK-29227][SS] Track rule info in optimization phase ### What changes were proposed in this pull request? Track timing info for each rule in optimization phase using `QueryPlanningTracker` in Structured Streaming ### Why are the changes needed? In Structured Streaming we only track rule info in analysis phase, not in optimization phase. ### Does this PR introduce any user-facing change? No Closes #25914 from wenxuanguan/spark-29227. Authored-by: wenxuanguan Signed-off-by: HyukjinKwon --- .../streaming/IncrementalExecution.scala | 3 ++- .../QueryPlanningTrackerEndToEndSuite.scala | 24 +++++++++++++++++-- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index af52af0d1d7e..b8e18b89b54b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -77,7 +77,8 @@ class IncrementalExecution( */ override lazy val optimizedPlan: LogicalPlan = tracker.measurePhase(QueryPlanningTracker.OPTIMIZATION) { - sparkSession.sessionState.optimizer.execute(withCachedData) transformAllExpressions { + sparkSession.sessionState.optimizer.executeAndTrack(withCachedData, + tracker) transformAllExpressions { case ts @ CurrentBatchTimestamp(timestamp, _, _) => logInfo(s"Current batch timestamp = $timestamp") ts.toLiteral diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryPlanningTrackerEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryPlanningTrackerEndToEndSuite.scala index 76006efda992..987338cf6cbb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryPlanningTrackerEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryPlanningTrackerEndToEndSuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamExecution} +import org.apache.spark.sql.streaming.StreamTest -class QueryPlanningTrackerEndToEndSuite extends SharedSparkSession { +class QueryPlanningTrackerEndToEndSuite extends StreamTest { + import testImplicits._ test("programmatic API") { val df = spark.range(1000).selectExpr("count(*)") @@ -38,4 +40,22 @@ class QueryPlanningTrackerEndToEndSuite extends SharedSparkSession { assert(tracker.rules.nonEmpty) } + test("SPARK-29227: Track rule info in optimization phase in streaming") { + val inputData = MemoryStream[Int] + val df = inputData.toDF() + + def assertStatus(stream: StreamExecution): Unit = { + stream.processAllAvailable() + val tracker = stream.lastExecution.tracker + assert(tracker.phases.keys == Set("analysis", "optimization", "planning")) + assert(tracker.rules.nonEmpty) + } + + testStream(df)( + StartStream(), + AddData(inputData, 1, 2, 3), + Execute(assertStatus), + StopStream) + } + } From 7417c3e7d5a890b93420e6b4c507e6805e633cca Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 24 Oct 2019 20:51:31 -0700 Subject: [PATCH 42/58] [SPARK-29597][DOCS] Deprecate old Java 8 versions prior to 8u92 ### What changes were proposed in this pull request? This PR aims to deprecate old Java 8 versions prior to 8u92. ### Why are the changes needed? This is a preparation to use JVM Option `ExitOnOutOfMemoryError`. - https://www.oracle.com/technetwork/java/javase/8u92-relnotes-2949471.html ### Does this PR introduce any user-facing change? Yes. It's highly recommended for users to use the latest JDK versions of Java 8/11. ### How was this patch tested? NA (This is a doc change). Closes #26249 from dongjoon-hyun/SPARK-29597. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- docs/index.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/index.md b/docs/index.md index edb1c421fb79..9e8af0d5f8e2 100644 --- a/docs/index.md +++ b/docs/index.md @@ -47,6 +47,7 @@ locally on one machine --- all you need is to have `java` installed on your syst or the `JAVA_HOME` environment variable pointing to a Java installation. Spark runs on Java 8/11, Scala 2.12, Python 2.7+/3.4+ and R 3.1+. +Java 8 prior to version 8u92 support is deprecated as of Spark 3.0.0. Python 2 support is deprecated as of Spark 3.0.0. R prior to version 3.4 support is deprecated as of Spark 3.0.0. For the Scala API, Spark {{site.SPARK_VERSION}} From 1474ed05fb2d3e9324f17e4bf4f5702037d0be62 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 24 Oct 2019 22:18:10 -0700 Subject: [PATCH 43/58] [SPARK-29562][SQL] Speed up and slim down metric aggregation in SQL listener First, a bit of background on the code being changed. The current code tracks metric updates for each task, recording which metrics the task is monitoring and the last update value. Once a SQL execution finishes, then the metrics for all the stages are aggregated, by building a list with all (metric ID, value) pairs collected for all tasks in the stages related to the execution, then grouping by metric ID, and then calculating the values shown in the UI. That is full of inefficiencies: - in normal operation, all tasks will be tracking and updating the same metrics. So recording the metric IDs per task is wasteful. - tracking by task means we might be double-counting values if you have speculative tasks (as a comment in the code mentions). - creating a list of (metric ID, value) is extremely inefficient, because now you have a huge map in memory storing boxed versions of the metric IDs and values. - same thing for the aggregation part, where now a Seq is built with the values for each metric ID. The end result is that for large queries, this code can become both really slow, thus affecting the processing of events, and memory hungry. The updated code changes the approach to the following: - stages track metrics by their ID; this means the stage tracking code naturally groups values, making aggregation later simpler. - each metric ID being tracked uses a long array matching the number of partitions of the stage; this means that it's cheap to update the value of the metric once a task ends. - when aggregating, custom code just concatenates the arrays corresponding to the matching metric IDs; this is cheaper than the previous, boxing-heavy approach. The end result is that the listener uses about half as much memory as before for tracking metrics, since it doesn't need to track metric IDs per task. I captured heap dumps with the old and the new code during metric aggregation in the listener, for an execution with 3 stages, 100k tasks per stage, 50 metrics updated per task. The dumps contained just reachable memory - so data kept by the listener plus the variables in the aggregateMetrics() method. With the old code, the thread doing aggregation references >1G of memory - and that does not include temporary data created by the "groupBy" transformation (for which the intermediate state is not referenced in the aggregation method). The same thread with the new code references ~250M of memory. The old code uses about ~250M to track all the metric values for that execution, while the new code uses about ~130M. (Note the per-thread numbers include the amount used to track the metrics - so, e.g., in the old case, aggregation was referencing about ~750M of temporary data.) I'm also including a small benchmark (based on the Benchmark class) so that we can measure how much changes to this code affect performance. The benchmark contains some extra code to measure things the normal Benchmark class does not, given that the code under test does not really map that well to the expectations of that class. Running with the old code (I removed results that don't make much sense for this benchmark): ``` [info] Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Linux 4.15.0-66-generic [info] Intel(R) Core(TM) i7-6820HQ CPU 2.70GHz [info] metrics aggregation (50 metrics, 100k tasks per stage): Best Time(ms) Avg Time(ms) [info] -------------------------------------------------------------------------------------- [info] 1 stage(s) 2113 2118 [info] 2 stage(s) 4172 4392 [info] 3 stage(s) 7755 8460 [info] [info] Stage Count Stage Proc. Time Aggreg. Time [info] 1 614 1187 [info] 2 620 2480 [info] 3 718 5069 ``` With the new code: ``` [info] Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Linux 4.15.0-66-generic [info] Intel(R) Core(TM) i7-6820HQ CPU 2.70GHz [info] metrics aggregation (50 metrics, 100k tasks per stage): Best Time(ms) Avg Time(ms) [info] -------------------------------------------------------------------------------------- [info] 1 stage(s) 727 886 [info] 2 stage(s) 1722 1983 [info] 3 stage(s) 2752 3013 [info] [info] Stage Count Stage Proc. Time Aggreg. Time [info] 1 408 177 [info] 2 389 423 [info] 3 372 660 ``` So the new code is faster than the old when processing task events, and about an order of maginute faster when aggregating metrics. Note this still leaves room for improvement; for example, using the above measurements, 600ms is still a huge amount of time to spend in an event handler. But I'll leave further enhancements for a separate change. Tested with benchmarking code + existing unit tests. Closes #26218 from vanzin/SPARK-29562. Authored-by: Marcelo Vanzin Signed-off-by: Dongjoon Hyun --- ...ricsAggregationBenchmark-jdk11-results.txt | 12 + .../MetricsAggregationBenchmark-results.txt | 12 + .../sql/execution/metric/SQLMetrics.scala | 14 +- .../execution/ui/SQLAppStatusListener.scala | 190 ++++++++++----- .../metric/SQLMetricsTestUtils.scala | 3 +- .../ui/MetricsAggregationBenchmark.scala | 219 ++++++++++++++++++ .../ui/SQLAppStatusListenerSuite.scala | 19 +- 7 files changed, 397 insertions(+), 72 deletions(-) create mode 100644 sql/core/benchmarks/MetricsAggregationBenchmark-jdk11-results.txt create mode 100644 sql/core/benchmarks/MetricsAggregationBenchmark-results.txt create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala diff --git a/sql/core/benchmarks/MetricsAggregationBenchmark-jdk11-results.txt b/sql/core/benchmarks/MetricsAggregationBenchmark-jdk11-results.txt new file mode 100644 index 000000000000..e33ed30eaa55 --- /dev/null +++ b/sql/core/benchmarks/MetricsAggregationBenchmark-jdk11-results.txt @@ -0,0 +1,12 @@ +OpenJDK 64-Bit Server VM 11.0.4+11 on Linux 4.15.0-66-generic +Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz +metrics aggregation (50 metrics, 100000 tasks per stage): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +1 stage(s) 672 841 179 0.0 671888474.0 1.0X +2 stage(s) 1700 1842 201 0.0 1699591662.0 0.4X +3 stage(s) 2601 2776 247 0.0 2601465786.0 0.3X + +Stage Count Stage Proc. Time Aggreg. Time + 1 436 164 + 2 537 354 + 3 480 602 diff --git a/sql/core/benchmarks/MetricsAggregationBenchmark-results.txt b/sql/core/benchmarks/MetricsAggregationBenchmark-results.txt new file mode 100644 index 000000000000..4fae928258d3 --- /dev/null +++ b/sql/core/benchmarks/MetricsAggregationBenchmark-results.txt @@ -0,0 +1,12 @@ +Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Linux 4.15.0-66-generic +Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz +metrics aggregation (50 metrics, 100000 tasks per stage): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +1 stage(s) 740 883 147 0.0 740089816.0 1.0X +2 stage(s) 1661 1943 399 0.0 1660649192.0 0.4X +3 stage(s) 2711 2967 362 0.0 2711110178.0 0.3X + +Stage Count Stage Proc. Time Aggreg. Time + 1 405 179 + 2 375 414 + 3 364 644 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 19809b07508d..b7f0ab2969e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.metric import java.text.NumberFormat -import java.util.Locale +import java.util.{Arrays, Locale} import scala.concurrent.duration._ @@ -150,7 +150,7 @@ object SQLMetrics { * A function that defines how we aggregate the final accumulator results among all tasks, * and represent it in string for a SQL physical operator. */ - def stringValue(metricsType: String, values: Seq[Long]): String = { + def stringValue(metricsType: String, values: Array[Long]): String = { if (metricsType == SUM_METRIC) { val numberFormat = NumberFormat.getIntegerInstance(Locale.US) numberFormat.format(values.sum) @@ -162,8 +162,9 @@ object SQLMetrics { val metric = if (validValues.isEmpty) { Seq.fill(3)(0L) } else { - val sorted = validValues.sorted - Seq(sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + Arrays.sort(validValues) + Seq(validValues(0), validValues(validValues.length / 2), + validValues(validValues.length - 1)) } metric.map(v => numberFormat.format(v.toDouble / baseForAvgMetric)) } @@ -184,8 +185,9 @@ object SQLMetrics { val metric = if (validValues.isEmpty) { Seq.fill(4)(0L) } else { - val sorted = validValues.sorted - Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + Arrays.sort(validValues) + Seq(validValues.sum, validValues(0), validValues(validValues.length / 2), + validValues(validValues.length - 1)) } metric.map(strFormat) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 2c4a7eacdf10..da526612e7bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -16,10 +16,11 @@ */ package org.apache.spark.sql.execution.ui -import java.util.{Date, NoSuchElementException} +import java.util.{Arrays, Date, NoSuchElementException} import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ +import scala.collection.mutable import org.apache.spark.{JobExecutionStatus, SparkConf} import org.apache.spark.internal.Logging @@ -29,6 +30,7 @@ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.metric._ import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.status.{ElementTrackingStore, KVUtils, LiveEntity} +import org.apache.spark.util.collection.OpenHashMap class SQLAppStatusListener( conf: SparkConf, @@ -103,8 +105,10 @@ class SQLAppStatusListener( // Record the accumulator IDs for the stages of this job, so that the code that keeps // track of the metrics knows which accumulators to look at. val accumIds = exec.metrics.map(_.accumulatorId).toSet - event.stageIds.foreach { id => - stageMetrics.put(id, new LiveStageMetrics(id, 0, accumIds, new ConcurrentHashMap())) + if (accumIds.nonEmpty) { + event.stageInfos.foreach { stage => + stageMetrics.put(stage.stageId, new LiveStageMetrics(0, stage.numTasks, accumIds)) + } } exec.jobs = exec.jobs + (jobId -> JobExecutionStatus.RUNNING) @@ -118,9 +122,11 @@ class SQLAppStatusListener( } // Reset the metrics tracking object for the new attempt. - Option(stageMetrics.get(event.stageInfo.stageId)).foreach { metrics => - metrics.taskMetrics.clear() - metrics.attemptId = event.stageInfo.attemptNumber + Option(stageMetrics.get(event.stageInfo.stageId)).foreach { stage => + if (stage.attemptId != event.stageInfo.attemptNumber) { + stageMetrics.put(event.stageInfo.stageId, + new LiveStageMetrics(event.stageInfo.attemptNumber, stage.numTasks, stage.accumulatorIds)) + } } } @@ -140,7 +146,16 @@ class SQLAppStatusListener( override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { event.accumUpdates.foreach { case (taskId, stageId, attemptId, accumUpdates) => - updateStageMetrics(stageId, attemptId, taskId, accumUpdates, false) + updateStageMetrics(stageId, attemptId, taskId, SQLAppStatusListener.UNKNOWN_INDEX, + accumUpdates, false) + } + } + + override def onTaskStart(event: SparkListenerTaskStart): Unit = { + Option(stageMetrics.get(event.stageId)).foreach { stage => + if (stage.attemptId == event.stageAttemptId) { + stage.registerTask(event.taskInfo.taskId, event.taskInfo.index) + } } } @@ -165,7 +180,7 @@ class SQLAppStatusListener( } else { info.accumulables } - updateStageMetrics(event.stageId, event.stageAttemptId, info.taskId, accums, + updateStageMetrics(event.stageId, event.stageAttemptId, info.taskId, info.index, accums, info.successful) } @@ -181,17 +196,40 @@ class SQLAppStatusListener( private def aggregateMetrics(exec: LiveExecutionData): Map[Long, String] = { val metricTypes = exec.metrics.map { m => (m.accumulatorId, m.metricType) }.toMap - val metrics = exec.stages.toSeq + + val taskMetrics = exec.stages.toSeq .flatMap { stageId => Option(stageMetrics.get(stageId)) } - .flatMap(_.taskMetrics.values().asScala) - .flatMap { metrics => metrics.ids.zip(metrics.values) } - - val aggregatedMetrics = (metrics ++ exec.driverAccumUpdates.toSeq) - .filter { case (id, _) => metricTypes.contains(id) } - .groupBy(_._1) - .map { case (id, values) => - id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2)) + .flatMap(_.metricValues()) + + val allMetrics = new mutable.HashMap[Long, Array[Long]]() + + taskMetrics.foreach { case (id, values) => + val prev = allMetrics.getOrElse(id, null) + val updated = if (prev != null) { + prev ++ values + } else { + values } + allMetrics(id) = updated + } + + exec.driverAccumUpdates.foreach { case (id, value) => + if (metricTypes.contains(id)) { + val prev = allMetrics.getOrElse(id, null) + val updated = if (prev != null) { + val _copy = Arrays.copyOf(prev, prev.length + 1) + _copy(prev.length) = value + _copy + } else { + Array(value) + } + allMetrics(id) = updated + } + } + + val aggregatedMetrics = allMetrics.map { case (id, values) => + id -> SQLMetrics.stringValue(metricTypes(id), values) + }.toMap // Check the execution again for whether the aggregated metrics data has been calculated. // This can happen if the UI is requesting this data, and the onExecutionEnd handler is @@ -208,43 +246,13 @@ class SQLAppStatusListener( stageId: Int, attemptId: Int, taskId: Long, + taskIdx: Int, accumUpdates: Seq[AccumulableInfo], succeeded: Boolean): Unit = { Option(stageMetrics.get(stageId)).foreach { metrics => - if (metrics.attemptId != attemptId || metrics.accumulatorIds.isEmpty) { - return - } - - val oldTaskMetrics = metrics.taskMetrics.get(taskId) - if (oldTaskMetrics != null && oldTaskMetrics.succeeded) { - return + if (metrics.attemptId == attemptId) { + metrics.updateTaskMetrics(taskId, taskIdx, succeeded, accumUpdates) } - - val updates = accumUpdates - .filter { acc => acc.update.isDefined && metrics.accumulatorIds.contains(acc.id) } - .sortBy(_.id) - - if (updates.isEmpty) { - return - } - - val ids = new Array[Long](updates.size) - val values = new Array[Long](updates.size) - updates.zipWithIndex.foreach { case (acc, idx) => - ids(idx) = acc.id - // In a live application, accumulators have Long values, but when reading from event - // logs, they have String values. For now, assume all accumulators are Long and covert - // accordingly. - values(idx) = acc.update.get match { - case s: String => s.toLong - case l: Long => l - case o => throw new IllegalArgumentException(s"Unexpected: $o") - } - } - - // TODO: storing metrics by task ID can cause metrics for the same task index to be - // counted multiple times, for example due to speculation or re-attempts. - metrics.taskMetrics.put(taskId, new LiveTaskMetrics(ids, values, succeeded)) } } @@ -425,12 +433,76 @@ private class LiveExecutionData(val executionId: Long) extends LiveEntity { } private class LiveStageMetrics( - val stageId: Int, - var attemptId: Int, - val accumulatorIds: Set[Long], - val taskMetrics: ConcurrentHashMap[Long, LiveTaskMetrics]) - -private class LiveTaskMetrics( - val ids: Array[Long], - val values: Array[Long], - val succeeded: Boolean) + val attemptId: Int, + val numTasks: Int, + val accumulatorIds: Set[Long]) { + + /** + * Mapping of task IDs to their respective index. Note this may contain more elements than the + * stage's number of tasks, if speculative execution is on. + */ + private val taskIndices = new OpenHashMap[Long, Int]() + + /** Bit set tracking which indices have been successfully computed. */ + private val completedIndices = new mutable.BitSet() + + /** + * Task metrics values for the stage. Maps the metric ID to the metric values for each + * index. For each metric ID, there will be the same number of values as the number + * of indices. This relies on `SQLMetrics.stringValue` treating 0 as a neutral value, + * independent of the actual metric type. + */ + private val taskMetrics = new ConcurrentHashMap[Long, Array[Long]]() + + def registerTask(taskId: Long, taskIdx: Int): Unit = { + taskIndices.update(taskId, taskIdx) + } + + def updateTaskMetrics( + taskId: Long, + eventIdx: Int, + finished: Boolean, + accumUpdates: Seq[AccumulableInfo]): Unit = { + val taskIdx = if (eventIdx == SQLAppStatusListener.UNKNOWN_INDEX) { + if (!taskIndices.contains(taskId)) { + // We probably missed the start event for the task, just ignore it. + return + } + taskIndices(taskId) + } else { + // Here we can recover from a missing task start event. Just register the task again. + registerTask(taskId, eventIdx) + eventIdx + } + + if (completedIndices.contains(taskIdx)) { + return + } + + accumUpdates + .filter { acc => acc.update.isDefined && accumulatorIds.contains(acc.id) } + .foreach { acc => + // In a live application, accumulators have Long values, but when reading from event + // logs, they have String values. For now, assume all accumulators are Long and convert + // accordingly. + val value = acc.update.get match { + case s: String => s.toLong + case l: Long => l + case o => throw new IllegalArgumentException(s"Unexpected: $o") + } + + val metricValues = taskMetrics.computeIfAbsent(acc.id, _ => new Array(numTasks)) + metricValues(taskIdx) = value + } + + if (finished) { + completedIndices += taskIdx + } + } + + def metricValues(): Seq[(Long, Array[Long])] = taskMetrics.asScala.toSeq +} + +private object SQLAppStatusListener { + val UNKNOWN_INDEX = -1 +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 5ab9b6f5fc2d..57731e5f4920 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -232,7 +232,8 @@ trait SQLMetricsTestUtils extends SQLTestUtils { val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId) assert(expectedNodeName === actualNodeName) for ((metricName, metricPredicate) <- expectedMetricsPredicatesMap) { - assert(metricPredicate(actualMetricsMap(metricName))) + assert(metricPredicate(actualMetricsMap(metricName)), + s"$nodeId / '$metricName' (= ${actualMetricsMap(metricName)}) did not match predicate.") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala new file mode 100644 index 000000000000..a88abc8209a8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import java.util.Properties +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable +import scala.concurrent.duration._ + +import org.apache.spark.{SparkConf, TaskState} +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} +import org.apache.spark.executor.ExecutorMetrics +import org.apache.spark.internal.config.Status._ +import org.apache.spark.scheduler._ +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric.SQLMetricInfo +import org.apache.spark.status.ElementTrackingStore +import org.apache.spark.util.{AccumulatorMetadata, LongAccumulator, Utils} +import org.apache.spark.util.kvstore.InMemoryStore + +/** + * Benchmark for metrics aggregation in the SQL listener. + * {{{ + * To run this benchmark: + * 1. without sbt: bin/spark-submit --class --jars + * 2. build/sbt "core/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "core/test:runMain " + * Results will be written to "benchmarks/MetricsAggregationBenchmark-results.txt". + * }}} + */ +object MetricsAggregationBenchmark extends BenchmarkBase { + + private def metricTrackingBenchmark( + timer: Benchmark.Timer, + numMetrics: Int, + numTasks: Int, + numStages: Int): Measurements = { + val conf = new SparkConf() + .set(LIVE_ENTITY_UPDATE_PERIOD, 0L) + .set(ASYNC_TRACKING_ENABLED, false) + val kvstore = new ElementTrackingStore(new InMemoryStore(), conf) + val listener = new SQLAppStatusListener(conf, kvstore, live = true) + val store = new SQLAppStatusStore(kvstore, Some(listener)) + + val metrics = (0 until numMetrics).map { i => + new SQLMetricInfo(s"metric$i", i.toLong, "average") + } + + val planInfo = new SparkPlanInfo( + getClass().getName(), + getClass().getName(), + Nil, + Map.empty, + metrics) + + val idgen = new AtomicInteger() + val executionId = idgen.incrementAndGet() + val executionStart = SparkListenerSQLExecutionStart( + executionId, + getClass().getName(), + getClass().getName(), + getClass().getName(), + planInfo, + System.currentTimeMillis()) + + val executionEnd = SparkListenerSQLExecutionEnd(executionId, System.currentTimeMillis()) + + val properties = new Properties() + properties.setProperty(SQLExecution.EXECUTION_ID_KEY, executionId.toString) + + timer.startTiming() + listener.onOtherEvent(executionStart) + + val taskEventsTime = (0 until numStages).map { _ => + val stageInfo = new StageInfo(idgen.incrementAndGet(), 0, getClass().getName(), + numTasks, Nil, Nil, getClass().getName()) + + val jobId = idgen.incrementAndGet() + val jobStart = SparkListenerJobStart( + jobId = jobId, + time = System.currentTimeMillis(), + stageInfos = Seq(stageInfo), + properties) + + val stageStart = SparkListenerStageSubmitted(stageInfo) + + val taskOffset = idgen.incrementAndGet().toLong + val taskEvents = (0 until numTasks).map { i => + val info = new TaskInfo( + taskId = taskOffset + i.toLong, + index = i, + attemptNumber = 0, + // The following fields are not used. + launchTime = 0, + executorId = "", + host = "", + taskLocality = null, + speculative = false) + info.markFinished(TaskState.FINISHED, 1L) + + val accumulables = (0 until numMetrics).map { mid => + val acc = new LongAccumulator + acc.metadata = AccumulatorMetadata(mid, None, false) + acc.toInfo(Some(i.toLong), None) + } + + info.setAccumulables(accumulables) + + val start = SparkListenerTaskStart(stageInfo.stageId, stageInfo.attemptNumber, info) + val end = SparkListenerTaskEnd(stageInfo.stageId, stageInfo.attemptNumber, + taskType = "", + reason = null, + info, + new ExecutorMetrics(), + null) + + (start, end) + } + + val jobEnd = SparkListenerJobEnd( + jobId = jobId, + time = System.currentTimeMillis(), + JobSucceeded) + + listener.onJobStart(jobStart) + listener.onStageSubmitted(stageStart) + + val (_, _taskEventsTime) = Utils.timeTakenMs { + taskEvents.foreach { case (start, end) => + listener.onTaskStart(start) + listener.onTaskEnd(end) + } + } + + listener.onJobEnd(jobEnd) + _taskEventsTime + } + + val (_, aggTime) = Utils.timeTakenMs { + listener.onOtherEvent(executionEnd) + val metrics = store.executionMetrics(executionId) + assert(metrics.size == numMetrics, s"${metrics.size} != $numMetrics") + } + + timer.stopTiming() + kvstore.close() + + Measurements(taskEventsTime, aggTime) + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + val metricCount = 50 + val taskCount = 100000 + val stageCounts = Seq(1, 2, 3) + + val benchmark = new Benchmark( + s"metrics aggregation ($metricCount metrics, $taskCount tasks per stage)", 1, + warmupTime = 0.seconds, output = output) + + // Run this outside the measurement code so that classes are loaded and JIT is triggered, + // otherwise the first run tends to be much slower than others. Also because this benchmark is a + // bit weird and doesn't really map to what the Benchmark class expects, so it's a bit harder + // to use warmupTime and friends effectively. + stageCounts.foreach { count => + metricTrackingBenchmark(new Benchmark.Timer(-1), metricCount, taskCount, count) + } + + val measurements = mutable.HashMap[Int, Seq[Measurements]]() + + stageCounts.foreach { count => + benchmark.addTimerCase(s"$count stage(s)") { timer => + val m = metricTrackingBenchmark(timer, metricCount, taskCount, count) + val all = measurements.getOrElse(count, Nil) + measurements(count) = all ++ Seq(m) + } + } + + benchmark.run() + + benchmark.out.printf("Stage Count Stage Proc. Time Aggreg. Time\n") + stageCounts.foreach { count => + val data = measurements(count) + val eventsTimes = data.flatMap(_.taskEventsTimes) + val aggTimes = data.map(_.aggregationTime) + + val msg = " %d %d %d\n".format( + count, + eventsTimes.sum / eventsTimes.size, + aggTimes.sum / aggTimes.size) + benchmark.out.printf(msg) + } + } + + /** + * Finer-grained measurements of how long it takes to run some parts of the benchmark. This is + * collected by the benchmark method, so this collection slightly affects the overall benchmark + * results, but this data helps with seeing where the time is going, since this benchmark is + * triggering a whole lot of code in the listener class. + */ + case class Measurements( + taskEventsTimes: Seq[Long], + aggregationTime: Long) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index 88864ccec752..b8c0935b33a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -79,9 +79,9 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils private def createStageInfo(stageId: Int, attemptId: Int): StageInfo = { new StageInfo(stageId = stageId, attemptId = attemptId, + numTasks = 8, // The following fields are not used in tests name = "", - numTasks = 0, rddInfos = Nil, parentIds = Nil, details = "") @@ -94,8 +94,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils val info = new TaskInfo( taskId = taskId, attemptNumber = attemptNumber, + index = taskId.toInt, // The following fields are not used in tests - index = 0, launchTime = 0, executorId = "", host = "", @@ -190,6 +190,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils ), createProperties(executionId))) listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 0))) + listener.onTaskStart(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0))) + listener.onTaskStart(SparkListenerTaskStart(0, 0, createTaskInfo(1, 0))) assert(statusStore.executionMetrics(executionId).isEmpty) @@ -217,6 +219,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils // Retrying a stage should reset the metrics listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 1))) + listener.onTaskStart(SparkListenerTaskStart(0, 1, createTaskInfo(0, 0))) + listener.onTaskStart(SparkListenerTaskStart(0, 1, createTaskInfo(1, 0))) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, accum updates) @@ -260,6 +264,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils // Summit a new stage listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(1, 0))) + listener.onTaskStart(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0))) + listener.onTaskStart(SparkListenerTaskStart(1, 0, createTaskInfo(1, 0))) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, accum updates) @@ -490,8 +496,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils val statusStore = spark.sharedState.statusStore val oldCount = statusStore.executionsList().size - val expectedAccumValue = 12345 - val expectedAccumValue2 = 54321 + val expectedAccumValue = 12345L + val expectedAccumValue2 = 54321L val physicalPlan = MyPlan(sqlContext.sparkContext, expectedAccumValue, expectedAccumValue2) val dummyQueryExecution = new QueryExecution(spark, LocalRelation()) { override lazy val sparkPlan = physicalPlan @@ -517,8 +523,9 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils val metrics = statusStore.executionMetrics(execId) val driverMetric = physicalPlan.metrics("dummy") val driverMetric2 = physicalPlan.metrics("dummy2") - val expectedValue = SQLMetrics.stringValue(driverMetric.metricType, Seq(expectedAccumValue)) - val expectedValue2 = SQLMetrics.stringValue(driverMetric2.metricType, Seq(expectedAccumValue2)) + val expectedValue = SQLMetrics.stringValue(driverMetric.metricType, Array(expectedAccumValue)) + val expectedValue2 = SQLMetrics.stringValue(driverMetric2.metricType, + Array(expectedAccumValue2)) assert(metrics.contains(driverMetric.id)) assert(metrics(driverMetric.id) === expectedValue) From 091cbc3be0ab8678e9a6d21eb29a14dd554c9b39 Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Fri, 25 Oct 2019 13:48:09 +0800 Subject: [PATCH 44/58] [SPARK-9612][ML] Add instance weight support for GBTs ### What changes were proposed in this pull request? add weight support for GBTs by sampling data before passing it to trees and then passing weights to trees in summary: 1, add setters of `minWeightFractionPerNode` & `weightCol` 2, update input types in private methods from `RDD[LabeledPoint]` to `RDD[Instance]`: `DecisionTreeRegressor.train`, `GradientBoostedTrees.run`, `GradientBoostedTrees.runWithValidation`, `GradientBoostedTrees.computeInitialPredictionAndError`, `GradientBoostedTrees.computeError`, `GradientBoostedTrees.evaluateEachIteration`, `GradientBoostedTrees.boost`, `GradientBoostedTrees.updatePredictionError` 3, add new private method `GradientBoostedTrees.computeError(data, predError)` to compute average error, since original `predError.values.mean()` do not take weights into account. 4, add new tests ### Why are the changes needed? GBTs should support sample weights like other algs ### Does this PR introduce any user-facing change? yes, new setters are added ### How was this patch tested? existing & added testsuites Closes #25926 from zhengruifeng/gbt_add_weight. Authored-by: zhengruifeng Signed-off-by: zhengruifeng --- .../spark/ml/classification/Classifier.scala | 2 +- .../ml/classification/GBTClassifier.scala | 68 ++++----- .../apache/spark/ml/feature/Instance.scala | 2 +- .../ml/regression/DecisionTreeRegressor.scala | 7 +- .../spark/ml/regression/GBTRegressor.scala | 48 +++--- .../ml/tree/impl/GradientBoostedTrees.scala | 144 ++++++++++-------- .../mllib/tree/GradientBoostedTrees.scala | 17 ++- .../classification/GBTClassifierSuite.scala | 61 ++++++-- .../LogisticRegressionSuite.scala | 2 - .../ml/regression/GBTRegressorSuite.scala | 51 ++++++- .../tree/impl/GradientBoostedTreesSuite.scala | 25 ++- 11 files changed, 261 insertions(+), 166 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 9ac673078d4a..3bff236677e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -53,7 +53,7 @@ private[spark] trait ClassifierParams val validateInstance = (instance: Instance) => { val label = instance.label require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" + - s" dataset with invalid label $label. Labels must be integers in range" + + s" dataset with invalid label $label. Labels must be integers in range" + s" [0, $numClasses).") } extractInstances(dataset, validateInstance) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 74624be360c6..5bc45f2b02a4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -23,7 +23,7 @@ import org.json4s.JsonDSL._ import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.regression.DecisionTreeRegressionModel @@ -34,7 +34,7 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ /** @@ -79,6 +79,10 @@ class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + /** @group setParam */ + @Since("3.0.0") + def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value) + /** @group setParam */ @Since("1.4.0") def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) @@ -152,36 +156,34 @@ class GBTClassifier @Since("1.4.0") ( set(validationIndicatorCol, value) } + /** + * Sets the value of param [[weightCol]]. + * If this is not set or empty, we treat all instance weights as 1.0. + * By default the weightCol is not set, so all instances have weight 1.0. + * + * @group setParam + */ + @Since("3.0.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + override protected def train( dataset: Dataset[_]): GBTClassificationModel = instrumented { instr => - val categoricalFeatures: Map[Int, Int] = - MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) - val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty - // We copy and modify this from Classifier.extractLabeledPoints since GBT only supports - // 2 classes now. This lets us provide a more precise error message. - val convert2LabeledPoint = (dataset: Dataset[_]) => { - dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { - case Row(label: Double, features: Vector) => - require(label == 0 || label == 1, s"GBTClassifier was given" + - s" dataset with invalid label $label. Labels must be in {0,1}; note that" + - s" GBTClassifier currently only supports binary classification.") - LabeledPoint(label, features) - } + val validateInstance = (instance: Instance) => { + val label = instance.label + require(label == 0 || label == 1, s"GBTClassifier was given" + + s" dataset with invalid label $label. Labels must be in {0,1}; note that" + + s" GBTClassifier currently only supports binary classification.") } val (trainDataset, validationDataset) = if (withValidation) { - ( - convert2LabeledPoint(dataset.filter(not(col($(validationIndicatorCol))))), - convert2LabeledPoint(dataset.filter(col($(validationIndicatorCol)))) - ) + (extractInstances(dataset.filter(not(col($(validationIndicatorCol)))), validateInstance), + extractInstances(dataset.filter(col($(validationIndicatorCol))), validateInstance)) } else { - (convert2LabeledPoint(dataset), null) + (extractInstances(dataset, validateInstance), null) } - val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) - val numClasses = 2 if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + @@ -191,12 +193,14 @@ class GBTClassifier @Since("1.4.0") ( instr.logPipelineStage(this) instr.logDataset(dataset) - instr.logParams(this, labelCol, featuresCol, predictionCol, leafCol, impurity, - lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, - seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy, - validationIndicatorCol, validationTol) + instr.logParams(this, labelCol, weightCol, featuresCol, predictionCol, leafCol, + impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, + minInstancesPerNode, minWeightFractionPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, + checkpointInterval, featureSubsetStrategy, validationIndicatorCol, validationTol) instr.logNumClasses(numClasses) + val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) val (baseLearners, learnerWeights) = if (withValidation) { GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) @@ -374,12 +378,9 @@ class GBTClassificationModel private[ml]( */ @Since("2.4.0") def evaluateEachIteration(dataset: Dataset[_]): Array[Double] = { - val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { - case Row(label: Double, features: Vector) => LabeledPoint(label, features) - } + val data = extractInstances(dataset) GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, loss, - OldAlgo.Classification - ) + OldAlgo.Classification) } @Since("2.0.0") @@ -423,10 +424,9 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int] val numTrees = (metadata.metadata \ numTreesKey).extract[Int] - val trees: Array[DecisionTreeRegressionModel] = treesData.map { + val trees = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) treeMetadata.getAndSetParams(tree) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala index dd56fbbfa2b6..11d0c4689cbb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala @@ -26,7 +26,7 @@ import org.apache.spark.ml.linalg.Vector * @param weight The weight of this instance. * @param features The vector of features for this data point. */ -private[ml] case class Instance(label: Double, weight: Double, features: Vector) +private[spark] case class Instance(label: Double, weight: Double, features: Vector) /** * Case class that represents an instance of data point with diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 602b5fac20d3..05851d511675 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -23,7 +23,7 @@ import org.json4s.JsonDSL._ import org.apache.spark.annotation.Since import org.apache.spark.ml.{PredictionModel, Predictor} -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree._ @@ -132,15 +132,14 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S /** (private[ml]) Train a decision tree on an RDD */ private[ml] def train( - data: RDD[LabeledPoint], + data: RDD[Instance], oldStrategy: OldStrategy, featureSubsetStrategy: String): DecisionTreeRegressionModel = instrumented { instr => instr.logPipelineStage(this) instr.logDataset(data) instr.logParams(this, params: _*) - val instances = data.map(_.toInstance) - val trees = RandomForest.run(instances, oldStrategy, numTrees = 1, + val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy, seed = $(seed), instr = Some(instr), parentUID = Some(uid)) trees.head.asInstanceOf[DecisionTreeRegressionModel] diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 0cc06d82bf3f..9c38647642a6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -24,7 +24,6 @@ import org.json4s.JsonDSL._ import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml.{PredictionModel, Predictor} -import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree._ @@ -34,7 +33,7 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} -import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} +import org.apache.spark.sql.{Column, DataFrame, Dataset} import org.apache.spark.sql.functions._ /** @@ -78,6 +77,10 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) @Since("1.4.0") def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + /** @group setParam */ + @Since("3.0.0") + def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value) + /** @group setParam */ @Since("1.4.0") def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) @@ -151,29 +154,35 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) set(validationIndicatorCol, value) } - override protected def train(dataset: Dataset[_]): GBTRegressionModel = instrumented { instr => - val categoricalFeatures: Map[Int, Int] = - MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + /** + * Sets the value of param [[weightCol]]. + * If this is not set or empty, we treat all instance weights as 1.0. + * By default the weightCol is not set, so all instances have weight 1.0. + * + * @group setParam + */ + @Since("3.0.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + override protected def train(dataset: Dataset[_]): GBTRegressionModel = instrumented { instr => val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty val (trainDataset, validationDataset) = if (withValidation) { - ( - extractLabeledPoints(dataset.filter(not(col($(validationIndicatorCol))))), - extractLabeledPoints(dataset.filter(col($(validationIndicatorCol)))) - ) + (extractInstances(dataset.filter(not(col($(validationIndicatorCol))))), + extractInstances(dataset.filter(col($(validationIndicatorCol))))) } else { - (extractLabeledPoints(dataset), null) + (extractInstances(dataset), null) } - val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) instr.logPipelineStage(this) instr.logDataset(dataset) - instr.logParams(this, labelCol, featuresCol, predictionCol, leafCol, impurity, lossType, - maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, - seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy, - validationIndicatorCol, validationTol) + instr.logParams(this, labelCol, featuresCol, predictionCol, leafCol, weightCol, impurity, + lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, + minWeightFractionPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, + featureSubsetStrategy, validationIndicatorCol, validationTol) + val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) val (baseLearners, learnerWeights) = if (withValidation) { GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) @@ -323,9 +332,7 @@ class GBTRegressionModel private[ml]( */ @Since("2.4.0") def evaluateEachIteration(dataset: Dataset[_], loss: String): Array[Double] = { - val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { - case Row(label: Double, features: Vector) => LabeledPoint(label, features) - } + val data = extractInstances(dataset) GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, convertToOldLossType(loss), OldAlgo.Regression) } @@ -368,10 +375,9 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] - val trees: Array[DecisionTreeRegressionModel] = treesData.map { + val trees = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) treeMetadata.getAndSetParams(tree) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index c31334c92e1c..744708258b0a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.tree.impl import org.apache.spark.internal.Logging -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -34,13 +34,13 @@ private[spark] object GradientBoostedTrees extends Logging { /** * Method to train a gradient boosting model - * @param input Training dataset: RDD of `LabeledPoint`. + * @param input Training dataset: RDD of `Instance`. * @param seed Random seed. * @return tuple of ensemble models and weights: * (array of decision tree models, array of model weights) */ def run( - input: RDD[LabeledPoint], + input: RDD[Instance], boostingStrategy: OldBoostingStrategy, seed: Long, featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = { @@ -51,7 +51,7 @@ private[spark] object GradientBoostedTrees extends Logging { seed, featureSubsetStrategy) case OldAlgo.Classification => // Map labels to -1, +1 so binary classification can be treated as regression. - val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + val remappedInput = input.map(x => Instance((x.label * 2) - 1, x.weight, x.features)) GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false, seed, featureSubsetStrategy) case _ => @@ -61,7 +61,7 @@ private[spark] object GradientBoostedTrees extends Logging { /** * Method to validate a gradient boosting model - * @param input Training dataset: RDD of `LabeledPoint`. + * @param input Training dataset: RDD of `Instance`. * @param validationInput Validation dataset. * This dataset should be different from the training dataset, * but it should follow the same distribution. @@ -72,8 +72,8 @@ private[spark] object GradientBoostedTrees extends Logging { * (array of decision tree models, array of model weights) */ def runWithValidation( - input: RDD[LabeledPoint], - validationInput: RDD[LabeledPoint], + input: RDD[Instance], + validationInput: RDD[Instance], boostingStrategy: OldBoostingStrategy, seed: Long, featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = { @@ -85,9 +85,9 @@ private[spark] object GradientBoostedTrees extends Logging { case OldAlgo.Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map( - x => new LabeledPoint((x.label * 2) - 1, x.features)) + x => Instance((x.label * 2) - 1, x.weight, x.features)) val remappedValidationInput = validationInput.map( - x => new LabeledPoint((x.label * 2) - 1, x.features)) + x => Instance((x.label * 2) - 1, x.weight, x.features)) GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, validate = true, seed, featureSubsetStrategy) case _ => @@ -106,13 +106,13 @@ private[spark] object GradientBoostedTrees extends Logging { * corresponding to every sample. */ def computeInitialPredictionAndError( - data: RDD[LabeledPoint], + data: RDD[Instance], initTreeWeight: Double, initTree: DecisionTreeRegressionModel, loss: OldLoss): RDD[(Double, Double)] = { - data.map { lp => - val pred = updatePrediction(lp.features, 0.0, initTree, initTreeWeight) - val error = loss.computeError(pred, lp.label) + data.map { case Instance(label, _, features) => + val pred = updatePrediction(features, 0.0, initTree, initTreeWeight) + val error = loss.computeError(pred, label) (pred, error) } } @@ -129,20 +129,17 @@ private[spark] object GradientBoostedTrees extends Logging { * corresponding to each sample. */ def updatePredictionError( - data: RDD[LabeledPoint], + data: RDD[Instance], predictionAndError: RDD[(Double, Double)], treeWeight: Double, tree: DecisionTreeRegressionModel, loss: OldLoss): RDD[(Double, Double)] = { - - val newPredError = data.zip(predictionAndError).mapPartitions { iter => - iter.map { case (lp, (pred, error)) => - val newPred = updatePrediction(lp.features, pred, tree, treeWeight) - val newError = loss.computeError(newPred, lp.label) + data.zip(predictionAndError).map { + case (Instance(label, _, features), (pred, _)) => + val newPred = updatePrediction(features, pred, tree, treeWeight) + val newError = loss.computeError(newPred, label) (newPred, newError) - } } - newPredError } /** @@ -166,29 +163,50 @@ private[spark] object GradientBoostedTrees extends Logging { * Method to calculate error of the base learner for the gradient boosting calculation. * Note: This method is not used by the gradient boosting algorithm but is useful for debugging * purposes. - * @param data Training dataset: RDD of `LabeledPoint`. + * @param data Training dataset: RDD of `Instance`. * @param trees Boosted Decision Tree models * @param treeWeights Learning rates at each boosting iteration. * @param loss evaluation metric. * @return Measure of model error on data */ - def computeError( - data: RDD[LabeledPoint], + def computeWeightedError( + data: RDD[Instance], trees: Array[DecisionTreeRegressionModel], treeWeights: Array[Double], loss: OldLoss): Double = { - data.map { lp => + val (errSum, weightSum) = data.map { case Instance(label, weight, features) => val predicted = trees.zip(treeWeights).foldLeft(0.0) { case (acc, (model, weight)) => - updatePrediction(lp.features, acc, model, weight) + updatePrediction(features, acc, model, weight) } - loss.computeError(predicted, lp.label) - }.mean() + (loss.computeError(predicted, label) * weight, weight) + }.treeReduce { case ((err1, weight1), (err2, weight2)) => + (err1 + err2, weight1 + weight2) + } + errSum / weightSum + } + + /** + * Method to calculate error of the base learner for the gradient boosting calculation. + * @param data Training dataset: RDD of `Instance`. + * @param predError Prediction and error. + * @return Measure of model error on data + */ + def computeWeightedError( + data: RDD[Instance], + predError: RDD[(Double, Double)]): Double = { + val (errSum, weightSum) = data.zip(predError).map { + case (Instance(_, weight, _), (_, err)) => + (err * weight, weight) + }.treeReduce { case ((err1, weight1), (err2, weight2)) => + (err1 + err2, weight1 + weight2) + } + errSum / weightSum } /** * Method to compute error or loss for every iteration of gradient boosting. * - * @param data RDD of `LabeledPoint` + * @param data RDD of `Instance` * @param trees Boosted Decision Tree models * @param treeWeights Learning rates at each boosting iteration. * @param loss evaluation metric. @@ -197,41 +215,34 @@ private[spark] object GradientBoostedTrees extends Logging { * containing the first i+1 trees */ def evaluateEachIteration( - data: RDD[LabeledPoint], + data: RDD[Instance], trees: Array[DecisionTreeRegressionModel], treeWeights: Array[Double], loss: OldLoss, algo: OldAlgo.Value): Array[Double] = { - - val sc = data.sparkContext val remappedData = algo match { - case OldAlgo.Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + case OldAlgo.Classification => + data.map(x => Instance((x.label * 2) - 1, x.weight, x.features)) case _ => data } - val broadcastTrees = sc.broadcast(trees) - val localTreeWeights = treeWeights - val treesIndices = trees.indices - - val dataCount = remappedData.count() - val evaluation = remappedData.map { point => - treesIndices.map { idx => - val prediction = broadcastTrees.value(idx) - .rootNode - .predictImpl(point.features) - .prediction - prediction * localTreeWeights(idx) + val numTrees = trees.length + val (errSum, weightSum) = remappedData.mapPartitions { iter => + iter.map { case Instance(label, weight, features) => + val pred = Array.tabulate(numTrees) { i => + trees(i).rootNode.predictImpl(features) + .prediction * treeWeights(i) + } + val err = pred.scanLeft(0.0)(_ + _).drop(1) + .map(p => loss.computeError(p, label) * weight) + (err, weight) } - .scanLeft(0.0)(_ + _).drop(1) - .map(prediction => loss.computeError(prediction, point.label)) + }.treeReduce { case ((err1, weight1), (err2, weight2)) => + (0 until numTrees).foreach(i => err1(i) += err2(i)) + (err1, weight1 + weight2) } - .aggregate(treesIndices.map(_ => 0.0))( - (aggregated, row) => treesIndices.map(idx => aggregated(idx) + row(idx)), - (a, b) => treesIndices.map(idx => a(idx) + b(idx))) - .map(_ / dataCount) - broadcastTrees.destroy() - evaluation.toArray + errSum.map(_ / weightSum) } /** @@ -245,8 +256,8 @@ private[spark] object GradientBoostedTrees extends Logging { * (array of decision tree models, array of model weights) */ def boost( - input: RDD[LabeledPoint], - validationInput: RDD[LabeledPoint], + input: RDD[Instance], + validationInput: RDD[Instance], boostingStrategy: OldBoostingStrategy, validate: Boolean, seed: Long, @@ -280,8 +291,10 @@ private[spark] object GradientBoostedTrees extends Logging { } // Prepare periodic checkpointers + // Note: this is checkpointing the unweighted training error val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( treeStrategy.getCheckpointInterval, input.sparkContext) + // Note: this is checkpointing the unweighted validation error val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( treeStrategy.getCheckpointInterval, input.sparkContext) @@ -299,26 +312,29 @@ private[spark] object GradientBoostedTrees extends Logging { baseLearners(0) = firstTreeModel baseLearnerWeights(0) = firstTreeWeight - var predError: RDD[(Double, Double)] = - computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) + var predError = computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) predErrorCheckpointer.update(predError) - logDebug("error of gbt = " + predError.values.mean()) + logDebug("error of gbt = " + computeWeightedError(input, predError)) // Note: A model of type regression is used since we require raw prediction timer.stop("building tree 0") - var validatePredError: RDD[(Double, Double)] = + var validatePredError = computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss) if (validate) validatePredErrorCheckpointer.update(validatePredError) - var bestValidateError = if (validate) validatePredError.values.mean() else 0.0 + var bestValidateError = if (validate) { + computeWeightedError(validationInput, validatePredError) + } else { + 0.0 + } var bestM = 1 var m = 1 var doneLearning = false while (m < numIterations && !doneLearning) { // Update data with pseudo-residuals - val data = predError.zip(input).map { case ((pred, _), point) => - LabeledPoint(-loss.gradient(pred, point.label), point.features) + val data = predError.zip(input).map { case ((pred, _), Instance(label, weight, features)) => + Instance(-loss.gradient(pred, label), weight, features) } timer.start(s"building tree $m") @@ -339,7 +355,7 @@ private[spark] object GradientBoostedTrees extends Logging { predError = updatePredictionError( input, predError, baseLearnerWeights(m), baseLearners(m), loss) predErrorCheckpointer.update(predError) - logDebug("error of gbt = " + predError.values.mean()) + logDebug("error of gbt = " + computeWeightedError(input, predError)) if (validate) { // Stop training early if @@ -350,7 +366,7 @@ private[spark] object GradientBoostedTrees extends Logging { validatePredError = updatePredictionError( validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss) validatePredErrorCheckpointer.update(validatePredError) - val currentValidateError = validatePredError.values.mean() + val currentValidateError = computeWeightedError(validationInput, validatePredError) if (bestValidateError - currentValidateError < validationTol * Math.max( currentValidateError, 0.01)) { doneLearning = true diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index d24d8da0dab4..d57f1b36a572 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging -import org.apache.spark.ml.feature.{LabeledPoint => NewLabeledPoint} +import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.tree.impl.{GradientBoostedTrees => NewGBT} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.BoostingStrategy @@ -67,8 +67,9 @@ class GradientBoostedTrees private[spark] ( @Since("1.2.0") def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo - val (trees, treeWeights) = NewGBT.run(input.map { point => - NewLabeledPoint(point.label, point.features.asML) + val (trees, treeWeights) = NewGBT.run(input.map { + case LabeledPoint(label, features) => + Instance(label, 1.0, features.asML) }, boostingStrategy, seed.toLong, "all") new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights) } @@ -97,10 +98,12 @@ class GradientBoostedTrees private[spark] ( input: RDD[LabeledPoint], validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo - val (trees, treeWeights) = NewGBT.runWithValidation(input.map { point => - NewLabeledPoint(point.label, point.features.asML) - }, validationInput.map { point => - NewLabeledPoint(point.label, point.features.asML) + val (trees, treeWeights) = NewGBT.runWithValidation(input.map { + case LabeledPoint(label, features) => + Instance(label, 1.0, features.asML) + }, validationInput.map { + case LabeledPoint(label, features) => + Instance(label, 1.0, features.asML) }, boostingStrategy, seed.toLong, "all") new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 530ca20d0eb0..fdca71f8911c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.ml.classification import com.github.fommil.netlib.BLAS import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.classification.LinearSVCSuite.generateSVMInput +import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel @@ -52,8 +53,10 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { private var data: RDD[LabeledPoint] = _ private var trainData: RDD[LabeledPoint] = _ private var validationData: RDD[LabeledPoint] = _ + private var binaryDataset: DataFrame = _ private val eps: Double = 1e-5 private val absEps: Double = 1e-8 + private val seed = 42 override def beforeAll(): Unit = { super.beforeAll() @@ -65,6 +68,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { validationData = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2) .map(_.asML) + binaryDataset = generateSVMInput(0.01, Array[Double](-1.5, 1.0), 1000, seed).toDF() } test("params") { @@ -362,7 +366,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { test("Tests of feature subset strategy") { val numClasses = 2 val gbt = new GBTClassifier() - .setSeed(42) + .setSeed(seed) .setMaxDepth(3) .setMaxIter(5) .setFeatureSubsetStrategy("all") @@ -397,13 +401,15 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { model3.trees.take(2), model3.treeWeights.take(2), model3.numFeatures, model3.numClasses) val evalArr = model3.evaluateEachIteration(validationData.toDF) - val remappedValidationData = validationData.map( - x => new LabeledPoint((x.label * 2) - 1, x.features)) - val lossErr1 = GradientBoostedTrees.computeError(remappedValidationData, + val remappedValidationData = validationData.map { + case LabeledPoint(label, features) => + Instance(label * 2 - 1, 1.0, features) + } + val lossErr1 = GradientBoostedTrees.computeWeightedError(remappedValidationData, model1.trees, model1.treeWeights, model1.getOldLossType) - val lossErr2 = GradientBoostedTrees.computeError(remappedValidationData, + val lossErr2 = GradientBoostedTrees.computeWeightedError(remappedValidationData, model2.trees, model2.treeWeights, model2.getOldLossType) - val lossErr3 = GradientBoostedTrees.computeError(remappedValidationData, + val lossErr3 = GradientBoostedTrees.computeWeightedError(remappedValidationData, model3.trees, model3.treeWeights, model3.getOldLossType) assert(evalArr(0) ~== lossErr1 relTol 1E-3) @@ -433,16 +439,19 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { assert(modelWithValidation.numTrees < numIter) val (errorWithoutValidation, errorWithValidation) = { - val remappedRdd = validationData.map(x => new LabeledPoint(2 * x.label - 1, x.features)) - (GradientBoostedTrees.computeError(remappedRdd, modelWithoutValidation.trees, + val remappedRdd = validationData.map { + case LabeledPoint(label, features) => + Instance(label * 2 - 1, 1.0, features) + } + (GradientBoostedTrees.computeWeightedError(remappedRdd, modelWithoutValidation.trees, modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType), - GradientBoostedTrees.computeError(remappedRdd, modelWithValidation.trees, + GradientBoostedTrees.computeWeightedError(remappedRdd, modelWithValidation.trees, modelWithValidation.treeWeights, modelWithValidation.getOldLossType)) } assert(errorWithValidation < errorWithoutValidation) val evaluationArray = GradientBoostedTrees - .evaluateEachIteration(validationData, modelWithoutValidation.trees, + .evaluateEachIteration(validationData.map(_.toInstance), modelWithoutValidation.trees, modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType, OldAlgo.Classification) assert(evaluationArray.length === numIter) @@ -472,6 +481,36 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { }) } + test("training with sample weights") { + val df = binaryDataset + val numClasses = 2 + val predEquals = (x: Double, y: Double) => x == y + // (maxIter, maxDepth) + val testParams = Seq( + (5, 5), + (5, 10) + ) + + for ((maxIter, maxDepth) <- testParams) { + val estimator = new GBTClassifier() + .setMaxIter(maxIter) + .setMaxDepth(maxDepth) + .setSeed(seed) + .setMinWeightFractionPerNode(0.049) + + MLTestingUtils.testArbitrarilyScaledWeights[GBTClassificationModel, + GBTClassifier](df.as[LabeledPoint], estimator, + MLTestingUtils.modelPredictionEquals(df, predEquals, 0.7)) + MLTestingUtils.testOutliersWithSmallWeights[GBTClassificationModel, + GBTClassifier](df.as[LabeledPoint], estimator, + numClasses, MLTestingUtils.modelPredictionEquals(df, predEquals, 0.8), + outlierRatio = 2) + MLTestingUtils.testOversamplingVsWeighting[GBTClassificationModel, + GBTClassifier](df.as[LabeledPoint], estimator, + MLTestingUtils.modelPredictionEquals(df, predEquals, 0.7), seed) + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 2b5a9a396eff..d2b8751360e9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -1425,8 +1425,6 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { } test("multinomial logistic regression with zero variance (SPARK-21681)") { - val sqlContext = multinomialDatasetWithZeroVar.sqlContext - import sqlContext.implicits._ val mlr = new LogisticRegression().setFamily("multinomial").setFitIntercept(true) .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(true).setWeightCol("weight") diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index e2462af2ac1d..b772a3b7737d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.lit @@ -46,6 +47,8 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { private var data: RDD[LabeledPoint] = _ private var trainData: RDD[LabeledPoint] = _ private var validationData: RDD[LabeledPoint] = _ + private var linearRegressionData: DataFrame = _ + private val seed = 42 override def beforeAll(): Unit = { super.beforeAll() @@ -57,6 +60,9 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { validationData = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2) .map(_.asML) + linearRegressionData = sc.parallelize(LinearDataGenerator.generateLinearInput( + intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3), + xVariance = Array(0.7, 1.2), nPoints = 1000, seed, eps = 0.5), 2).map(_.asML).toDF() } test("Regression with continuous features") { @@ -202,7 +208,7 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { val gbt = new GBTRegressor() .setMaxDepth(3) .setMaxIter(5) - .setSeed(42) + .setSeed(seed) .setFeatureSubsetStrategy("all") // In this data, feature 1 is very important. @@ -237,11 +243,11 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { for (evalLossType <- GBTRegressor.supportedLossTypes) { val evalArr = model3.evaluateEachIteration(validationData.toDF, evalLossType) - val lossErr1 = GradientBoostedTrees.computeError(validationData, + val lossErr1 = GradientBoostedTrees.computeWeightedError(validationData.map(_.toInstance), model1.trees, model1.treeWeights, model1.convertToOldLossType(evalLossType)) - val lossErr2 = GradientBoostedTrees.computeError(validationData, + val lossErr2 = GradientBoostedTrees.computeWeightedError(validationData.map(_.toInstance), model2.trees, model2.treeWeights, model2.convertToOldLossType(evalLossType)) - val lossErr3 = GradientBoostedTrees.computeError(validationData, + val lossErr3 = GradientBoostedTrees.computeWeightedError(validationData.map(_.toInstance), model3.trees, model3.treeWeights, model3.convertToOldLossType(evalLossType)) assert(evalArr(0) ~== lossErr1 relTol 1E-3) @@ -272,17 +278,19 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { // early stop assert(modelWithValidation.numTrees < numIter) - val errorWithoutValidation = GradientBoostedTrees.computeError(validationData, + val errorWithoutValidation = GradientBoostedTrees.computeWeightedError( + validationData.map(_.toInstance), modelWithoutValidation.trees, modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType) - val errorWithValidation = GradientBoostedTrees.computeError(validationData, + val errorWithValidation = GradientBoostedTrees.computeWeightedError( + validationData.map(_.toInstance), modelWithValidation.trees, modelWithValidation.treeWeights, modelWithValidation.getOldLossType) assert(errorWithValidation < errorWithoutValidation) val evaluationArray = GradientBoostedTrees - .evaluateEachIteration(validationData, modelWithoutValidation.trees, + .evaluateEachIteration(validationData.map(_.toInstance), modelWithoutValidation.trees, modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType, OldAlgo.Regression) assert(evaluationArray.length === numIter) @@ -310,6 +318,35 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { }) } + test("training with sample weights") { + val df = linearRegressionData + val numClasses = 0 + // (maxIter, maxDepth) + val testParams = Seq( + (5, 5), + (5, 10) + ) + + for ((maxIter, maxDepth) <- testParams) { + val estimator = new GBTRegressor() + .setMaxIter(maxIter) + .setMaxDepth(maxDepth) + .setSeed(seed) + .setMinWeightFractionPerNode(0.1) + + MLTestingUtils.testArbitrarilyScaledWeights[GBTRegressionModel, + GBTRegressor](df.as[LabeledPoint], estimator, + MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, 0.95)) + MLTestingUtils.testOutliersWithSmallWeights[GBTRegressionModel, + GBTRegressor](df.as[LabeledPoint], estimator, numClasses, + MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, 0.95), + outlierRatio = 2) + MLTestingUtils.testOversamplingVsWeighting[GBTRegressionModel, + GBTRegressor](df.as[LabeledPoint], estimator, + MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.01, 0.95), seed) + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala index 366d5ec3a53f..18fc1407557f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.tree.impl import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.Instance import org.apache.spark.mllib.tree.{GradientBoostedTreesSuite => OldGBTSuite} import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ @@ -32,15 +32,12 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext */ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { - import testImplicits._ - test("runWithValidation stops early and performs better on a validation dataset") { // Set numIterations large enough so that it stops early. val numIterations = 20 - val trainRdd = sc.parallelize(OldGBTSuite.trainData, 2).map(_.asML) - val validateRdd = sc.parallelize(OldGBTSuite.validateData, 2).map(_.asML) - val trainDF = trainRdd.toDF() - val validateDF = validateRdd.toDF() + val trainRdd = sc.parallelize(OldGBTSuite.trainData, 2).map(_.asML.toInstance) + val validateRdd = sc.parallelize(OldGBTSuite.validateData, 2).map(_.asML.toInstance) + val seed = 42 val algos = Array(Regression, Regression, Classification) val losses = Array(SquaredError, AbsoluteError, LogLoss) @@ -50,21 +47,21 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext val boostingStrategy = new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) val (validateTrees, validateTreeWeights) = GradientBoostedTrees - .runWithValidation(trainRdd, validateRdd, boostingStrategy, 42L, "all") + .runWithValidation(trainRdd, validateRdd, boostingStrategy, seed, "all") val numTrees = validateTrees.length assert(numTrees !== numIterations) // Test that it performs better on the validation dataset. - val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, 42L, "all") + val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, seed, "all") val (errorWithoutValidation, errorWithValidation) = { if (algo == Classification) { - val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) - (GradientBoostedTrees.computeError(remappedRdd, trees, treeWeights, loss), - GradientBoostedTrees.computeError(remappedRdd, validateTrees, + val remappedRdd = validateRdd.map(x => Instance(2 * x.label - 1, x.weight, x.features)) + (GradientBoostedTrees.computeWeightedError(remappedRdd, trees, treeWeights, loss), + GradientBoostedTrees.computeWeightedError(remappedRdd, validateTrees, validateTreeWeights, loss)) } else { - (GradientBoostedTrees.computeError(validateRdd, trees, treeWeights, loss), - GradientBoostedTrees.computeError(validateRdd, validateTrees, + (GradientBoostedTrees.computeWeightedError(validateRdd, trees, treeWeights, loss), + GradientBoostedTrees.computeWeightedError(validateRdd, validateTrees, validateTreeWeights, loss)) } } From cfbdd9d2932d8d80ab679f45f146641579546855 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Fri, 25 Oct 2019 16:32:06 +0900 Subject: [PATCH 45/58] [SPARK-29461][SQL] Measure the number of records being updated for JDBC writer ### What changes were proposed in this pull request? This patch adds the functionality to measure records being written for JDBC writer. In reality, the value is meant to be a number of records being updated from queries, as per JDBC spec it will return updated count. ### Why are the changes needed? Output metrics for JDBC writer are missing now. The value of "bytesWritten" is also missing, but we can't measure it from JDBC API. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Unit test added. Closes #26109 from HeartSaVioR/SPARK-29461. Authored-by: Jungtaek Lim (HeartSaVioR) Signed-off-by: Takeshi Yamamuro --- .../datasources/jdbc/JdbcUtils.scala | 23 ++++++-- .../spark/sql/jdbc/JDBCWriteSuite.scala | 55 +++++++++++++++++++ 2 files changed, 73 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 86a27b5afc25..55ca4e3624bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -605,6 +605,13 @@ object JdbcUtils extends Logging { * implementation changes elsewhere might easily render such a closure * non-Serializable. Instead, we explicitly close over all variables that * are used. + * + * Note that this method records task output metrics. It assumes the method is + * running in a task. For now, we only records the number of rows being written + * because there's no good way to measure the total bytes being written. Only + * effective outputs are taken into account: for example, metric will not be updated + * if it supports transaction and transaction is rolled back, but metric will be + * updated even with error if it doesn't support transaction, as there're dirty outputs. */ def savePartition( getConnection: () => Connection, @@ -615,7 +622,9 @@ object JdbcUtils extends Logging { batchSize: Int, dialect: JdbcDialect, isolationLevel: Int, - options: JDBCOptions): Iterator[Byte] = { + options: JDBCOptions): Unit = { + val outMetrics = TaskContext.get().taskMetrics().outputMetrics + val conn = getConnection() var committed = false @@ -643,7 +652,7 @@ object JdbcUtils extends Logging { } } val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE - + var totalRowCount = 0 try { if (supportsTransactions) { conn.setAutoCommit(false) // Everything in the same db transaction. @@ -672,6 +681,7 @@ object JdbcUtils extends Logging { } stmt.addBatch() rowCount += 1 + totalRowCount += 1 if (rowCount % batchSize == 0) { stmt.executeBatch() rowCount = 0 @@ -687,7 +697,6 @@ object JdbcUtils extends Logging { conn.commit() } committed = true - Iterator.empty } catch { case e: SQLException => val cause = e.getNextException @@ -715,9 +724,13 @@ object JdbcUtils extends Logging { // tell the user about another problem. if (supportsTransactions) { conn.rollback() + } else { + outMetrics.setRecordsWritten(totalRowCount) } conn.close() } else { + outMetrics.setRecordsWritten(totalRowCount) + // The stage must succeed. We cannot propagate any exception close() might throw. try { conn.close() @@ -840,10 +853,10 @@ object JdbcUtils extends Logging { case Some(n) if n < df.rdd.getNumPartitions => df.coalesce(n) case _ => df } - repartitionedDF.rdd.foreachPartition(iterator => savePartition( + repartitionedDF.rdd.foreachPartition { iterator => savePartition( getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel, options) - ) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index b28c6531d42b..8021ef1a17a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -21,10 +21,12 @@ import java.sql.DriverManager import java.util.Properties import scala.collection.JavaConverters.propertiesAsScalaMapConverter +import scala.collection.mutable.ArrayBuffer import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} @@ -543,4 +545,57 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { }.getMessage assert(errMsg.contains("Statement was canceled or the session timed out")) } + + test("metrics") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) + + runAndVerifyRecordsWritten(2) { + df.write.mode(SaveMode.Append).jdbc(url, "TEST.BASICCREATETEST", new Properties()) + } + + runAndVerifyRecordsWritten(1) { + df2.write.mode(SaveMode.Overwrite).jdbc(url, "TEST.BASICCREATETEST", new Properties()) + } + + runAndVerifyRecordsWritten(1) { + df2.write.mode(SaveMode.Overwrite).option("truncate", true) + .jdbc(url, "TEST.BASICCREATETEST", new Properties()) + } + + runAndVerifyRecordsWritten(0) { + intercept[AnalysisException] { + df2.write.mode(SaveMode.ErrorIfExists).jdbc(url, "TEST.BASICCREATETEST", new Properties()) + } + } + + runAndVerifyRecordsWritten(0) { + df.write.mode(SaveMode.Ignore).jdbc(url, "TEST.BASICCREATETEST", new Properties()) + } + } + + private def runAndVerifyRecordsWritten(expected: Long)(job: => Unit): Unit = { + assert(expected === runAndReturnMetrics(job, _.taskMetrics.outputMetrics.recordsWritten)) + } + + private def runAndReturnMetrics(job: => Unit, collector: (SparkListenerTaskEnd) => Long): Long = { + val taskMetrics = new ArrayBuffer[Long]() + + // Avoid receiving earlier taskEnd events + sparkContext.listenerBus.waitUntilEmpty() + + val listener = new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + taskMetrics += collector(taskEnd) + } + } + sparkContext.addSparkListener(listener) + + job + + sparkContext.listenerBus.waitUntilEmpty() + + sparkContext.removeSparkListener(listener) + taskMetrics.sum + } } From 8bd8f492ea006ce03d215c3b272c31c1b8bc1858 Mon Sep 17 00:00:00 2001 From: redsk Date: Fri, 25 Oct 2019 08:06:36 -0500 Subject: [PATCH 46/58] [SPARK-29500][SQL][SS] Support partition column when writing to Kafka ### What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-29500 `KafkaRowWriter` now supports setting the Kafka partition by reading a "partition" column in the input dataframe. Code changes in commit nr. 1. Test changes in commit nr. 2. Doc changes in commit nr. 3. tcondie dongjinleekr srowen ### Why are the changes needed? While it is possible to configure a custom Kafka Partitioner with `.option("kafka.partitioner.class", "my.custom.Partitioner")`, this is not enough for certain use cases. See the Jira issue. ### Does this PR introduce any user-facing change? No, as this behaviour is optional. ### How was this patch tested? Two new UT were added and one was updated. Closes #26153 from redsk/feature/SPARK-29500. Authored-by: redsk Signed-off-by: Sean Owen --- .../structured-streaming-kafka-integration.md | 10 +++ .../spark/sql/kafka010/KafkaWriteTask.scala | 22 ++++- .../spark/sql/kafka010/KafkaWriter.scala | 11 ++- .../kafka010/KafkaContinuousSinkSuite.scala | 9 ++ .../spark/sql/kafka010/KafkaSinkSuite.scala | 88 ++++++++++++++++++- 5 files changed, 134 insertions(+), 6 deletions(-) diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index 89732d309aa2..badf0429545f 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -614,6 +614,10 @@ The Dataframe being written to Kafka should have the following columns in schema topic (*optional) string + + partition (optional) + int + \* The topic column is required if the "topic" configuration option is not specified.
    @@ -622,6 +626,12 @@ a ```null``` valued key column will be automatically added (see Kafka semantics how ```null``` valued key values are handled). If a topic column exists then its value is used as the topic when writing the given row to Kafka, unless the "topic" configuration option is set i.e., the "topic" configuration option overrides the topic column. +If a "partition" column is not specified (or its value is ```null```) +then the partition is calculated by the Kafka producer. +A Kafka partitioner can be specified in Spark by setting the +```kafka.partitioner.class``` option. If not present, Kafka default partitioner +will be used. + The following options must be set for the Kafka sink for both batch and streaming queries. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index b423ddc959c1..5bdc1b5fe9f3 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -27,7 +27,7 @@ import org.apache.kafka.common.header.internals.RecordHeader import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} -import org.apache.spark.sql.types.{BinaryType, StringType} +import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType} /** * Writes out data in a single Spark task, without any concerns about how @@ -92,8 +92,10 @@ private[kafka010] abstract class KafkaRowWriter( throw new NullPointerException(s"null topic present in the data. Use the " + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") } + val partition: Integer = + if (projectedRow.isNullAt(4)) null else projectedRow.getInt(4) val record = if (projectedRow.isNullAt(3)) { - new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, null, key, value) + new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, partition, key, value) } else { val headerArray = projectedRow.getArray(3) val headers = (0 until headerArray.numElements()).map { i => @@ -101,7 +103,8 @@ private[kafka010] abstract class KafkaRowWriter( new RecordHeader(struct.getUTF8String(0).toString, struct.getBinary(1)) .asInstanceOf[Header] } - new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, null, key, value, headers.asJava) + new ProducerRecord[Array[Byte], Array[Byte]]( + topic.toString, partition, key, value, headers.asJava) } producer.send(record, callback) } @@ -156,12 +159,23 @@ private[kafka010] abstract class KafkaRowWriter( throw new IllegalStateException(s"${KafkaWriter.HEADERS_ATTRIBUTE_NAME} " + s"attribute unsupported type ${t.catalogString}") } + val partitionExpression = + inputSchema.find(_.name == KafkaWriter.PARTITION_ATTRIBUTE_NAME) + .getOrElse(Literal(null, IntegerType)) + partitionExpression.dataType match { + case IntegerType => // good + case t => + throw new IllegalStateException(s"${KafkaWriter.PARTITION_ATTRIBUTE_NAME} " + + s"attribute unsupported type $t. ${KafkaWriter.PARTITION_ATTRIBUTE_NAME} " + + s"must be a ${IntegerType.catalogString}") + } UnsafeProjection.create( Seq( topicExpression, Cast(keyExpression, BinaryType), Cast(valueExpression, BinaryType), - headersExpression + headersExpression, + partitionExpression ), inputSchema ) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index bbb060356f73..9b0d11f137ce 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.types.{BinaryType, MapType, StringType} +import org.apache.spark.sql.types.{BinaryType, IntegerType, MapType, StringType} import org.apache.spark.util.Utils /** @@ -41,6 +41,7 @@ private[kafka010] object KafkaWriter extends Logging { val KEY_ATTRIBUTE_NAME: String = "key" val VALUE_ATTRIBUTE_NAME: String = "value" val HEADERS_ATTRIBUTE_NAME: String = "headers" + val PARTITION_ATTRIBUTE_NAME: String = "partition" override def toString: String = "KafkaWriter" @@ -86,6 +87,14 @@ private[kafka010] object KafkaWriter extends Logging { throw new AnalysisException(s"$HEADERS_ATTRIBUTE_NAME attribute type " + s"must be a ${KafkaRecordToRowConverter.headersType.catalogString}") } + schema.find(_.name == PARTITION_ATTRIBUTE_NAME).getOrElse( + Literal(null, IntegerType) + ).dataType match { + case IntegerType => // good + case _ => + throw new AnalysisException(s"$PARTITION_ATTRIBUTE_NAME attribute type " + + s"must be an ${IntegerType.catalogString}") + } } def write( diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala index 65adbd6b9887..cbf4952406c0 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -286,6 +286,15 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { } assert(ex3.getMessage.toLowerCase(Locale.ROOT).contains( "key attribute type must be a string or binary")) + + val ex4 = intercept[AnalysisException] { + /* partition field wrong type */ + createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "value as partition", "value" + ) + } + assert(ex4.getMessage.toLowerCase(Locale.ROOT).contains( + "partition attribute type must be an int")) } test("streaming - write to non-existing topic") { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index d77b9a3b6a9e..aacb10f5197b 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -22,6 +22,8 @@ import java.util.Locale import java.util.concurrent.atomic.AtomicInteger import org.apache.kafka.clients.producer.ProducerConfig +import org.apache.kafka.clients.producer.internals.DefaultPartitioner +import org.apache.kafka.common.Cluster import org.apache.kafka.common.serialization.ByteArraySerializer import org.scalatest.time.SpanSugar._ @@ -33,7 +35,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{BinaryType, DataType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, StringType, StructField, StructType} abstract class KafkaSinkSuiteBase extends QueryTest with SharedSparkSession with KafkaTest { protected var testUtils: KafkaTestUtils = _ @@ -293,6 +295,21 @@ class KafkaSinkStreamingSuite extends KafkaSinkSuiteBase with StreamTest { } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( "key attribute type must be a string or binary")) + + try { + ex = intercept[StreamingQueryException] { + /* partition field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "value", "value as partition" + ) + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "partition attribute type must be an int")) } test("streaming - write to non-existing topic") { @@ -418,6 +435,65 @@ abstract class KafkaSinkBatchSuiteBase extends KafkaSinkSuiteBase { ) } + def writeToKafka(df: DataFrame, topic: String, options: Map[String, String] = Map.empty): Unit = { + df + .write + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("topic", topic) + .options(options) + .mode("append") + .save() + } + + def partitionsInTopic(topic: String): Set[Int] = { + createKafkaReader(topic) + .select("partition") + .map(_.getInt(0)) + .collect() + .toSet + } + + test("batch - partition column and partitioner priorities") { + val nrPartitions = 4 + val topic1 = newTopic() + val topic2 = newTopic() + val topic3 = newTopic() + val topic4 = newTopic() + testUtils.createTopic(topic1, nrPartitions) + testUtils.createTopic(topic2, nrPartitions) + testUtils.createTopic(topic3, nrPartitions) + testUtils.createTopic(topic4, nrPartitions) + val customKafkaPartitionerConf = Map( + "kafka.partitioner.class" -> "org.apache.spark.sql.kafka010.TestKafkaPartitioner" + ) + + val df = (0 until 5).map(n => (topic1, s"$n", s"$n")).toDF("topic", "key", "value") + + // default kafka partitioner + writeToKafka(df, topic1) + val partitionsInTopic1 = partitionsInTopic(topic1) + assert(partitionsInTopic1.size > 1) + + // custom partitioner (always returns 0) overrides default partitioner + writeToKafka(df, topic2, customKafkaPartitionerConf) + val partitionsInTopic2 = partitionsInTopic(topic2) + assert(partitionsInTopic2.size == 1) + assert(partitionsInTopic2.head == 0) + + // partition column overrides custom partitioner + val dfWithCustomPartition = df.withColumn("partition", lit(2)) + writeToKafka(dfWithCustomPartition, topic3, customKafkaPartitionerConf) + val partitionsInTopic3 = partitionsInTopic(topic3) + assert(partitionsInTopic3.size == 1) + assert(partitionsInTopic3.head == 2) + + // when the partition column value is null, it is ignored + val dfWithNullPartitions = df.withColumn("partition", lit(null).cast(IntegerType)) + writeToKafka(dfWithNullPartitions, topic4) + assert(partitionsInTopic(topic4) == partitionsInTopic1) + } + test("batch - null topic field value, and no topic option") { val df = Seq[(String, String)](null.asInstanceOf[String] -> "1").toDF("topic", "value") val ex = intercept[SparkException] { @@ -515,3 +591,13 @@ class KafkaSinkBatchSuiteV2 extends KafkaSinkBatchSuiteBase { } } } + +class TestKafkaPartitioner extends DefaultPartitioner { + override def partition( + topic: String, + key: Any, + keyBytes: Array[Byte], + value: Any, + valueBytes: Array[Byte], + cluster: Cluster): Int = 0 +} From 0cf4f07c66b44770efa4b97db8d47d5fc394aeab Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 25 Oct 2019 22:19:19 +0900 Subject: [PATCH 47/58] [SPARK-29545][SQL] Add support for bit_xor aggregate function ### What changes were proposed in this pull request? bit_xor(expr) - Returns the bitwise XOR of all non-null input values, or null if none ### Why are the changes needed? As we support `bit_and`, `bit_or` now, we'd better support the related aggregate function **bit_xor** ahead of postgreSQL, because many other popular databases support it. http://infocenter.sybase.com/help/index.jsp?topic=/com.sybase.help.sqlanywhere.12.0.1/dbreference/bit-xor-function.html https://dev.mysql.com/doc/refman/5.7/en/group-by-functions.html#function_bit-or https://www.vertica.com/docs/9.2.x/HTML/Content/Authoring/SQLReferenceManual/Functions/Aggregate/BIT_XOR.htm?TocPath=SQL%20Reference%20Manual%7CSQL%20Functions%7CAggregate%20Functions%7C_____10 ### Does this PR introduce any user-facing change? add a new bit agg ### How was this patch tested? UTs added Closes #26205 from yaooqinn/SPARK-29545. Authored-by: Kent Yao Signed-off-by: Takeshi Yamamuro --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../aggregate/bitwiseAggregates.scala | 92 ++++++++++--------- .../resources/sql-tests/inputs/bitwise.sql | 31 +++++++ .../sql-tests/results/bitwise.sql.out | 71 +++++++++++++- 4 files changed, 151 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 04e8963944fd..52e05b820366 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -527,6 +527,7 @@ object FunctionRegistry { expression[BitwiseCount]("bit_count"), expression[BitAndAgg]("bit_and"), expression[BitOrAgg]("bit_or"), + expression[BitXorAgg]("bit_xor"), // json expression[StructsToJson]("to_json"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala index 131fa2eb5055..b77c3bd9cbde 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala @@ -17,20 +17,14 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BitwiseAnd, BitwiseOr, ExpectsInputTypes, Expression, ExpressionDescription, If, IsNull, Literal} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BinaryArithmetic, BitwiseAnd, BitwiseOr, BitwiseXor, ExpectsInputTypes, Expression, ExpressionDescription, If, IsNull, Literal} import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegralType} -@ExpressionDescription( - usage = "_FUNC_(expr) - Returns the bitwise AND of all non-null input values, or null if none.", - examples = """ - Examples: - > SELECT _FUNC_(col) FROM VALUES (3), (5) AS tab(col); - 1 - """, - since = "3.0.0") -case class BitAndAgg(child: Expression) extends DeclarativeAggregate with ExpectsInputTypes { +abstract class BitAggregate extends DeclarativeAggregate with ExpectsInputTypes { - override def nodeName: String = "bit_and" + val child: Expression + + def bitOperator(left: Expression, right: Expression): BinaryArithmetic override def children: Seq[Expression] = child :: Nil @@ -40,23 +34,40 @@ case class BitAndAgg(child: Expression) extends DeclarativeAggregate with Expect override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType) - private lazy val bitAnd = AttributeReference("bit_and", child.dataType)() - - override lazy val aggBufferAttributes: Seq[AttributeReference] = bitAnd :: Nil + private lazy val bitAgg = AttributeReference(nodeName, child.dataType)() override lazy val initialValues: Seq[Literal] = Literal.create(null, dataType) :: Nil + override lazy val aggBufferAttributes: Seq[AttributeReference] = bitAgg :: Nil + + override lazy val evaluateExpression: AttributeReference = bitAgg + override lazy val updateExpressions: Seq[Expression] = - If(IsNull(bitAnd), + If(IsNull(bitAgg), child, - If(IsNull(child), bitAnd, BitwiseAnd(bitAnd, child))) :: Nil + If(IsNull(child), bitAgg, bitOperator(bitAgg, child))) :: Nil override lazy val mergeExpressions: Seq[Expression] = - If(IsNull(bitAnd.left), - bitAnd.right, - If(IsNull(bitAnd.right), bitAnd.left, BitwiseAnd(bitAnd.left, bitAnd.right))) :: Nil + If(IsNull(bitAgg.left), + bitAgg.right, + If(IsNull(bitAgg.right), bitAgg.left, bitOperator(bitAgg.left, bitAgg.right))) :: Nil +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the bitwise AND of all non-null input values, or null if none.", + examples = """ + Examples: + > SELECT _FUNC_(col) FROM VALUES (3), (5) AS tab(col); + 1 + """, + since = "3.0.0") +case class BitAndAgg(child: Expression) extends BitAggregate { - override lazy val evaluateExpression: AttributeReference = bitAnd + override def nodeName: String = "bit_and" + + override def bitOperator(left: Expression, right: Expression): BinaryArithmetic = { + BitwiseAnd(left, right) + } } @ExpressionDescription( @@ -67,33 +78,28 @@ case class BitAndAgg(child: Expression) extends DeclarativeAggregate with Expect 7 """, since = "3.0.0") -case class BitOrAgg(child: Expression) extends DeclarativeAggregate with ExpectsInputTypes { +case class BitOrAgg(child: Expression) extends BitAggregate { override def nodeName: String = "bit_or" - override def children: Seq[Expression] = child :: Nil - - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType - - override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType) - - private lazy val bitOr = AttributeReference("bit_or", child.dataType)() - - override lazy val aggBufferAttributes: Seq[AttributeReference] = bitOr :: Nil - - override lazy val initialValues: Seq[Literal] = Literal.create(null, dataType) :: Nil + override def bitOperator(left: Expression, right: Expression): BinaryArithmetic = { + BitwiseOr(left, right) + } +} - override lazy val updateExpressions: Seq[Expression] = - If(IsNull(bitOr), - child, - If(IsNull(child), bitOr, BitwiseOr(bitOr, child))) :: Nil +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the bitwise XOR of all non-null input values, or null if none.", + examples = """ + Examples: + > SELECT _FUNC_(col) FROM VALUES (3), (5) AS tab(col); + 6 + """, + since = "3.0.0") +case class BitXorAgg(child: Expression) extends BitAggregate { - override lazy val mergeExpressions: Seq[Expression] = - If(IsNull(bitOr.left), - bitOr.right, - If(IsNull(bitOr.right), bitOr.left, BitwiseOr(bitOr.left, bitOr.right))) :: Nil + override def nodeName: String = "bit_xor" - override lazy val evaluateExpression: AttributeReference = bitOr + override def bitOperator(left: Expression, right: Expression): BinaryArithmetic = { + BitwiseXor(left, right) + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/bitwise.sql b/sql/core/src/test/resources/sql-tests/inputs/bitwise.sql index 993eecf0f89b..5e665e4c0c38 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/bitwise.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/bitwise.sql @@ -37,3 +37,34 @@ select bit_count(-9223372036854775808L); -- other illegal arguments select bit_count("bit count"); select bit_count('a'); + +-- test for bit_xor +-- +CREATE OR REPLACE TEMPORARY VIEW bitwise_test AS SELECT * FROM VALUES + (1, 1, 1, 1L), + (2, 3, 4, null), + (7, 7, 7, 3L) AS bitwise_test(b1, b2, b3, b4); + +-- empty case +SELECT BIT_XOR(b3) AS n1 FROM bitwise_test where 1 = 0; + +-- null case +SELECT BIT_XOR(b4) AS n1 FROM bitwise_test where b4 is null; + +-- the suffix numbers show the expected answer +SELECT + BIT_XOR(cast(b1 as tinyint)) AS a4, + BIT_XOR(cast(b2 as smallint)) AS b5, + BIT_XOR(b3) AS c2, + BIT_XOR(b4) AS d2, + BIT_XOR(distinct b4) AS e2 +FROM bitwise_test; + +-- group by +SELECT bit_xor(b3) FROM bitwise_test GROUP BY b1 & 1; + +--having +SELECT b1, bit_xor(b2) FROM bitwise_test GROUP BY b1 HAVING bit_and(b2) < 7; + +-- window +SELECT b1, b2, bit_xor(b2) OVER (PARTITION BY b1 ORDER BY b2) FROM bitwise_test; diff --git a/sql/core/src/test/resources/sql-tests/results/bitwise.sql.out b/sql/core/src/test/resources/sql-tests/results/bitwise.sql.out index 7cbd26e87bd2..42c22a317eb4 100644 --- a/sql/core/src/test/resources/sql-tests/results/bitwise.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/bitwise.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 20 +-- Number of queries: 27 -- !query 0 @@ -162,3 +162,72 @@ struct<> -- !query 19 output org.apache.spark.sql.AnalysisException cannot resolve 'bit_count('a')' due to data type mismatch: argument 1 requires (integral or boolean) type, however, ''a'' is of string type.; line 1 pos 7 + + +-- !query 20 +CREATE OR REPLACE TEMPORARY VIEW bitwise_test AS SELECT * FROM VALUES + (1, 1, 1, 1L), + (2, 3, 4, null), + (7, 7, 7, 3L) AS bitwise_test(b1, b2, b3, b4) +-- !query 20 schema +struct<> +-- !query 20 output + + + +-- !query 21 +SELECT BIT_XOR(b3) AS n1 FROM bitwise_test where 1 = 0 +-- !query 21 schema +struct +-- !query 21 output +NULL + + +-- !query 22 +SELECT BIT_XOR(b4) AS n1 FROM bitwise_test where b4 is null +-- !query 22 schema +struct +-- !query 22 output +NULL + + +-- !query 23 +SELECT + BIT_XOR(cast(b1 as tinyint)) AS a4, + BIT_XOR(cast(b2 as smallint)) AS b5, + BIT_XOR(b3) AS c2, + BIT_XOR(b4) AS d2, + BIT_XOR(distinct b4) AS e2 +FROM bitwise_test +-- !query 23 schema +struct +-- !query 23 output +4 5 2 2 2 + + +-- !query 24 +SELECT bit_xor(b3) FROM bitwise_test GROUP BY b1 & 1 +-- !query 24 schema +struct +-- !query 24 output +4 +6 + + +-- !query 25 +SELECT b1, bit_xor(b2) FROM bitwise_test GROUP BY b1 HAVING bit_and(b2) < 7 +-- !query 25 schema +struct +-- !query 25 output +1 1 +2 3 + + +-- !query 26 +SELECT b1, b2, bit_xor(b2) OVER (PARTITION BY b1 ORDER BY b2) FROM bitwise_test +-- !query 26 schema +struct +-- !query 26 output +1 1 1 +2 3 3 +7 7 7 From 68dca9a0953e4a9472235acf78aecbb95c07acb6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 25 Oct 2019 23:09:08 +0800 Subject: [PATCH 48/58] [SPARK-29527][SQL] SHOW CREATE TABLE should look up catalog/table like v2 commands ### What changes were proposed in this pull request? Add ShowCreateTableStatement and make SHOW CREATE TABLE go through the same catalog/table resolution framework of v2 commands. ### Why are the changes needed? It's important to make all the commands have the same table resolution behavior, to avoid confusing end-users. e.g. ``` USE my_catalog DESC t // success and describe the table t from my_catalog SHOW CREATE TABLE t // report table not found as there is no table t in the session catalog ``` ### Does this PR introduce any user-facing change? yes. When running SHOW CREATE TABLE, Spark fails the command if the current catalog is set to a v2 catalog, or the table name specified a v2 catalog. ### How was this patch tested? Unit tests. Closes #26184 from viirya/SPARK-29527. Lead-authored-by: Liang-Chi Hsieh Co-authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan --- .../org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 7 +++++++ .../spark/sql/catalyst/plans/logical/statements.scala | 5 +++++ .../apache/spark/sql/catalyst/parser/DDLParserSuite.scala | 6 ++++++ .../sql/catalyst/analysis/ResolveSessionCatalog.scala | 6 +++++- .../org/apache/spark/sql/execution/SparkSqlParser.scala | 8 -------- .../apache/spark/sql/connector/DataSourceV2SQLSuite.scala | 8 ++++++++ 7 files changed, 32 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 1e89507411ad..c97eb3c935be 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -201,7 +201,7 @@ statement | SHOW PARTITIONS multipartIdentifier partitionSpec? #showPartitions | SHOW identifier? FUNCTIONS (LIKE? (qualifiedName | pattern=STRING))? #showFunctions - | SHOW CREATE TABLE tableIdentifier #showCreateTable + | SHOW CREATE TABLE multipartIdentifier #showCreateTable | (DESC | DESCRIBE) FUNCTION EXTENDED? describeFuncName #describeFunction | (DESC | DESCRIBE) database EXTENDED? db=errorCapturingIdentifier #describeDatabase | (DESC | DESCRIBE) TABLE? option=(EXTENDED | FORMATTED)? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index b030227b4881..4fa479f083e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2797,6 +2797,13 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging RepairTableStatement(visitMultipartIdentifier(ctx.multipartIdentifier())) } + /** + * Creates a [[ShowCreateTableStatement]] + */ + override def visitShowCreateTable(ctx: ShowCreateTableContext): LogicalPlan = withOrigin(ctx) { + ShowCreateTableStatement(visitMultipartIdentifier(ctx.multipartIdentifier())) + } + /** * Create a [[CacheTableStatement]]. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index ef8c92269434..655e87fce4e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -330,6 +330,11 @@ case class AnalyzeColumnStatement( */ case class RepairTableStatement(tableName: Seq[String]) extends ParsedStatement +/** + * A SHOW CREATE TABLE statement, as parsed from SQL. + */ +case class ShowCreateTableStatement(tableName: Seq[String]) extends ParsedStatement + /** * A CACHE TABLE statement, as parsed from SQL */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index f4375956f0af..da01c612b350 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -1045,6 +1045,12 @@ class DDLParserSuite extends AnalysisTest { RepairTableStatement(Seq("a", "b", "c"))) } + test("SHOW CREATE table") { + comparePlans( + parsePlan("SHOW CREATE TABLE a.b.c"), + ShowCreateTableStatement(Seq("a", "b", "c"))) + } + test("CACHE TABLE") { comparePlans( parsePlan("CACHE TABLE a.b.c"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index f91686cb544c..e7e34b1ef312 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog, TableChange, V1Table} import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableRecoverPartitionsCommand, AlterTableSetLocationCommand, AlterTableSetPropertiesCommand, AlterTableUnsetPropertiesCommand, AnalyzeColumnCommand, AnalyzePartitionCommand, AnalyzeTableCommand, CacheTableCommand, CreateDatabaseCommand, DescribeColumnCommand, DescribeTableCommand, DropTableCommand, ShowPartitionsCommand, ShowTablesCommand, TruncateTableCommand, UncacheTableCommand} +import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableRecoverPartitionsCommand, AlterTableSetLocationCommand, AlterTableSetPropertiesCommand, AlterTableUnsetPropertiesCommand, AnalyzeColumnCommand, AnalyzePartitionCommand, AnalyzeTableCommand, CacheTableCommand, CreateDatabaseCommand, DescribeColumnCommand, DescribeTableCommand, DropTableCommand, ShowCreateTableCommand, ShowPartitionsCommand, ShowTablesCommand, TruncateTableCommand, UncacheTableCommand} import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, RefreshTable} import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.internal.SQLConf @@ -299,6 +299,10 @@ class ResolveSessionCatalog( v1TableName.asTableIdentifier, "MSCK REPAIR TABLE") + case ShowCreateTableStatement(tableName) => + val v1TableName = parseV1Table(tableName, "SHOW CREATE TABLE") + ShowCreateTableCommand(v1TableName.asTableIdentifier) + case CacheTableStatement(tableName, plan, isLazy, options) => val v1TableName = parseV1Table(tableName, "CACHE TABLE") CacheTableCommand(v1TableName.asTableIdentifier, plan, isLazy, options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index aef0a2d2e595..20894b39ce5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -135,14 +135,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { ShowColumnsCommand(Option(ctx.db).map(_.getText), visitTableIdentifier(ctx.tableIdentifier)) } - /** - * Creates a [[ShowCreateTableCommand]] - */ - override def visitShowCreateTable(ctx: ShowCreateTableContext): LogicalPlan = withOrigin(ctx) { - val table = visitTableIdentifier(ctx.tableIdentifier()) - ShowCreateTableCommand(table) - } - /** * Create a [[RefreshResource]] logical plan. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 4f2c1af8f7b5..b8a8acbba57c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1268,6 +1268,14 @@ class DataSourceV2SQLSuite } } + test("SHOW CREATE TABLE") { + val t = "testcat.ns1.ns2.tbl" + withTable(t) { + spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo") + testV1Command("SHOW CREATE TABLE", t) + } + } + test("CACHE TABLE") { val t = "testcat.ns1.ns2.tbl" withTable(t) { From ae5b60da329ac63935d180d20a62f1bb181f5514 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 25 Oct 2019 11:13:06 -0700 Subject: [PATCH 49/58] [SPARK-29182][CORE][FOLLOWUP] Cache preferred locations of checkpointed RDD ### What changes were proposed in this pull request? This is a followup to #25856. This fixes the document about the config value of spark.rdd.checkpoint.cachePreferredLocsExpireTime. ### Why are the changes needed? The document is not correct. spark.rdd.checkpoint.cachePreferredLocsExpireTime can not be 0. ### Does this PR introduce any user-facing change? No ### How was this patch tested? This is document only change. Closes #26251 from viirya/SPARK-29182-followup. Authored-by: Liang-Chi Hsieh Signed-off-by: Dongjoon Hyun --- .../main/scala/org/apache/spark/internal/config/package.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 36211dc2ed4f..444a1544777a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -804,7 +804,7 @@ package object config { .doc("Expire time in minutes for caching preferred locations of checkpointed RDD." + "Caching preferred locations can relieve query loading to DFS and save the query " + "time. The drawback is that the cached locations can be possibly outdated and " + - "lose data locality. If this config is not specified or is 0, it will not cache.") + "lose data locality. If this config is not specified, it will not cache.") .timeConf(TimeUnit.MINUTES) .checkValue(_ > 0, "The expire time for caching preferred locations cannot be non-positive.") .createOptional From 2baf7a1d8ff0d7018d6b70876c1e65b549ae30b0 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 25 Oct 2019 13:57:26 -0700 Subject: [PATCH 50/58] [SPARK-29608][BUILD] Add `hadoop-3.2` profile to release build ### What changes were proposed in this pull request? This PR aims to add `hadoop-3.2` profile to pre-built binary package releases. ### Why are the changes needed? Since Apache Spark 3.0.0, we provides Hadoop 3.2 pre-built binary. ### Does this PR introduce any user-facing change? No. (Although the artifacts are available, this change is for release managers). ### How was this patch tested? Manual. Please note that `DRY_RUN=0` disables these combination. ``` $ dev/create-release/release-build.sh package ... Packages to build: without-hadoop hadoop3.2 hadoop2.7 make_binary_release without-hadoop -Pscala-2.12 -Phadoop-provided 2.12 make_binary_release hadoop3.2 -Pscala-2.12 -Phadoop-3.2 -Phive -Phive-thriftserver 2.12 make_binary_release hadoop2.7 -Pscala-2.12 -Phadoop-2.7 -Phive -Phive-thriftserver withpip,withr 2.12 ``` Closes #26260 from dongjoon-hyun/SPARK-29608. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- dev/create-release/release-build.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 61951e73f4ba..1f6fdb2a55ff 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -280,6 +280,8 @@ if [[ "$1" == "package" ]]; then BINARY_PKGS_ARGS["without-hadoop"]="-Phadoop-provided" if [[ $SPARK_VERSION < "3.0." ]]; then BINARY_PKGS_ARGS["hadoop2.6"]="-Phadoop-2.6 $HIVE_PROFILES" + else + BINARY_PKGS_ARGS["hadoop3.2"]="-Phadoop-3.2 $HIVE_PROFILES" fi fi From 25493919f82415c329c81c0c529eff576b491cd9 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Fri, 25 Oct 2019 14:11:35 -0700 Subject: [PATCH 51/58] [SPARK-29580][TESTS] Add kerberos debug messages for Kafka secure tests ### What changes were proposed in this pull request? `org.apache.spark.sql.kafka010.KafkaDelegationTokenSuite` failed lately. After had a look at the logs it just shows the following fact without any details: ``` Caused by: sbt.ForkMain$ForkError: sun.security.krb5.KrbException: Server not found in Kerberos database (7) - Server not found in Kerberos database ``` Since the issue is intermittent and not able to reproduce it we should add more debug information and wait for reproduction with the extended logs. ### Why are the changes needed? Failing test doesn't give enough debug information. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? I've started the test manually and checked that such additional debug messages show up: ``` >>> KrbApReq: APOptions are 00000000 00000000 00000000 00000000 >>> EType: sun.security.krb5.internal.crypto.Aes128CtsHmacSha1EType Looking for keys for: kafka/localhostEXAMPLE.COM Added key: 17version: 0 Added key: 23version: 0 Added key: 16version: 0 Found unsupported keytype (3) for kafka/localhostEXAMPLE.COM >>> EType: sun.security.krb5.internal.crypto.Aes128CtsHmacSha1EType Using builtin default etypes for permitted_enctypes default etypes for permitted_enctypes: 17 16 23. >>> EType: sun.security.krb5.internal.crypto.Aes128CtsHmacSha1EType MemoryCache: add 1571936500/174770/16C565221B70AAB2BEFE31A83D13A2F4/client/localhostEXAMPLE.COM to client/localhostEXAMPLE.COM|kafka/localhostEXAMPLE.COM MemoryCache: Existing AuthList: #3: 1571936493/200803/8CD70D280B0862C5DA1FF901ECAD39FE/client/localhostEXAMPLE.COM #2: 1571936499/985009/BAD33290D079DD4E3579A8686EC326B7/client/localhostEXAMPLE.COM #1: 1571936499/995208/B76B9D78A9BE283AC78340157107FD40/client/localhostEXAMPLE.COM ``` Closes #26252 from gaborgsomogyi/SPARK-29580. Authored-by: Gabor Somogyi Signed-off-by: Dongjoon Hyun --- .../spark/sql/kafka010/KafkaTestUtils.scala | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index bbb72bf9973e..6c745987b4c2 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -67,6 +67,8 @@ class KafkaTestUtils( secure: Boolean = false) extends Logging { private val JAVA_AUTH_CONFIG = "java.security.auth.login.config" + private val IBM_KRB_DEBUG_CONFIG = "com.ibm.security.krb5.Krb5Debug" + private val SUN_KRB_DEBUG_CONFIG = "sun.security.krb5.debug" private val localCanonicalHostName = InetAddress.getLoopbackAddress().getCanonicalHostName() logInfo(s"Local host name is $localCanonicalHostName") @@ -133,6 +135,7 @@ class KafkaTestUtils( private def setUpMiniKdc(): Unit = { val kdcDir = Utils.createTempDir() val kdcConf = MiniKdc.createConf() + kdcConf.setProperty(MiniKdc.DEBUG, "true") kdc = new MiniKdc(kdcConf, kdcDir) kdc.start() kdcReady = true @@ -238,6 +241,7 @@ class KafkaTestUtils( } if (secure) { + setupKrbDebug() setUpMiniKdc() val jaasConfigFile = createKeytabsAndJaasConfigFile() System.setProperty(JAVA_AUTH_CONFIG, jaasConfigFile) @@ -252,6 +256,14 @@ class KafkaTestUtils( } } + private def setupKrbDebug(): Unit = { + if (System.getProperty("java.vendor").contains("IBM")) { + System.setProperty(IBM_KRB_DEBUG_CONFIG, "all") + } else { + System.setProperty(SUN_KRB_DEBUG_CONFIG, "true") + } + } + /** Teardown the whole servers, including Kafka broker and Zookeeper */ def teardown(): Unit = { if (leakDetector != null) { @@ -303,6 +315,15 @@ class KafkaTestUtils( kdc.stop() } UserGroupInformation.reset() + teardownKrbDebug() + } + + private def teardownKrbDebug(): Unit = { + if (System.getProperty("java.vendor").contains("IBM")) { + System.clearProperty(IBM_KRB_DEBUG_CONFIG) + } else { + System.clearProperty(SUN_KRB_DEBUG_CONFIG) + } } /** Create a Kafka topic and wait until it is propagated to the whole cluster */ From 5bdc58bf8a951df9b1be5a0298335b3668749358 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 25 Oct 2019 21:17:17 -0700 Subject: [PATCH 52/58] [SPARK-27653][SQL][FOLLOWUP] Fix `since` version of `min_by/max_by` ### What changes were proposed in this pull request? This is a follow-up of https://github.com/apache/spark/pull/24557 to fix `since` version. ### Why are the changes needed? This is found during 3.0.0-preview preparation. The version will be exposed to our SQL document like the following. We had better fix this. - https://spark.apache.org/docs/latest/api/sql/#array_min ### Does this PR introduce any user-facing change? Yes. It's exposed at `DESC FUNCTION EXTENDED` SQL command and SQL doc, but this is new at 3.0.0. ### How was this patch tested? Manual. ``` spark-sql> DESC FUNCTION EXTENDED min_by; Function: min_by Class: org.apache.spark.sql.catalyst.expressions.aggregate.MinBy Usage: min_by(x, y) - Returns the value of `x` associated with the minimum value of `y`. Extended Usage: Examples: > SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y); a Since: 3.0.0 ``` Closes #26264 from dongjoon-hyun/SPARK-27653. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala index c7fdb15130c4..b69b341b0ee3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala @@ -98,7 +98,7 @@ abstract class MaxMinBy extends DeclarativeAggregate { > SELECT _FUNC_(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y); b """, - since = "3.0") + since = "3.0.0") case class MaxBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMinBy { override protected def funcName: String = "max_by" @@ -116,7 +116,7 @@ case class MaxBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMin > SELECT _FUNC_(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y); a """, - since = "3.0") + since = "3.0.0") case class MinBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMinBy { override protected def funcName: String = "min_by" From 9a4670279177353519a0d12d9d37f7207f72488e Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 25 Oct 2019 23:02:11 -0700 Subject: [PATCH 53/58] [SPARK-29554][SQL] Add `version` SQL function ### What changes were proposed in this pull request? ``` hive> select version(); OK 3.1.1 rf4e0529634b6231a0072295da48af466cf2f10b7 Time taken: 2.113 seconds, Fetched: 1 row(s) ``` ### Why are the changes needed? From hive behavior and I guess it is useful for debugging and developing etc. ### Does this PR introduce any user-facing change? add a misc func ### How was this patch tested? add ut Closes #26209 from yaooqinn/SPARK-29554. Authored-by: Kent Yao Signed-off-by: Dongjoon Hyun --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/misc.scala | 17 +++++++++++++++-- .../apache/spark/sql/MiscFunctionsSuite.scala | 7 +++++++ 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 52e05b820366..019e1a08779e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -486,6 +486,7 @@ object FunctionRegistry { expression[CurrentDatabase]("current_database"), expression[CallMethodViaReflection]("reflect"), expression[CallMethodViaReflection]("java_method"), + expression[Version]("version"), // grouping sets expression[Cube]("cube"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 2af2b13ad77f..b8c23a1f0891 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.util.UUID - +import org.apache.spark.{SPARK_REVISION, SPARK_VERSION_SHORT} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ @@ -164,3 +163,17 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta override def freshCopy(): Uuid = Uuid(randomSeed) } + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """_FUNC_() - Returns the Spark version. The string contains 2 fields, the first being a release version and the second being a git revision.""", + since = "3.0.0") +// scalastyle:on line.size.limit +case class Version() extends LeafExpression with CodegenFallback { + override def nullable: Boolean = false + override def foldable: Boolean = true + override def dataType: DataType = StringType + override def eval(input: InternalRow): Any = { + UTF8String.fromString(SPARK_VERSION_SHORT + " " + SPARK_REVISION) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala index cad0821dbf5a..5ab06b1ebebf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.{SPARK_REVISION, SPARK_VERSION_SHORT} import org.apache.spark.sql.test.SharedSparkSession class MiscFunctionsSuite extends QueryTest with SharedSparkSession { @@ -31,6 +32,12 @@ class MiscFunctionsSuite extends QueryTest with SharedSparkSession { s"java_method('$className', 'method1', a, b)"), Row("m1one", "m1one")) } + + test("version") { + checkAnswer( + Seq("").toDF("a").selectExpr("version()"), + Row(SPARK_VERSION_SHORT + " " + SPARK_REVISION)) + } } object ReflectClass { From 2115bf61465b504bc21e37465cb34878039b5cb8 Mon Sep 17 00:00:00 2001 From: rongma1997 Date: Fri, 25 Oct 2019 23:11:01 -0700 Subject: [PATCH 54/58] [SPARK-29490][SQL] Reset 'WritableColumnVector' in 'RowToColumnarExec' ### What changes were proposed in this pull request? Reset the `WritableColumnVector` when getting "next" ColumnarBatch in `RowToColumnarExec` ### Why are the changes needed? When converting `Iterator[InternalRow]` to `Iterator[ColumnarBatch]`, the vectors used to create a new `ColumnarBatch` should be reset in the iterator's "next()" method. ### Does this PR introduce any user-facing change? No ### How was this patch tested? N/A Closes #26137 from rongma1997/reset-WritableColumnVector. Authored-by: rongma1997 Signed-off-by: Dongjoon Hyun --- .../apache/spark/sql/execution/Columnar.scala | 1 + .../sql/SparkSessionExtensionSuite.scala | 25 +++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala index 9d1636ccf271..b41a4ff76667 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala @@ -454,6 +454,7 @@ case class RowToColumnarExec(child: SparkPlan) extends UnaryExecNode { override def next(): ColumnarBatch = { cb.setNumRows(0) + vectors.foreach(_.reset()) var rowCount = 0 while (rowCount < numRows && rowIterator.hasNext) { val row = rowIterator.next() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index b8df6f2bebf5..2a4c15233fe3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.internal.SQLConf.COLUMN_BATCH_SIZE import org.apache.spark.sql.internal.StaticSQLConf.SPARK_SESSION_EXTENSIONS import org.apache.spark.sql.types.{DataType, Decimal, IntegerType, LongType, Metadata, StructType} import org.apache.spark.sql.vectorized.{ColumnarArray, ColumnarBatch, ColumnarMap, ColumnVector} @@ -171,6 +172,30 @@ class SparkSessionExtensionSuite extends SparkFunSuite { } } + test("reset column vectors") { + val session = SparkSession.builder() + .master("local[1]") + .config(COLUMN_BATCH_SIZE.key, 2) + .withExtensions { extensions => + extensions.injectColumnar(session => + MyColumarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())) } + .getOrCreate() + + try { + assert(session.sessionState.columnarRules.contains( + MyColumarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule()))) + import session.sqlContext.implicits._ + + val input = Seq((100L), (200L), (300L)) + val data = input.toDF("vals").repartition(1) + val df = data.selectExpr("vals + 1") + val result = df.collect() + assert(result sameElements input.map(x => Row(x + 2))) + } finally { + stop(session) + } + } + test("use custom class for extensions") { val session = SparkSession.builder() .master("local[1]") From 077fb99a26a9e92104503fade25c0a095fec5e5d Mon Sep 17 00:00:00 2001 From: shahid Date: Sat, 26 Oct 2019 15:46:24 -0500 Subject: [PATCH 55/58] [SPARK-29589][WEBUI] Support pagination for sqlstats session table in JDBC/ODBC Session page ### What changes were proposed in this pull request? In the PR https://github.com/apache/spark/pull/26215, we supported pagination for sqlstats table in JDBC/ODBC server page. In this PR, we are extending the support of pagination to sqlstats session table by making use of existing pagination classes in https://github.com/apache/spark/pull/26215. ### Why are the changes needed? Support pagination for sqlsessionstats table in JDBC/ODBC server page in the WEBUI. It will easier for user to analyse the table and it may fix the potential issues like oom while loading the page, that may occur similar to the SQL page (refer #22645) ### Does this PR introduce any user-facing change? There will be no change in the sqlsessionstats table in JDBC/ODBC server page execpt pagination support. ### How was this patch tested? Manually verified. Before: ![Screenshot 2019-10-24 at 11 32 27 PM](https://user-images.githubusercontent.com/23054875/67512507-96715000-f6b6-11e9-9f1f-ab1877eb24e6.png) After: ![Screenshot 2019-10-24 at 10 58 53 PM](https://user-images.githubusercontent.com/23054875/67512314-295dba80-f6b6-11e9-9e3e-dd50c6e62fe9.png) Closes #26246 from shahidki31/SPARK_29589. Authored-by: shahid Signed-off-by: Sean Owen --- .../thriftserver/ui/ThriftServerPage.scala | 16 +-- .../ui/ThriftServerSessionPage.scala | 127 +++++++----------- 2 files changed, 50 insertions(+), 93 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index e472aaad5bdc..7258978e3bad 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -28,7 +28,7 @@ import scala.xml.{Node, Unparsed} import org.apache.commons.text.StringEscapeUtils import org.apache.spark.internal.Logging -import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{ExecutionInfo, ExecutionState, SessionInfo} +import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{ExecutionInfo, SessionInfo} import org.apache.spark.sql.hive.thriftserver.ui.ToolTips._ import org.apache.spark.ui._ import org.apache.spark.ui.UIUtils._ @@ -181,14 +181,6 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" private def formatDurationOption(msOption: Option[Long]): String = { msOption.map(formatDurationVerbose).getOrElse(emptyCell) } - - /** Generate HTML table from string data */ - private def listingTable(headers: Seq[String], data: Seq[Seq[String]]) = { - def generateDataRow(data: Seq[String]): Seq[Node] = { - {data.map(d => {d})} - } - UIUtils.listingTable(headers, generateDataRow, data, fixedWidth = true) - } } private[ui] class SqlStatsPagedTable( @@ -328,11 +320,12 @@ private[ui] class SqlStatsPagedTable( {if (info.closeTimestamp > 0) formatDate(info.closeTimestamp)} + - {UIUtils.formatDuration(executionTime)} + {formatDurationVerbose(executionTime)} - {UIUtils.formatDuration(duration)} + {formatDurationVerbose(duration)} {info.statement} @@ -400,7 +393,6 @@ private[ui] class SqlStatsPagedTable( override def sliceData(from: Int, to: Int): Seq[SqlStatsTableRow] = { val r = data.slice(from, to) - r.map(x => x) _slicedStartTime = r.map(_.executionInfo.startTimestamp).toSet r } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index 0aa0a2b8335d..8b275f8f7be0 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -20,15 +20,13 @@ package org.apache.spark.sql.hive.thriftserver.ui import java.util.Calendar import javax.servlet.http.HttpServletRequest +import scala.collection.JavaConverters._ import scala.xml.Node -import org.apache.commons.text.StringEscapeUtils - import org.apache.spark.internal.Logging -import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{ExecutionInfo, ExecutionState} -import org.apache.spark.sql.hive.thriftserver.ui.ToolTips._ import org.apache.spark.ui._ import org.apache.spark.ui.UIUtils._ +import org.apache.spark.util.Utils /** Page for Spark Web UI that shows statistics of jobs running in the thrift server */ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) @@ -36,7 +34,6 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) private val listener = parent.listener private val startTime = Calendar.getInstance().getTime() - private val emptyCell = "-" /** Render the page */ def render(request: HttpServletRequest): Seq[Node] = { @@ -80,45 +77,52 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) .filter(_.sessionId == sessionID) val numStatement = executionList.size val table = if (numStatement > 0) { - val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Close Time", - "Execution Time", "Duration", "Statement", "State", "Detail") - val tooltips = Seq(None, None, None, None, Some(THRIFT_SERVER_FINISH_TIME), - Some(THRIFT_SERVER_CLOSE_TIME), Some(THRIFT_SERVER_EXECUTION), - Some(THRIFT_SERVER_DURATION), None, None, None) - assert(headerRow.length == tooltips.length) - val dataRows = executionList.sortBy(_.startTimestamp).reverse - - def generateDataRow(info: ExecutionInfo): Seq[Node] = { - val jobLink = info.jobId.map { id: String => - - [{id}] - + + val sqlTableTag = "sqlsessionstat" + + val parameterOtherTable = request.getParameterMap().asScala + .filterNot(_._1.startsWith(sqlTableTag)) + .map { case (name, vals) => + name + "=" + vals(0) } - val detail = Option(info.detail).filter(!_.isEmpty).getOrElse(info.executePlan) - - {info.userName} - - {jobLink} - - {info.groupId} - {formatDate(info.startTimestamp)} - {if (info.finishTimestamp > 0) formatDate(info.finishTimestamp)} - {if (info.closeTimestamp > 0) formatDate(info.closeTimestamp)} - - {formatDurationOption(Some(info.totalTime(info.finishTimestamp)))} - - - {formatDurationOption(Some(info.totalTime(info.closeTimestamp)))} - - {info.statement} - {info.state} - {errorMessageCell(detail)} - - } - Some(UIUtils.listingTable(headerRow, generateDataRow, - dataRows, false, None, Seq(null), false, tooltipHeaders = tooltips)) + val parameterSqlTablePage = request.getParameter(s"$sqlTableTag.page") + val parameterSqlTableSortColumn = request.getParameter(s"$sqlTableTag.sort") + val parameterSqlTableSortDesc = request.getParameter(s"$sqlTableTag.desc") + val parameterSqlPageSize = request.getParameter(s"$sqlTableTag.pageSize") + + val sqlTablePage = Option(parameterSqlTablePage).map(_.toInt).getOrElse(1) + val sqlTableSortColumn = Option(parameterSqlTableSortColumn).map { sortColumn => + UIUtils.decodeURLParameter(sortColumn) + }.getOrElse("Start Time") + val sqlTableSortDesc = Option(parameterSqlTableSortDesc).map(_.toBoolean).getOrElse( + // New executions should be shown above old executions by default. + sqlTableSortColumn == "Start Time" + ) + val sqlTablePageSize = Option(parameterSqlPageSize).map(_.toInt).getOrElse(100) + + try { + Some(new SqlStatsPagedTable( + request, + parent, + executionList, + "sqlserver/session", + UIUtils.prependBaseUri(request, parent.basePath), + parameterOtherTable, + sqlTableTag, + pageSize = sqlTablePageSize, + sortColumn = sqlTableSortColumn, + desc = sqlTableSortDesc + ).table(sqlTablePage)) + } catch { + case e@(_: IllegalArgumentException | _: IndexOutOfBoundsException) => + Some(
    +

    Error while rendering job table:

    +
    +              {Utils.exceptionString(e)}
    +            
    +
    ) + } } else { None } @@ -133,43 +137,4 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) content } - - private def errorMessageCell(errorMessage: String): Seq[Node] = { - val isMultiline = errorMessage.indexOf('\n') >= 0 - val errorSummary = StringEscapeUtils.escapeHtml4( - if (isMultiline) { - errorMessage.substring(0, errorMessage.indexOf('\n')) - } else { - errorMessage - }) - val details = if (isMultiline) { - // scalastyle:off - - + details - ++ - - // scalastyle:on - } else { - "" - } - {errorSummary}{details} - } - - /** - * Returns a human-readable string representing a duration such as "5 second 35 ms" - */ - private def formatDurationOption(msOption: Option[Long]): String = { - msOption.map(formatDurationVerbose).getOrElse(emptyCell) - } - - /** Generate HTML table from string data */ - private def listingTable(headers: Seq[String], data: Seq[Seq[String]]) = { - def generateDataRow(data: Seq[String]): Seq[Node] = { - {data.map(d => {d})} - } - UIUtils.listingTable(headers, generateDataRow, data, fixedWidth = true) - } } From 74514b46e53231e3567570d183be04dfa9d4af0a Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 27 Oct 2019 13:48:43 -0700 Subject: [PATCH 56/58] [SPARK-29614][SQL][TEST] Fix failures of DateTimeUtilsSuite and TimestampFormatterSuite ### What changes were proposed in this pull request? The `DateTimeUtilsSuite` and `TimestampFormatterSuite` assume constant time difference between `timestamp'yesterday'`, `timestamp'today'` and `timestamp'tomorrow'` which is wrong on daylight switching day - day length can be 23 or 25 hours. In the PR, I propose to use Java 8 time API to calculate instances of `yesterday` and `tomorrow` timestamps. ### Why are the changes needed? The changes fix test failures and make the tests tolerant to daylight time switching. ### Does this PR introduce any user-facing change? No ### How was this patch tested? By existing test suites `DateTimeUtilsSuite` and `TimestampFormatterSuite`. Closes #26273 from MaxGekk/midnight-tolerant. Authored-by: Maxim Gekk Signed-off-by: Dongjoon Hyun --- .../sql/catalyst/util/DateTimeUtilsSuite.scala | 11 +++++++---- .../spark/sql/util/TimestampFormatterSuite.scala | 13 ++++++++----- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 10642b3ca8a4..0eaf53823128 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -586,12 +586,15 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers { val now = instantToMicros(LocalDateTime.now(zoneId).atZone(zoneId).toInstant) toTimestamp("NOW", zoneId).get should be (now +- tolerance) assert(toTimestamp("now UTC", zoneId) === None) - val today = instantToMicros(LocalDateTime.now(zoneId) + val localToday = LocalDateTime.now(zoneId) .`with`(LocalTime.MIDNIGHT) - .atZone(zoneId).toInstant) - toTimestamp(" Yesterday", zoneId).get should be (today - MICROS_PER_DAY +- tolerance) + .atZone(zoneId) + val yesterday = instantToMicros(localToday.minusDays(1).toInstant) + toTimestamp(" Yesterday", zoneId).get should be (yesterday +- tolerance) + val today = instantToMicros(localToday.toInstant) toTimestamp("Today ", zoneId).get should be (today +- tolerance) - toTimestamp(" tomorrow CET ", zoneId).get should be (today + MICROS_PER_DAY +- tolerance) + val tomorrow = instantToMicros(localToday.plusDays(1).toInstant) + toTimestamp(" tomorrow CET ", zoneId).get should be (tomorrow +- tolerance) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala index 170daa6277c4..84581c0badd8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, TimestampFormatter} -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, instantToMicros, MICROS_PER_DAY} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, instantToMicros} import org.apache.spark.sql.internal.SQLConf class TimestampFormatterSuite extends SparkFunSuite with SQLHelper with Matchers { @@ -146,12 +146,15 @@ class TimestampFormatterSuite extends SparkFunSuite with SQLHelper with Matchers assert(formatter.parse("EPOCH") === 0) val now = instantToMicros(LocalDateTime.now(zoneId).atZone(zoneId).toInstant) formatter.parse("now") should be (now +- tolerance) - val today = instantToMicros(LocalDateTime.now(zoneId) + val localToday = LocalDateTime.now(zoneId) .`with`(LocalTime.MIDNIGHT) - .atZone(zoneId).toInstant) - formatter.parse("yesterday CET") should be (today - MICROS_PER_DAY +- tolerance) + .atZone(zoneId) + val yesterday = instantToMicros(localToday.minusDays(1).toInstant) + formatter.parse("yesterday CET") should be (yesterday +- tolerance) + val today = instantToMicros(localToday.toInstant) formatter.parse(" TODAY ") should be (today +- tolerance) - formatter.parse("Tomorrow ") should be (today + MICROS_PER_DAY +- tolerance) + val tomorrow = instantToMicros(localToday.plusDays(1).toInstant) + formatter.parse("Tomorrow ") should be (tomorrow +- tolerance) } } } From a43b966f00cf5622aa88c98c1636924d0e24d626 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 27 Oct 2019 16:15:54 -0700 Subject: [PATCH 57/58] [SPARK-29613][BUILD][SS] Upgrade to Kafka 2.3.1 ### What changes were proposed in this pull request? This PR aims to upgrade to Kafka 2.3.1 client library for client fixes like KAFKA-8950, KAFKA-8570, and KAFKA-8635. The following is the full release note. - https://archive.apache.org/dist/kafka/2.3.1/RELEASE_NOTES.html ### Why are the changes needed? - [KAFKA-8950 KafkaConsumer stops fetching](https://issues.apache.org/jira/browse/KAFKA-8950) - [KAFKA-8570 Downconversion could fail when log contains out of order message formats](https://issues.apache.org/jira/browse/KAFKA-8570) - [KAFKA-8635 Unnecessary wait when looking up coordinator before transactional request](https://issues.apache.org/jira/browse/KAFKA-8635) ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Pass the Jenkins with the existing tests. Closes #26271 from dongjoon-hyun/SPARK-29613. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index c42ef5c6626d..f1a7cb3d106f 100644 --- a/pom.xml +++ b/pom.xml @@ -136,7 +136,7 @@ 1.2.1 - 2.3.0 + 2.3.1 10.12.1.1 1.10.1 1.5.6 From b19fd487dfe307542d65391fd7b8410fa4992698 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Mon, 28 Oct 2019 11:36:10 +0800 Subject: [PATCH 58/58] [SPARK-29093][PYTHON][ML] Remove automatically generated param setters in _shared_params_code_gen.py ### What changes were proposed in this pull request? Remove automatically generated param setters in _shared_params_code_gen.py ### Why are the changes needed? To keep parity between scala and python ### Does this PR introduce any user-facing change? Yes Add some setters in Python ML XXXModels ### How was this patch tested? unit tests Closes #26232 from huaxingao/spark-29093. Authored-by: Huaxin Gao Signed-off-by: zhengruifeng --- .../ml/feature/QuantileDiscretizer.scala | 5 +- python/pyspark/ml/base.py | 12 + python/pyspark/ml/classification.py | 324 +++++- python/pyspark/ml/clustering.py | 254 ++++- python/pyspark/ml/evaluation.py | 152 ++- python/pyspark/ml/feature.py | 1016 ++++++++++++++++- python/pyspark/ml/fpm.py | 13 + .../ml/param/_shared_params_code_gen.py | 6 - python/pyspark/ml/param/shared.py | 186 --- python/pyspark/ml/recommendation.py | 48 +- python/pyspark/ml/regression.py | 366 +++++- python/pyspark/ml/tests/test_param.py | 12 +- python/pyspark/ml/tuning.py | 36 + 13 files changed, 2141 insertions(+), 289 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index aa4ab5903f71..eb78d8224fc3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types.StructType * Params for [[QuantileDiscretizer]]. */ private[feature] trait QuantileDiscretizerBase extends Params - with HasHandleInvalid with HasInputCol with HasOutputCol { + with HasHandleInvalid with HasInputCol with HasOutputCol with HasInputCols with HasOutputCols { /** * Number of buckets (quantiles, or categories) into which data points are grouped. Must @@ -129,8 +129,7 @@ private[feature] trait QuantileDiscretizerBase extends Params */ @Since("1.6.0") final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val uid: String) - extends Estimator[Bucketizer] with QuantileDiscretizerBase with DefaultParamsWritable - with HasInputCols with HasOutputCols { + extends Estimator[Bucketizer] with QuantileDiscretizerBase with DefaultParamsWritable { @Since("1.6.0") def this() = this(Identifiable.randomUID("quantileDiscretizer")) diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py index 82ff81c58d3c..542cb25172ea 100644 --- a/python/pyspark/ml/base.py +++ b/python/pyspark/ml/base.py @@ -194,6 +194,18 @@ class UnaryTransformer(HasInputCol, HasOutputCol, Transformer): .. versionadded:: 2.3.0 """ + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + @abstractmethod def createTransformFunc(self): """ diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index d0c821329471..c5cdf35729dd 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -177,7 +177,19 @@ class LinearSVC(JavaClassifier, _LinearSVCParams, JavaMLWritable, JavaMLReadable >>> df = sc.parallelize([ ... Row(label=1.0, features=Vectors.dense(1.0, 1.0, 1.0)), ... Row(label=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF() - >>> svm = LinearSVC(maxIter=5, regParam=0.01) + >>> svm = LinearSVC() + >>> svm.getMaxIter() + 100 + >>> svm.setMaxIter(5) + LinearSVC... + >>> svm.getMaxIter() + 5 + >>> svm.getRegParam() + 0.0 + >>> svm.setRegParam(0.01) + LinearSVC... + >>> svm.getRegParam() + 0.01 >>> model = svm.fit(df) >>> model.setPredictionCol("newPrediction") LinearSVC... @@ -257,6 +269,62 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LinearSVCModel(java_model) + @since("2.2.0") + def setMaxIter(self, value): + """ + Sets the value of :py:attr:`maxIter`. + """ + return self._set(maxIter=value) + + @since("2.2.0") + def setRegParam(self, value): + """ + Sets the value of :py:attr:`regParam`. + """ + return self._set(regParam=value) + + @since("2.2.0") + def setTol(self, value): + """ + Sets the value of :py:attr:`tol`. + """ + return self._set(tol=value) + + @since("2.2.0") + def setFitIntercept(self, value): + """ + Sets the value of :py:attr:`fitIntercept`. + """ + return self._set(fitIntercept=value) + + @since("2.2.0") + def setStandardization(self, value): + """ + Sets the value of :py:attr:`standardization`. + """ + return self._set(standardization=value) + + @since("2.2.0") + def setThreshold(self, value): + """ + Sets the value of :py:attr:`threshold`. + """ + return self._set(threshold=value) + + @since("2.2.0") + def setWeightCol(self, value): + """ + Sets the value of :py:attr:`weightCol`. + """ + return self._set(weightCol=value) + + @since("2.2.0") + def setAggregationDepth(self, value): + """ + Sets the value of :py:attr:`aggregationDepth`. + """ + return self._set(aggregationDepth=value) + class LinearSVCModel(JavaClassificationModel, _LinearSVCParams, JavaMLWritable, JavaMLReadable): """ @@ -265,6 +333,13 @@ class LinearSVCModel(JavaClassificationModel, _LinearSVCParams, JavaMLWritable, .. versionadded:: 2.2.0 """ + @since("3.0.0") + def setThreshold(self, value): + """ + Sets the value of :py:attr:`threshold`. + """ + return self._set(threshold=value) + @property @since("2.2.0") def coefficients(self): @@ -454,7 +529,18 @@ class LogisticRegression(JavaProbabilisticClassifier, _LogisticRegressionParams, ... Row(label=0.0, weight=2.0, features=Vectors.dense(1.0, 2.0)), ... Row(label=1.0, weight=3.0, features=Vectors.dense(2.0, 1.0)), ... Row(label=0.0, weight=4.0, features=Vectors.dense(3.0, 3.0))]).toDF() - >>> blor = LogisticRegression(regParam=0.01, weightCol="weight") + >>> blor = LogisticRegression(weightCol="weight") + >>> blor.getRegParam() + 0.0 + >>> blor.setRegParam(0.01) + LogisticRegression... + >>> blor.getRegParam() + 0.01 + >>> blor.setMaxIter(10) + LogisticRegression... + >>> blor.getMaxIter() + 10 + >>> blor.clear(blor.maxIter) >>> blorModel = blor.fit(bdf) >>> blorModel.setFeaturesCol("features") LogisticRegressionModel... @@ -603,6 +689,54 @@ def setUpperBoundsOnIntercepts(self, value): """ return self._set(upperBoundsOnIntercepts=value) + def setMaxIter(self, value): + """ + Sets the value of :py:attr:`maxIter`. + """ + return self._set(maxIter=value) + + def setRegParam(self, value): + """ + Sets the value of :py:attr:`regParam`. + """ + return self._set(regParam=value) + + def setTol(self, value): + """ + Sets the value of :py:attr:`tol`. + """ + return self._set(tol=value) + + def setElasticNetParam(self, value): + """ + Sets the value of :py:attr:`elasticNetParam`. + """ + return self._set(elasticNetParam=value) + + def setFitIntercept(self, value): + """ + Sets the value of :py:attr:`fitIntercept`. + """ + return self._set(fitIntercept=value) + + def setStandardization(self, value): + """ + Sets the value of :py:attr:`standardization`. + """ + return self._set(standardization=value) + + def setWeightCol(self, value): + """ + Sets the value of :py:attr:`weightCol`. + """ + return self._set(weightCol=value) + + def setAggregationDepth(self, value): + """ + Sets the value of :py:attr:`aggregationDepth`. + """ + return self._set(aggregationDepth=value) + class LogisticRegressionModel(JavaProbabilisticClassificationModel, _LogisticRegressionParams, JavaMLWritable, JavaMLReadable, HasTrainingSummary): @@ -1148,6 +1282,27 @@ def setImpurity(self, value): """ return self._set(impurity=value) + @since("1.4.0") + def setCheckpointInterval(self, value): + """ + Sets the value of :py:attr:`checkpointInterval`. + """ + return self._set(checkpointInterval=value) + + @since("1.4.0") + def setSeed(self, value): + """ + Sets the value of :py:attr:`seed`. + """ + return self._set(seed=value) + + @since("3.0.0") + def setWeightCol(self, value): + """ + Sets the value of :py:attr:`weightCol`. + """ + return self._set(weightCol=value) + @inherit_doc class DecisionTreeClassificationModel(_DecisionTreeModel, JavaProbabilisticClassificationModel, @@ -1366,6 +1521,18 @@ def setFeatureSubsetStrategy(self, value): """ return self._set(featureSubsetStrategy=value) + def setSeed(self, value): + """ + Sets the value of :py:attr:`seed`. + """ + return self._set(seed=value) + + def setCheckpointInterval(self, value): + """ + Sets the value of :py:attr:`checkpointInterval`. + """ + return self._set(checkpointInterval=value) + class RandomForestClassificationModel(_TreeEnsembleModel, JavaProbabilisticClassificationModel, _RandomForestClassifierParams, JavaMLWritable, @@ -1451,6 +1618,10 @@ class GBTClassifier(JavaProbabilisticClassifier, _GBTClassifierParams, >>> td = si_model.transform(df) >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42, ... leafCol="leafId") + >>> gbt.setMaxIter(5) + GBTClassifier... + >>> gbt.getMaxIter() + 5 >>> gbt.getFeatureSubsetStrategy() 'all' >>> model = gbt.fit(td) @@ -1630,6 +1801,34 @@ def setValidationIndicatorCol(self, value): """ return self._set(validationIndicatorCol=value) + @since("1.4.0") + def setMaxIter(self, value): + """ + Sets the value of :py:attr:`maxIter`. + """ + return self._set(maxIter=value) + + @since("1.4.0") + def setCheckpointInterval(self, value): + """ + Sets the value of :py:attr:`checkpointInterval`. + """ + return self._set(checkpointInterval=value) + + @since("1.4.0") + def setSeed(self, value): + """ + Sets the value of :py:attr:`seed`. + """ + return self._set(seed=value) + + @since("1.4.0") + def setStepSize(self, value): + """ + Sets the value of :py:attr:`stepSize`. + """ + return self._set(stepSize=value) + class GBTClassificationModel(_TreeEnsembleModel, JavaProbabilisticClassificationModel, _GBTClassifierParams, JavaMLWritable, JavaMLReadable): @@ -1723,10 +1922,6 @@ class NaiveBayes(JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds, >>> model = nb.fit(df) >>> model.setFeaturesCol("features") NaiveBayes_... - >>> model.setLabelCol("newLabel") - NaiveBayes_... - >>> model.getLabelCol() - 'newLabel' >>> model.getSmoothing() 1.0 >>> model.pi @@ -1814,6 +2009,12 @@ def setModelType(self, value): """ return self._set(modelType=value) + def setWeightCol(self, value): + """ + Sets the value of :py:attr:`weightCol`. + """ + return self._set(weightCol=value) + class NaiveBayesModel(JavaProbabilisticClassificationModel, _NaiveBayesParams, JavaMLWritable, JavaMLReadable): @@ -1906,7 +2107,11 @@ class MultilayerPerceptronClassifier(JavaProbabilisticClassifier, _MultilayerPer ... (1.0, Vectors.dense([0.0, 1.0])), ... (1.0, Vectors.dense([1.0, 0.0])), ... (0.0, Vectors.dense([1.0, 1.0]))], ["label", "features"]) - >>> mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[2, 2, 2], blockSize=1, seed=123) + >>> mlp = MultilayerPerceptronClassifier(layers=[2, 2, 2], blockSize=1, seed=123) + >>> mlp.setMaxIter(100) + MultilayerPerceptronClassifier... + >>> mlp.getMaxIter() + 100 >>> model = mlp.fit(df) >>> model.setFeaturesCol("features") MultilayerPerceptronClassifier... @@ -2000,6 +2205,31 @@ def setBlockSize(self, value): """ return self._set(blockSize=value) + @since("2.0.0") + def setInitialWeights(self, value): + """ + Sets the value of :py:attr:`initialWeights`. + """ + return self._set(initialWeights=value) + + def setMaxIter(self, value): + """ + Sets the value of :py:attr:`maxIter`. + """ + return self._set(maxIter=value) + + def setSeed(self, value): + """ + Sets the value of :py:attr:`seed`. + """ + return self._set(seed=value) + + def setTol(self, value): + """ + Sets the value of :py:attr:`tol`. + """ + return self._set(tol=value) + @since("2.0.0") def setStepSize(self, value): """ @@ -2007,12 +2237,11 @@ def setStepSize(self, value): """ return self._set(stepSize=value) - @since("2.0.0") - def setInitialWeights(self, value): + def setSolver(self, value): """ - Sets the value of :py:attr:`initialWeights`. + Sets the value of :py:attr:`solver`. """ - return self._set(initialWeights=value) + return self._set(solver=value) class MultilayerPerceptronClassificationModel(JavaProbabilisticClassificationModel, JavaMLWritable, @@ -2134,6 +2363,42 @@ def setClassifier(self, value): """ return self._set(classifier=value) + def setLabelCol(self, value): + """ + Sets the value of :py:attr:`labelCol`. + """ + return self._set(labelCol=value) + + def setFeaturesCol(self, value): + """ + Sets the value of :py:attr:`featuresCol`. + """ + return self._set(featuresCol=value) + + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + return self._set(predictionCol=value) + + def setRawPredictionCol(self, value): + """ + Sets the value of :py:attr:`rawPredictionCol`. + """ + return self._set(rawPredictionCol=value) + + def setWeightCol(self, value): + """ + Sets the value of :py:attr:`weightCol`. + """ + return self._set(weightCol=value) + + def setParallelism(self, value): + """ + Sets the value of :py:attr:`parallelism`. + """ + return self._set(parallelism=value) + def _fit(self, dataset): labelCol = self.getLabelCol() featuresCol = self.getFeaturesCol() @@ -2287,6 +2552,43 @@ class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable): .. versionadded:: 2.0.0 """ + @since("2.0.0") + def setClassifier(self, value): + """ + Sets the value of :py:attr:`classifier`. + """ + return self._set(classifier=value) + + def setFeaturesCol(self, value): + """ + Sets the value of :py:attr:`featuresCol`. + """ + return self._set(featuresCol=value) + + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + return self._set(predictionCol=value) + + def setRawPredictionCol(self, value): + """ + Sets the value of :py:attr:`rawPredictionCol`. + """ + return self._set(rawPredictionCol=value) + + def setLabelCol(self, value): + """ + Sets the value of :py:attr:`labelCol`. + """ + return self._set(labelCol=value) + + def setWeightCol(self, value): + """ + Sets the value of :py:attr:`weightCol`. + """ + return self._set(weightCol=value) + def __init__(self, models): super(OneVsRestModel, self).__init__() self.models = models diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index cbbbd36955dc..bb73dc78c4ab 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -123,6 +123,27 @@ class GaussianMixtureModel(JavaModel, _GaussianMixtureParams, JavaMLWritable, Ja .. versionadded:: 2.0.0 """ + @since("3.0.0") + def setFeaturesCol(self, value): + """ + Sets the value of :py:attr:`featuresCol`. + """ + return self._set(featuresCol=value) + + @since("3.0.0") + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + return self._set(predictionCol=value) + + @since("3.0.0") + def setProbabilityCol(self, value): + """ + Sets the value of :py:attr:`probabilityCol`. + """ + return self._set(probabilityCol=value) + @property @since("2.0.0") def weights(self): @@ -200,8 +221,13 @@ class GaussianMixture(JavaEstimator, _GaussianMixtureParams, JavaMLWritable, Jav ... (Vectors.dense([-0.83, -0.68]),), ... (Vectors.dense([-0.91, -0.76]),)] >>> df = spark.createDataFrame(data, ["features"]) - >>> gm = GaussianMixture(k=3, tol=0.0001, - ... maxIter=10, seed=10) + >>> gm = GaussianMixture(k=3, tol=0.0001, seed=10) + >>> gm.getMaxIter() + 100 + >>> gm.setMaxIter(10) + GaussianMixture... + >>> gm.getMaxIter() + 10 >>> model = gm.fit(df) >>> model.getFeaturesCol() 'features' @@ -290,6 +316,48 @@ def setK(self, value): """ return self._set(k=value) + @since("2.0.0") + def setMaxIter(self, value): + """ + Sets the value of :py:attr:`maxIter`. + """ + return self._set(maxIter=value) + + @since("2.0.0") + def setFeaturesCol(self, value): + """ + Sets the value of :py:attr:`featuresCol`. + """ + return self._set(featuresCol=value) + + @since("2.0.0") + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + return self._set(predictionCol=value) + + @since("2.0.0") + def setProbabilityCol(self, value): + """ + Sets the value of :py:attr:`probabilityCol`. + """ + return self._set(probabilityCol=value) + + @since("2.0.0") + def setSeed(self, value): + """ + Sets the value of :py:attr:`seed`. + """ + return self._set(seed=value) + + @since("2.0.0") + def setTol(self, value): + """ + Sets the value of :py:attr:`tol`. + """ + return self._set(tol=value) + class GaussianMixtureSummary(ClusteringSummary): """ @@ -389,6 +457,20 @@ class KMeansModel(JavaModel, _KMeansParams, GeneralJavaMLWritable, JavaMLReadabl .. versionadded:: 1.5.0 """ + @since("3.0.0") + def setFeaturesCol(self, value): + """ + Sets the value of :py:attr:`featuresCol`. + """ + return self._set(featuresCol=value) + + @since("3.0.0") + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + return self._set(predictionCol=value) + @since("1.5.0") def clusterCenters(self): """Get the cluster centers, represented as a list of NumPy arrays.""" @@ -425,7 +507,14 @@ class KMeans(JavaEstimator, _KMeansParams, JavaMLWritable, JavaMLReadable): >>> data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), ... (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)] >>> df = spark.createDataFrame(data, ["features"]) - >>> kmeans = KMeans(k=2, seed=1) + >>> kmeans = KMeans(k=2) + >>> kmeans.setSeed(1) + KMeans... + >>> kmeans.setMaxIter(10) + KMeans... + >>> kmeans.getMaxIter() + 10 + >>> kmeans.clear(kmeans.maxIter) >>> model = kmeans.fit(df) >>> model.getDistanceMeasure() 'euclidean' @@ -531,6 +620,41 @@ def setDistanceMeasure(self, value): """ return self._set(distanceMeasure=value) + @since("1.5.0") + def setMaxIter(self, value): + """ + Sets the value of :py:attr:`maxIter`. + """ + return self._set(maxIter=value) + + @since("1.5.0") + def setFeaturesCol(self, value): + """ + Sets the value of :py:attr:`featuresCol`. + """ + return self._set(featuresCol=value) + + @since("1.5.0") + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + return self._set(predictionCol=value) + + @since("1.5.0") + def setSeed(self, value): + """ + Sets the value of :py:attr:`seed`. + """ + return self._set(seed=value) + + @since("1.5.0") + def setTol(self, value): + """ + Sets the value of :py:attr:`tol`. + """ + return self._set(tol=value) + @inherit_doc class _BisectingKMeansParams(HasMaxIter, HasFeaturesCol, HasSeed, HasPredictionCol, @@ -571,6 +695,20 @@ class BisectingKMeansModel(JavaModel, _BisectingKMeansParams, JavaMLWritable, Ja .. versionadded:: 2.0.0 """ + @since("3.0.0") + def setFeaturesCol(self, value): + """ + Sets the value of :py:attr:`featuresCol`. + """ + return self._set(featuresCol=value) + + @since("3.0.0") + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + return self._set(predictionCol=value) + @since("2.0.0") def clusterCenters(self): """Get the cluster centers, represented as a list of NumPy arrays.""" @@ -629,6 +767,16 @@ class BisectingKMeans(JavaEstimator, _BisectingKMeansParams, JavaMLWritable, Jav ... (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)] >>> df = spark.createDataFrame(data, ["features"]) >>> bkm = BisectingKMeans(k=2, minDivisibleClusterSize=1.0) + >>> bkm.setMaxIter(10) + BisectingKMeans... + >>> bkm.getMaxIter() + 10 + >>> bkm.clear(bkm.maxIter) + >>> bkm.setSeed(1) + BisectingKMeans... + >>> bkm.getSeed() + 1 + >>> bkm.clear(bkm.seed) >>> model = bkm.fit(df) >>> model.getMaxIter() 20 @@ -723,6 +871,34 @@ def setDistanceMeasure(self, value): """ return self._set(distanceMeasure=value) + @since("2.0.0") + def setMaxIter(self, value): + """ + Sets the value of :py:attr:`maxIter`. + """ + return self._set(maxIter=value) + + @since("2.0.0") + def setFeaturesCol(self, value): + """ + Sets the value of :py:attr:`featuresCol`. + """ + return self._set(featuresCol=value) + + @since("2.0.0") + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + return self._set(predictionCol=value) + + @since("2.0.0") + def setSeed(self, value): + """ + Sets the value of :py:attr:`seed`. + """ + return self._set(seed=value) + def _create_model(self, java_model): return BisectingKMeansModel(java_model) @@ -873,6 +1049,31 @@ class LDAModel(JavaModel, _LDAParams): .. versionadded:: 2.0.0 """ + @since("3.0.0") + def setFeaturesCol(self, value): + """ + Sets the value of :py:attr:`featuresCol`. + """ + return self._set(featuresCol=value) + + @since("3.0.0") + def setSeed(self, value): + """ + Sets the value of :py:attr:`seed`. + """ + return self._set(seed=value) + + @since("3.0.0") + def setTopicDistributionCol(self, value): + """ + Sets the value of :py:attr:`topicDistributionCol`. + + >>> algo = LDA().setTopicDistributionCol("topicDistributionCol") + >>> algo.getTopicDistributionCol() + 'topicDistributionCol' + """ + return self._set(topicDistributionCol=value) + @since("2.0.0") def isDistributed(self): """ @@ -1045,6 +1246,11 @@ class LDA(JavaEstimator, _LDAParams, JavaMLReadable, JavaMLWritable): >>> df = spark.createDataFrame([[1, Vectors.dense([0.0, 1.0])], ... [2, SparseVector(2, {0: 1.0})],], ["id", "features"]) >>> lda = LDA(k=2, seed=1, optimizer="em") + >>> lda.setMaxIter(10) + LDA... + >>> lda.getMaxIter() + 10 + >>> lda.clear(lda.maxIter) >>> model = lda.fit(df) >>> model.getTopicDistributionCol() 'topicDistribution' @@ -1125,6 +1331,20 @@ def setParams(self, featuresCol="features", maxIter=20, seed=None, checkpointInt kwargs = self._input_kwargs return self._set(**kwargs) + @since("2.0.0") + def setCheckpointInterval(self, value): + """ + Sets the value of :py:attr:`checkpointInterval`. + """ + return self._set(checkpointInterval=value) + + @since("2.0.0") + def setSeed(self, value): + """ + Sets the value of :py:attr:`seed`. + """ + return self._set(seed=value) + @since("2.0.0") def setK(self, value): """ @@ -1236,6 +1456,20 @@ def setKeepLastCheckpoint(self, value): """ return self._set(keepLastCheckpoint=value) + @since("2.0.0") + def setMaxIter(self, value): + """ + Sets the value of :py:attr:`maxIter`. + """ + return self._set(maxIter=value) + + @since("2.0.0") + def setFeaturesCol(self, value): + """ + Sets the value of :py:attr:`featuresCol`. + """ + return self._set(featuresCol=value) + @inherit_doc class _PowerIterationClusteringParams(HasMaxIter, HasWeightCol): @@ -1392,6 +1626,20 @@ def setDstCol(self, value): """ return self._set(dstCol=value) + @since("2.4.0") + def setMaxIter(self, value): + """ + Sets the value of :py:attr:`maxIter`. + """ + return self._set(maxIter=value) + + @since("2.4.0") + def setWeightCol(self, value): + """ + Sets the value of :py:attr:`weightCol`. + """ + return self._set(weightCol=value) + @since("2.4.0") def assignClusters(self, dataset): """ diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index cdd9be7bf11b..6539e2abaed1 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -119,7 +119,9 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction ... [(0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)]) >>> dataset = spark.createDataFrame(scoreAndLabels, ["raw", "label"]) ... - >>> evaluator = BinaryClassificationEvaluator(rawPredictionCol="raw") + >>> evaluator = BinaryClassificationEvaluator() + >>> evaluator.setRawPredictionCol("raw") + BinaryClassificationEvaluator... >>> evaluator.evaluate(dataset) 0.70... >>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"}) @@ -196,6 +198,25 @@ def getNumBins(self): """ return self.getOrDefault(self.numBins) + def setLabelCol(self, value): + """ + Sets the value of :py:attr:`labelCol`. + """ + return self._set(labelCol=value) + + def setRawPredictionCol(self, value): + """ + Sets the value of :py:attr:`rawPredictionCol`. + """ + return self._set(rawPredictionCol=value) + + @since("3.0.0") + def setWeightCol(self, value): + """ + Sets the value of :py:attr:`weightCol`. + """ + return self._set(weightCol=value) + @keyword_only @since("1.4.0") def setParams(self, rawPredictionCol="rawPrediction", labelCol="label", @@ -220,7 +241,9 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeigh ... (-25.98418959, -22.0), (30.69731842, 33.0), (74.69283752, 71.0)] >>> dataset = spark.createDataFrame(scoreAndLabels, ["raw", "label"]) ... - >>> evaluator = RegressionEvaluator(predictionCol="raw") + >>> evaluator = RegressionEvaluator() + >>> evaluator.setPredictionCol("raw") + RegressionEvaluator... >>> evaluator.evaluate(dataset) 2.842... >>> evaluator.evaluate(dataset, {evaluator.metricName: "r2"}) @@ -299,6 +322,25 @@ def getThroughOrigin(self): """ return self.getOrDefault(self.throughOrigin) + def setLabelCol(self, value): + """ + Sets the value of :py:attr:`labelCol`. + """ + return self._set(labelCol=value) + + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + return self._set(predictionCol=value) + + @since("3.0.0") + def setWeightCol(self, value): + """ + Sets the value of :py:attr:`weightCol`. + """ + return self._set(weightCol=value) + @keyword_only @since("1.4.0") def setParams(self, predictionCol="prediction", labelCol="label", @@ -322,7 +364,9 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio >>> scoreAndLabels = [(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)] >>> dataset = spark.createDataFrame(scoreAndLabels, ["prediction", "label"]) - >>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction") + >>> evaluator = MulticlassClassificationEvaluator() + >>> evaluator.setPredictionCol("prediction") + MulticlassClassificationEvaluator... >>> evaluator.evaluate(dataset) 0.66... >>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"}) @@ -453,6 +497,32 @@ def getEps(self): """ return self.getOrDefault(self.eps) + def setLabelCol(self, value): + """ + Sets the value of :py:attr:`labelCol`. + """ + return self._set(labelCol=value) + + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + return self._set(predictionCol=value) + + @since("3.0.0") + def setProbabilityCol(self, value): + """ + Sets the value of :py:attr:`probabilityCol`. + """ + return self._set(probabilityCol=value) + + @since("3.0.0") + def setWeightCol(self, value): + """ + Sets the value of :py:attr:`weightCol`. + """ + return self._set(weightCol=value) + @keyword_only @since("1.5.0") def setParams(self, predictionCol="prediction", labelCol="label", @@ -482,7 +552,9 @@ class MultilabelClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio ... ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])] >>> dataset = spark.createDataFrame(scoreAndLabels, ["prediction", "label"]) ... - >>> evaluator = MultilabelClassificationEvaluator(predictionCol="prediction") + >>> evaluator = MultilabelClassificationEvaluator() + >>> evaluator.setPredictionCol("prediction") + MultilabelClassificationEvaluator... >>> evaluator.evaluate(dataset) 0.63... >>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"}) @@ -549,6 +621,20 @@ def getMetricLabel(self): """ return self.getOrDefault(self.metricLabel) + @since("3.0.0") + def setLabelCol(self, value): + """ + Sets the value of :py:attr:`labelCol`. + """ + return self._set(labelCol=value) + + @since("3.0.0") + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + return self._set(predictionCol=value) + @keyword_only @since("3.0.0") def setParams(self, predictionCol="prediction", labelCol="label", @@ -581,7 +667,9 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol, ... ([10.5, 11.5], 1.0), ([1.0, 1.0], 0.0), ([8.0, 6.0], 1.0)]) >>> dataset = spark.createDataFrame(featureAndPredictions, ["features", "prediction"]) ... - >>> evaluator = ClusteringEvaluator(predictionCol="prediction") + >>> evaluator = ClusteringEvaluator() + >>> evaluator.setPredictionCol("prediction") + ClusteringEvaluator... >>> evaluator.evaluate(dataset) 0.9079... >>> ce_path = temp_path + "/ce" @@ -613,6 +701,18 @@ def __init__(self, predictionCol="prediction", featuresCol="features", kwargs = self._input_kwargs self._set(**kwargs) + @keyword_only + @since("2.3.0") + def setParams(self, predictionCol="prediction", featuresCol="features", + metricName="silhouette", distanceMeasure="squaredEuclidean"): + """ + setParams(self, predictionCol="prediction", featuresCol="features", \ + metricName="silhouette", distanceMeasure="squaredEuclidean") + Sets params for clustering evaluator. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + @since("2.3.0") def setMetricName(self, value): """ @@ -627,18 +727,6 @@ def getMetricName(self): """ return self.getOrDefault(self.metricName) - @keyword_only - @since("2.3.0") - def setParams(self, predictionCol="prediction", featuresCol="features", - metricName="silhouette", distanceMeasure="squaredEuclidean"): - """ - setParams(self, predictionCol="prediction", featuresCol="features", \ - metricName="silhouette", distanceMeasure="squaredEuclidean") - Sets params for clustering evaluator. - """ - kwargs = self._input_kwargs - return self._set(**kwargs) - @since("2.4.0") def setDistanceMeasure(self, value): """ @@ -653,6 +741,18 @@ def getDistanceMeasure(self): """ return self.getOrDefault(self.distanceMeasure) + def setFeaturesCol(self, value): + """ + Sets the value of :py:attr:`featuresCol`. + """ + return self._set(featuresCol=value) + + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + return self._set(predictionCol=value) + @inherit_doc class RankingEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, @@ -669,7 +769,9 @@ class RankingEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, ... ([1.0, 2.0, 3.0, 4.0, 5.0], [])] >>> dataset = spark.createDataFrame(scoreAndLabels, ["prediction", "label"]) ... - >>> evaluator = RankingEvaluator(predictionCol="prediction") + >>> evaluator = RankingEvaluator() + >>> evaluator.setPredictionCol("prediction") + RankingEvaluator... >>> evaluator.evaluate(dataset) 0.35... >>> evaluator.evaluate(dataset, {evaluator.metricName: "precisionAtK", evaluator.k: 2}) @@ -734,6 +836,20 @@ def getK(self): """ return self.getOrDefault(self.k) + @since("3.0.0") + def setLabelCol(self, value): + """ + Sets the value of :py:attr:`labelCol`. + """ + return self._set(labelCol=value) + + @since("3.0.0") + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + return self._set(predictionCol=value) + @keyword_only @since("3.0.0") def setParams(self, predictionCol="prediction", labelCol="label", diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index a0883f1d54fe..11bb7941b5d9 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -76,6 +76,12 @@ class Binarizer(JavaTransformer, HasThreshold, HasThresholds, HasInputCol, HasOu >>> df = spark.createDataFrame([(0.5,)], ["values"]) >>> binarizer = Binarizer(threshold=1.0, inputCol="values", outputCol="features") + >>> binarizer.setThreshold(1.0) + Binarizer... + >>> binarizer.setInputCol("values") + Binarizer... + >>> binarizer.setOutputCol("features") + Binarizer... >>> binarizer.transform(df).head().features 0.0 >>> binarizer.setParams(outputCol="freqs").transform(df).head().freqs @@ -154,6 +160,32 @@ def setThresholds(self, value): """ return self._set(thresholds=value) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + @since("3.0.0") + def setInputCols(self, value): + """ + Sets the value of :py:attr:`inputCols`. + """ + return self._set(inputCols=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + + @since("3.0.0") + def setOutputCols(self, value): + """ + Sets the value of :py:attr:`outputCols`. + """ + return self._set(outputCols=value) + class _LSHParams(HasInputCol, HasOutputCol): """ @@ -183,12 +215,36 @@ def setNumHashTables(self, value): """ return self._set(numHashTables=value) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + class _LSHModel(JavaModel, _LSHParams): """ Mixin for Locality Sensitive Hashing (LSH) models. """ + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + def approxNearestNeighbors(self, dataset, key, numNearestNeighbors, distCol="distCol"): """ Given a large dataset and an item, approximately find at most k items which have the @@ -269,8 +325,15 @@ class BucketedRandomProjectionLSH(_LSH, _BucketedRandomProjectionLSHParams, ... (2, Vectors.dense([1.0, -1.0 ]),), ... (3, Vectors.dense([1.0, 1.0]),)] >>> df = spark.createDataFrame(data, ["id", "features"]) - >>> brp = BucketedRandomProjectionLSH(inputCol="features", outputCol="hashes", - ... seed=12345, bucketLength=1.0) + >>> brp = BucketedRandomProjectionLSH() + >>> brp.setInputCol("features") + BucketedRandomProjectionLSH... + >>> brp.setOutputCol("hashes") + BucketedRandomProjectionLSH... + >>> brp.setSeed(12345) + BucketedRandomProjectionLSH... + >>> brp.setBucketLength(1.0) + BucketedRandomProjectionLSH... >>> model = brp.fit(df) >>> model.getBucketLength() 1.0 @@ -350,6 +413,24 @@ def setBucketLength(self, value): """ return self._set(bucketLength=value) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + + def setSeed(self, value): + """ + Sets the value of :py:attr:`seed`. + """ + return self._set(seed=value) + def _create_model(self, java_model): return BucketedRandomProjectionLSHModel(java_model) @@ -366,6 +447,20 @@ class BucketedRandomProjectionLSHModel(_LSHModel, _BucketedRandomProjectionLSHPa .. versionadded:: 2.2.0 """ + @since("3.0.0") + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + @since("3.0.0") + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + @inherit_doc class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols, @@ -380,8 +475,13 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOu >>> values = [(0.1, 0.0), (0.4, 1.0), (1.2, 1.3), (1.5, float("nan")), ... (float("nan"), 1.0), (float("nan"), 0.0)] >>> df = spark.createDataFrame(values, ["values1", "values2"]) - >>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")], - ... inputCol="values1", outputCol="buckets") + >>> bucketizer = Bucketizer() + >>> bucketizer.setSplits([-float("inf"), 0.5, 1.4, float("inf")]) + Bucketizer... + >>> bucketizer.setInputCol("values1") + Bucketizer... + >>> bucketizer.setOutputCol("buckets") + Bucketizer... >>> bucketed = bucketizer.setHandleInvalid("keep").transform(df).collect() >>> bucketed = bucketizer.setHandleInvalid("keep").transform(df.select("values1")) >>> bucketed.show(truncate=False) @@ -510,6 +610,38 @@ def getSplitsArray(self): """ return self.getOrDefault(self.splitsArray) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + @since("3.0.0") + def setInputCols(self, value): + """ + Sets the value of :py:attr:`inputCols`. + """ + return self._set(inputCols=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + + @since("3.0.0") + def setOutputCols(self, value): + """ + Sets the value of :py:attr:`outputCols`. + """ + return self._set(outputCols=value) + + def setHandleInvalid(self, value): + """ + Sets the value of :py:attr:`handleInvalid`. + """ + return self._set(handleInvalid=value) + class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol): """ @@ -595,7 +727,11 @@ class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, Jav >>> df = spark.createDataFrame( ... [(0, ["a", "b", "c"]), (1, ["a", "b", "b", "c", "a"])], ... ["label", "raw"]) - >>> cv = CountVectorizer(inputCol="raw", outputCol="vectors") + >>> cv = CountVectorizer() + >>> cv.setInputCol("raw") + CountVectorizer... + >>> cv.setOutputCol("vectors") + CountVectorizer... >>> model = cv.fit(df) >>> model.transform(df).show(truncate=False) +-----+---------------+-------------------------+ @@ -695,6 +831,18 @@ def setBinary(self, value): """ return self._set(binary=value) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + def _create_model(self, java_model): return CountVectorizerModel(java_model) @@ -707,6 +855,34 @@ class CountVectorizerModel(JavaModel, _CountVectorizerParams, JavaMLReadable, Ja .. versionadded:: 1.6.0 """ + @since("3.0.0") + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + @since("3.0.0") + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + + @since("3.0.0") + def setMinTF(self, value): + """ + Sets the value of :py:attr:`minTF`. + """ + return self._set(minTF=value) + + @since("3.0.0") + def setBinary(self, value): + """ + Sets the value of :py:attr:`binary`. + """ + return self._set(binary=value) + @classmethod @since("2.4.0") def from_vocabulary(cls, vocabulary, inputCol, outputCol=None, minTF=None, binary=None): @@ -766,7 +942,13 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWrit >>> from pyspark.ml.linalg import Vectors >>> df1 = spark.createDataFrame([(Vectors.dense([5.0, 8.0, 6.0]),)], ["vec"]) - >>> dct = DCT(inverse=False, inputCol="vec", outputCol="resultVec") + >>> dct = DCT( ) + >>> dct.setInverse(False) + DCT... + >>> dct.setInputCol("vec") + DCT... + >>> dct.setOutputCol("resultVec") + DCT... >>> df2 = dct.transform(df1) >>> df2.head().resultVec DenseVector([10.969..., -0.707..., -2.041...]) @@ -820,6 +1002,18 @@ def getInverse(self): """ return self.getOrDefault(self.inverse) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + @inherit_doc class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, @@ -831,8 +1025,13 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReada >>> from pyspark.ml.linalg import Vectors >>> df = spark.createDataFrame([(Vectors.dense([2.0, 1.0, 3.0]),)], ["values"]) - >>> ep = ElementwiseProduct(scalingVec=Vectors.dense([1.0, 2.0, 3.0]), - ... inputCol="values", outputCol="eprod") + >>> ep = ElementwiseProduct() + >>> ep.setScalingVec(Vectors.dense([1.0, 2.0, 3.0])) + ElementwiseProduct... + >>> ep.setInputCol("values") + ElementwiseProduct... + >>> ep.setOutputCol("eprod") + ElementwiseProduct... >>> ep.transform(df).head().eprod DenseVector([2.0, 2.0, 9.0]) >>> ep.setParams(scalingVec=Vectors.dense([2.0, 3.0, 5.0])).transform(df).head().eprod @@ -884,6 +1083,18 @@ def getScalingVec(self): """ return self.getOrDefault(self.scalingVec) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + @inherit_doc class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures, JavaMLReadable, @@ -923,7 +1134,11 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures, >>> data = [(2.0, True, "1", "foo"), (3.0, False, "2", "bar")] >>> cols = ["real", "bool", "stringNum", "string"] >>> df = spark.createDataFrame(data, cols) - >>> hasher = FeatureHasher(inputCols=cols, outputCol="features") + >>> hasher = FeatureHasher() + >>> hasher.setInputCols(cols) + FeatureHasher... + >>> hasher.setOutputCol("features") + FeatureHasher... >>> hasher.transform(df).head().features SparseVector(262144, {174475: 2.0, 247670: 1.0, 257907: 1.0, 262126: 1.0}) >>> hasher.setCategoricalCols(["real"]).transform(df).head().features @@ -978,6 +1193,24 @@ def getCategoricalCols(self): """ return self.getOrDefault(self.categoricalCols) + def setInputCols(self, value): + """ + Sets the value of :py:attr:`inputCols`. + """ + return self._set(inputCols=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + + def setNumFeatures(self, value): + """ + Sets the value of :py:attr:`numFeatures`. + """ + return self._set(numFeatures=value) + @inherit_doc class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, JavaMLReadable, @@ -991,7 +1224,9 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, Java otherwise the features will not be mapped evenly to the columns. >>> df = spark.createDataFrame([(["a", "b", "c"],)], ["words"]) - >>> hashingTF = HashingTF(numFeatures=10, inputCol="words", outputCol="features") + >>> hashingTF = HashingTF(inputCol="words", outputCol="features") + >>> hashingTF.setNumFeatures(10) + HashingTF... >>> hashingTF.transform(df).head().features SparseVector(10, {5: 1.0, 7: 1.0, 8: 1.0}) >>> hashingTF.setParams(outputCol="freqs").transform(df).head().freqs @@ -1050,6 +1285,24 @@ def getBinary(self): """ return self.getOrDefault(self.binary) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + + def setNumFeatures(self, value): + """ + Sets the value of :py:attr:`numFeatures`. + """ + return self._set(numFeatures=value) + @since("3.0.0") def indexOf(self, term): """ @@ -1086,7 +1339,11 @@ class IDF(JavaEstimator, _IDFParams, JavaMLReadable, JavaMLWritable): >>> from pyspark.ml.linalg import DenseVector >>> df = spark.createDataFrame([(DenseVector([1.0, 2.0]),), ... (DenseVector([0.0, 1.0]),), (DenseVector([3.0, 0.2]),)], ["tf"]) - >>> idf = IDF(minDocFreq=3, inputCol="tf", outputCol="idf") + >>> idf = IDF(minDocFreq=3) + >>> idf.setInputCol("tf") + IDF... + >>> idf.setOutputCol("idf") + IDF... >>> model = idf.fit(df) >>> model.getMinDocFreq() 3 @@ -1145,6 +1402,18 @@ def setMinDocFreq(self, value): """ return self._set(minDocFreq=value) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + def _create_model(self, java_model): return IDFModel(java_model) @@ -1156,6 +1425,20 @@ class IDFModel(JavaModel, _IDFParams, JavaMLReadable, JavaMLWritable): .. versionadded:: 1.4.0 """ + @since("3.0.0") + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + @since("3.0.0") + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + @property @since("2.0.0") def idf(self): @@ -1228,7 +1511,11 @@ class Imputer(JavaEstimator, _ImputerParams, JavaMLReadable, JavaMLWritable): >>> df = spark.createDataFrame([(1.0, float("nan")), (2.0, float("nan")), (float("nan"), 3.0), ... (4.0, 4.0), (5.0, 5.0)], ["a", "b"]) - >>> imputer = Imputer(inputCols=["a", "b"], outputCols=["out_a", "out_b"]) + >>> imputer = Imputer() + >>> imputer.setInputCols(["a", "b"]) + Imputer... + >>> imputer.setOutputCols(["out_a", "out_b"]) + Imputer... >>> model = imputer.fit(df) >>> model.getStrategy() 'mean' @@ -1308,6 +1595,20 @@ def setMissingValue(self, value): """ return self._set(missingValue=value) + @since("2.2.0") + def setInputCols(self, value): + """ + Sets the value of :py:attr:`inputCols`. + """ + return self._set(inputCols=value) + + @since("2.2.0") + def setOutputCols(self, value): + """ + Sets the value of :py:attr:`outputCols`. + """ + return self._set(outputCols=value) + def _create_model(self, java_model): return ImputerModel(java_model) @@ -1319,6 +1620,20 @@ class ImputerModel(JavaModel, _ImputerParams, JavaMLReadable, JavaMLWritable): .. versionadded:: 2.2.0 """ + @since("3.0.0") + def setInputCols(self, value): + """ + Sets the value of :py:attr:`inputCols`. + """ + return self._set(inputCols=value) + + @since("3.0.0") + def setOutputCols(self, value): + """ + Sets the value of :py:attr:`outputCols`. + """ + return self._set(outputCols=value) + @property @since("2.2.0") def surrogateDF(self): @@ -1342,7 +1657,11 @@ class Interaction(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadable, J with four categories, the output would then be `Vector(0, 0, 0, 0, 3, 4, 0, 0)`. >>> df = spark.createDataFrame([(0.0, 1.0), (2.0, 3.0)], ["a", "b"]) - >>> interaction = Interaction(inputCols=["a", "b"], outputCol="ab") + >>> interaction = Interaction() + >>> interaction.setInputCols(["a", "b"]) + Interaction... + >>> interaction.setOutputCol("ab") + Interaction... >>> interaction.transform(df).show() +---+---+-----+ | a| b| ab| @@ -1381,6 +1700,20 @@ def setParams(self, inputCols=None, outputCol=None): kwargs = self._input_kwargs return self._set(**kwargs) + @since("3.0.0") + def setInputCols(self, value): + """ + Sets the value of :py:attr:`inputCols`. + """ + return self._set(inputCols=value) + + @since("3.0.0") + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + class _MaxAbsScalerParams(HasInputCol, HasOutputCol): """ @@ -1400,7 +1733,9 @@ class MaxAbsScaler(JavaEstimator, _MaxAbsScalerParams, JavaMLReadable, JavaMLWri >>> from pyspark.ml.linalg import Vectors >>> df = spark.createDataFrame([(Vectors.dense([1.0]),), (Vectors.dense([2.0]),)], ["a"]) - >>> maScaler = MaxAbsScaler(inputCol="a", outputCol="scaled") + >>> maScaler = MaxAbsScaler(outputCol="scaled") + >>> maScaler.setInputCol("a") + MaxAbsScaler... >>> model = maScaler.fit(df) >>> model.setOutputCol("scaledOutput") MaxAbsScaler... @@ -1449,6 +1784,18 @@ def setParams(self, inputCol=None, outputCol=None): kwargs = self._input_kwargs return self._set(**kwargs) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + def _create_model(self, java_model): return MaxAbsScalerModel(java_model) @@ -1460,6 +1807,20 @@ class MaxAbsScalerModel(JavaModel, _MaxAbsScalerParams, JavaMLReadable, JavaMLWr .. versionadded:: 2.0.0 """ + @since("3.0.0") + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + @since("3.0.0") + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + @property @since("2.0.0") def maxAbs(self): @@ -1487,7 +1848,13 @@ class MinHashLSH(_LSH, HasInputCol, HasOutputCol, HasSeed, JavaMLReadable, JavaM ... (1, Vectors.sparse(6, [2, 3, 4], [1.0, 1.0, 1.0]),), ... (2, Vectors.sparse(6, [0, 2, 4], [1.0, 1.0, 1.0]),)] >>> df = spark.createDataFrame(data, ["id", "features"]) - >>> mh = MinHashLSH(inputCol="features", outputCol="hashes", seed=12345) + >>> mh = MinHashLSH() + >>> mh.setInputCol("features") + MinHashLSH... + >>> mh.setOutputCol("hashes") + MinHashLSH... + >>> mh.setSeed(12345) + MinHashLSH... >>> model = mh.fit(df) >>> model.transform(df).head() Row(id=0, features=SparseVector(6, {0: 1.0, 1: 1.0, 2: 1.0}), hashes=[DenseVector([6179668... @@ -1544,6 +1911,12 @@ def setParams(self, inputCol=None, outputCol=None, seed=None, numHashTables=1): kwargs = self._input_kwargs return self._set(**kwargs) + def setSeed(self, value): + """ + Sets the value of :py:attr:`seed`. + """ + return self._set(seed=value) + def _create_model(self, java_model): return MinHashLSHModel(java_model) @@ -1606,7 +1979,9 @@ class MinMaxScaler(JavaEstimator, _MinMaxScalerParams, JavaMLReadable, JavaMLWri >>> from pyspark.ml.linalg import Vectors >>> df = spark.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"]) - >>> mmScaler = MinMaxScaler(inputCol="a", outputCol="scaled") + >>> mmScaler = MinMaxScaler(outputCol="scaled") + >>> mmScaler.setInputCol("a") + MinMaxScaler... >>> model = mmScaler.fit(df) >>> model.setOutputCol("scaledOutput") MinMaxScaler... @@ -1675,6 +2050,18 @@ def setMax(self, value): """ return self._set(max=value) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + def _create_model(self, java_model): return MinMaxScalerModel(java_model) @@ -1686,6 +2073,34 @@ class MinMaxScalerModel(JavaModel, _MinMaxScalerParams, JavaMLReadable, JavaMLWr .. versionadded:: 1.6.0 """ + @since("3.0.0") + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + @since("3.0.0") + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + + @since("3.0.0") + def setMin(self, value): + """ + Sets the value of :py:attr:`min`. + """ + return self._set(min=value) + + @since("3.0.0") + def setMax(self, value): + """ + Sets the value of :py:attr:`max`. + """ + return self._set(max=value) + @property @since("2.0.0") def originalMin(self): @@ -1716,7 +2131,11 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWr returned. >>> df = spark.createDataFrame([Row(inputTokens=["a", "b", "c", "d", "e"])]) - >>> ngram = NGram(n=2, inputCol="inputTokens", outputCol="nGrams") + >>> ngram = NGram(n=2) + >>> ngram.setInputCol("inputTokens") + NGram... + >>> ngram.setOutputCol("nGrams") + NGram... >>> ngram.transform(df).head() Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b', u'b c', u'c d', u'd e']) >>> # Change n-gram length @@ -1779,6 +2198,18 @@ def getN(self): """ return self.getOrDefault(self.n) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + @inherit_doc class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): @@ -1788,7 +2219,11 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav >>> from pyspark.ml.linalg import Vectors >>> svec = Vectors.sparse(4, {1: 4.0, 3: 3.0}) >>> df = spark.createDataFrame([(Vectors.dense([3.0, -4.0]), svec)], ["dense", "sparse"]) - >>> normalizer = Normalizer(p=2.0, inputCol="dense", outputCol="features") + >>> normalizer = Normalizer(p=2.0) + >>> normalizer.setInputCol("dense") + Normalizer... + >>> normalizer.setOutputCol("features") + Normalizer... >>> normalizer.transform(df).head().features DenseVector([0.6, -0.8]) >>> normalizer.setParams(inputCol="sparse", outputCol="freqs").transform(df).head().freqs @@ -1843,6 +2278,18 @@ def getP(self): """ return self.getOrDefault(self.p) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + class _OneHotEncoderParams(HasInputCols, HasOutputCols, HasHandleInvalid): """ @@ -1895,7 +2342,11 @@ class OneHotEncoder(JavaEstimator, _OneHotEncoderParams, JavaMLReadable, JavaMLW >>> from pyspark.ml.linalg import Vectors >>> df = spark.createDataFrame([(0.0,), (1.0,), (2.0,)], ["input"]) - >>> ohe = OneHotEncoder(inputCols=["input"], outputCols=["output"]) + >>> ohe = OneHotEncoder() + >>> ohe.setInputCols(["input"]) + OneHotEncoder... + >>> ohe.setOutputCols(["output"]) + OneHotEncoder... >>> model = ohe.fit(df) >>> model.getHandleInvalid() 'error' @@ -1944,6 +2395,27 @@ def setDropLast(self, value): """ return self._set(dropLast=value) + @since("3.0.0") + def setInputCols(self, value): + """ + Sets the value of :py:attr:`inputCols`. + """ + return self._set(inputCols=value) + + @since("3.0.0") + def setOutputCols(self, value): + """ + Sets the value of :py:attr:`outputCols`. + """ + return self._set(outputCols=value) + + @since("3.0.0") + def setHandleInvalid(self, value): + """ + Sets the value of :py:attr:`handleInvalid`. + """ + return self._set(handleInvalid=value) + def _create_model(self, java_model): return OneHotEncoderModel(java_model) @@ -1955,6 +2427,34 @@ class OneHotEncoderModel(JavaModel, _OneHotEncoderParams, JavaMLReadable, JavaML .. versionadded:: 2.3.0 """ + @since("3.0.0") + def setDropLast(self, value): + """ + Sets the value of :py:attr:`dropLast`. + """ + return self._set(dropLast=value) + + @since("3.0.0") + def setInputCols(self, value): + """ + Sets the value of :py:attr:`inputCols`. + """ + return self._set(inputCols=value) + + @since("3.0.0") + def setOutputCols(self, value): + """ + Sets the value of :py:attr:`outputCols`. + """ + return self._set(outputCols=value) + + @since("3.0.0") + def setHandleInvalid(self, value): + """ + Sets the value of :py:attr:`handleInvalid`. + """ + return self._set(handleInvalid=value) + @property @since("2.3.0") def categorySizes(self): @@ -1977,7 +2477,11 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLRead >>> from pyspark.ml.linalg import Vectors >>> df = spark.createDataFrame([(Vectors.dense([0.5, 2.0]),)], ["dense"]) - >>> px = PolynomialExpansion(degree=2, inputCol="dense", outputCol="expanded") + >>> px = PolynomialExpansion(degree=2) + >>> px.setInputCol("dense") + PolynomialExpansion... + >>> px.setOutputCol("expanded") + PolynomialExpansion... >>> px.transform(df).head().expanded DenseVector([0.5, 0.25, 2.0, 1.0, 4.0]) >>> px.setParams(outputCol="test").transform(df).head().test @@ -2030,6 +2534,18 @@ def getDegree(self): """ return self.getOrDefault(self.degree) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + @inherit_doc class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols, @@ -2060,8 +2576,13 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasInputCols >>> values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)] >>> df1 = spark.createDataFrame(values, ["values"]) - >>> qds1 = QuantileDiscretizer(numBuckets=2, - ... inputCol="values", outputCol="buckets", relativeError=0.01, handleInvalid="error") + >>> qds1 = QuantileDiscretizer(inputCol="values", outputCol="buckets") + >>> qds1.setNumBuckets(2) + QuantileDiscretizer... + >>> qds1.setRelativeError(0.01) + QuantileDiscretizer... + >>> qds1.setHandleInvalid("error") + QuantileDiscretizer... >>> qds1.getRelativeError() 0.01 >>> bucketizer = qds1.fit(df1) @@ -2213,6 +2734,38 @@ def getRelativeError(self): """ return self.getOrDefault(self.relativeError) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + @since("3.0.0") + def setInputCols(self, value): + """ + Sets the value of :py:attr:`inputCols`. + """ + return self._set(inputCols=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + + @since("3.0.0") + def setOutputCols(self, value): + """ + Sets the value of :py:attr:`outputCols`. + """ + return self._set(outputCols=value) + + def setHandleInvalid(self, value): + """ + Sets the value of :py:attr:`handleInvalid`. + """ + return self._set(handleInvalid=value) + def _create_model(self, java_model): """ Private method to convert the java_model to a Python model. @@ -2292,7 +2845,11 @@ class RobustScaler(JavaEstimator, _RobustScalerParams, JavaMLReadable, JavaMLWri ... (3, Vectors.dense([3.0, -3.0]),), ... (4, Vectors.dense([4.0, -4.0]),),] >>> df = spark.createDataFrame(data, ["id", "features"]) - >>> scaler = RobustScaler(inputCol="features", outputCol="scaled") + >>> scaler = RobustScaler() + >>> scaler.setInputCol("features") + RobustScaler... + >>> scaler.setOutputCol("scaled") + RobustScaler... >>> model = scaler.fit(df) >>> model.setOutputCol("output") RobustScaler... @@ -2373,6 +2930,20 @@ def setWithScaling(self, value): """ return self._set(withScaling=value) + @since("3.0.0") + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + @since("3.0.0") + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + def _create_model(self, java_model): return RobustScalerModel(java_model) @@ -2384,6 +2955,20 @@ class RobustScalerModel(JavaModel, _RobustScalerParams, JavaMLReadable, JavaMLWr .. versionadded:: 3.0.0 """ + @since("3.0.0") + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + @since("3.0.0") + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + @property @since("3.0.0") def median(self): @@ -2413,7 +2998,11 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, It returns an array of strings that can be empty. >>> df = spark.createDataFrame([("A B c",)], ["text"]) - >>> reTokenizer = RegexTokenizer(inputCol="text", outputCol="words") + >>> reTokenizer = RegexTokenizer() + >>> reTokenizer.setInputCol("text") + RegexTokenizer... + >>> reTokenizer.setOutputCol("words") + RegexTokenizer... >>> reTokenizer.transform(df).head() Row(text=u'A B c', words=[u'a', u'b', u'c']) >>> # Change a parameter. @@ -2530,6 +3119,18 @@ def getToLowercase(self): """ return self.getOrDefault(self.toLowercase) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + @inherit_doc class SQLTransformer(JavaTransformer, JavaMLReadable, JavaMLWritable): @@ -2629,7 +3230,11 @@ class StandardScaler(JavaEstimator, _StandardScalerParams, JavaMLReadable, JavaM >>> from pyspark.ml.linalg import Vectors >>> df = spark.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"]) - >>> standardScaler = StandardScaler(inputCol="a", outputCol="scaled") + >>> standardScaler = StandardScaler() + >>> standardScaler.setInputCol("a") + StandardScaler... + >>> standardScaler.setOutputCol("scaled") + StandardScaler... >>> model = standardScaler.fit(df) >>> model.getInputCol() 'a' @@ -2694,6 +3299,18 @@ def setWithStd(self, value): """ return self._set(withStd=value) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + def _create_model(self, java_model): return StandardScalerModel(java_model) @@ -2705,6 +3322,18 @@ class StandardScalerModel(JavaModel, _StandardScalerParams, JavaMLReadable, Java .. versionadded:: 1.4.0 """ + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + @property @since("2.0.0") def std(self): @@ -2765,8 +3394,10 @@ class StringIndexer(JavaEstimator, _StringIndexerParams, JavaMLReadable, JavaMLW so the most frequent label gets index 0. The ordering behavior is controlled by setting :py:attr:`stringOrderType`. Its default value is 'frequencyDesc'. - >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="error", + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", ... stringOrderType="frequencyDesc") + >>> stringIndexer.setHandleInvalid("error") + StringIndexer... >>> model = stringIndexer.fit(stringIndDf) >>> td = model.transform(stringIndDf) >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), @@ -2866,6 +3497,38 @@ def setStringOrderType(self, value): """ return self._set(stringOrderType=value) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + @since("3.0.0") + def setInputCols(self, value): + """ + Sets the value of :py:attr:`inputCols`. + """ + return self._set(inputCols=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + + @since("3.0.0") + def setOutputCols(self, value): + """ + Sets the value of :py:attr:`outputCols`. + """ + return self._set(outputCols=value) + + def setHandleInvalid(self, value): + """ + Sets the value of :py:attr:`handleInvalid`. + """ + return self._set(handleInvalid=value) + class StringIndexerModel(JavaModel, _StringIndexerParams, JavaMLReadable, JavaMLWritable): """ @@ -2874,6 +3537,39 @@ class StringIndexerModel(JavaModel, _StringIndexerParams, JavaMLReadable, JavaML .. versionadded:: 1.4.0 """ + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + @since("3.0.0") + def setInputCols(self, value): + """ + Sets the value of :py:attr:`inputCols`. + """ + return self._set(inputCols=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + + @since("3.0.0") + def setOutputCols(self, value): + """ + Sets the value of :py:attr:`outputCols`. + """ + return self._set(outputCols=value) + + @since("2.4.0") + def setHandleInvalid(self, value): + """ + Sets the value of :py:attr:`handleInvalid`. + """ + return self._set(handleInvalid=value) + @classmethod @since("2.4.0") def from_labels(cls, labels, inputCol, outputCol=None, handleInvalid=None): @@ -2921,13 +3617,6 @@ def labels(self): """ return self._call_java("labels") - @since("2.4.0") - def setHandleInvalid(self, value): - """ - Sets the value of :py:attr:`handleInvalid`. - """ - return self._set(handleInvalid=value) - @inherit_doc class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): @@ -2981,6 +3670,18 @@ def getLabels(self): """ return self.getOrDefault(self.labels) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ @@ -2989,7 +3690,11 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl .. note:: null values from input array are preserved unless adding null to stopWords explicitly. >>> df = spark.createDataFrame([(["a", "b", "c"],)], ["text"]) - >>> remover = StopWordsRemover(inputCol="text", outputCol="words", stopWords=["b"]) + >>> remover = StopWordsRemover(stopWords=["b"]) + >>> remover.setInputCol("text") + StopWordsRemover... + >>> remover.setOutputCol("words") + StopWordsRemover... >>> remover.transform(df).head().words == ['a', 'c'] True >>> stopWordsRemoverPath = temp_path + "/stopwords-remover" @@ -3079,6 +3784,18 @@ def getLocale(self): """ return self.getOrDefault(self.locale) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + @staticmethod @since("2.0.0") def loadDefaultStopWords(language): @@ -3099,7 +3816,9 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Java splits it by white spaces. >>> df = spark.createDataFrame([("a b c",)], ["text"]) - >>> tokenizer = Tokenizer(inputCol="text", outputCol="words") + >>> tokenizer = Tokenizer(outputCol="words") + >>> tokenizer.setInputCol("text") + Tokenizer... >>> tokenizer.transform(df).head() Row(text=u'a b c', words=[u'a', u'b', u'c']) >>> # Change a parameter. @@ -3144,6 +3863,18 @@ def setParams(self, inputCol=None, outputCol=None): kwargs = self._input_kwargs return self._set(**kwargs) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + @inherit_doc class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, HasHandleInvalid, JavaMLReadable, @@ -3152,7 +3883,9 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, HasHandleInva A feature transformer that merges multiple columns into a vector column. >>> df = spark.createDataFrame([(1, 0, 3)], ["a", "b", "c"]) - >>> vecAssembler = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features") + >>> vecAssembler = VectorAssembler(outputCol="features") + >>> vecAssembler.setInputCols(["a", "b", "c"]) + VectorAssembler... >>> vecAssembler.transform(df).head().features DenseVector([1.0, 0.0, 3.0]) >>> vecAssembler.setParams(outputCol="freqs").transform(df).head().freqs @@ -3220,6 +3953,24 @@ def setParams(self, inputCols=None, outputCol=None, handleInvalid="error"): kwargs = self._input_kwargs return self._set(**kwargs) + def setInputCols(self, value): + """ + Sets the value of :py:attr:`inputCols`. + """ + return self._set(inputCols=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + + def setHandleInvalid(self, value): + """ + Sets the value of :py:attr:`handleInvalid`. + """ + return self._set(handleInvalid=value) + class _VectorIndexerParams(HasInputCol, HasOutputCol, HasHandleInvalid): """ @@ -3288,7 +4039,9 @@ class VectorIndexer(JavaEstimator, _VectorIndexerParams, JavaMLReadable, JavaMLW >>> from pyspark.ml.linalg import Vectors >>> df = spark.createDataFrame([(Vectors.dense([-1.0, 0.0]),), ... (Vectors.dense([0.0, 1.0]),), (Vectors.dense([0.0, 2.0]),)], ["a"]) - >>> indexer = VectorIndexer(maxCategories=2, inputCol="a", outputCol="indexed") + >>> indexer = VectorIndexer(maxCategories=2, inputCol="a") + >>> indexer.setOutputCol("indexed") + VectorIndexer... >>> model = indexer.fit(df) >>> indexer.getHandleInvalid() 'error' @@ -3359,6 +4112,24 @@ def setMaxCategories(self, value): """ return self._set(maxCategories=value) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + + def setHandleInvalid(self, value): + """ + Sets the value of :py:attr:`handleInvalid`. + """ + return self._set(handleInvalid=value) + def _create_model(self, java_model): return VectorIndexerModel(java_model) @@ -3380,6 +4151,20 @@ class VectorIndexerModel(JavaModel, _VectorIndexerParams, JavaMLReadable, JavaML .. versionadded:: 1.4.0 """ + @since("3.0.0") + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + @since("3.0.0") + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + @property @since("1.4.0") def numFeatures(self): @@ -3417,7 +4202,9 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, J ... (Vectors.dense([-2.0, 2.3, 0.0, 0.0, 1.0]),), ... (Vectors.dense([0.0, 0.0, 0.0, 0.0, 0.0]),), ... (Vectors.dense([0.6, -1.1, -3.0, 4.5, 3.3]),)], ["features"]) - >>> vs = VectorSlicer(inputCol="features", outputCol="sliced", indices=[1, 4]) + >>> vs = VectorSlicer(outputCol="sliced", indices=[1, 4]) + >>> vs.setInputCol("features") + VectorSlicer... >>> vs.transform(df).head().sliced DenseVector([2.3, 1.0]) >>> vectorSlicerPath = temp_path + "/vector-slicer" @@ -3488,6 +4275,18 @@ def getNames(self): """ return self.getOrDefault(self.names) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + class _Word2VecParams(HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCol): """ @@ -3560,6 +4359,11 @@ class Word2Vec(JavaEstimator, _Word2VecParams, JavaMLReadable, JavaMLWritable): >>> sent = ("a b " * 100 + "a c " * 10).split(" ") >>> doc = spark.createDataFrame([(sent,), (sent,)], ["sentence"]) >>> word2Vec = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", outputCol="model") + >>> word2Vec.setMaxIter(10) + Word2Vec... + >>> word2Vec.getMaxIter() + 10 + >>> word2Vec.clear(word2Vec.maxIter) >>> model = word2Vec.fit(doc) >>> model.getMinCount() 5 @@ -3666,12 +4470,36 @@ def setMaxSentenceLength(self, value): """ return self._set(maxSentenceLength=value) - @since("2.0.0") - def getMaxSentenceLength(self): + def setMaxIter(self, value): """ - Gets the value of maxSentenceLength or its default value. + Sets the value of :py:attr:`maxIter`. """ - return self.getOrDefault(self.maxSentenceLength) + return self._set(maxIter=value) + + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + + def setSeed(self, value): + """ + Sets the value of :py:attr:`seed`. + """ + return self._set(seed=value) + + @since("1.4.0") + def setStepSize(self, value): + """ + Sets the value of :py:attr:`stepSize`. + """ + return self._set(stepSize=value) def _create_model(self, java_model): return Word2VecModel(java_model) @@ -3692,6 +4520,18 @@ def getVectors(self): """ return self._call_java("getVectors") + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + @since("1.5.0") def findSynonyms(self, word, num): """ @@ -3747,7 +4587,9 @@ class PCA(JavaEstimator, _PCAParams, JavaMLReadable, JavaMLWritable): ... (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), ... (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] >>> df = spark.createDataFrame(data,["features"]) - >>> pca = PCA(k=2, inputCol="features", outputCol="pca_features") + >>> pca = PCA(k=2, inputCol="features") + >>> pca.setOutputCol("pca_features") + PCA... >>> model = pca.fit(df) >>> model.getK() 2 @@ -3800,6 +4642,18 @@ def setK(self, value): """ return self._set(k=value) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + def _create_model(self, java_model): return PCAModel(java_model) @@ -3811,6 +4665,20 @@ class PCAModel(JavaModel, _PCAParams, JavaMLReadable, JavaMLWritable): .. versionadded:: 1.5.0 """ + @since("3.0.0") + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + @since("3.0.0") + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + @property @since("2.0.0") def pc(self): @@ -4001,6 +4869,24 @@ def setStringIndexerOrderType(self, value): """ return self._set(stringIndexerOrderType=value) + def setFeaturesCol(self, value): + """ + Sets the value of :py:attr:`featuresCol`. + """ + return self._set(featuresCol=value) + + def setLabelCol(self, value): + """ + Sets the value of :py:attr:`labelCol`. + """ + return self._set(labelCol=value) + + def setHandleInvalid(self, value): + """ + Sets the value of :py:attr:`handleInvalid`. + """ + return self._set(handleInvalid=value) + def _create_model(self, java_model): return RFormulaModel(java_model) @@ -4228,6 +5114,24 @@ def setFwe(self, value): """ return self._set(fwe=value) + def setFeaturesCol(self, value): + """ + Sets the value of :py:attr:`featuresCol`. + """ + return self._set(featuresCol=value) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + + def setLabelCol(self, value): + """ + Sets the value of :py:attr:`labelCol`. + """ + return self._set(labelCol=value) + def _create_model(self, java_model): return ChiSqSelectorModel(java_model) @@ -4239,6 +5143,20 @@ class ChiSqSelectorModel(JavaModel, _ChiSqSelectorParams, JavaMLReadable, JavaML .. versionadded:: 2.0.0 """ + @since("3.0.0") + def setFeaturesCol(self, value): + """ + Sets the value of :py:attr:`featuresCol`. + """ + return self._set(featuresCol=value) + + @since("3.0.0") + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + @property @since("2.0.0") def selectedFeatures(self): @@ -4323,6 +5241,18 @@ def setSize(self, value): """ Sets size param, the size of vectors in `inputCol`.""" return self._set(size=value) + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + def setHandleInvalid(self, value): + """ + Sets the value of :py:attr:`handleInvalid`. + """ + return self._set(handleInvalid=value) + if __name__ == "__main__": import doctest diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index 652acbb34a90..5b34d555484d 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -102,6 +102,13 @@ def setMinConfidence(self, value): """ return self._set(minConfidence=value) + @since("3.0.0") + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + return self._set(predictionCol=value) + @property @since("2.2.0") def freqItemsets(self): @@ -239,6 +246,12 @@ def setMinConfidence(self, value): """ return self._set(minConfidence=value) + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + return self._set(predictionCol=value) + def _create_model(self, java_model): return FPGrowthModel(java_model) diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index c99ec3f467ac..8ea94e476000 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -81,12 +81,6 @@ def _gen_param_code(name, doc, defaultValueStr): """ # TODO: How to correctly inherit instance attributes? template = ''' - def set$Name(self, value): - """ - Sets the value of :py:attr:`$name`. - """ - return self._set($name=value) - def get$Name(self): """ Gets the value of $name or its default value. diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 771b4bcd9ba0..26d74fab6975 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -30,12 +30,6 @@ class HasMaxIter(Params): def __init__(self): super(HasMaxIter, self).__init__() - def setMaxIter(self, value): - """ - Sets the value of :py:attr:`maxIter`. - """ - return self._set(maxIter=value) - def getMaxIter(self): """ Gets the value of maxIter or its default value. @@ -53,12 +47,6 @@ class HasRegParam(Params): def __init__(self): super(HasRegParam, self).__init__() - def setRegParam(self, value): - """ - Sets the value of :py:attr:`regParam`. - """ - return self._set(regParam=value) - def getRegParam(self): """ Gets the value of regParam or its default value. @@ -77,12 +65,6 @@ def __init__(self): super(HasFeaturesCol, self).__init__() self._setDefault(featuresCol='features') - def setFeaturesCol(self, value): - """ - Sets the value of :py:attr:`featuresCol`. - """ - return self._set(featuresCol=value) - def getFeaturesCol(self): """ Gets the value of featuresCol or its default value. @@ -101,12 +83,6 @@ def __init__(self): super(HasLabelCol, self).__init__() self._setDefault(labelCol='label') - def setLabelCol(self, value): - """ - Sets the value of :py:attr:`labelCol`. - """ - return self._set(labelCol=value) - def getLabelCol(self): """ Gets the value of labelCol or its default value. @@ -125,12 +101,6 @@ def __init__(self): super(HasPredictionCol, self).__init__() self._setDefault(predictionCol='prediction') - def setPredictionCol(self, value): - """ - Sets the value of :py:attr:`predictionCol`. - """ - return self._set(predictionCol=value) - def getPredictionCol(self): """ Gets the value of predictionCol or its default value. @@ -149,12 +119,6 @@ def __init__(self): super(HasProbabilityCol, self).__init__() self._setDefault(probabilityCol='probability') - def setProbabilityCol(self, value): - """ - Sets the value of :py:attr:`probabilityCol`. - """ - return self._set(probabilityCol=value) - def getProbabilityCol(self): """ Gets the value of probabilityCol or its default value. @@ -173,12 +137,6 @@ def __init__(self): super(HasRawPredictionCol, self).__init__() self._setDefault(rawPredictionCol='rawPrediction') - def setRawPredictionCol(self, value): - """ - Sets the value of :py:attr:`rawPredictionCol`. - """ - return self._set(rawPredictionCol=value) - def getRawPredictionCol(self): """ Gets the value of rawPredictionCol or its default value. @@ -196,12 +154,6 @@ class HasInputCol(Params): def __init__(self): super(HasInputCol, self).__init__() - def setInputCol(self, value): - """ - Sets the value of :py:attr:`inputCol`. - """ - return self._set(inputCol=value) - def getInputCol(self): """ Gets the value of inputCol or its default value. @@ -219,12 +171,6 @@ class HasInputCols(Params): def __init__(self): super(HasInputCols, self).__init__() - def setInputCols(self, value): - """ - Sets the value of :py:attr:`inputCols`. - """ - return self._set(inputCols=value) - def getInputCols(self): """ Gets the value of inputCols or its default value. @@ -243,12 +189,6 @@ def __init__(self): super(HasOutputCol, self).__init__() self._setDefault(outputCol=self.uid + '__output') - def setOutputCol(self, value): - """ - Sets the value of :py:attr:`outputCol`. - """ - return self._set(outputCol=value) - def getOutputCol(self): """ Gets the value of outputCol or its default value. @@ -266,12 +206,6 @@ class HasOutputCols(Params): def __init__(self): super(HasOutputCols, self).__init__() - def setOutputCols(self, value): - """ - Sets the value of :py:attr:`outputCols`. - """ - return self._set(outputCols=value) - def getOutputCols(self): """ Gets the value of outputCols or its default value. @@ -290,12 +224,6 @@ def __init__(self): super(HasNumFeatures, self).__init__() self._setDefault(numFeatures=262144) - def setNumFeatures(self, value): - """ - Sets the value of :py:attr:`numFeatures`. - """ - return self._set(numFeatures=value) - def getNumFeatures(self): """ Gets the value of numFeatures or its default value. @@ -313,12 +241,6 @@ class HasCheckpointInterval(Params): def __init__(self): super(HasCheckpointInterval, self).__init__() - def setCheckpointInterval(self, value): - """ - Sets the value of :py:attr:`checkpointInterval`. - """ - return self._set(checkpointInterval=value) - def getCheckpointInterval(self): """ Gets the value of checkpointInterval or its default value. @@ -337,12 +259,6 @@ def __init__(self): super(HasSeed, self).__init__() self._setDefault(seed=hash(type(self).__name__)) - def setSeed(self, value): - """ - Sets the value of :py:attr:`seed`. - """ - return self._set(seed=value) - def getSeed(self): """ Gets the value of seed or its default value. @@ -360,12 +276,6 @@ class HasTol(Params): def __init__(self): super(HasTol, self).__init__() - def setTol(self, value): - """ - Sets the value of :py:attr:`tol`. - """ - return self._set(tol=value) - def getTol(self): """ Gets the value of tol or its default value. @@ -383,12 +293,6 @@ class HasStepSize(Params): def __init__(self): super(HasStepSize, self).__init__() - def setStepSize(self, value): - """ - Sets the value of :py:attr:`stepSize`. - """ - return self._set(stepSize=value) - def getStepSize(self): """ Gets the value of stepSize or its default value. @@ -406,12 +310,6 @@ class HasHandleInvalid(Params): def __init__(self): super(HasHandleInvalid, self).__init__() - def setHandleInvalid(self, value): - """ - Sets the value of :py:attr:`handleInvalid`. - """ - return self._set(handleInvalid=value) - def getHandleInvalid(self): """ Gets the value of handleInvalid or its default value. @@ -430,12 +328,6 @@ def __init__(self): super(HasElasticNetParam, self).__init__() self._setDefault(elasticNetParam=0.0) - def setElasticNetParam(self, value): - """ - Sets the value of :py:attr:`elasticNetParam`. - """ - return self._set(elasticNetParam=value) - def getElasticNetParam(self): """ Gets the value of elasticNetParam or its default value. @@ -454,12 +346,6 @@ def __init__(self): super(HasFitIntercept, self).__init__() self._setDefault(fitIntercept=True) - def setFitIntercept(self, value): - """ - Sets the value of :py:attr:`fitIntercept`. - """ - return self._set(fitIntercept=value) - def getFitIntercept(self): """ Gets the value of fitIntercept or its default value. @@ -478,12 +364,6 @@ def __init__(self): super(HasStandardization, self).__init__() self._setDefault(standardization=True) - def setStandardization(self, value): - """ - Sets the value of :py:attr:`standardization`. - """ - return self._set(standardization=value) - def getStandardization(self): """ Gets the value of standardization or its default value. @@ -501,12 +381,6 @@ class HasThresholds(Params): def __init__(self): super(HasThresholds, self).__init__() - def setThresholds(self, value): - """ - Sets the value of :py:attr:`thresholds`. - """ - return self._set(thresholds=value) - def getThresholds(self): """ Gets the value of thresholds or its default value. @@ -525,12 +399,6 @@ def __init__(self): super(HasThreshold, self).__init__() self._setDefault(threshold=0.5) - def setThreshold(self, value): - """ - Sets the value of :py:attr:`threshold`. - """ - return self._set(threshold=value) - def getThreshold(self): """ Gets the value of threshold or its default value. @@ -548,12 +416,6 @@ class HasWeightCol(Params): def __init__(self): super(HasWeightCol, self).__init__() - def setWeightCol(self, value): - """ - Sets the value of :py:attr:`weightCol`. - """ - return self._set(weightCol=value) - def getWeightCol(self): """ Gets the value of weightCol or its default value. @@ -572,12 +434,6 @@ def __init__(self): super(HasSolver, self).__init__() self._setDefault(solver='auto') - def setSolver(self, value): - """ - Sets the value of :py:attr:`solver`. - """ - return self._set(solver=value) - def getSolver(self): """ Gets the value of solver or its default value. @@ -595,12 +451,6 @@ class HasVarianceCol(Params): def __init__(self): super(HasVarianceCol, self).__init__() - def setVarianceCol(self, value): - """ - Sets the value of :py:attr:`varianceCol`. - """ - return self._set(varianceCol=value) - def getVarianceCol(self): """ Gets the value of varianceCol or its default value. @@ -619,12 +469,6 @@ def __init__(self): super(HasAggregationDepth, self).__init__() self._setDefault(aggregationDepth=2) - def setAggregationDepth(self, value): - """ - Sets the value of :py:attr:`aggregationDepth`. - """ - return self._set(aggregationDepth=value) - def getAggregationDepth(self): """ Gets the value of aggregationDepth or its default value. @@ -643,12 +487,6 @@ def __init__(self): super(HasParallelism, self).__init__() self._setDefault(parallelism=1) - def setParallelism(self, value): - """ - Sets the value of :py:attr:`parallelism`. - """ - return self._set(parallelism=value) - def getParallelism(self): """ Gets the value of parallelism or its default value. @@ -667,12 +505,6 @@ def __init__(self): super(HasCollectSubModels, self).__init__() self._setDefault(collectSubModels=False) - def setCollectSubModels(self, value): - """ - Sets the value of :py:attr:`collectSubModels`. - """ - return self._set(collectSubModels=value) - def getCollectSubModels(self): """ Gets the value of collectSubModels or its default value. @@ -690,12 +522,6 @@ class HasLoss(Params): def __init__(self): super(HasLoss, self).__init__() - def setLoss(self, value): - """ - Sets the value of :py:attr:`loss`. - """ - return self._set(loss=value) - def getLoss(self): """ Gets the value of loss or its default value. @@ -714,12 +540,6 @@ def __init__(self): super(HasDistanceMeasure, self).__init__() self._setDefault(distanceMeasure='euclidean') - def setDistanceMeasure(self, value): - """ - Sets the value of :py:attr:`distanceMeasure`. - """ - return self._set(distanceMeasure=value) - def getDistanceMeasure(self): """ Gets the value of distanceMeasure or its default value. @@ -737,12 +557,6 @@ class HasValidationIndicatorCol(Params): def __init__(self): super(HasValidationIndicatorCol, self).__init__() - def setValidationIndicatorCol(self, value): - """ - Sets the value of :py:attr:`validationIndicatorCol`. - """ - return self._set(validationIndicatorCol=value) - def getValidationIndicatorCol(self): """ Gets the value of validationIndicatorCol or its default value. diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index df9c765457ec..3ebd0ac2765f 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -212,7 +212,16 @@ class ALS(JavaEstimator, _ALSParams, JavaMLWritable, JavaMLReadable): >>> df = spark.createDataFrame( ... [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)], ... ["user", "item", "rating"]) - >>> als = ALS(rank=10, maxIter=5, seed=0) + >>> als = ALS(rank=10, seed=0) + >>> als.setMaxIter(5) + ALS... + >>> als.getMaxIter() + 5 + >>> als.setRegParam(0.1) + ALS... + >>> als.getRegParam() + 0.1 + >>> als.clear(als.regParam) >>> model = als.fit(df) >>> model.getUserCol() 'user' @@ -402,6 +411,36 @@ def setColdStartStrategy(self, value): """ return self._set(coldStartStrategy=value) + def setMaxIter(self, value): + """ + Sets the value of :py:attr:`maxIter`. + """ + return self._set(maxIter=value) + + def setRegParam(self, value): + """ + Sets the value of :py:attr:`regParam`. + """ + return self._set(regParam=value) + + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + return self._set(predictionCol=value) + + def setCheckpointInterval(self, value): + """ + Sets the value of :py:attr:`checkpointInterval`. + """ + return self._set(checkpointInterval=value) + + def setSeed(self, value): + """ + Sets the value of :py:attr:`seed`. + """ + return self._set(seed=value) + class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable, JavaMLReadable): """ @@ -431,6 +470,13 @@ def setColdStartStrategy(self, value): """ return self._set(coldStartStrategy=value) + @since("3.0.0") + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + return self._set(predictionCol=value) + @property @since("1.4.0") def rank(self): diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 147ebed1d633..08e68d8bc304 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -92,7 +92,17 @@ class LinearRegression(JavaPredictor, _LinearRegressionParams, JavaMLWritable, J >>> df = spark.createDataFrame([ ... (1.0, 2.0, Vectors.dense(1.0)), ... (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"]) - >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight") + >>> lr = LinearRegression(regParam=0.0, solver="normal", weightCol="weight") + >>> lr.setMaxIter(5) + LinearRegression... + >>> lr.getMaxIter() + 5 + >>> lr.setRegParam(0.1) + LinearRegression... + >>> lr.getRegParam() + 0.1 + >>> lr.setRegParam(0.0) + LinearRegression... >>> model = lr.fit(df) >>> model.setFeaturesCol("features") LinearRegression... @@ -179,6 +189,66 @@ def setEpsilon(self, value): """ return self._set(epsilon=value) + def setMaxIter(self, value): + """ + Sets the value of :py:attr:`maxIter`. + """ + return self._set(maxIter=value) + + def setRegParam(self, value): + """ + Sets the value of :py:attr:`regParam`. + """ + return self._set(regParam=value) + + def setTol(self, value): + """ + Sets the value of :py:attr:`tol`. + """ + return self._set(tol=value) + + def setElasticNetParam(self, value): + """ + Sets the value of :py:attr:`elasticNetParam`. + """ + return self._set(elasticNetParam=value) + + def setFitIntercept(self, value): + """ + Sets the value of :py:attr:`fitIntercept`. + """ + return self._set(fitIntercept=value) + + def setStandardization(self, value): + """ + Sets the value of :py:attr:`standardization`. + """ + return self._set(standardization=value) + + def setWeightCol(self, value): + """ + Sets the value of :py:attr:`weightCol`. + """ + return self._set(weightCol=value) + + def setSolver(self, value): + """ + Sets the value of :py:attr:`solver`. + """ + return self._set(solver=value) + + def setAggregationDepth(self, value): + """ + Sets the value of :py:attr:`aggregationDepth`. + """ + return self._set(aggregationDepth=value) + + def setLoss(self, value): + """ + Sets the value of :py:attr:`loss`. + """ + return self._set(lossType=value) + class LinearRegressionModel(JavaPredictionModel, _LinearRegressionParams, GeneralJavaMLWritable, JavaMLReadable, HasTrainingSummary): @@ -522,10 +592,6 @@ class IsotonicRegression(JavaEstimator, _IsotonicRegressionParams, HasWeightCol, >>> model = ir.fit(df) >>> model.setFeaturesCol("features") IsotonicRegression... - >>> model.setLabelCol("newLabel") - IsotonicRegression... - >>> model.getLabelCol() - 'newLabel' >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction 0.0 @@ -586,6 +652,34 @@ def setFeatureIndex(self, value): """ return self._set(featureIndex=value) + @since("1.6.0") + def setFeaturesCol(self, value): + """ + Sets the value of :py:attr:`featuresCol`. + """ + return self._set(featuresCol=value) + + @since("1.6.0") + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + return self._set(predictionCol=value) + + @since("1.6.0") + def setLabelCol(self, value): + """ + Sets the value of :py:attr:`labelCol`. + """ + return self._set(labelCol=value) + + @since("1.6.0") + def setWeightCol(self, value): + """ + Sets the value of :py:attr:`weightCol`. + """ + return self._set(weightCol=value) + class IsotonicRegressionModel(JavaModel, _IsotonicRegressionParams, JavaMLWritable, JavaMLReadable): @@ -595,6 +689,26 @@ class IsotonicRegressionModel(JavaModel, _IsotonicRegressionParams, JavaMLWritab .. versionadded:: 1.6.0 """ + @since("3.0.0") + def setFeaturesCol(self, value): + """ + Sets the value of :py:attr:`featuresCol`. + """ + return self._set(featuresCol=value) + + @since("3.0.0") + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + return self._set(predictionCol=value) + + def setFeatureIndex(self, value): + """ + Sets the value of :py:attr:`featureIndex`. + """ + return self._set(featureIndex=value) + @property @since("1.6.0") def boundaries(self): @@ -635,7 +749,9 @@ class DecisionTreeRegressor(JavaPredictor, _DecisionTreeRegressorParams, JavaMLW >>> df = spark.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> dt = DecisionTreeRegressor(maxDepth=2, varianceCol="variance") + >>> dt = DecisionTreeRegressor(maxDepth=2) + >>> dt.setVarianceCol("variance") + DecisionTreeRegressor... >>> model = dt.fit(df) >>> model.getVarianceCol() 'variance' @@ -732,18 +848,21 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return DecisionTreeRegressionModel(java_model) + @since("1.4.0") def setMaxDepth(self, value): """ Sets the value of :py:attr:`maxDepth`. """ return self._set(maxDepth=value) + @since("1.4.0") def setMaxBins(self, value): """ Sets the value of :py:attr:`maxBins`. """ return self._set(maxBins=value) + @since("1.4.0") def setMinInstancesPerNode(self, value): """ Sets the value of :py:attr:`minInstancesPerNode`. @@ -757,18 +876,21 @@ def setMinWeightFractionPerNode(self, value): """ return self._set(minWeightFractionPerNode=value) + @since("1.4.0") def setMinInfoGain(self, value): """ Sets the value of :py:attr:`minInfoGain`. """ return self._set(minInfoGain=value) + @since("1.4.0") def setMaxMemoryInMB(self, value): """ Sets the value of :py:attr:`maxMemoryInMB`. """ return self._set(maxMemoryInMB=value) + @since("1.4.0") def setCacheNodeIds(self, value): """ Sets the value of :py:attr:`cacheNodeIds`. @@ -782,6 +904,34 @@ def setImpurity(self, value): """ return self._set(impurity=value) + @since("1.4.0") + def setCheckpointInterval(self, value): + """ + Sets the value of :py:attr:`checkpointInterval`. + """ + return self._set(checkpointInterval=value) + + @since("1.4.0") + def setSeed(self, value): + """ + Sets the value of :py:attr:`seed`. + """ + return self._set(seed=value) + + @since("3.0.0") + def setWeightCol(self, value): + """ + Sets the value of :py:attr:`weightCol`. + """ + return self._set(weightCol=value) + + @since("2.0.0") + def setVarianceCol(self, value): + """ + Sets the value of :py:attr:`varianceCol`. + """ + return self._set(varianceCol=value) + @inherit_doc class DecisionTreeRegressionModel(_DecisionTreeModel, _DecisionTreeRegressorParams, @@ -792,6 +942,13 @@ class DecisionTreeRegressionModel(_DecisionTreeModel, _DecisionTreeRegressorPara .. versionadded:: 1.4.0 """ + @since("3.0.0") + def setVarianceCol(self, value): + """ + Sets the value of :py:attr:`varianceCol`. + """ + return self._set(varianceCol=value) + @property @since("2.0.0") def featureImportances(self): @@ -836,7 +993,9 @@ class RandomForestRegressor(JavaPredictor, _RandomForestRegressorParams, JavaMLW >>> df = spark.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42) + >>> rf = RandomForestRegressor(numTrees=2, maxDepth=2) + >>> rf.setSeed(42) + RandomForestRegressor... >>> model = rf.fit(df) >>> model.getSeed() 42 @@ -987,6 +1146,18 @@ def setFeatureSubsetStrategy(self, value): """ return self._set(featureSubsetStrategy=value) + def setCheckpointInterval(self, value): + """ + Sets the value of :py:attr:`checkpointInterval`. + """ + return self._set(checkpointInterval=value) + + def setSeed(self, value): + """ + Sets the value of :py:attr:`seed`. + """ + return self._set(seed=value) + class RandomForestRegressionModel(_TreeEnsembleModel, _RandomForestRegressorParams, JavaMLWritable, JavaMLReadable): @@ -1052,7 +1223,11 @@ class GBTRegressor(JavaPredictor, _GBTRegressorParams, JavaMLWritable, JavaMLRea >>> df = spark.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42, leafCol="leafId") + >>> gbt = GBTRegressor(maxDepth=2, seed=42, leafCol="leafId") + >>> gbt.setMaxIter(5) + GBTRegressor... + >>> gbt.getMaxIter() + 5 >>> print(gbt.getImpurity()) variance >>> print(gbt.getFeatureSubsetStrategy()) @@ -1152,36 +1327,42 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return GBTRegressionModel(java_model) + @since("1.4.0") def setMaxDepth(self, value): """ Sets the value of :py:attr:`maxDepth`. """ return self._set(maxDepth=value) + @since("1.4.0") def setMaxBins(self, value): """ Sets the value of :py:attr:`maxBins`. """ return self._set(maxBins=value) + @since("1.4.0") def setMinInstancesPerNode(self, value): """ Sets the value of :py:attr:`minInstancesPerNode`. """ return self._set(minInstancesPerNode=value) + @since("1.4.0") def setMinInfoGain(self, value): """ Sets the value of :py:attr:`minInfoGain`. """ return self._set(minInfoGain=value) + @since("1.4.0") def setMaxMemoryInMB(self, value): """ Sets the value of :py:attr:`maxMemoryInMB`. """ return self._set(maxMemoryInMB=value) + @since("1.4.0") def setCacheNodeIds(self, value): """ Sets the value of :py:attr:`cacheNodeIds`. @@ -1223,6 +1404,34 @@ def setValidationIndicatorCol(self, value): """ return self._set(validationIndicatorCol=value) + @since("1.4.0") + def setMaxIter(self, value): + """ + Sets the value of :py:attr:`maxIter`. + """ + return self._set(maxIter=value) + + @since("1.4.0") + def setCheckpointInterval(self, value): + """ + Sets the value of :py:attr:`checkpointInterval`. + """ + return self._set(checkpointInterval=value) + + @since("1.4.0") + def setSeed(self, value): + """ + Sets the value of :py:attr:`seed`. + """ + return self._set(seed=value) + + @since("1.4.0") + def setStepSize(self, value): + """ + Sets the value of :py:attr:`stepSize`. + """ + return self._set(stepSize=value) + class GBTRegressionModel(_TreeEnsembleModel, _GBTRegressorParams, JavaMLWritable, JavaMLReadable): """ @@ -1330,6 +1539,11 @@ class AFTSurvivalRegression(JavaEstimator, _AFTSurvivalRegressionParams, ... (1.0, Vectors.dense(1.0), 1.0), ... (1e-40, Vectors.sparse(1, [], []), 0.0)], ["label", "features", "censor"]) >>> aftsr = AFTSurvivalRegression() + >>> aftsr.setMaxIter(10) + AFTSurvivalRegression... + >>> aftsr.getMaxIter() + 10 + >>> aftsr.clear(aftsr.maxIter) >>> model = aftsr.fit(df) >>> model.setFeaturesCol("features") AFTSurvivalRegression... @@ -1422,6 +1636,55 @@ def setQuantilesCol(self, value): """ return self._set(quantilesCol=value) + @since("1.6.0") + def setMaxIter(self, value): + """ + Sets the value of :py:attr:`maxIter`. + """ + return self._set(maxIter=value) + + @since("1.6.0") + def setFeaturesCol(self, value): + """ + Sets the value of :py:attr:`featuresCol`. + """ + return self._set(featuresCol=value) + + @since("1.6.0") + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + return self._set(predictionCol=value) + + @since("1.6.0") + def setLabelCol(self, value): + """ + Sets the value of :py:attr:`labelCol`. + """ + return self._set(labelCol=value) + + @since("1.6.0") + def setTol(self, value): + """ + Sets the value of :py:attr:`tol`. + """ + return self._set(tol=value) + + @since("1.6.0") + def setFitIntercept(self, value): + """ + Sets the value of :py:attr:`fitIntercept`. + """ + return self._set(fitIntercept=value) + + @since("2.1.0") + def setAggregationDepth(self, value): + """ + Sets the value of :py:attr:`aggregationDepth`. + """ + return self._set(aggregationDepth=value) + class AFTSurvivalRegressionModel(JavaModel, _AFTSurvivalRegressionParams, JavaMLWritable, JavaMLReadable): @@ -1431,6 +1694,34 @@ class AFTSurvivalRegressionModel(JavaModel, _AFTSurvivalRegressionParams, .. versionadded:: 1.6.0 """ + @since("3.0.0") + def setFeaturesCol(self, value): + """ + Sets the value of :py:attr:`featuresCol`. + """ + return self._set(featuresCol=value) + + @since("3.0.0") + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + return self._set(predictionCol=value) + + @since("3.0.0") + def setQuantileProbabilities(self, value): + """ + Sets the value of :py:attr:`quantileProbabilities`. + """ + return self._set(quantileProbabilities=value) + + @since("3.0.0") + def setQuantilesCol(self, value): + """ + Sets the value of :py:attr:`quantilesCol`. + """ + return self._set(quantilesCol=value) + @property @since("2.0.0") def coefficients(self): @@ -1577,6 +1868,16 @@ class GeneralizedLinearRegression(JavaPredictor, _GeneralizedLinearRegressionPar ... (2.0, Vectors.dense(0.0, 0.0)), ... (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"]) >>> glr = GeneralizedLinearRegression(family="gaussian", link="identity", linkPredictionCol="p") + >>> glr.setRegParam(0.1) + GeneralizedLinearRegression... + >>> glr.getRegParam() + 0.1 + >>> glr.clear(glr.regParam) + >>> glr.setMaxIter(10) + GeneralizedLinearRegression... + >>> glr.getMaxIter() + 10 + >>> glr.clear(glr.maxIter) >>> model = glr.fit(df) >>> model.setFeaturesCol("features") GeneralizedLinearRegression... @@ -1690,6 +1991,48 @@ def setOffsetCol(self, value): """ return self._set(offsetCol=value) + @since("2.0.0") + def setMaxIter(self, value): + """ + Sets the value of :py:attr:`maxIter`. + """ + return self._set(maxIter=value) + + @since("2.0.0") + def setRegParam(self, value): + """ + Sets the value of :py:attr:`regParam`. + """ + return self._set(regParam=value) + + @since("2.0.0") + def setTol(self, value): + """ + Sets the value of :py:attr:`tol`. + """ + return self._set(tol=value) + + @since("2.2.0") + def setFitIntercept(self, value): + """ + Sets the value of :py:attr:`fitIntercept`. + """ + return self._set(fitIntercept=value) + + @since("2.0.0") + def setWeightCol(self, value): + """ + Sets the value of :py:attr:`weightCol`. + """ + return self._set(weightCol=value) + + @since("2.0.0") + def setSolver(self, value): + """ + Sets the value of :py:attr:`solver`. + """ + return self._set(solver=value) + class GeneralizedLinearRegressionModel(JavaPredictionModel, _GeneralizedLinearRegressionParams, JavaMLWritable, JavaMLReadable, HasTrainingSummary): @@ -1699,6 +2042,13 @@ class GeneralizedLinearRegressionModel(JavaPredictionModel, _GeneralizedLinearRe .. versionadded:: 2.0.0 """ + @since("3.0.0") + def setLinkPredictionCol(self, value): + """ + Sets the value of :py:attr:`linkPredictionCol`. + """ + return self._set(linkPredictionCol=value) + @property @since("2.0.0") def coefficients(self): diff --git a/python/pyspark/ml/tests/test_param.py b/python/pyspark/ml/tests/test_param.py index 4c7f01484dc2..75cd903b5d6d 100644 --- a/python/pyspark/ml/tests/test_param.py +++ b/python/pyspark/ml/tests/test_param.py @@ -221,13 +221,6 @@ def test_params(self): self.assertFalse(testParams.isSet(maxIter)) self.assertTrue(testParams.isDefined(maxIter)) self.assertEqual(testParams.getMaxIter(), 10) - testParams.setMaxIter(100) - self.assertTrue(testParams.isSet(maxIter)) - self.assertEqual(testParams.getMaxIter(), 100) - testParams.clear(maxIter) - self.assertFalse(testParams.isSet(maxIter)) - self.assertEqual(testParams.getMaxIter(), 10) - testParams.setMaxIter(100) self.assertTrue(testParams.hasParam(inputCol.name)) self.assertFalse(testParams.hasDefault(inputCol)) @@ -244,13 +237,12 @@ def test_params(self): # Since the default is normally random, set it to a known number for debug str testParams._setDefault(seed=41) - testParams.setSeed(43) self.assertEqual( testParams.explainParams(), "\n".join(["inputCol: input column name. (undefined)", - "maxIter: max number of iterations (>= 0). (default: 10, current: 100)", - "seed: random seed. (default: 41, current: 43)"])) + "maxIter: max number of iterations (>= 0). (default: 10)", + "seed: random seed. (default: 41)"])) def test_clear_param(self): df = self.spark.createDataFrame([(Vectors.dense([1.0]),), (Vectors.dense([2.0]),)], ["a"]) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 8052163acd00..16c376296c20 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -299,6 +299,24 @@ def setNumFolds(self, value): """ return self._set(numFolds=value) + def setSeed(self, value): + """ + Sets the value of :py:attr:`seed`. + """ + return self._set(seed=value) + + def setParallelism(self, value): + """ + Sets the value of :py:attr:`parallelism`. + """ + return self._set(parallelism=value) + + def setCollectSubModels(self, value): + """ + Sets the value of :py:attr:`collectSubModels`. + """ + return self._set(collectSubModels=value) + def _fit(self, dataset): est = self.getOrDefault(self.estimator) epm = self.getOrDefault(self.estimatorParamMaps) @@ -643,6 +661,24 @@ def setTrainRatio(self, value): """ return self._set(trainRatio=value) + def setSeed(self, value): + """ + Sets the value of :py:attr:`seed`. + """ + return self._set(seed=value) + + def setParallelism(self, value): + """ + Sets the value of :py:attr:`parallelism`. + """ + return self._set(parallelism=value) + + def setCollectSubModels(self, value): + """ + Sets the value of :py:attr:`collectSubModels`. + """ + return self._set(collectSubModels=value) + def _fit(self, dataset): est = self.getOrDefault(self.estimator) epm = self.getOrDefault(self.estimatorParamMaps)