Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_python_connect.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ jobs:
python packaging/connect/setup.py sdist
cd dist
pip install pyspark-connect-*.tar.gz
pip install 'six==1.16.0' 'pandas<=2.2.2' scipy 'plotly>=4.8' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' torch torchvision torcheval deepspeed unittest-xml-reporting
pip install 'six==1.16.0' 'pandas<=2.2.2' scipy 'plotly>=4.8' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' torch torchvision torcheval deepspeed unittest-xml-reporting parameterized
- name: Run tests
env:
SPARK_TESTING: 1
Expand Down
1 change: 1 addition & 0 deletions dev/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ memory-profiler>=0.61.0
# PySpark test dependencies
unittest-xml-reporting
openpyxl
parameterized

# PySpark test dependencies (optional)
coverage
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/sql/connect/client/reattach.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,10 @@ def _call_iter(self, iter_fun: Callable) -> Any:
return iter_fun()
except grpc.RpcError as e:
status = rpc_status.from_call(cast(grpc.Call, e))
if status is not None and "INVALID_HANDLE.OPERATION_NOT_FOUND" in status.message:
if status is not None and (
"INVALID_HANDLE.OPERATION_NOT_FOUND" in status.message
or "INVALID_HANDLE.SESSION_NOT_FOUND" in status.message
):
if self._last_returned_response_id is not None:
raise PySparkRuntimeError(
error_class="RESPONSE_ALREADY_RECEIVED",
Expand Down
88 changes: 84 additions & 4 deletions python/pyspark/sql/tests/connect/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
import unittest
import uuid
from collections.abc import Generator
from typing import Optional, Any
from typing import Optional, Any, Union
from parameterized import parameterized

from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import eventually

if should_test_connect:
import grpc
from google.rpc import status_pb2
import pandas as pd
import pyarrow as pa
from pyspark.sql.connect.client import SparkConnectClient, DefaultChannelBuilder
Expand All @@ -33,7 +35,7 @@
DefaultPolicy,
)
from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator
from pyspark.errors import RetriesExceeded
from pyspark.errors import PySparkRuntimeError, RetriesExceeded
import pyspark.sql.connect.proto as proto

class TestPolicy(DefaultPolicy):
Expand All @@ -50,18 +52,29 @@ def __init__(self):
class TestException(grpc.RpcError, grpc.Call):
"""Exception mock to test retryable exceptions."""

def __init__(self, msg, code=grpc.StatusCode.INTERNAL):
def __init__(
self,
msg,
code=grpc.StatusCode.INTERNAL,
trailing_status: Union[status_pb2.Status, None] = None,
):
self.msg = msg
self._code = code
self._trailer: dict[str, Any] = {}
if trailing_status is not None:
self._trailer["grpc-status-details-bin"] = trailing_status.SerializeToString()

def code(self):
return self._code

def __str__(self):
return self.msg

def details(self):
return self.msg

def trailing_metadata(self):
return ()
return None if not self._trailer else self._trailer.items()

class ResponseGenerator(Generator):
"""This class is used to generate values that are returned by the streaming
Expand Down Expand Up @@ -340,6 +353,73 @@ def check():

eventually(timeout=1, catch_assertions=True)(check)()

@parameterized.expand(
[
("session", "INVALID_HANDLE.SESSION_NOT_FOUND"),
("operation", "INVALID_HANDLE.OPERATION_NOT_FOUND"),
]
)
def test_not_found_recovers(self, _, error_msg: str):
"""SPARK-48056: Assert that the client recovers from session or operation not
found error if no partial responses were previously received.
"""

def not_found():
raise TestException(
error_msg,
grpc.StatusCode.UNAVAILABLE,
trailing_status=status_pb2.Status(code=14, message=error_msg, details=""),
)

stub = self._stub_with([not_found, self.finished])
ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, [])

for _ in ite:
pass

def checks():
self.assertEquals(2, stub.execute_calls)
self.assertEquals(0, stub.attach_calls)
self.assertEquals(0, stub.release_calls)
self.assertEquals(0, stub.release_until_calls)

eventually(timeout=1, catch_assertions=True)(checks)()

@parameterized.expand(
[
("session", "INVALID_HANDLE.SESSION_NOT_FOUND"),
("operation", "INVALID_HANDLE.OPERATION_NOT_FOUND"),
]
)
def test_not_found_fails(self, _, error_msg: str):
"""SPARK-48056: Assert that the client fails from session or operation not found error
if a partial response was previously received.
"""

def not_found():
raise TestException(
error_msg,
grpc.StatusCode.UNAVAILABLE,
trailing_status=status_pb2.Status(code=14, message=error_msg, details=""),
)

stub = self._stub_with([self.response], [not_found])

with self.assertRaises(PySparkRuntimeError) as e:
ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, [])
for _ in ite:
pass

self.assertTrue("RESPONSE_ALREADY_RECEIVED" in e.exception.getMessage())

def checks():
self.assertEquals(1, stub.execute_calls)
self.assertEquals(1, stub.attach_calls)
self.assertEquals(0, stub.release_calls)
self.assertEquals(0, stub.release_until_calls)

eventually(timeout=1, catch_assertions=True)(checks)()


if __name__ == "__main__":
from pyspark.sql.tests.connect.client.test_client import * # noqa: F401
Expand Down