diff --git a/.travis.yml b/.travis.yml
index d7e9f8c0290e..05b94adeeb93 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -43,7 +43,7 @@ notifications:
# 5. Run maven install before running lint-java.
install:
- export MAVEN_SKIP_RC=1
- - build/mvn -T 4 -q -DskipTests -Pmesos -Pyarn -Pkinesis-asl -Phive -Phive-thriftserver install
+ - build/mvn -T 4 -q -DskipTests -Pkubernetes -Pmesos -Pyarn -Pkinesis-asl -Phive -Phive-thriftserver install
# 6. Run lint-java.
script:
diff --git a/NOTICE b/NOTICE
index f4b64b5c3f47..6ec240efbf12 100644
--- a/NOTICE
+++ b/NOTICE
@@ -448,6 +448,12 @@ Copyright (C) 2011 Google Inc.
Apache Commons Pool
Copyright 1999-2009 The Apache Software Foundation
+This product includes/uses Kubernetes & OpenShift 3 Java Client (https://github.com/fabric8io/kubernetes-client)
+Copyright (C) 2015 Red Hat, Inc.
+
+This product includes/uses OkHttp (https://github.com/square/okhttp)
+Copyright (C) 2012 The Android Open Source Project
+
=========================================================================
== NOTICE file corresponding to section 4(d) of the Apache License, ==
== Version 2.0, in this case for the DataNucleus distribution. ==
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackendUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackendUtils.scala
new file mode 100644
index 000000000000..c166d030f2c8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackendUtils.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster
+
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.config.{DYN_ALLOCATION_MAX_EXECUTORS, DYN_ALLOCATION_MIN_EXECUTORS, EXECUTOR_INSTANCES}
+import org.apache.spark.util.Utils
+
+private[spark] object SchedulerBackendUtils {
+ val DEFAULT_NUMBER_EXECUTORS = 2
+
+ /**
+ * Getting the initial target number of executors depends on whether dynamic allocation is
+ * enabled.
+ * If not using dynamic allocation it gets the number of executors requested by the user.
+ */
+ def getInitialTargetExecutorNumber(
+ conf: SparkConf,
+ numExecutors: Int = DEFAULT_NUMBER_EXECUTORS): Int = {
+ if (Utils.isDynamicAllocationEnabled(conf)) {
+ val minNumExecutors = conf.get(DYN_ALLOCATION_MIN_EXECUTORS)
+ val initialNumExecutors = Utils.getDynamicAllocationInitialExecutors(conf)
+ val maxNumExecutors = conf.get(DYN_ALLOCATION_MAX_EXECUTORS)
+ require(initialNumExecutors >= minNumExecutors && initialNumExecutors <= maxNumExecutors,
+ s"initial executor number $initialNumExecutors must between min executor number " +
+ s"$minNumExecutors and max executor number $maxNumExecutors")
+
+ initialNumExecutors
+ } else {
+ conf.get(EXECUTOR_INSTANCES).getOrElse(numExecutors)
+ }
+ }
+}
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index dacc89f5f27d..44f990ec8a5a 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -532,6 +532,14 @@ def __hash__(self):
sbt_test_goals=["mesos/test"]
)
+kubernetes = Module(
+ name="kubernetes",
+ dependencies=[],
+ source_file_regexes=["resource-managers/kubernetes/core"],
+ build_profile_flags=["-Pkubernetes"],
+ sbt_test_goals=["kubernetes/test"]
+)
+
# The root module is a dummy module which is used to run all of the tests.
# No other modules should directly depend on this module.
root = Module(
diff --git a/docs/configuration.md b/docs/configuration.md
index 9b9583d9165e..e42f866c4056 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -1438,10 +1438,10 @@ Apart from these, the following properties are also available, and may be useful
spark.scheduler.minRegisteredResourcesRatio |
- 0.8 for YARN mode; 0.0 for standalone mode and Mesos coarse-grained mode |
+ 0.8 for KUBERNETES mode; 0.8 for YARN mode; 0.0 for standalone mode and Mesos coarse-grained mode |
The minimum ratio of registered resources (registered resources / total expected resources)
- (resources are executors in yarn mode, CPU cores in standalone mode and Mesos coarsed-grained
+ (resources are executors in yarn mode and Kubernetes mode, CPU cores in standalone mode and Mesos coarsed-grained
mode ['spark.cores.max' value is total expected resources for Mesos coarse-grained mode] )
to wait for before scheduling begins. Specified as a double between 0.0 and 1.0.
Regardless of whether the minimum ratio of resources has been reached,
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 5f9821378b27..983770d50683 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1716,6 +1716,8 @@ options.
Note that, for DecimalType(38,0)*, the table above intentionally does not cover all other combinations of scales and precisions because currently we only infer decimal type like `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type.
+ - In PySpark, now we need Pandas 0.19.2 or upper if you want to use Pandas related functionalities, such as `toPandas`, `createDataFrame` from Pandas DataFrame, etc.
+ - In PySpark, the behavior of timestamp values for Pandas related functionalities was changed to respect session timezone. If you want to use the old behavior, you need to set a configuration `spark.sql.execution.pandas.respectSessionTimeZone` to `False`. See [SPARK-22395](https://issues.apache.org/jira/browse/SPARK-22395) for details.
## Upgrading From Spark SQL 2.1 to 2.2
diff --git a/pom.xml b/pom.xml
index 3b2c629f8ec3..7bc66e7d1954 100644
--- a/pom.xml
+++ b/pom.xml
@@ -2664,6 +2664,13 @@
+
+ kubernetes
+
+ resource-managers/kubernetes/core
+
+
+
hive-thriftserver
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index c726ec25478a..75703380cdb4 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -53,11 +53,11 @@ object BuildCommons {
"tags", "sketch", "kvstore"
).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects
- val optionallyEnabledProjects@Seq(mesos, yarn,
+ val optionallyEnabledProjects@Seq(kubernetes, mesos, yarn,
streamingFlumeSink, streamingFlume,
streamingKafka, sparkGangliaLgpl, streamingKinesisAsl,
dockerIntegrationTests, hadoopCloud) =
- Seq("mesos", "yarn",
+ Seq("kubernetes", "mesos", "yarn",
"streaming-flume-sink", "streaming-flume",
"streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl",
"docker-integration-tests", "hadoop-cloud").map(ProjectRef(buildLocation, _))
@@ -671,9 +671,9 @@ object Unidoc {
publish := {},
unidocProjectFilter in(ScalaUnidoc, unidoc) :=
- inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010, sqlKafka010),
+ inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, yarn, tags, streamingKafka010, sqlKafka010),
unidocProjectFilter in(JavaUnidoc, unidoc) :=
- inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010, sqlKafka010),
+ inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, yarn, tags, streamingKafka010, sqlKafka010),
unidocAllClasspaths in (ScalaUnidoc, unidoc) := {
ignoreClasspaths((unidocAllClasspaths in (ScalaUnidoc, unidoc)).value)
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index b95de2c80439..37e7cf3fa662 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -206,11 +206,12 @@ def __repr__(self):
return "ArrowSerializer"
-def _create_batch(series):
+def _create_batch(series, timezone):
"""
Create an Arrow record batch from the given pandas.Series or list of Series, with optional type.
:param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
+ :param timezone: A timezone to respect when handling timestamp values
:return: Arrow RecordBatch
"""
@@ -227,7 +228,7 @@ def _create_batch(series):
def cast_series(s, t):
if type(t) == pa.TimestampType:
# NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680
- return _check_series_convert_timestamps_internal(s.fillna(0))\
+ return _check_series_convert_timestamps_internal(s.fillna(0), timezone)\
.values.astype('datetime64[us]', copy=False)
# NOTE: can not compare None with pyarrow.DataType(), fixed with Arrow >= 0.7.1
elif t is not None and t == pa.date32():
@@ -253,6 +254,10 @@ class ArrowStreamPandasSerializer(Serializer):
Serializes Pandas.Series as Arrow data with Arrow streaming format.
"""
+ def __init__(self, timezone):
+ super(ArrowStreamPandasSerializer, self).__init__()
+ self._timezone = timezone
+
def dump_stream(self, iterator, stream):
"""
Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or
@@ -262,7 +267,7 @@ def dump_stream(self, iterator, stream):
writer = None
try:
for series in iterator:
- batch = _create_batch(series)
+ batch = _create_batch(series, self._timezone)
if writer is None:
write_int(SpecialLengths.START_ARROW_STREAM, stream)
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
@@ -280,7 +285,7 @@ def load_stream(self, stream):
reader = pa.open_stream(stream)
for batch in reader:
# NOTE: changed from pa.Columns.to_pandas, timezone issue in conversion fixed in 0.7.1
- pdf = _check_dataframe_localize_timestamps(batch.to_pandas())
+ pdf = _check_dataframe_localize_timestamps(batch.to_pandas(), self._timezone)
yield [c for _, c in pdf.iteritems()]
def __repr__(self):
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 406686e6df72..9864dc98c1f3 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -39,6 +39,7 @@
from pyspark.sql.streaming import DataStreamWriter
from pyspark.sql.types import IntegralType
from pyspark.sql.types import *
+from pyspark.util import _exception_message
__all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"]
@@ -1881,6 +1882,13 @@ def toPandas(self):
1 5 Bob
"""
import pandas as pd
+
+ if self.sql_ctx.getConf("spark.sql.execution.pandas.respectSessionTimeZone").lower() \
+ == "true":
+ timezone = self.sql_ctx.getConf("spark.sql.session.timeZone")
+ else:
+ timezone = None
+
if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true":
try:
from pyspark.sql.types import _check_dataframe_localize_timestamps
@@ -1889,13 +1897,13 @@ def toPandas(self):
if tables:
table = pyarrow.concat_tables(tables)
pdf = table.to_pandas()
- return _check_dataframe_localize_timestamps(pdf)
+ return _check_dataframe_localize_timestamps(pdf, timezone)
else:
return pd.DataFrame.from_records([], columns=self.columns)
except ImportError as e:
msg = "note: pyarrow must be installed and available on calling Python process " \
"if using spark.sql.execution.arrow.enabled=true"
- raise ImportError("%s\n%s" % (e.message, msg))
+ raise ImportError("%s\n%s" % (_exception_message(e), msg))
else:
pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
@@ -1913,7 +1921,17 @@ def toPandas(self):
for f, t in dtype.items():
pdf[f] = pdf[f].astype(t, copy=False)
- return pdf
+
+ if timezone is None:
+ return pdf
+ else:
+ from pyspark.sql.types import _check_series_convert_timestamps_local_tz
+ for field in self.schema:
+ # TODO: handle nested timestamps, such as ArrayType(TimestampType())?
+ if isinstance(field.dataType, TimestampType):
+ pdf[field.name] = \
+ _check_series_convert_timestamps_local_tz(pdf[field.name], timezone)
+ return pdf
def _collectAsArrow(self):
"""
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index a75bdf8078dd..1ad974e9aa4c 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -828,8 +828,7 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
set, it uses the default value, ``,``.
:param quote: sets the single character used for escaping quoted values where the
separator can be part of the value. If None is set, it uses the default
- value, ``"``. If you would like to turn off quotations, you need to set an
- empty string.
+ value, ``"``. If an empty string is set, it uses ``u0000`` (null character).
:param escape: sets the single character used for escaping quotes inside an already
quoted value. If None is set, it uses the default value, ``\``
:param escapeQuotes: a flag indicating whether values containing quotes should always
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 47c58bb28221..e2435e09af23 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -34,8 +34,9 @@
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.readwriter import DataFrameReader
from pyspark.sql.streaming import DataStreamReader
-from pyspark.sql.types import Row, DataType, StringType, StructType, _make_type_verifier, \
- _infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string
+from pyspark.sql.types import Row, DataType, StringType, StructType, TimestampType, \
+ _make_type_verifier, _infer_schema, _has_nulltype, _merge_type, _create_converter, \
+ _parse_datatype_string
from pyspark.sql.utils import install_exception_handler
__all__ = ["SparkSession"]
@@ -444,11 +445,34 @@ def _get_numpy_record_dtype(self, rec):
record_type_list.append((str(col_names[i]), curr_type))
return np.dtype(record_type_list) if has_rec_fix else None
- def _convert_from_pandas(self, pdf):
+ def _convert_from_pandas(self, pdf, schema, timezone):
"""
Convert a pandas.DataFrame to list of records that can be used to make a DataFrame
:return list of records
"""
+ if timezone is not None:
+ from pyspark.sql.types import _check_series_convert_timestamps_tz_local
+ copied = False
+ if isinstance(schema, StructType):
+ for field in schema:
+ # TODO: handle nested timestamps, such as ArrayType(TimestampType())?
+ if isinstance(field.dataType, TimestampType):
+ s = _check_series_convert_timestamps_tz_local(pdf[field.name], timezone)
+ if not copied and s is not pdf[field.name]:
+ # Copy once if the series is modified to prevent the original Pandas
+ # DataFrame from being updated
+ pdf = pdf.copy()
+ copied = True
+ pdf[field.name] = s
+ else:
+ for column, series in pdf.iteritems():
+ s = _check_series_convert_timestamps_tz_local(pdf[column], timezone)
+ if not copied and s is not pdf[column]:
+ # Copy once if the series is modified to prevent the original Pandas
+ # DataFrame from being updated
+ pdf = pdf.copy()
+ copied = True
+ pdf[column] = s
# Convert pandas.DataFrame to list of numpy records
np_records = pdf.to_records(index=False)
@@ -462,15 +486,19 @@ def _convert_from_pandas(self, pdf):
# Convert list of numpy records to python lists
return [r.tolist() for r in np_records]
- def _create_from_pandas_with_arrow(self, pdf, schema):
+ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
"""
Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
data types will be used to coerce the data in Pandas to Arrow conversion.
"""
from pyspark.serializers import ArrowSerializer, _create_batch
- from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType
- from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
+ from pyspark.sql.types import from_arrow_schema, to_arrow_type, \
+ _old_pandas_exception_message, TimestampType
+ try:
+ from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
+ except ImportError as e:
+ raise ImportError(_old_pandas_exception_message(e))
# Determine arrow types to coerce data when creating batches
if isinstance(schema, StructType):
@@ -488,7 +516,8 @@ def _create_from_pandas_with_arrow(self, pdf, schema):
pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))
# Create Arrow record batches
- batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)])
+ batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)],
+ timezone)
for pdf_slice in pdf_slices]
# Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)
@@ -606,6 +635,11 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
except Exception:
has_pandas = False
if has_pandas and isinstance(data, pandas.DataFrame):
+ if self.conf.get("spark.sql.execution.pandas.respectSessionTimeZone").lower() \
+ == "true":
+ timezone = self.conf.get("spark.sql.session.timeZone")
+ else:
+ timezone = None
# If no schema supplied by user then get the names of columns only
if schema is None:
@@ -614,11 +648,11 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \
and len(data) > 0:
try:
- return self._create_from_pandas_with_arrow(data, schema)
+ return self._create_from_pandas_with_arrow(data, schema, timezone)
except Exception as e:
warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e))
# Fallback to create DataFrame without arrow if raise some exception
- data = self._convert_from_pandas(data)
+ data = self._convert_from_pandas(data, schema, timezone)
if isinstance(schema, StructType):
verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 762afe0d730f..b4d32d8de8a2 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -49,9 +49,14 @@
import unittest
_have_pandas = False
+_have_old_pandas = False
try:
import pandas
- _have_pandas = True
+ try:
+ import pandas.api
+ _have_pandas = True
+ except:
+ _have_old_pandas = True
except:
# No Pandas, but that's okay, we'll skip those tests
pass
@@ -2565,21 +2570,38 @@ def count_bucketed_cols(names, table="pyspark_bucket"):
.mode("overwrite").saveAsTable("pyspark_bucket"))
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
- @unittest.skipIf(not _have_pandas, "Pandas not installed")
- def test_to_pandas(self):
+ def _to_pandas(self):
+ from datetime import datetime, date
import numpy as np
schema = StructType().add("a", IntegerType()).add("b", StringType())\
- .add("c", BooleanType()).add("d", FloatType())
+ .add("c", BooleanType()).add("d", FloatType())\
+ .add("dt", DateType()).add("ts", TimestampType())
data = [
- (1, "foo", True, 3.0), (2, "foo", True, 5.0),
- (3, "bar", False, -1.0), (4, "bar", False, 6.0),
+ (1, "foo", True, 3.0, date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
+ (2, "foo", True, 5.0, None, None),
+ (3, "bar", False, -1.0, date(2012, 3, 3), datetime(2012, 3, 3, 3, 3, 3)),
+ (4, "bar", False, 6.0, date(2100, 4, 4), datetime(2100, 4, 4, 4, 4, 4)),
]
df = self.spark.createDataFrame(data, schema)
- types = df.toPandas().dtypes
+ return df.toPandas()
+
+ @unittest.skipIf(not _have_pandas, "Pandas not installed")
+ def test_to_pandas(self):
+ import numpy as np
+ pdf = self._to_pandas()
+ types = pdf.dtypes
self.assertEquals(types[0], np.int32)
self.assertEquals(types[1], np.object)
self.assertEquals(types[2], np.bool)
self.assertEquals(types[3], np.float32)
+ self.assertEquals(types[4], 'datetime64[ns]')
+ self.assertEquals(types[5], 'datetime64[ns]')
+
+ @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed")
+ def test_to_pandas_old(self):
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(ImportError, 'Pandas \(.*\) must be installed'):
+ self._to_pandas()
@unittest.skipIf(not _have_pandas, "Pandas not installed")
def test_to_pandas_avoid_astype(self):
@@ -2614,6 +2636,16 @@ def test_create_dataframe_from_pandas_with_timestamp(self):
self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType))
self.assertTrue(isinstance(df.schema['d'].dataType, DateType))
+ @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed")
+ def test_create_dataframe_from_old_pandas(self):
+ import pandas as pd
+ from datetime import datetime
+ pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
+ "d": [pd.Timestamp.now().date()]})
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(ImportError, 'Pandas \(.*\) must be installed'):
+ self.spark.createDataFrame(pdf)
+
class HiveSparkSubmitTests(SparkSubmitTests):
@@ -3103,7 +3135,7 @@ def __init__(self, **kwargs):
_make_type_verifier(data_type, nullable=False)(obj)
-@unittest.skipIf(not _have_arrow, "Arrow not installed")
+@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
class ArrowTests(ReusedSQLTestCase):
@classmethod
@@ -3169,16 +3201,47 @@ def test_null_conversion(self):
null_counts = pdf.isnull().sum().tolist()
self.assertTrue(all([c == 1 for c in null_counts]))
- def test_toPandas_arrow_toggle(self):
- df = self.spark.createDataFrame(self.data, schema=self.schema)
+ def _toPandas_arrow_toggle(self, df):
self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
try:
pdf = df.toPandas()
finally:
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
pdf_arrow = df.toPandas()
+ return pdf, pdf_arrow
+
+ def test_toPandas_arrow_toggle(self):
+ df = self.spark.createDataFrame(self.data, schema=self.schema)
+ pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
self.assertFramesEqual(pdf_arrow, pdf)
+ def test_toPandas_respect_session_timezone(self):
+ df = self.spark.createDataFrame(self.data, schema=self.schema)
+ orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
+ try:
+ timezone = "America/New_York"
+ self.spark.conf.set("spark.sql.session.timeZone", timezone)
+ self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
+ try:
+ pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
+ self.assertFramesEqual(pdf_arrow_la, pdf_la)
+ finally:
+ self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
+ pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df)
+ self.assertFramesEqual(pdf_arrow_ny, pdf_ny)
+
+ self.assertFalse(pdf_ny.equals(pdf_la))
+
+ from pyspark.sql.types import _check_series_convert_timestamps_local_tz
+ pdf_la_corrected = pdf_la.copy()
+ for field in self.schema:
+ if isinstance(field.dataType, TimestampType):
+ pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz(
+ pdf_la_corrected[field.name], timezone)
+ self.assertFramesEqual(pdf_ny, pdf_la_corrected)
+ finally:
+ self.spark.conf.set("spark.sql.session.timeZone", orig_tz)
+
def test_pandas_round_trip(self):
pdf = self.create_pandas_data_frame()
df = self.spark.createDataFrame(self.data, schema=self.schema)
@@ -3192,16 +3255,50 @@ def test_filtered_frame(self):
self.assertEqual(pdf.columns[0], "i")
self.assertTrue(pdf.empty)
- def test_createDataFrame_toggle(self):
- pdf = self.create_pandas_data_frame()
+ def _createDataFrame_toggle(self, pdf, schema=None):
self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
try:
- df_no_arrow = self.spark.createDataFrame(pdf)
+ df_no_arrow = self.spark.createDataFrame(pdf, schema=schema)
finally:
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
- df_arrow = self.spark.createDataFrame(pdf)
+ df_arrow = self.spark.createDataFrame(pdf, schema=schema)
+ return df_no_arrow, df_arrow
+
+ def test_createDataFrame_toggle(self):
+ pdf = self.create_pandas_data_frame()
+ df_no_arrow, df_arrow = self._createDataFrame_toggle(pdf, schema=self.schema)
self.assertEquals(df_no_arrow.collect(), df_arrow.collect())
+ def test_createDataFrame_respect_session_timezone(self):
+ from datetime import timedelta
+ pdf = self.create_pandas_data_frame()
+ orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
+ try:
+ timezone = "America/New_York"
+ self.spark.conf.set("spark.sql.session.timeZone", timezone)
+ self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
+ try:
+ df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema)
+ result_la = df_no_arrow_la.collect()
+ result_arrow_la = df_arrow_la.collect()
+ self.assertEqual(result_la, result_arrow_la)
+ finally:
+ self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
+ df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, schema=self.schema)
+ result_ny = df_no_arrow_ny.collect()
+ result_arrow_ny = df_arrow_ny.collect()
+ self.assertEqual(result_ny, result_arrow_ny)
+
+ self.assertNotEqual(result_ny, result_la)
+
+ # Correct result_la by adjusting 3 hours difference between Los Angeles and New York
+ result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '7_timestamp_t' else v
+ for k, v in row.asDict().items()})
+ for row in result_la]
+ self.assertEqual(result_ny, result_la_corrected)
+ finally:
+ self.spark.conf.set("spark.sql.session.timeZone", orig_tz)
+
def test_createDataFrame_with_schema(self):
pdf = self.create_pandas_data_frame()
df = self.spark.createDataFrame(pdf, schema=self.schema)
@@ -3385,6 +3482,27 @@ def foo(k, v):
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
class VectorizedUDFTests(ReusedSQLTestCase):
+ @classmethod
+ def setUpClass(cls):
+ ReusedSQLTestCase.setUpClass()
+
+ # Synchronize default timezone between Python and Java
+ cls.tz_prev = os.environ.get("TZ", None) # save current tz if set
+ tz = "America/Los_Angeles"
+ os.environ["TZ"] = tz
+ time.tzset()
+
+ cls.sc.environment["TZ"] = tz
+ cls.spark.conf.set("spark.sql.session.timeZone", tz)
+
+ @classmethod
+ def tearDownClass(cls):
+ del os.environ["TZ"]
+ if cls.tz_prev is not None:
+ os.environ["TZ"] = cls.tz_prev
+ time.tzset()
+ ReusedSQLTestCase.tearDownClass()
+
def test_vectorized_udf_basic(self):
from pyspark.sql.functions import pandas_udf, col
df = self.spark.range(10).select(
@@ -3621,29 +3739,37 @@ def test_vectorized_udf_timestamps(self):
data = [(0, datetime(1969, 1, 1, 1, 1, 1)),
(1, datetime(2012, 2, 2, 2, 2, 2)),
(2, None),
- (3, datetime(2100, 4, 4, 4, 4, 4))]
+ (3, datetime(2100, 3, 3, 3, 3, 3))]
+
df = self.spark.createDataFrame(data, schema=schema)
# Check that a timestamp passed through a pandas_udf will not be altered by timezone calc
f_timestamp_copy = pandas_udf(lambda t: t, returnType=TimestampType())
df = df.withColumn("timestamp_copy", f_timestamp_copy(col("timestamp")))
- @pandas_udf(returnType=BooleanType())
+ @pandas_udf(returnType=StringType())
def check_data(idx, timestamp, timestamp_copy):
+ import pandas as pd
+ msgs = []
is_equal = timestamp.isnull() # use this array to check values are equal
for i in range(len(idx)):
# Check that timestamps are as expected in the UDF
- is_equal[i] = (is_equal[i] and data[idx[i]][1] is None) or \
- timestamp[i].to_pydatetime() == data[idx[i]][1]
- return is_equal
-
- result = df.withColumn("is_equal", check_data(col("idx"), col("timestamp"),
- col("timestamp_copy"))).collect()
+ if (is_equal[i] and data[idx[i]][1] is None) or \
+ timestamp[i].to_pydatetime() == data[idx[i]][1]:
+ msgs.append(None)
+ else:
+ msgs.append(
+ "timestamp values are not equal (timestamp='%s': data[%d][1]='%s')"
+ % (timestamp[i], idx[i], data[idx[i]][1]))
+ return pd.Series(msgs)
+
+ result = df.withColumn("check_data", check_data(col("idx"), col("timestamp"),
+ col("timestamp_copy"))).collect()
# Check that collection values are correct
self.assertEquals(len(data), len(result))
for i in range(len(result)):
self.assertEquals(data[i][1], result[i][1]) # "timestamp" col
- self.assertTrue(result[i][3]) # "is_equal" data in udf was as expected
+ self.assertIsNone(result[i][3]) # "check_data" col
def test_vectorized_udf_return_timestamp_tz(self):
from pyspark.sql.functions import pandas_udf, col
@@ -3683,6 +3809,48 @@ def check_records_per_batch(x):
else:
self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value)
+ def test_vectorized_udf_timestamps_respect_session_timezone(self):
+ from pyspark.sql.functions import pandas_udf, col
+ from datetime import datetime
+ import pandas as pd
+ schema = StructType([
+ StructField("idx", LongType(), True),
+ StructField("timestamp", TimestampType(), True)])
+ data = [(1, datetime(1969, 1, 1, 1, 1, 1)),
+ (2, datetime(2012, 2, 2, 2, 2, 2)),
+ (3, None),
+ (4, datetime(2100, 3, 3, 3, 3, 3))]
+ df = self.spark.createDataFrame(data, schema=schema)
+
+ f_timestamp_copy = pandas_udf(lambda ts: ts, TimestampType())
+ internal_value = pandas_udf(
+ lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else None), LongType())
+
+ orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
+ try:
+ timezone = "America/New_York"
+ self.spark.conf.set("spark.sql.session.timeZone", timezone)
+ self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
+ try:
+ df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
+ .withColumn("internal_value", internal_value(col("timestamp")))
+ result_la = df_la.select(col("idx"), col("internal_value")).collect()
+ # Correct result_la by adjusting 3 hours difference between Los Angeles and New York
+ diff = 3 * 60 * 60 * 1000 * 1000 * 1000
+ result_la_corrected = \
+ df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect()
+ finally:
+ self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
+
+ df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
+ .withColumn("internal_value", internal_value(col("timestamp")))
+ result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect()
+
+ self.assertNotEqual(result_ny, result_la)
+ self.assertEqual(result_ny, result_la_corrected)
+ finally:
+ self.spark.conf.set("spark.sql.session.timeZone", orig_tz)
+
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
class GroupbyApplyTests(ReusedSQLTestCase):
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index fe62f60dd6d0..78abc32a35a1 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -35,6 +35,7 @@
from pyspark import SparkContext
from pyspark.serializers import CloudPickleSerializer
+from pyspark.util import _exception_message
__all__ = [
"DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
@@ -1678,37 +1679,105 @@ def from_arrow_schema(arrow_schema):
for field in arrow_schema])
-def _check_dataframe_localize_timestamps(pdf):
+def _old_pandas_exception_message(e):
+ """ Create an error message for importing old Pandas.
"""
- Convert timezone aware timestamps to timezone-naive in local time
+ msg = "note: Pandas (>=0.19.2) must be installed and available on calling Python process"
+ return "%s\n%s" % (_exception_message(e), msg)
+
+
+def _check_dataframe_localize_timestamps(pdf, timezone):
+ """
+ Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone
:param pdf: pandas.DataFrame
- :return pandas.DataFrame where any timezone aware columns have be converted to tz-naive
+ :param timezone: the timezone to convert. if None then use local timezone
+ :return pandas.DataFrame where any timezone aware columns have been converted to tz-naive
"""
- from pandas.api.types import is_datetime64tz_dtype
+ try:
+ from pandas.api.types import is_datetime64tz_dtype
+ except ImportError as e:
+ raise ImportError(_old_pandas_exception_message(e))
+ tz = timezone or 'tzlocal()'
for column, series in pdf.iteritems():
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64tz_dtype(series.dtype):
- pdf[column] = series.dt.tz_convert('tzlocal()').dt.tz_localize(None)
+ pdf[column] = series.dt.tz_convert(tz).dt.tz_localize(None)
return pdf
-def _check_series_convert_timestamps_internal(s):
+def _check_series_convert_timestamps_internal(s, timezone):
"""
- Convert a tz-naive timestamp in local tz to UTC normalized for Spark internal storage
+ Convert a tz-naive timestamp in the specified timezone or local timezone to UTC normalized for
+ Spark internal storage
+
:param s: a pandas.Series
+ :param timezone: the timezone to convert. if None then use local timezone
:return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone
"""
- from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
+ try:
+ from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
+ except ImportError as e:
+ raise ImportError(_old_pandas_exception_message(e))
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64_dtype(s.dtype):
- return s.dt.tz_localize('tzlocal()').dt.tz_convert('UTC')
+ tz = timezone or 'tzlocal()'
+ return s.dt.tz_localize(tz).dt.tz_convert('UTC')
elif is_datetime64tz_dtype(s.dtype):
return s.dt.tz_convert('UTC')
else:
return s
+def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone):
+ """
+ Convert timestamp to timezone-naive in the specified timezone or local timezone
+
+ :param s: a pandas.Series
+ :param from_timezone: the timezone to convert from. if None then use local timezone
+ :param to_timezone: the timezone to convert to. if None then use local timezone
+ :return pandas.Series where if it is a timestamp, has been converted to tz-naive
+ """
+ try:
+ import pandas as pd
+ from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype
+ except ImportError as e:
+ raise ImportError(_old_pandas_exception_message(e))
+ from_tz = from_timezone or 'tzlocal()'
+ to_tz = to_timezone or 'tzlocal()'
+ # TODO: handle nested timestamps, such as ArrayType(TimestampType())?
+ if is_datetime64tz_dtype(s.dtype):
+ return s.dt.tz_convert(to_tz).dt.tz_localize(None)
+ elif is_datetime64_dtype(s.dtype) and from_tz != to_tz:
+ # `s.dt.tz_localize('tzlocal()')` doesn't work properly when including NaT.
+ return s.apply(lambda ts: ts.tz_localize(from_tz).tz_convert(to_tz).tz_localize(None)
+ if ts is not pd.NaT else pd.NaT)
+ else:
+ return s
+
+
+def _check_series_convert_timestamps_local_tz(s, timezone):
+ """
+ Convert timestamp to timezone-naive in the specified timezone or local timezone
+
+ :param s: a pandas.Series
+ :param timezone: the timezone to convert to. if None then use local timezone
+ :return pandas.Series where if it is a timestamp, has been converted to tz-naive
+ """
+ return _check_series_convert_timestamps_localize(s, None, timezone)
+
+
+def _check_series_convert_timestamps_tz_local(s, timezone):
+ """
+ Convert timestamp to timezone-naive in the specified timezone or local timezone
+
+ :param s: a pandas.Series
+ :param timezone: the timezone to convert from. if None then use local timezone
+ :return pandas.Series where if it is a timestamp, has been converted to tz-naive
+ """
+ return _check_series_convert_timestamps_localize(s, timezone, None)
+
+
def _test():
import doctest
from pyspark.context import SparkContext
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 939643071943..e6737ae1c128 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -150,7 +150,8 @@ def read_udfs(pickleSer, infile, eval_type):
if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF \
or eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF:
- ser = ArrowStreamPandasSerializer()
+ timezone = utf8_deserializer.loads(infile)
+ ser = ArrowStreamPandasSerializer(timezone)
else:
ser = BatchedSerializer(PickleSerializer(), 100)
diff --git a/python/setup.py b/python/setup.py
index 02612ff8a724..310670e697a8 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -201,7 +201,7 @@ def _supports_symlinks():
extras_require={
'ml': ['numpy>=1.7'],
'mllib': ['numpy>=1.7'],
- 'sql': ['pandas>=0.13.0']
+ 'sql': ['pandas>=0.19.2']
},
classifiers=[
'Development Status :: 5 - Production/Stable',
diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml
new file mode 100644
index 000000000000..7d35aea8a414
--- /dev/null
+++ b/resource-managers/kubernetes/core/pom.xml
@@ -0,0 +1,100 @@
+
+
+
+ 4.0.0
+
+ org.apache.spark
+ spark-parent_2.11
+ 2.3.0-SNAPSHOT
+ ../../../pom.xml
+
+
+ spark-kubernetes_2.11
+ jar
+ Spark Project Kubernetes
+
+ kubernetes
+ 3.0.0
+
+
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+
+ io.fabric8
+ kubernetes-client
+ ${kubernetes.client.version}
+
+
+ com.fasterxml.jackson.core
+ *
+
+
+ com.fasterxml.jackson.dataformat
+ jackson-dataformat-yaml
+
+
+
+
+
+
+ com.fasterxml.jackson.dataformat
+ jackson-dataformat-yaml
+ ${fasterxml.jackson.version}
+
+
+
+
+ com.google.guava
+ guava
+
+
+
+
+ org.mockito
+ mockito-core
+ test
+
+
+
+ com.squareup.okhttp3
+ okhttp
+ 3.8.1
+
+
+
+
+
+
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
+
+
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala
new file mode 100644
index 000000000000..f0742b91987b
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.ConfigBuilder
+import org.apache.spark.network.util.ByteUnit
+
+private[spark] object Config extends Logging {
+
+ val KUBERNETES_NAMESPACE =
+ ConfigBuilder("spark.kubernetes.namespace")
+ .doc("The namespace that will be used for running the driver and executor pods. When using " +
+ "spark-submit in cluster mode, this can also be passed to spark-submit via the " +
+ "--kubernetes-namespace command line argument.")
+ .stringConf
+ .createWithDefault("default")
+
+ val EXECUTOR_DOCKER_IMAGE =
+ ConfigBuilder("spark.kubernetes.executor.docker.image")
+ .doc("Docker image to use for the executors. Specify this using the standard Docker tag " +
+ "format.")
+ .stringConf
+ .createOptional
+
+ val DOCKER_IMAGE_PULL_POLICY =
+ ConfigBuilder("spark.kubernetes.docker.image.pullPolicy")
+ .doc("Kubernetes image pull policy. Valid values are Always, Never, and IfNotPresent.")
+ .stringConf
+ .checkValues(Set("Always", "Never", "IfNotPresent"))
+ .createWithDefault("IfNotPresent")
+
+ val APISERVER_AUTH_DRIVER_CONF_PREFIX =
+ "spark.kubernetes.authenticate.driver"
+ val APISERVER_AUTH_DRIVER_MOUNTED_CONF_PREFIX =
+ "spark.kubernetes.authenticate.driver.mounted"
+ val OAUTH_TOKEN_CONF_SUFFIX = "oauthToken"
+ val OAUTH_TOKEN_FILE_CONF_SUFFIX = "oauthTokenFile"
+ val CLIENT_KEY_FILE_CONF_SUFFIX = "clientKeyFile"
+ val CLIENT_CERT_FILE_CONF_SUFFIX = "clientCertFile"
+ val CA_CERT_FILE_CONF_SUFFIX = "caCertFile"
+
+ val KUBERNETES_SERVICE_ACCOUNT_NAME =
+ ConfigBuilder(s"$APISERVER_AUTH_DRIVER_CONF_PREFIX.serviceAccountName")
+ .doc("Service account that is used when running the driver pod. The driver pod uses " +
+ "this service account when requesting executor pods from the API server. If specific " +
+ "credentials are given for the driver pod to use, the driver will favor " +
+ "using those credentials instead.")
+ .stringConf
+ .createOptional
+
+ // Note that while we set a default for this when we start up the
+ // scheduler, the specific default value is dynamically determined
+ // based on the executor memory.
+ val KUBERNETES_EXECUTOR_MEMORY_OVERHEAD =
+ ConfigBuilder("spark.kubernetes.executor.memoryOverhead")
+ .doc("The amount of off-heap memory (in megabytes) to be allocated per executor. This " +
+ "is memory that accounts for things like VM overheads, interned strings, other native " +
+ "overheads, etc. This tends to grow with the executor size. (typically 6-10%).")
+ .bytesConf(ByteUnit.MiB)
+ .createOptional
+
+ val KUBERNETES_EXECUTOR_LABEL_PREFIX = "spark.kubernetes.executor.label."
+ val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation."
+
+ val KUBERNETES_DRIVER_POD_NAME =
+ ConfigBuilder("spark.kubernetes.driver.pod.name")
+ .doc("Name of the driver pod.")
+ .stringConf
+ .createOptional
+
+ val KUBERNETES_EXECUTOR_POD_NAME_PREFIX =
+ ConfigBuilder("spark.kubernetes.executor.podNamePrefix")
+ .doc("Prefix to use in front of the executor pod names.")
+ .internal()
+ .stringConf
+ .createWithDefault("spark")
+
+ val KUBERNETES_ALLOCATION_BATCH_SIZE =
+ ConfigBuilder("spark.kubernetes.allocation.batch.size")
+ .doc("Number of pods to launch at once in each round of executor allocation.")
+ .intConf
+ .checkValue(value => value > 0, "Allocation batch size should be a positive integer")
+ .createWithDefault(5)
+
+ val KUBERNETES_ALLOCATION_BATCH_DELAY =
+ ConfigBuilder("spark.kubernetes.allocation.batch.delay")
+ .doc("Number of seconds to wait between each round of executor allocation.")
+ .longConf
+ .checkValue(value => value > 0, "Allocation batch delay should be a positive integer")
+ .createWithDefault(1)
+
+ val KUBERNETES_EXECUTOR_LIMIT_CORES =
+ ConfigBuilder("spark.kubernetes.executor.limit.cores")
+ .doc("Specify the hard cpu limit for a single executor pod")
+ .stringConf
+ .createOptional
+
+ val KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS =
+ ConfigBuilder("spark.kubernetes.executor.lostCheck.maxAttempts")
+ .doc("Maximum number of attempts allowed for checking the reason of an executor loss " +
+ "before it is assumed that the executor failed.")
+ .intConf
+ .checkValue(value => value > 0, "Maximum attempts of checks of executor lost reason " +
+ "must be a positive integer")
+ .createWithDefault(10)
+
+ val KUBERNETES_NODE_SELECTOR_PREFIX = "spark.kubernetes.node.selector."
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/ConfigurationUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/ConfigurationUtils.scala
new file mode 100644
index 000000000000..01717479fddd
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/ConfigurationUtils.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.k8s
+
+import org.apache.spark.SparkConf
+
+private[spark] object ConfigurationUtils {
+
+ /**
+ * Extract and parse Spark configuration properties with a given name prefix and
+ * return the result as a Map. Keys must not have more than one value.
+ *
+ * @param sparkConf Spark configuration
+ * @param prefix the given property name prefix
+ * @return a Map storing the configuration property keys and values
+ */
+ def parsePrefixedKeyValuePairs(
+ sparkConf: SparkConf,
+ prefix: String): Map[String, String] = {
+ sparkConf.getAllWithPrefix(prefix).toMap
+ }
+
+ def requireNandDefined(opt1: Option[_], opt2: Option[_], errMessage: String): Unit = {
+ opt1.foreach { _ => require(opt2.isEmpty, errMessage) }
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala
new file mode 100644
index 000000000000..4ddeefb15a89
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s
+
+private[spark] object Constants {
+
+ // Labels
+ val SPARK_APP_ID_LABEL = "spark-app-selector"
+ val SPARK_EXECUTOR_ID_LABEL = "spark-exec-id"
+ val SPARK_ROLE_LABEL = "spark-role"
+ val SPARK_POD_DRIVER_ROLE = "driver"
+ val SPARK_POD_EXECUTOR_ROLE = "executor"
+
+ // Default and fixed ports
+ val DEFAULT_DRIVER_PORT = 7078
+ val DEFAULT_BLOCKMANAGER_PORT = 7079
+ val BLOCK_MANAGER_PORT_NAME = "blockmanager"
+ val EXECUTOR_PORT_NAME = "executor"
+
+ // Environment Variables
+ val ENV_EXECUTOR_PORT = "SPARK_EXECUTOR_PORT"
+ val ENV_DRIVER_URL = "SPARK_DRIVER_URL"
+ val ENV_EXECUTOR_CORES = "SPARK_EXECUTOR_CORES"
+ val ENV_EXECUTOR_MEMORY = "SPARK_EXECUTOR_MEMORY"
+ val ENV_APPLICATION_ID = "SPARK_APPLICATION_ID"
+ val ENV_EXECUTOR_ID = "SPARK_EXECUTOR_ID"
+ val ENV_EXECUTOR_POD_IP = "SPARK_EXECUTOR_POD_IP"
+ val ENV_EXECUTOR_EXTRA_CLASSPATH = "SPARK_EXECUTOR_EXTRA_CLASSPATH"
+ val ENV_MOUNTED_CLASSPATH = "SPARK_MOUNTED_CLASSPATH"
+ val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_"
+
+ // Miscellaneous
+ val KUBERNETES_MASTER_INTERNAL_URL = "https://kubernetes.default.svc"
+ val MEMORY_OVERHEAD_FACTOR = 0.10
+ val MEMORY_OVERHEAD_MIN_MIB = 384L
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala
new file mode 100644
index 000000000000..1e3f055e0576
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s
+
+import java.io.File
+
+import com.google.common.base.Charsets
+import com.google.common.io.Files
+import io.fabric8.kubernetes.client.{ConfigBuilder, DefaultKubernetesClient, KubernetesClient}
+import io.fabric8.kubernetes.client.utils.HttpClientUtils
+import okhttp3.Dispatcher
+
+import org.apache.spark.SparkConf
+import org.apache.spark.deploy.k8s.Config._
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * Spark-opinionated builder for Kubernetes clients. It uses a prefix plus common suffixes to
+ * parse configuration keys, similar to the manner in which Spark's SecurityManager parses SSL
+ * options for different components.
+ */
+private[spark] object SparkKubernetesClientFactory {
+
+ def createKubernetesClient(
+ master: String,
+ namespace: Option[String],
+ kubernetesAuthConfPrefix: String,
+ sparkConf: SparkConf,
+ defaultServiceAccountToken: Option[File],
+ defaultServiceAccountCaCert: Option[File]): KubernetesClient = {
+ val oauthTokenFileConf = s"$kubernetesAuthConfPrefix.$OAUTH_TOKEN_FILE_CONF_SUFFIX"
+ val oauthTokenConf = s"$kubernetesAuthConfPrefix.$OAUTH_TOKEN_CONF_SUFFIX"
+ val oauthTokenFile = sparkConf.getOption(oauthTokenFileConf)
+ .map(new File(_))
+ .orElse(defaultServiceAccountToken)
+ val oauthTokenValue = sparkConf.getOption(oauthTokenConf)
+ ConfigurationUtils.requireNandDefined(
+ oauthTokenFile,
+ oauthTokenValue,
+ s"Cannot specify OAuth token through both a file $oauthTokenFileConf and a " +
+ s"value $oauthTokenConf.")
+
+ val caCertFile = sparkConf
+ .getOption(s"$kubernetesAuthConfPrefix.$CA_CERT_FILE_CONF_SUFFIX")
+ .orElse(defaultServiceAccountCaCert.map(_.getAbsolutePath))
+ val clientKeyFile = sparkConf
+ .getOption(s"$kubernetesAuthConfPrefix.$CLIENT_KEY_FILE_CONF_SUFFIX")
+ val clientCertFile = sparkConf
+ .getOption(s"$kubernetesAuthConfPrefix.$CLIENT_CERT_FILE_CONF_SUFFIX")
+ val dispatcher = new Dispatcher(
+ ThreadUtils.newDaemonCachedThreadPool("kubernetes-dispatcher"))
+ val config = new ConfigBuilder()
+ .withApiVersion("v1")
+ .withMasterUrl(master)
+ .withWebsocketPingInterval(0)
+ .withOption(oauthTokenValue) {
+ (token, configBuilder) => configBuilder.withOauthToken(token)
+ }.withOption(oauthTokenFile) {
+ (file, configBuilder) =>
+ configBuilder.withOauthToken(Files.toString(file, Charsets.UTF_8))
+ }.withOption(caCertFile) {
+ (file, configBuilder) => configBuilder.withCaCertFile(file)
+ }.withOption(clientKeyFile) {
+ (file, configBuilder) => configBuilder.withClientKeyFile(file)
+ }.withOption(clientCertFile) {
+ (file, configBuilder) => configBuilder.withClientCertFile(file)
+ }.withOption(namespace) {
+ (ns, configBuilder) => configBuilder.withNamespace(ns)
+ }.build()
+ val baseHttpClient = HttpClientUtils.createHttpClient(config)
+ val httpClientWithCustomDispatcher = baseHttpClient.newBuilder()
+ .dispatcher(dispatcher)
+ .build()
+ new DefaultKubernetesClient(httpClientWithCustomDispatcher, config)
+ }
+
+ private implicit class OptionConfigurableConfigBuilder(val configBuilder: ConfigBuilder)
+ extends AnyVal {
+
+ def withOption[T]
+ (option: Option[T])
+ (configurator: ((T, ConfigBuilder) => ConfigBuilder)): ConfigBuilder = {
+ option.map { opt =>
+ configurator(opt, configBuilder)
+ }.getOrElse(configBuilder)
+ }
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala
new file mode 100644
index 000000000000..f79155b117b6
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster.k8s
+
+import scala.collection.JavaConverters._
+
+import io.fabric8.kubernetes.api.model._
+
+import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.deploy.k8s.Config._
+import org.apache.spark.deploy.k8s.ConfigurationUtils
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.util.Utils
+
+/**
+ * A factory class for configuring and creating executor pods.
+ */
+private[spark] trait ExecutorPodFactory {
+
+ /**
+ * Configure and construct an executor pod with the given parameters.
+ */
+ def createExecutorPod(
+ executorId: String,
+ applicationId: String,
+ driverUrl: String,
+ executorEnvs: Seq[(String, String)],
+ driverPod: Pod,
+ nodeToLocalTaskCount: Map[String, Int]): Pod
+}
+
+private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf)
+ extends ExecutorPodFactory {
+
+ private val executorExtraClasspath =
+ sparkConf.get(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH)
+
+ private val executorLabels = ConfigurationUtils.parsePrefixedKeyValuePairs(
+ sparkConf,
+ KUBERNETES_EXECUTOR_LABEL_PREFIX)
+ require(
+ !executorLabels.contains(SPARK_APP_ID_LABEL),
+ s"Custom executor labels cannot contain $SPARK_APP_ID_LABEL as it is reserved for Spark.")
+ require(
+ !executorLabels.contains(SPARK_EXECUTOR_ID_LABEL),
+ s"Custom executor labels cannot contain $SPARK_EXECUTOR_ID_LABEL as it is reserved for" +
+ " Spark.")
+ require(
+ !executorLabels.contains(SPARK_ROLE_LABEL),
+ s"Custom executor labels cannot contain $SPARK_ROLE_LABEL as it is reserved for Spark.")
+
+ private val executorAnnotations =
+ ConfigurationUtils.parsePrefixedKeyValuePairs(
+ sparkConf,
+ KUBERNETES_EXECUTOR_ANNOTATION_PREFIX)
+ private val nodeSelector =
+ ConfigurationUtils.parsePrefixedKeyValuePairs(
+ sparkConf,
+ KUBERNETES_NODE_SELECTOR_PREFIX)
+
+ private val executorDockerImage = sparkConf
+ .get(EXECUTOR_DOCKER_IMAGE)
+ .getOrElse(throw new SparkException("Must specify the executor Docker image"))
+ private val dockerImagePullPolicy = sparkConf.get(DOCKER_IMAGE_PULL_POLICY)
+ private val blockManagerPort = sparkConf
+ .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT)
+
+ private val executorPodNamePrefix = sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX)
+
+ private val executorMemoryMiB = sparkConf.get(org.apache.spark.internal.config.EXECUTOR_MEMORY)
+ private val executorMemoryString = sparkConf.get(
+ org.apache.spark.internal.config.EXECUTOR_MEMORY.key,
+ org.apache.spark.internal.config.EXECUTOR_MEMORY.defaultValueString)
+
+ private val memoryOverheadMiB = sparkConf
+ .get(KUBERNETES_EXECUTOR_MEMORY_OVERHEAD)
+ .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * executorMemoryMiB).toInt,
+ MEMORY_OVERHEAD_MIN_MIB))
+ private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB
+
+ private val executorCores = sparkConf.getDouble("spark.executor.cores", 1)
+ private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES)
+
+ override def createExecutorPod(
+ executorId: String,
+ applicationId: String,
+ driverUrl: String,
+ executorEnvs: Seq[(String, String)],
+ driverPod: Pod,
+ nodeToLocalTaskCount: Map[String, Int]): Pod = {
+ val name = s"$executorPodNamePrefix-exec-$executorId"
+
+ // hostname must be no longer than 63 characters, so take the last 63 characters of the pod
+ // name as the hostname. This preserves uniqueness since the end of name contains
+ // executorId
+ val hostname = name.substring(Math.max(0, name.length - 63))
+ val resolvedExecutorLabels = Map(
+ SPARK_EXECUTOR_ID_LABEL -> executorId,
+ SPARK_APP_ID_LABEL -> applicationId,
+ SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++
+ executorLabels
+ val executorMemoryQuantity = new QuantityBuilder(false)
+ .withAmount(s"${executorMemoryMiB}Mi")
+ .build()
+ val executorMemoryLimitQuantity = new QuantityBuilder(false)
+ .withAmount(s"${executorMemoryWithOverhead}Mi")
+ .build()
+ val executorCpuQuantity = new QuantityBuilder(false)
+ .withAmount(executorCores.toString)
+ .build()
+ val executorExtraClasspathEnv = executorExtraClasspath.map { cp =>
+ new EnvVarBuilder()
+ .withName(ENV_EXECUTOR_EXTRA_CLASSPATH)
+ .withValue(cp)
+ .build()
+ }
+ val executorExtraJavaOptionsEnv = sparkConf
+ .get(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS)
+ .map { opts =>
+ val delimitedOpts = Utils.splitCommandString(opts)
+ delimitedOpts.zipWithIndex.map {
+ case (opt, index) =>
+ new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build()
+ }
+ }.getOrElse(Seq.empty[EnvVar])
+ val executorEnv = (Seq(
+ (ENV_DRIVER_URL, driverUrl),
+ // Executor backend expects integral value for executor cores, so round it up to an int.
+ (ENV_EXECUTOR_CORES, math.ceil(executorCores).toInt.toString),
+ (ENV_EXECUTOR_MEMORY, executorMemoryString),
+ (ENV_APPLICATION_ID, applicationId),
+ (ENV_EXECUTOR_ID, executorId)) ++ executorEnvs)
+ .map(env => new EnvVarBuilder()
+ .withName(env._1)
+ .withValue(env._2)
+ .build()
+ ) ++ Seq(
+ new EnvVarBuilder()
+ .withName(ENV_EXECUTOR_POD_IP)
+ .withValueFrom(new EnvVarSourceBuilder()
+ .withNewFieldRef("v1", "status.podIP")
+ .build())
+ .build()
+ ) ++ executorExtraJavaOptionsEnv ++ executorExtraClasspathEnv.toSeq
+ val requiredPorts = Seq(
+ (BLOCK_MANAGER_PORT_NAME, blockManagerPort))
+ .map { case (name, port) =>
+ new ContainerPortBuilder()
+ .withName(name)
+ .withContainerPort(port)
+ .build()
+ }
+
+ val executorContainer = new ContainerBuilder()
+ .withName("executor")
+ .withImage(executorDockerImage)
+ .withImagePullPolicy(dockerImagePullPolicy)
+ .withNewResources()
+ .addToRequests("memory", executorMemoryQuantity)
+ .addToLimits("memory", executorMemoryLimitQuantity)
+ .addToRequests("cpu", executorCpuQuantity)
+ .endResources()
+ .addAllToEnv(executorEnv.asJava)
+ .withPorts(requiredPorts.asJava)
+ .build()
+
+ val executorPod = new PodBuilder()
+ .withNewMetadata()
+ .withName(name)
+ .withLabels(resolvedExecutorLabels.asJava)
+ .withAnnotations(executorAnnotations.asJava)
+ .withOwnerReferences()
+ .addNewOwnerReference()
+ .withController(true)
+ .withApiVersion(driverPod.getApiVersion)
+ .withKind(driverPod.getKind)
+ .withName(driverPod.getMetadata.getName)
+ .withUid(driverPod.getMetadata.getUid)
+ .endOwnerReference()
+ .endMetadata()
+ .withNewSpec()
+ .withHostname(hostname)
+ .withRestartPolicy("Never")
+ .withNodeSelector(nodeSelector.asJava)
+ .endSpec()
+ .build()
+
+ val containerWithExecutorLimitCores = executorLimitCores.map { limitCores =>
+ val executorCpuLimitQuantity = new QuantityBuilder(false)
+ .withAmount(limitCores)
+ .build()
+ new ContainerBuilder(executorContainer)
+ .editResources()
+ .addToLimits("cpu", executorCpuLimitQuantity)
+ .endResources()
+ .build()
+ }.getOrElse(executorContainer)
+
+ new PodBuilder(executorPod)
+ .editSpec()
+ .addToContainers(containerWithExecutorLimitCores)
+ .endSpec()
+ .build()
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala
new file mode 100644
index 000000000000..68ca6a762217
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster.k8s
+
+import java.io.File
+
+import io.fabric8.kubernetes.client.Config
+
+import org.apache.spark.SparkContext
+import org.apache.spark.deploy.k8s.Config._
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory
+import org.apache.spark.internal.Logging
+import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl}
+import org.apache.spark.util.ThreadUtils
+
+private[spark] class KubernetesClusterManager extends ExternalClusterManager with Logging {
+
+ override def canCreate(masterURL: String): Boolean = masterURL.startsWith("k8s")
+
+ override def createTaskScheduler(sc: SparkContext, masterURL: String): TaskScheduler = {
+ new TaskSchedulerImpl(sc)
+ }
+
+ override def createSchedulerBackend(
+ sc: SparkContext,
+ masterURL: String,
+ scheduler: TaskScheduler): SchedulerBackend = {
+ val sparkConf = sc.getConf
+
+ val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient(
+ KUBERNETES_MASTER_INTERNAL_URL,
+ Some(sparkConf.get(KUBERNETES_NAMESPACE)),
+ APISERVER_AUTH_DRIVER_MOUNTED_CONF_PREFIX,
+ sparkConf,
+ Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)),
+ Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH)))
+
+ val executorPodFactory = new ExecutorPodFactoryImpl(sparkConf)
+ val allocatorExecutor = ThreadUtils
+ .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator")
+ val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool(
+ "kubernetes-executor-requests")
+ new KubernetesClusterSchedulerBackend(
+ scheduler.asInstanceOf[TaskSchedulerImpl],
+ sc.env.rpcEnv,
+ executorPodFactory,
+ kubernetesClient,
+ allocatorExecutor,
+ requestExecutorsService)
+ }
+
+ override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = {
+ scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend)
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala
new file mode 100644
index 000000000000..e79c987852db
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala
@@ -0,0 +1,442 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster.k8s
+
+import java.io.Closeable
+import java.net.InetAddress
+import java.util.concurrent.{ConcurrentHashMap, ExecutorService, ScheduledExecutorService, TimeUnit}
+import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, AtomicReference}
+import javax.annotation.concurrent.GuardedBy
+
+import io.fabric8.kubernetes.api.model._
+import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watcher}
+import io.fabric8.kubernetes.client.Watcher.Action
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+import scala.concurrent.{ExecutionContext, Future}
+
+import org.apache.spark.SparkException
+import org.apache.spark.deploy.k8s.Config._
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress, RpcEnv}
+import org.apache.spark.scheduler.{ExecutorExited, SlaveLost, TaskSchedulerImpl}
+import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SchedulerBackendUtils}
+import org.apache.spark.util.Utils
+
+private[spark] class KubernetesClusterSchedulerBackend(
+ scheduler: TaskSchedulerImpl,
+ rpcEnv: RpcEnv,
+ executorPodFactory: ExecutorPodFactory,
+ kubernetesClient: KubernetesClient,
+ allocatorExecutor: ScheduledExecutorService,
+ requestExecutorsService: ExecutorService)
+ extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) {
+
+ import KubernetesClusterSchedulerBackend._
+
+ private val EXECUTOR_ID_COUNTER = new AtomicLong(0L)
+ private val RUNNING_EXECUTOR_PODS_LOCK = new Object
+ @GuardedBy("RUNNING_EXECUTOR_PODS_LOCK")
+ private val runningExecutorsToPods = new mutable.HashMap[String, Pod]
+ private val executorPodsByIPs = new ConcurrentHashMap[String, Pod]()
+ private val podsWithKnownExitReasons = new ConcurrentHashMap[String, ExecutorExited]()
+ private val disconnectedPodsByExecutorIdPendingRemoval = new ConcurrentHashMap[String, Pod]()
+
+ private val kubernetesNamespace = conf.get(KUBERNETES_NAMESPACE)
+
+ private val kubernetesDriverPodName = conf
+ .get(KUBERNETES_DRIVER_POD_NAME)
+ .getOrElse(throw new SparkException("Must specify the driver pod name"))
+ private implicit val requestExecutorContext = ExecutionContext.fromExecutorService(
+ requestExecutorsService)
+
+ private val driverPod = kubernetesClient.pods()
+ .inNamespace(kubernetesNamespace)
+ .withName(kubernetesDriverPodName)
+ .get()
+
+ protected override val minRegisteredRatio =
+ if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) {
+ 0.8
+ } else {
+ super.minRegisteredRatio
+ }
+
+ private val executorWatchResource = new AtomicReference[Closeable]
+ private val totalExpectedExecutors = new AtomicInteger(0)
+
+ private val driverUrl = RpcEndpointAddress(
+ conf.get("spark.driver.host"),
+ conf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT),
+ CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString
+
+ private val initialExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf)
+
+ private val podAllocationInterval = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY)
+
+ private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE)
+
+ private val executorLostReasonCheckMaxAttempts = conf.get(
+ KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS)
+
+ private val allocatorRunnable = new Runnable {
+
+ // Maintains a map of executor id to count of checks performed to learn the loss reason
+ // for an executor.
+ private val executorReasonCheckAttemptCounts = new mutable.HashMap[String, Int]
+
+ override def run(): Unit = {
+ handleDisconnectedExecutors()
+
+ val executorsToAllocate = mutable.Map[String, Pod]()
+ val currentTotalRegisteredExecutors = totalRegisteredExecutors.get
+ val currentTotalExpectedExecutors = totalExpectedExecutors.get
+ val currentNodeToLocalTaskCount = getNodesWithLocalTaskCounts()
+ RUNNING_EXECUTOR_PODS_LOCK.synchronized {
+ if (currentTotalRegisteredExecutors < runningExecutorsToPods.size) {
+ logDebug("Waiting for pending executors before scaling")
+ } else if (currentTotalExpectedExecutors <= runningExecutorsToPods.size) {
+ logDebug("Maximum allowed executor limit reached. Not scaling up further.")
+ } else {
+ for (_ <- 0 until math.min(
+ currentTotalExpectedExecutors - runningExecutorsToPods.size, podAllocationSize)) {
+ val executorId = EXECUTOR_ID_COUNTER.incrementAndGet().toString
+ val executorPod = executorPodFactory.createExecutorPod(
+ executorId,
+ applicationId(),
+ driverUrl,
+ conf.getExecutorEnv,
+ driverPod,
+ currentNodeToLocalTaskCount)
+ executorsToAllocate(executorId) = executorPod
+ logInfo(
+ s"Requesting a new executor, total executors is now ${runningExecutorsToPods.size}")
+ }
+ }
+ }
+
+ val allocatedExecutors = executorsToAllocate.mapValues { pod =>
+ Utils.tryLog {
+ kubernetesClient.pods().create(pod)
+ }
+ }
+
+ RUNNING_EXECUTOR_PODS_LOCK.synchronized {
+ allocatedExecutors.map {
+ case (executorId, attemptedAllocatedExecutor) =>
+ attemptedAllocatedExecutor.map { successfullyAllocatedExecutor =>
+ runningExecutorsToPods.put(executorId, successfullyAllocatedExecutor)
+ }
+ }
+ }
+ }
+
+ def handleDisconnectedExecutors(): Unit = {
+ // For each disconnected executor, synchronize with the loss reasons that may have been found
+ // by the executor pod watcher. If the loss reason was discovered by the watcher,
+ // inform the parent class with removeExecutor.
+ disconnectedPodsByExecutorIdPendingRemoval.asScala.foreach {
+ case (executorId, executorPod) =>
+ val knownExitReason = Option(podsWithKnownExitReasons.remove(
+ executorPod.getMetadata.getName))
+ knownExitReason.fold {
+ removeExecutorOrIncrementLossReasonCheckCount(executorId)
+ } { executorExited =>
+ logWarning(s"Removing executor $executorId with loss reason " + executorExited.message)
+ removeExecutor(executorId, executorExited)
+ // We don't delete the pod running the executor that has an exit condition caused by
+ // the application from the Kubernetes API server. This allows users to debug later on
+ // through commands such as "kubectl logs " and
+ // "kubectl describe pod ". Note that exited containers have terminated and
+ // therefore won't take CPU and memory resources.
+ // Otherwise, the executor pod is marked to be deleted from the API server.
+ if (executorExited.exitCausedByApp) {
+ logInfo(s"Executor $executorId exited because of the application.")
+ deleteExecutorFromDataStructures(executorId)
+ } else {
+ logInfo(s"Executor $executorId failed because of a framework error.")
+ deleteExecutorFromClusterAndDataStructures(executorId)
+ }
+ }
+ }
+ }
+
+ def removeExecutorOrIncrementLossReasonCheckCount(executorId: String): Unit = {
+ val reasonCheckCount = executorReasonCheckAttemptCounts.getOrElse(executorId, 0)
+ if (reasonCheckCount >= executorLostReasonCheckMaxAttempts) {
+ removeExecutor(executorId, SlaveLost("Executor lost for unknown reasons."))
+ deleteExecutorFromClusterAndDataStructures(executorId)
+ } else {
+ executorReasonCheckAttemptCounts.put(executorId, reasonCheckCount + 1)
+ }
+ }
+
+ def deleteExecutorFromClusterAndDataStructures(executorId: String): Unit = {
+ deleteExecutorFromDataStructures(executorId).foreach { pod =>
+ kubernetesClient.pods().delete(pod)
+ }
+ }
+
+ def deleteExecutorFromDataStructures(executorId: String): Option[Pod] = {
+ disconnectedPodsByExecutorIdPendingRemoval.remove(executorId)
+ executorReasonCheckAttemptCounts -= executorId
+ podsWithKnownExitReasons.remove(executorId)
+ RUNNING_EXECUTOR_PODS_LOCK.synchronized {
+ runningExecutorsToPods.remove(executorId).orElse {
+ logWarning(s"Unable to remove pod for unknown executor $executorId")
+ None
+ }
+ }
+ }
+ }
+
+ override def sufficientResourcesRegistered(): Boolean = {
+ totalRegisteredExecutors.get() >= initialExecutors * minRegisteredRatio
+ }
+
+ override def start(): Unit = {
+ super.start()
+ executorWatchResource.set(
+ kubernetesClient
+ .pods()
+ .withLabel(SPARK_APP_ID_LABEL, applicationId())
+ .watch(new ExecutorPodsWatcher()))
+
+ allocatorExecutor.scheduleWithFixedDelay(
+ allocatorRunnable, 0L, podAllocationInterval, TimeUnit.SECONDS)
+
+ if (!Utils.isDynamicAllocationEnabled(conf)) {
+ doRequestTotalExecutors(initialExecutors)
+ }
+ }
+
+ override def stop(): Unit = {
+ // stop allocation of new resources and caches.
+ allocatorExecutor.shutdown()
+ allocatorExecutor.awaitTermination(30, TimeUnit.SECONDS)
+
+ // send stop message to executors so they shut down cleanly
+ super.stop()
+
+ try {
+ val resource = executorWatchResource.getAndSet(null)
+ if (resource != null) {
+ resource.close()
+ }
+ } catch {
+ case e: Throwable => logWarning("Failed to close the executor pod watcher", e)
+ }
+
+ // then delete the executor pods
+ Utils.tryLogNonFatalError {
+ deleteExecutorPodsOnStop()
+ executorPodsByIPs.clear()
+ }
+ Utils.tryLogNonFatalError {
+ logInfo("Closing kubernetes client")
+ kubernetesClient.close()
+ }
+ }
+
+ /**
+ * @return A map of K8s cluster nodes to the number of tasks that could benefit from data
+ * locality if an executor launches on the cluster node.
+ */
+ private def getNodesWithLocalTaskCounts() : Map[String, Int] = {
+ val nodeToLocalTaskCount = synchronized {
+ mutable.Map[String, Int]() ++ hostToLocalTaskCount
+ }
+
+ for (pod <- executorPodsByIPs.values().asScala) {
+ // Remove cluster nodes that are running our executors already.
+ // TODO: This prefers spreading out executors across nodes. In case users want
+ // consolidating executors on fewer nodes, introduce a flag. See the spark.deploy.spreadOut
+ // flag that Spark standalone has: https://spark.apache.org/docs/latest/spark-standalone.html
+ nodeToLocalTaskCount.remove(pod.getSpec.getNodeName).nonEmpty ||
+ nodeToLocalTaskCount.remove(pod.getStatus.getHostIP).nonEmpty ||
+ nodeToLocalTaskCount.remove(
+ InetAddress.getByName(pod.getStatus.getHostIP).getCanonicalHostName).nonEmpty
+ }
+ nodeToLocalTaskCount.toMap[String, Int]
+ }
+
+ override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = Future[Boolean] {
+ totalExpectedExecutors.set(requestedTotal)
+ true
+ }
+
+ override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] {
+ val podsToDelete = RUNNING_EXECUTOR_PODS_LOCK.synchronized {
+ executorIds.flatMap { executorId =>
+ runningExecutorsToPods.remove(executorId) match {
+ case Some(pod) =>
+ disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod)
+ Some(pod)
+
+ case None =>
+ logWarning(s"Unable to remove pod for unknown executor $executorId")
+ None
+ }
+ }
+ }
+
+ kubernetesClient.pods().delete(podsToDelete: _*)
+ true
+ }
+
+ private def deleteExecutorPodsOnStop(): Unit = {
+ val executorPodsToDelete = RUNNING_EXECUTOR_PODS_LOCK.synchronized {
+ val runningExecutorPodsCopy = Seq(runningExecutorsToPods.values.toSeq: _*)
+ runningExecutorsToPods.clear()
+ runningExecutorPodsCopy
+ }
+ kubernetesClient.pods().delete(executorPodsToDelete: _*)
+ }
+
+ private class ExecutorPodsWatcher extends Watcher[Pod] {
+
+ private val DEFAULT_CONTAINER_FAILURE_EXIT_STATUS = -1
+
+ override def eventReceived(action: Action, pod: Pod): Unit = {
+ val podName = pod.getMetadata.getName
+ val podIP = pod.getStatus.getPodIP
+
+ action match {
+ case Action.MODIFIED if (pod.getStatus.getPhase == "Running"
+ && pod.getMetadata.getDeletionTimestamp == null) =>
+ val clusterNodeName = pod.getSpec.getNodeName
+ logInfo(s"Executor pod $podName ready, launched at $clusterNodeName as IP $podIP.")
+ executorPodsByIPs.put(podIP, pod)
+
+ case Action.DELETED | Action.ERROR =>
+ val executorId = getExecutorId(pod)
+ logDebug(s"Executor pod $podName at IP $podIP was at $action.")
+ if (podIP != null) {
+ executorPodsByIPs.remove(podIP)
+ }
+
+ val executorExitReason = if (action == Action.ERROR) {
+ logWarning(s"Received error event of executor pod $podName. Reason: " +
+ pod.getStatus.getReason)
+ executorExitReasonOnError(pod)
+ } else if (action == Action.DELETED) {
+ logWarning(s"Received delete event of executor pod $podName. Reason: " +
+ pod.getStatus.getReason)
+ executorExitReasonOnDelete(pod)
+ } else {
+ throw new IllegalStateException(
+ s"Unknown action that should only be DELETED or ERROR: $action")
+ }
+ podsWithKnownExitReasons.put(pod.getMetadata.getName, executorExitReason)
+
+ if (!disconnectedPodsByExecutorIdPendingRemoval.containsKey(executorId)) {
+ log.warn(s"Executor with id $executorId was not marked as disconnected, but the " +
+ s"watch received an event of type $action for this executor. The executor may " +
+ "have failed to start in the first place and never registered with the driver.")
+ }
+ disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod)
+
+ case _ => logDebug(s"Received event of executor pod $podName: " + action)
+ }
+ }
+
+ override def onClose(cause: KubernetesClientException): Unit = {
+ logDebug("Executor pod watch closed.", cause)
+ }
+
+ private def getExecutorExitStatus(pod: Pod): Int = {
+ val containerStatuses = pod.getStatus.getContainerStatuses
+ if (!containerStatuses.isEmpty) {
+ // we assume the first container represents the pod status. This assumption may not hold
+ // true in the future. Revisit this if side-car containers start running inside executor
+ // pods.
+ getExecutorExitStatus(containerStatuses.get(0))
+ } else DEFAULT_CONTAINER_FAILURE_EXIT_STATUS
+ }
+
+ private def getExecutorExitStatus(containerStatus: ContainerStatus): Int = {
+ Option(containerStatus.getState).map { containerState =>
+ Option(containerState.getTerminated).map { containerStateTerminated =>
+ containerStateTerminated.getExitCode.intValue()
+ }.getOrElse(UNKNOWN_EXIT_CODE)
+ }.getOrElse(UNKNOWN_EXIT_CODE)
+ }
+
+ private def isPodAlreadyReleased(pod: Pod): Boolean = {
+ val executorId = pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL)
+ RUNNING_EXECUTOR_PODS_LOCK.synchronized {
+ !runningExecutorsToPods.contains(executorId)
+ }
+ }
+
+ private def executorExitReasonOnError(pod: Pod): ExecutorExited = {
+ val containerExitStatus = getExecutorExitStatus(pod)
+ // container was probably actively killed by the driver.
+ if (isPodAlreadyReleased(pod)) {
+ ExecutorExited(containerExitStatus, exitCausedByApp = false,
+ s"Container in pod ${pod.getMetadata.getName} exited from explicit termination " +
+ "request.")
+ } else {
+ val containerExitReason = s"Pod ${pod.getMetadata.getName}'s executor container " +
+ s"exited with exit status code $containerExitStatus."
+ ExecutorExited(containerExitStatus, exitCausedByApp = true, containerExitReason)
+ }
+ }
+
+ private def executorExitReasonOnDelete(pod: Pod): ExecutorExited = {
+ val exitMessage = if (isPodAlreadyReleased(pod)) {
+ s"Container in pod ${pod.getMetadata.getName} exited from explicit termination request."
+ } else {
+ s"Pod ${pod.getMetadata.getName} deleted or lost."
+ }
+ ExecutorExited(getExecutorExitStatus(pod), exitCausedByApp = false, exitMessage)
+ }
+
+ private def getExecutorId(pod: Pod): String = {
+ val executorId = pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL)
+ require(executorId != null, "Unexpected pod metadata; expected all executor pods " +
+ s"to have label $SPARK_EXECUTOR_ID_LABEL.")
+ executorId
+ }
+ }
+
+ override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = {
+ new KubernetesDriverEndpoint(rpcEnv, properties)
+ }
+
+ private class KubernetesDriverEndpoint(
+ rpcEnv: RpcEnv,
+ sparkProperties: Seq[(String, String)])
+ extends DriverEndpoint(rpcEnv, sparkProperties) {
+
+ override def onDisconnected(rpcAddress: RpcAddress): Unit = {
+ addressToExecutorId.get(rpcAddress).foreach { executorId =>
+ if (disableExecutor(executorId)) {
+ RUNNING_EXECUTOR_PODS_LOCK.synchronized {
+ runningExecutorsToPods.get(executorId).foreach { pod =>
+ disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod)
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+private object KubernetesClusterSchedulerBackend {
+ private val UNKNOWN_EXIT_CODE = -1
+}
diff --git a/resource-managers/kubernetes/core/src/test/resources/log4j.properties b/resource-managers/kubernetes/core/src/test/resources/log4j.properties
new file mode 100644
index 000000000000..ad95fadb7c0c
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/resources/log4j.properties
@@ -0,0 +1,31 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Set everything to be logged to the file target/unit-tests.log
+log4j.rootCategory=INFO, file
+log4j.appender.file=org.apache.log4j.FileAppender
+log4j.appender.file.append=true
+log4j.appender.file.file=target/unit-tests.log
+log4j.appender.file.layout=org.apache.log4j.PatternLayout
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
+
+# Ignore messages below warning level from a few verbose libraries.
+log4j.logger.com.sun.jersey=WARN
+log4j.logger.org.apache.hadoop=WARN
+log4j.logger.org.eclipse.jetty=WARN
+log4j.logger.org.mortbay=WARN
+log4j.logger.org.spark_project.jetty=WARN
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala
new file mode 100644
index 000000000000..1c7717c23809
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster.k8s
+
+import scala.collection.JavaConverters._
+
+import io.fabric8.kubernetes.api.model.{Pod, _}
+import org.mockito.MockitoAnnotations
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach}
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.k8s.Config._
+import org.apache.spark.deploy.k8s.Constants._
+
+class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach {
+ private val driverPodName: String = "driver-pod"
+ private val driverPodUid: String = "driver-uid"
+ private val executorPrefix: String = "base"
+ private val executorImage: String = "executor-image"
+ private val driverPod = new PodBuilder()
+ .withNewMetadata()
+ .withName(driverPodName)
+ .withUid(driverPodUid)
+ .endMetadata()
+ .withNewSpec()
+ .withNodeName("some-node")
+ .endSpec()
+ .withNewStatus()
+ .withHostIP("192.168.99.100")
+ .endStatus()
+ .build()
+ private var baseConf: SparkConf = _
+
+ before {
+ MockitoAnnotations.initMocks(this)
+ baseConf = new SparkConf()
+ .set(KUBERNETES_DRIVER_POD_NAME, driverPodName)
+ .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix)
+ .set(EXECUTOR_DOCKER_IMAGE, executorImage)
+ }
+
+ test("basic executor pod has reasonable defaults") {
+ val factory = new ExecutorPodFactoryImpl(baseConf)
+ val executor = factory.createExecutorPod(
+ "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]())
+
+ // The executor pod name and default labels.
+ assert(executor.getMetadata.getName === s"$executorPrefix-exec-1")
+ assert(executor.getMetadata.getLabels.size() === 3)
+ assert(executor.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) === "1")
+
+ // There is exactly 1 container with no volume mounts and default memory limits.
+ // Default memory limit is 1024M + 384M (minimum overhead constant).
+ assert(executor.getSpec.getContainers.size() === 1)
+ assert(executor.getSpec.getContainers.get(0).getImage === executorImage)
+ assert(executor.getSpec.getContainers.get(0).getVolumeMounts.isEmpty)
+ assert(executor.getSpec.getContainers.get(0).getResources.getLimits.size() === 1)
+ assert(executor.getSpec.getContainers.get(0).getResources
+ .getLimits.get("memory").getAmount === "1408Mi")
+
+ // The pod has no node selector, volumes.
+ assert(executor.getSpec.getNodeSelector.isEmpty)
+ assert(executor.getSpec.getVolumes.isEmpty)
+
+ checkEnv(executor, Map())
+ checkOwnerReferences(executor, driverPodUid)
+ }
+
+ test("executor pod hostnames get truncated to 63 characters") {
+ val conf = baseConf.clone()
+ conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX,
+ "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple")
+
+ val factory = new ExecutorPodFactoryImpl(conf)
+ val executor = factory.createExecutorPod(
+ "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]())
+
+ assert(executor.getSpec.getHostname.length === 63)
+ }
+
+ test("classpath and extra java options get translated into environment variables") {
+ val conf = baseConf.clone()
+ conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar")
+ conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz")
+
+ val factory = new ExecutorPodFactoryImpl(conf)
+ val executor = factory.createExecutorPod(
+ "1", "dummy", "dummy", Seq[(String, String)]("qux" -> "quux"), driverPod, Map[String, Int]())
+
+ checkEnv(executor,
+ Map("SPARK_JAVA_OPT_0" -> "foo=bar",
+ "SPARK_EXECUTOR_EXTRA_CLASSPATH" -> "bar=baz",
+ "qux" -> "quux"))
+ checkOwnerReferences(executor, driverPodUid)
+ }
+
+ // There is always exactly one controller reference, and it points to the driver pod.
+ private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = {
+ assert(executor.getMetadata.getOwnerReferences.size() === 1)
+ assert(executor.getMetadata.getOwnerReferences.get(0).getUid === driverPodUid)
+ assert(executor.getMetadata.getOwnerReferences.get(0).getController === true)
+ }
+
+ // Check that the expected environment variables are present.
+ private def checkEnv(executor: Pod, additionalEnvVars: Map[String, String]): Unit = {
+ val defaultEnvs = Map(
+ ENV_EXECUTOR_ID -> "1",
+ ENV_DRIVER_URL -> "dummy",
+ ENV_EXECUTOR_CORES -> "1",
+ ENV_EXECUTOR_MEMORY -> "1g",
+ ENV_APPLICATION_ID -> "dummy",
+ ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars
+
+ assert(executor.getSpec.getContainers.size() === 1)
+ assert(executor.getSpec.getContainers.get(0).getEnv.size() === defaultEnvs.size)
+ val mapEnvs = executor.getSpec.getContainers.get(0).getEnv.asScala.map {
+ x => (x.getName, x.getValue)
+ }.toMap
+ assert(defaultEnvs === mapEnvs)
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala
new file mode 100644
index 000000000000..3febb2f47cfd
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala
@@ -0,0 +1,440 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster.k8s
+
+import java.util.concurrent.{ExecutorService, ScheduledExecutorService, TimeUnit}
+
+import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder, PodList}
+import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher}
+import io.fabric8.kubernetes.client.Watcher.Action
+import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NonNamespaceOperation, PodResource}
+import org.mockito.{AdditionalAnswers, ArgumentCaptor, Mock, MockitoAnnotations}
+import org.mockito.Matchers.{any, eq => mockitoEq}
+import org.mockito.Mockito.{doNothing, never, times, verify, when}
+import org.scalatest.BeforeAndAfter
+import org.scalatest.mockito.MockitoSugar._
+import scala.collection.JavaConverters._
+import scala.concurrent.Future
+
+import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.deploy.k8s.Config._
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.rpc._
+import org.apache.spark.scheduler.{ExecutorExited, LiveListenerBus, SlaveLost, TaskSchedulerImpl}
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor, RemoveExecutor}
+import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
+import org.apache.spark.util.ThreadUtils
+
+class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAndAfter {
+
+ private val APP_ID = "test-spark-app"
+ private val DRIVER_POD_NAME = "spark-driver-pod"
+ private val NAMESPACE = "test-namespace"
+ private val SPARK_DRIVER_HOST = "localhost"
+ private val SPARK_DRIVER_PORT = 7077
+ private val POD_ALLOCATION_INTERVAL = 60L
+ private val DRIVER_URL = RpcEndpointAddress(
+ SPARK_DRIVER_HOST, SPARK_DRIVER_PORT, CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString
+ private val FIRST_EXECUTOR_POD = new PodBuilder()
+ .withNewMetadata()
+ .withName("pod1")
+ .endMetadata()
+ .withNewSpec()
+ .withNodeName("node1")
+ .endSpec()
+ .withNewStatus()
+ .withHostIP("192.168.99.100")
+ .endStatus()
+ .build()
+ private val SECOND_EXECUTOR_POD = new PodBuilder()
+ .withNewMetadata()
+ .withName("pod2")
+ .endMetadata()
+ .withNewSpec()
+ .withNodeName("node2")
+ .endSpec()
+ .withNewStatus()
+ .withHostIP("192.168.99.101")
+ .endStatus()
+ .build()
+
+ private type PODS = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]]
+ private type LABELED_PODS = FilterWatchListDeletable[
+ Pod, PodList, java.lang.Boolean, Watch, Watcher[Pod]]
+ private type IN_NAMESPACE_PODS = NonNamespaceOperation[
+ Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]]
+
+ @Mock
+ private var sparkContext: SparkContext = _
+
+ @Mock
+ private var listenerBus: LiveListenerBus = _
+
+ @Mock
+ private var taskSchedulerImpl: TaskSchedulerImpl = _
+
+ @Mock
+ private var allocatorExecutor: ScheduledExecutorService = _
+
+ @Mock
+ private var requestExecutorsService: ExecutorService = _
+
+ @Mock
+ private var executorPodFactory: ExecutorPodFactory = _
+
+ @Mock
+ private var kubernetesClient: KubernetesClient = _
+
+ @Mock
+ private var podOperations: PODS = _
+
+ @Mock
+ private var podsWithLabelOperations: LABELED_PODS = _
+
+ @Mock
+ private var podsInNamespace: IN_NAMESPACE_PODS = _
+
+ @Mock
+ private var podsWithDriverName: PodResource[Pod, DoneablePod] = _
+
+ @Mock
+ private var rpcEnv: RpcEnv = _
+
+ @Mock
+ private var driverEndpointRef: RpcEndpointRef = _
+
+ @Mock
+ private var executorPodsWatch: Watch = _
+
+ @Mock
+ private var successFuture: Future[Boolean] = _
+
+ private var sparkConf: SparkConf = _
+ private var executorPodsWatcherArgument: ArgumentCaptor[Watcher[Pod]] = _
+ private var allocatorRunnable: ArgumentCaptor[Runnable] = _
+ private var requestExecutorRunnable: ArgumentCaptor[Runnable] = _
+ private var driverEndpoint: ArgumentCaptor[RpcEndpoint] = _
+
+ private val driverPod = new PodBuilder()
+ .withNewMetadata()
+ .withName(DRIVER_POD_NAME)
+ .addToLabels(SPARK_APP_ID_LABEL, APP_ID)
+ .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_DRIVER_ROLE)
+ .endMetadata()
+ .build()
+
+ before {
+ MockitoAnnotations.initMocks(this)
+ sparkConf = new SparkConf()
+ .set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME)
+ .set(KUBERNETES_NAMESPACE, NAMESPACE)
+ .set("spark.driver.host", SPARK_DRIVER_HOST)
+ .set("spark.driver.port", SPARK_DRIVER_PORT.toString)
+ .set(KUBERNETES_ALLOCATION_BATCH_DELAY, POD_ALLOCATION_INTERVAL)
+ executorPodsWatcherArgument = ArgumentCaptor.forClass(classOf[Watcher[Pod]])
+ allocatorRunnable = ArgumentCaptor.forClass(classOf[Runnable])
+ requestExecutorRunnable = ArgumentCaptor.forClass(classOf[Runnable])
+ driverEndpoint = ArgumentCaptor.forClass(classOf[RpcEndpoint])
+ when(sparkContext.conf).thenReturn(sparkConf)
+ when(sparkContext.listenerBus).thenReturn(listenerBus)
+ when(taskSchedulerImpl.sc).thenReturn(sparkContext)
+ when(kubernetesClient.pods()).thenReturn(podOperations)
+ when(podOperations.withLabel(SPARK_APP_ID_LABEL, APP_ID)).thenReturn(podsWithLabelOperations)
+ when(podsWithLabelOperations.watch(executorPodsWatcherArgument.capture()))
+ .thenReturn(executorPodsWatch)
+ when(podOperations.inNamespace(NAMESPACE)).thenReturn(podsInNamespace)
+ when(podsInNamespace.withName(DRIVER_POD_NAME)).thenReturn(podsWithDriverName)
+ when(podsWithDriverName.get()).thenReturn(driverPod)
+ when(allocatorExecutor.scheduleWithFixedDelay(
+ allocatorRunnable.capture(),
+ mockitoEq(0L),
+ mockitoEq(POD_ALLOCATION_INTERVAL),
+ mockitoEq(TimeUnit.SECONDS))).thenReturn(null)
+ // Creating Futures in Scala backed by a Java executor service resolves to running
+ // ExecutorService#execute (as opposed to submit)
+ doNothing().when(requestExecutorsService).execute(requestExecutorRunnable.capture())
+ when(rpcEnv.setupEndpoint(
+ mockitoEq(CoarseGrainedSchedulerBackend.ENDPOINT_NAME), driverEndpoint.capture()))
+ .thenReturn(driverEndpointRef)
+
+ // Used by the CoarseGrainedSchedulerBackend when making RPC calls.
+ when(driverEndpointRef.ask[Boolean]
+ (any(classOf[Any]))
+ (any())).thenReturn(successFuture)
+ when(successFuture.failed).thenReturn(Future[Throwable] {
+ // emulate behavior of the Future.failed method.
+ throw new NoSuchElementException()
+ }(ThreadUtils.sameThread))
+ }
+
+ test("Basic lifecycle expectations when starting and stopping the scheduler.") {
+ val scheduler = newSchedulerBackend()
+ scheduler.start()
+ assert(executorPodsWatcherArgument.getValue != null)
+ assert(allocatorRunnable.getValue != null)
+ scheduler.stop()
+ verify(executorPodsWatch).close()
+ }
+
+ test("Static allocation should request executors upon first allocator run.") {
+ sparkConf
+ .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2)
+ .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2)
+ val scheduler = newSchedulerBackend()
+ scheduler.start()
+ requestExecutorRunnable.getValue.run()
+ val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD)
+ val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD)
+ when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg())
+ allocatorRunnable.getValue.run()
+ verify(podOperations).create(firstResolvedPod)
+ verify(podOperations).create(secondResolvedPod)
+ }
+
+ test("Killing executors deletes the executor pods") {
+ sparkConf
+ .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2)
+ .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2)
+ val scheduler = newSchedulerBackend()
+ scheduler.start()
+ requestExecutorRunnable.getValue.run()
+ val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD)
+ val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD)
+ when(podOperations.create(any(classOf[Pod])))
+ .thenAnswer(AdditionalAnswers.returnsFirstArg())
+ allocatorRunnable.getValue.run()
+ scheduler.doKillExecutors(Seq("2"))
+ requestExecutorRunnable.getAllValues.asScala.last.run()
+ verify(podOperations).delete(secondResolvedPod)
+ verify(podOperations, never()).delete(firstResolvedPod)
+ }
+
+ test("Executors should be requested in batches.") {
+ sparkConf
+ .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1)
+ .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2)
+ val scheduler = newSchedulerBackend()
+ scheduler.start()
+ requestExecutorRunnable.getValue.run()
+ when(podOperations.create(any(classOf[Pod])))
+ .thenAnswer(AdditionalAnswers.returnsFirstArg())
+ val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD)
+ val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD)
+ allocatorRunnable.getValue.run()
+ verify(podOperations).create(firstResolvedPod)
+ verify(podOperations, never()).create(secondResolvedPod)
+ val registerFirstExecutorMessage = RegisterExecutor(
+ "1", mock[RpcEndpointRef], "localhost", 1, Map.empty[String, String])
+ when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty)
+ driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext])
+ .apply(registerFirstExecutorMessage)
+ allocatorRunnable.getValue.run()
+ verify(podOperations).create(secondResolvedPod)
+ }
+
+ test("Scaled down executors should be cleaned up") {
+ sparkConf
+ .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1)
+ .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1)
+ val scheduler = newSchedulerBackend()
+ scheduler.start()
+
+ // The scheduler backend spins up one executor pod.
+ requestExecutorRunnable.getValue.run()
+ when(podOperations.create(any(classOf[Pod])))
+ .thenAnswer(AdditionalAnswers.returnsFirstArg())
+ val resolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD)
+ allocatorRunnable.getValue.run()
+ val executorEndpointRef = mock[RpcEndpointRef]
+ when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000))
+ val registerFirstExecutorMessage = RegisterExecutor(
+ "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String])
+ when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty)
+ driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext])
+ .apply(registerFirstExecutorMessage)
+
+ // Request that there are 0 executors and trigger deletion from driver.
+ scheduler.doRequestTotalExecutors(0)
+ requestExecutorRunnable.getAllValues.asScala.last.run()
+ scheduler.doKillExecutors(Seq("1"))
+ requestExecutorRunnable.getAllValues.asScala.last.run()
+ verify(podOperations, times(1)).delete(resolvedPod)
+ driverEndpoint.getValue.onDisconnected(executorEndpointRef.address)
+
+ val exitedPod = exitPod(resolvedPod, 0)
+ executorPodsWatcherArgument.getValue.eventReceived(Action.DELETED, exitedPod)
+ allocatorRunnable.getValue.run()
+
+ // No more deletion attempts of the executors.
+ // This is graceful termination and should not be detected as a failure.
+ verify(podOperations, times(1)).delete(resolvedPod)
+ verify(driverEndpointRef, times(1)).ask[Boolean](
+ RemoveExecutor("1", ExecutorExited(
+ 0,
+ exitCausedByApp = false,
+ s"Container in pod ${exitedPod.getMetadata.getName} exited from" +
+ s" explicit termination request.")))
+ }
+
+ test("Executors that fail should not be deleted.") {
+ sparkConf
+ .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1)
+ .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1)
+
+ val scheduler = newSchedulerBackend()
+ scheduler.start()
+ val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD)
+ when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg())
+ requestExecutorRunnable.getValue.run()
+ allocatorRunnable.getValue.run()
+ val executorEndpointRef = mock[RpcEndpointRef]
+ when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000))
+ val registerFirstExecutorMessage = RegisterExecutor(
+ "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String])
+ when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty)
+ driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext])
+ .apply(registerFirstExecutorMessage)
+ driverEndpoint.getValue.onDisconnected(executorEndpointRef.address)
+ executorPodsWatcherArgument.getValue.eventReceived(
+ Action.ERROR, exitPod(firstResolvedPod, 1))
+
+ // A replacement executor should be created but the error pod should persist.
+ val replacementPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD)
+ scheduler.doRequestTotalExecutors(1)
+ requestExecutorRunnable.getValue.run()
+ allocatorRunnable.getAllValues.asScala.last.run()
+ verify(podOperations, never()).delete(firstResolvedPod)
+ verify(driverEndpointRef).ask[Boolean](
+ RemoveExecutor("1", ExecutorExited(
+ 1,
+ exitCausedByApp = true,
+ s"Pod ${FIRST_EXECUTOR_POD.getMetadata.getName}'s executor container exited with" +
+ " exit status code 1.")))
+ }
+
+ test("Executors disconnected due to unknown reasons are deleted and replaced.") {
+ sparkConf
+ .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1)
+ .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1)
+ val executorLostReasonCheckMaxAttempts = sparkConf.get(
+ KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS)
+
+ val scheduler = newSchedulerBackend()
+ scheduler.start()
+ val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD)
+ when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg())
+ requestExecutorRunnable.getValue.run()
+ allocatorRunnable.getValue.run()
+ val executorEndpointRef = mock[RpcEndpointRef]
+ when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000))
+ val registerFirstExecutorMessage = RegisterExecutor(
+ "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String])
+ when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty)
+ driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext])
+ .apply(registerFirstExecutorMessage)
+
+ driverEndpoint.getValue.onDisconnected(executorEndpointRef.address)
+ 1 to executorLostReasonCheckMaxAttempts foreach { _ =>
+ allocatorRunnable.getValue.run()
+ verify(podOperations, never()).delete(FIRST_EXECUTOR_POD)
+ }
+
+ val recreatedResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD)
+ allocatorRunnable.getValue.run()
+ verify(podOperations).delete(firstResolvedPod)
+ verify(driverEndpointRef).ask[Boolean](
+ RemoveExecutor("1", SlaveLost("Executor lost for unknown reasons.")))
+ }
+
+ test("Executors that fail to start on the Kubernetes API call rebuild in the next batch.") {
+ sparkConf
+ .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1)
+ .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1)
+ val scheduler = newSchedulerBackend()
+ scheduler.start()
+ val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD)
+ when(podOperations.create(firstResolvedPod))
+ .thenThrow(new RuntimeException("test"))
+ requestExecutorRunnable.getValue.run()
+ allocatorRunnable.getValue.run()
+ verify(podOperations, times(1)).create(firstResolvedPod)
+ val recreatedResolvedPod = expectPodCreationWithId(2, FIRST_EXECUTOR_POD)
+ allocatorRunnable.getValue.run()
+ verify(podOperations).create(recreatedResolvedPod)
+ }
+
+ test("Executors that are initially created but the watch notices them fail are rebuilt" +
+ " in the next batch.") {
+ sparkConf
+ .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1)
+ .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1)
+ val scheduler = newSchedulerBackend()
+ scheduler.start()
+ val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD)
+ when(podOperations.create(FIRST_EXECUTOR_POD)).thenAnswer(AdditionalAnswers.returnsFirstArg())
+ requestExecutorRunnable.getValue.run()
+ allocatorRunnable.getValue.run()
+ verify(podOperations, times(1)).create(firstResolvedPod)
+ executorPodsWatcherArgument.getValue.eventReceived(Action.ERROR, firstResolvedPod)
+ val recreatedResolvedPod = expectPodCreationWithId(2, FIRST_EXECUTOR_POD)
+ allocatorRunnable.getValue.run()
+ verify(podOperations).create(recreatedResolvedPod)
+ }
+
+ private def newSchedulerBackend(): KubernetesClusterSchedulerBackend = {
+ new KubernetesClusterSchedulerBackend(
+ taskSchedulerImpl,
+ rpcEnv,
+ executorPodFactory,
+ kubernetesClient,
+ allocatorExecutor,
+ requestExecutorsService) {
+
+ override def applicationId(): String = APP_ID
+ }
+ }
+
+ private def exitPod(basePod: Pod, exitCode: Int): Pod = {
+ new PodBuilder(basePod)
+ .editStatus()
+ .addNewContainerStatus()
+ .withNewState()
+ .withNewTerminated()
+ .withExitCode(exitCode)
+ .endTerminated()
+ .endState()
+ .endContainerStatus()
+ .endStatus()
+ .build()
+ }
+
+ private def expectPodCreationWithId(executorId: Int, expectedPod: Pod): Pod = {
+ val resolvedPod = new PodBuilder(expectedPod)
+ .editMetadata()
+ .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString)
+ .endMetadata()
+ .build()
+ when(executorPodFactory.createExecutorPod(
+ executorId.toString,
+ APP_ID,
+ DRIVER_URL,
+ sparkConf.getExecutorEnv,
+ driverPod,
+ Map.empty)).thenReturn(resolvedPod)
+ resolvedPod
+ }
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
index 7052fb347106..506adb363aa9 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
@@ -41,6 +41,7 @@ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef}
import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason}
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RetrieveLastAllocatedExecutorId
+import org.apache.spark.scheduler.cluster.SchedulerBackendUtils
import org.apache.spark.util.{Clock, SystemClock, ThreadUtils}
/**
@@ -109,7 +110,7 @@ private[yarn] class YarnAllocator(
sparkConf.get(EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS).getOrElse(-1L)
@volatile private var targetNumExecutors =
- YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sparkConf)
+ SchedulerBackendUtils.getInitialTargetExecutorNumber(sparkConf)
private var currentNodeBlacklist = Set.empty[String]
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
index 3d9f99f57bed..9c1472cb50e3 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
@@ -133,8 +133,6 @@ object YarnSparkHadoopUtil {
val ANY_HOST = "*"
- val DEFAULT_NUMBER_EXECUTORS = 2
-
// All RM requests are issued with same priority : we do not (yet) have any distinction between
// request types (like map/reduce in hadoop for example)
val RM_REQUEST_PRIORITY = Priority.newInstance(1)
@@ -279,27 +277,5 @@ object YarnSparkHadoopUtil {
securityMgr.getModifyAclsGroups)
)
}
-
- /**
- * Getting the initial target number of executors depends on whether dynamic allocation is
- * enabled.
- * If not using dynamic allocation it gets the number of executors requested by the user.
- */
- def getInitialTargetExecutorNumber(
- conf: SparkConf,
- numExecutors: Int = DEFAULT_NUMBER_EXECUTORS): Int = {
- if (Utils.isDynamicAllocationEnabled(conf)) {
- val minNumExecutors = conf.get(DYN_ALLOCATION_MIN_EXECUTORS)
- val initialNumExecutors = Utils.getDynamicAllocationInitialExecutors(conf)
- val maxNumExecutors = conf.get(DYN_ALLOCATION_MAX_EXECUTORS)
- require(initialNumExecutors >= minNumExecutors && initialNumExecutors <= maxNumExecutors,
- s"initial executor number $initialNumExecutors must between min executor number " +
- s"$minNumExecutors and max executor number $maxNumExecutors")
-
- initialNumExecutors
- } else {
- conf.get(EXECUTOR_INSTANCES).getOrElse(numExecutors)
- }
- }
}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
index d482376d14dd..b722cc401bb7 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -52,7 +52,7 @@ private[spark] class YarnClientSchedulerBackend(
logDebug("ClientArguments called with: " + argsArrayBuf.mkString(" "))
val args = new ClientArguments(argsArrayBuf.toArray)
- totalExpectedExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(conf)
+ totalExpectedExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf)
client = new Client(args, conf)
bindToYarn(client.submitApplication(), None)
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala
index 4f3d5ebf403e..e2d477be329c 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala
@@ -34,7 +34,7 @@ private[spark] class YarnClusterSchedulerBackend(
val attemptId = ApplicationMaster.getAttemptId
bindToYarn(attemptId.getApplicationId(), Some(attemptId))
super.start()
- totalExpectedExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sc.conf)
+ totalExpectedExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(sc.conf)
}
override def getDriverLogUrls: Option[Map[String, String]] = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
index b87bbb487467..95b6fbb0cd61 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIden
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Cast, ExprId, Literal}
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.catalyst.util.quoteIdentifier
import org.apache.spark.sql.types.StructType
@@ -366,10 +367,17 @@ case class CatalogStatistics(
* Convert [[CatalogStatistics]] to [[Statistics]], and match column stats to attributes based
* on column names.
*/
- def toPlanStats(planOutput: Seq[Attribute]): Statistics = {
- val matched = planOutput.flatMap(a => colStats.get(a.name).map(a -> _))
- Statistics(sizeInBytes = sizeInBytes, rowCount = rowCount,
- attributeStats = AttributeMap(matched))
+ def toPlanStats(planOutput: Seq[Attribute], cboEnabled: Boolean): Statistics = {
+ if (cboEnabled && rowCount.isDefined) {
+ val attrStats = AttributeMap(planOutput.flatMap(a => colStats.get(a.name).map(a -> _)))
+ // Estimate size as number of rows * row size.
+ val size = EstimationUtils.getOutputSize(planOutput, rowCount.get, attrStats)
+ Statistics(sizeInBytes = size, rowCount = rowCount, attributeStats = attrStats)
+ } else {
+ // When CBO is disabled or the table doesn't have other statistics, we apply the size-only
+ // estimation strategy and only propagate sizeInBytes in statistics.
+ Statistics(sizeInBytes = sizeInBytes)
+ }
}
/** Readable string representation for the CatalogStatistics. */
@@ -452,7 +460,7 @@ case class HiveTableRelation(
)
override def computeStats(): Statistics = {
- tableMeta.stats.map(_.toPlanStats(output)).getOrElse {
+ tableMeta.stats.map(_.toPlanStats(output, conf.cboEnabled)).getOrElse {
throw new IllegalStateException("table stats must be specified.")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
index f8644c2cd672..8d06804ce1e9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
@@ -87,8 +87,7 @@ class EquivalentExpressions {
def childrenToRecurse: Seq[Expression] = expr match {
case _: CodegenFallback => Nil
case i: If => i.predicate :: Nil
- // `CaseWhen` implements `CodegenFallback`, we only need to handle `CaseWhenCodegen` here.
- case c: CaseWhenCodegen => c.children.head :: Nil
+ case c: CaseWhen => c.children.head :: Nil
case c: Coalesce => c.children.head :: Nil
case other => other.children
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index e5a1096bba71..d98f7b3d8efe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -614,7 +614,7 @@ case class Least(children: Seq[Expression]) extends Expression {
}
"""
}
- val codes = ctx.splitExpressions(ctx.INPUT_ROW, evalChildren.map(updateEval))
+ val codes = ctx.splitExpressions(evalChildren.map(updateEval))
ev.copy(code = s"""
${ev.isNull} = true;
${ev.value} = ${ctx.defaultValue(dataType)};
@@ -680,7 +680,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
}
"""
}
- val codes = ctx.splitExpressions(ctx.INPUT_ROW, evalChildren.map(updateEval))
+ val codes = ctx.splitExpressions(evalChildren.map(updateEval))
ev.copy(code = s"""
${ev.isNull} = true;
${ev.value} = ${ctx.defaultValue(dataType)};
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 0498e61819f4..668c816b3fd8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -781,15 +781,18 @@ class CodegenContext {
* beyond 1000kb, we declare a private, inner sub-class, and the function is inlined to it
* instead, because classes have a constant pool limit of 65,536 named values.
*
- * @param row the variable name of row that is used by expressions
+ * Note that we will extract the current inputs of this context and pass them to the generated
+ * functions. The input is `INPUT_ROW` for normal codegen path, and `currentVars` for whole
+ * stage codegen path. Whole stage codegen path is not supported yet.
+ *
* @param expressions the codes to evaluate expressions.
*/
- def splitExpressions(row: String, expressions: Seq[String]): String = {
- if (row == null || currentVars != null) {
- // Cannot split these expressions because they are not created from a row object.
+ def splitExpressions(expressions: Seq[String]): String = {
+ // TODO: support whole stage codegen
+ if (INPUT_ROW == null || currentVars != null) {
return expressions.mkString("\n")
}
- splitExpressions(expressions, funcName = "apply", arguments = ("InternalRow", row) :: Nil)
+ splitExpressions(expressions, funcName = "apply", arguments = ("InternalRow", INPUT_ROW) :: Nil)
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 802e8bdb1ca3..5fdbda51b4ad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -91,8 +91,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
}
- val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes)
- val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates)
+ val allProjections = ctx.splitExpressions(projectionCodes)
+ val allUpdates = ctx.splitExpressions(updates)
val codeBody = s"""
public java.lang.Object generate(Object[] references) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index 1e4ac3f2afd5..5d35cce1a91c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -45,7 +45,8 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
ctx: CodegenContext,
input: String,
schema: StructType): ExprCode = {
- val tmp = ctx.freshName("tmp")
+ // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
+ val tmpInput = ctx.freshName("tmpInput")
val output = ctx.freshName("safeRow")
val values = ctx.freshName("values")
// These expressions could be split into multiple functions
@@ -54,17 +55,21 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val rowClass = classOf[GenericInternalRow].getName
val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
- val converter = convertToSafe(ctx, ctx.getValue(tmp, dt, i.toString), dt)
+ val converter = convertToSafe(ctx, ctx.getValue(tmpInput, dt, i.toString), dt)
s"""
- if (!$tmp.isNullAt($i)) {
+ if (!$tmpInput.isNullAt($i)) {
${converter.code}
$values[$i] = ${converter.value};
}
"""
}
- val allFields = ctx.splitExpressions(tmp, fieldWriters)
+ val allFields = ctx.splitExpressions(
+ expressions = fieldWriters,
+ funcName = "writeFields",
+ arguments = Seq("InternalRow" -> tmpInput)
+ )
val code = s"""
- final InternalRow $tmp = $input;
+ final InternalRow $tmpInput = $input;
$values = new Object[${schema.length}];
$allFields
final InternalRow $output = new $rowClass($values);
@@ -78,20 +83,22 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
ctx: CodegenContext,
input: String,
elementType: DataType): ExprCode = {
- val tmp = ctx.freshName("tmp")
+ // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
+ val tmpInput = ctx.freshName("tmpInput")
val output = ctx.freshName("safeArray")
val values = ctx.freshName("values")
val numElements = ctx.freshName("numElements")
val index = ctx.freshName("index")
val arrayClass = classOf[GenericArrayData].getName
- val elementConverter = convertToSafe(ctx, ctx.getValue(tmp, elementType, index), elementType)
+ val elementConverter = convertToSafe(
+ ctx, ctx.getValue(tmpInput, elementType, index), elementType)
val code = s"""
- final ArrayData $tmp = $input;
- final int $numElements = $tmp.numElements();
+ final ArrayData $tmpInput = $input;
+ final int $numElements = $tmpInput.numElements();
final Object[] $values = new Object[$numElements];
for (int $index = 0; $index < $numElements; $index++) {
- if (!$tmp.isNullAt($index)) {
+ if (!$tmpInput.isNullAt($index)) {
${elementConverter.code}
$values[$index] = ${elementConverter.value};
}
@@ -107,14 +114,14 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
input: String,
keyType: DataType,
valueType: DataType): ExprCode = {
- val tmp = ctx.freshName("tmp")
+ val tmpInput = ctx.freshName("tmpInput")
val output = ctx.freshName("safeMap")
val mapClass = classOf[ArrayBasedMapData].getName
- val keyConverter = createCodeForArray(ctx, s"$tmp.keyArray()", keyType)
- val valueConverter = createCodeForArray(ctx, s"$tmp.valueArray()", valueType)
+ val keyConverter = createCodeForArray(ctx, s"$tmpInput.keyArray()", keyType)
+ val valueConverter = createCodeForArray(ctx, s"$tmpInput.valueArray()", valueType)
val code = s"""
- final MapData $tmp = $input;
+ final MapData $tmpInput = $input;
${keyConverter.code}
${valueConverter.code}
final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value});
@@ -152,7 +159,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
}
"""
}
- val allExpressions = ctx.splitExpressions(ctx.INPUT_ROW, expressionCodes)
+ val allExpressions = ctx.splitExpressions(expressionCodes)
val codeBody = s"""
public java.lang.Object generate(Object[] references) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 4bd50aee0551..b022457865d5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -36,7 +36,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case NullType => true
case t: AtomicType => true
case _: CalendarIntervalType => true
- case t: StructType => t.toSeq.forall(field => canSupport(field.dataType))
+ case t: StructType => t.forall(field => canSupport(field.dataType))
case t: ArrayType if canSupport(t.elementType) => true
case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true
case udt: UserDefinedType[_] => canSupport(udt.sqlType)
@@ -49,25 +49,18 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
input: String,
fieldTypes: Seq[DataType],
bufferHolder: String): String = {
+ // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
+ val tmpInput = ctx.freshName("tmpInput")
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
- val javaType = ctx.javaType(dt)
- val isNullVar = ctx.freshName("isNull")
- val valueVar = ctx.freshName("value")
- val defaultValue = ctx.defaultValue(dt)
- val readValue = ctx.getValue(input, dt, i.toString)
- val code =
- s"""
- boolean $isNullVar = $input.isNullAt($i);
- $javaType $valueVar = $isNullVar ? $defaultValue : $readValue;
- """
- ExprCode(code, isNullVar, valueVar)
+ ExprCode("", s"$tmpInput.isNullAt($i)", ctx.getValue(tmpInput, dt, i.toString))
}
s"""
- if ($input instanceof UnsafeRow) {
- ${writeUnsafeData(ctx, s"((UnsafeRow) $input)", bufferHolder)}
+ final InternalRow $tmpInput = $input;
+ if ($tmpInput instanceof UnsafeRow) {
+ ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", bufferHolder)}
} else {
- ${writeExpressionsToBuffer(ctx, input, fieldEvals, fieldTypes, bufferHolder)}
+ ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, bufferHolder)}
}
"""
}
@@ -167,9 +160,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
}
}
+ val writeFieldsCode = if (isTopLevel && (row == null || ctx.currentVars != null)) {
+ // TODO: support whole stage codegen
+ writeFields.mkString("\n")
+ } else {
+ assert(row != null, "the input row name cannot be null when generating code to write it.")
+ ctx.splitExpressions(
+ expressions = writeFields,
+ funcName = "writeFields",
+ arguments = Seq("InternalRow" -> row))
+ }
+
s"""
$resetWriter
- ${ctx.splitExpressions(row, writeFields)}
+ $writeFieldsCode
""".trim
}
@@ -179,13 +183,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
input: String,
elementType: DataType,
bufferHolder: String): String = {
+ // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
+ val tmpInput = ctx.freshName("tmpInput")
val arrayWriterClass = classOf[UnsafeArrayWriter].getName
val arrayWriter = ctx.freshName("arrayWriter")
ctx.addMutableState(arrayWriterClass, arrayWriter,
s"$arrayWriter = new $arrayWriterClass();")
val numElements = ctx.freshName("numElements")
val index = ctx.freshName("index")
- val element = ctx.freshName("element")
val et = elementType match {
case udt: UserDefinedType[_] => udt.sqlType
@@ -201,6 +206,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
}
val tmpCursor = ctx.freshName("tmpCursor")
+ val element = ctx.getValue(tmpInput, et, index)
val writeElement = et match {
case t: StructType =>
s"""
@@ -233,17 +239,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else ""
s"""
- if ($input instanceof UnsafeArrayData) {
- ${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)}
+ final ArrayData $tmpInput = $input;
+ if ($tmpInput instanceof UnsafeArrayData) {
+ ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", bufferHolder)}
} else {
- final int $numElements = $input.numElements();
+ final int $numElements = $tmpInput.numElements();
$arrayWriter.initialize($bufferHolder, $numElements, $elementOrOffsetSize);
for (int $index = 0; $index < $numElements; $index++) {
- if ($input.isNullAt($index)) {
+ if ($tmpInput.isNullAt($index)) {
$arrayWriter.setNull$primitiveTypeName($index);
} else {
- final $jt $element = ${ctx.getValue(input, et, index)};
$writeElement
}
}
@@ -258,19 +264,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
keyType: DataType,
valueType: DataType,
bufferHolder: String): String = {
- val keys = ctx.freshName("keys")
- val values = ctx.freshName("values")
+ // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
+ val tmpInput = ctx.freshName("tmpInput")
val tmpCursor = ctx.freshName("tmpCursor")
-
// Writes out unsafe map according to the format described in `UnsafeMapData`.
s"""
- if ($input instanceof UnsafeMapData) {
- ${writeUnsafeData(ctx, s"((UnsafeMapData) $input)", bufferHolder)}
+ final MapData $tmpInput = $input;
+ if ($tmpInput instanceof UnsafeMapData) {
+ ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", bufferHolder)}
} else {
- final ArrayData $keys = $input.keyArray();
- final ArrayData $values = $input.valueArray();
-
// preserve 8 bytes to write the key array numBytes later.
$bufferHolder.grow(8);
$bufferHolder.cursor += 8;
@@ -278,11 +281,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
// Remember the current cursor so that we can write numBytes of key array later.
final int $tmpCursor = $bufferHolder.cursor;
- ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)}
+ ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, bufferHolder)}
// Write the numBytes of key array into the first 8 bytes.
Platform.putLong($bufferHolder.buffer, $tmpCursor - 8, $bufferHolder.cursor - $tmpCursor);
- ${writeArrayToBuffer(ctx, values, valueType, bufferHolder)}
+ ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, bufferHolder)}
}
"""
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 2a00d57ee130..57a7f2e20773 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -63,7 +63,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
val (preprocess, assigns, postprocess, arrayData) =
GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false)
ev.copy(
- code = preprocess + ctx.splitExpressions(ctx.INPUT_ROW, assigns) + postprocess,
+ code = preprocess + ctx.splitExpressions(assigns) + postprocess,
value = arrayData,
isNull = "false")
}
@@ -216,10 +216,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
s"""
final boolean ${ev.isNull} = false;
$preprocessKeyData
- ${ctx.splitExpressions(ctx.INPUT_ROW, assignKeys)}
+ ${ctx.splitExpressions(assignKeys)}
$postprocessKeyData
$preprocessValueData
- ${ctx.splitExpressions(ctx.INPUT_ROW, assignValues)}
+ ${ctx.splitExpressions(assignValues)}
$postprocessValueData
final MapData ${ev.value} = new $mapClass($keyArrayData, $valueArrayData);
"""
@@ -351,24 +351,25 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc
val rowClass = classOf[GenericInternalRow].getName
val values = ctx.freshName("values")
ctx.addMutableState("Object[]", values, s"$values = null;")
-
- ev.copy(code = s"""
- $values = new Object[${valExprs.size}];""" +
- ctx.splitExpressions(
- ctx.INPUT_ROW,
- valExprs.zipWithIndex.map { case (e, i) =>
- val eval = e.genCode(ctx)
- eval.code + s"""
+ val valuesCode = ctx.splitExpressions(
+ valExprs.zipWithIndex.map { case (e, i) =>
+ val eval = e.genCode(ctx)
+ s"""
+ ${eval.code}
if (${eval.isNull}) {
$values[$i] = null;
} else {
$values[$i] = ${eval.value};
}"""
- }) +
+ })
+
+ ev.copy(code =
s"""
- final InternalRow ${ev.value} = new $rowClass($values);
- $values = null;
- """, isNull = "false")
+ |$values = new Object[${valExprs.size}];
+ |$valuesCode
+ |final InternalRow ${ev.value} = new $rowClass($values);
+ |$values = null;
+ """.stripMargin, isNull = "false")
}
override def prettyName: String = "named_struct"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 6195be3a258c..43e643178c89 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -88,14 +88,34 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
}
/**
- * Abstract parent class for common logic in CaseWhen and CaseWhenCodegen.
+ * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
+ * When a = true, returns b; when c = true, returns d; else returns e.
*
* @param branches seq of (branch condition, branch value)
* @param elseValue optional value for the else branch
*/
-abstract class CaseWhenBase(
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END - When `expr1` = true, returns `expr2`; else when `expr3` = true, returns `expr4`; else returns `expr5`.",
+ arguments = """
+ Arguments:
+ * expr1, expr3 - the branch condition expressions should all be boolean type.
+ * expr2, expr4, expr5 - the branch value expressions and else value expression should all be
+ same type or coercible to a common type.
+ """,
+ examples = """
+ Examples:
+ > SELECT CASE WHEN 1 > 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
+ 1
+ > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
+ 2
+ > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 < 0 THEN 2.0 END;
+ NULL
+ """)
+// scalastyle:on line.size.limit
+case class CaseWhen(
branches: Seq[(Expression, Expression)],
- elseValue: Option[Expression])
+ elseValue: Option[Expression] = None)
extends Expression with Serializable {
override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
@@ -158,111 +178,99 @@ abstract class CaseWhenBase(
val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("")
"CASE" + cases + elseCase + " END"
}
-}
-
-
-/**
- * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
- * When a = true, returns b; when c = true, returns d; else returns e.
- *
- * @param branches seq of (branch condition, branch value)
- * @param elseValue optional value for the else branch
- */
-// scalastyle:off line.size.limit
-@ExpressionDescription(
- usage = "CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END - When `expr1` = true, returns `expr2`; else when `expr3` = true, returns `expr4`; else returns `expr5`.",
- arguments = """
- Arguments:
- * expr1, expr3 - the branch condition expressions should all be boolean type.
- * expr2, expr4, expr5 - the branch value expressions and else value expression should all be
- same type or coercible to a common type.
- """,
- examples = """
- Examples:
- > SELECT CASE WHEN 1 > 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
- 1
- > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
- 2
- > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 < 0 THEN 2.0 ELSE null END;
- NULL
- """)
-// scalastyle:on line.size.limit
-case class CaseWhen(
- val branches: Seq[(Expression, Expression)],
- val elseValue: Option[Expression] = None)
- extends CaseWhenBase(branches, elseValue) with CodegenFallback with Serializable {
-
- override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- super[CodegenFallback].doGenCode(ctx, ev)
- }
-
- def toCodegen(): CaseWhenCodegen = {
- CaseWhenCodegen(branches, elseValue)
- }
-}
-
-/**
- * CaseWhen expression used when code generation condition is satisfied.
- * OptimizeCodegen optimizer replaces CaseWhen into CaseWhenCodegen.
- *
- * @param branches seq of (branch condition, branch value)
- * @param elseValue optional value for the else branch
- */
-case class CaseWhenCodegen(
- val branches: Seq[(Expression, Expression)],
- val elseValue: Option[Expression] = None)
- extends CaseWhenBase(branches, elseValue) with Serializable {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- // Generate code that looks like:
- //
- // condA = ...
- // if (condA) {
- // valueA
- // } else {
- // condB = ...
- // if (condB) {
- // valueB
- // } else {
- // condC = ...
- // if (condC) {
- // valueC
- // } else {
- // elseValue
- // }
- // }
- // }
+ // This variable represents whether the first successful condition is met or not.
+ // It is initialized to `false` and it is set to `true` when the first condition which
+ // evaluates to `true` is met and therefore is not needed to go on anymore on the computation
+ // of the following conditions.
+ val conditionMet = ctx.freshName("caseWhenConditionMet")
+ ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
+ ctx.addMutableState(ctx.javaType(dataType), ev.value)
+
+ // these blocks are meant to be inside a
+ // do {
+ // ...
+ // } while (false);
+ // loop
val cases = branches.map { case (condExpr, valueExpr) =>
val cond = condExpr.genCode(ctx)
val res = valueExpr.genCode(ctx)
s"""
- ${cond.code}
- if (!${cond.isNull} && ${cond.value}) {
- ${res.code}
- ${ev.isNull} = ${res.isNull};
- ${ev.value} = ${res.value};
- }
- """
+ |${cond.code}
+ |if (!${cond.isNull} && ${cond.value}) {
+ | ${res.code}
+ | ${ev.isNull} = ${res.isNull};
+ | ${ev.value} = ${res.value};
+ | $conditionMet = true;
+ | continue;
+ |}
+ """.stripMargin
}
- var generatedCode = cases.mkString("", "\nelse {\n", "\nelse {\n")
-
- elseValue.foreach { elseExpr =>
+ val elseCode = elseValue.map { elseExpr =>
val res = elseExpr.genCode(ctx)
- generatedCode +=
- s"""
- ${res.code}
- ${ev.isNull} = ${res.isNull};
- ${ev.value} = ${res.value};
- """
+ s"""
+ |${res.code}
+ |${ev.isNull} = ${res.isNull};
+ |${ev.value} = ${res.value};
+ """.stripMargin
}
- generatedCode += "}\n" * cases.size
+ val allConditions = cases ++ elseCode
+
+ val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
+ allConditions.mkString("\n")
+ } else {
+ // This generates code like:
+ // conditionMet = caseWhen_1(i);
+ // if(conditionMet) {
+ // continue;
+ // }
+ // conditionMet = caseWhen_2(i);
+ // if(conditionMet) {
+ // continue;
+ // }
+ // ...
+ // and the declared methods are:
+ // private boolean caseWhen_1234() {
+ // boolean conditionMet = false;
+ // do {
+ // // here the evaluation of the conditions
+ // } while (false);
+ // return conditionMet;
+ // }
+ ctx.splitExpressions(allConditions, "caseWhen",
+ ("InternalRow", ctx.INPUT_ROW) :: Nil,
+ returnType = ctx.JAVA_BOOLEAN,
+ makeSplitFunction = {
+ func =>
+ s"""
+ ${ctx.JAVA_BOOLEAN} $conditionMet = false;
+ do {
+ $func
+ } while (false);
+ return $conditionMet;
+ """
+ },
+ foldFunctions = { funcCalls =>
+ funcCalls.map { funcCall =>
+ s"""
+ $conditionMet = $funcCall;
+ if ($conditionMet) {
+ continue;
+ }"""
+ }.mkString
+ })
+ }
ev.copy(code = s"""
- boolean ${ev.isNull} = true;
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- $generatedCode""")
+ ${ev.isNull} = true;
+ ${ev.value} = ${ctx.defaultValue(dataType)};
+ ${ctx.JAVA_BOOLEAN} $conditionMet = false;
+ do {
+ $code
+ } while (false);""")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 8618f4908607..f1aa13066926 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -203,7 +203,7 @@ case class Stack(children: Seq[Expression]) extends Generator {
ctx.addMutableState("InternalRow[]", rowData, s"$rowData = new InternalRow[$numRows];")
val values = children.tail
val dataTypes = values.take(numFields).map(_.dataType)
- val code = ctx.splitExpressions(ctx.INPUT_ROW, Seq.tabulate(numRows) { row =>
+ val code = ctx.splitExpressions(Seq.tabulate(numRows) { row =>
val fields = Seq.tabulate(numFields) { col =>
val index = row * numFields + col
if (index < values.length) values(index) else Literal(null, dataTypes(col))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index 9e0786e36791..c3289b829993 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -270,7 +270,7 @@ abstract class HashExpression[E] extends Expression {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
ev.isNull = "false"
- val childrenHash = ctx.splitExpressions(ctx.INPUT_ROW, children.map { child =>
+ val childrenHash = ctx.splitExpressions(children.map { child =>
val childGen = child.genCode(ctx)
childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) {
computeHash(childGen.value, child.dataType, ev.value, ctx)
@@ -330,9 +330,9 @@ abstract class HashExpression[E] extends Expression {
} else {
val bytes = ctx.freshName("bytes")
s"""
- final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();
- ${genHashBytes(bytes, result)}
- """
+ |final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();
+ |${genHashBytes(bytes, result)}
+ """.stripMargin
}
}
@@ -392,7 +392,10 @@ abstract class HashExpression[E] extends Expression {
val hashes = fields.zipWithIndex.map { case (field, index) =>
nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx)
}
- ctx.splitExpressions(input, hashes)
+ ctx.splitExpressions(
+ expressions = hashes,
+ funcName = "getHash",
+ arguments = Seq("InternalRow" -> input))
}
@tailrec
@@ -608,12 +611,17 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
ev.isNull = "false"
val childHash = ctx.freshName("childHash")
- val childrenHash = ctx.splitExpressions(ctx.INPUT_ROW, children.map { child =>
+ val childrenHash = ctx.splitExpressions(children.map { child =>
val childGen = child.genCode(ctx)
- childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) {
+ val codeToComputeHash = ctx.nullSafeExec(child.nullable, childGen.isNull) {
computeHash(childGen.value, child.dataType, childHash, ctx)
- } + s"${ev.value} = (31 * ${ev.value}) + $childHash;" +
- s"\n$childHash = 0;"
+ }
+ s"""
+ |${childGen.code}
+ |$codeToComputeHash
+ |${ev.value} = (31 * ${ev.value}) + $childHash;
+ |$childHash = 0;
+ """.stripMargin
})
ctx.addMutableState(ctx.javaType(dataType), ev.value)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index 5eaf3f220277..173e171910b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -91,7 +91,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
ev.copy(code = s"""
${ev.isNull} = true;
${ev.value} = ${ctx.defaultValue(dataType)};
- ${ctx.splitExpressions(ctx.INPUT_ROW, evals)}""")
+ ${ctx.splitExpressions(evals)}""")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 006d37f38d6c..e2bc79d98b33 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -101,7 +101,7 @@ trait InvokeLike extends Expression with NonSQLExpression {
"""
}
}
- val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes)
+ val argCode = ctx.splitExpressions(argCodes)
(argCode, argValues.mkString(", "), resultIsNull)
}
@@ -1119,7 +1119,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
"""
}
- val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes)
+ val childrenCode = ctx.splitExpressions(childrenCodes)
val schemaField = ctx.addReferenceObj("schema", schema)
val code = s"""
@@ -1254,7 +1254,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
${javaBeanInstance}.$setterMethod(${fieldGen.value});
"""
}
- val initializeCode = ctx.splitExpressions(ctx.INPUT_ROW, initialize.toSeq)
+ val initializeCode = ctx.splitExpressions(initialize.toSeq)
val code = s"""
${instanceGen.code}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 1c599af2a01d..ee5cf925d3ce 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -208,7 +208,7 @@ case class ConcatWs(children: Seq[Expression])
}
}.unzip
- val codes = ctx.splitExpressions(ctx.INPUT_ROW, evals.map(_.code))
+ val codes = ctx.splitExpressions(evals.map(_.code))
val varargCounts = ctx.splitExpressions(
expressions = varargCount,
funcName = "varargCountsConcatWs",
@@ -1372,10 +1372,10 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
val pattern = children.head.genCode(ctx)
val argListGen = children.tail.map(x => (x.dataType, x.genCode(ctx)))
- val argListCode = argListGen.map(_._2.code + "\n")
-
- val argListString = argListGen.foldLeft("")((s, v) => {
- val nullSafeString =
+ val argList = ctx.freshName("argLists")
+ val numArgLists = argListGen.length
+ val argListCode = argListGen.zipWithIndex.map { case(v, index) =>
+ val value =
if (ctx.boxedType(v._1) != ctx.javaType(v._1)) {
// Java primitives get boxed in order to allow null values.
s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " +
@@ -1383,8 +1383,19 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
} else {
s"(${v._2.isNull}) ? null : ${v._2.value}"
}
- s + "," + nullSafeString
- })
+ s"""
+ ${v._2.code}
+ $argList[$index] = $value;
+ """
+ }
+ val argListCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
+ ctx.splitExpressions(
+ expressions = argListCode,
+ funcName = "valueFormatString",
+ arguments = ("InternalRow", ctx.INPUT_ROW) :: ("Object[]", argList) :: Nil)
+ } else {
+ argListCode.mkString("\n")
+ }
val form = ctx.freshName("formatter")
val formatter = classOf[java.util.Formatter].getName
@@ -1395,10 +1406,11 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
boolean ${ev.isNull} = ${pattern.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
- ${argListCode.mkString}
$stringBuffer $sb = new $stringBuffer();
$formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US);
- $form.format(${pattern.value}.toString() $argListString);
+ Object[] $argList = new Object[$numArgLists];
+ $argListCodes
+ $form.format(${pattern.value}.toString(), $argList);
${ev.value} = UTF8String.fromString($sb.toString());
}""")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 3a3ccd5ff5e6..0d961bf2e6e5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -138,8 +138,6 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
// The following batch should be executed after batch "Join Reorder" and "LocalRelation".
Batch("Check Cartesian Products", Once,
CheckCartesianProducts) ::
- Batch("OptimizeCodegen", Once,
- OptimizeCodegen) ::
Batch("RewriteSubquery", Once,
RewritePredicateSubquery,
CollapseProject) :: Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 523b53b39d6b..785e815b4118 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -552,21 +552,6 @@ object FoldablePropagation extends Rule[LogicalPlan] {
}
-/**
- * Optimizes expressions by replacing according to CodeGen configuration.
- */
-object OptimizeCodegen extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case e: CaseWhen if canCodegen(e) => e.toCodegen()
- }
-
- private def canCodegen(e: CaseWhen): Boolean = {
- val numBranches = e.branches.size + e.elseValue.size
- numBranches <= SQLConf.get.maxCaseBranchesForCodegen
- }
-}
-
-
/**
* Removes [[Cast Casts]] that are unnecessary because the input is already the correct type.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala
index d701a956887a..5e1c4e0bd606 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
import org.apache.spark.sql.catalyst.expressions.AttributeMap
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
-import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical._
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 4eda9f337953..8abb4262d735 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -599,12 +599,6 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
- val MAX_CASES_BRANCHES = buildConf("spark.sql.codegen.maxCaseBranches")
- .internal()
- .doc("The maximum number of switches supported with codegen.")
- .intConf
- .createWithDefault(20)
-
val CODEGEN_LOGGING_MAX_LINES = buildConf("spark.sql.codegen.logging.maxLines")
.internal()
.doc("The maximum number of codegen lines to log when errors occur. Use -1 for unlimited.")
@@ -1004,6 +998,15 @@ object SQLConf {
.intConf
.createWithDefault(10000)
+ val PANDAS_RESPECT_SESSION_LOCAL_TIMEZONE =
+ buildConf("spark.sql.execution.pandas.respectSessionTimeZone")
+ .internal()
+ .doc("When true, make Pandas DataFrame with timestamp type respecting session local " +
+ "timezone when converting to/from Pandas DataFrame. This configuration will be " +
+ "deprecated in the future releases.")
+ .booleanConf
+ .createWithDefault(true)
+
val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter")
.internal()
.doc("When true, the apply function of the rule verifies whether the right node of the" +
@@ -1140,8 +1143,6 @@ class SQLConf extends Serializable with Logging {
def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK)
- def maxCaseBranchesForCodegen: Int = getConf(MAX_CASES_BRANCHES)
-
def loggingMaxLinesForCodegen: Int = getConf(CODEGEN_LOGGING_MAX_LINES)
def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT)
@@ -1324,6 +1325,8 @@ class SQLConf extends Serializable with Logging {
def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH)
+ def pandasRespectSessionTimeZone: Boolean = getConf(PANDAS_RESPECT_SESSION_LOCAL_TIMEZONE)
+
def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER)
/** ********************** SQLConf functionality methods ************ */
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index 6e33087b4c6c..a4198f826ced 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -77,7 +77,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("SPARK-13242: case-when expression with large number of branches (or cases)") {
- val cases = 50
+ val cases = 500
val clauses = 20
// Generate an individual case
@@ -88,13 +88,13 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
(condition, Literal(n))
}
- val expression = CaseWhen((1 to cases).map(generateCase(_)))
+ val expression = CaseWhen((1 to cases).map(generateCase))
val plan = GenerateMutableProjection.generate(Seq(expression))
- val input = new GenericInternalRow(Array[Any](UTF8String.fromString(s"${clauses}:${cases}")))
+ val input = new GenericInternalRow(Array[Any](UTF8String.fromString(s"$clauses:$cases")))
val actual = plan(input).toSeq(Seq(expression.dataType))
- assert(actual(0) == cases)
+ assert(actual.head == cases)
}
test("SPARK-22543: split large if expressions into blocks due to JVM code size limit") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index c76139475687..54cde77176e2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -518,6 +518,14 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
FormatString(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null")
}
+ test("SPARK-22603: FormatString should not generate codes beyond 64KB") {
+ val N = 4500
+ val args = (1 to N).map(i => Literal.create(i.toString, StringType))
+ val format = "%s" * N
+ val expected = (1 to N).map(i => i.toString).mkString
+ checkEvaluation(FormatString(Literal(format) +: args: _*), expected)
+ }
+
test("INSTR") {
val s1 = 'a.string.at(0)
val s2 = 'b.string.at(1)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala
deleted file mode 100644
index b1157f3e3edd..000000000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala
+++ /dev/null
@@ -1,101 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.optimizer
-
-import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.Literal._
-import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.rules._
-
-
-class OptimizeCodegenSuite extends PlanTest {
-
- object Optimize extends RuleExecutor[LogicalPlan] {
- val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen) :: Nil
- }
-
- protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
- val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation()).analyze
- val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation()).analyze)
- comparePlans(actual, correctAnswer)
- }
-
- test("Codegen only when the number of branches is small.") {
- assertEquivalent(
- CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
- CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen())
-
- assertEquivalent(
- CaseWhen(List.fill(100)((TrueLiteral, Literal(1))), Literal(2)),
- CaseWhen(List.fill(100)((TrueLiteral, Literal(1))), Literal(2)))
- }
-
- test("Nested CaseWhen Codegen.") {
- assertEquivalent(
- CaseWhen(
- Seq((CaseWhen(Seq((TrueLiteral, TrueLiteral)), FalseLiteral), Literal(3))),
- CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5))),
- CaseWhen(
- Seq((CaseWhen(Seq((TrueLiteral, TrueLiteral)), FalseLiteral).toCodegen(), Literal(3))),
- CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5)).toCodegen()).toCodegen())
- }
-
- test("Multiple CaseWhen in one operator.") {
- val plan = OneRowRelation()
- .select(
- CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
- CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)),
- CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)),
- CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6))).analyze
- val correctAnswer = OneRowRelation()
- .select(
- CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(),
- CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(),
- CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)),
- CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen()).analyze
- val optimized = Optimize.execute(plan)
- comparePlans(optimized, correctAnswer)
- }
-
- test("Multiple CaseWhen in different operators") {
- val plan = OneRowRelation()
- .select(
- CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
- CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)),
- CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
- .where(
- LessThan(
- CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)),
- CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
- ).analyze
- val correctAnswer = OneRowRelation()
- .select(
- CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(),
- CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(),
- CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
- .where(
- LessThan(
- CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen(),
- CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
- ).analyze
- val optimized = Optimize.execute(plan)
- comparePlans(optimized, correctAnswer)
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index e3fa2ced760e..35abeccfd514 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -592,7 +592,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* `sep` (default `,`): sets the single character as a separator for each
* field and value.
* `quote` (default `"`): sets the single character used for escaping quoted values where
- * the separator can be part of the value.
+ * the separator can be part of the value. If an empty string is set, it uses `u0000`
+ * (null character).
* `escape` (default `\`): sets the single character used for escaping quotes inside
* an already quoted value.
* `escapeQuotes` (default `true`): a flag indicating whether values containing
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
index 236995708a12..8d715f634298 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
@@ -41,7 +41,7 @@ case class LogicalRelation(
override def computeStats(): Statistics = {
catalogTable
- .flatMap(_.stats.map(_.toPlanStats(output)))
+ .flatMap(_.stats.map(_.toPlanStats(output, conf.cboEnabled)))
.getOrElse(Statistics(sizeInBytes = relation.sizeInBytes))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index e27210117a1e..c06bc7b66ff3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -63,6 +63,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
private val batchSize = conf.arrowMaxRecordsPerBatch
private val sessionLocalTimeZone = conf.sessionLocalTimeZone
+ private val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone
protected override def evaluate(
funcs: Seq[ChainedPythonFunctions],
@@ -81,7 +82,8 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
val columnarBatchIter = new ArrowPythonRunner(
funcs, bufferSize, reuseWorker,
- PythonEvalType.SQL_PANDAS_SCALAR_UDF, argOffsets, schema, sessionLocalTimeZone)
+ PythonEvalType.SQL_PANDAS_SCALAR_UDF, argOffsets, schema,
+ sessionLocalTimeZone, pandasRespectSessionTimeZone)
.compute(batchIter, context.partitionId(), context)
new Iterator[InternalRow] {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 94c05b9b5e49..9a94d771a01b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -44,7 +44,8 @@ class ArrowPythonRunner(
evalType: Int,
argOffsets: Array[Array[Int]],
schema: StructType,
- timeZoneId: String)
+ timeZoneId: String,
+ respectTimeZone: Boolean)
extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](
funcs, bufferSize, reuseWorker, evalType, argOffsets) {
@@ -58,6 +59,11 @@ class ArrowPythonRunner(
protected override def writeCommand(dataOut: DataOutputStream): Unit = {
PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
+ if (respectTimeZone) {
+ PythonRDD.writeUTF(timeZoneId, dataOut)
+ } else {
+ dataOut.writeInt(SpecialLengths.NULL)
+ }
}
protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
index ee495814b825..59db66bd7adf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
@@ -78,6 +78,7 @@ case class FlatMapGroupsInPandasExec(
val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray)
val schema = StructType(child.schema.drop(groupingAttributes.length))
val sessionLocalTimeZone = conf.sessionLocalTimeZone
+ val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone
inputRDD.mapPartitionsInternal { iter =>
val grouped = if (groupingAttributes.isEmpty) {
@@ -95,7 +96,8 @@ case class FlatMapGroupsInPandasExec(
val columnarBatchIter = new ArrowPythonRunner(
chainedFunc, bufferSize, reuseWorker,
- PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, argOffsets, schema, sessionLocalTimeZone)
+ PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, argOffsets, schema,
+ sessionLocalTimeZone, pandasRespectSessionTimeZone)
.compute(grouped, context.partitionId(), context)
columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala
index d077836da847..e49546830286 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala
@@ -90,7 +90,7 @@ class FlatMapGroupsWithState_StateManager(
val deser = stateEncoder.resolveAndBind().deserializer.transformUp {
case BoundReference(ordinal, _, _) => GetStructField(boundRefToNestedState, ordinal)
}
- CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), elseValue = deser).toCodegen()
+ CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), elseValue = deser)
}
// Converters for translating state between rows and Java objects
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
index fdd25330c5e6..6ae307bce10c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
@@ -480,7 +480,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
if (tableMetadata.tableType == CatalogTableType.VIEW) {
// Temp or persistent views: refresh (or invalidate) any metadata/data cached
// in the plan recursively.
- table.queryExecution.analyzed.foreach(_.refresh())
+ table.queryExecution.analyzed.refresh()
} else {
// Non-temp tables: refresh the metadata cache.
sessionCatalog.refreshTable(tableIdent)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 644e72c893ce..72a5cc98fbec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -29,7 +29,7 @@ import org.scalatest.Matchers._
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union}
-import org.apache.spark.sql.execution.{FilterExec, QueryExecution}
+import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.functions._
@@ -2158,4 +2158,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val mean = result.select("DecimalCol").where($"summary" === "mean")
assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000")))
}
+
+ test("SPARK-22520: support code generation for large CaseWhen") {
+ val N = 30
+ var expr1 = when($"id" === lit(0), 0)
+ var expr2 = when($"id" === lit(0), 10)
+ (1 to N).foreach { i =>
+ expr1 = expr1.when($"id" === lit(i), -i)
+ expr2 = expr2.when($"id" === lit(i + 10), i)
+ }
+ val df = spark.range(1).select(expr1, expr2.otherwise(0))
+ checkAnswer(df, Row(0, 10) :: Nil)
+ assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala
index f6df077ec572..65ccc1915882 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, H
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, HistogramBin, LogicalPlan}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.LogicalRelation
-import org.apache.spark.sql.internal.StaticSQLConf
+import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.Decimal
@@ -223,11 +223,19 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils
assert(catalogTable.stats.get.colStats == Map("c1" -> emptyColStat))
// Check relation statistics
- assert(relation.stats.sizeInBytes == 0)
- assert(relation.stats.rowCount == Some(0))
- assert(relation.stats.attributeStats.size == 1)
- val (attribute, colStat) = relation.stats.attributeStats.head
- assert(attribute.name == "c1")
- assert(colStat == emptyColStat)
+ withSQLConf(SQLConf.CBO_ENABLED.key -> "true") {
+ assert(relation.stats.sizeInBytes == 1)
+ assert(relation.stats.rowCount == Some(0))
+ assert(relation.stats.attributeStats.size == 1)
+ val (attribute, colStat) = relation.stats.attributeStats.head
+ assert(attribute.name == "c1")
+ assert(colStat == emptyColStat)
+ }
+ relation.invalidateStatsCache()
+ withSQLConf(SQLConf.CBO_ENABLED.key -> "false") {
+ assert(relation.stats.sizeInBytes == 0)
+ assert(relation.stats.rowCount.isEmpty)
+ assert(relation.stats.attributeStats.isEmpty)
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
index 878f435c75cb..fdb9b2f51f9c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
@@ -117,6 +117,21 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo
}
}
+ test("SPARK-22431: table with nested type col with special char") {
+ withTable("t") {
+ spark.sql("CREATE TABLE t(q STRUCT<`$a`:INT, col2:STRING>, i1 INT) USING PARQUET")
+ checkAnswer(spark.table("t"), Nil)
+ }
+ }
+
+ test("SPARK-22431: view with nested type") {
+ withView("t", "v") {
+ spark.sql("CREATE VIEW t AS SELECT STRUCT('a' AS `$a`, 1 AS b) q")
+ checkAnswer(spark.table("t"), Row(Row("a", 1)) :: Nil)
+ spark.sql("CREATE VIEW v AS SELECT STRUCT('a' AS `a`, 1 AS b) q")
+ checkAnswer(spark.table("t"), Row(Row("a", 1)) :: Nil)
+ }
+ }
}
abstract class DDLSuite extends QueryTest with SQLTestUtils {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
index f9d75fc1788d..8b1521bacea4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
@@ -221,35 +221,6 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
.sessionState.conf.warehousePath.stripSuffix("/"))
}
- test("MAX_CASES_BRANCHES") {
- withTable("tab1") {
- spark.range(10).write.saveAsTable("tab1")
- val sql_one_branch_caseWhen = "SELECT CASE WHEN id = 1 THEN 1 END FROM tab1"
- val sql_two_branch_caseWhen = "SELECT CASE WHEN id = 1 THEN 1 ELSE 0 END FROM tab1"
-
- withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "0") {
- assert(!sql(sql_one_branch_caseWhen)
- .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
- assert(!sql(sql_two_branch_caseWhen)
- .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
- }
-
- withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "1") {
- assert(sql(sql_one_branch_caseWhen)
- .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
- assert(!sql(sql_two_branch_caseWhen)
- .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
- }
-
- withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "2") {
- assert(sql(sql_one_branch_caseWhen)
- .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
- assert(sql(sql_two_branch_caseWhen)
- .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
- }
- }
- }
-
test("static SQL conf comes from SparkConf") {
val previousValue = sparkContext.conf.get(SCHEMA_STRING_LENGTH_THRESHOLD)
try {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
index b5a5890d47b0..47ce6ba83866 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
@@ -488,6 +488,7 @@ private[hive] class HiveClientImpl(
}
override def createTable(table: CatalogTable, ignoreIfExists: Boolean): Unit = withHiveState {
+ verifyColumnDataType(table.dataSchema)
client.createTable(toHiveTable(table, Some(userName)), ignoreIfExists)
}
@@ -507,6 +508,7 @@ private[hive] class HiveClientImpl(
// these properties are still available to the others that share the same Hive metastore.
// If users explicitly alter these Hive-specific properties through ALTER TABLE DDL, we respect
// these user-specified values.
+ verifyColumnDataType(table.dataSchema)
val hiveTable = toHiveTable(
table.copy(properties = table.ignoredProperties ++ table.properties), Some(userName))
// Do not use `table.qualifiedName` here because this may be a rename
@@ -520,6 +522,7 @@ private[hive] class HiveClientImpl(
newDataSchema: StructType,
schemaProps: Map[String, String]): Unit = withHiveState {
val oldTable = client.getTable(dbName, tableName)
+ verifyColumnDataType(newDataSchema)
val hiveCols = newDataSchema.map(toHiveColumn)
oldTable.setFields(hiveCols.asJava)
@@ -872,15 +875,19 @@ private[hive] object HiveClientImpl {
new FieldSchema(c.name, typeString, c.getComment().orNull)
}
- /** Builds the native StructField from Hive's FieldSchema. */
- def fromHiveColumn(hc: FieldSchema): StructField = {
- val columnType = try {
+ /** Get the Spark SQL native DataType from Hive's FieldSchema. */
+ private def getSparkSQLDataType(hc: FieldSchema): DataType = {
+ try {
CatalystSqlParser.parseDataType(hc.getType)
} catch {
case e: ParseException =>
throw new SparkException("Cannot recognize hive type string: " + hc.getType, e)
}
+ }
+ /** Builds the native StructField from Hive's FieldSchema. */
+ def fromHiveColumn(hc: FieldSchema): StructField = {
+ val columnType = getSparkSQLDataType(hc)
val metadata = if (hc.getType != columnType.catalogString) {
new MetadataBuilder().putString(HIVE_TYPE_STRING, hc.getType).build()
} else {
@@ -895,6 +902,10 @@ private[hive] object HiveClientImpl {
Option(hc.getComment).map(field.withComment).getOrElse(field)
}
+ private def verifyColumnDataType(schema: StructType): Unit = {
+ schema.foreach(col => getSparkSQLDataType(toHiveColumn(col)))
+ }
+
private def toInputFormat(name: String) =
Utils.classForName(name).asInstanceOf[Class[_ <: org.apache.hadoop.mapred.InputFormat[_, _]]]
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
index 7427948fe138..0cdd9305c6b6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
@@ -41,7 +41,35 @@ import org.apache.spark.sql.types._
class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton {
- test("Hive serde tables should fallback to HDFS for size estimation") {
+
+ test("size estimation for relations is based on row size * number of rows") {
+ val dsTbl = "rel_est_ds_table"
+ val hiveTbl = "rel_est_hive_table"
+ withTable(dsTbl, hiveTbl) {
+ spark.range(1000L).write.format("parquet").saveAsTable(dsTbl)
+ spark.range(1000L).write.format("hive").saveAsTable(hiveTbl)
+
+ Seq(dsTbl, hiveTbl).foreach { tbl =>
+ sql(s"ANALYZE TABLE $tbl COMPUTE STATISTICS")
+ val catalogStats = getCatalogStatistics(tbl)
+ withSQLConf(SQLConf.CBO_ENABLED.key -> "false") {
+ val relationStats = spark.table(tbl).queryExecution.optimizedPlan.stats
+ assert(relationStats.sizeInBytes == catalogStats.sizeInBytes)
+ assert(relationStats.rowCount.isEmpty)
+ }
+ spark.sessionState.catalog.refreshTable(TableIdentifier(tbl))
+ withSQLConf(SQLConf.CBO_ENABLED.key -> "true") {
+ val relationStats = spark.table(tbl).queryExecution.optimizedPlan.stats
+ // Due to compression in parquet files, in this test, file size is smaller than
+ // in-memory size.
+ assert(catalogStats.sizeInBytes < relationStats.sizeInBytes)
+ assert(catalogStats.rowCount == relationStats.rowCount)
+ }
+ }
+ }
+ }
+
+ test("Hive serde tables should fallback to HDFS for size estimation") {
withSQLConf(SQLConf.ENABLE_FALL_BACK_TO_HDFS_FOR_STATS.key -> "true") {
withTable("csv_table") {
withTempDir { tempDir =>
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
index d3465a641a1a..9063ef066aa8 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
@@ -174,6 +174,88 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA
test("alter datasource table add columns - partitioned - orc") {
testAddColumnPartitioned("orc")
}
+
+ test("SPARK-22431: illegal nested type") {
+ val queries = Seq(
+ "CREATE TABLE t AS SELECT STRUCT('a' AS `$a`, 1 AS b) q",
+ "CREATE TABLE t(q STRUCT<`$a`:INT, col2:STRING>, i1 INT)",
+ "CREATE VIEW t AS SELECT STRUCT('a' AS `$a`, 1 AS b) q")
+
+ queries.foreach(query => {
+ val err = intercept[SparkException] {
+ spark.sql(query)
+ }.getMessage
+ assert(err.contains("Cannot recognize hive type string"))
+ })
+
+ withView("v") {
+ spark.sql("CREATE VIEW v AS SELECT STRUCT('a' AS `a`, 1 AS b) q")
+ checkAnswer(sql("SELECT q.`a`, q.b FROM v"), Row("a", 1) :: Nil)
+
+ val err = intercept[SparkException] {
+ spark.sql("ALTER VIEW v AS SELECT STRUCT('a' AS `$a`, 1 AS b) q")
+ }.getMessage
+ assert(err.contains("Cannot recognize hive type string"))
+ }
+ }
+
+ test("SPARK-22431: table with nested type") {
+ withTable("t", "x") {
+ spark.sql("CREATE TABLE t(q STRUCT<`$a`:INT, col2:STRING>, i1 INT) USING PARQUET")
+ checkAnswer(spark.table("t"), Nil)
+ spark.sql("CREATE TABLE x (q STRUCT, i1 INT)")
+ checkAnswer(spark.table("x"), Nil)
+ }
+ }
+
+ test("SPARK-22431: view with nested type") {
+ withView("v") {
+ spark.sql("CREATE VIEW v AS SELECT STRUCT('a' AS `a`, 1 AS b) q")
+ checkAnswer(spark.table("v"), Row(Row("a", 1)) :: Nil)
+
+ spark.sql("ALTER VIEW v AS SELECT STRUCT('a' AS `b`, 1 AS b) q1")
+ val df = spark.table("v")
+ assert("q1".equals(df.schema.fields(0).name))
+ checkAnswer(df, Row(Row("a", 1)) :: Nil)
+ }
+ }
+
+ test("SPARK-22431: alter table tests with nested types") {
+ withTable("t1", "t2", "t3") {
+ spark.sql("CREATE TABLE t1 (q STRUCT, i1 INT)")
+ spark.sql("ALTER TABLE t1 ADD COLUMNS (newcol1 STRUCT<`col1`:STRING, col2:Int>)")
+ val newcol = spark.sql("SELECT * FROM t1").schema.fields(2).name
+ assert("newcol1".equals(newcol))
+
+ spark.sql("CREATE TABLE t2(q STRUCT<`a`:INT, col2:STRING>, i1 INT) USING PARQUET")
+ spark.sql("ALTER TABLE t2 ADD COLUMNS (newcol1 STRUCT<`$col1`:STRING, col2:Int>)")
+ spark.sql("ALTER TABLE t2 ADD COLUMNS (newcol2 STRUCT<`col1`:STRING, col2:Int>)")
+
+ val df2 = spark.table("t2")
+ checkAnswer(df2, Nil)
+ assert("newcol1".equals(df2.schema.fields(2).name))
+ assert("newcol2".equals(df2.schema.fields(3).name))
+
+ spark.sql("CREATE TABLE t3(q STRUCT<`$a`:INT, col2:STRING>, i1 INT) USING PARQUET")
+ spark.sql("ALTER TABLE t3 ADD COLUMNS (newcol1 STRUCT<`$col1`:STRING, col2:Int>)")
+ spark.sql("ALTER TABLE t3 ADD COLUMNS (newcol2 STRUCT<`col1`:STRING, col2:Int>)")
+
+ val df3 = spark.table("t3")
+ checkAnswer(df3, Nil)
+ assert("newcol1".equals(df3.schema.fields(2).name))
+ assert("newcol2".equals(df3.schema.fields(3).name))
+ }
+ }
+
+ test("SPARK-22431: negative alter table tests with nested types") {
+ withTable("t1") {
+ spark.sql("CREATE TABLE t1 (q STRUCT, i1 INT)")
+ val err = intercept[SparkException] {
+ spark.sql("ALTER TABLE t1 ADD COLUMNS (newcol1 STRUCT<`$col1`:STRING, col2:Int>)")
+ }.getMessage
+ assert(err.contains("Cannot recognize hive type string:"))
+ }
+ }
}
class HiveDDLSuite
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
index 3066a4f305f0..dfabf1ec2a22 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
@@ -18,8 +18,10 @@
package org.apache.spark.sql.hive.execution
import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
/**
@@ -29,21 +31,32 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto
import testImplicits._
test("show cost in explain command") {
+ val explainCostCommand = "EXPLAIN COST SELECT * FROM src"
// For readability, we only show optimized plan and physical plan in explain cost command
- checkKeywordsExist(sql("EXPLAIN COST SELECT * FROM src "),
+ checkKeywordsExist(sql(explainCostCommand),
"Optimized Logical Plan", "Physical Plan")
- checkKeywordsNotExist(sql("EXPLAIN COST SELECT * FROM src "),
+ checkKeywordsNotExist(sql(explainCostCommand),
"Parsed Logical Plan", "Analyzed Logical Plan")
- // Only has sizeInBytes before ANALYZE command
- checkKeywordsExist(sql("EXPLAIN COST SELECT * FROM src "), "sizeInBytes")
- checkKeywordsNotExist(sql("EXPLAIN COST SELECT * FROM src "), "rowCount")
+ withSQLConf(SQLConf.CBO_ENABLED.key -> "true") {
+ // Only has sizeInBytes before ANALYZE command
+ checkKeywordsExist(sql(explainCostCommand), "sizeInBytes")
+ checkKeywordsNotExist(sql(explainCostCommand), "rowCount")
- // Has both sizeInBytes and rowCount after ANALYZE command
- sql("ANALYZE TABLE src COMPUTE STATISTICS")
- checkKeywordsExist(sql("EXPLAIN COST SELECT * FROM src "), "sizeInBytes", "rowCount")
+ // Has both sizeInBytes and rowCount after ANALYZE command
+ sql("ANALYZE TABLE src COMPUTE STATISTICS")
+ checkKeywordsExist(sql(explainCostCommand), "sizeInBytes", "rowCount")
+ }
+
+ spark.sessionState.catalog.refreshTable(TableIdentifier("src"))
+
+ withSQLConf(SQLConf.CBO_ENABLED.key -> "false") {
+ // Don't show rowCount if cbo is disabled
+ checkKeywordsExist(sql(explainCostCommand), "sizeInBytes")
+ checkKeywordsNotExist(sql(explainCostCommand), "rowCount")
+ }
- // No cost information
+ // No statistics information if "cost" is not specified
checkKeywordsNotExist(sql("EXPLAIN SELECT * FROM src "), "sizeInBytes", "rowCount")
}
|