diff --git a/src/anthropic/_utils/_sync.py b/src/anthropic/_utils/_sync.py index 8b3aaf2..ad7ec71 100644 --- a/src/anthropic/_utils/_sync.py +++ b/src/anthropic/_utils/_sync.py @@ -7,16 +7,20 @@ from typing import Any, TypeVar, Callable, Awaitable from typing_extensions import ParamSpec +import anyio +import sniffio +import anyio.to_thread + T_Retval = TypeVar("T_Retval") T_ParamSpec = ParamSpec("T_ParamSpec") if sys.version_info >= (3, 9): - to_thread = asyncio.to_thread + _asyncio_to_thread = asyncio.to_thread else: # backport of https://docs.python.org/3/library/asyncio-task.html#asyncio.to_thread # for Python 3.8 support - async def to_thread( + async def _asyncio_to_thread( func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs ) -> Any: """Asynchronously run function *func* in a separate thread. @@ -34,6 +38,17 @@ async def to_thread( return await loop.run_in_executor(None, func_call) +async def to_thread( + func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs +) -> T_Retval: + if sniffio.current_async_library() == "asyncio": + return await _asyncio_to_thread(func, *args, **kwargs) + + return await anyio.to_thread.run_sync( + functools.partial(func, *args, **kwargs), + ) + + # inspired by `asyncer`, https://github.com/tiangolo/asyncer def asyncify(function: Callable[T_ParamSpec, T_Retval]) -> Callable[T_ParamSpec, Awaitable[T_Retval]]: """ diff --git a/tests/test_client.py b/tests/test_client.py index f60e2fb..a29e65d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -23,6 +23,7 @@ from anthropic import Anthropic, AsyncAnthropic, APIResponseValidationError from anthropic._types import Omit +from anthropic._utils import maybe_transform from anthropic._models import BaseModel, FinalRequestOptions from anthropic._constants import RAW_RESPONSE_HEADER from anthropic._streaming import Stream, AsyncStream @@ -33,6 +34,7 @@ BaseClient, make_request_options, ) +from anthropic.types.message_create_params import MessageCreateParamsNonStreaming from .utils import update_env @@ -738,15 +740,18 @@ def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> No "/v1/messages", body=cast( object, - dict( - max_tokens=1024, - messages=[ - { - "role": "user", - "content": "Hello, Claude", - } - ], - model="claude-3-5-sonnet-latest", + maybe_transform( + dict( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Hello, Claude", + } + ], + model="claude-3-5-sonnet-latest", + ), + MessageCreateParamsNonStreaming, ), ), cast_to=httpx.Response, @@ -765,15 +770,18 @@ def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> Non "/v1/messages", body=cast( object, - dict( - max_tokens=1024, - messages=[ - { - "role": "user", - "content": "Hello, Claude", - } - ], - model="claude-3-5-sonnet-latest", + maybe_transform( + dict( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Hello, Claude", + } + ], + model="claude-3-5-sonnet-latest", + ), + MessageCreateParamsNonStreaming, ), ), cast_to=httpx.Response, @@ -1618,15 +1626,18 @@ async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) "/v1/messages", body=cast( object, - dict( - max_tokens=1024, - messages=[ - { - "role": "user", - "content": "Hello, Claude", - } - ], - model="claude-3-5-sonnet-latest", + maybe_transform( + dict( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Hello, Claude", + } + ], + model="claude-3-5-sonnet-latest", + ), + MessageCreateParamsNonStreaming, ), ), cast_to=httpx.Response, @@ -1645,15 +1656,18 @@ async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) "/v1/messages", body=cast( object, - dict( - max_tokens=1024, - messages=[ - { - "role": "user", - "content": "Hello, Claude", - } - ], - model="claude-3-5-sonnet-latest", + maybe_transform( + dict( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Hello, Claude", + } + ], + model="claude-3-5-sonnet-latest", + ), + MessageCreateParamsNonStreaming, ), ), cast_to=httpx.Response,