Skip to content
Next Next commit
crappy start
  • Loading branch information
nija-at committed Apr 29, 2024
commit 181d0e8b94be1bda1872b95377f0a5b4567c827a
61 changes: 58 additions & 3 deletions python/pyspark/sql/tests/connect/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
import unittest
import uuid
from collections.abc import Generator
from typing import Optional, Any
from typing import Optional, Any, Union

from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import eventually

if should_test_connect:
import grpc
from google.rpc import status_pb2
import pandas as pd
import pyarrow as pa
from pyspark.sql.connect.client import SparkConnectClient, DefaultChannelBuilder
Expand All @@ -37,6 +38,7 @@
import pyspark.sql.connect.proto as proto

class TestPolicy(DefaultPolicy):
__test__ = False
def __init__(self):
super().__init__(
max_retries=3,
Expand All @@ -47,21 +49,36 @@ def __init__(self):
min_jitter_threshold=10,
)

class NoRetry(DefaultPolicy):
def __init__(self):
super().__init__(
max_retries=0
)

class TestException(grpc.RpcError, grpc.Call):
"""Exception mock to test retryable exceptions."""

def __init__(self, msg, code=grpc.StatusCode.INTERNAL):
def __init__(
self,
msg,
code=grpc.StatusCode.INTERNAL,
trailing_metadata: Union[dict[str, Any], None] = None,
):
self.msg = msg
self._code = code
self._trailing_metadata = trailing_metadata

def code(self):
return self._code

def __str__(self):
return self.msg

def details(self):
return self.msg

def trailing_metadata(self):
return ()
return None if self._trailing_metadata is None else self._trailing_metadata.items()

class ResponseGenerator(Generator):
"""This class is used to generate values that are returned by the streaming
Expand Down Expand Up @@ -340,6 +357,44 @@ def check():

eventually(timeout=1, catch_assertions=True)(check)()

def test_restart_new_session(self):
def execute_unavailable():
raise TestException("Unavailable", grpc.StatusCode.UNAVAILABLE)

def not_found():
raise TestException(
"INVALID_HANDLE.OPERATION_NOT_FOUND",
grpc.StatusCode.UNAVAILABLE,
trailing_metadata={
"grpc-status-details-bin": status_pb2.Status(
code=14,
message="INVALID_HANDLE.OPERATION_NOT_FOUND",
details="",
).SerializeToString()
}
)

stub = self._stub_with(
[execute_unavailable, self.finished], [not_found]
)
ite = ExecutePlanResponseReattachableIterator(
self.request,
stub,
self.retrying,
# lambda: Retrying(NoRetry()),
[]
)

for b in ite:
pass

def checks():
self.assertEquals(1, stub.execute_calls)
self.assertEquals(1, stub.attach_calls)

eventually(timeout=1, catch_assertions=True)(checks)()



if __name__ == "__main__":
from pyspark.sql.tests.connect.client.test_client import * # noqa: F401
Expand Down