Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
test coverage
  • Loading branch information
cdkrot committed Dec 6, 2023
commit 8b707fd736cd09906ecefa806c6f1e55c9c6c027
21 changes: 21 additions & 0 deletions python/pyspark/sql/tests/connect/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from collections.abc import Generator
from typing import Optional, Any

from pyspark.sql.connect.client.core import ForbidRecursion
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import eventually

Expand Down Expand Up @@ -133,6 +134,26 @@ def test_channel_builder_with_session(self):
client = SparkConnectClient(chan)
self.assertEqual(client._session_id, chan.session_id)

def test_forbid_recursion(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test does not test directly the scenario, we're talking about. Ideally you can just use the mock tests we have to fail any query and see that the recursion guard works.

Copy link
Contributor Author

@cdkrot cdkrot Dec 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually tried, but it seems hard to make the mock test for this because it needs to pass through this pieces of code:

status = rpc_status.from_call(cast(grpc.Call, rpc_error))

This seems hard to create a mock exception which would pass this without poking grpc's internals significantly. Alternatively we could introduce some testing clutches here, i.e. check if exception is from testing code, but that's not great either.

@grundprinzip

guard = ForbidRecursion()
max_depth = 0

def g(n):
nonlocal max_depth
with guard:
max_depth = n
g(n + 1)

with self.assertRaises(RecursionError):
g(1)
self.assertEqual(max_depth, 1)

# Do the same test again to check that guard resets.
max_depth = 0
with self.assertRaises(RecursionError):
g(1)
self.assertEqual(max_depth, 1)


class TestPolicy(DefaultPolicy):
def __init__(self):
Expand Down