diff --git a/.travis.yml b/.travis.yml index d7e9f8c0290e..05b94adeeb93 100644 --- a/.travis.yml +++ b/.travis.yml @@ -43,7 +43,7 @@ notifications: # 5. Run maven install before running lint-java. install: - export MAVEN_SKIP_RC=1 - - build/mvn -T 4 -q -DskipTests -Pmesos -Pyarn -Pkinesis-asl -Phive -Phive-thriftserver install + - build/mvn -T 4 -q -DskipTests -Pkubernetes -Pmesos -Pyarn -Pkinesis-asl -Phive -Phive-thriftserver install # 6. Run lint-java. script: diff --git a/NOTICE b/NOTICE index f4b64b5c3f47..6ec240efbf12 100644 --- a/NOTICE +++ b/NOTICE @@ -448,6 +448,12 @@ Copyright (C) 2011 Google Inc. Apache Commons Pool Copyright 1999-2009 The Apache Software Foundation +This product includes/uses Kubernetes & OpenShift 3 Java Client (https://github.com/fabric8io/kubernetes-client) +Copyright (C) 2015 Red Hat, Inc. + +This product includes/uses OkHttp (https://github.com/square/okhttp) +Copyright (C) 2012 The Android Open Source Project + ========================================================================= == NOTICE file corresponding to section 4(d) of the Apache License, == == Version 2.0, in this case for the DataNucleus distribution. == diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackendUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackendUtils.scala new file mode 100644 index 000000000000..c166d030f2c8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackendUtils.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster + +import org.apache.spark.SparkConf +import org.apache.spark.internal.config.{DYN_ALLOCATION_MAX_EXECUTORS, DYN_ALLOCATION_MIN_EXECUTORS, EXECUTOR_INSTANCES} +import org.apache.spark.util.Utils + +private[spark] object SchedulerBackendUtils { + val DEFAULT_NUMBER_EXECUTORS = 2 + + /** + * Getting the initial target number of executors depends on whether dynamic allocation is + * enabled. + * If not using dynamic allocation it gets the number of executors requested by the user. + */ + def getInitialTargetExecutorNumber( + conf: SparkConf, + numExecutors: Int = DEFAULT_NUMBER_EXECUTORS): Int = { + if (Utils.isDynamicAllocationEnabled(conf)) { + val minNumExecutors = conf.get(DYN_ALLOCATION_MIN_EXECUTORS) + val initialNumExecutors = Utils.getDynamicAllocationInitialExecutors(conf) + val maxNumExecutors = conf.get(DYN_ALLOCATION_MAX_EXECUTORS) + require(initialNumExecutors >= minNumExecutors && initialNumExecutors <= maxNumExecutors, + s"initial executor number $initialNumExecutors must between min executor number " + + s"$minNumExecutors and max executor number $maxNumExecutors") + + initialNumExecutors + } else { + conf.get(EXECUTOR_INSTANCES).getOrElse(numExecutors) + } + } +} diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index dacc89f5f27d..44f990ec8a5a 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -532,6 +532,14 @@ def __hash__(self): sbt_test_goals=["mesos/test"] ) +kubernetes = Module( + name="kubernetes", + dependencies=[], + source_file_regexes=["resource-managers/kubernetes/core"], + build_profile_flags=["-Pkubernetes"], + sbt_test_goals=["kubernetes/test"] +) + # The root module is a dummy module which is used to run all of the tests. # No other modules should directly depend on this module. root = Module( diff --git a/docs/configuration.md b/docs/configuration.md index 9b9583d9165e..e42f866c4056 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1438,10 +1438,10 @@ Apart from these, the following properties are also available, and may be useful spark.scheduler.minRegisteredResourcesRatio - 0.8 for YARN mode; 0.0 for standalone mode and Mesos coarse-grained mode + 0.8 for KUBERNETES mode; 0.8 for YARN mode; 0.0 for standalone mode and Mesos coarse-grained mode The minimum ratio of registered resources (registered resources / total expected resources) - (resources are executors in yarn mode, CPU cores in standalone mode and Mesos coarsed-grained + (resources are executors in yarn mode and Kubernetes mode, CPU cores in standalone mode and Mesos coarsed-grained mode ['spark.cores.max' value is total expected resources for Mesos coarse-grained mode] ) to wait for before scheduling begins. Specified as a double between 0.0 and 1.0. Regardless of whether the minimum ratio of resources has been reached, diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 5f9821378b27..983770d50683 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1716,6 +1716,8 @@ options. Note that, for DecimalType(38,0)*, the table above intentionally does not cover all other combinations of scales and precisions because currently we only infer decimal type like `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type. + - In PySpark, now we need Pandas 0.19.2 or upper if you want to use Pandas related functionalities, such as `toPandas`, `createDataFrame` from Pandas DataFrame, etc. + - In PySpark, the behavior of timestamp values for Pandas related functionalities was changed to respect session timezone. If you want to use the old behavior, you need to set a configuration `spark.sql.execution.pandas.respectSessionTimeZone` to `False`. See [SPARK-22395](https://issues.apache.org/jira/browse/SPARK-22395) for details. ## Upgrading From Spark SQL 2.1 to 2.2 diff --git a/pom.xml b/pom.xml index 3b2c629f8ec3..7bc66e7d1954 100644 --- a/pom.xml +++ b/pom.xml @@ -2664,6 +2664,13 @@ + + kubernetes + + resource-managers/kubernetes/core + + + hive-thriftserver diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index c726ec25478a..75703380cdb4 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -53,11 +53,11 @@ object BuildCommons { "tags", "sketch", "kvstore" ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects - val optionallyEnabledProjects@Seq(mesos, yarn, + val optionallyEnabledProjects@Seq(kubernetes, mesos, yarn, streamingFlumeSink, streamingFlume, streamingKafka, sparkGangliaLgpl, streamingKinesisAsl, dockerIntegrationTests, hadoopCloud) = - Seq("mesos", "yarn", + Seq("kubernetes", "mesos", "yarn", "streaming-flume-sink", "streaming-flume", "streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl", "docker-integration-tests", "hadoop-cloud").map(ProjectRef(buildLocation, _)) @@ -671,9 +671,9 @@ object Unidoc { publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010, sqlKafka010), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, yarn, tags, streamingKafka010, sqlKafka010), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010, sqlKafka010), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, yarn, tags, streamingKafka010, sqlKafka010), unidocAllClasspaths in (ScalaUnidoc, unidoc) := { ignoreClasspaths((unidocAllClasspaths in (ScalaUnidoc, unidoc)).value) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index b95de2c80439..37e7cf3fa662 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -206,11 +206,12 @@ def __repr__(self): return "ArrowSerializer" -def _create_batch(series): +def _create_batch(series, timezone): """ Create an Arrow record batch from the given pandas.Series or list of Series, with optional type. :param series: A single pandas.Series, list of Series, or list of (series, arrow_type) + :param timezone: A timezone to respect when handling timestamp values :return: Arrow RecordBatch """ @@ -227,7 +228,7 @@ def _create_batch(series): def cast_series(s, t): if type(t) == pa.TimestampType: # NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680 - return _check_series_convert_timestamps_internal(s.fillna(0))\ + return _check_series_convert_timestamps_internal(s.fillna(0), timezone)\ .values.astype('datetime64[us]', copy=False) # NOTE: can not compare None with pyarrow.DataType(), fixed with Arrow >= 0.7.1 elif t is not None and t == pa.date32(): @@ -253,6 +254,10 @@ class ArrowStreamPandasSerializer(Serializer): Serializes Pandas.Series as Arrow data with Arrow streaming format. """ + def __init__(self, timezone): + super(ArrowStreamPandasSerializer, self).__init__() + self._timezone = timezone + def dump_stream(self, iterator, stream): """ Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or @@ -262,7 +267,7 @@ def dump_stream(self, iterator, stream): writer = None try: for series in iterator: - batch = _create_batch(series) + batch = _create_batch(series, self._timezone) if writer is None: write_int(SpecialLengths.START_ARROW_STREAM, stream) writer = pa.RecordBatchStreamWriter(stream, batch.schema) @@ -280,7 +285,7 @@ def load_stream(self, stream): reader = pa.open_stream(stream) for batch in reader: # NOTE: changed from pa.Columns.to_pandas, timezone issue in conversion fixed in 0.7.1 - pdf = _check_dataframe_localize_timestamps(batch.to_pandas()) + pdf = _check_dataframe_localize_timestamps(batch.to_pandas(), self._timezone) yield [c for _, c in pdf.iteritems()] def __repr__(self): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 406686e6df72..9864dc98c1f3 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -39,6 +39,7 @@ from pyspark.sql.streaming import DataStreamWriter from pyspark.sql.types import IntegralType from pyspark.sql.types import * +from pyspark.util import _exception_message __all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"] @@ -1881,6 +1882,13 @@ def toPandas(self): 1 5 Bob """ import pandas as pd + + if self.sql_ctx.getConf("spark.sql.execution.pandas.respectSessionTimeZone").lower() \ + == "true": + timezone = self.sql_ctx.getConf("spark.sql.session.timeZone") + else: + timezone = None + if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: from pyspark.sql.types import _check_dataframe_localize_timestamps @@ -1889,13 +1897,13 @@ def toPandas(self): if tables: table = pyarrow.concat_tables(tables) pdf = table.to_pandas() - return _check_dataframe_localize_timestamps(pdf) + return _check_dataframe_localize_timestamps(pdf, timezone) else: return pd.DataFrame.from_records([], columns=self.columns) except ImportError as e: msg = "note: pyarrow must be installed and available on calling Python process " \ "if using spark.sql.execution.arrow.enabled=true" - raise ImportError("%s\n%s" % (e.message, msg)) + raise ImportError("%s\n%s" % (_exception_message(e), msg)) else: pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) @@ -1913,7 +1921,17 @@ def toPandas(self): for f, t in dtype.items(): pdf[f] = pdf[f].astype(t, copy=False) - return pdf + + if timezone is None: + return pdf + else: + from pyspark.sql.types import _check_series_convert_timestamps_local_tz + for field in self.schema: + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if isinstance(field.dataType, TimestampType): + pdf[field.name] = \ + _check_series_convert_timestamps_local_tz(pdf[field.name], timezone) + return pdf def _collectAsArrow(self): """ diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index a75bdf8078dd..1ad974e9aa4c 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -828,8 +828,7 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No set, it uses the default value, ``,``. :param quote: sets the single character used for escaping quoted values where the separator can be part of the value. If None is set, it uses the default - value, ``"``. If you would like to turn off quotations, you need to set an - empty string. + value, ``"``. If an empty string is set, it uses ``u0000`` (null character). :param escape: sets the single character used for escaping quotes inside an already quoted value. If None is set, it uses the default value, ``\`` :param escapeQuotes: a flag indicating whether values containing quotes should always diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 47c58bb28221..e2435e09af23 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -34,8 +34,9 @@ from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.streaming import DataStreamReader -from pyspark.sql.types import Row, DataType, StringType, StructType, _make_type_verifier, \ - _infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string +from pyspark.sql.types import Row, DataType, StringType, StructType, TimestampType, \ + _make_type_verifier, _infer_schema, _has_nulltype, _merge_type, _create_converter, \ + _parse_datatype_string from pyspark.sql.utils import install_exception_handler __all__ = ["SparkSession"] @@ -444,11 +445,34 @@ def _get_numpy_record_dtype(self, rec): record_type_list.append((str(col_names[i]), curr_type)) return np.dtype(record_type_list) if has_rec_fix else None - def _convert_from_pandas(self, pdf): + def _convert_from_pandas(self, pdf, schema, timezone): """ Convert a pandas.DataFrame to list of records that can be used to make a DataFrame :return list of records """ + if timezone is not None: + from pyspark.sql.types import _check_series_convert_timestamps_tz_local + copied = False + if isinstance(schema, StructType): + for field in schema: + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if isinstance(field.dataType, TimestampType): + s = _check_series_convert_timestamps_tz_local(pdf[field.name], timezone) + if not copied and s is not pdf[field.name]: + # Copy once if the series is modified to prevent the original Pandas + # DataFrame from being updated + pdf = pdf.copy() + copied = True + pdf[field.name] = s + else: + for column, series in pdf.iteritems(): + s = _check_series_convert_timestamps_tz_local(pdf[column], timezone) + if not copied and s is not pdf[column]: + # Copy once if the series is modified to prevent the original Pandas + # DataFrame from being updated + pdf = pdf.copy() + copied = True + pdf[column] = s # Convert pandas.DataFrame to list of numpy records np_records = pdf.to_records(index=False) @@ -462,15 +486,19 @@ def _convert_from_pandas(self, pdf): # Convert list of numpy records to python lists return [r.tolist() for r in np_records] - def _create_from_pandas_with_arrow(self, pdf, schema): + def _create_from_pandas_with_arrow(self, pdf, schema, timezone): """ Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the data types will be used to coerce the data in Pandas to Arrow conversion. """ from pyspark.serializers import ArrowSerializer, _create_batch - from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType - from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype + from pyspark.sql.types import from_arrow_schema, to_arrow_type, \ + _old_pandas_exception_message, TimestampType + try: + from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype + except ImportError as e: + raise ImportError(_old_pandas_exception_message(e)) # Determine arrow types to coerce data when creating batches if isinstance(schema, StructType): @@ -488,7 +516,8 @@ def _create_from_pandas_with_arrow(self, pdf, schema): pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step)) # Create Arrow record batches - batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]) + batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)], + timezone) for pdf_slice in pdf_slices] # Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing) @@ -606,6 +635,11 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr except Exception: has_pandas = False if has_pandas and isinstance(data, pandas.DataFrame): + if self.conf.get("spark.sql.execution.pandas.respectSessionTimeZone").lower() \ + == "true": + timezone = self.conf.get("spark.sql.session.timeZone") + else: + timezone = None # If no schema supplied by user then get the names of columns only if schema is None: @@ -614,11 +648,11 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \ and len(data) > 0: try: - return self._create_from_pandas_with_arrow(data, schema) + return self._create_from_pandas_with_arrow(data, schema, timezone) except Exception as e: warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e)) # Fallback to create DataFrame without arrow if raise some exception - data = self._convert_from_pandas(data) + data = self._convert_from_pandas(data, schema, timezone) if isinstance(schema, StructType): verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 762afe0d730f..b4d32d8de8a2 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -49,9 +49,14 @@ import unittest _have_pandas = False +_have_old_pandas = False try: import pandas - _have_pandas = True + try: + import pandas.api + _have_pandas = True + except: + _have_old_pandas = True except: # No Pandas, but that's okay, we'll skip those tests pass @@ -2565,21 +2570,38 @@ def count_bucketed_cols(names, table="pyspark_bucket"): .mode("overwrite").saveAsTable("pyspark_bucket")) self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) - @unittest.skipIf(not _have_pandas, "Pandas not installed") - def test_to_pandas(self): + def _to_pandas(self): + from datetime import datetime, date import numpy as np schema = StructType().add("a", IntegerType()).add("b", StringType())\ - .add("c", BooleanType()).add("d", FloatType()) + .add("c", BooleanType()).add("d", FloatType())\ + .add("dt", DateType()).add("ts", TimestampType()) data = [ - (1, "foo", True, 3.0), (2, "foo", True, 5.0), - (3, "bar", False, -1.0), (4, "bar", False, 6.0), + (1, "foo", True, 3.0, date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), + (2, "foo", True, 5.0, None, None), + (3, "bar", False, -1.0, date(2012, 3, 3), datetime(2012, 3, 3, 3, 3, 3)), + (4, "bar", False, 6.0, date(2100, 4, 4), datetime(2100, 4, 4, 4, 4, 4)), ] df = self.spark.createDataFrame(data, schema) - types = df.toPandas().dtypes + return df.toPandas() + + @unittest.skipIf(not _have_pandas, "Pandas not installed") + def test_to_pandas(self): + import numpy as np + pdf = self._to_pandas() + types = pdf.dtypes self.assertEquals(types[0], np.int32) self.assertEquals(types[1], np.object) self.assertEquals(types[2], np.bool) self.assertEquals(types[3], np.float32) + self.assertEquals(types[4], 'datetime64[ns]') + self.assertEquals(types[5], 'datetime64[ns]') + + @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed") + def test_to_pandas_old(self): + with QuietTest(self.sc): + with self.assertRaisesRegexp(ImportError, 'Pandas \(.*\) must be installed'): + self._to_pandas() @unittest.skipIf(not _have_pandas, "Pandas not installed") def test_to_pandas_avoid_astype(self): @@ -2614,6 +2636,16 @@ def test_create_dataframe_from_pandas_with_timestamp(self): self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) + @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed") + def test_create_dataframe_from_old_pandas(self): + import pandas as pd + from datetime import datetime + pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], + "d": [pd.Timestamp.now().date()]}) + with QuietTest(self.sc): + with self.assertRaisesRegexp(ImportError, 'Pandas \(.*\) must be installed'): + self.spark.createDataFrame(pdf) + class HiveSparkSubmitTests(SparkSubmitTests): @@ -3103,7 +3135,7 @@ def __init__(self, **kwargs): _make_type_verifier(data_type, nullable=False)(obj) -@unittest.skipIf(not _have_arrow, "Arrow not installed") +@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class ArrowTests(ReusedSQLTestCase): @classmethod @@ -3169,16 +3201,47 @@ def test_null_conversion(self): null_counts = pdf.isnull().sum().tolist() self.assertTrue(all([c == 1 for c in null_counts])) - def test_toPandas_arrow_toggle(self): - df = self.spark.createDataFrame(self.data, schema=self.schema) + def _toPandas_arrow_toggle(self, df): self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") try: pdf = df.toPandas() finally: self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") pdf_arrow = df.toPandas() + return pdf, pdf_arrow + + def test_toPandas_arrow_toggle(self): + df = self.spark.createDataFrame(self.data, schema=self.schema) + pdf, pdf_arrow = self._toPandas_arrow_toggle(df) self.assertFramesEqual(pdf_arrow, pdf) + def test_toPandas_respect_session_timezone(self): + df = self.spark.createDataFrame(self.data, schema=self.schema) + orig_tz = self.spark.conf.get("spark.sql.session.timeZone") + try: + timezone = "America/New_York" + self.spark.conf.set("spark.sql.session.timeZone", timezone) + self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") + try: + pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df) + self.assertFramesEqual(pdf_arrow_la, pdf_la) + finally: + self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") + pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df) + self.assertFramesEqual(pdf_arrow_ny, pdf_ny) + + self.assertFalse(pdf_ny.equals(pdf_la)) + + from pyspark.sql.types import _check_series_convert_timestamps_local_tz + pdf_la_corrected = pdf_la.copy() + for field in self.schema: + if isinstance(field.dataType, TimestampType): + pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz( + pdf_la_corrected[field.name], timezone) + self.assertFramesEqual(pdf_ny, pdf_la_corrected) + finally: + self.spark.conf.set("spark.sql.session.timeZone", orig_tz) + def test_pandas_round_trip(self): pdf = self.create_pandas_data_frame() df = self.spark.createDataFrame(self.data, schema=self.schema) @@ -3192,16 +3255,50 @@ def test_filtered_frame(self): self.assertEqual(pdf.columns[0], "i") self.assertTrue(pdf.empty) - def test_createDataFrame_toggle(self): - pdf = self.create_pandas_data_frame() + def _createDataFrame_toggle(self, pdf, schema=None): self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") try: - df_no_arrow = self.spark.createDataFrame(pdf) + df_no_arrow = self.spark.createDataFrame(pdf, schema=schema) finally: self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") - df_arrow = self.spark.createDataFrame(pdf) + df_arrow = self.spark.createDataFrame(pdf, schema=schema) + return df_no_arrow, df_arrow + + def test_createDataFrame_toggle(self): + pdf = self.create_pandas_data_frame() + df_no_arrow, df_arrow = self._createDataFrame_toggle(pdf, schema=self.schema) self.assertEquals(df_no_arrow.collect(), df_arrow.collect()) + def test_createDataFrame_respect_session_timezone(self): + from datetime import timedelta + pdf = self.create_pandas_data_frame() + orig_tz = self.spark.conf.get("spark.sql.session.timeZone") + try: + timezone = "America/New_York" + self.spark.conf.set("spark.sql.session.timeZone", timezone) + self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") + try: + df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema) + result_la = df_no_arrow_la.collect() + result_arrow_la = df_arrow_la.collect() + self.assertEqual(result_la, result_arrow_la) + finally: + self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") + df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, schema=self.schema) + result_ny = df_no_arrow_ny.collect() + result_arrow_ny = df_arrow_ny.collect() + self.assertEqual(result_ny, result_arrow_ny) + + self.assertNotEqual(result_ny, result_la) + + # Correct result_la by adjusting 3 hours difference between Los Angeles and New York + result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '7_timestamp_t' else v + for k, v in row.asDict().items()}) + for row in result_la] + self.assertEqual(result_ny, result_la_corrected) + finally: + self.spark.conf.set("spark.sql.session.timeZone", orig_tz) + def test_createDataFrame_with_schema(self): pdf = self.create_pandas_data_frame() df = self.spark.createDataFrame(pdf, schema=self.schema) @@ -3385,6 +3482,27 @@ def foo(k, v): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class VectorizedUDFTests(ReusedSQLTestCase): + @classmethod + def setUpClass(cls): + ReusedSQLTestCase.setUpClass() + + # Synchronize default timezone between Python and Java + cls.tz_prev = os.environ.get("TZ", None) # save current tz if set + tz = "America/Los_Angeles" + os.environ["TZ"] = tz + time.tzset() + + cls.sc.environment["TZ"] = tz + cls.spark.conf.set("spark.sql.session.timeZone", tz) + + @classmethod + def tearDownClass(cls): + del os.environ["TZ"] + if cls.tz_prev is not None: + os.environ["TZ"] = cls.tz_prev + time.tzset() + ReusedSQLTestCase.tearDownClass() + def test_vectorized_udf_basic(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10).select( @@ -3621,29 +3739,37 @@ def test_vectorized_udf_timestamps(self): data = [(0, datetime(1969, 1, 1, 1, 1, 1)), (1, datetime(2012, 2, 2, 2, 2, 2)), (2, None), - (3, datetime(2100, 4, 4, 4, 4, 4))] + (3, datetime(2100, 3, 3, 3, 3, 3))] + df = self.spark.createDataFrame(data, schema=schema) # Check that a timestamp passed through a pandas_udf will not be altered by timezone calc f_timestamp_copy = pandas_udf(lambda t: t, returnType=TimestampType()) df = df.withColumn("timestamp_copy", f_timestamp_copy(col("timestamp"))) - @pandas_udf(returnType=BooleanType()) + @pandas_udf(returnType=StringType()) def check_data(idx, timestamp, timestamp_copy): + import pandas as pd + msgs = [] is_equal = timestamp.isnull() # use this array to check values are equal for i in range(len(idx)): # Check that timestamps are as expected in the UDF - is_equal[i] = (is_equal[i] and data[idx[i]][1] is None) or \ - timestamp[i].to_pydatetime() == data[idx[i]][1] - return is_equal - - result = df.withColumn("is_equal", check_data(col("idx"), col("timestamp"), - col("timestamp_copy"))).collect() + if (is_equal[i] and data[idx[i]][1] is None) or \ + timestamp[i].to_pydatetime() == data[idx[i]][1]: + msgs.append(None) + else: + msgs.append( + "timestamp values are not equal (timestamp='%s': data[%d][1]='%s')" + % (timestamp[i], idx[i], data[idx[i]][1])) + return pd.Series(msgs) + + result = df.withColumn("check_data", check_data(col("idx"), col("timestamp"), + col("timestamp_copy"))).collect() # Check that collection values are correct self.assertEquals(len(data), len(result)) for i in range(len(result)): self.assertEquals(data[i][1], result[i][1]) # "timestamp" col - self.assertTrue(result[i][3]) # "is_equal" data in udf was as expected + self.assertIsNone(result[i][3]) # "check_data" col def test_vectorized_udf_return_timestamp_tz(self): from pyspark.sql.functions import pandas_udf, col @@ -3683,6 +3809,48 @@ def check_records_per_batch(x): else: self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value) + def test_vectorized_udf_timestamps_respect_session_timezone(self): + from pyspark.sql.functions import pandas_udf, col + from datetime import datetime + import pandas as pd + schema = StructType([ + StructField("idx", LongType(), True), + StructField("timestamp", TimestampType(), True)]) + data = [(1, datetime(1969, 1, 1, 1, 1, 1)), + (2, datetime(2012, 2, 2, 2, 2, 2)), + (3, None), + (4, datetime(2100, 3, 3, 3, 3, 3))] + df = self.spark.createDataFrame(data, schema=schema) + + f_timestamp_copy = pandas_udf(lambda ts: ts, TimestampType()) + internal_value = pandas_udf( + lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else None), LongType()) + + orig_tz = self.spark.conf.get("spark.sql.session.timeZone") + try: + timezone = "America/New_York" + self.spark.conf.set("spark.sql.session.timeZone", timezone) + self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") + try: + df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ + .withColumn("internal_value", internal_value(col("timestamp"))) + result_la = df_la.select(col("idx"), col("internal_value")).collect() + # Correct result_la by adjusting 3 hours difference between Los Angeles and New York + diff = 3 * 60 * 60 * 1000 * 1000 * 1000 + result_la_corrected = \ + df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect() + finally: + self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") + + df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ + .withColumn("internal_value", internal_value(col("timestamp"))) + result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect() + + self.assertNotEqual(result_ny, result_la) + self.assertEqual(result_ny, result_la_corrected) + finally: + self.spark.conf.set("spark.sql.session.timeZone", orig_tz) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedSQLTestCase): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index fe62f60dd6d0..78abc32a35a1 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -35,6 +35,7 @@ from pyspark import SparkContext from pyspark.serializers import CloudPickleSerializer +from pyspark.util import _exception_message __all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", @@ -1678,37 +1679,105 @@ def from_arrow_schema(arrow_schema): for field in arrow_schema]) -def _check_dataframe_localize_timestamps(pdf): +def _old_pandas_exception_message(e): + """ Create an error message for importing old Pandas. """ - Convert timezone aware timestamps to timezone-naive in local time + msg = "note: Pandas (>=0.19.2) must be installed and available on calling Python process" + return "%s\n%s" % (_exception_message(e), msg) + + +def _check_dataframe_localize_timestamps(pdf, timezone): + """ + Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone :param pdf: pandas.DataFrame - :return pandas.DataFrame where any timezone aware columns have be converted to tz-naive + :param timezone: the timezone to convert. if None then use local timezone + :return pandas.DataFrame where any timezone aware columns have been converted to tz-naive """ - from pandas.api.types import is_datetime64tz_dtype + try: + from pandas.api.types import is_datetime64tz_dtype + except ImportError as e: + raise ImportError(_old_pandas_exception_message(e)) + tz = timezone or 'tzlocal()' for column, series in pdf.iteritems(): # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64tz_dtype(series.dtype): - pdf[column] = series.dt.tz_convert('tzlocal()').dt.tz_localize(None) + pdf[column] = series.dt.tz_convert(tz).dt.tz_localize(None) return pdf -def _check_series_convert_timestamps_internal(s): +def _check_series_convert_timestamps_internal(s, timezone): """ - Convert a tz-naive timestamp in local tz to UTC normalized for Spark internal storage + Convert a tz-naive timestamp in the specified timezone or local timezone to UTC normalized for + Spark internal storage + :param s: a pandas.Series + :param timezone: the timezone to convert. if None then use local timezone :return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone """ - from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype + try: + from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype + except ImportError as e: + raise ImportError(_old_pandas_exception_message(e)) # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64_dtype(s.dtype): - return s.dt.tz_localize('tzlocal()').dt.tz_convert('UTC') + tz = timezone or 'tzlocal()' + return s.dt.tz_localize(tz).dt.tz_convert('UTC') elif is_datetime64tz_dtype(s.dtype): return s.dt.tz_convert('UTC') else: return s +def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone): + """ + Convert timestamp to timezone-naive in the specified timezone or local timezone + + :param s: a pandas.Series + :param from_timezone: the timezone to convert from. if None then use local timezone + :param to_timezone: the timezone to convert to. if None then use local timezone + :return pandas.Series where if it is a timestamp, has been converted to tz-naive + """ + try: + import pandas as pd + from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype + except ImportError as e: + raise ImportError(_old_pandas_exception_message(e)) + from_tz = from_timezone or 'tzlocal()' + to_tz = to_timezone or 'tzlocal()' + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if is_datetime64tz_dtype(s.dtype): + return s.dt.tz_convert(to_tz).dt.tz_localize(None) + elif is_datetime64_dtype(s.dtype) and from_tz != to_tz: + # `s.dt.tz_localize('tzlocal()')` doesn't work properly when including NaT. + return s.apply(lambda ts: ts.tz_localize(from_tz).tz_convert(to_tz).tz_localize(None) + if ts is not pd.NaT else pd.NaT) + else: + return s + + +def _check_series_convert_timestamps_local_tz(s, timezone): + """ + Convert timestamp to timezone-naive in the specified timezone or local timezone + + :param s: a pandas.Series + :param timezone: the timezone to convert to. if None then use local timezone + :return pandas.Series where if it is a timestamp, has been converted to tz-naive + """ + return _check_series_convert_timestamps_localize(s, None, timezone) + + +def _check_series_convert_timestamps_tz_local(s, timezone): + """ + Convert timestamp to timezone-naive in the specified timezone or local timezone + + :param s: a pandas.Series + :param timezone: the timezone to convert from. if None then use local timezone + :return pandas.Series where if it is a timestamp, has been converted to tz-naive + """ + return _check_series_convert_timestamps_localize(s, timezone, None) + + def _test(): import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 939643071943..e6737ae1c128 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -150,7 +150,8 @@ def read_udfs(pickleSer, infile, eval_type): if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF \ or eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: - ser = ArrowStreamPandasSerializer() + timezone = utf8_deserializer.loads(infile) + ser = ArrowStreamPandasSerializer(timezone) else: ser = BatchedSerializer(PickleSerializer(), 100) diff --git a/python/setup.py b/python/setup.py index 02612ff8a724..310670e697a8 100644 --- a/python/setup.py +++ b/python/setup.py @@ -201,7 +201,7 @@ def _supports_symlinks(): extras_require={ 'ml': ['numpy>=1.7'], 'mllib': ['numpy>=1.7'], - 'sql': ['pandas>=0.13.0'] + 'sql': ['pandas>=0.19.2'] }, classifiers=[ 'Development Status :: 5 - Production/Stable', diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml new file mode 100644 index 000000000000..7d35aea8a414 --- /dev/null +++ b/resource-managers/kubernetes/core/pom.xml @@ -0,0 +1,100 @@ + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.3.0-SNAPSHOT + ../../../pom.xml + + + spark-kubernetes_2.11 + jar + Spark Project Kubernetes + + kubernetes + 3.0.0 + + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + + io.fabric8 + kubernetes-client + ${kubernetes.client.version} + + + com.fasterxml.jackson.core + * + + + com.fasterxml.jackson.dataformat + jackson-dataformat-yaml + + + + + + + com.fasterxml.jackson.dataformat + jackson-dataformat-yaml + ${fasterxml.jackson.version} + + + + + com.google.guava + guava + + + + + org.mockito + mockito-core + test + + + + com.squareup.okhttp3 + okhttp + 3.8.1 + + + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala new file mode 100644 index 000000000000..f0742b91987b --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.ConfigBuilder +import org.apache.spark.network.util.ByteUnit + +private[spark] object Config extends Logging { + + val KUBERNETES_NAMESPACE = + ConfigBuilder("spark.kubernetes.namespace") + .doc("The namespace that will be used for running the driver and executor pods. When using " + + "spark-submit in cluster mode, this can also be passed to spark-submit via the " + + "--kubernetes-namespace command line argument.") + .stringConf + .createWithDefault("default") + + val EXECUTOR_DOCKER_IMAGE = + ConfigBuilder("spark.kubernetes.executor.docker.image") + .doc("Docker image to use for the executors. Specify this using the standard Docker tag " + + "format.") + .stringConf + .createOptional + + val DOCKER_IMAGE_PULL_POLICY = + ConfigBuilder("spark.kubernetes.docker.image.pullPolicy") + .doc("Kubernetes image pull policy. Valid values are Always, Never, and IfNotPresent.") + .stringConf + .checkValues(Set("Always", "Never", "IfNotPresent")) + .createWithDefault("IfNotPresent") + + val APISERVER_AUTH_DRIVER_CONF_PREFIX = + "spark.kubernetes.authenticate.driver" + val APISERVER_AUTH_DRIVER_MOUNTED_CONF_PREFIX = + "spark.kubernetes.authenticate.driver.mounted" + val OAUTH_TOKEN_CONF_SUFFIX = "oauthToken" + val OAUTH_TOKEN_FILE_CONF_SUFFIX = "oauthTokenFile" + val CLIENT_KEY_FILE_CONF_SUFFIX = "clientKeyFile" + val CLIENT_CERT_FILE_CONF_SUFFIX = "clientCertFile" + val CA_CERT_FILE_CONF_SUFFIX = "caCertFile" + + val KUBERNETES_SERVICE_ACCOUNT_NAME = + ConfigBuilder(s"$APISERVER_AUTH_DRIVER_CONF_PREFIX.serviceAccountName") + .doc("Service account that is used when running the driver pod. The driver pod uses " + + "this service account when requesting executor pods from the API server. If specific " + + "credentials are given for the driver pod to use, the driver will favor " + + "using those credentials instead.") + .stringConf + .createOptional + + // Note that while we set a default for this when we start up the + // scheduler, the specific default value is dynamically determined + // based on the executor memory. + val KUBERNETES_EXECUTOR_MEMORY_OVERHEAD = + ConfigBuilder("spark.kubernetes.executor.memoryOverhead") + .doc("The amount of off-heap memory (in megabytes) to be allocated per executor. This " + + "is memory that accounts for things like VM overheads, interned strings, other native " + + "overheads, etc. This tends to grow with the executor size. (typically 6-10%).") + .bytesConf(ByteUnit.MiB) + .createOptional + + val KUBERNETES_EXECUTOR_LABEL_PREFIX = "spark.kubernetes.executor.label." + val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." + + val KUBERNETES_DRIVER_POD_NAME = + ConfigBuilder("spark.kubernetes.driver.pod.name") + .doc("Name of the driver pod.") + .stringConf + .createOptional + + val KUBERNETES_EXECUTOR_POD_NAME_PREFIX = + ConfigBuilder("spark.kubernetes.executor.podNamePrefix") + .doc("Prefix to use in front of the executor pod names.") + .internal() + .stringConf + .createWithDefault("spark") + + val KUBERNETES_ALLOCATION_BATCH_SIZE = + ConfigBuilder("spark.kubernetes.allocation.batch.size") + .doc("Number of pods to launch at once in each round of executor allocation.") + .intConf + .checkValue(value => value > 0, "Allocation batch size should be a positive integer") + .createWithDefault(5) + + val KUBERNETES_ALLOCATION_BATCH_DELAY = + ConfigBuilder("spark.kubernetes.allocation.batch.delay") + .doc("Number of seconds to wait between each round of executor allocation.") + .longConf + .checkValue(value => value > 0, "Allocation batch delay should be a positive integer") + .createWithDefault(1) + + val KUBERNETES_EXECUTOR_LIMIT_CORES = + ConfigBuilder("spark.kubernetes.executor.limit.cores") + .doc("Specify the hard cpu limit for a single executor pod") + .stringConf + .createOptional + + val KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS = + ConfigBuilder("spark.kubernetes.executor.lostCheck.maxAttempts") + .doc("Maximum number of attempts allowed for checking the reason of an executor loss " + + "before it is assumed that the executor failed.") + .intConf + .checkValue(value => value > 0, "Maximum attempts of checks of executor lost reason " + + "must be a positive integer") + .createWithDefault(10) + + val KUBERNETES_NODE_SELECTOR_PREFIX = "spark.kubernetes.node.selector." +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/ConfigurationUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/ConfigurationUtils.scala new file mode 100644 index 000000000000..01717479fddd --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/ConfigurationUtils.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.k8s + +import org.apache.spark.SparkConf + +private[spark] object ConfigurationUtils { + + /** + * Extract and parse Spark configuration properties with a given name prefix and + * return the result as a Map. Keys must not have more than one value. + * + * @param sparkConf Spark configuration + * @param prefix the given property name prefix + * @return a Map storing the configuration property keys and values + */ + def parsePrefixedKeyValuePairs( + sparkConf: SparkConf, + prefix: String): Map[String, String] = { + sparkConf.getAllWithPrefix(prefix).toMap + } + + def requireNandDefined(opt1: Option[_], opt2: Option[_], errMessage: String): Unit = { + opt1.foreach { _ => require(opt2.isEmpty, errMessage) } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala new file mode 100644 index 000000000000..4ddeefb15a89 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +private[spark] object Constants { + + // Labels + val SPARK_APP_ID_LABEL = "spark-app-selector" + val SPARK_EXECUTOR_ID_LABEL = "spark-exec-id" + val SPARK_ROLE_LABEL = "spark-role" + val SPARK_POD_DRIVER_ROLE = "driver" + val SPARK_POD_EXECUTOR_ROLE = "executor" + + // Default and fixed ports + val DEFAULT_DRIVER_PORT = 7078 + val DEFAULT_BLOCKMANAGER_PORT = 7079 + val BLOCK_MANAGER_PORT_NAME = "blockmanager" + val EXECUTOR_PORT_NAME = "executor" + + // Environment Variables + val ENV_EXECUTOR_PORT = "SPARK_EXECUTOR_PORT" + val ENV_DRIVER_URL = "SPARK_DRIVER_URL" + val ENV_EXECUTOR_CORES = "SPARK_EXECUTOR_CORES" + val ENV_EXECUTOR_MEMORY = "SPARK_EXECUTOR_MEMORY" + val ENV_APPLICATION_ID = "SPARK_APPLICATION_ID" + val ENV_EXECUTOR_ID = "SPARK_EXECUTOR_ID" + val ENV_EXECUTOR_POD_IP = "SPARK_EXECUTOR_POD_IP" + val ENV_EXECUTOR_EXTRA_CLASSPATH = "SPARK_EXECUTOR_EXTRA_CLASSPATH" + val ENV_MOUNTED_CLASSPATH = "SPARK_MOUNTED_CLASSPATH" + val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_" + + // Miscellaneous + val KUBERNETES_MASTER_INTERNAL_URL = "https://kubernetes.default.svc" + val MEMORY_OVERHEAD_FACTOR = 0.10 + val MEMORY_OVERHEAD_MIN_MIB = 384L +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala new file mode 100644 index 000000000000..1e3f055e0576 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import java.io.File + +import com.google.common.base.Charsets +import com.google.common.io.Files +import io.fabric8.kubernetes.client.{ConfigBuilder, DefaultKubernetesClient, KubernetesClient} +import io.fabric8.kubernetes.client.utils.HttpClientUtils +import okhttp3.Dispatcher + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.util.ThreadUtils + +/** + * Spark-opinionated builder for Kubernetes clients. It uses a prefix plus common suffixes to + * parse configuration keys, similar to the manner in which Spark's SecurityManager parses SSL + * options for different components. + */ +private[spark] object SparkKubernetesClientFactory { + + def createKubernetesClient( + master: String, + namespace: Option[String], + kubernetesAuthConfPrefix: String, + sparkConf: SparkConf, + defaultServiceAccountToken: Option[File], + defaultServiceAccountCaCert: Option[File]): KubernetesClient = { + val oauthTokenFileConf = s"$kubernetesAuthConfPrefix.$OAUTH_TOKEN_FILE_CONF_SUFFIX" + val oauthTokenConf = s"$kubernetesAuthConfPrefix.$OAUTH_TOKEN_CONF_SUFFIX" + val oauthTokenFile = sparkConf.getOption(oauthTokenFileConf) + .map(new File(_)) + .orElse(defaultServiceAccountToken) + val oauthTokenValue = sparkConf.getOption(oauthTokenConf) + ConfigurationUtils.requireNandDefined( + oauthTokenFile, + oauthTokenValue, + s"Cannot specify OAuth token through both a file $oauthTokenFileConf and a " + + s"value $oauthTokenConf.") + + val caCertFile = sparkConf + .getOption(s"$kubernetesAuthConfPrefix.$CA_CERT_FILE_CONF_SUFFIX") + .orElse(defaultServiceAccountCaCert.map(_.getAbsolutePath)) + val clientKeyFile = sparkConf + .getOption(s"$kubernetesAuthConfPrefix.$CLIENT_KEY_FILE_CONF_SUFFIX") + val clientCertFile = sparkConf + .getOption(s"$kubernetesAuthConfPrefix.$CLIENT_CERT_FILE_CONF_SUFFIX") + val dispatcher = new Dispatcher( + ThreadUtils.newDaemonCachedThreadPool("kubernetes-dispatcher")) + val config = new ConfigBuilder() + .withApiVersion("v1") + .withMasterUrl(master) + .withWebsocketPingInterval(0) + .withOption(oauthTokenValue) { + (token, configBuilder) => configBuilder.withOauthToken(token) + }.withOption(oauthTokenFile) { + (file, configBuilder) => + configBuilder.withOauthToken(Files.toString(file, Charsets.UTF_8)) + }.withOption(caCertFile) { + (file, configBuilder) => configBuilder.withCaCertFile(file) + }.withOption(clientKeyFile) { + (file, configBuilder) => configBuilder.withClientKeyFile(file) + }.withOption(clientCertFile) { + (file, configBuilder) => configBuilder.withClientCertFile(file) + }.withOption(namespace) { + (ns, configBuilder) => configBuilder.withNamespace(ns) + }.build() + val baseHttpClient = HttpClientUtils.createHttpClient(config) + val httpClientWithCustomDispatcher = baseHttpClient.newBuilder() + .dispatcher(dispatcher) + .build() + new DefaultKubernetesClient(httpClientWithCustomDispatcher, config) + } + + private implicit class OptionConfigurableConfigBuilder(val configBuilder: ConfigBuilder) + extends AnyVal { + + def withOption[T] + (option: Option[T]) + (configurator: ((T, ConfigBuilder) => ConfigBuilder)): ConfigBuilder = { + option.map { opt => + configurator(opt, configBuilder) + }.getOrElse(configBuilder) + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala new file mode 100644 index 000000000000..f79155b117b6 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model._ + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.ConfigurationUtils +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.util.Utils + +/** + * A factory class for configuring and creating executor pods. + */ +private[spark] trait ExecutorPodFactory { + + /** + * Configure and construct an executor pod with the given parameters. + */ + def createExecutorPod( + executorId: String, + applicationId: String, + driverUrl: String, + executorEnvs: Seq[(String, String)], + driverPod: Pod, + nodeToLocalTaskCount: Map[String, Int]): Pod +} + +private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) + extends ExecutorPodFactory { + + private val executorExtraClasspath = + sparkConf.get(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH) + + private val executorLabels = ConfigurationUtils.parsePrefixedKeyValuePairs( + sparkConf, + KUBERNETES_EXECUTOR_LABEL_PREFIX) + require( + !executorLabels.contains(SPARK_APP_ID_LABEL), + s"Custom executor labels cannot contain $SPARK_APP_ID_LABEL as it is reserved for Spark.") + require( + !executorLabels.contains(SPARK_EXECUTOR_ID_LABEL), + s"Custom executor labels cannot contain $SPARK_EXECUTOR_ID_LABEL as it is reserved for" + + " Spark.") + require( + !executorLabels.contains(SPARK_ROLE_LABEL), + s"Custom executor labels cannot contain $SPARK_ROLE_LABEL as it is reserved for Spark.") + + private val executorAnnotations = + ConfigurationUtils.parsePrefixedKeyValuePairs( + sparkConf, + KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) + private val nodeSelector = + ConfigurationUtils.parsePrefixedKeyValuePairs( + sparkConf, + KUBERNETES_NODE_SELECTOR_PREFIX) + + private val executorDockerImage = sparkConf + .get(EXECUTOR_DOCKER_IMAGE) + .getOrElse(throw new SparkException("Must specify the executor Docker image")) + private val dockerImagePullPolicy = sparkConf.get(DOCKER_IMAGE_PULL_POLICY) + private val blockManagerPort = sparkConf + .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT) + + private val executorPodNamePrefix = sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX) + + private val executorMemoryMiB = sparkConf.get(org.apache.spark.internal.config.EXECUTOR_MEMORY) + private val executorMemoryString = sparkConf.get( + org.apache.spark.internal.config.EXECUTOR_MEMORY.key, + org.apache.spark.internal.config.EXECUTOR_MEMORY.defaultValueString) + + private val memoryOverheadMiB = sparkConf + .get(KUBERNETES_EXECUTOR_MEMORY_OVERHEAD) + .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * executorMemoryMiB).toInt, + MEMORY_OVERHEAD_MIN_MIB)) + private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB + + private val executorCores = sparkConf.getDouble("spark.executor.cores", 1) + private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) + + override def createExecutorPod( + executorId: String, + applicationId: String, + driverUrl: String, + executorEnvs: Seq[(String, String)], + driverPod: Pod, + nodeToLocalTaskCount: Map[String, Int]): Pod = { + val name = s"$executorPodNamePrefix-exec-$executorId" + + // hostname must be no longer than 63 characters, so take the last 63 characters of the pod + // name as the hostname. This preserves uniqueness since the end of name contains + // executorId + val hostname = name.substring(Math.max(0, name.length - 63)) + val resolvedExecutorLabels = Map( + SPARK_EXECUTOR_ID_LABEL -> executorId, + SPARK_APP_ID_LABEL -> applicationId, + SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ + executorLabels + val executorMemoryQuantity = new QuantityBuilder(false) + .withAmount(s"${executorMemoryMiB}Mi") + .build() + val executorMemoryLimitQuantity = new QuantityBuilder(false) + .withAmount(s"${executorMemoryWithOverhead}Mi") + .build() + val executorCpuQuantity = new QuantityBuilder(false) + .withAmount(executorCores.toString) + .build() + val executorExtraClasspathEnv = executorExtraClasspath.map { cp => + new EnvVarBuilder() + .withName(ENV_EXECUTOR_EXTRA_CLASSPATH) + .withValue(cp) + .build() + } + val executorExtraJavaOptionsEnv = sparkConf + .get(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS) + .map { opts => + val delimitedOpts = Utils.splitCommandString(opts) + delimitedOpts.zipWithIndex.map { + case (opt, index) => + new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build() + } + }.getOrElse(Seq.empty[EnvVar]) + val executorEnv = (Seq( + (ENV_DRIVER_URL, driverUrl), + // Executor backend expects integral value for executor cores, so round it up to an int. + (ENV_EXECUTOR_CORES, math.ceil(executorCores).toInt.toString), + (ENV_EXECUTOR_MEMORY, executorMemoryString), + (ENV_APPLICATION_ID, applicationId), + (ENV_EXECUTOR_ID, executorId)) ++ executorEnvs) + .map(env => new EnvVarBuilder() + .withName(env._1) + .withValue(env._2) + .build() + ) ++ Seq( + new EnvVarBuilder() + .withName(ENV_EXECUTOR_POD_IP) + .withValueFrom(new EnvVarSourceBuilder() + .withNewFieldRef("v1", "status.podIP") + .build()) + .build() + ) ++ executorExtraJavaOptionsEnv ++ executorExtraClasspathEnv.toSeq + val requiredPorts = Seq( + (BLOCK_MANAGER_PORT_NAME, blockManagerPort)) + .map { case (name, port) => + new ContainerPortBuilder() + .withName(name) + .withContainerPort(port) + .build() + } + + val executorContainer = new ContainerBuilder() + .withName("executor") + .withImage(executorDockerImage) + .withImagePullPolicy(dockerImagePullPolicy) + .withNewResources() + .addToRequests("memory", executorMemoryQuantity) + .addToLimits("memory", executorMemoryLimitQuantity) + .addToRequests("cpu", executorCpuQuantity) + .endResources() + .addAllToEnv(executorEnv.asJava) + .withPorts(requiredPorts.asJava) + .build() + + val executorPod = new PodBuilder() + .withNewMetadata() + .withName(name) + .withLabels(resolvedExecutorLabels.asJava) + .withAnnotations(executorAnnotations.asJava) + .withOwnerReferences() + .addNewOwnerReference() + .withController(true) + .withApiVersion(driverPod.getApiVersion) + .withKind(driverPod.getKind) + .withName(driverPod.getMetadata.getName) + .withUid(driverPod.getMetadata.getUid) + .endOwnerReference() + .endMetadata() + .withNewSpec() + .withHostname(hostname) + .withRestartPolicy("Never") + .withNodeSelector(nodeSelector.asJava) + .endSpec() + .build() + + val containerWithExecutorLimitCores = executorLimitCores.map { limitCores => + val executorCpuLimitQuantity = new QuantityBuilder(false) + .withAmount(limitCores) + .build() + new ContainerBuilder(executorContainer) + .editResources() + .addToLimits("cpu", executorCpuLimitQuantity) + .endResources() + .build() + }.getOrElse(executorContainer) + + new PodBuilder(executorPod) + .editSpec() + .addToContainers(containerWithExecutorLimitCores) + .endSpec() + .build() + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala new file mode 100644 index 000000000000..68ca6a762217 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.io.File + +import io.fabric8.kubernetes.client.Config + +import org.apache.spark.SparkContext +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} +import org.apache.spark.util.ThreadUtils + +private[spark] class KubernetesClusterManager extends ExternalClusterManager with Logging { + + override def canCreate(masterURL: String): Boolean = masterURL.startsWith("k8s") + + override def createTaskScheduler(sc: SparkContext, masterURL: String): TaskScheduler = { + new TaskSchedulerImpl(sc) + } + + override def createSchedulerBackend( + sc: SparkContext, + masterURL: String, + scheduler: TaskScheduler): SchedulerBackend = { + val sparkConf = sc.getConf + + val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( + KUBERNETES_MASTER_INTERNAL_URL, + Some(sparkConf.get(KUBERNETES_NAMESPACE)), + APISERVER_AUTH_DRIVER_MOUNTED_CONF_PREFIX, + sparkConf, + Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), + Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) + + val executorPodFactory = new ExecutorPodFactoryImpl(sparkConf) + val allocatorExecutor = ThreadUtils + .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") + val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( + "kubernetes-executor-requests") + new KubernetesClusterSchedulerBackend( + scheduler.asInstanceOf[TaskSchedulerImpl], + sc.env.rpcEnv, + executorPodFactory, + kubernetesClient, + allocatorExecutor, + requestExecutorsService) + } + + override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { + scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala new file mode 100644 index 000000000000..e79c987852db --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -0,0 +1,442 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.io.Closeable +import java.net.InetAddress +import java.util.concurrent.{ConcurrentHashMap, ExecutorService, ScheduledExecutorService, TimeUnit} +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, AtomicReference} +import javax.annotation.concurrent.GuardedBy + +import io.fabric8.kubernetes.api.model._ +import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.concurrent.{ExecutionContext, Future} + +import org.apache.spark.SparkException +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress, RpcEnv} +import org.apache.spark.scheduler.{ExecutorExited, SlaveLost, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SchedulerBackendUtils} +import org.apache.spark.util.Utils + +private[spark] class KubernetesClusterSchedulerBackend( + scheduler: TaskSchedulerImpl, + rpcEnv: RpcEnv, + executorPodFactory: ExecutorPodFactory, + kubernetesClient: KubernetesClient, + allocatorExecutor: ScheduledExecutorService, + requestExecutorsService: ExecutorService) + extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) { + + import KubernetesClusterSchedulerBackend._ + + private val EXECUTOR_ID_COUNTER = new AtomicLong(0L) + private val RUNNING_EXECUTOR_PODS_LOCK = new Object + @GuardedBy("RUNNING_EXECUTOR_PODS_LOCK") + private val runningExecutorsToPods = new mutable.HashMap[String, Pod] + private val executorPodsByIPs = new ConcurrentHashMap[String, Pod]() + private val podsWithKnownExitReasons = new ConcurrentHashMap[String, ExecutorExited]() + private val disconnectedPodsByExecutorIdPendingRemoval = new ConcurrentHashMap[String, Pod]() + + private val kubernetesNamespace = conf.get(KUBERNETES_NAMESPACE) + + private val kubernetesDriverPodName = conf + .get(KUBERNETES_DRIVER_POD_NAME) + .getOrElse(throw new SparkException("Must specify the driver pod name")) + private implicit val requestExecutorContext = ExecutionContext.fromExecutorService( + requestExecutorsService) + + private val driverPod = kubernetesClient.pods() + .inNamespace(kubernetesNamespace) + .withName(kubernetesDriverPodName) + .get() + + protected override val minRegisteredRatio = + if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { + 0.8 + } else { + super.minRegisteredRatio + } + + private val executorWatchResource = new AtomicReference[Closeable] + private val totalExpectedExecutors = new AtomicInteger(0) + + private val driverUrl = RpcEndpointAddress( + conf.get("spark.driver.host"), + conf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString + + private val initialExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf) + + private val podAllocationInterval = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) + + private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) + + private val executorLostReasonCheckMaxAttempts = conf.get( + KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS) + + private val allocatorRunnable = new Runnable { + + // Maintains a map of executor id to count of checks performed to learn the loss reason + // for an executor. + private val executorReasonCheckAttemptCounts = new mutable.HashMap[String, Int] + + override def run(): Unit = { + handleDisconnectedExecutors() + + val executorsToAllocate = mutable.Map[String, Pod]() + val currentTotalRegisteredExecutors = totalRegisteredExecutors.get + val currentTotalExpectedExecutors = totalExpectedExecutors.get + val currentNodeToLocalTaskCount = getNodesWithLocalTaskCounts() + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + if (currentTotalRegisteredExecutors < runningExecutorsToPods.size) { + logDebug("Waiting for pending executors before scaling") + } else if (currentTotalExpectedExecutors <= runningExecutorsToPods.size) { + logDebug("Maximum allowed executor limit reached. Not scaling up further.") + } else { + for (_ <- 0 until math.min( + currentTotalExpectedExecutors - runningExecutorsToPods.size, podAllocationSize)) { + val executorId = EXECUTOR_ID_COUNTER.incrementAndGet().toString + val executorPod = executorPodFactory.createExecutorPod( + executorId, + applicationId(), + driverUrl, + conf.getExecutorEnv, + driverPod, + currentNodeToLocalTaskCount) + executorsToAllocate(executorId) = executorPod + logInfo( + s"Requesting a new executor, total executors is now ${runningExecutorsToPods.size}") + } + } + } + + val allocatedExecutors = executorsToAllocate.mapValues { pod => + Utils.tryLog { + kubernetesClient.pods().create(pod) + } + } + + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + allocatedExecutors.map { + case (executorId, attemptedAllocatedExecutor) => + attemptedAllocatedExecutor.map { successfullyAllocatedExecutor => + runningExecutorsToPods.put(executorId, successfullyAllocatedExecutor) + } + } + } + } + + def handleDisconnectedExecutors(): Unit = { + // For each disconnected executor, synchronize with the loss reasons that may have been found + // by the executor pod watcher. If the loss reason was discovered by the watcher, + // inform the parent class with removeExecutor. + disconnectedPodsByExecutorIdPendingRemoval.asScala.foreach { + case (executorId, executorPod) => + val knownExitReason = Option(podsWithKnownExitReasons.remove( + executorPod.getMetadata.getName)) + knownExitReason.fold { + removeExecutorOrIncrementLossReasonCheckCount(executorId) + } { executorExited => + logWarning(s"Removing executor $executorId with loss reason " + executorExited.message) + removeExecutor(executorId, executorExited) + // We don't delete the pod running the executor that has an exit condition caused by + // the application from the Kubernetes API server. This allows users to debug later on + // through commands such as "kubectl logs " and + // "kubectl describe pod ". Note that exited containers have terminated and + // therefore won't take CPU and memory resources. + // Otherwise, the executor pod is marked to be deleted from the API server. + if (executorExited.exitCausedByApp) { + logInfo(s"Executor $executorId exited because of the application.") + deleteExecutorFromDataStructures(executorId) + } else { + logInfo(s"Executor $executorId failed because of a framework error.") + deleteExecutorFromClusterAndDataStructures(executorId) + } + } + } + } + + def removeExecutorOrIncrementLossReasonCheckCount(executorId: String): Unit = { + val reasonCheckCount = executorReasonCheckAttemptCounts.getOrElse(executorId, 0) + if (reasonCheckCount >= executorLostReasonCheckMaxAttempts) { + removeExecutor(executorId, SlaveLost("Executor lost for unknown reasons.")) + deleteExecutorFromClusterAndDataStructures(executorId) + } else { + executorReasonCheckAttemptCounts.put(executorId, reasonCheckCount + 1) + } + } + + def deleteExecutorFromClusterAndDataStructures(executorId: String): Unit = { + deleteExecutorFromDataStructures(executorId).foreach { pod => + kubernetesClient.pods().delete(pod) + } + } + + def deleteExecutorFromDataStructures(executorId: String): Option[Pod] = { + disconnectedPodsByExecutorIdPendingRemoval.remove(executorId) + executorReasonCheckAttemptCounts -= executorId + podsWithKnownExitReasons.remove(executorId) + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + runningExecutorsToPods.remove(executorId).orElse { + logWarning(s"Unable to remove pod for unknown executor $executorId") + None + } + } + } + } + + override def sufficientResourcesRegistered(): Boolean = { + totalRegisteredExecutors.get() >= initialExecutors * minRegisteredRatio + } + + override def start(): Unit = { + super.start() + executorWatchResource.set( + kubernetesClient + .pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId()) + .watch(new ExecutorPodsWatcher())) + + allocatorExecutor.scheduleWithFixedDelay( + allocatorRunnable, 0L, podAllocationInterval, TimeUnit.SECONDS) + + if (!Utils.isDynamicAllocationEnabled(conf)) { + doRequestTotalExecutors(initialExecutors) + } + } + + override def stop(): Unit = { + // stop allocation of new resources and caches. + allocatorExecutor.shutdown() + allocatorExecutor.awaitTermination(30, TimeUnit.SECONDS) + + // send stop message to executors so they shut down cleanly + super.stop() + + try { + val resource = executorWatchResource.getAndSet(null) + if (resource != null) { + resource.close() + } + } catch { + case e: Throwable => logWarning("Failed to close the executor pod watcher", e) + } + + // then delete the executor pods + Utils.tryLogNonFatalError { + deleteExecutorPodsOnStop() + executorPodsByIPs.clear() + } + Utils.tryLogNonFatalError { + logInfo("Closing kubernetes client") + kubernetesClient.close() + } + } + + /** + * @return A map of K8s cluster nodes to the number of tasks that could benefit from data + * locality if an executor launches on the cluster node. + */ + private def getNodesWithLocalTaskCounts() : Map[String, Int] = { + val nodeToLocalTaskCount = synchronized { + mutable.Map[String, Int]() ++ hostToLocalTaskCount + } + + for (pod <- executorPodsByIPs.values().asScala) { + // Remove cluster nodes that are running our executors already. + // TODO: This prefers spreading out executors across nodes. In case users want + // consolidating executors on fewer nodes, introduce a flag. See the spark.deploy.spreadOut + // flag that Spark standalone has: https://spark.apache.org/docs/latest/spark-standalone.html + nodeToLocalTaskCount.remove(pod.getSpec.getNodeName).nonEmpty || + nodeToLocalTaskCount.remove(pod.getStatus.getHostIP).nonEmpty || + nodeToLocalTaskCount.remove( + InetAddress.getByName(pod.getStatus.getHostIP).getCanonicalHostName).nonEmpty + } + nodeToLocalTaskCount.toMap[String, Int] + } + + override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = Future[Boolean] { + totalExpectedExecutors.set(requestedTotal) + true + } + + override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] { + val podsToDelete = RUNNING_EXECUTOR_PODS_LOCK.synchronized { + executorIds.flatMap { executorId => + runningExecutorsToPods.remove(executorId) match { + case Some(pod) => + disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) + Some(pod) + + case None => + logWarning(s"Unable to remove pod for unknown executor $executorId") + None + } + } + } + + kubernetesClient.pods().delete(podsToDelete: _*) + true + } + + private def deleteExecutorPodsOnStop(): Unit = { + val executorPodsToDelete = RUNNING_EXECUTOR_PODS_LOCK.synchronized { + val runningExecutorPodsCopy = Seq(runningExecutorsToPods.values.toSeq: _*) + runningExecutorsToPods.clear() + runningExecutorPodsCopy + } + kubernetesClient.pods().delete(executorPodsToDelete: _*) + } + + private class ExecutorPodsWatcher extends Watcher[Pod] { + + private val DEFAULT_CONTAINER_FAILURE_EXIT_STATUS = -1 + + override def eventReceived(action: Action, pod: Pod): Unit = { + val podName = pod.getMetadata.getName + val podIP = pod.getStatus.getPodIP + + action match { + case Action.MODIFIED if (pod.getStatus.getPhase == "Running" + && pod.getMetadata.getDeletionTimestamp == null) => + val clusterNodeName = pod.getSpec.getNodeName + logInfo(s"Executor pod $podName ready, launched at $clusterNodeName as IP $podIP.") + executorPodsByIPs.put(podIP, pod) + + case Action.DELETED | Action.ERROR => + val executorId = getExecutorId(pod) + logDebug(s"Executor pod $podName at IP $podIP was at $action.") + if (podIP != null) { + executorPodsByIPs.remove(podIP) + } + + val executorExitReason = if (action == Action.ERROR) { + logWarning(s"Received error event of executor pod $podName. Reason: " + + pod.getStatus.getReason) + executorExitReasonOnError(pod) + } else if (action == Action.DELETED) { + logWarning(s"Received delete event of executor pod $podName. Reason: " + + pod.getStatus.getReason) + executorExitReasonOnDelete(pod) + } else { + throw new IllegalStateException( + s"Unknown action that should only be DELETED or ERROR: $action") + } + podsWithKnownExitReasons.put(pod.getMetadata.getName, executorExitReason) + + if (!disconnectedPodsByExecutorIdPendingRemoval.containsKey(executorId)) { + log.warn(s"Executor with id $executorId was not marked as disconnected, but the " + + s"watch received an event of type $action for this executor. The executor may " + + "have failed to start in the first place and never registered with the driver.") + } + disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) + + case _ => logDebug(s"Received event of executor pod $podName: " + action) + } + } + + override def onClose(cause: KubernetesClientException): Unit = { + logDebug("Executor pod watch closed.", cause) + } + + private def getExecutorExitStatus(pod: Pod): Int = { + val containerStatuses = pod.getStatus.getContainerStatuses + if (!containerStatuses.isEmpty) { + // we assume the first container represents the pod status. This assumption may not hold + // true in the future. Revisit this if side-car containers start running inside executor + // pods. + getExecutorExitStatus(containerStatuses.get(0)) + } else DEFAULT_CONTAINER_FAILURE_EXIT_STATUS + } + + private def getExecutorExitStatus(containerStatus: ContainerStatus): Int = { + Option(containerStatus.getState).map { containerState => + Option(containerState.getTerminated).map { containerStateTerminated => + containerStateTerminated.getExitCode.intValue() + }.getOrElse(UNKNOWN_EXIT_CODE) + }.getOrElse(UNKNOWN_EXIT_CODE) + } + + private def isPodAlreadyReleased(pod: Pod): Boolean = { + val executorId = pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + !runningExecutorsToPods.contains(executorId) + } + } + + private def executorExitReasonOnError(pod: Pod): ExecutorExited = { + val containerExitStatus = getExecutorExitStatus(pod) + // container was probably actively killed by the driver. + if (isPodAlreadyReleased(pod)) { + ExecutorExited(containerExitStatus, exitCausedByApp = false, + s"Container in pod ${pod.getMetadata.getName} exited from explicit termination " + + "request.") + } else { + val containerExitReason = s"Pod ${pod.getMetadata.getName}'s executor container " + + s"exited with exit status code $containerExitStatus." + ExecutorExited(containerExitStatus, exitCausedByApp = true, containerExitReason) + } + } + + private def executorExitReasonOnDelete(pod: Pod): ExecutorExited = { + val exitMessage = if (isPodAlreadyReleased(pod)) { + s"Container in pod ${pod.getMetadata.getName} exited from explicit termination request." + } else { + s"Pod ${pod.getMetadata.getName} deleted or lost." + } + ExecutorExited(getExecutorExitStatus(pod), exitCausedByApp = false, exitMessage) + } + + private def getExecutorId(pod: Pod): String = { + val executorId = pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) + require(executorId != null, "Unexpected pod metadata; expected all executor pods " + + s"to have label $SPARK_EXECUTOR_ID_LABEL.") + executorId + } + } + + override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { + new KubernetesDriverEndpoint(rpcEnv, properties) + } + + private class KubernetesDriverEndpoint( + rpcEnv: RpcEnv, + sparkProperties: Seq[(String, String)]) + extends DriverEndpoint(rpcEnv, sparkProperties) { + + override def onDisconnected(rpcAddress: RpcAddress): Unit = { + addressToExecutorId.get(rpcAddress).foreach { executorId => + if (disableExecutor(executorId)) { + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + runningExecutorsToPods.get(executorId).foreach { pod => + disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) + } + } + } + } + } + } +} + +private object KubernetesClusterSchedulerBackend { + private val UNKNOWN_EXIT_CODE = -1 +} diff --git a/resource-managers/kubernetes/core/src/test/resources/log4j.properties b/resource-managers/kubernetes/core/src/test/resources/log4j.properties new file mode 100644 index 000000000000..ad95fadb7c0c --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/resources/log4j.properties @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from a few verbose libraries. +log4j.logger.com.sun.jersey=WARN +log4j.logger.org.apache.hadoop=WARN +log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.mortbay=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala new file mode 100644 index 000000000000..1c7717c23809 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{Pod, _} +import org.mockito.MockitoAnnotations +import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach { + private val driverPodName: String = "driver-pod" + private val driverPodUid: String = "driver-uid" + private val executorPrefix: String = "base" + private val executorImage: String = "executor-image" + private val driverPod = new PodBuilder() + .withNewMetadata() + .withName(driverPodName) + .withUid(driverPodUid) + .endMetadata() + .withNewSpec() + .withNodeName("some-node") + .endSpec() + .withNewStatus() + .withHostIP("192.168.99.100") + .endStatus() + .build() + private var baseConf: SparkConf = _ + + before { + MockitoAnnotations.initMocks(this) + baseConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, driverPodName) + .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix) + .set(EXECUTOR_DOCKER_IMAGE, executorImage) + } + + test("basic executor pod has reasonable defaults") { + val factory = new ExecutorPodFactoryImpl(baseConf) + val executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + + // The executor pod name and default labels. + assert(executor.getMetadata.getName === s"$executorPrefix-exec-1") + assert(executor.getMetadata.getLabels.size() === 3) + assert(executor.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) === "1") + + // There is exactly 1 container with no volume mounts and default memory limits. + // Default memory limit is 1024M + 384M (minimum overhead constant). + assert(executor.getSpec.getContainers.size() === 1) + assert(executor.getSpec.getContainers.get(0).getImage === executorImage) + assert(executor.getSpec.getContainers.get(0).getVolumeMounts.isEmpty) + assert(executor.getSpec.getContainers.get(0).getResources.getLimits.size() === 1) + assert(executor.getSpec.getContainers.get(0).getResources + .getLimits.get("memory").getAmount === "1408Mi") + + // The pod has no node selector, volumes. + assert(executor.getSpec.getNodeSelector.isEmpty) + assert(executor.getSpec.getVolumes.isEmpty) + + checkEnv(executor, Map()) + checkOwnerReferences(executor, driverPodUid) + } + + test("executor pod hostnames get truncated to 63 characters") { + val conf = baseConf.clone() + conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, + "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple") + + val factory = new ExecutorPodFactoryImpl(conf) + val executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + + assert(executor.getSpec.getHostname.length === 63) + } + + test("classpath and extra java options get translated into environment variables") { + val conf = baseConf.clone() + conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") + conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") + + val factory = new ExecutorPodFactoryImpl(conf) + val executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)]("qux" -> "quux"), driverPod, Map[String, Int]()) + + checkEnv(executor, + Map("SPARK_JAVA_OPT_0" -> "foo=bar", + "SPARK_EXECUTOR_EXTRA_CLASSPATH" -> "bar=baz", + "qux" -> "quux")) + checkOwnerReferences(executor, driverPodUid) + } + + // There is always exactly one controller reference, and it points to the driver pod. + private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { + assert(executor.getMetadata.getOwnerReferences.size() === 1) + assert(executor.getMetadata.getOwnerReferences.get(0).getUid === driverPodUid) + assert(executor.getMetadata.getOwnerReferences.get(0).getController === true) + } + + // Check that the expected environment variables are present. + private def checkEnv(executor: Pod, additionalEnvVars: Map[String, String]): Unit = { + val defaultEnvs = Map( + ENV_EXECUTOR_ID -> "1", + ENV_DRIVER_URL -> "dummy", + ENV_EXECUTOR_CORES -> "1", + ENV_EXECUTOR_MEMORY -> "1g", + ENV_APPLICATION_ID -> "dummy", + ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars + + assert(executor.getSpec.getContainers.size() === 1) + assert(executor.getSpec.getContainers.get(0).getEnv.size() === defaultEnvs.size) + val mapEnvs = executor.getSpec.getContainers.get(0).getEnv.asScala.map { + x => (x.getName, x.getValue) + }.toMap + assert(defaultEnvs === mapEnvs) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala new file mode 100644 index 000000000000..3febb2f47cfd --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -0,0 +1,440 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.{ExecutorService, ScheduledExecutorService, TimeUnit} + +import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder, PodList} +import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action +import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NonNamespaceOperation, PodResource} +import org.mockito.{AdditionalAnswers, ArgumentCaptor, Mock, MockitoAnnotations} +import org.mockito.Matchers.{any, eq => mockitoEq} +import org.mockito.Mockito.{doNothing, never, times, verify, when} +import org.scalatest.BeforeAndAfter +import org.scalatest.mockito.MockitoSugar._ +import scala.collection.JavaConverters._ +import scala.concurrent.Future + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.rpc._ +import org.apache.spark.scheduler.{ExecutorExited, LiveListenerBus, SlaveLost, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor, RemoveExecutor} +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.util.ThreadUtils + +class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAndAfter { + + private val APP_ID = "test-spark-app" + private val DRIVER_POD_NAME = "spark-driver-pod" + private val NAMESPACE = "test-namespace" + private val SPARK_DRIVER_HOST = "localhost" + private val SPARK_DRIVER_PORT = 7077 + private val POD_ALLOCATION_INTERVAL = 60L + private val DRIVER_URL = RpcEndpointAddress( + SPARK_DRIVER_HOST, SPARK_DRIVER_PORT, CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString + private val FIRST_EXECUTOR_POD = new PodBuilder() + .withNewMetadata() + .withName("pod1") + .endMetadata() + .withNewSpec() + .withNodeName("node1") + .endSpec() + .withNewStatus() + .withHostIP("192.168.99.100") + .endStatus() + .build() + private val SECOND_EXECUTOR_POD = new PodBuilder() + .withNewMetadata() + .withName("pod2") + .endMetadata() + .withNewSpec() + .withNodeName("node2") + .endSpec() + .withNewStatus() + .withHostIP("192.168.99.101") + .endStatus() + .build() + + private type PODS = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] + private type LABELED_PODS = FilterWatchListDeletable[ + Pod, PodList, java.lang.Boolean, Watch, Watcher[Pod]] + private type IN_NAMESPACE_PODS = NonNamespaceOperation[ + Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] + + @Mock + private var sparkContext: SparkContext = _ + + @Mock + private var listenerBus: LiveListenerBus = _ + + @Mock + private var taskSchedulerImpl: TaskSchedulerImpl = _ + + @Mock + private var allocatorExecutor: ScheduledExecutorService = _ + + @Mock + private var requestExecutorsService: ExecutorService = _ + + @Mock + private var executorPodFactory: ExecutorPodFactory = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var podsWithLabelOperations: LABELED_PODS = _ + + @Mock + private var podsInNamespace: IN_NAMESPACE_PODS = _ + + @Mock + private var podsWithDriverName: PodResource[Pod, DoneablePod] = _ + + @Mock + private var rpcEnv: RpcEnv = _ + + @Mock + private var driverEndpointRef: RpcEndpointRef = _ + + @Mock + private var executorPodsWatch: Watch = _ + + @Mock + private var successFuture: Future[Boolean] = _ + + private var sparkConf: SparkConf = _ + private var executorPodsWatcherArgument: ArgumentCaptor[Watcher[Pod]] = _ + private var allocatorRunnable: ArgumentCaptor[Runnable] = _ + private var requestExecutorRunnable: ArgumentCaptor[Runnable] = _ + private var driverEndpoint: ArgumentCaptor[RpcEndpoint] = _ + + private val driverPod = new PodBuilder() + .withNewMetadata() + .withName(DRIVER_POD_NAME) + .addToLabels(SPARK_APP_ID_LABEL, APP_ID) + .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_DRIVER_ROLE) + .endMetadata() + .build() + + before { + MockitoAnnotations.initMocks(this) + sparkConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME) + .set(KUBERNETES_NAMESPACE, NAMESPACE) + .set("spark.driver.host", SPARK_DRIVER_HOST) + .set("spark.driver.port", SPARK_DRIVER_PORT.toString) + .set(KUBERNETES_ALLOCATION_BATCH_DELAY, POD_ALLOCATION_INTERVAL) + executorPodsWatcherArgument = ArgumentCaptor.forClass(classOf[Watcher[Pod]]) + allocatorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) + requestExecutorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) + driverEndpoint = ArgumentCaptor.forClass(classOf[RpcEndpoint]) + when(sparkContext.conf).thenReturn(sparkConf) + when(sparkContext.listenerBus).thenReturn(listenerBus) + when(taskSchedulerImpl.sc).thenReturn(sparkContext) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withLabel(SPARK_APP_ID_LABEL, APP_ID)).thenReturn(podsWithLabelOperations) + when(podsWithLabelOperations.watch(executorPodsWatcherArgument.capture())) + .thenReturn(executorPodsWatch) + when(podOperations.inNamespace(NAMESPACE)).thenReturn(podsInNamespace) + when(podsInNamespace.withName(DRIVER_POD_NAME)).thenReturn(podsWithDriverName) + when(podsWithDriverName.get()).thenReturn(driverPod) + when(allocatorExecutor.scheduleWithFixedDelay( + allocatorRunnable.capture(), + mockitoEq(0L), + mockitoEq(POD_ALLOCATION_INTERVAL), + mockitoEq(TimeUnit.SECONDS))).thenReturn(null) + // Creating Futures in Scala backed by a Java executor service resolves to running + // ExecutorService#execute (as opposed to submit) + doNothing().when(requestExecutorsService).execute(requestExecutorRunnable.capture()) + when(rpcEnv.setupEndpoint( + mockitoEq(CoarseGrainedSchedulerBackend.ENDPOINT_NAME), driverEndpoint.capture())) + .thenReturn(driverEndpointRef) + + // Used by the CoarseGrainedSchedulerBackend when making RPC calls. + when(driverEndpointRef.ask[Boolean] + (any(classOf[Any])) + (any())).thenReturn(successFuture) + when(successFuture.failed).thenReturn(Future[Throwable] { + // emulate behavior of the Future.failed method. + throw new NoSuchElementException() + }(ThreadUtils.sameThread)) + } + + test("Basic lifecycle expectations when starting and stopping the scheduler.") { + val scheduler = newSchedulerBackend() + scheduler.start() + assert(executorPodsWatcherArgument.getValue != null) + assert(allocatorRunnable.getValue != null) + scheduler.stop() + verify(executorPodsWatch).close() + } + + test("Static allocation should request executors upon first allocator run.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) + val scheduler = newSchedulerBackend() + scheduler.start() + requestExecutorRunnable.getValue.run() + val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) + allocatorRunnable.getValue.run() + verify(podOperations).create(firstResolvedPod) + verify(podOperations).create(secondResolvedPod) + } + + test("Killing executors deletes the executor pods") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) + val scheduler = newSchedulerBackend() + scheduler.start() + requestExecutorRunnable.getValue.run() + val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + when(podOperations.create(any(classOf[Pod]))) + .thenAnswer(AdditionalAnswers.returnsFirstArg()) + allocatorRunnable.getValue.run() + scheduler.doKillExecutors(Seq("2")) + requestExecutorRunnable.getAllValues.asScala.last.run() + verify(podOperations).delete(secondResolvedPod) + verify(podOperations, never()).delete(firstResolvedPod) + } + + test("Executors should be requested in batches.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) + val scheduler = newSchedulerBackend() + scheduler.start() + requestExecutorRunnable.getValue.run() + when(podOperations.create(any(classOf[Pod]))) + .thenAnswer(AdditionalAnswers.returnsFirstArg()) + val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + allocatorRunnable.getValue.run() + verify(podOperations).create(firstResolvedPod) + verify(podOperations, never()).create(secondResolvedPod) + val registerFirstExecutorMessage = RegisterExecutor( + "1", mock[RpcEndpointRef], "localhost", 1, Map.empty[String, String]) + when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) + driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) + .apply(registerFirstExecutorMessage) + allocatorRunnable.getValue.run() + verify(podOperations).create(secondResolvedPod) + } + + test("Scaled down executors should be cleaned up") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) + val scheduler = newSchedulerBackend() + scheduler.start() + + // The scheduler backend spins up one executor pod. + requestExecutorRunnable.getValue.run() + when(podOperations.create(any(classOf[Pod]))) + .thenAnswer(AdditionalAnswers.returnsFirstArg()) + val resolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + allocatorRunnable.getValue.run() + val executorEndpointRef = mock[RpcEndpointRef] + when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) + val registerFirstExecutorMessage = RegisterExecutor( + "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) + when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) + driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) + .apply(registerFirstExecutorMessage) + + // Request that there are 0 executors and trigger deletion from driver. + scheduler.doRequestTotalExecutors(0) + requestExecutorRunnable.getAllValues.asScala.last.run() + scheduler.doKillExecutors(Seq("1")) + requestExecutorRunnable.getAllValues.asScala.last.run() + verify(podOperations, times(1)).delete(resolvedPod) + driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) + + val exitedPod = exitPod(resolvedPod, 0) + executorPodsWatcherArgument.getValue.eventReceived(Action.DELETED, exitedPod) + allocatorRunnable.getValue.run() + + // No more deletion attempts of the executors. + // This is graceful termination and should not be detected as a failure. + verify(podOperations, times(1)).delete(resolvedPod) + verify(driverEndpointRef, times(1)).ask[Boolean]( + RemoveExecutor("1", ExecutorExited( + 0, + exitCausedByApp = false, + s"Container in pod ${exitedPod.getMetadata.getName} exited from" + + s" explicit termination request."))) + } + + test("Executors that fail should not be deleted.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) + + val scheduler = newSchedulerBackend() + scheduler.start() + val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) + requestExecutorRunnable.getValue.run() + allocatorRunnable.getValue.run() + val executorEndpointRef = mock[RpcEndpointRef] + when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) + val registerFirstExecutorMessage = RegisterExecutor( + "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) + when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) + driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) + .apply(registerFirstExecutorMessage) + driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) + executorPodsWatcherArgument.getValue.eventReceived( + Action.ERROR, exitPod(firstResolvedPod, 1)) + + // A replacement executor should be created but the error pod should persist. + val replacementPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + scheduler.doRequestTotalExecutors(1) + requestExecutorRunnable.getValue.run() + allocatorRunnable.getAllValues.asScala.last.run() + verify(podOperations, never()).delete(firstResolvedPod) + verify(driverEndpointRef).ask[Boolean]( + RemoveExecutor("1", ExecutorExited( + 1, + exitCausedByApp = true, + s"Pod ${FIRST_EXECUTOR_POD.getMetadata.getName}'s executor container exited with" + + " exit status code 1."))) + } + + test("Executors disconnected due to unknown reasons are deleted and replaced.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) + val executorLostReasonCheckMaxAttempts = sparkConf.get( + KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS) + + val scheduler = newSchedulerBackend() + scheduler.start() + val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) + requestExecutorRunnable.getValue.run() + allocatorRunnable.getValue.run() + val executorEndpointRef = mock[RpcEndpointRef] + when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) + val registerFirstExecutorMessage = RegisterExecutor( + "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) + when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) + driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) + .apply(registerFirstExecutorMessage) + + driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) + 1 to executorLostReasonCheckMaxAttempts foreach { _ => + allocatorRunnable.getValue.run() + verify(podOperations, never()).delete(FIRST_EXECUTOR_POD) + } + + val recreatedResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + allocatorRunnable.getValue.run() + verify(podOperations).delete(firstResolvedPod) + verify(driverEndpointRef).ask[Boolean]( + RemoveExecutor("1", SlaveLost("Executor lost for unknown reasons."))) + } + + test("Executors that fail to start on the Kubernetes API call rebuild in the next batch.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) + val scheduler = newSchedulerBackend() + scheduler.start() + val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + when(podOperations.create(firstResolvedPod)) + .thenThrow(new RuntimeException("test")) + requestExecutorRunnable.getValue.run() + allocatorRunnable.getValue.run() + verify(podOperations, times(1)).create(firstResolvedPod) + val recreatedResolvedPod = expectPodCreationWithId(2, FIRST_EXECUTOR_POD) + allocatorRunnable.getValue.run() + verify(podOperations).create(recreatedResolvedPod) + } + + test("Executors that are initially created but the watch notices them fail are rebuilt" + + " in the next batch.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) + val scheduler = newSchedulerBackend() + scheduler.start() + val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + when(podOperations.create(FIRST_EXECUTOR_POD)).thenAnswer(AdditionalAnswers.returnsFirstArg()) + requestExecutorRunnable.getValue.run() + allocatorRunnable.getValue.run() + verify(podOperations, times(1)).create(firstResolvedPod) + executorPodsWatcherArgument.getValue.eventReceived(Action.ERROR, firstResolvedPod) + val recreatedResolvedPod = expectPodCreationWithId(2, FIRST_EXECUTOR_POD) + allocatorRunnable.getValue.run() + verify(podOperations).create(recreatedResolvedPod) + } + + private def newSchedulerBackend(): KubernetesClusterSchedulerBackend = { + new KubernetesClusterSchedulerBackend( + taskSchedulerImpl, + rpcEnv, + executorPodFactory, + kubernetesClient, + allocatorExecutor, + requestExecutorsService) { + + override def applicationId(): String = APP_ID + } + } + + private def exitPod(basePod: Pod, exitCode: Int): Pod = { + new PodBuilder(basePod) + .editStatus() + .addNewContainerStatus() + .withNewState() + .withNewTerminated() + .withExitCode(exitCode) + .endTerminated() + .endState() + .endContainerStatus() + .endStatus() + .build() + } + + private def expectPodCreationWithId(executorId: Int, expectedPod: Pod): Pod = { + val resolvedPod = new PodBuilder(expectedPod) + .editMetadata() + .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString) + .endMetadata() + .build() + when(executorPodFactory.createExecutorPod( + executorId.toString, + APP_ID, + DRIVER_URL, + sparkConf.getExecutorEnv, + driverPod, + Map.empty)).thenReturn(resolvedPod) + resolvedPod + } +} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 7052fb347106..506adb363aa9 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -41,6 +41,7 @@ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RetrieveLastAllocatedExecutorId +import org.apache.spark.scheduler.cluster.SchedulerBackendUtils import org.apache.spark.util.{Clock, SystemClock, ThreadUtils} /** @@ -109,7 +110,7 @@ private[yarn] class YarnAllocator( sparkConf.get(EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS).getOrElse(-1L) @volatile private var targetNumExecutors = - YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sparkConf) + SchedulerBackendUtils.getInitialTargetExecutorNumber(sparkConf) private var currentNodeBlacklist = Set.empty[String] diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 3d9f99f57bed..9c1472cb50e3 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -133,8 +133,6 @@ object YarnSparkHadoopUtil { val ANY_HOST = "*" - val DEFAULT_NUMBER_EXECUTORS = 2 - // All RM requests are issued with same priority : we do not (yet) have any distinction between // request types (like map/reduce in hadoop for example) val RM_REQUEST_PRIORITY = Priority.newInstance(1) @@ -279,27 +277,5 @@ object YarnSparkHadoopUtil { securityMgr.getModifyAclsGroups) ) } - - /** - * Getting the initial target number of executors depends on whether dynamic allocation is - * enabled. - * If not using dynamic allocation it gets the number of executors requested by the user. - */ - def getInitialTargetExecutorNumber( - conf: SparkConf, - numExecutors: Int = DEFAULT_NUMBER_EXECUTORS): Int = { - if (Utils.isDynamicAllocationEnabled(conf)) { - val minNumExecutors = conf.get(DYN_ALLOCATION_MIN_EXECUTORS) - val initialNumExecutors = Utils.getDynamicAllocationInitialExecutors(conf) - val maxNumExecutors = conf.get(DYN_ALLOCATION_MAX_EXECUTORS) - require(initialNumExecutors >= minNumExecutors && initialNumExecutors <= maxNumExecutors, - s"initial executor number $initialNumExecutors must between min executor number " + - s"$minNumExecutors and max executor number $maxNumExecutors") - - initialNumExecutors - } else { - conf.get(EXECUTOR_INSTANCES).getOrElse(numExecutors) - } - } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index d482376d14dd..b722cc401bb7 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -52,7 +52,7 @@ private[spark] class YarnClientSchedulerBackend( logDebug("ClientArguments called with: " + argsArrayBuf.mkString(" ")) val args = new ClientArguments(argsArrayBuf.toArray) - totalExpectedExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(conf) + totalExpectedExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf) client = new Client(args, conf) bindToYarn(client.submitApplication(), None) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index 4f3d5ebf403e..e2d477be329c 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -34,7 +34,7 @@ private[spark] class YarnClusterSchedulerBackend( val attemptId = ApplicationMaster.getAttemptId bindToYarn(attemptId.getApplicationId(), Some(attemptId)) super.start() - totalExpectedExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sc.conf) + totalExpectedExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(sc.conf) } override def getDriverLogUrls: Option[Map[String, String]] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index b87bbb487467..95b6fbb0cd61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIden import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Cast, ExprId, Literal} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.types.StructType @@ -366,10 +367,17 @@ case class CatalogStatistics( * Convert [[CatalogStatistics]] to [[Statistics]], and match column stats to attributes based * on column names. */ - def toPlanStats(planOutput: Seq[Attribute]): Statistics = { - val matched = planOutput.flatMap(a => colStats.get(a.name).map(a -> _)) - Statistics(sizeInBytes = sizeInBytes, rowCount = rowCount, - attributeStats = AttributeMap(matched)) + def toPlanStats(planOutput: Seq[Attribute], cboEnabled: Boolean): Statistics = { + if (cboEnabled && rowCount.isDefined) { + val attrStats = AttributeMap(planOutput.flatMap(a => colStats.get(a.name).map(a -> _))) + // Estimate size as number of rows * row size. + val size = EstimationUtils.getOutputSize(planOutput, rowCount.get, attrStats) + Statistics(sizeInBytes = size, rowCount = rowCount, attributeStats = attrStats) + } else { + // When CBO is disabled or the table doesn't have other statistics, we apply the size-only + // estimation strategy and only propagate sizeInBytes in statistics. + Statistics(sizeInBytes = sizeInBytes) + } } /** Readable string representation for the CatalogStatistics. */ @@ -452,7 +460,7 @@ case class HiveTableRelation( ) override def computeStats(): Statistics = { - tableMeta.stats.map(_.toPlanStats(output)).getOrElse { + tableMeta.stats.map(_.toPlanStats(output, conf.cboEnabled)).getOrElse { throw new IllegalStateException("table stats must be specified.") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index f8644c2cd672..8d06804ce1e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -87,8 +87,7 @@ class EquivalentExpressions { def childrenToRecurse: Seq[Expression] = expr match { case _: CodegenFallback => Nil case i: If => i.predicate :: Nil - // `CaseWhen` implements `CodegenFallback`, we only need to handle `CaseWhenCodegen` here. - case c: CaseWhenCodegen => c.children.head :: Nil + case c: CaseWhen => c.children.head :: Nil case c: Coalesce => c.children.head :: Nil case other => other.children } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index e5a1096bba71..d98f7b3d8efe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -614,7 +614,7 @@ case class Least(children: Seq[Expression]) extends Expression { } """ } - val codes = ctx.splitExpressions(ctx.INPUT_ROW, evalChildren.map(updateEval)) + val codes = ctx.splitExpressions(evalChildren.map(updateEval)) ev.copy(code = s""" ${ev.isNull} = true; ${ev.value} = ${ctx.defaultValue(dataType)}; @@ -680,7 +680,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { } """ } - val codes = ctx.splitExpressions(ctx.INPUT_ROW, evalChildren.map(updateEval)) + val codes = ctx.splitExpressions(evalChildren.map(updateEval)) ev.copy(code = s""" ${ev.isNull} = true; ${ev.value} = ${ctx.defaultValue(dataType)}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 0498e61819f4..668c816b3fd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -781,15 +781,18 @@ class CodegenContext { * beyond 1000kb, we declare a private, inner sub-class, and the function is inlined to it * instead, because classes have a constant pool limit of 65,536 named values. * - * @param row the variable name of row that is used by expressions + * Note that we will extract the current inputs of this context and pass them to the generated + * functions. The input is `INPUT_ROW` for normal codegen path, and `currentVars` for whole + * stage codegen path. Whole stage codegen path is not supported yet. + * * @param expressions the codes to evaluate expressions. */ - def splitExpressions(row: String, expressions: Seq[String]): String = { - if (row == null || currentVars != null) { - // Cannot split these expressions because they are not created from a row object. + def splitExpressions(expressions: Seq[String]): String = { + // TODO: support whole stage codegen + if (INPUT_ROW == null || currentVars != null) { return expressions.mkString("\n") } - splitExpressions(expressions, funcName = "apply", arguments = ("InternalRow", row) :: Nil) + splitExpressions(expressions, funcName = "apply", arguments = ("InternalRow", INPUT_ROW) :: Nil) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 802e8bdb1ca3..5fdbda51b4ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -91,8 +91,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } - val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes) - val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates) + val allProjections = ctx.splitExpressions(projectionCodes) + val allUpdates = ctx.splitExpressions(updates) val codeBody = s""" public java.lang.Object generate(Object[] references) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 1e4ac3f2afd5..5d35cce1a91c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -45,7 +45,8 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ctx: CodegenContext, input: String, schema: StructType): ExprCode = { - val tmp = ctx.freshName("tmp") + // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. + val tmpInput = ctx.freshName("tmpInput") val output = ctx.freshName("safeRow") val values = ctx.freshName("values") // These expressions could be split into multiple functions @@ -54,17 +55,21 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val rowClass = classOf[GenericInternalRow].getName val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => - val converter = convertToSafe(ctx, ctx.getValue(tmp, dt, i.toString), dt) + val converter = convertToSafe(ctx, ctx.getValue(tmpInput, dt, i.toString), dt) s""" - if (!$tmp.isNullAt($i)) { + if (!$tmpInput.isNullAt($i)) { ${converter.code} $values[$i] = ${converter.value}; } """ } - val allFields = ctx.splitExpressions(tmp, fieldWriters) + val allFields = ctx.splitExpressions( + expressions = fieldWriters, + funcName = "writeFields", + arguments = Seq("InternalRow" -> tmpInput) + ) val code = s""" - final InternalRow $tmp = $input; + final InternalRow $tmpInput = $input; $values = new Object[${schema.length}]; $allFields final InternalRow $output = new $rowClass($values); @@ -78,20 +83,22 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ctx: CodegenContext, input: String, elementType: DataType): ExprCode = { - val tmp = ctx.freshName("tmp") + // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. + val tmpInput = ctx.freshName("tmpInput") val output = ctx.freshName("safeArray") val values = ctx.freshName("values") val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") val arrayClass = classOf[GenericArrayData].getName - val elementConverter = convertToSafe(ctx, ctx.getValue(tmp, elementType, index), elementType) + val elementConverter = convertToSafe( + ctx, ctx.getValue(tmpInput, elementType, index), elementType) val code = s""" - final ArrayData $tmp = $input; - final int $numElements = $tmp.numElements(); + final ArrayData $tmpInput = $input; + final int $numElements = $tmpInput.numElements(); final Object[] $values = new Object[$numElements]; for (int $index = 0; $index < $numElements; $index++) { - if (!$tmp.isNullAt($index)) { + if (!$tmpInput.isNullAt($index)) { ${elementConverter.code} $values[$index] = ${elementConverter.value}; } @@ -107,14 +114,14 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] input: String, keyType: DataType, valueType: DataType): ExprCode = { - val tmp = ctx.freshName("tmp") + val tmpInput = ctx.freshName("tmpInput") val output = ctx.freshName("safeMap") val mapClass = classOf[ArrayBasedMapData].getName - val keyConverter = createCodeForArray(ctx, s"$tmp.keyArray()", keyType) - val valueConverter = createCodeForArray(ctx, s"$tmp.valueArray()", valueType) + val keyConverter = createCodeForArray(ctx, s"$tmpInput.keyArray()", keyType) + val valueConverter = createCodeForArray(ctx, s"$tmpInput.valueArray()", valueType) val code = s""" - final MapData $tmp = $input; + final MapData $tmpInput = $input; ${keyConverter.code} ${valueConverter.code} final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value}); @@ -152,7 +159,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] } """ } - val allExpressions = ctx.splitExpressions(ctx.INPUT_ROW, expressionCodes) + val allExpressions = ctx.splitExpressions(expressionCodes) val codeBody = s""" public java.lang.Object generate(Object[] references) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 4bd50aee0551..b022457865d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -36,7 +36,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case NullType => true case t: AtomicType => true case _: CalendarIntervalType => true - case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) + case t: StructType => t.forall(field => canSupport(field.dataType)) case t: ArrayType if canSupport(t.elementType) => true case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true case udt: UserDefinedType[_] => canSupport(udt.sqlType) @@ -49,25 +49,18 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro input: String, fieldTypes: Seq[DataType], bufferHolder: String): String = { + // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. + val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - val javaType = ctx.javaType(dt) - val isNullVar = ctx.freshName("isNull") - val valueVar = ctx.freshName("value") - val defaultValue = ctx.defaultValue(dt) - val readValue = ctx.getValue(input, dt, i.toString) - val code = - s""" - boolean $isNullVar = $input.isNullAt($i); - $javaType $valueVar = $isNullVar ? $defaultValue : $readValue; - """ - ExprCode(code, isNullVar, valueVar) + ExprCode("", s"$tmpInput.isNullAt($i)", ctx.getValue(tmpInput, dt, i.toString)) } s""" - if ($input instanceof UnsafeRow) { - ${writeUnsafeData(ctx, s"((UnsafeRow) $input)", bufferHolder)} + final InternalRow $tmpInput = $input; + if ($tmpInput instanceof UnsafeRow) { + ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", bufferHolder)} } else { - ${writeExpressionsToBuffer(ctx, input, fieldEvals, fieldTypes, bufferHolder)} + ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, bufferHolder)} } """ } @@ -167,9 +160,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } } + val writeFieldsCode = if (isTopLevel && (row == null || ctx.currentVars != null)) { + // TODO: support whole stage codegen + writeFields.mkString("\n") + } else { + assert(row != null, "the input row name cannot be null when generating code to write it.") + ctx.splitExpressions( + expressions = writeFields, + funcName = "writeFields", + arguments = Seq("InternalRow" -> row)) + } + s""" $resetWriter - ${ctx.splitExpressions(row, writeFields)} + $writeFieldsCode """.trim } @@ -179,13 +183,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro input: String, elementType: DataType, bufferHolder: String): String = { + // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. + val tmpInput = ctx.freshName("tmpInput") val arrayWriterClass = classOf[UnsafeArrayWriter].getName val arrayWriter = ctx.freshName("arrayWriter") ctx.addMutableState(arrayWriterClass, arrayWriter, s"$arrayWriter = new $arrayWriterClass();") val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") - val element = ctx.freshName("element") val et = elementType match { case udt: UserDefinedType[_] => udt.sqlType @@ -201,6 +206,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } val tmpCursor = ctx.freshName("tmpCursor") + val element = ctx.getValue(tmpInput, et, index) val writeElement = et match { case t: StructType => s""" @@ -233,17 +239,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else "" s""" - if ($input instanceof UnsafeArrayData) { - ${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)} + final ArrayData $tmpInput = $input; + if ($tmpInput instanceof UnsafeArrayData) { + ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", bufferHolder)} } else { - final int $numElements = $input.numElements(); + final int $numElements = $tmpInput.numElements(); $arrayWriter.initialize($bufferHolder, $numElements, $elementOrOffsetSize); for (int $index = 0; $index < $numElements; $index++) { - if ($input.isNullAt($index)) { + if ($tmpInput.isNullAt($index)) { $arrayWriter.setNull$primitiveTypeName($index); } else { - final $jt $element = ${ctx.getValue(input, et, index)}; $writeElement } } @@ -258,19 +264,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro keyType: DataType, valueType: DataType, bufferHolder: String): String = { - val keys = ctx.freshName("keys") - val values = ctx.freshName("values") + // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. + val tmpInput = ctx.freshName("tmpInput") val tmpCursor = ctx.freshName("tmpCursor") - // Writes out unsafe map according to the format described in `UnsafeMapData`. s""" - if ($input instanceof UnsafeMapData) { - ${writeUnsafeData(ctx, s"((UnsafeMapData) $input)", bufferHolder)} + final MapData $tmpInput = $input; + if ($tmpInput instanceof UnsafeMapData) { + ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", bufferHolder)} } else { - final ArrayData $keys = $input.keyArray(); - final ArrayData $values = $input.valueArray(); - // preserve 8 bytes to write the key array numBytes later. $bufferHolder.grow(8); $bufferHolder.cursor += 8; @@ -278,11 +281,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Remember the current cursor so that we can write numBytes of key array later. final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)} + ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, bufferHolder)} // Write the numBytes of key array into the first 8 bytes. Platform.putLong($bufferHolder.buffer, $tmpCursor - 8, $bufferHolder.cursor - $tmpCursor); - ${writeArrayToBuffer(ctx, values, valueType, bufferHolder)} + ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, bufferHolder)} } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 2a00d57ee130..57a7f2e20773 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -63,7 +63,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { val (preprocess, assigns, postprocess, arrayData) = GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( - code = preprocess + ctx.splitExpressions(ctx.INPUT_ROW, assigns) + postprocess, + code = preprocess + ctx.splitExpressions(assigns) + postprocess, value = arrayData, isNull = "false") } @@ -216,10 +216,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression { s""" final boolean ${ev.isNull} = false; $preprocessKeyData - ${ctx.splitExpressions(ctx.INPUT_ROW, assignKeys)} + ${ctx.splitExpressions(assignKeys)} $postprocessKeyData $preprocessValueData - ${ctx.splitExpressions(ctx.INPUT_ROW, assignValues)} + ${ctx.splitExpressions(assignValues)} $postprocessValueData final MapData ${ev.value} = new $mapClass($keyArrayData, $valueArrayData); """ @@ -351,24 +351,25 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc val rowClass = classOf[GenericInternalRow].getName val values = ctx.freshName("values") ctx.addMutableState("Object[]", values, s"$values = null;") - - ev.copy(code = s""" - $values = new Object[${valExprs.size}];""" + - ctx.splitExpressions( - ctx.INPUT_ROW, - valExprs.zipWithIndex.map { case (e, i) => - val eval = e.genCode(ctx) - eval.code + s""" + val valuesCode = ctx.splitExpressions( + valExprs.zipWithIndex.map { case (e, i) => + val eval = e.genCode(ctx) + s""" + ${eval.code} if (${eval.isNull}) { $values[$i] = null; } else { $values[$i] = ${eval.value}; }""" - }) + + }) + + ev.copy(code = s""" - final InternalRow ${ev.value} = new $rowClass($values); - $values = null; - """, isNull = "false") + |$values = new Object[${valExprs.size}]; + |$valuesCode + |final InternalRow ${ev.value} = new $rowClass($values); + |$values = null; + """.stripMargin, isNull = "false") } override def prettyName: String = "named_struct" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 6195be3a258c..43e643178c89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -88,14 +88,34 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } /** - * Abstract parent class for common logic in CaseWhen and CaseWhenCodegen. + * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". + * When a = true, returns b; when c = true, returns d; else returns e. * * @param branches seq of (branch condition, branch value) * @param elseValue optional value for the else branch */ -abstract class CaseWhenBase( +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END - When `expr1` = true, returns `expr2`; else when `expr3` = true, returns `expr4`; else returns `expr5`.", + arguments = """ + Arguments: + * expr1, expr3 - the branch condition expressions should all be boolean type. + * expr2, expr4, expr5 - the branch value expressions and else value expression should all be + same type or coercible to a common type. + """, + examples = """ + Examples: + > SELECT CASE WHEN 1 > 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END; + 1 + > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END; + 2 + > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 < 0 THEN 2.0 END; + NULL + """) +// scalastyle:on line.size.limit +case class CaseWhen( branches: Seq[(Expression, Expression)], - elseValue: Option[Expression]) + elseValue: Option[Expression] = None) extends Expression with Serializable { override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue @@ -158,111 +178,99 @@ abstract class CaseWhenBase( val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("") "CASE" + cases + elseCase + " END" } -} - - -/** - * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". - * When a = true, returns b; when c = true, returns d; else returns e. - * - * @param branches seq of (branch condition, branch value) - * @param elseValue optional value for the else branch - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END - When `expr1` = true, returns `expr2`; else when `expr3` = true, returns `expr4`; else returns `expr5`.", - arguments = """ - Arguments: - * expr1, expr3 - the branch condition expressions should all be boolean type. - * expr2, expr4, expr5 - the branch value expressions and else value expression should all be - same type or coercible to a common type. - """, - examples = """ - Examples: - > SELECT CASE WHEN 1 > 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END; - 1 - > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END; - 2 - > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 < 0 THEN 2.0 ELSE null END; - NULL - """) -// scalastyle:on line.size.limit -case class CaseWhen( - val branches: Seq[(Expression, Expression)], - val elseValue: Option[Expression] = None) - extends CaseWhenBase(branches, elseValue) with CodegenFallback with Serializable { - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - super[CodegenFallback].doGenCode(ctx, ev) - } - - def toCodegen(): CaseWhenCodegen = { - CaseWhenCodegen(branches, elseValue) - } -} - -/** - * CaseWhen expression used when code generation condition is satisfied. - * OptimizeCodegen optimizer replaces CaseWhen into CaseWhenCodegen. - * - * @param branches seq of (branch condition, branch value) - * @param elseValue optional value for the else branch - */ -case class CaseWhenCodegen( - val branches: Seq[(Expression, Expression)], - val elseValue: Option[Expression] = None) - extends CaseWhenBase(branches, elseValue) with Serializable { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - // Generate code that looks like: - // - // condA = ... - // if (condA) { - // valueA - // } else { - // condB = ... - // if (condB) { - // valueB - // } else { - // condC = ... - // if (condC) { - // valueC - // } else { - // elseValue - // } - // } - // } + // This variable represents whether the first successful condition is met or not. + // It is initialized to `false` and it is set to `true` when the first condition which + // evaluates to `true` is met and therefore is not needed to go on anymore on the computation + // of the following conditions. + val conditionMet = ctx.freshName("caseWhenConditionMet") + ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ctx.addMutableState(ctx.javaType(dataType), ev.value) + + // these blocks are meant to be inside a + // do { + // ... + // } while (false); + // loop val cases = branches.map { case (condExpr, valueExpr) => val cond = condExpr.genCode(ctx) val res = valueExpr.genCode(ctx) s""" - ${cond.code} - if (!${cond.isNull} && ${cond.value}) { - ${res.code} - ${ev.isNull} = ${res.isNull}; - ${ev.value} = ${res.value}; - } - """ + |${cond.code} + |if (!${cond.isNull} && ${cond.value}) { + | ${res.code} + | ${ev.isNull} = ${res.isNull}; + | ${ev.value} = ${res.value}; + | $conditionMet = true; + | continue; + |} + """.stripMargin } - var generatedCode = cases.mkString("", "\nelse {\n", "\nelse {\n") - - elseValue.foreach { elseExpr => + val elseCode = elseValue.map { elseExpr => val res = elseExpr.genCode(ctx) - generatedCode += - s""" - ${res.code} - ${ev.isNull} = ${res.isNull}; - ${ev.value} = ${res.value}; - """ + s""" + |${res.code} + |${ev.isNull} = ${res.isNull}; + |${ev.value} = ${res.value}; + """.stripMargin } - generatedCode += "}\n" * cases.size + val allConditions = cases ++ elseCode + + val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { + allConditions.mkString("\n") + } else { + // This generates code like: + // conditionMet = caseWhen_1(i); + // if(conditionMet) { + // continue; + // } + // conditionMet = caseWhen_2(i); + // if(conditionMet) { + // continue; + // } + // ... + // and the declared methods are: + // private boolean caseWhen_1234() { + // boolean conditionMet = false; + // do { + // // here the evaluation of the conditions + // } while (false); + // return conditionMet; + // } + ctx.splitExpressions(allConditions, "caseWhen", + ("InternalRow", ctx.INPUT_ROW) :: Nil, + returnType = ctx.JAVA_BOOLEAN, + makeSplitFunction = { + func => + s""" + ${ctx.JAVA_BOOLEAN} $conditionMet = false; + do { + $func + } while (false); + return $conditionMet; + """ + }, + foldFunctions = { funcCalls => + funcCalls.map { funcCall => + s""" + $conditionMet = $funcCall; + if ($conditionMet) { + continue; + }""" + }.mkString + }) + } ev.copy(code = s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $generatedCode""") + ${ev.isNull} = true; + ${ev.value} = ${ctx.defaultValue(dataType)}; + ${ctx.JAVA_BOOLEAN} $conditionMet = false; + do { + $code + } while (false);""") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 8618f4908607..f1aa13066926 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -203,7 +203,7 @@ case class Stack(children: Seq[Expression]) extends Generator { ctx.addMutableState("InternalRow[]", rowData, s"$rowData = new InternalRow[$numRows];") val values = children.tail val dataTypes = values.take(numFields).map(_.dataType) - val code = ctx.splitExpressions(ctx.INPUT_ROW, Seq.tabulate(numRows) { row => + val code = ctx.splitExpressions(Seq.tabulate(numRows) { row => val fields = Seq.tabulate(numFields) { col => val index = row * numFields + col if (index < values.length) values(index) else Literal(null, dataTypes(col)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 9e0786e36791..c3289b829993 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -270,7 +270,7 @@ abstract class HashExpression[E] extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.isNull = "false" - val childrenHash = ctx.splitExpressions(ctx.INPUT_ROW, children.map { child => + val childrenHash = ctx.splitExpressions(children.map { child => val childGen = child.genCode(ctx) childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { computeHash(childGen.value, child.dataType, ev.value, ctx) @@ -330,9 +330,9 @@ abstract class HashExpression[E] extends Expression { } else { val bytes = ctx.freshName("bytes") s""" - final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray(); - ${genHashBytes(bytes, result)} - """ + |final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray(); + |${genHashBytes(bytes, result)} + """.stripMargin } } @@ -392,7 +392,10 @@ abstract class HashExpression[E] extends Expression { val hashes = fields.zipWithIndex.map { case (field, index) => nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) } - ctx.splitExpressions(input, hashes) + ctx.splitExpressions( + expressions = hashes, + funcName = "getHash", + arguments = Seq("InternalRow" -> input)) } @tailrec @@ -608,12 +611,17 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.isNull = "false" val childHash = ctx.freshName("childHash") - val childrenHash = ctx.splitExpressions(ctx.INPUT_ROW, children.map { child => + val childrenHash = ctx.splitExpressions(children.map { child => val childGen = child.genCode(ctx) - childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { + val codeToComputeHash = ctx.nullSafeExec(child.nullable, childGen.isNull) { computeHash(childGen.value, child.dataType, childHash, ctx) - } + s"${ev.value} = (31 * ${ev.value}) + $childHash;" + - s"\n$childHash = 0;" + } + s""" + |${childGen.code} + |$codeToComputeHash + |${ev.value} = (31 * ${ev.value}) + $childHash; + |$childHash = 0; + """.stripMargin }) ctx.addMutableState(ctx.javaType(dataType), ev.value) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 5eaf3f220277..173e171910b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -91,7 +91,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ev.copy(code = s""" ${ev.isNull} = true; ${ev.value} = ${ctx.defaultValue(dataType)}; - ${ctx.splitExpressions(ctx.INPUT_ROW, evals)}""") + ${ctx.splitExpressions(evals)}""") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 006d37f38d6c..e2bc79d98b33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -101,7 +101,7 @@ trait InvokeLike extends Expression with NonSQLExpression { """ } } - val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes) + val argCode = ctx.splitExpressions(argCodes) (argCode, argValues.mkString(", "), resultIsNull) } @@ -1119,7 +1119,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) """ } - val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes) + val childrenCode = ctx.splitExpressions(childrenCodes) val schemaField = ctx.addReferenceObj("schema", schema) val code = s""" @@ -1254,7 +1254,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp ${javaBeanInstance}.$setterMethod(${fieldGen.value}); """ } - val initializeCode = ctx.splitExpressions(ctx.INPUT_ROW, initialize.toSeq) + val initializeCode = ctx.splitExpressions(initialize.toSeq) val code = s""" ${instanceGen.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 1c599af2a01d..ee5cf925d3ce 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -208,7 +208,7 @@ case class ConcatWs(children: Seq[Expression]) } }.unzip - val codes = ctx.splitExpressions(ctx.INPUT_ROW, evals.map(_.code)) + val codes = ctx.splitExpressions(evals.map(_.code)) val varargCounts = ctx.splitExpressions( expressions = varargCount, funcName = "varargCountsConcatWs", @@ -1372,10 +1372,10 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC val pattern = children.head.genCode(ctx) val argListGen = children.tail.map(x => (x.dataType, x.genCode(ctx))) - val argListCode = argListGen.map(_._2.code + "\n") - - val argListString = argListGen.foldLeft("")((s, v) => { - val nullSafeString = + val argList = ctx.freshName("argLists") + val numArgLists = argListGen.length + val argListCode = argListGen.zipWithIndex.map { case(v, index) => + val value = if (ctx.boxedType(v._1) != ctx.javaType(v._1)) { // Java primitives get boxed in order to allow null values. s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " + @@ -1383,8 +1383,19 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC } else { s"(${v._2.isNull}) ? null : ${v._2.value}" } - s + "," + nullSafeString - }) + s""" + ${v._2.code} + $argList[$index] = $value; + """ + } + val argListCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { + ctx.splitExpressions( + expressions = argListCode, + funcName = "valueFormatString", + arguments = ("InternalRow", ctx.INPUT_ROW) :: ("Object[]", argList) :: Nil) + } else { + argListCode.mkString("\n") + } val form = ctx.freshName("formatter") val formatter = classOf[java.util.Formatter].getName @@ -1395,10 +1406,11 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC boolean ${ev.isNull} = ${pattern.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${argListCode.mkString} $stringBuffer $sb = new $stringBuffer(); $formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US); - $form.format(${pattern.value}.toString() $argListString); + Object[] $argList = new Object[$numArgLists]; + $argListCodes + $form.format(${pattern.value}.toString(), $argList); ${ev.value} = UTF8String.fromString($sb.toString()); }""") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3a3ccd5ff5e6..0d961bf2e6e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -138,8 +138,6 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) // The following batch should be executed after batch "Join Reorder" and "LocalRelation". Batch("Check Cartesian Products", Once, CheckCartesianProducts) :: - Batch("OptimizeCodegen", Once, - OptimizeCodegen) :: Batch("RewriteSubquery", Once, RewritePredicateSubquery, CollapseProject) :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 523b53b39d6b..785e815b4118 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -552,21 +552,6 @@ object FoldablePropagation extends Rule[LogicalPlan] { } -/** - * Optimizes expressions by replacing according to CodeGen configuration. - */ -object OptimizeCodegen extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case e: CaseWhen if canCodegen(e) => e.toCodegen() - } - - private def canCodegen(e: CaseWhen): Boolean = { - val numBranches = e.branches.size + e.elseValue.size - numBranches <= SQLConf.get.maxCaseBranchesForCodegen - } -} - - /** * Removes [[Cast Casts]] that are unnecessary because the input is already the correct type. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala index d701a956887a..5e1c4e0bd606 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import org.apache.spark.sql.catalyst.expressions.AttributeMap import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} -import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4eda9f337953..8abb4262d735 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -599,12 +599,6 @@ object SQLConf { .booleanConf .createWithDefault(true) - val MAX_CASES_BRANCHES = buildConf("spark.sql.codegen.maxCaseBranches") - .internal() - .doc("The maximum number of switches supported with codegen.") - .intConf - .createWithDefault(20) - val CODEGEN_LOGGING_MAX_LINES = buildConf("spark.sql.codegen.logging.maxLines") .internal() .doc("The maximum number of codegen lines to log when errors occur. Use -1 for unlimited.") @@ -1004,6 +998,15 @@ object SQLConf { .intConf .createWithDefault(10000) + val PANDAS_RESPECT_SESSION_LOCAL_TIMEZONE = + buildConf("spark.sql.execution.pandas.respectSessionTimeZone") + .internal() + .doc("When true, make Pandas DataFrame with timestamp type respecting session local " + + "timezone when converting to/from Pandas DataFrame. This configuration will be " + + "deprecated in the future releases.") + .booleanConf + .createWithDefault(true) + val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter") .internal() .doc("When true, the apply function of the rule verifies whether the right node of the" + @@ -1140,8 +1143,6 @@ class SQLConf extends Serializable with Logging { def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK) - def maxCaseBranchesForCodegen: Int = getConf(MAX_CASES_BRANCHES) - def loggingMaxLinesForCodegen: Int = getConf(CODEGEN_LOGGING_MAX_LINES) def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT) @@ -1324,6 +1325,8 @@ class SQLConf extends Serializable with Logging { def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) + def pandasRespectSessionTimeZone: Boolean = getConf(PANDAS_RESPECT_SESSION_LOCAL_TIMEZONE) + def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) /** ********************** SQLConf functionality methods ************ */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 6e33087b4c6c..a4198f826ced 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -77,7 +77,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { } test("SPARK-13242: case-when expression with large number of branches (or cases)") { - val cases = 50 + val cases = 500 val clauses = 20 // Generate an individual case @@ -88,13 +88,13 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { (condition, Literal(n)) } - val expression = CaseWhen((1 to cases).map(generateCase(_))) + val expression = CaseWhen((1 to cases).map(generateCase)) val plan = GenerateMutableProjection.generate(Seq(expression)) - val input = new GenericInternalRow(Array[Any](UTF8String.fromString(s"${clauses}:${cases}"))) + val input = new GenericInternalRow(Array[Any](UTF8String.fromString(s"$clauses:$cases"))) val actual = plan(input).toSeq(Seq(expression.dataType)) - assert(actual(0) == cases) + assert(actual.head == cases) } test("SPARK-22543: split large if expressions into blocks due to JVM code size limit") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index c76139475687..54cde77176e2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -518,6 +518,14 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { FormatString(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null") } + test("SPARK-22603: FormatString should not generate codes beyond 64KB") { + val N = 4500 + val args = (1 to N).map(i => Literal.create(i.toString, StringType)) + val format = "%s" * N + val expected = (1 to N).map(i => i.toString).mkString + checkEvaluation(FormatString(Literal(format) +: args: _*), expected) + } + test("INSTR") { val s1 = 'a.string.at(0) val s2 = 'b.string.at(1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala deleted file mode 100644 index b1157f3e3edd..000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.Literal._ -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules._ - - -class OptimizeCodegenSuite extends PlanTest { - - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen) :: Nil - } - - protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { - val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation()).analyze - val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation()).analyze) - comparePlans(actual, correctAnswer) - } - - test("Codegen only when the number of branches is small.") { - assertEquivalent( - CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), - CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen()) - - assertEquivalent( - CaseWhen(List.fill(100)((TrueLiteral, Literal(1))), Literal(2)), - CaseWhen(List.fill(100)((TrueLiteral, Literal(1))), Literal(2))) - } - - test("Nested CaseWhen Codegen.") { - assertEquivalent( - CaseWhen( - Seq((CaseWhen(Seq((TrueLiteral, TrueLiteral)), FalseLiteral), Literal(3))), - CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5))), - CaseWhen( - Seq((CaseWhen(Seq((TrueLiteral, TrueLiteral)), FalseLiteral).toCodegen(), Literal(3))), - CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5)).toCodegen()).toCodegen()) - } - - test("Multiple CaseWhen in one operator.") { - val plan = OneRowRelation() - .select( - CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), - CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)), - CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)), - CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6))).analyze - val correctAnswer = OneRowRelation() - .select( - CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(), - CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(), - CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)), - CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen()).analyze - val optimized = Optimize.execute(plan) - comparePlans(optimized, correctAnswer) - } - - test("Multiple CaseWhen in different operators") { - val plan = OneRowRelation() - .select( - CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), - CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)), - CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0))) - .where( - LessThan( - CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)), - CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0))) - ).analyze - val correctAnswer = OneRowRelation() - .select( - CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(), - CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(), - CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0))) - .where( - LessThan( - CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen(), - CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0))) - ).analyze - val optimized = Optimize.execute(plan) - comparePlans(optimized, correctAnswer) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index e3fa2ced760e..35abeccfd514 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -592,7 +592,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `sep` (default `,`): sets the single character as a separator for each * field and value.
  • *
  • `quote` (default `"`): sets the single character used for escaping quoted values where - * the separator can be part of the value.
  • + * the separator can be part of the value. If an empty string is set, it uses `u0000` + * (null character). *
  • `escape` (default `\`): sets the single character used for escaping quotes inside * an already quoted value.
  • *
  • `escapeQuotes` (default `true`): a flag indicating whether values containing diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 236995708a12..8d715f634298 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -41,7 +41,7 @@ case class LogicalRelation( override def computeStats(): Statistics = { catalogTable - .flatMap(_.stats.map(_.toPlanStats(output))) + .flatMap(_.stats.map(_.toPlanStats(output, conf.cboEnabled))) .getOrElse(Statistics(sizeInBytes = relation.sizeInBytes)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index e27210117a1e..c06bc7b66ff3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -63,6 +63,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi private val batchSize = conf.arrowMaxRecordsPerBatch private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone protected override def evaluate( funcs: Seq[ChainedPythonFunctions], @@ -81,7 +82,8 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val columnarBatchIter = new ArrowPythonRunner( funcs, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_SCALAR_UDF, argOffsets, schema, sessionLocalTimeZone) + PythonEvalType.SQL_PANDAS_SCALAR_UDF, argOffsets, schema, + sessionLocalTimeZone, pandasRespectSessionTimeZone) .compute(batchIter, context.partitionId(), context) new Iterator[InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 94c05b9b5e49..9a94d771a01b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -44,7 +44,8 @@ class ArrowPythonRunner( evalType: Int, argOffsets: Array[Array[Int]], schema: StructType, - timeZoneId: String) + timeZoneId: String, + respectTimeZone: Boolean) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( funcs, bufferSize, reuseWorker, evalType, argOffsets) { @@ -58,6 +59,11 @@ class ArrowPythonRunner( protected override def writeCommand(dataOut: DataOutputStream): Unit = { PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + if (respectTimeZone) { + PythonRDD.writeUTF(timeZoneId, dataOut) + } else { + dataOut.writeInt(SpecialLengths.NULL) + } } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index ee495814b825..59db66bd7adf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -78,6 +78,7 @@ case class FlatMapGroupsInPandasExec( val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray) val schema = StructType(child.schema.drop(groupingAttributes.length)) val sessionLocalTimeZone = conf.sessionLocalTimeZone + val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone inputRDD.mapPartitionsInternal { iter => val grouped = if (groupingAttributes.isEmpty) { @@ -95,7 +96,8 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, argOffsets, schema, sessionLocalTimeZone) + PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, argOffsets, schema, + sessionLocalTimeZone, pandasRespectSessionTimeZone) .compute(grouped, context.partitionId(), context) columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala index d077836da847..e49546830286 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala @@ -90,7 +90,7 @@ class FlatMapGroupsWithState_StateManager( val deser = stateEncoder.resolveAndBind().deserializer.transformUp { case BoundReference(ordinal, _, _) => GetStructField(boundRefToNestedState, ordinal) } - CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), elseValue = deser).toCodegen() + CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), elseValue = deser) } // Converters for translating state between rows and Java objects diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index fdd25330c5e6..6ae307bce10c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -480,7 +480,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { if (tableMetadata.tableType == CatalogTableType.VIEW) { // Temp or persistent views: refresh (or invalidate) any metadata/data cached // in the plan recursively. - table.queryExecution.analyzed.foreach(_.refresh()) + table.queryExecution.analyzed.refresh() } else { // Non-temp tables: refresh the metadata cache. sessionCatalog.refreshTable(tableIdent) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 644e72c893ce..72a5cc98fbec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -29,7 +29,7 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} -import org.apache.spark.sql.execution.{FilterExec, QueryExecution} +import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.functions._ @@ -2158,4 +2158,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val mean = result.select("DecimalCol").where($"summary" === "mean") assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000"))) } + + test("SPARK-22520: support code generation for large CaseWhen") { + val N = 30 + var expr1 = when($"id" === lit(0), 0) + var expr2 = when($"id" === lit(0), 10) + (1 to N).foreach { i => + expr1 = expr1.when($"id" === lit(i), -i) + expr2 = expr2.when($"id" === lit(i + 10), i) + } + val df = spark.range(1).select(expr1, expr2.otherwise(0)) + checkAnswer(df, Row(0, 10) :: Nil) + assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala index f6df077ec572..65ccc1915882 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, H import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, HistogramBin, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.internal.StaticSQLConf +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.Decimal @@ -223,11 +223,19 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils assert(catalogTable.stats.get.colStats == Map("c1" -> emptyColStat)) // Check relation statistics - assert(relation.stats.sizeInBytes == 0) - assert(relation.stats.rowCount == Some(0)) - assert(relation.stats.attributeStats.size == 1) - val (attribute, colStat) = relation.stats.attributeStats.head - assert(attribute.name == "c1") - assert(colStat == emptyColStat) + withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { + assert(relation.stats.sizeInBytes == 1) + assert(relation.stats.rowCount == Some(0)) + assert(relation.stats.attributeStats.size == 1) + val (attribute, colStat) = relation.stats.attributeStats.head + assert(attribute.name == "c1") + assert(colStat == emptyColStat) + } + relation.invalidateStatsCache() + withSQLConf(SQLConf.CBO_ENABLED.key -> "false") { + assert(relation.stats.sizeInBytes == 0) + assert(relation.stats.rowCount.isEmpty) + assert(relation.stats.attributeStats.isEmpty) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 878f435c75cb..fdb9b2f51f9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -117,6 +117,21 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo } } + test("SPARK-22431: table with nested type col with special char") { + withTable("t") { + spark.sql("CREATE TABLE t(q STRUCT<`$a`:INT, col2:STRING>, i1 INT) USING PARQUET") + checkAnswer(spark.table("t"), Nil) + } + } + + test("SPARK-22431: view with nested type") { + withView("t", "v") { + spark.sql("CREATE VIEW t AS SELECT STRUCT('a' AS `$a`, 1 AS b) q") + checkAnswer(spark.table("t"), Row(Row("a", 1)) :: Nil) + spark.sql("CREATE VIEW v AS SELECT STRUCT('a' AS `a`, 1 AS b) q") + checkAnswer(spark.table("t"), Row(Row("a", 1)) :: Nil) + } + } } abstract class DDLSuite extends QueryTest with SQLTestUtils { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index f9d75fc1788d..8b1521bacea4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -221,35 +221,6 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { .sessionState.conf.warehousePath.stripSuffix("/")) } - test("MAX_CASES_BRANCHES") { - withTable("tab1") { - spark.range(10).write.saveAsTable("tab1") - val sql_one_branch_caseWhen = "SELECT CASE WHEN id = 1 THEN 1 END FROM tab1" - val sql_two_branch_caseWhen = "SELECT CASE WHEN id = 1 THEN 1 ELSE 0 END FROM tab1" - - withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "0") { - assert(!sql(sql_one_branch_caseWhen) - .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) - assert(!sql(sql_two_branch_caseWhen) - .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) - } - - withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "1") { - assert(sql(sql_one_branch_caseWhen) - .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) - assert(!sql(sql_two_branch_caseWhen) - .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) - } - - withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "2") { - assert(sql(sql_one_branch_caseWhen) - .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) - assert(sql(sql_two_branch_caseWhen) - .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) - } - } - } - test("static SQL conf comes from SparkConf") { val previousValue = sparkContext.conf.get(SCHEMA_STRING_LENGTH_THRESHOLD) try { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index b5a5890d47b0..47ce6ba83866 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -488,6 +488,7 @@ private[hive] class HiveClientImpl( } override def createTable(table: CatalogTable, ignoreIfExists: Boolean): Unit = withHiveState { + verifyColumnDataType(table.dataSchema) client.createTable(toHiveTable(table, Some(userName)), ignoreIfExists) } @@ -507,6 +508,7 @@ private[hive] class HiveClientImpl( // these properties are still available to the others that share the same Hive metastore. // If users explicitly alter these Hive-specific properties through ALTER TABLE DDL, we respect // these user-specified values. + verifyColumnDataType(table.dataSchema) val hiveTable = toHiveTable( table.copy(properties = table.ignoredProperties ++ table.properties), Some(userName)) // Do not use `table.qualifiedName` here because this may be a rename @@ -520,6 +522,7 @@ private[hive] class HiveClientImpl( newDataSchema: StructType, schemaProps: Map[String, String]): Unit = withHiveState { val oldTable = client.getTable(dbName, tableName) + verifyColumnDataType(newDataSchema) val hiveCols = newDataSchema.map(toHiveColumn) oldTable.setFields(hiveCols.asJava) @@ -872,15 +875,19 @@ private[hive] object HiveClientImpl { new FieldSchema(c.name, typeString, c.getComment().orNull) } - /** Builds the native StructField from Hive's FieldSchema. */ - def fromHiveColumn(hc: FieldSchema): StructField = { - val columnType = try { + /** Get the Spark SQL native DataType from Hive's FieldSchema. */ + private def getSparkSQLDataType(hc: FieldSchema): DataType = { + try { CatalystSqlParser.parseDataType(hc.getType) } catch { case e: ParseException => throw new SparkException("Cannot recognize hive type string: " + hc.getType, e) } + } + /** Builds the native StructField from Hive's FieldSchema. */ + def fromHiveColumn(hc: FieldSchema): StructField = { + val columnType = getSparkSQLDataType(hc) val metadata = if (hc.getType != columnType.catalogString) { new MetadataBuilder().putString(HIVE_TYPE_STRING, hc.getType).build() } else { @@ -895,6 +902,10 @@ private[hive] object HiveClientImpl { Option(hc.getComment).map(field.withComment).getOrElse(field) } + private def verifyColumnDataType(schema: StructType): Unit = { + schema.foreach(col => getSparkSQLDataType(toHiveColumn(col))) + } + private def toInputFormat(name: String) = Utils.classForName(name).asInstanceOf[Class[_ <: org.apache.hadoop.mapred.InputFormat[_, _]]] diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 7427948fe138..0cdd9305c6b6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -41,7 +41,35 @@ import org.apache.spark.sql.types._ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton { - test("Hive serde tables should fallback to HDFS for size estimation") { + + test("size estimation for relations is based on row size * number of rows") { + val dsTbl = "rel_est_ds_table" + val hiveTbl = "rel_est_hive_table" + withTable(dsTbl, hiveTbl) { + spark.range(1000L).write.format("parquet").saveAsTable(dsTbl) + spark.range(1000L).write.format("hive").saveAsTable(hiveTbl) + + Seq(dsTbl, hiveTbl).foreach { tbl => + sql(s"ANALYZE TABLE $tbl COMPUTE STATISTICS") + val catalogStats = getCatalogStatistics(tbl) + withSQLConf(SQLConf.CBO_ENABLED.key -> "false") { + val relationStats = spark.table(tbl).queryExecution.optimizedPlan.stats + assert(relationStats.sizeInBytes == catalogStats.sizeInBytes) + assert(relationStats.rowCount.isEmpty) + } + spark.sessionState.catalog.refreshTable(TableIdentifier(tbl)) + withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { + val relationStats = spark.table(tbl).queryExecution.optimizedPlan.stats + // Due to compression in parquet files, in this test, file size is smaller than + // in-memory size. + assert(catalogStats.sizeInBytes < relationStats.sizeInBytes) + assert(catalogStats.rowCount == relationStats.rowCount) + } + } + } + } + + test("Hive serde tables should fallback to HDFS for size estimation") { withSQLConf(SQLConf.ENABLE_FALL_BACK_TO_HDFS_FOR_STATS.key -> "true") { withTable("csv_table") { withTempDir { tempDir => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index d3465a641a1a..9063ef066aa8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -174,6 +174,88 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA test("alter datasource table add columns - partitioned - orc") { testAddColumnPartitioned("orc") } + + test("SPARK-22431: illegal nested type") { + val queries = Seq( + "CREATE TABLE t AS SELECT STRUCT('a' AS `$a`, 1 AS b) q", + "CREATE TABLE t(q STRUCT<`$a`:INT, col2:STRING>, i1 INT)", + "CREATE VIEW t AS SELECT STRUCT('a' AS `$a`, 1 AS b) q") + + queries.foreach(query => { + val err = intercept[SparkException] { + spark.sql(query) + }.getMessage + assert(err.contains("Cannot recognize hive type string")) + }) + + withView("v") { + spark.sql("CREATE VIEW v AS SELECT STRUCT('a' AS `a`, 1 AS b) q") + checkAnswer(sql("SELECT q.`a`, q.b FROM v"), Row("a", 1) :: Nil) + + val err = intercept[SparkException] { + spark.sql("ALTER VIEW v AS SELECT STRUCT('a' AS `$a`, 1 AS b) q") + }.getMessage + assert(err.contains("Cannot recognize hive type string")) + } + } + + test("SPARK-22431: table with nested type") { + withTable("t", "x") { + spark.sql("CREATE TABLE t(q STRUCT<`$a`:INT, col2:STRING>, i1 INT) USING PARQUET") + checkAnswer(spark.table("t"), Nil) + spark.sql("CREATE TABLE x (q STRUCT, i1 INT)") + checkAnswer(spark.table("x"), Nil) + } + } + + test("SPARK-22431: view with nested type") { + withView("v") { + spark.sql("CREATE VIEW v AS SELECT STRUCT('a' AS `a`, 1 AS b) q") + checkAnswer(spark.table("v"), Row(Row("a", 1)) :: Nil) + + spark.sql("ALTER VIEW v AS SELECT STRUCT('a' AS `b`, 1 AS b) q1") + val df = spark.table("v") + assert("q1".equals(df.schema.fields(0).name)) + checkAnswer(df, Row(Row("a", 1)) :: Nil) + } + } + + test("SPARK-22431: alter table tests with nested types") { + withTable("t1", "t2", "t3") { + spark.sql("CREATE TABLE t1 (q STRUCT, i1 INT)") + spark.sql("ALTER TABLE t1 ADD COLUMNS (newcol1 STRUCT<`col1`:STRING, col2:Int>)") + val newcol = spark.sql("SELECT * FROM t1").schema.fields(2).name + assert("newcol1".equals(newcol)) + + spark.sql("CREATE TABLE t2(q STRUCT<`a`:INT, col2:STRING>, i1 INT) USING PARQUET") + spark.sql("ALTER TABLE t2 ADD COLUMNS (newcol1 STRUCT<`$col1`:STRING, col2:Int>)") + spark.sql("ALTER TABLE t2 ADD COLUMNS (newcol2 STRUCT<`col1`:STRING, col2:Int>)") + + val df2 = spark.table("t2") + checkAnswer(df2, Nil) + assert("newcol1".equals(df2.schema.fields(2).name)) + assert("newcol2".equals(df2.schema.fields(3).name)) + + spark.sql("CREATE TABLE t3(q STRUCT<`$a`:INT, col2:STRING>, i1 INT) USING PARQUET") + spark.sql("ALTER TABLE t3 ADD COLUMNS (newcol1 STRUCT<`$col1`:STRING, col2:Int>)") + spark.sql("ALTER TABLE t3 ADD COLUMNS (newcol2 STRUCT<`col1`:STRING, col2:Int>)") + + val df3 = spark.table("t3") + checkAnswer(df3, Nil) + assert("newcol1".equals(df3.schema.fields(2).name)) + assert("newcol2".equals(df3.schema.fields(3).name)) + } + } + + test("SPARK-22431: negative alter table tests with nested types") { + withTable("t1") { + spark.sql("CREATE TABLE t1 (q STRUCT, i1 INT)") + val err = intercept[SparkException] { + spark.sql("ALTER TABLE t1 ADD COLUMNS (newcol1 STRUCT<`$col1`:STRING, col2:Int>)") + }.getMessage + assert(err.contains("Cannot recognize hive type string:")) + } + } } class HiveDDLSuite diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 3066a4f305f0..dfabf1ec2a22 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -18,8 +18,10 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils /** @@ -29,21 +31,32 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto import testImplicits._ test("show cost in explain command") { + val explainCostCommand = "EXPLAIN COST SELECT * FROM src" // For readability, we only show optimized plan and physical plan in explain cost command - checkKeywordsExist(sql("EXPLAIN COST SELECT * FROM src "), + checkKeywordsExist(sql(explainCostCommand), "Optimized Logical Plan", "Physical Plan") - checkKeywordsNotExist(sql("EXPLAIN COST SELECT * FROM src "), + checkKeywordsNotExist(sql(explainCostCommand), "Parsed Logical Plan", "Analyzed Logical Plan") - // Only has sizeInBytes before ANALYZE command - checkKeywordsExist(sql("EXPLAIN COST SELECT * FROM src "), "sizeInBytes") - checkKeywordsNotExist(sql("EXPLAIN COST SELECT * FROM src "), "rowCount") + withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { + // Only has sizeInBytes before ANALYZE command + checkKeywordsExist(sql(explainCostCommand), "sizeInBytes") + checkKeywordsNotExist(sql(explainCostCommand), "rowCount") - // Has both sizeInBytes and rowCount after ANALYZE command - sql("ANALYZE TABLE src COMPUTE STATISTICS") - checkKeywordsExist(sql("EXPLAIN COST SELECT * FROM src "), "sizeInBytes", "rowCount") + // Has both sizeInBytes and rowCount after ANALYZE command + sql("ANALYZE TABLE src COMPUTE STATISTICS") + checkKeywordsExist(sql(explainCostCommand), "sizeInBytes", "rowCount") + } + + spark.sessionState.catalog.refreshTable(TableIdentifier("src")) + + withSQLConf(SQLConf.CBO_ENABLED.key -> "false") { + // Don't show rowCount if cbo is disabled + checkKeywordsExist(sql(explainCostCommand), "sizeInBytes") + checkKeywordsNotExist(sql(explainCostCommand), "rowCount") + } - // No cost information + // No statistics information if "cost" is not specified checkKeywordsNotExist(sql("EXPLAIN SELECT * FROM src "), "sizeInBytes", "rowCount") }