Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion sdk/core/azure-core/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Release History

## 1.12.1 (Unreleased)
## 1.13.0 (Unreleased)

### Features

- Supported adding custom policies #16519


## 1.12.0 (2021-03-08)
Expand Down
43 changes: 34 additions & 9 deletions sdk/core/azure-core/azure/core/_pipeline_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,18 @@
# --------------------------------------------------------------------------

import logging
try:
from collections.abc import Iterable
except ImportError:
from collections import Iterable
from .configuration import Configuration
from .pipeline import Pipeline
from .pipeline.transport._base import PipelineClientBase
from .pipeline.policies import (
ContentDecodePolicy, DistributedTracingPolicy, HttpLoggingPolicy, RequestIdPolicy
ContentDecodePolicy,
DistributedTracingPolicy,
HttpLoggingPolicy,
RequestIdPolicy,
)
from .pipeline.transport import RequestsTransport

Expand Down Expand Up @@ -64,6 +71,10 @@ class PipelineClient(PipelineClientBase):
:keyword ~azure.core.configuration.Configuration config: If omitted, the standard configuration is used.
:keyword Pipeline pipeline: If omitted, a Pipeline object is created and returned.
:keyword list[HTTPPolicy] policies: If omitted, the standard policies of the configuration object is used.
:keyword per_call_policies: If specified, the policies will be added into the policy list before RetryPolicy
:paramtype per_call_policies: Union[HTTPPolicy, SansIOHTTPPolicy, list[HTTPPolicy], list[SansIOHTTPPolicy]]
:keyword per_retry_policies: If specified, the policies will be added into the policy list after RetryPolicy
:paramtype per_retry_policies: Union[HTTPPolicy, SansIOHTTPPolicy, list[HTTPPolicy], list[SansIOHTTPPolicy]]
:keyword HttpTransport transport: If omitted, RequestsTransport is used for synchronous transport.
:return: A pipeline object.
:rtype: ~azure.core.pipeline.Pipeline
Expand Down Expand Up @@ -102,20 +113,34 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use
policies = kwargs.get('policies')

if policies is None: # [] is a valid policy list
per_call_policies = kwargs.get('per_call_policies', [])
per_retry_policies = kwargs.get('per_retry_policies', [])
policies = [
RequestIdPolicy(**kwargs),
config.headers_policy,
config.user_agent_policy,
config.proxy_policy,
ContentDecodePolicy(**kwargs),
config.redirect_policy,
config.retry_policy,
config.authentication_policy,
config.custom_hook_policy,
config.logging_policy,
DistributedTracingPolicy(**kwargs),
config.http_logging_policy or HttpLoggingPolicy(**kwargs)
ContentDecodePolicy(**kwargs)
]
if isinstance(per_call_policies, Iterable):
for policy in per_call_policies:
policies.append(policy)
Comment on lines +125 to +127
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to explicitly iterate over per_call_policies. list.extend(iterable) does the same thing.

else:
policies.append(per_call_policies)

policies = policies + [config.redirect_policy,
config.retry_policy,
config.authentication_policy,
config.custom_hook_policy]
Comment on lines +131 to +134
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes a new list object to be created and assigned to policies. list.extend(iterable) is a better choice as it modifies policies directly.

if isinstance(per_retry_policies, Iterable):
for policy in per_retry_policies:
policies.append(policy)
else:
policies.append(per_retry_policies)

policies = policies + [config.logging_policy,
DistributedTracingPolicy(**kwargs),
config.http_logging_policy or HttpLoggingPolicy(**kwargs)]

if not transport:
transport = RequestsTransport(**kwargs)
Expand Down
38 changes: 33 additions & 5 deletions sdk/core/azure-core/azure/core/_pipeline_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,15 @@
# --------------------------------------------------------------------------

import logging
from collections.abc import Iterable
from .configuration import Configuration
from .pipeline import AsyncPipeline
from .pipeline.transport._base import PipelineClientBase
from .pipeline.policies import (
ContentDecodePolicy, DistributedTracingPolicy, HttpLoggingPolicy, RequestIdPolicy
ContentDecodePolicy,
DistributedTracingPolicy,
HttpLoggingPolicy,
RequestIdPolicy,
)

try:
Expand Down Expand Up @@ -62,8 +66,14 @@ class AsyncPipelineClient(PipelineClientBase):
:param str base_url: URL for the request.
:keyword ~azure.core.configuration.Configuration config: If omitted, the standard configuration is used.
:keyword Pipeline pipeline: If omitted, a Pipeline object is created and returned.
:keyword list[HTTPPolicy] policies: If omitted, the standard policies of the configuration object is used.
:keyword HttpTransport transport: If omitted, RequestsTransport is used for synchronous transport.
:keyword list[AsyncHTTPPolicy] policies: If omitted, the standard policies of the configuration object is used.
:keyword per_call_policies: If specified, the policies will be added into the policy list before RetryPolicy
:paramtype per_call_policies: Union[AsyncHTTPPolicy, SansIOHTTPPolicy,
list[AsyncHTTPPolicy], list[SansIOHTTPPolicy]]
:keyword per_retry_policies: If specified, the policies will be added into the policy list after RetryPolicy
:paramtype per_retry_policies: Union[AsyncHTTPPolicy, SansIOHTTPPolicy,
list[AsyncHTTPPolicy], list[SansIOHTTPPolicy]]
:keyword AsyncHttpTransport transport: If omitted, AioHttpTransport is used for synchronous transport.
:return: An async pipeline object.
:rtype: ~azure.core.pipeline.AsyncPipeline

Expand Down Expand Up @@ -101,16 +111,34 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use
policies = kwargs.get('policies')

if policies is None: # [] is a valid policy list
per_call_policies = kwargs.get('per_call_policies', [])
per_retry_policies = kwargs.get('per_retry_policies', [])
policies = [
RequestIdPolicy(**kwargs),
config.headers_policy,
config.user_agent_policy,
config.proxy_policy,
ContentDecodePolicy(**kwargs),
ContentDecodePolicy(**kwargs)
]
if isinstance(per_call_policies, Iterable):
for policy in per_call_policies:
policies.append(policy)
else:
policies.append(per_call_policies)

policies = policies + [
config.redirect_policy,
config.retry_policy,
config.authentication_policy,
config.custom_hook_policy,
config.custom_hook_policy
]
if isinstance(per_retry_policies, Iterable):
for policy in per_retry_policies:
policies.append(policy)
else:
policies.append(per_retry_policies)

policies = policies + [
config.logging_policy,
DistributedTracingPolicy(**kwargs),
config.http_logging_policy or HttpLoggingPolicy(**kwargs)
Expand Down
2 changes: 1 addition & 1 deletion sdk/core/azure-core/azure/core/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
# regenerated.
# --------------------------------------------------------------------------

VERSION = "1.12.1"
VERSION = "1.13.0"
40 changes: 39 additions & 1 deletion sdk/core/azure-core/tests/async_tests/test_pipeline_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,42 @@ def send(*args):
policies = [AsyncRetryPolicy(), NaughtyPolicy()]
pipeline = AsyncPipeline(policies=policies, transport=None)
with pytest.raises(AzureError):
await pipeline.run(HttpRequest('GET', url='https://foo.bar'))
await pipeline.run(HttpRequest('GET', url='https://foo.bar'))

@pytest.mark.asyncio
async def test_add_custom_policy():
class BooPolicy(AsyncHTTPPolicy):
def send(*args):
raise AzureError('boo')

class FooPolicy(AsyncHTTPPolicy):
def send(*args):
raise AzureError('boo')

boo_policy = BooPolicy()
foo_policy = FooPolicy()
client = AsyncPipelineClient(base_url="test", per_call_policies=boo_policy)
policies = client._pipeline._impl_policies
assert boo_policy in policies

client = AsyncPipelineClient(base_url="test", per_call_policies=[boo_policy])
policies = client._pipeline._impl_policies
assert boo_policy in policies

client = AsyncPipelineClient(base_url="test", per_retry_policies=boo_policy)
policies = client._pipeline._impl_policies
assert boo_policy in policies
client = AsyncPipelineClient(base_url="test", per_retry_policies=[boo_policy])
policies = client._pipeline._impl_policies
assert boo_policy in policies

client = AsyncPipelineClient(base_url="test", per_call_policies=boo_policy, per_retry_policies=foo_policy)
policies = client._pipeline._impl_policies
assert boo_policy in policies
assert foo_policy in policies

client = AsyncPipelineClient(base_url="test", per_call_policies=[boo_policy],
per_retry_policies=[foo_policy])
policies = client._pipeline._impl_policies
assert boo_policy in policies
assert foo_policy in policies
41 changes: 40 additions & 1 deletion sdk/core/azure-core/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@
SansIOHTTPPolicy,
UserAgentPolicy,
RedirectPolicy,
HttpLoggingPolicy
HttpLoggingPolicy,
HTTPPolicy,
SansIOHTTPPolicy
)
from azure.core.pipeline.transport._base import PipelineClientBase
from azure.core.pipeline.transport import (
Expand Down Expand Up @@ -332,6 +334,43 @@ def test_repr(self):
request = HttpRequest("GET", "hello.com")
assert repr(request) == "<HttpRequest [GET], url: 'hello.com'>"

def test_add_custom_policy(self):
class BooPolicy(HTTPPolicy):
def send(*args):
raise AzureError('boo')

class FooPolicy(HTTPPolicy):
def send(*args):
raise AzureError('boo')

boo_policy = BooPolicy()
foo_policy = FooPolicy()
client = PipelineClient(base_url="test", per_call_policies=boo_policy)
policies = client._pipeline._impl_policies
assert boo_policy in policies

client = PipelineClient(base_url="test", per_call_policies=[boo_policy])
policies = client._pipeline._impl_policies
assert boo_policy in policies

client = PipelineClient(base_url="test", per_retry_policies=boo_policy)
policies = client._pipeline._impl_policies
assert boo_policy in policies
client = PipelineClient(base_url="test", per_retry_policies=[boo_policy])
policies = client._pipeline._impl_policies
assert boo_policy in policies

client = PipelineClient(base_url="test", per_call_policies=boo_policy, per_retry_policies=foo_policy)
policies = client._pipeline._impl_policies
assert boo_policy in policies
assert foo_policy in policies

client = PipelineClient(base_url="test", per_call_policies=[boo_policy],
per_retry_policies=[foo_policy])
policies = client._pipeline._impl_policies
assert boo_policy in policies
assert foo_policy in policies


if __name__ == "__main__":
unittest.main()