Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions sdk/core/corehttp/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Release History

## 1.0.0b3 (2024-02-01)

### Breaking Changes

- The `SansIOHTTPPolicy` and `HTTPPolicy` under `corehttp.runtime.policies` are now `typing.Protocols`. [#34296](https://github.com/Azure/azure-sdk-for-python/pull/34296)


## 1.0.0b3 (2024-02-01)

### Features Added
Expand Down
14 changes: 11 additions & 3 deletions sdk/core/corehttp/corehttp/runtime/pipeline/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class _SansIOHTTPPolicyRunner(HTTPPolicy[HTTPRequestType, HTTPResponseType]):
:type policy: ~corehttp.runtime.pipeline.policies.SansIOHTTPPolicy
"""

next: "HTTPPolicy[HTTPRequestType, HTTPResponseType]"
"""Pointer to the next policy or a transport (wrapped as a policy). Will be set at pipeline creation."""

def __init__(self, policy: SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]) -> None:
super(_SansIOHTTPPolicyRunner, self).__init__()
self._policy = policy
Expand Down Expand Up @@ -78,6 +81,9 @@ class _TransportRunner(HTTPPolicy[HTTPRequestType, HTTPResponseType]):
:type sender: ~corehttp.transport.HttpTransport
"""

next: "HTTPPolicy[HTTPRequestType, HTTPResponseType]"
"""Pointer to the next policy or a transport (wrapped as a policy). Will be set at pipeline creation."""

def __init__(self, sender: HttpTransport[HTTPRequestType, HTTPResponseType]) -> None:
super(_TransportRunner, self).__init__()
self._sender = sender
Expand Down Expand Up @@ -123,10 +129,12 @@ def __init__(
self._transport = transport

for policy in policies or []:
if isinstance(policy, SansIOHTTPPolicy):
self._impl_policies.append(_SansIOHTTPPolicyRunner(policy))
elif policy:
if isinstance(policy, HTTPPolicy):
self._impl_policies.append(policy)
elif isinstance(policy, SansIOHTTPPolicy):
self._impl_policies.append(_SansIOHTTPPolicyRunner(policy))
else:
raise TypeError("Unsupported policy type: {}".format(type(policy)))
for index in range(len(self._impl_policies) - 1):
self._impl_policies[index].next = self._impl_policies[index + 1]
if self._impl_policies:
Expand Down
14 changes: 11 additions & 3 deletions sdk/core/corehttp/corehttp/runtime/pipeline/_base_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class _SansIOAsyncHTTPPolicyRunner(
:type policy: ~corehttp.runtime.pipeline.policies.SansIOHTTPPolicy
"""

next: "AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]"
"""Pointer to the next policy or a transport (wrapped as a policy). Will be set at pipeline creation."""

def __init__(self, policy: SansIOHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]) -> None:
super(_SansIOAsyncHTTPPolicyRunner, self).__init__()
self._policy = policy
Expand Down Expand Up @@ -79,6 +82,9 @@ class _AsyncTransportRunner(
:type sender: ~corehttp.transport.AsyncHttpTransport
"""

next: "AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]"
"""Pointer to the next policy or a transport (wrapped as a policy). Will be set at pipeline creation."""

def __init__(self, sender: AsyncHttpTransport[HTTPRequestType, AsyncHTTPResponseType]) -> None:
super(_AsyncTransportRunner, self).__init__()
self._sender = sender
Expand Down Expand Up @@ -127,10 +133,12 @@ def __init__(
self._transport = transport

for policy in policies or []:
if isinstance(policy, SansIOHTTPPolicy):
self._impl_policies.append(_SansIOAsyncHTTPPolicyRunner(policy))
elif policy:
if isinstance(policy, AsyncHTTPPolicy):
self._impl_policies.append(policy)
elif isinstance(policy, SansIOHTTPPolicy):
self._impl_policies.append(_SansIOAsyncHTTPPolicyRunner(policy))
else:
raise TypeError("Unsupported policy type: {}".format(type(policy)))
for index in range(len(self._impl_policies) - 1):
self._impl_policies[index].next = self._impl_policies[index + 1]
if self._impl_policies:
Expand Down
28 changes: 17 additions & 11 deletions sdk/core/corehttp/corehttp/runtime/policies/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,33 +24,35 @@
#
# --------------------------------------------------------------------------
from __future__ import annotations
import abc
import copy
import logging

from typing import Generic, TypeVar, Union, Any, Optional, Awaitable, Dict, TYPE_CHECKING
from typing_extensions import Protocol, runtime_checkable

if TYPE_CHECKING:
from ...runtime.pipeline import PipelineRequest, PipelineResponse

HTTPResponseType = TypeVar("HTTPResponseType")
HTTPRequestType = TypeVar("HTTPRequestType")
SansIOHTTPResponseType_contra = TypeVar("SansIOHTTPResponseType_contra", contravariant=True)
SansIOHTTPRequestType_contra = TypeVar("SansIOHTTPRequestType_contra", contravariant=True)

_LOGGER = logging.getLogger(__name__)


class HTTPPolicy(abc.ABC, Generic[HTTPRequestType, HTTPResponseType]):
"""An HTTP policy ABC.
@runtime_checkable
class HTTPPolicy(Generic[HTTPRequestType, HTTPResponseType], Protocol):
"""An HTTP policy protocol.

Use with a synchronous pipeline.
"""

next: "HTTPPolicy[HTTPRequestType, HTTPResponseType]"
"""Pointer to the next policy or a transport (wrapped as a policy). Will be set at pipeline creation."""

@abc.abstractmethod
def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]:
"""Abstract send method for a synchronous pipeline. Mutates the request.
"""Send method for a synchronous pipeline. Mutates the request.

Context content is dependent on the HttpTransport.

Expand All @@ -59,31 +61,34 @@ def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HT
:return: The pipeline response object.
:rtype: ~corehttp.runtime.pipeline.PipelineResponse
"""
...


class SansIOHTTPPolicy(Generic[HTTPRequestType, HTTPResponseType]):
@runtime_checkable
class SansIOHTTPPolicy(Generic[SansIOHTTPRequestType_contra, SansIOHTTPResponseType_contra], Protocol):
"""Represents a sans I/O policy.

SansIOHTTPPolicy is a base class for policies that only modify or
SansIOHTTPPolicy is a protocol for policies that only modify or
mutate a request based on the HTTP specification, and do not depend
on the specifics of any particular transport. SansIOHTTPPolicy
subclasses will function in either a Pipeline or an AsyncPipeline,
subtype classes will function in either a Pipeline or an AsyncPipeline,
and can act either before the request is done, or after.
You can optionally make these methods coroutines (or return awaitable objects)
but they will then be tied to AsyncPipeline usage.
"""

def on_request(self, request: PipelineRequest[HTTPRequestType]) -> Union[None, Awaitable[None]]:
def on_request(self, request: PipelineRequest[SansIOHTTPRequestType_contra]) -> Union[None, Awaitable[None]]:
"""Is executed before sending the request from next policy.

:param request: Request to be modified before sent from next policy.
:type request: ~corehttp.runtime.pipeline.PipelineRequest
"""
...

def on_response(
self,
request: PipelineRequest[HTTPRequestType],
response: PipelineResponse[HTTPRequestType, HTTPResponseType],
request: PipelineRequest[SansIOHTTPRequestType_contra],
response: PipelineResponse[SansIOHTTPRequestType_contra, SansIOHTTPResponseType_contra],
) -> Union[None, Awaitable[None]]:
"""Is executed after the request comes back from the policy.

Expand All @@ -92,6 +97,7 @@ def on_response(
:param response: Pipeline response object
:type response: ~corehttp.runtime.pipeline.PipelineResponse
"""
...


class RequestHistory(Generic[HTTPRequestType, HTTPResponseType]):
Expand Down
12 changes: 6 additions & 6 deletions sdk/core/corehttp/corehttp/runtime/policies/_base_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,32 +24,31 @@
#
# --------------------------------------------------------------------------
from __future__ import annotations
import abc

from typing import Generic, TypeVar, TYPE_CHECKING
from typing_extensions import Protocol, runtime_checkable

if TYPE_CHECKING:
from ...runtime.pipeline import PipelineRequest, PipelineResponse

AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType")
HTTPResponseType = TypeVar("HTTPResponseType")
HTTPRequestType = TypeVar("HTTPRequestType")


class AsyncHTTPPolicy(abc.ABC, Generic[HTTPRequestType, AsyncHTTPResponseType]):
"""An async HTTP policy ABC.
@runtime_checkable
class AsyncHTTPPolicy(Generic[HTTPRequestType, AsyncHTTPResponseType], Protocol):
"""An async HTTP policy protocol.

Use with an asynchronous pipeline.
"""

next: "AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]"
"""Pointer to the next policy or a transport (wrapped as a policy). Will be set at pipeline creation."""

@abc.abstractmethod
async def send(
self, request: PipelineRequest[HTTPRequestType]
) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]:
"""Abstract send method for a asynchronous pipeline. Mutates the request.
"""Send method for a asynchronous pipeline. Mutates the request.

Context content is dependent on the HttpTransport.

Expand All @@ -58,3 +57,4 @@ async def send(
:return: The pipeline response object.
:rtype: ~corehttp.runtime.pipeline.PipelineResponse
"""
...
3 changes: 3 additions & 0 deletions sdk/core/corehttp/corehttp/runtime/policies/_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,9 @@ class RetryPolicy(RetryPolicyBase, HTTPPolicy[HttpRequest, HttpResponse]):
:keyword int timeout: Timeout setting for the operation in seconds, default is 604800s (7 days).
"""

next: "HTTPPolicy[HttpRequest, HttpResponse]"
"""Pointer to the next policy or a transport (wrapped as a policy). Will be set at pipeline creation."""

def _sleep_for_retry(
self,
response: PipelineResponse[HttpRequest, HttpResponse],
Expand Down
3 changes: 3 additions & 0 deletions sdk/core/corehttp/corehttp/runtime/policies/_retry_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class AsyncRetryPolicy(RetryPolicyBase, AsyncHTTPPolicy[HttpRequest, AsyncHttpRe
:keyword int retry_backoff_max: The maximum back off time. Default value is 120 seconds (2 minutes).
"""

next: "AsyncHTTPPolicy[HttpRequest, AsyncHttpResponse]"
"""Pointer to the next policy or a transport (wrapped as a policy). Will be set at pipeline creation."""

async def _sleep_for_retry(
self,
response: PipelineResponse[HttpRequest, AsyncHttpResponse],
Expand Down
44 changes: 43 additions & 1 deletion sdk/core/corehttp/corehttp/runtime/policies/_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import xml.etree.ElementTree as ET
import types
import re
from typing import IO, cast, Union, Optional, AnyStr, Dict, Any, Mapping, TYPE_CHECKING
from typing import IO, cast, Union, Optional, AnyStr, Dict, Any, Mapping, Awaitable, TYPE_CHECKING

from ... import __version__ as core_version
from ...exceptions import DecodeError
Expand Down Expand Up @@ -98,6 +98,20 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
if additional_headers:
request.http_request.headers.update(additional_headers)

def on_response( # pylint: disable=unused-argument
self,
request: PipelineRequest[HTTPRequestType],
response: PipelineResponse[HTTPRequestType, HTTPResponseType],
) -> Union[None, Awaitable[None]]:
"""Is executed after the request comes back from the policy.

:param request: Request to be modified after returning from the policy.
:type request: ~corehttp.runtime.pipeline.PipelineRequest
:param response: Pipeline response object
:type response: ~corehttp.runtime.pipeline.PipelineResponse
"""
return None


class UserAgentPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
"""User-Agent Policy. Allows custom values to be added to the User-Agent header.
Expand Down Expand Up @@ -170,6 +184,20 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
elif self.overwrite or self._USERAGENT not in http_request.headers:
http_request.headers[self._USERAGENT] = self.user_agent

def on_response( # pylint: disable=unused-argument
self,
request: PipelineRequest[HTTPRequestType],
response: PipelineResponse[HTTPRequestType, HTTPResponseType],
) -> Union[None, Awaitable[None]]:
"""Is executed after the request comes back from the policy.

:param request: Request to be modified after returning from the policy.
:type request: ~corehttp.runtime.pipeline.PipelineRequest
:param response: Pipeline response object
:type response: ~corehttp.runtime.pipeline.PipelineResponse
"""
return None


class NetworkTraceLoggingPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):

Expand Down Expand Up @@ -462,3 +490,17 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
ctxt = request.context.options
if self.proxies and "proxies" not in ctxt:
ctxt["proxies"] = self.proxies

def on_response( # pylint: disable=unused-argument
self,
request: PipelineRequest[HTTPRequestType],
response: PipelineResponse[HTTPRequestType, HTTPResponseType],
) -> Union[None, Awaitable[None]]:
"""Is executed after the request comes back from the policy.

:param request: Request to be modified after returning from the policy.
:type request: ~corehttp.runtime.pipeline.PipelineRequest
:param response: Pipeline response object
:type response: ~corehttp.runtime.pipeline.PipelineResponse
"""
return None
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@

@pytest.mark.asyncio
async def test_sans_io_exception():
class SansIOHTTPPolicyImpl(SansIOHTTPPolicy):
def on_request(self, request):
pass

def on_response(self, request, response):
pass

class BrokenSender(AsyncHttpTransport):
async def send(self, request, **config):
raise ValueError("Broken")
Expand All @@ -42,7 +49,7 @@ async def __aexit__(self, exc_type, exc_value, traceback):
"""Raise any exception triggered within the runtime context."""
return self.close()

pipeline = AsyncPipeline(BrokenSender(), [SansIOHTTPPolicy()])
pipeline = AsyncPipeline(BrokenSender(), [SansIOHTTPPolicyImpl()])

req = HttpRequest("GET", "/")
with pytest.raises(ValueError):
Expand Down
9 changes: 8 additions & 1 deletion sdk/core/corehttp/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@


def test_sans_io_exception():
class SansIOHTTPPolicyImpl(SansIOHTTPPolicy):
def on_request(self, request):
pass

def on_response(self, request, response):
pass

class BrokenSender(HttpTransport):
def send(self, request, **config):
raise ValueError("Broken")
Expand All @@ -45,7 +52,7 @@ def __exit__(self, exc_type, exc_value, traceback):
"""Raise any exception triggered within the runtime context."""
return self.close()

pipeline = Pipeline(BrokenSender(), [SansIOHTTPPolicy()])
pipeline = Pipeline(BrokenSender(), [SansIOHTTPPolicyImpl()])

req = HttpRequest("GET", "/")
with pytest.raises(ValueError):
Expand Down