diff --git a/sdk/core/corehttp/corehttp/runtime/pipeline/_base.py b/sdk/core/corehttp/corehttp/runtime/pipeline/_base.py index db40018a6b3e..a30f855ab112 100644 --- a/sdk/core/corehttp/corehttp/runtime/pipeline/_base.py +++ b/sdk/core/corehttp/corehttp/runtime/pipeline/_base.py @@ -26,6 +26,7 @@ from __future__ import annotations import logging from typing import Generic, TypeVar, Union, Any, List, Optional, Iterable, ContextManager +from typing_extensions import TypeGuard from . import ( PipelineRequest, @@ -42,6 +43,18 @@ _LOGGER = logging.getLogger(__name__) +def is_http_policy(policy: object) -> TypeGuard[HTTPPolicy]: + if hasattr(policy, "send"): + return True + return False + + +def is_sansio_http_policy(policy: object) -> TypeGuard[SansIOHTTPPolicy]: + if hasattr(policy, "on_request") and hasattr(policy, "on_response"): + return True + return False + + class _SansIOHTTPPolicyRunner(HTTPPolicy[HTTPRequestType, HTTPResponseType]): """Sync implementation of the SansIO policy. @@ -123,10 +136,14 @@ def __init__( self._transport = transport for policy in policies or []: - if isinstance(policy, SansIOHTTPPolicy): + if is_http_policy(policy): + self._impl_policies.append(policy) + elif is_sansio_http_policy(policy): self._impl_policies.append(_SansIOHTTPPolicyRunner(policy)) elif policy: - self._impl_policies.append(policy) + raise AttributeError( + f"'{type(policy)}' object has no attribute 'send' or both 'on_request' and 'on_response'." + ) for index in range(len(self._impl_policies) - 1): self._impl_policies[index].next = self._impl_policies[index + 1] if self._impl_policies: diff --git a/sdk/core/corehttp/corehttp/runtime/pipeline/_base_async.py b/sdk/core/corehttp/corehttp/runtime/pipeline/_base_async.py index 7e51ed468b62..00c9b3c17d45 100644 --- a/sdk/core/corehttp/corehttp/runtime/pipeline/_base_async.py +++ b/sdk/core/corehttp/corehttp/runtime/pipeline/_base_async.py @@ -24,12 +24,14 @@ # # -------------------------------------------------------------------------- from __future__ import annotations +import inspect from types import TracebackType from typing import Any, Union, Generic, TypeVar, List, Optional, Iterable, Type -from typing_extensions import AsyncContextManager +from typing_extensions import AsyncContextManager, TypeGuard from . import PipelineRequest, PipelineResponse, PipelineContext from ..policies import AsyncHTTPPolicy, SansIOHTTPPolicy +from ..pipeline._base import is_sansio_http_policy from ._tools_async import await_result as _await_result from ...transport import AsyncHttpTransport @@ -37,6 +39,12 @@ HTTPRequestType = TypeVar("HTTPRequestType") +def is_async_http_policy(policy: object) -> TypeGuard[AsyncHTTPPolicy]: + if hasattr(policy, "send") and inspect.iscoroutinefunction(policy.send): + return True + return False + + class _SansIOAsyncHTTPPolicyRunner( AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType] ): # pylint: disable=unsubscriptable-object @@ -127,10 +135,14 @@ def __init__( self._transport = transport for policy in policies or []: - if isinstance(policy, SansIOHTTPPolicy): + if is_async_http_policy(policy): + self._impl_policies.append(policy) + elif is_sansio_http_policy(policy): self._impl_policies.append(_SansIOAsyncHTTPPolicyRunner(policy)) elif policy: - self._impl_policies.append(policy) + raise AttributeError( + f"'{type(policy)}' object has no attribute 'send' or both 'on_request' and 'on_response'." + ) for index in range(len(self._impl_policies) - 1): self._impl_policies[index].next = self._impl_policies[index + 1] if self._impl_policies: diff --git a/sdk/core/corehttp/tests/async_tests/test_authentication_async.py b/sdk/core/corehttp/tests/async_tests/test_authentication_async.py index e441fefa48a3..8d81f6b269d2 100644 --- a/sdk/core/corehttp/tests/async_tests/test_authentication_async.py +++ b/sdk/core/corehttp/tests/async_tests/test_authentication_async.py @@ -16,6 +16,7 @@ SansIOHTTPPolicy, ) from corehttp.rest import HttpRequest +from azure.core.pipeline.policies import AzureKeyCredentialPolicy import pytest pytestmark = pytest.mark.asyncio @@ -93,12 +94,15 @@ async def get_token(*_, **__): get_token_calls += 1 return expected_token + async def send_mock(_): + return Mock(http_response=Mock(status_code=200)) + credential = Mock(get_token=get_token) policies = [ AsyncBearerTokenCredentialPolicy(credential, "scope"), - Mock(send=Mock(return_value=get_completed_future(Mock()))), + Mock(send=send_mock), ] - pipeline = AsyncPipeline(transport=Mock, policies=policies) + pipeline = AsyncPipeline(transport=Mock(), policies=policies) await pipeline.run(HttpRequest("GET", "https://spam.eggs")) assert get_token_calls == 1 # policy has no token at first request -> it should call get_token @@ -111,7 +115,7 @@ async def get_token(*_, **__): expected_token = expired_token policies = [ AsyncBearerTokenCredentialPolicy(credential, "scope"), - Mock(send=lambda _: get_completed_future(Mock())), + Mock(send=send_mock), ] pipeline = AsyncPipeline(transport=Mock(), policies=policies) @@ -238,6 +242,27 @@ async def fake_send(*args, **kwargs): policy.on_exception.assert_called_once_with(policy.request) +async def test_azure_core_sans_io_policy(): + """Tests to see that we can use an azure.core SansIOHTTPPolicy with the corehttp Pipeline""" + + class TestPolicy(AzureKeyCredentialPolicy): + def __init__(self, *args, **kwargs): + super(TestPolicy, self).__init__(*args, **kwargs) + self.on_exception = Mock(return_value=False) + self.on_request = Mock() + + credential = Mock( + get_token=Mock(return_value=get_completed_future(AccessToken("***", int(time.time()) + 3600))), key="key" + ) + policy = TestPolicy(credential, "scope") + transport = Mock(send=Mock(return_value=get_completed_future(Mock(status_code=200)))) + + pipeline = AsyncPipeline(transport=transport, policies=[policy]) + await pipeline.run(HttpRequest("GET", "https://localhost")) + + policy.on_request.assert_called_once() + + def get_completed_future(result=None): fut = asyncio.Future() fut.set_result(result) diff --git a/sdk/core/corehttp/tests/async_tests/test_pipeline_async.py b/sdk/core/corehttp/tests/async_tests/test_pipeline_async.py index 427449cd6ee9..007ef094d667 100644 --- a/sdk/core/corehttp/tests/async_tests/test_pipeline_async.py +++ b/sdk/core/corehttp/tests/async_tests/test_pipeline_async.py @@ -4,7 +4,7 @@ # license information. # ------------------------------------------------------------------------- from typing import cast -from unittest.mock import AsyncMock, PropertyMock +from unittest.mock import AsyncMock, PropertyMock, Mock from corehttp.rest import HttpRequest from corehttp.runtime import AsyncPipelineClient @@ -49,6 +49,39 @@ async def __aexit__(self, exc_type, exc_value, traceback): await pipeline.run(req) +def test_invalid_policy_error(): + # non-HTTPPolicy/non-SansIOHTTPPolicy should raise an error + class FooPolicy: + pass + + # sync send method should raise an error + class SyncSendPolicy: + def send(self, request): + pass + + # only on_request should raise an error + class OnlyOnRequestPolicy: + def on_request(self, request): + pass + + # only on_response should raise an error + class OnlyOnResponsePolicy: + def on_response(self, request, response): + pass + + with pytest.raises(AttributeError): + pipeline = AsyncPipeline(transport=Mock(), policies=[FooPolicy()]) + + with pytest.raises(AttributeError): + pipeline = AsyncPipeline(transport=Mock(), policies=[SyncSendPolicy()]) + + with pytest.raises(AttributeError): + pipeline = AsyncPipeline(transport=Mock(), policies=[OnlyOnRequestPolicy()]) + + with pytest.raises(AttributeError): + pipeline = AsyncPipeline(transport=Mock(), policies=[OnlyOnResponsePolicy()]) + + @pytest.mark.asyncio @pytest.mark.parametrize("transport", ASYNC_TRANSPORTS) async def test_transport_socket_timeout(transport): @@ -95,7 +128,7 @@ async def test_basic_aiohttp_separate_session(port): @pytest.mark.asyncio async def test_retry_without_http_response(): class NaughtyPolicy(AsyncHTTPPolicy): - def send(*args): + async def send(*args): raise BaseError("boo") policies = [AsyncRetryPolicy(), NaughtyPolicy()] @@ -107,11 +140,11 @@ def send(*args): @pytest.mark.asyncio async def test_add_custom_policy(): class BooPolicy(AsyncHTTPPolicy): - def send(*args): + async def send(*args): raise BaseError("boo") class FooPolicy(AsyncHTTPPolicy): - def send(*args): + async def send(*args): raise BaseError("boo") retry_policy = AsyncRetryPolicy() diff --git a/sdk/core/corehttp/tests/test_authentication.py b/sdk/core/corehttp/tests/test_authentication.py index 5e6e49b14119..e0781bc769dc 100644 --- a/sdk/core/corehttp/tests/test_authentication.py +++ b/sdk/core/corehttp/tests/test_authentication.py @@ -15,6 +15,7 @@ ServiceKeyCredentialPolicy, ) from corehttp.rest import HttpRequest +from azure.core.pipeline.policies import AzureKeyCredentialPolicy import pytest @@ -251,6 +252,25 @@ def raise_the_second_time(*args, **kwargs): policy.on_exception.assert_called_once_with(policy.request) +def test_azure_core_sans_io_policy(): + """Tests to see that we can use an azure.core SansIOHTTPPolicy with the corehttp Pipeline""" + + class TestPolicy(AzureKeyCredentialPolicy): + def __init__(self, *args, **kwargs): + super(TestPolicy, self).__init__(*args, **kwargs) + self.on_exception = Mock(return_value=False) + self.on_request = Mock() + + credential = Mock(get_token=Mock(return_value=AccessToken("***", int(time.time()) + 3600)), key="key") + policy = TestPolicy(credential, "scope") + transport = Mock(send=Mock(return_value=Mock(status_code=200))) + + pipeline = Pipeline(transport=transport, policies=[policy]) + pipeline.run(HttpRequest("GET", "https://localhost")) + + policy.on_request.assert_called_once() + + def test_service_key_credential_policy(): """Tests to see if we can create an ServiceKeyCredentialPolicy""" diff --git a/sdk/core/corehttp/tests/test_pipeline.py b/sdk/core/corehttp/tests/test_pipeline.py index 1df7305b4c53..dd79e4bb1cf2 100644 --- a/sdk/core/corehttp/tests/test_pipeline.py +++ b/sdk/core/corehttp/tests/test_pipeline.py @@ -4,6 +4,7 @@ # license information. # ------------------------------------------------------------------------- +from unittest.mock import Mock import json from io import BytesIO import xml.etree.ElementTree as ET @@ -52,6 +53,31 @@ def __exit__(self, exc_type, exc_value, traceback): pipeline.run(req) +def test_invalid_policy_error(): + # non-HTTPPolicy/non-SansIOHTTPPolicy should raise an error + class FooPolicy: + pass + + # only on_request should raise an error + class OnlyOnRequestPolicy: + def on_request(self, request): + pass + + # only on_response should raise an error + class OnlyOnResponsePolicy: + def on_response(self, request, response): + pass + + with pytest.raises(AttributeError): + pipeline = Pipeline(transport=Mock(), policies=[FooPolicy()]) + + with pytest.raises(AttributeError): + pipeline = Pipeline(transport=Mock(), policies=[OnlyOnRequestPolicy()]) + + with pytest.raises(AttributeError): + pipeline = Pipeline(transport=Mock(), policies=[OnlyOnResponsePolicy()]) + + @pytest.mark.parametrize("transport", SYNC_TRANSPORTS) def test_transport_socket_timeout(transport): request = HttpRequest("GET", "https://bing.com")