diff --git a/sdk/core/azure-mgmt-core/azure/mgmt/core/_async_pipeline_client.py b/sdk/core/azure-mgmt-core/azure/mgmt/core/_async_pipeline_client.py index 2e71532e72be..289e82ea70a0 100644 --- a/sdk/core/azure-mgmt-core/azure/mgmt/core/_async_pipeline_client.py +++ b/sdk/core/azure-mgmt-core/azure/mgmt/core/_async_pipeline_client.py @@ -23,12 +23,8 @@ # IN THE SOFTWARE. # # -------------------------------------------------------------------------- +from collections.abc import Iterable from azure.core import AsyncPipelineClient -from azure.core.pipeline.policies import ( - ContentDecodePolicy, - DistributedTracingPolicy, - RequestIdPolicy, -) from .policies import AsyncARMAutoResourceProviderRegistrationPolicy, ARMHttpLoggingPolicy @@ -37,8 +33,14 @@ class AsyncARMPipelineClient(AsyncPipelineClient): :param str base_url: URL for the request. :keyword AsyncPipeline pipeline: If omitted, a Pipeline object is created and returned. - :keyword list[HTTPPolicy] policies: If omitted, the standard policies of the configuration object is used. - :keyword HttpTransport transport: If omitted, RequestsTransport is used for synchronous transport. + :keyword list[AsyncHTTPPolicy] policies: If omitted, the standard policies of the configuration object is used. + :keyword per_call_policies: If specified, the policies will be added into the policy list before RetryPolicy + :paramtype per_call_policies: Union[AsyncHTTPPolicy, SansIOHTTPPolicy, + list[AsyncHTTPPolicy], list[SansIOHTTPPolicy]] + :keyword per_retry_policies: If specified, the policies will be added into the policy list after RetryPolicy + :paramtype per_retry_policies: Union[AsyncHTTPPolicy, SansIOHTTPPolicy, + list[AsyncHTTPPolicy], list[SansIOHTTPPolicy]] + :keyword AsyncHttpTransport transport: If omitted, AioHttpTransport is used for asynchronous transport. """ def __init__(self, base_url, **kwargs): @@ -47,23 +49,15 @@ def __init__(self, base_url, **kwargs): raise ValueError( "Current implementation requires to pass 'config' if you don't pass 'policies'" ) - kwargs["policies"] = self._default_policies(**kwargs) + per_call_policies = kwargs.get('per_call_policies', []) + if isinstance(per_call_policies, Iterable): + per_call_policies.append(AsyncARMAutoResourceProviderRegistrationPolicy()) + else: + per_call_policies = [per_call_policies, + AsyncARMAutoResourceProviderRegistrationPolicy()] + kwargs["per_call_policies"] = per_call_policies + config = kwargs.get('config') + if not config.http_logging_policy: + config.http_logging_policy = kwargs.get('http_logging_policy', ARMHttpLoggingPolicy(**kwargs)) + kwargs["config"] = config super(AsyncARMPipelineClient, self).__init__(base_url, **kwargs) - - @staticmethod - def _default_policies(config, **kwargs): - return [ - RequestIdPolicy(**kwargs), - AsyncARMAutoResourceProviderRegistrationPolicy(), - config.headers_policy, - config.user_agent_policy, - config.proxy_policy, - ContentDecodePolicy(**kwargs), - config.redirect_policy, - config.retry_policy, - config.authentication_policy, - config.custom_hook_policy, - config.logging_policy, - DistributedTracingPolicy(**kwargs), - config.http_logging_policy or ARMHttpLoggingPolicy(**kwargs), - ] diff --git a/sdk/core/azure-mgmt-core/azure/mgmt/core/_pipeline_client.py b/sdk/core/azure-mgmt-core/azure/mgmt/core/_pipeline_client.py index 0280e2408ac1..b6570a2d2d02 100644 --- a/sdk/core/azure-mgmt-core/azure/mgmt/core/_pipeline_client.py +++ b/sdk/core/azure-mgmt-core/azure/mgmt/core/_pipeline_client.py @@ -23,12 +23,11 @@ # IN THE SOFTWARE. # # -------------------------------------------------------------------------- +try: + from collections.abc import Iterable +except ImportError: + from collections import Iterable from azure.core import PipelineClient -from azure.core.pipeline.policies import ( - ContentDecodePolicy, - DistributedTracingPolicy, - RequestIdPolicy, -) from .policies import ARMAutoResourceProviderRegistrationPolicy, ARMHttpLoggingPolicy @@ -38,6 +37,10 @@ class ARMPipelineClient(PipelineClient): :param str base_url: URL for the request. :keyword Pipeline pipeline: If omitted, a Pipeline object is created and returned. :keyword list[HTTPPolicy] policies: If omitted, the standard policies of the configuration object is used. + :keyword per_call_policies: If specified, the policies will be added into the policy list before RetryPolicy + :paramtype per_call_policies: Union[HTTPPolicy, SansIOHTTPPolicy, list[HTTPPolicy], list[SansIOHTTPPolicy]] + :keyword per_retry_policies: If specified, the policies will be added into the policy list after RetryPolicy + :paramtype per_retry_policies: Union[HTTPPolicy, SansIOHTTPPolicy, list[HTTPPolicy], list[SansIOHTTPPolicy]] :keyword HttpTransport transport: If omitted, RequestsTransport is used for synchronous transport. """ @@ -47,23 +50,15 @@ def __init__(self, base_url, **kwargs): raise ValueError( "Current implementation requires to pass 'config' if you don't pass 'policies'" ) - kwargs["policies"] = self._default_policies(**kwargs) + per_call_policies = kwargs.get('per_call_policies', []) + if isinstance(per_call_policies, Iterable): + per_call_policies.append(ARMAutoResourceProviderRegistrationPolicy()) + else: + per_call_policies = [per_call_policies, + ARMAutoResourceProviderRegistrationPolicy()] + kwargs["per_call_policies"] = per_call_policies + config = kwargs.get('config') + if not config.http_logging_policy: + config.http_logging_policy = kwargs.get('http_logging_policy', ARMHttpLoggingPolicy(**kwargs)) + kwargs["config"] = config super(ARMPipelineClient, self).__init__(base_url, **kwargs) - - @staticmethod - def _default_policies(config, **kwargs): - return [ - RequestIdPolicy(**kwargs), - ARMAutoResourceProviderRegistrationPolicy(), - config.headers_policy, - config.user_agent_policy, - config.proxy_policy, - ContentDecodePolicy(**kwargs), - config.redirect_policy, - config.retry_policy, - config.authentication_policy, - config.custom_hook_policy, - config.logging_policy, - DistributedTracingPolicy(**kwargs), - config.http_logging_policy or ARMHttpLoggingPolicy(**kwargs), - ] diff --git a/sdk/core/azure-mgmt-core/setup.py b/sdk/core/azure-mgmt-core/setup.py index d24cdea8c8a3..7d55cdfdc70d 100644 --- a/sdk/core/azure-mgmt-core/setup.py +++ b/sdk/core/azure-mgmt-core/setup.py @@ -68,7 +68,7 @@ 'pytyped': ['py.typed'], }, install_requires=[ - "azure-core<2.0.0,>=1.9.0", + "azure-core<2.0.0,>=1.13.0", ], extras_require={ ":python_version<'3.0'": ['azure-mgmt-nspkg'], diff --git a/sdk/core/azure-mgmt-core/tests/asynctests/test_policies_async.py b/sdk/core/azure-mgmt-core/tests/asynctests/test_policies_async.py index f34c92cd514d..d922594dbf1a 100644 --- a/sdk/core/azure-mgmt-core/tests/asynctests/test_policies_async.py +++ b/sdk/core/azure-mgmt-core/tests/asynctests/test_policies_async.py @@ -30,7 +30,7 @@ def test_default_http_logging_policy(): config = Configuration() pipeline_client = AsyncARMPipelineClient(base_url="test", config=config) - http_logging_policy = pipeline_client._default_policies(config=config)[-1] + http_logging_policy = pipeline_client._pipeline._impl_policies[-1]._policy assert http_logging_policy.allowed_header_names == ARMHttpLoggingPolicy.DEFAULT_HEADERS_WHITELIST def test_pass_in_http_logging_policy(): @@ -42,5 +42,5 @@ def test_pass_in_http_logging_policy(): config.http_logging_policy = http_logging_policy pipeline_client = AsyncARMPipelineClient(base_url="test", config=config) - http_logging_policy = pipeline_client._default_policies(config=config)[-1] + http_logging_policy = pipeline_client._pipeline._impl_policies[-1]._policy assert http_logging_policy.allowed_header_names == ARMHttpLoggingPolicy.DEFAULT_HEADERS_WHITELIST.union({"x-ms-added-header"}) \ No newline at end of file diff --git a/sdk/core/azure-mgmt-core/tests/test_policies.py b/sdk/core/azure-mgmt-core/tests/test_policies.py index 19bc2fbb93cd..7153e8e50486 100644 --- a/sdk/core/azure-mgmt-core/tests/test_policies.py +++ b/sdk/core/azure-mgmt-core/tests/test_policies.py @@ -171,7 +171,7 @@ def test_register_failed_policy(): def test_default_http_logging_policy(): config = Configuration() pipeline_client = ARMPipelineClient(base_url="test", config=config) - http_logging_policy = pipeline_client._default_policies(config=config)[-1] + http_logging_policy = pipeline_client._pipeline._impl_policies[-1]._policy assert http_logging_policy.allowed_header_names == ARMHttpLoggingPolicy.DEFAULT_HEADERS_WHITELIST def test_pass_in_http_logging_policy(): @@ -183,5 +183,5 @@ def test_pass_in_http_logging_policy(): config.http_logging_policy = http_logging_policy pipeline_client = ARMPipelineClient(base_url="test", config=config) - http_logging_policy = pipeline_client._default_policies(config=config)[-1] + http_logging_policy = pipeline_client._pipeline._impl_policies[-1]._policy assert http_logging_policy.allowed_header_names == ARMHttpLoggingPolicy.DEFAULT_HEADERS_WHITELIST.union({"x-ms-added-header"}) \ No newline at end of file diff --git a/shared_requirements.txt b/shared_requirements.txt index 1d205e9c3f63..cc5eedb6ba12 100644 --- a/shared_requirements.txt +++ b/shared_requirements.txt @@ -122,7 +122,7 @@ six>=1.11.0 isodate>=0.6.0 avro<2.0.0,>=1.10.0 #override azure azure-keyvault~=1.0 -#override azure-mgmt-core azure-core<2.0.0,>=1.9.0 +#override azure-mgmt-core azure-core<2.0.0,>=1.13.0 #override azure-containerregistry azure-core>=1.4.0,<2.0.0 #override azure-core-tracing-opencensus azure-core<2.0.0,>=1.0.0 #override azure-core-tracing-opentelemetry azure-core<2.0.0,>=1.13.0