Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix mypy errors
  • Loading branch information
swathipil committed Feb 20, 2024
commit d0fc0d03c50c2f95e0d15c37bafa95b326d6fd13
8 changes: 7 additions & 1 deletion sdk/core/corehttp/corehttp/runtime/pipeline/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class _SansIOHTTPPolicyRunner(HTTPPolicy[HTTPRequestType, HTTPResponseType]):
:type policy: ~corehttp.runtime.pipeline.policies.SansIOHTTPPolicy
"""

next: "HTTPPolicy[HTTPRequestType, HTTPResponseType]"
"""Pointer to the next policy or a transport (wrapped as a policy). Will be set at pipeline creation."""

def __init__(self, policy: SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]) -> None:
super(_SansIOHTTPPolicyRunner, self).__init__()
self._policy = policy
Expand Down Expand Up @@ -78,6 +81,9 @@ class _TransportRunner(HTTPPolicy[HTTPRequestType, HTTPResponseType]):
:type sender: ~corehttp.transport.HttpTransport
"""

next: "HTTPPolicy[HTTPRequestType, HTTPResponseType]"
"""Pointer to the next policy or a transport (wrapped as a policy). Will be set at pipeline creation."""

def __init__(self, sender: HttpTransport[HTTPRequestType, HTTPResponseType]) -> None:
super(_TransportRunner, self).__init__()
self._sender = sender
Expand Down Expand Up @@ -123,7 +129,7 @@ def __init__(
self._transport = transport

for policy in policies or []:
if hasattr(policy, "send"):
if isinstance(policy, HTTPPolicy):
self._impl_policies.append(policy)
elif isinstance(policy, SansIOHTTPPolicy):
self._impl_policies.append(_SansIOHTTPPolicyRunner(policy))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class _SansIOAsyncHTTPPolicyRunner(
:type policy: ~corehttp.runtime.pipeline.policies.SansIOHTTPPolicy
"""

next: "AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]"
"""Pointer to the next policy or a transport (wrapped as a policy). Will be set at pipeline creation."""

def __init__(self, policy: SansIOHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]) -> None:
super(_SansIOAsyncHTTPPolicyRunner, self).__init__()
self._policy = policy
Expand Down Expand Up @@ -79,6 +82,9 @@ class _AsyncTransportRunner(
:type sender: ~corehttp.transport.AsyncHttpTransport
"""

next: "AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]"
"""Pointer to the next policy or a transport (wrapped as a policy). Will be set at pipeline creation."""

def __init__(self, sender: AsyncHttpTransport[HTTPRequestType, AsyncHTTPResponseType]) -> None:
super(_AsyncTransportRunner, self).__init__()
self._sender = sender
Expand Down Expand Up @@ -127,7 +133,7 @@ def __init__(
self._transport = transport

for policy in policies or []:
if hasattr(policy, "send"):
if isinstance(policy, AsyncHTTPPolicy):
self._impl_policies.append(policy)
elif isinstance(policy, SansIOHTTPPolicy):
self._impl_policies.append(_SansIOAsyncHTTPPolicyRunner(policy))
Expand Down
13 changes: 6 additions & 7 deletions sdk/core/corehttp/corehttp/runtime/policies/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#
# --------------------------------------------------------------------------
from __future__ import annotations
import abc
import copy
import logging

Expand All @@ -40,18 +39,18 @@
_LOGGER = logging.getLogger(__name__)


class HTTPPolicy(abc.ABC, Generic[HTTPRequestType, HTTPResponseType]):
"""An HTTP policy ABC.
@runtime_checkable
class HTTPPolicy(Generic[HTTPRequestType, HTTPResponseType], Protocol):
"""An HTTP policy protocol.

Use with a synchronous pipeline.
"""

next: "HTTPPolicy[HTTPRequestType, HTTPResponseType]"
"""Pointer to the next policy or a transport (wrapped as a policy). Will be set at pipeline creation."""

@abc.abstractmethod
def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]:
"""Abstract send method for a synchronous pipeline. Mutates the request.
"""Send method for a synchronous pipeline. Mutates the request.

Context content is dependent on the HttpTransport.

Expand All @@ -66,10 +65,10 @@ def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HT
class SansIOHTTPPolicy(Generic[HTTPRequestType, HTTPResponseType], Protocol):
"""Represents a sans I/O policy.

SansIOHTTPPolicy is a base class for policies that only modify or
SansIOHTTPPolicy is a protocol for policies that only modify or
mutate a request based on the HTTP specification, and do not depend
on the specifics of any particular transport. SansIOHTTPPolicy
subclasses will function in either a Pipeline or an AsyncPipeline,
subtype classes will function in either a Pipeline or an AsyncPipeline,
and can act either before the request is done, or after.
You can optionally make these methods coroutines (or return awaitable objects)
but they will then be tied to AsyncPipeline usage.
Expand Down
9 changes: 5 additions & 4 deletions sdk/core/corehttp/corehttp/runtime/policies/_base_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import abc

from typing import Generic, TypeVar, TYPE_CHECKING
from typing_extensions import Protocol, runtime_checkable

if TYPE_CHECKING:
from ...runtime.pipeline import PipelineRequest, PipelineResponse
Expand All @@ -36,20 +37,20 @@
HTTPRequestType = TypeVar("HTTPRequestType")


class AsyncHTTPPolicy(abc.ABC, Generic[HTTPRequestType, AsyncHTTPResponseType]):
"""An async HTTP policy ABC.
@runtime_checkable
class AsyncHTTPPolicy(Generic[HTTPRequestType, AsyncHTTPResponseType], Protocol):
"""An async HTTP policy protocol.

Use with an asynchronous pipeline.
"""

next: "AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]"
"""Pointer to the next policy or a transport (wrapped as a policy). Will be set at pipeline creation."""

@abc.abstractmethod
async def send(
self, request: PipelineRequest[HTTPRequestType]
) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]:
"""Abstract send method for a asynchronous pipeline. Mutates the request.
"""Send method for a asynchronous pipeline. Mutates the request.

Context content is dependent on the HttpTransport.

Expand Down
5 changes: 5 additions & 0 deletions sdk/core/corehttp/corehttp/runtime/policies/_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@

AllHttpResponseType = TypeVar("AllHttpResponseType", HttpResponse, AsyncHttpResponse)
ClsRetryPolicy = TypeVar("ClsRetryPolicy", bound="RetryPolicyBase")
HTTPResponseType = TypeVar("HTTPResponseType")
HTTPRequestType = TypeVar("HTTPRequestType")

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -419,6 +421,9 @@ class RetryPolicy(RetryPolicyBase, HTTPPolicy[HttpRequest, HttpResponse]):
:keyword int timeout: Timeout setting for the operation in seconds, default is 604800s (7 days).
"""

next: "HTTPPolicy[HTTPRequestType, HTTPResponseType]"
"""Pointer to the next policy or a transport (wrapped as a policy). Will be set at pipeline creation."""

def _sleep_for_retry(
self,
response: PipelineResponse[HttpRequest, HttpResponse],
Expand Down
7 changes: 6 additions & 1 deletion sdk/core/corehttp/corehttp/runtime/policies/_retry_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
This module is the requests implementation of Pipeline ABC
"""
from __future__ import annotations
from typing import Dict, Any, Optional, cast, TYPE_CHECKING
from typing import Dict, Any, Optional, cast, TypeVar, TYPE_CHECKING
import logging
import time

Expand All @@ -46,6 +46,8 @@

_LOGGER = logging.getLogger(__name__)

AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType")
HTTPRequestType = TypeVar("HTTPRequestType")

class AsyncRetryPolicy(RetryPolicyBase, AsyncHTTPPolicy[HttpRequest, AsyncHttpResponse]):
"""Async flavor of the retry policy.
Expand Down Expand Up @@ -74,6 +76,9 @@ class AsyncRetryPolicy(RetryPolicyBase, AsyncHTTPPolicy[HttpRequest, AsyncHttpRe
:keyword int retry_backoff_max: The maximum back off time. Default value is 120 seconds (2 minutes).
"""

next: "AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]"
"""Pointer to the next policy or a transport (wrapped as a policy). Will be set at pipeline creation."""

async def _sleep_for_retry(
self,
response: PipelineResponse[HttpRequest, AsyncHttpResponse],
Expand Down
43 changes: 42 additions & 1 deletion sdk/core/corehttp/corehttp/runtime/policies/_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import xml.etree.ElementTree as ET
import types
import re
from typing import IO, cast, Union, Optional, AnyStr, Dict, Any, Mapping, TYPE_CHECKING
from typing import IO, cast, Union, Optional, AnyStr, Dict, Any, Mapping, Awaitable, TYPE_CHECKING

from ... import __version__ as core_version
from ...exceptions import DecodeError
Expand Down Expand Up @@ -98,6 +98,19 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
if additional_headers:
request.http_request.headers.update(additional_headers)

def on_response( # pylint: disable=unused-argument
self,
request: PipelineRequest[HTTPRequestType],
response: PipelineResponse[HTTPRequestType, HTTPResponseType],
) -> Union[None, Awaitable[None]]:
"""Is executed after the request comes back from the policy.

:param request: Request to be modified after returning from the policy.
:type request: ~corehttp.runtime.pipeline.PipelineRequest
:param response: Pipeline response object
:type response: ~corehttp.runtime.pipeline.PipelineResponse
"""
return None

class UserAgentPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
"""User-Agent Policy. Allows custom values to be added to the User-Agent header.
Expand Down Expand Up @@ -170,6 +183,20 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
elif self.overwrite or self._USERAGENT not in http_request.headers:
http_request.headers[self._USERAGENT] = self.user_agent

def on_response( # pylint: disable=unused-argument
self,
request: PipelineRequest[HTTPRequestType],
response: PipelineResponse[HTTPRequestType, HTTPResponseType],
) -> Union[None, Awaitable[None]]:
"""Is executed after the request comes back from the policy.

:param request: Request to be modified after returning from the policy.
:type request: ~corehttp.runtime.pipeline.PipelineRequest
:param response: Pipeline response object
:type response: ~corehttp.runtime.pipeline.PipelineResponse
"""
return None


class NetworkTraceLoggingPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):

Expand Down Expand Up @@ -462,3 +489,17 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
ctxt = request.context.options
if self.proxies and "proxies" not in ctxt:
ctxt["proxies"] = self.proxies

def on_response( # pylint: disable=unused-argument
self,
request: PipelineRequest[HTTPRequestType],
response: PipelineResponse[HTTPRequestType, HTTPResponseType],
) -> Union[None, Awaitable[None]]:
"""Is executed after the request comes back from the policy.

:param request: Request to be modified after returning from the policy.
:type request: ~corehttp.runtime.pipeline.PipelineRequest
:param response: Pipeline response object
:type response: ~corehttp.runtime.pipeline.PipelineResponse
"""
return None