Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
adding support for writer support
  • Loading branch information
grundprinzip committed Jun 18, 2024
commit a310cda1954915e563c035dc6b69e52dd39f5b3f
12 changes: 8 additions & 4 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,7 @@ def explain_string(self, plan: pb2.Plan, explain_mode: str = "extended") -> str:

def execute_command(
self, command: pb2.Command, observations: Optional[Dict[str, Observation]] = None
) -> Tuple[Optional[pd.DataFrame], Dict[str, Any]]:
) -> Tuple[Optional[pd.DataFrame], Dict[str, Any], QueryExecution]:
"""
Execute given command.
"""
Expand All @@ -1022,11 +1022,15 @@ def execute_command(
if self._user_id:
req.user_context.user_id = self._user_id
req.plan.command.CopyFrom(command)
data, _, _, _, properties = self._execute_and_fetch(req, observations or {})
data, _, metrics, observed_metrics, properties = self._execute_and_fetch(
req, observations or {}
)
# Create a query execution object.
qe = QueryExecution(metrics, observed_metrics)
if data is not None:
return (data.to_pandas(), properties)
return (data.to_pandas(), properties, qe)
else:
return (None, properties)
return (None, properties, qe)

def execute_command_as_iterator(
self, command: pb2.Command, observations: Optional[Dict[str, Observation]] = None
Expand Down
10 changes: 8 additions & 2 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,10 @@ def _repr_html_(self) -> Optional[str]:

@property
def write(self) -> "DataFrameWriter":
return DataFrameWriter(self._plan, self._session)
def cb(qe: "QueryExecution") -> None:
self._query_execution = qe

return DataFrameWriter(self._plan, self._session, cb)
Copy link
Contributor

@WweiL WweiL Jun 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like writeStream is not overriden here. So I imagine streaming query is not supported yet.

In streaming a query could have multiple data frames, what we do in scala is to access it with query.explain(), which uses this lastExecution

def lastExecution: IncrementalExecution = getLatestExecutionContext().executionPlan

That's, as it's name, the QueryExecution(IncrementalExecution) of the last execution.

We could also add a similar mechanism to StreamingQuery object. This sounds like an interesting followup that im interested in

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should look at streaming as a follow up.


@functools.cache
def isEmpty(self) -> bool:
Expand Down Expand Up @@ -2170,7 +2173,10 @@ def semanticHash(self) -> int:
)

def writeTo(self, table: str) -> "DataFrameWriterV2":
return DataFrameWriterV2(self._plan, self._session, table)
def cb(qe: "QueryExecution") -> None:
self._query_execution = qe

return DataFrameWriterV2(self._plan, self._session, table, cb)

def offset(self, n: int) -> ParentDataFrame:
return DataFrame(plan.Offset(child=self._plan, offset=n), session=self._session)
Expand Down
47 changes: 35 additions & 12 deletions python/pyspark/sql/connect/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
check_dependencies(__name__)

from typing import Dict
from typing import Optional, Union, List, overload, Tuple, cast
from typing import Optional, Union, List, overload, Tuple, cast, Callable
from typing import TYPE_CHECKING

from pyspark.sql.connect.plan import Read, DataSource, LogicalPlan, WriteOperation, WriteOperationV2
Expand All @@ -37,6 +37,7 @@
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect._typing import ColumnOrName, OptionalPrimitiveType
from pyspark.sql.connect.session import SparkSession
from pyspark.sql.metrics import QueryExecution

__all__ = ["DataFrameReader", "DataFrameWriter"]

Expand Down Expand Up @@ -486,10 +487,16 @@ def _jreader(self) -> None:


class DataFrameWriter(OptionUtils):
def __init__(self, plan: "LogicalPlan", session: "SparkSession"):
def __init__(
self,
plan: "LogicalPlan",
session: "SparkSession",
callback: Optional[Callable[['QueryExecution'], None]] = None,
):
self._df: "LogicalPlan" = plan
self._spark: "SparkSession" = session
self._write: "WriteOperation" = WriteOperation(self._df)
self._callback = callback if callback is not None else lambda _: None

def mode(self, saveMode: Optional[str]) -> "DataFrameWriter":
# At the JVM side, the default value of mode is already set to "error".
Expand Down Expand Up @@ -649,9 +656,10 @@ def save(
if format is not None:
self.format(format)
self._write.path = path
self._spark.client.execute_command(
_, _, qe = self._spark.client.execute_command(
self._write.command(self._spark.client), self._write.observations
)
self._callback(qe)

save.__doc__ = PySparkDataFrameWriter.save.__doc__

Expand All @@ -660,9 +668,10 @@ def insertInto(self, tableName: str, overwrite: Optional[bool] = None) -> None:
self.mode("overwrite" if overwrite else "append")
self._write.table_name = tableName
self._write.table_save_method = "insert_into"
self._spark.client.execute_command(
_, _, qe = self._spark.client.execute_command(
self._write.command(self._spark.client), self._write.observations
)
self._callback(qe)

insertInto.__doc__ = PySparkDataFrameWriter.insertInto.__doc__

Expand All @@ -681,9 +690,10 @@ def saveAsTable(
self.format(format)
self._write.table_name = name
self._write.table_save_method = "save_as_table"
self._spark.client.execute_command(
_, _, qe = self._spark.client.execute_command(
self._write.command(self._spark.client), self._write.observations
)
self._callback(qe)

saveAsTable.__doc__ = PySparkDataFrameWriter.saveAsTable.__doc__

Expand Down Expand Up @@ -845,11 +855,18 @@ def jdbc(


class DataFrameWriterV2(OptionUtils):
def __init__(self, plan: "LogicalPlan", session: "SparkSession", table: str):
def __init__(
self,
plan: "LogicalPlan",
session: "SparkSession",
table: str,
callback: Optional[Callable[['QueryExecution'], None]] = None,
):
self._df: "LogicalPlan" = plan
self._spark: "SparkSession" = session
self._table_name: str = table
self._write: "WriteOperationV2" = WriteOperationV2(self._df, self._table_name)
self._callback = callback if callback is not None else lambda _: None

def using(self, provider: str) -> "DataFrameWriterV2":
self._write.provider = provider
Expand Down Expand Up @@ -884,50 +901,56 @@ def partitionedBy(self, col: "ColumnOrName", *cols: "ColumnOrName") -> "DataFram

def create(self) -> None:
self._write.mode = "create"
self._spark.client.execute_command(
_, _, qe = self._spark.client.execute_command(
self._write.command(self._spark.client), self._write.observations
)
self._callback(qe)

create.__doc__ = PySparkDataFrameWriterV2.create.__doc__

def replace(self) -> None:
self._write.mode = "replace"
self._spark.client.execute_command(
_, _, qe = self._spark.client.execute_command(
self._write.command(self._spark.client), self._write.observations
)
self._callback(qe)

replace.__doc__ = PySparkDataFrameWriterV2.replace.__doc__

def createOrReplace(self) -> None:
self._write.mode = "create_or_replace"
self._spark.client.execute_command(
_, _, qe = self._spark.client.execute_command(
self._write.command(self._spark.client), self._write.observations
)
self._callback(qe)

createOrReplace.__doc__ = PySparkDataFrameWriterV2.createOrReplace.__doc__

def append(self) -> None:
self._write.mode = "append"
self._spark.client.execute_command(
_, _, qe = self._spark.client.execute_command(
self._write.command(self._spark.client), self._write.observations
)
self._callback(qe)

append.__doc__ = PySparkDataFrameWriterV2.append.__doc__

def overwrite(self, condition: "ColumnOrName") -> None:
self._write.mode = "overwrite"
self._write.overwrite_condition = F._to_col(condition)
self._spark.client.execute_command(
_, _, qe = self._spark.client.execute_command(
self._write.command(self._spark.client), self._write.observations
)
self._callback(qe)

overwrite.__doc__ = PySparkDataFrameWriterV2.overwrite.__doc__

def overwritePartitions(self) -> None:
self._write.mode = "overwrite_partitions"
self._spark.client.execute_command(
_, _, qe = self._spark.client.execute_command(
self._write.command(self._spark.client), self._write.observations
)
self._callback(qe)

overwritePartitions.__doc__ = PySparkDataFrameWriterV2.overwritePartitions.__doc__

Expand Down
7 changes: 5 additions & 2 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,9 +720,12 @@ def sql(
_views.append(SubqueryAlias(df._plan, name))

cmd = SQL(sqlQuery, _args, _named_args, _views)
data, properties = self.client.execute_command(cmd.command(self._client))
data, properties, qe = self.client.execute_command(cmd.command(self._client))
if "sql_command_result" in properties:
return DataFrame(CachedRelation(properties["sql_command_result"]), self)
df = DataFrame(CachedRelation(properties["sql_command_result"]), self)
# A command result contains the execution.
df._query_execution = qe
return df
else:
return DataFrame(cmd, self)

Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/sql/tests/connect/test_df_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def test_df_quey_execution_empty_before_execution(self):
qe = df.queryExecution
self.assertIsNone(qe, "The query execution must be None before the action is executed")

def test_df_query_execution_with_writes(self):
df: DataFrame = self.connect.range(100).repation(10).groupBy("id").count()
df.write.save("/tmp/test_df_query_execution_with_writes", format="json")

qe = df.queryExecution
self.assertIsNone(qe, "The query execution must be None after the write action is executed")

@unittest.skipIf(not have_graphviz, graphviz_requirement_message)
def test_df_query_execution_metrics_to_dot(self):
df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count()
Expand Down