From 0b589555862fa1d0142c5949ec9f9cf1552834c9 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Sun, 16 Jun 2024 22:35:40 +0200 Subject: [PATCH 01/16] [SPARK-48638] Add QueryExecution support for DataFrame --- .../sql/connect/utils/MetricGenerator.scala | 1 + dev/requirements.txt | 3 + dev/sparktestsupport/modules.py | 1 + python/pyspark/errors/error-conditions.json | 5 + python/pyspark/sql/classic/dataframe.py | 8 + python/pyspark/sql/connect/client/core.py | 68 ++--- python/pyspark/sql/connect/dataframe.py | 19 +- python/pyspark/sql/dataframe.py | 5 + python/pyspark/sql/metrics.py | 238 ++++++++++++++++++ .../sql/tests/connect/test_df_debug.py | 68 +++++ python/pyspark/sql/tests/test_dataframe.py | 11 +- python/pyspark/testing/connectutils.py | 7 + 12 files changed, 379 insertions(+), 55 deletions(-) create mode 100644 python/pyspark/sql/metrics.py create mode 100644 python/pyspark/sql/tests/connect/test_df_debug.py diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala index e2e412831187..d76bec5454ab 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala @@ -70,6 +70,7 @@ private[connect] object MetricGenerator extends AdaptiveSparkPlanHelper { .newBuilder() .setName(p.nodeName) .setPlanId(p.id) + .setParent(parentId) .putAllExecutionMetrics(mv.asJava) .build() Seq(mo) ++ transformChildren(p) diff --git a/dev/requirements.txt b/dev/requirements.txt index d6530d8ce282..88883a963950 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -60,6 +60,9 @@ mypy-protobuf==3.3.0 googleapis-common-protos-stubs==2.2.0 grpc-stubs==1.24.11 +# Debug for Spark and Spark Connect +graphviz==0.20.3 + # TorchDistributor dependencies torch torchvision diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index b97ec34b5382..de3bc99826d5 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -1062,6 +1062,7 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_pandas_udf_window", "pyspark.sql.tests.connect.test_resources", "pyspark.sql.tests.connect.shell.test_progress", + "pyspark.sql.tests.connect.test_df_debug", ], excluded_python_implementations=[ "PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index 30db37387249..23b0c4cd4f9b 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -149,6 +149,11 @@ "Cannot without ." ] }, + "CLASSIC_OPERATION_NOT_SUPPORTED_ON_DF": { + "message": [ + "Calling property or member is not supported in PySpark Classic, please use Spark Connect instead." + ] + }, "COLLATION_INVALID_PROVIDER" : { "message" : [ "The value does not represent a correct collation provider. Supported providers are: []." diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py index a03467aff194..d4541386e6ad 100644 --- a/python/pyspark/sql/classic/dataframe.py +++ b/python/pyspark/sql/classic/dataframe.py @@ -94,6 +94,7 @@ from pyspark.sql.session import SparkSession from pyspark.sql.group import GroupedData from pyspark.sql.observation import Observation + from pyspark.sql.metrics import QueryExecution class DataFrame(ParentDataFrame, PandasMapOpsMixin, PandasConversionMixin): @@ -1835,6 +1836,13 @@ def toArrow(self) -> "pa.Table": def toPandas(self) -> "PandasDataFrameLike": return PandasConversionMixin.toPandas(self) + @property + def queryExecution(self) -> Optional["QueryExecution"]: + raise PySparkValueError( + error_class="CLASSIC_OPERATION_NOT_SUPPORTED_ON_DF", + message_parameters={"member": "queryExecution"}, + ) + def _to_scala_map(sc: "SparkContext", jm: Dict) -> "JavaObject": """ diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 4c638be3b0af..5d91313e79b7 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -61,6 +61,7 @@ from pyspark.loose_version import LooseVersion from pyspark.version import __version__ from pyspark.resource.information import ResourceInformation +from pyspark.sql.metrics import MetricValue, PlanMetrics, QueryExecution, ObservedMetrics from pyspark.sql.connect.client.artifact import ArtifactManager from pyspark.sql.connect.client.logging import logger from pyspark.sql.connect.profiler import ConnectProfilerCollector @@ -447,56 +448,7 @@ def toChannel(self) -> grpc.Channel: return self._secure_channel(self.endpoint, creds) -class MetricValue: - def __init__(self, name: str, value: Union[int, float], type: str): - self._name = name - self._type = type - self._value = value - - def __repr__(self) -> str: - return f"<{self._name}={self._value} ({self._type})>" - - @property - def name(self) -> str: - return self._name - - @property - def value(self) -> Union[int, float]: - return self._value - - @property - def metric_type(self) -> str: - return self._type - - -class PlanMetrics: - def __init__(self, name: str, id: int, parent: int, metrics: List[MetricValue]): - self._name = name - self._id = id - self._parent_id = parent - self._metrics = metrics - - def __repr__(self) -> str: - return f"Plan({self._name})={self._metrics}" - - @property - def name(self) -> str: - return self._name - - @property - def plan_id(self) -> int: - return self._id - - @property - def parent_plan_id(self) -> int: - return self._parent_id - - @property - def metrics(self) -> List[MetricValue]: - return self._metrics - - -class PlanObservedMetrics: +class PlanObservedMetrics(ObservedMetrics): def __init__(self, name: str, metrics: List[pb2.Expression.Literal], keys: List[str]): self._name = name self._metrics = metrics @@ -513,6 +465,13 @@ def name(self) -> str: def metrics(self) -> List[pb2.Expression.Literal]: return self._metrics + @property + def pairs(self) -> dict[str, Any]: + result = {} + for x in range(len(self._metrics)): + result[self.keys[x]] = LiteralExpression._to_value(self.metrics[x]) + return result + @property def keys(self) -> List[str]: return self._keys @@ -920,16 +879,19 @@ def to_table_as_iterator( def to_table( self, plan: pb2.Plan, observations: Dict[str, Observation] - ) -> Tuple["pa.Table", Optional[StructType]]: + ) -> Tuple["pa.Table", Optional[StructType], QueryExecution]: """ Return given plan as a PyArrow Table. """ logger.info(f"Executing plan {self._proto_to_string(plan)}") req = self._execute_plan_request_with_metadata() req.plan.CopyFrom(plan) - table, schema, _, _, _ = self._execute_and_fetch(req, observations) + table, schema, metrics, observed_metrics, _ = self._execute_and_fetch(req, observations) + + # Create a query execution object. + qe = QueryExecution(metrics, observed_metrics) assert table is not None - return table, schema + return table, schema, qe def to_pandas(self, plan: pb2.Plan, observations: Dict[str, Observation]) -> "pd.DataFrame": """ diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index f2705ec7ad71..53421938dbff 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -101,6 +101,7 @@ from pyspark.sql.connect.observation import Observation from pyspark.sql.connect.session import SparkSession from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame + from pyspark.sql.metrics import QueryExecution class DataFrame(ParentDataFrame): @@ -137,6 +138,7 @@ def __init__( # by __repr__ and _repr_html_ while eager evaluation opens. self._support_repr_html = False self._cached_schema: Optional[StructType] = None + self._query_execution: Optional["QueryExecution"] = None def __reduce__(self) -> Tuple: """ @@ -1836,7 +1838,9 @@ def collect(self) -> List[Row]: def _to_table(self) -> Tuple["pa.Table", Optional[StructType]]: query = self._plan.to_proto(self._session.client) - table, schema = self._session.client.to_table(query, self._plan.observations) + table, schema, self._query_execution = self._session.client.to_table( + query, self._plan.observations + ) assert table is not None return (table, schema) @@ -2202,6 +2206,19 @@ def rdd(self) -> "RDD[Row]": message_parameters={"feature": "rdd"}, ) + @property + def queryExecution(self) -> Optional["QueryExecution"]: + """ + The queryExecution method allows to introspect information about the actual + query execution after the successful execution. Accessing this member before + the query execution has happened will return None. + + Returns + ------- + An instance of QueryExecution or None when the value is not set yet. + """ + return self._query_execution + class DataFrameNaFunctions(ParentDataFrameNaFunctions): def __init__(self, df: ParentDataFrame): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 62c46cfec93c..7a352c44c35d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -64,6 +64,7 @@ ArrowMapIterFunction, DataFrameLike as PandasDataFrameLike, ) + from pyspark.sql.metrics import QueryExecution __all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"] @@ -6281,6 +6282,10 @@ def toPandas(self) -> "PandasDataFrameLike": """ ... + @property + def queryExecution(self) -> Optional["QueryExecution"]: + ... + class DataFrameNaFunctions: """Functionality for working with missing data in :class:`DataFrame`. diff --git a/python/pyspark/sql/metrics.py b/python/pyspark/sql/metrics.py new file mode 100644 index 000000000000..0f5169c89b47 --- /dev/null +++ b/python/pyspark/sql/metrics.py @@ -0,0 +1,238 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +import dataclasses +from typing import Optional, List, Tuple, Dict, Any, Union, TYPE_CHECKING, Sequence + +from pyspark.errors import PySparkValueError +from pyspark.testing.connectutils import have_graphviz + +if TYPE_CHECKING: + if have_graphviz: + import graphviz # type: ignore + + +class ObservedMetrics(abc.ABC): + @property + @abc.abstractmethod + def name(self) -> str: + ... + + @property + @abc.abstractmethod + def pairs(self) -> Dict[str, Any]: + ... + + @property + @abc.abstractmethod + def keys(self) -> List[str]: + ... + + +class MetricValue: + """The metric values is the Python representation of a plan metric value from the JVM. + However, it does not have any reference to the original value.""" + + def __init__(self, name: str, value: Union[int, float], type: str): + self._name = name + self._type = type + self._value = value + + def __repr__(self) -> str: + return f"<{self._name}={self._value} ({self._type})>" + + @property + def name(self) -> str: + return self._name + + @property + def value(self) -> Union[int, float]: + return self._value + + @property + def metric_type(self) -> str: + return self._type + + +class PlanMetrics: + """Represents a particular plan node and the associated metrics of this node.""" + + def __init__(self, name: str, id: int, parent: int, metrics: List[MetricValue]): + self._name = name + self._id = id + self._parent_id = parent + self._metrics = metrics + + def __repr__(self) -> str: + return f"Plan({self._name}: {self._id}->{self._parent_id})={self._metrics}" + + @property + def name(self) -> str: + return self._name + + @property + def plan_id(self) -> int: + return self._id + + @property + def parent_plan_id(self) -> int: + return self._parent_id + + @property + def metrics(self) -> List[MetricValue]: + return self._metrics + + +class CollectedMetrics: + @dataclasses.dataclass + class Node: + id: int + name: str = dataclasses.field(default="") + metrics: List[MetricValue] = dataclasses.field(default_factory=list) + children: List[int] = dataclasses.field(default_factory=list) + + def __init__(self, metrics: List[PlanMetrics]): + # Sort the input list + self._metrics = sorted(metrics, key=lambda x: x._parent_id, reverse=False) + + def extract_graph(self) -> Tuple[int, Dict[int, "CollectedMetrics.Node"]]: + """ + Builds the graph of the query plan. The graph is represented as a dictionary where the key + is the node ID and the value is the node itself. The root node is the node that has no + parent. + + Returns + ------- + The root node ID and the graph of all nodes. + """ + all_nodes: Dict[int, CollectedMetrics.Node] = {} + + for m in self._metrics: + # Add yourself to the list if you have to. + if m.plan_id not in all_nodes: + all_nodes[m.plan_id] = CollectedMetrics.Node(m.plan_id, m.name, m.metrics) + else: + all_nodes[m.plan_id].name = m.name + all_nodes[m.plan_id].metrics = m.metrics + + # Now check for the parent of this node if it's in + if m.parent_plan_id not in all_nodes: + all_nodes[m.parent_plan_id] = CollectedMetrics.Node(m.parent_plan_id) + + all_nodes[m.parent_plan_id].children.append(m.plan_id) + + # Next step is to find all the root nodes. Root nodes are never used in children. + # So we start will all node ids as candidates. + candidates = set(all_nodes.keys()) + for k, v in all_nodes.items(): + for c in v.children: + if c in candidates and c != k: + candidates.remove(c) + + assert len(candidates) == 1, f"Expected 1 root node, found {len(candidates)}" + return candidates.pop(), all_nodes + + def toDot(self, filename: Optional[str] = None, out_format: str = "png") -> "graphviz.Digraph": + """ + Converts the collected metrics into a dot representation. Since the graphviz Digraph + implementation provides the ability to render the result graph directory in a + notebook, we return the graph object directly. + + If the graphviz package is not available, a PACKAGE_NOT_INSTALLED error is raised. + + Parameters + ---------- + filename - str, optional + The filename to save the graph to given an output format. The path can be + relative or absolute. + + out_format - str + The output format of the graph. The default is 'png'. + + Returns + ------- + + """ + try: + import graphviz + + dot = graphviz.Digraph( + comment="Query Plan", + node_attr={ + "shape": "box", + "font-size": "10pt", + }, + ) + + root, graph = self.extract_graph() + for k, v in graph.items(): + # Build table rows for the metrics + rows = "\n".join( + [ + ( + f'{x.name}' + f'{x.value} ({x.metric_type})' + ) + for x in v.metrics + ] + ) + + dot.node( + str(k), + """< + + + + + {} +
+ {} +
Metrics
>""".format( + v.name, rows + ), + ) + for c in v.children: + dot.edge(str(k), str(c)) + + if filename: + dot.render(filename, format=out_format, cleanup=True) + return dot + + except ImportError: + raise PySparkValueError( + error_class="PACKAGE_NOT_INSTALLED", + message_parameters={"package_name": "graphviz", "minimum_version": "0.20"}, + ) + + +class QueryExecution: + """The query execution class allows users to inspect the query execution of this particular + data frame. This value is only set in the data frame if it was executed.""" + + def __init__( + self, metrics: Optional[list[PlanMetrics]], obs: Optional[Sequence[ObservedMetrics]] + ): + self._metrics = CollectedMetrics(metrics) if metrics else None + self._observations = obs if obs else [] + + @property + def metrics(self) -> Optional[CollectedMetrics]: + return self._metrics + + @property + def flows(self) -> List[Tuple[str, Dict[str, Any]]]: + return [(f.name, f.pairs) for f in self._observations] diff --git a/python/pyspark/sql/tests/connect/test_df_debug.py b/python/pyspark/sql/tests/connect/test_df_debug.py new file mode 100644 index 000000000000..01c47fa9ad3c --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_df_debug.py @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.testing.connectutils import ( + should_test_connect, + have_graphviz, + graphviz_requirement_message, +) +from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase + +if should_test_connect: + from pyspark.sql.connect.dataframe import DataFrame + + +class SparkConnectDataFrameDebug(SparkConnectSQLTestCase): + def test_df_debug_basics(self): + df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() + x = df.collect() # noqa: F841 + qe = df.queryExecution + + root, graph = qe.metrics.extract_graph() + self.assertIn(root, graph, "The root must be rooted in the graph") + + def test_df_quey_execution_empty_before_execution(self): + df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() + qe = df.queryExecution + self.assertIsNone(qe, "The query execution must be None before the action is executed") + + @unittest.skipIf(not have_graphviz, graphviz_requirement_message) + def test_df_query_execution_metrics_to_dot(self): + df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() + x = df.collect() # noqa: F841 + qe = df.queryExecution + + dot = qe.metrics.toDot() + source = dot.source + self.assertIsNotNone(dot, "The dot representation must not be None") + self.assertGreater(len(source), 0, "The dot representation must not be empty") + self.assertIn("digraph", source, "The dot representation must contain the digraph keyword") + self.assertIn("Metrics", source, "The dot representation must contain the Metrics keyword") + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.test_df_debug import * # noqa: F401 + + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 36a856b62719..1a54d850bb80 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -37,6 +37,7 @@ AnalysisException, IllegalArgumentException, PySparkTypeError, + PySparkValueError, ) from pyspark.testing.sqlutils import ( ReusedSQLTestCase, @@ -849,7 +850,15 @@ def test_checkpoint_dataframe(self): class DataFrameTests(DataFrameTestsMixin, ReusedSQLTestCase): - pass + def test_query_execution_unsupported_in_classic(self): + with self.assertRaises(PySparkValueError) as pe: + self.spark.range(1).queryExecution + + self.check_error( + exception=pe.exception, + error_class="CLASSIC_OPERATION_NOT_SUPPORTED_ON_DF", + message_parameters={"member": "queryExecution"}, + ) if __name__ == "__main__": diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index 191505741eb4..267bb5255a9e 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -45,6 +45,13 @@ googleapis_common_protos_requirement_message = str(e) have_googleapis_common_protos = googleapis_common_protos_requirement_message is None +graphviz_requirement_message = None +try: + import graphviz +except ImportError as e: + graphviz_requirement_message = str(e) +have_graphviz = graphviz_requirement_message is None + from pyspark import Row, SparkConf from pyspark.util import is_remote_only from pyspark.testing.utils import PySparkErrorTestUtils From 07527aabc61fd369fd98e2fc2db0e665f142f90d Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Mon, 17 Jun 2024 21:56:47 +0200 Subject: [PATCH 02/16] fix lint --- python/pyspark/sql/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/metrics.py b/python/pyspark/sql/metrics.py index 0f5169c89b47..cb8ff5d3041d 100644 --- a/python/pyspark/sql/metrics.py +++ b/python/pyspark/sql/metrics.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: if have_graphviz: - import graphviz # type: ignore + import graphviz # type: ignore class ObservedMetrics(abc.ABC): From de5541a5ccf58b3a1b9f9a4cd17efa5dd3d610b3 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Mon, 17 Jun 2024 22:10:08 +0200 Subject: [PATCH 03/16] review comments --- python/pyspark/sql/connect/dataframe.py | 9 --------- python/pyspark/sql/dataframe.py | 13 +++++++++++++ python/pyspark/sql/metrics.py | 3 ++- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 53421938dbff..8eaf59bf8e4f 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -2208,15 +2208,6 @@ def rdd(self) -> "RDD[Row]": @property def queryExecution(self) -> Optional["QueryExecution"]: - """ - The queryExecution method allows to introspect information about the actual - query execution after the successful execution. Accessing this member before - the query execution has happened will return None. - - Returns - ------- - An instance of QueryExecution or None when the value is not set yet. - """ return self._query_execution diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 7a352c44c35d..e70a690d0cbb 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -6284,6 +6284,19 @@ def toPandas(self) -> "PandasDataFrameLike": @property def queryExecution(self) -> Optional["QueryExecution"]: + """ + Returns a QueryExecution object after the query was executed. + + The queryExecution method allows to introspect information about the actual + query execution after the successful execution. Accessing this member before + the query execution will return None. + + .. versionadded:: 4.0.0 + + Returns + ------- + An instance of QueryExecution or None when the value is not set yet. + """ ... diff --git a/python/pyspark/sql/metrics.py b/python/pyspark/sql/metrics.py index cb8ff5d3041d..f2e72d9a1d90 100644 --- a/python/pyspark/sql/metrics.py +++ b/python/pyspark/sql/metrics.py @@ -19,9 +19,10 @@ from typing import Optional, List, Tuple, Dict, Any, Union, TYPE_CHECKING, Sequence from pyspark.errors import PySparkValueError -from pyspark.testing.connectutils import have_graphviz if TYPE_CHECKING: + from pyspark.testing.connectutils import have_graphviz + if have_graphviz: import graphviz # type: ignore From 3e95c96bb660341ccce63e71c5da43682988866b Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Mon, 17 Jun 2024 22:29:21 +0200 Subject: [PATCH 04/16] fix lint --- python/pyspark/sql/metrics.py | 6 +++--- python/pyspark/testing/connectutils.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/metrics.py b/python/pyspark/sql/metrics.py index f2e72d9a1d90..004e07e4fb75 100644 --- a/python/pyspark/sql/metrics.py +++ b/python/pyspark/sql/metrics.py @@ -157,16 +157,16 @@ def toDot(self, filename: Optional[str] = None, out_format: str = "png") -> "gra Parameters ---------- - filename - str, optional + filename : str, optional The filename to save the graph to given an output format. The path can be relative or absolute. - out_format - str + out_format : str The output format of the graph. The default is 'png'. Returns ------- - + An instance of the graphviz.Digraph object. """ try: import graphviz diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index 267bb5255a9e..b3004693724b 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -50,7 +50,7 @@ import graphviz except ImportError as e: graphviz_requirement_message = str(e) -have_graphviz = graphviz_requirement_message is None +have_graphviz: bool = graphviz_requirement_message is None from pyspark import Row, SparkConf from pyspark.util import is_remote_only From a310cda1954915e563c035dc6b69e52dd39f5b3f Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Tue, 18 Jun 2024 14:57:24 +0200 Subject: [PATCH 05/16] adding support for writer support --- python/pyspark/sql/connect/client/core.py | 12 +++-- python/pyspark/sql/connect/dataframe.py | 10 +++- python/pyspark/sql/connect/readwriter.py | 47 ++++++++++++++----- python/pyspark/sql/connect/session.py | 7 ++- .../sql/tests/connect/test_df_debug.py | 7 +++ 5 files changed, 63 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 3728985b9609..c4ed1a736a4d 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -1013,7 +1013,7 @@ def explain_string(self, plan: pb2.Plan, explain_mode: str = "extended") -> str: def execute_command( self, command: pb2.Command, observations: Optional[Dict[str, Observation]] = None - ) -> Tuple[Optional[pd.DataFrame], Dict[str, Any]]: + ) -> Tuple[Optional[pd.DataFrame], Dict[str, Any], QueryExecution]: """ Execute given command. """ @@ -1022,11 +1022,15 @@ def execute_command( if self._user_id: req.user_context.user_id = self._user_id req.plan.command.CopyFrom(command) - data, _, _, _, properties = self._execute_and_fetch(req, observations or {}) + data, _, metrics, observed_metrics, properties = self._execute_and_fetch( + req, observations or {} + ) + # Create a query execution object. + qe = QueryExecution(metrics, observed_metrics) if data is not None: - return (data.to_pandas(), properties) + return (data.to_pandas(), properties, qe) else: - return (None, properties) + return (None, properties, qe) def execute_command_as_iterator( self, command: pb2.Command, observations: Optional[Dict[str, Observation]] = None diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 8eaf59bf8e4f..1ee6e9ab1054 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -208,7 +208,10 @@ def _repr_html_(self) -> Optional[str]: @property def write(self) -> "DataFrameWriter": - return DataFrameWriter(self._plan, self._session) + def cb(qe: "QueryExecution") -> None: + self._query_execution = qe + + return DataFrameWriter(self._plan, self._session, cb) @functools.cache def isEmpty(self) -> bool: @@ -2170,7 +2173,10 @@ def semanticHash(self) -> int: ) def writeTo(self, table: str) -> "DataFrameWriterV2": - return DataFrameWriterV2(self._plan, self._session, table) + def cb(qe: "QueryExecution") -> None: + self._query_execution = qe + + return DataFrameWriterV2(self._plan, self._session, table, cb) def offset(self, n: int) -> ParentDataFrame: return DataFrame(plan.Offset(child=self._plan, offset=n), session=self._session) diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index bf7dc4d36905..f49e5795902f 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -19,7 +19,7 @@ check_dependencies(__name__) from typing import Dict -from typing import Optional, Union, List, overload, Tuple, cast +from typing import Optional, Union, List, overload, Tuple, cast, Callable from typing import TYPE_CHECKING from pyspark.sql.connect.plan import Read, DataSource, LogicalPlan, WriteOperation, WriteOperationV2 @@ -37,6 +37,7 @@ from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.connect._typing import ColumnOrName, OptionalPrimitiveType from pyspark.sql.connect.session import SparkSession + from pyspark.sql.metrics import QueryExecution __all__ = ["DataFrameReader", "DataFrameWriter"] @@ -486,10 +487,16 @@ def _jreader(self) -> None: class DataFrameWriter(OptionUtils): - def __init__(self, plan: "LogicalPlan", session: "SparkSession"): + def __init__( + self, + plan: "LogicalPlan", + session: "SparkSession", + callback: Optional[Callable[['QueryExecution'], None]] = None, + ): self._df: "LogicalPlan" = plan self._spark: "SparkSession" = session self._write: "WriteOperation" = WriteOperation(self._df) + self._callback = callback if callback is not None else lambda _: None def mode(self, saveMode: Optional[str]) -> "DataFrameWriter": # At the JVM side, the default value of mode is already set to "error". @@ -649,9 +656,10 @@ def save( if format is not None: self.format(format) self._write.path = path - self._spark.client.execute_command( + _, _, qe = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) + self._callback(qe) save.__doc__ = PySparkDataFrameWriter.save.__doc__ @@ -660,9 +668,10 @@ def insertInto(self, tableName: str, overwrite: Optional[bool] = None) -> None: self.mode("overwrite" if overwrite else "append") self._write.table_name = tableName self._write.table_save_method = "insert_into" - self._spark.client.execute_command( + _, _, qe = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) + self._callback(qe) insertInto.__doc__ = PySparkDataFrameWriter.insertInto.__doc__ @@ -681,9 +690,10 @@ def saveAsTable( self.format(format) self._write.table_name = name self._write.table_save_method = "save_as_table" - self._spark.client.execute_command( + _, _, qe = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) + self._callback(qe) saveAsTable.__doc__ = PySparkDataFrameWriter.saveAsTable.__doc__ @@ -845,11 +855,18 @@ def jdbc( class DataFrameWriterV2(OptionUtils): - def __init__(self, plan: "LogicalPlan", session: "SparkSession", table: str): + def __init__( + self, + plan: "LogicalPlan", + session: "SparkSession", + table: str, + callback: Optional[Callable[['QueryExecution'], None]] = None, + ): self._df: "LogicalPlan" = plan self._spark: "SparkSession" = session self._table_name: str = table self._write: "WriteOperationV2" = WriteOperationV2(self._df, self._table_name) + self._callback = callback if callback is not None else lambda _: None def using(self, provider: str) -> "DataFrameWriterV2": self._write.provider = provider @@ -884,50 +901,56 @@ def partitionedBy(self, col: "ColumnOrName", *cols: "ColumnOrName") -> "DataFram def create(self) -> None: self._write.mode = "create" - self._spark.client.execute_command( + _, _, qe = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) + self._callback(qe) create.__doc__ = PySparkDataFrameWriterV2.create.__doc__ def replace(self) -> None: self._write.mode = "replace" - self._spark.client.execute_command( + _, _, qe = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) + self._callback(qe) replace.__doc__ = PySparkDataFrameWriterV2.replace.__doc__ def createOrReplace(self) -> None: self._write.mode = "create_or_replace" - self._spark.client.execute_command( + _, _, qe = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) + self._callback(qe) createOrReplace.__doc__ = PySparkDataFrameWriterV2.createOrReplace.__doc__ def append(self) -> None: self._write.mode = "append" - self._spark.client.execute_command( + _, _, qe = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) + self._callback(qe) append.__doc__ = PySparkDataFrameWriterV2.append.__doc__ def overwrite(self, condition: "ColumnOrName") -> None: self._write.mode = "overwrite" self._write.overwrite_condition = F._to_col(condition) - self._spark.client.execute_command( + _, _, qe = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) + self._callback(qe) overwrite.__doc__ = PySparkDataFrameWriterV2.overwrite.__doc__ def overwritePartitions(self) -> None: self._write.mode = "overwrite_partitions" - self._spark.client.execute_command( + _, _, qe = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) + self._callback(qe) overwritePartitions.__doc__ = PySparkDataFrameWriterV2.overwritePartitions.__doc__ diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index f359ab829483..20e2a9939abb 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -720,9 +720,12 @@ def sql( _views.append(SubqueryAlias(df._plan, name)) cmd = SQL(sqlQuery, _args, _named_args, _views) - data, properties = self.client.execute_command(cmd.command(self._client)) + data, properties, qe = self.client.execute_command(cmd.command(self._client)) if "sql_command_result" in properties: - return DataFrame(CachedRelation(properties["sql_command_result"]), self) + df = DataFrame(CachedRelation(properties["sql_command_result"]), self) + # A command result contains the execution. + df._query_execution = qe + return df else: return DataFrame(cmd, self) diff --git a/python/pyspark/sql/tests/connect/test_df_debug.py b/python/pyspark/sql/tests/connect/test_df_debug.py index 01c47fa9ad3c..9eed7e92e02c 100644 --- a/python/pyspark/sql/tests/connect/test_df_debug.py +++ b/python/pyspark/sql/tests/connect/test_df_debug.py @@ -42,6 +42,13 @@ def test_df_quey_execution_empty_before_execution(self): qe = df.queryExecution self.assertIsNone(qe, "The query execution must be None before the action is executed") + def test_df_query_execution_with_writes(self): + df: DataFrame = self.connect.range(100).repation(10).groupBy("id").count() + df.write.save("/tmp/test_df_query_execution_with_writes", format="json") + + qe = df.queryExecution + self.assertIsNone(qe, "The query execution must be None after the write action is executed") + @unittest.skipIf(not have_graphviz, graphviz_requirement_message) def test_df_query_execution_metrics_to_dot(self): df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() From f08f598ae6a4e320ed2d1b3a3e99b6699a573592 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Tue, 18 Jun 2024 16:09:24 +0200 Subject: [PATCH 06/16] fixing unpacking bugs --- python/pyspark/sql/connect/client/core.py | 4 ++-- python/pyspark/sql/connect/dataframe.py | 18 ++++++++++++------ python/pyspark/sql/connect/readwriter.py | 4 ++-- python/pyspark/sql/connect/streaming/query.py | 4 ++-- .../sql/connect/streaming/readwriter.py | 2 +- 5 files changed, 19 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index c4ed1a736a4d..2cc4b908806a 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -847,7 +847,7 @@ def _resources(self) -> Dict[str, ResourceInformation]: logger.info("Fetching the resources") cmd = pb2.Command() cmd.get_resources_command.SetInParent() - (_, properties) = self.execute_command(cmd) + (_, properties, _) = self.execute_command(cmd) resources = properties["get_resources_command_result"] return resources @@ -1815,6 +1815,6 @@ def _create_profile(self, profile: pb2.ResourceProfile) -> int: logger.info("Creating the ResourceProfile") cmd = pb2.Command() cmd.create_resource_profile_command.profile.CopyFrom(profile) - (_, properties) = self.execute_command(cmd) + (_, properties, _) = self.execute_command(cmd) profile_id = properties["create_resource_profile_command_result"] return profile_id diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 1ee6e9ab1054..1034f86ef97d 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -1980,25 +1980,29 @@ def createTempView(self, name: str) -> None: command = plan.CreateView( child=self._plan, name=name, is_global=False, replace=False ).command(session=self._session.client) - self._session.client.execute_command(command, self._plan.observations) + _, _, qe = self._session.client.execute_command(command, self._plan.observations) + self._query_execution = qe def createOrReplaceTempView(self, name: str) -> None: command = plan.CreateView( child=self._plan, name=name, is_global=False, replace=True ).command(session=self._session.client) - self._session.client.execute_command(command, self._plan.observations) + _, _, qe = self._session.client.execute_command(command, self._plan.observations) + self._query_execution = qe def createGlobalTempView(self, name: str) -> None: command = plan.CreateView( child=self._plan, name=name, is_global=True, replace=False ).command(session=self._session.client) - self._session.client.execute_command(command, self._plan.observations) + _, _, qe = self._session.client.execute_command(command, self._plan.observations) + self._query_execution = qe def createOrReplaceGlobalTempView(self, name: str) -> None: command = plan.CreateView( child=self._plan, name=name, is_global=True, replace=True ).command(session=self._session.client) - self._session.client.execute_command(command, self._plan.observations) + _, _, qe = self._session.client.execute_command(command, self._plan.observations) + self._query_execution = qe def cache(self) -> ParentDataFrame: return self.persist() @@ -2183,7 +2187,8 @@ def offset(self, n: int) -> ParentDataFrame: def checkpoint(self, eager: bool = True) -> "DataFrame": cmd = plan.Checkpoint(child=self._plan, local=False, eager=eager) - _, properties = self._session.client.execute_command(cmd.command(self._session.client)) + _, properties, qe = self._session.client.execute_command(cmd.command(self._session.client)) + self._query_execution = qe assert "checkpoint_command_result" in properties checkpointed = properties["checkpoint_command_result"] assert isinstance(checkpointed._plan, plan.CachedRemoteRelation) @@ -2191,7 +2196,8 @@ def checkpoint(self, eager: bool = True) -> "DataFrame": def localCheckpoint(self, eager: bool = True) -> "DataFrame": cmd = plan.Checkpoint(child=self._plan, local=True, eager=eager) - _, properties = self._session.client.execute_command(cmd.command(self._session.client)) + _, properties, qe = self._session.client.execute_command(cmd.command(self._session.client)) + self._query_execution = qe assert "checkpoint_command_result" in properties checkpointed = properties["checkpoint_command_result"] assert isinstance(checkpointed._plan, plan.CachedRemoteRelation) diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index f49e5795902f..c6c2039f910f 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -491,7 +491,7 @@ def __init__( self, plan: "LogicalPlan", session: "SparkSession", - callback: Optional[Callable[['QueryExecution'], None]] = None, + callback: Optional[Callable[["QueryExecution"], None]] = None, ): self._df: "LogicalPlan" = plan self._spark: "SparkSession" = session @@ -860,7 +860,7 @@ def __init__( plan: "LogicalPlan", session: "SparkSession", table: str, - callback: Optional[Callable[['QueryExecution'], None]] = None, + callback: Optional[Callable[["QueryExecution"], None]] = None, ): self._df: "LogicalPlan" = plan self._spark: "SparkSession" = session diff --git a/python/pyspark/sql/connect/streaming/query.py b/python/pyspark/sql/connect/streaming/query.py index cc1e2e220188..19ddb6e87eae 100644 --- a/python/pyspark/sql/connect/streaming/query.py +++ b/python/pyspark/sql/connect/streaming/query.py @@ -182,7 +182,7 @@ def _execute_streaming_query_cmd( cmd.query_id.run_id = self._run_id exec_cmd = pb2.Command() exec_cmd.streaming_query_command.CopyFrom(cmd) - (_, properties) = self._session.client.execute_command(exec_cmd) + (_, properties, _) = self._session.client.execute_command(exec_cmd) return cast(pb2.StreamingQueryCommandResult, properties["streaming_query_command_result"]) @@ -261,7 +261,7 @@ def _execute_streaming_query_manager_cmd( ) -> pb2.StreamingQueryManagerCommandResult: exec_cmd = pb2.Command() exec_cmd.streaming_query_manager_command.CopyFrom(cmd) - (_, properties) = self._session.client.execute_command(exec_cmd) + (_, properties, _) = self._session.client.execute_command(exec_cmd) return cast( pb2.StreamingQueryManagerCommandResult, properties["streaming_query_manager_command_result"], diff --git a/python/pyspark/sql/connect/streaming/readwriter.py b/python/pyspark/sql/connect/streaming/readwriter.py index b5bb7f2a0912..9b11bf328b85 100644 --- a/python/pyspark/sql/connect/streaming/readwriter.py +++ b/python/pyspark/sql/connect/streaming/readwriter.py @@ -601,7 +601,7 @@ def _start_internal( self._write_proto.table_name = tableName cmd = self._write_stream.command(self._session.client) - (_, properties) = self._session.client.execute_command(cmd) + (_, properties, _) = self._session.client.execute_command(cmd) start_result = cast( pb2.WriteStreamOperationStartResult, properties["write_stream_operation_start_result"] From 882c12d016738cfb659f121cf7825631f9d0f148 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Tue, 18 Jun 2024 22:47:03 +0200 Subject: [PATCH 07/16] adding text support --- python/pyspark/sql/connect/readwriter.py | 1 + python/pyspark/sql/metrics.py | 49 ++++++++++++++++++- .../sql/tests/connect/test_df_debug.py | 14 ++++-- 3 files changed, 59 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index c6c2039f910f..feae2ddac041 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -496,6 +496,7 @@ def __init__( self._df: "LogicalPlan" = plan self._spark: "SparkSession" = session self._write: "WriteOperation" = WriteOperation(self._df) + self._callback = callback if callback is not None else lambda _: None def mode(self, saveMode: Optional[str]) -> "DataFrameWriter": diff --git a/python/pyspark/sql/metrics.py b/python/pyspark/sql/metrics.py index 004e07e4fb75..5a82707a9fa2 100644 --- a/python/pyspark/sql/metrics.py +++ b/python/pyspark/sql/metrics.py @@ -106,6 +106,41 @@ class Node: metrics: List[MetricValue] = dataclasses.field(default_factory=list) children: List[int] = dataclasses.field(default_factory=list) + def text(self, current: "Node", graph: Dict[int, "Node"], prefix="") -> str: + """ + Converts the current node and its children into a textual representation. This is used + to provide a usable output for the command line or other text-based interfaces. However, + it is recommended to use the Graphviz representation for a more visual representation. + + Parameters + ---------- + current: Node + Current node in the graph. + graph: Dict[int, Node] + The full graph of all nodes in the executed plan. + prefix: str + String prefix used for generating the output buffer. + + Returns + ------- + The full string representation of the current node as root. + """ + base_metrics = set(["numPartitions", "peakMemory", "numOutputRows", "spillSize"]) + + # Format the metrics of this node: + metric_buffer = [] + for m in current.metrics: + if m.name in base_metrics: + metric_buffer.append(f"{m.name}: {m.value} ({m.metric_type})") + + buffer = f"{prefix}- {current.name}({','.join(metric_buffer)})\n" + for i, child in enumerate(current.children): + c = graph[child] + new_prefix = prefix + " " if i == len(c.children) - 1 else prefix + if current.id != c.id: + buffer += self.text(c, graph, new_prefix) + return buffer + def __init__(self, metrics: List[PlanMetrics]): # Sort the input list self._metrics = sorted(metrics, key=lambda x: x._parent_id, reverse=False) @@ -137,7 +172,7 @@ def extract_graph(self) -> Tuple[int, Dict[int, "CollectedMetrics.Node"]]: all_nodes[m.parent_plan_id].children.append(m.plan_id) # Next step is to find all the root nodes. Root nodes are never used in children. - # So we start will all node ids as candidates. + # So we start with all node ids as candidates. candidates = set(all_nodes.keys()) for k, v in all_nodes.items(): for c in v.children: @@ -147,6 +182,18 @@ def extract_graph(self) -> Tuple[int, Dict[int, "CollectedMetrics.Node"]]: assert len(candidates) == 1, f"Expected 1 root node, found {len(candidates)}" return candidates.pop(), all_nodes + def toText(self) -> str: + """ + Converts the execution graph from a graph into a textual representation + that can be read at the command line for example. + + Returns + ------- + A string representation of the collected metrics. + """ + root, graph = self.extract_graph() + return self.text(graph[root], graph) + def toDot(self, filename: Optional[str] = None, out_format: str = "png") -> "graphviz.Digraph": """ Converts the collected metrics into a dot representation. Since the graphviz Digraph diff --git a/python/pyspark/sql/tests/connect/test_df_debug.py b/python/pyspark/sql/tests/connect/test_df_debug.py index 9eed7e92e02c..980424d9048b 100644 --- a/python/pyspark/sql/tests/connect/test_df_debug.py +++ b/python/pyspark/sql/tests/connect/test_df_debug.py @@ -43,11 +43,17 @@ def test_df_quey_execution_empty_before_execution(self): self.assertIsNone(qe, "The query execution must be None before the action is executed") def test_df_query_execution_with_writes(self): - df: DataFrame = self.connect.range(100).repation(10).groupBy("id").count() - df.write.save("/tmp/test_df_query_execution_with_writes", format="json") - + df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() + df.write.save("/tmp/test_df_query_execution_with_writes", format="json", mode="overwrite") qe = df.queryExecution - self.assertIsNone(qe, "The query execution must be None after the write action is executed") + self.assertIsNotNone( + qe, "The query execution must be None after the write action is executed" + ) + + def test_query_execution_text_format(self): + df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() + df.collect() + self.assertIn("HashAggregate", df.queryExecution.metrics.toText()) @unittest.skipIf(not have_graphviz, graphviz_requirement_message) def test_df_query_execution_metrics_to_dot(self): From f25a9e6e3b6bc8e9727bf381b15775367e59e3d4 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Tue, 18 Jun 2024 23:11:30 +0200 Subject: [PATCH 08/16] adding text support --- python/pyspark/sql/connect/client/core.py | 7 +++++-- python/pyspark/sql/connect/dataframe.py | 4 +++- python/pyspark/sql/tests/connect/test_df_debug.py | 5 +++++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 2cc4b908806a..53a04f5a109d 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -888,7 +888,9 @@ def to_table( assert table is not None return table, schema, qe - def to_pandas(self, plan: pb2.Plan, observations: Dict[str, Observation]) -> "pd.DataFrame": + def to_pandas( + self, plan: pb2.Plan, observations: Dict[str, Observation] + ) -> Tuple["pd.DataFrame", "QueryExecution"]: """ Return given plan as a pandas DataFrame. """ @@ -903,6 +905,7 @@ def to_pandas(self, plan: pb2.Plan, observations: Dict[str, Observation]) -> "pd req, observations, self_destruct=self_destruct ) assert table is not None + qe = QueryExecution(metrics, observed_metrics) schema = schema or from_arrow_schema(table.schema, prefer_timestamp_ntz=True) assert schema is not None and isinstance(schema, StructType) @@ -969,7 +972,7 @@ def to_pandas(self, plan: pb2.Plan, observations: Dict[str, Observation]) -> "pd pdf.attrs["metrics"] = metrics if len(observed_metrics) > 0: pdf.attrs["observed_metrics"] = observed_metrics - return pdf + return pdf, qe def _proto_to_string(self, p: google.protobuf.message.Message) -> str: """ diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 1034f86ef97d..0bf85d6087fb 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -1854,7 +1854,9 @@ def toArrow(self) -> "pa.Table": def toPandas(self) -> "PandasDataFrameLike": query = self._plan.to_proto(self._session.client) - return self._session.client.to_pandas(query, self._plan.observations) + pdf, qe = self._session.client.to_pandas(query, self._plan.observations) + self._query_execution = qe + return pdf @property def schema(self) -> StructType: diff --git a/python/pyspark/sql/tests/connect/test_df_debug.py b/python/pyspark/sql/tests/connect/test_df_debug.py index 980424d9048b..18f54fda1ac2 100644 --- a/python/pyspark/sql/tests/connect/test_df_debug.py +++ b/python/pyspark/sql/tests/connect/test_df_debug.py @@ -55,6 +55,11 @@ def test_query_execution_text_format(self): df.collect() self.assertIn("HashAggregate", df.queryExecution.metrics.toText()) + # Different execution mode. + df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() + df.toPandas() + self.assertIn("HashAggregate", df.queryExecution.metrics.toText()) + @unittest.skipIf(not have_graphviz, graphviz_requirement_message) def test_df_query_execution_metrics_to_dot(self): df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() From 4db1c0fffa92398b3229b6c2369f3485d7d376a4 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Tue, 18 Jun 2024 23:13:43 +0200 Subject: [PATCH 09/16] Update python/pyspark/errors/error-conditions.json Co-authored-by: allisonwang-db --- python/pyspark/errors/error-conditions.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index 23b0c4cd4f9b..dd70e814b1ea 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -151,7 +151,7 @@ }, "CLASSIC_OPERATION_NOT_SUPPORTED_ON_DF": { "message": [ - "Calling property or member is not supported in PySpark Classic, please use Spark Connect instead." + "Calling property or member '' is not supported in PySpark Classic, please use Spark Connect instead." ] }, "COLLATION_INVALID_PROVIDER" : { From 1588d4316cada5995a867fe35b4ffe098629f341 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Wed, 19 Jun 2024 06:23:54 +0200 Subject: [PATCH 10/16] fixing lint --- python/pyspark/sql/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/metrics.py b/python/pyspark/sql/metrics.py index 5a82707a9fa2..b2c09ba7bf08 100644 --- a/python/pyspark/sql/metrics.py +++ b/python/pyspark/sql/metrics.py @@ -106,7 +106,7 @@ class Node: metrics: List[MetricValue] = dataclasses.field(default_factory=list) children: List[int] = dataclasses.field(default_factory=list) - def text(self, current: "Node", graph: Dict[int, "Node"], prefix="") -> str: + def text(self, current: "Node", graph: Dict[int, "Node"], prefix: str="") -> str: """ Converts the current node and its children into a textual representation. This is used to provide a usable output for the command line or other text-based interfaces. However, From 84668cb7650bd36ac844563e7b3eed25c858a9cf Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Wed, 19 Jun 2024 14:44:57 +0200 Subject: [PATCH 11/16] fixing lint --- python/pyspark/sql/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/metrics.py b/python/pyspark/sql/metrics.py index b2c09ba7bf08..d63218eb28b7 100644 --- a/python/pyspark/sql/metrics.py +++ b/python/pyspark/sql/metrics.py @@ -106,7 +106,7 @@ class Node: metrics: List[MetricValue] = dataclasses.field(default_factory=list) children: List[int] = dataclasses.field(default_factory=list) - def text(self, current: "Node", graph: Dict[int, "Node"], prefix: str="") -> str: + def text(self, current: "Node", graph: Dict[int, "Node"], prefix: str = "") -> str: """ Converts the current node and its children into a textual representation. This is used to provide a usable output for the command line or other text-based interfaces. However, From 08b453eac275c78a2caec9a4ad77d51f93d14acf Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Fri, 21 Jun 2024 17:19:05 +0200 Subject: [PATCH 12/16] renaming to ExecutionInfo; --- python/pyspark/sql/classic/dataframe.py | 4 +- python/pyspark/sql/connect/client/core.py | 22 ++++----- python/pyspark/sql/connect/dataframe.py | 48 ++++++++++--------- python/pyspark/sql/connect/readwriter.py | 42 ++++++++-------- python/pyspark/sql/connect/session.py | 4 +- python/pyspark/sql/dataframe.py | 4 +- python/pyspark/sql/metrics.py | 2 +- .../sql/tests/connect/test_df_debug.py | 20 ++++---- python/pyspark/sql/tests/test_dataframe.py | 2 +- 9 files changed, 75 insertions(+), 73 deletions(-) diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py index d4541386e6ad..1bedd624603e 100644 --- a/python/pyspark/sql/classic/dataframe.py +++ b/python/pyspark/sql/classic/dataframe.py @@ -94,7 +94,7 @@ from pyspark.sql.session import SparkSession from pyspark.sql.group import GroupedData from pyspark.sql.observation import Observation - from pyspark.sql.metrics import QueryExecution + from pyspark.sql.metrics import ExecutionInfo class DataFrame(ParentDataFrame, PandasMapOpsMixin, PandasConversionMixin): @@ -1837,7 +1837,7 @@ def toPandas(self) -> "PandasDataFrameLike": return PandasConversionMixin.toPandas(self) @property - def queryExecution(self) -> Optional["QueryExecution"]: + def executionInfo(self) -> Optional["ExecutionInfo"]: raise PySparkValueError( error_class="CLASSIC_OPERATION_NOT_SUPPORTED_ON_DF", message_parameters={"member": "queryExecution"}, diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 89505e295fd6..12ba6573bfa8 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -61,7 +61,7 @@ from pyspark.loose_version import LooseVersion from pyspark.version import __version__ from pyspark.resource.information import ResourceInformation -from pyspark.sql.metrics import MetricValue, PlanMetrics, QueryExecution, ObservedMetrics +from pyspark.sql.metrics import MetricValue, PlanMetrics, ExecutionInfo, ObservedMetrics from pyspark.sql.connect.client.artifact import ArtifactManager from pyspark.sql.connect.client.logging import logger from pyspark.sql.connect.profiler import ConnectProfilerCollector @@ -874,7 +874,7 @@ def to_table_as_iterator( def to_table( self, plan: pb2.Plan, observations: Dict[str, Observation] - ) -> Tuple["pa.Table", Optional[StructType], QueryExecution]: + ) -> Tuple["pa.Table", Optional[StructType], ExecutionInfo]: """ Return given plan as a PyArrow Table. """ @@ -884,13 +884,13 @@ def to_table( table, schema, metrics, observed_metrics, _ = self._execute_and_fetch(req, observations) # Create a query execution object. - qe = QueryExecution(metrics, observed_metrics) + ei = ExecutionInfo(metrics, observed_metrics) assert table is not None - return table, schema, qe + return table, schema, ei def to_pandas( self, plan: pb2.Plan, observations: Dict[str, Observation] - ) -> Tuple["pd.DataFrame", "QueryExecution"]: + ) -> Tuple["pd.DataFrame", "ExecutionInfo"]: """ Return given plan as a pandas DataFrame. """ @@ -905,7 +905,7 @@ def to_pandas( req, observations, self_destruct=self_destruct ) assert table is not None - qe = QueryExecution(metrics, observed_metrics) + ei = ExecutionInfo(metrics, observed_metrics) schema = schema or from_arrow_schema(table.schema, prefer_timestamp_ntz=True) assert schema is not None and isinstance(schema, StructType) @@ -972,7 +972,7 @@ def to_pandas( pdf.attrs["metrics"] = metrics if len(observed_metrics) > 0: pdf.attrs["observed_metrics"] = observed_metrics - return pdf, qe + return pdf, ei def _proto_to_string(self, p: google.protobuf.message.Message) -> str: """ @@ -1016,7 +1016,7 @@ def explain_string(self, plan: pb2.Plan, explain_mode: str = "extended") -> str: def execute_command( self, command: pb2.Command, observations: Optional[Dict[str, Observation]] = None - ) -> Tuple[Optional[pd.DataFrame], Dict[str, Any], QueryExecution]: + ) -> Tuple[Optional[pd.DataFrame], Dict[str, Any], ExecutionInfo]: """ Execute given command. """ @@ -1029,11 +1029,11 @@ def execute_command( req, observations or {} ) # Create a query execution object. - qe = QueryExecution(metrics, observed_metrics) + ei = ExecutionInfo(metrics, observed_metrics) if data is not None: - return (data.to_pandas(), properties, qe) + return (data.to_pandas(), properties, ei) else: - return (None, properties, qe) + return (None, properties, ei) def execute_command_as_iterator( self, command: pb2.Command, observations: Optional[Dict[str, Observation]] = None diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index ba34c9bf2e37..1aa8fc00cfcc 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -101,7 +101,7 @@ from pyspark.sql.connect.observation import Observation from pyspark.sql.connect.session import SparkSession from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame - from pyspark.sql.metrics import QueryExecution + from pyspark.sql.metrics import ExecutionInfo class DataFrame(ParentDataFrame): @@ -138,7 +138,7 @@ def __init__( # by __repr__ and _repr_html_ while eager evaluation opens. self._support_repr_html = False self._cached_schema: Optional[StructType] = None - self._query_execution: Optional["QueryExecution"] = None + self._execution_info: Optional["ExecutionInfo"] = None def __reduce__(self) -> Tuple: """ @@ -208,8 +208,8 @@ def _repr_html_(self) -> Optional[str]: @property def write(self) -> "DataFrameWriter": - def cb(qe: "QueryExecution") -> None: - self._query_execution = qe + def cb(qe: "ExecutionInfo") -> None: + self._execution_info = qe return DataFrameWriter(self._plan, self._session, cb) @@ -1844,7 +1844,7 @@ def collect(self) -> List[Row]: def _to_table(self) -> Tuple["pa.Table", Optional[StructType]]: query = self._plan.to_proto(self._session.client) - table, schema, self._query_execution = self._session.client.to_table( + table, schema, self._execution_info = self._session.client.to_table( query, self._plan.observations ) assert table is not None @@ -1857,8 +1857,8 @@ def toArrow(self) -> "pa.Table": def toPandas(self) -> "PandasDataFrameLike": query = self._plan.to_proto(self._session.client) - pdf, qe = self._session.client.to_pandas(query, self._plan.observations) - self._query_execution = qe + pdf, ei = self._session.client.to_pandas(query, self._plan.observations) + self._execution_info = ei return pdf @property @@ -1985,29 +1985,29 @@ def createTempView(self, name: str) -> None: command = plan.CreateView( child=self._plan, name=name, is_global=False, replace=False ).command(session=self._session.client) - _, _, qe = self._session.client.execute_command(command, self._plan.observations) - self._query_execution = qe + _, _, ei = self._session.client.execute_command(command, self._plan.observations) + self._execution_info = ei def createOrReplaceTempView(self, name: str) -> None: command = plan.CreateView( child=self._plan, name=name, is_global=False, replace=True ).command(session=self._session.client) - _, _, qe = self._session.client.execute_command(command, self._plan.observations) - self._query_execution = qe + _, _, ei = self._session.client.execute_command(command, self._plan.observations) + self._execution_info = ei def createGlobalTempView(self, name: str) -> None: command = plan.CreateView( child=self._plan, name=name, is_global=True, replace=False ).command(session=self._session.client) - _, _, qe = self._session.client.execute_command(command, self._plan.observations) - self._query_execution = qe + _, _, ei = self._session.client.execute_command(command, self._plan.observations) + self._execution_info = ei def createOrReplaceGlobalTempView(self, name: str) -> None: command = plan.CreateView( child=self._plan, name=name, is_global=True, replace=True ).command(session=self._session.client) - _, _, qe = self._session.client.execute_command(command, self._plan.observations) - self._query_execution = qe + _, _, ei = self._session.client.execute_command(command, self._plan.observations) + self._execution_info = ei def cache(self) -> ParentDataFrame: return self.persist() @@ -2182,8 +2182,8 @@ def semanticHash(self) -> int: ) def writeTo(self, table: str) -> "DataFrameWriterV2": - def cb(qe: "QueryExecution") -> None: - self._query_execution = qe + def cb(ei: "ExecutionInfo") -> None: + self._execution_info = ei return DataFrameWriterV2(self._plan, self._session, table, cb) @@ -2192,8 +2192,9 @@ def offset(self, n: int) -> ParentDataFrame: def checkpoint(self, eager: bool = True) -> "DataFrame": cmd = plan.Checkpoint(child=self._plan, local=False, eager=eager) - _, properties, qe = self._session.client.execute_command(cmd.command(self._session.client)) - self._query_execution = qe + _, properties, self._execution_info = self._session.client.execute_command( + cmd.command(self._session.client) + ) assert "checkpoint_command_result" in properties checkpointed = properties["checkpoint_command_result"] assert isinstance(checkpointed._plan, plan.CachedRemoteRelation) @@ -2201,8 +2202,9 @@ def checkpoint(self, eager: bool = True) -> "DataFrame": def localCheckpoint(self, eager: bool = True) -> "DataFrame": cmd = plan.Checkpoint(child=self._plan, local=True, eager=eager) - _, properties, qe = self._session.client.execute_command(cmd.command(self._session.client)) - self._query_execution = qe + _, properties, self._execution_info = self._session.client.execute_command( + cmd.command(self._session.client) + ) assert "checkpoint_command_result" in properties checkpointed = properties["checkpoint_command_result"] assert isinstance(checkpointed._plan, plan.CachedRemoteRelation) @@ -2224,8 +2226,8 @@ def rdd(self) -> "RDD[Row]": ) @property - def queryExecution(self) -> Optional["QueryExecution"]: - return self._query_execution + def executionInfo(self) -> Optional["ExecutionInfo"]: + return self._execution_info class DataFrameNaFunctions(ParentDataFrameNaFunctions): diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index feae2ddac041..de62cf65b01e 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -37,7 +37,7 @@ from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.connect._typing import ColumnOrName, OptionalPrimitiveType from pyspark.sql.connect.session import SparkSession - from pyspark.sql.metrics import QueryExecution + from pyspark.sql.metrics import ExecutionInfo __all__ = ["DataFrameReader", "DataFrameWriter"] @@ -491,7 +491,7 @@ def __init__( self, plan: "LogicalPlan", session: "SparkSession", - callback: Optional[Callable[["QueryExecution"], None]] = None, + callback: Optional[Callable[["ExecutionInfo"], None]] = None, ): self._df: "LogicalPlan" = plan self._spark: "SparkSession" = session @@ -657,10 +657,10 @@ def save( if format is not None: self.format(format) self._write.path = path - _, _, qe = self._spark.client.execute_command( + _, _, ei = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) - self._callback(qe) + self._callback(ei) save.__doc__ = PySparkDataFrameWriter.save.__doc__ @@ -669,10 +669,10 @@ def insertInto(self, tableName: str, overwrite: Optional[bool] = None) -> None: self.mode("overwrite" if overwrite else "append") self._write.table_name = tableName self._write.table_save_method = "insert_into" - _, _, qe = self._spark.client.execute_command( + _, _, ei = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) - self._callback(qe) + self._callback(ei) insertInto.__doc__ = PySparkDataFrameWriter.insertInto.__doc__ @@ -691,10 +691,10 @@ def saveAsTable( self.format(format) self._write.table_name = name self._write.table_save_method = "save_as_table" - _, _, qe = self._spark.client.execute_command( + _, _, ei = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) - self._callback(qe) + self._callback(ei) saveAsTable.__doc__ = PySparkDataFrameWriter.saveAsTable.__doc__ @@ -861,7 +861,7 @@ def __init__( plan: "LogicalPlan", session: "SparkSession", table: str, - callback: Optional[Callable[["QueryExecution"], None]] = None, + callback: Optional[Callable[["ExecutionInfo"], None]] = None, ): self._df: "LogicalPlan" = plan self._spark: "SparkSession" = session @@ -902,56 +902,56 @@ def partitionedBy(self, col: "ColumnOrName", *cols: "ColumnOrName") -> "DataFram def create(self) -> None: self._write.mode = "create" - _, _, qe = self._spark.client.execute_command( + _, _, ei = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) - self._callback(qe) + self._callback(ei) create.__doc__ = PySparkDataFrameWriterV2.create.__doc__ def replace(self) -> None: self._write.mode = "replace" - _, _, qe = self._spark.client.execute_command( + _, _, ei = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) - self._callback(qe) + self._callback(ei) replace.__doc__ = PySparkDataFrameWriterV2.replace.__doc__ def createOrReplace(self) -> None: self._write.mode = "create_or_replace" - _, _, qe = self._spark.client.execute_command( + _, _, ei = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) - self._callback(qe) + self._callback(ei) createOrReplace.__doc__ = PySparkDataFrameWriterV2.createOrReplace.__doc__ def append(self) -> None: self._write.mode = "append" - _, _, qe = self._spark.client.execute_command( + _, _, ei = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) - self._callback(qe) + self._callback(ei) append.__doc__ = PySparkDataFrameWriterV2.append.__doc__ def overwrite(self, condition: "ColumnOrName") -> None: self._write.mode = "overwrite" self._write.overwrite_condition = F._to_col(condition) - _, _, qe = self._spark.client.execute_command( + _, _, ei = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) - self._callback(qe) + self._callback(ei) overwrite.__doc__ = PySparkDataFrameWriterV2.overwrite.__doc__ def overwritePartitions(self) -> None: self._write.mode = "overwrite_partitions" - _, _, qe = self._spark.client.execute_command( + _, _, ei = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) - self._callback(qe) + self._callback(ei) overwritePartitions.__doc__ = PySparkDataFrameWriterV2.overwritePartitions.__doc__ diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 20e2a9939abb..8e277b3fc63a 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -720,11 +720,11 @@ def sql( _views.append(SubqueryAlias(df._plan, name)) cmd = SQL(sqlQuery, _args, _named_args, _views) - data, properties, qe = self.client.execute_command(cmd.command(self._client)) + data, properties, ei = self.client.execute_command(cmd.command(self._client)) if "sql_command_result" in properties: df = DataFrame(CachedRelation(properties["sql_command_result"]), self) # A command result contains the execution. - df._query_execution = qe + df._execution_info = ei return df else: return DataFrame(cmd, self) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index e70a690d0cbb..aee602097728 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -64,7 +64,7 @@ ArrowMapIterFunction, DataFrameLike as PandasDataFrameLike, ) - from pyspark.sql.metrics import QueryExecution + from pyspark.sql.metrics import ExecutionInfo __all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"] @@ -6283,7 +6283,7 @@ def toPandas(self) -> "PandasDataFrameLike": ... @property - def queryExecution(self) -> Optional["QueryExecution"]: + def executionInfo(self) -> Optional["ExecutionInfo"]: """ Returns a QueryExecution object after the query was executed. diff --git a/python/pyspark/sql/metrics.py b/python/pyspark/sql/metrics.py index d63218eb28b7..a8cd853a4f2b 100644 --- a/python/pyspark/sql/metrics.py +++ b/python/pyspark/sql/metrics.py @@ -267,7 +267,7 @@ def toDot(self, filename: Optional[str] = None, out_format: str = "png") -> "gra ) -class QueryExecution: +class ExecutionInfo: """The query execution class allows users to inspect the query execution of this particular data frame. This value is only set in the data frame if it was executed.""" diff --git a/python/pyspark/sql/tests/connect/test_df_debug.py b/python/pyspark/sql/tests/connect/test_df_debug.py index 18f54fda1ac2..8a4ec68fda84 100644 --- a/python/pyspark/sql/tests/connect/test_df_debug.py +++ b/python/pyspark/sql/tests/connect/test_df_debug.py @@ -32,41 +32,41 @@ class SparkConnectDataFrameDebug(SparkConnectSQLTestCase): def test_df_debug_basics(self): df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() x = df.collect() # noqa: F841 - qe = df.queryExecution + ei = df.executionInfo - root, graph = qe.metrics.extract_graph() + root, graph = ei.metrics.extract_graph() self.assertIn(root, graph, "The root must be rooted in the graph") def test_df_quey_execution_empty_before_execution(self): df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() - qe = df.queryExecution - self.assertIsNone(qe, "The query execution must be None before the action is executed") + ei = df.executionInfo + self.assertIsNone(ei, "The query execution must be None before the action is executed") def test_df_query_execution_with_writes(self): df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() df.write.save("/tmp/test_df_query_execution_with_writes", format="json", mode="overwrite") - qe = df.queryExecution + ei = df.executionInfo self.assertIsNotNone( - qe, "The query execution must be None after the write action is executed" + ei, "The query execution must be None after the write action is executed" ) def test_query_execution_text_format(self): df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() df.collect() - self.assertIn("HashAggregate", df.queryExecution.metrics.toText()) + self.assertIn("HashAggregate", df.executionInfo.metrics.toText()) # Different execution mode. df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() df.toPandas() - self.assertIn("HashAggregate", df.queryExecution.metrics.toText()) + self.assertIn("HashAggregate", df.executionInfo.metrics.toText()) @unittest.skipIf(not have_graphviz, graphviz_requirement_message) def test_df_query_execution_metrics_to_dot(self): df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() x = df.collect() # noqa: F841 - qe = df.queryExecution + ei = df.executionInfo - dot = qe.metrics.toDot() + dot = ei.metrics.toDot() source = dot.source self.assertIsNotNone(dot, "The dot representation must not be None") self.assertGreater(len(source), 0, "The dot representation must not be empty") diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 1a54d850bb80..8cf2b7919f33 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -852,7 +852,7 @@ def test_checkpoint_dataframe(self): class DataFrameTests(DataFrameTestsMixin, ReusedSQLTestCase): def test_query_execution_unsupported_in_classic(self): with self.assertRaises(PySparkValueError) as pe: - self.spark.range(1).queryExecution + self.spark.range(1).executionInfo self.check_error( exception=pe.exception, From 1eb6281e60b96ef5da109023cdad9775ddc28545 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Mon, 24 Jun 2024 09:40:55 +0200 Subject: [PATCH 13/16] updating doc --- python/pyspark/sql/dataframe.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index aee602097728..586c4ddb963e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -6291,6 +6291,9 @@ def executionInfo(self) -> Optional["ExecutionInfo"]: query execution after the successful execution. Accessing this member before the query execution will return None. + If the same DataFrame is executed multiple times, the execution info will be + overwritten by the latest operation. + .. versionadded:: 4.0.0 Returns From 8c68b171c1b6bab843a4ab6df672085894ccbabd Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Mon, 24 Jun 2024 22:16:25 +0200 Subject: [PATCH 14/16] adjusting format --- python/pyspark/sql/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/metrics.py b/python/pyspark/sql/metrics.py index a8cd853a4f2b..fcff57aba1ca 100644 --- a/python/pyspark/sql/metrics.py +++ b/python/pyspark/sql/metrics.py @@ -133,7 +133,7 @@ def text(self, current: "Node", graph: Dict[int, "Node"], prefix: str = "") -> s if m.name in base_metrics: metric_buffer.append(f"{m.name}: {m.value} ({m.metric_type})") - buffer = f"{prefix}- {current.name}({','.join(metric_buffer)})\n" + buffer = f"{prefix}+- {current.name}({','.join(metric_buffer)})\n" for i, child in enumerate(current.children): c = graph[child] new_prefix = prefix + " " if i == len(c.children) - 1 else prefix From 36ca3960c71c7f25adaecf429f7b8b4d70ab3187 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Tue, 25 Jun 2024 08:22:26 +0200 Subject: [PATCH 15/16] Update python/pyspark/sql/dataframe.py Co-authored-by: Hyukjin Kwon --- python/pyspark/sql/dataframe.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 586c4ddb963e..625678588bf9 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -6299,6 +6299,11 @@ def executionInfo(self) -> Optional["ExecutionInfo"]: Returns ------- An instance of QueryExecution or None when the value is not set yet. + + Notes + ----- + This is an API dedicated to Spark Connect client only. With regular Spark Session, it throws + an exception. """ ... From 0412ad5c5a3ae96e2fb73db441780e613751174c Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Tue, 25 Jun 2024 08:46:36 +0200 Subject: [PATCH 16/16] comments --- python/docs/source/getting_started/install.rst | 1 + python/docs/source/reference/pyspark.sql/dataframe.rst | 1 + python/pyspark/sql/metrics.py | 5 +++-- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/docs/source/getting_started/install.rst b/python/docs/source/getting_started/install.rst index 21926ae295bf..6cc68cd46b11 100644 --- a/python/docs/source/getting_started/install.rst +++ b/python/docs/source/getting_started/install.rst @@ -210,6 +210,7 @@ Package Supported version Note `grpcio` >=1.62.0 Required for Spark Connect `grpcio-status` >=1.62.0 Required for Spark Connect `googleapis-common-protos` >=1.56.4 Required for Spark Connect +`graphviz` >=0.20 Optional for Spark Connect ========================== ================= ========================== Spark SQL diff --git a/python/docs/source/reference/pyspark.sql/dataframe.rst b/python/docs/source/reference/pyspark.sql/dataframe.rst index ec39b645b140..d0196baa7a05 100644 --- a/python/docs/source/reference/pyspark.sql/dataframe.rst +++ b/python/docs/source/reference/pyspark.sql/dataframe.rst @@ -55,6 +55,7 @@ DataFrame DataFrame.dropna DataFrame.dtypes DataFrame.exceptAll + DataFrame.executionInfo DataFrame.explain DataFrame.fillna DataFrame.filter diff --git a/python/pyspark/sql/metrics.py b/python/pyspark/sql/metrics.py index fcff57aba1ca..666458295201 100644 --- a/python/pyspark/sql/metrics.py +++ b/python/pyspark/sql/metrics.py @@ -116,8 +116,9 @@ def text(self, current: "Node", graph: Dict[int, "Node"], prefix: str = "") -> s ---------- current: Node Current node in the graph. - graph: Dict[int, Node] - The full graph of all nodes in the executed plan. + graph: dict + A dictionary representing the full graph mapping from node ID (int) to the node itself. + The node is an instance of :class:`CollectedMetrics:Node`. prefix: str String prefix used for generating the output buffer.