Skip to content

Commit 962391c

Browse files
authored
[Cosmos] Reconfigure retry policy (#7544)
* Reconfigure retry policy * Review feedback * Fix pylint * Updated tests * Support client-side timeout * Updated timeout logic * Renamed client error * Updated tests * Patch azure-core Needed pending PR 7542 * Fixed status retry tests * Using dev core
1 parent 680aaf7 commit 962391c

File tree

11 files changed

+273
-47
lines changed

11 files changed

+273
-47
lines changed

sdk/core/azure-core/azure/core/pipeline/transport/requests_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def send(self, request, **kwargs): # type: ignore
246246
allow_redirects=False,
247247
**kwargs)
248248

249-
except urllib3.exceptions.NewConnectionError as err:
249+
except (urllib3.exceptions.NewConnectionError, urllib3.exceptions.ConnectTimeoutError) as err:
250250
error = ServiceRequestError(err, error=err)
251251
except requests.exceptions.ReadTimeout as err:
252252
error = ServiceResponseError(err, error=err)

sdk/cosmos/azure-cosmos/azure/cosmos/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2020
# SOFTWARE.
2121

22+
from ._retry_utility import ConnectionRetryPolicy
2223
from .container import ContainerProxy
2324
from .cosmos_client import CosmosClient
2425
from .database import DatabaseProxy
@@ -56,5 +57,6 @@
5657
"SSLConfiguration",
5758
"TriggerOperation",
5859
"TriggerType",
60+
"ConnectionRetryPolicy",
5961
)
6062
__version__ = VERSION

sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,11 @@
2626
"""
2727
from typing import Dict, Any, Optional
2828
import six
29-
import requests
30-
from requests.adapters import HTTPAdapter
29+
from requests.packages.urllib3.util.retry import Retry # pylint: disable=import-error
3130
from azure.core.paging import ItemPaged # type: ignore
3231
from azure.core import PipelineClient # type: ignore
33-
from azure.core.pipeline.transport import RequestsTransport
3432
from azure.core.pipeline.policies import ( # type: ignore
33+
HTTPPolicy,
3534
ContentDecodePolicy,
3635
HeadersPolicy,
3736
UserAgentPolicy,
@@ -51,6 +50,7 @@
5150
from . import _synchronized_request as synchronized_request
5251
from . import _global_endpoint_manager as global_endpoint_manager
5352
from ._routing import routing_map_provider
53+
from ._retry_utility import ConnectionRetryPolicy
5454
from . import _session
5555
from . import _utils
5656
from .partition_key import _Undefined, _Empty
@@ -151,15 +151,24 @@ def __init__(
151151
self._useMultipleWriteLocations = False
152152
self._global_endpoint_manager = global_endpoint_manager._GlobalEndpointManager(self)
153153

154-
# creating a requests session used for connection pooling and re-used by all requests
155-
requests_session = requests.Session()
156-
157-
transport = None
158-
if self.connection_policy.ConnectionRetryConfiguration is not None:
159-
adapter = HTTPAdapter(max_retries=self.connection_policy.ConnectionRetryConfiguration)
160-
requests_session.mount('http://', adapter)
161-
requests_session.mount('https://', adapter)
162-
transport = RequestsTransport(session=requests_session)
154+
retry_policy = None
155+
if isinstance(self.connection_policy.ConnectionRetryConfiguration, HTTPPolicy):
156+
retry_policy = self.connection_policy.ConnectionRetryConfiguration
157+
elif isinstance(self.connection_policy.ConnectionRetryConfiguration, int):
158+
retry_policy = ConnectionRetryPolicy(total=self.connection_policy.ConnectionRetryConfiguration)
159+
elif isinstance(self.connection_policy.ConnectionRetryConfiguration, Retry):
160+
# Convert a urllib3 retry policy to a Pipeline policy
161+
retry_policy = ConnectionRetryPolicy(
162+
retry_total=self.connection_policy.ConnectionRetryConfiguration.total,
163+
retry_connect=self.connection_policy.ConnectionRetryConfiguration.connect,
164+
retry_read=self.connection_policy.ConnectionRetryConfiguration.read,
165+
retry_status=self.connection_policy.ConnectionRetryConfiguration.status,
166+
retry_backoff_max=self.connection_policy.ConnectionRetryConfiguration.BACKOFF_MAX,
167+
retry_on_status_codes=list(self.connection_policy.ConnectionRetryConfiguration.status_forcelist),
168+
retry_backoff_factor=self.connection_policy.ConnectionRetryConfiguration.backoff_factor
169+
)
170+
else:
171+
TypeError("Unsupported retry policy. Must be an azure.cosmos.ConnectionRetryPolicy, int, or urllib3.Retry")
163172

164173
proxies = kwargs.pop('proxies', {})
165174
if self.connection_policy.ProxyConfiguration and self.connection_policy.ProxyConfiguration.Host:
@@ -173,11 +182,13 @@ def __init__(
173182
ProxyPolicy(proxies=proxies),
174183
UserAgentPolicy(base_user_agent=_utils.get_user_agent(), **kwargs),
175184
ContentDecodePolicy(),
185+
retry_policy,
176186
CustomHookPolicy(**kwargs),
177187
DistributedTracingPolicy(),
178188
NetworkTraceLoggingPolicy(**kwargs),
179189
]
180190

191+
transport = kwargs.pop("transport", None)
181192
self.pipeline_client = PipelineClient(url_connection, "empty-config", transport=transport, policies=policies)
182193

183194
# Query compatibility mode.
@@ -188,7 +199,7 @@ def __init__(
188199
# Routing map provider
189200
self._routing_map_provider = routing_map_provider.SmartRoutingMapProvider(self)
190201

191-
database_account = self._global_endpoint_manager._GetDatabaseAccount()
202+
database_account = self._global_endpoint_manager._GetDatabaseAccount(**kwargs)
192203
self._global_endpoint_manager.force_refresh(database_account)
193204

194205
@property

sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,17 +87,17 @@ def force_refresh(self, database_account):
8787
self.refresh_needed = True
8888
self.refresh_endpoint_list(database_account)
8989

90-
def refresh_endpoint_list(self, database_account):
90+
def refresh_endpoint_list(self, database_account, **kwargs):
9191
with self.refresh_lock:
9292
# if refresh is not needed or refresh is already taking place, return
9393
if not self.refresh_needed:
9494
return
9595
try:
96-
self._refresh_endpoint_list_private(database_account)
96+
self._refresh_endpoint_list_private(database_account, **kwargs)
9797
except Exception as e:
9898
raise e
9999

100-
def _refresh_endpoint_list_private(self, database_account=None):
100+
def _refresh_endpoint_list_private(self, database_account=None, **kwargs):
101101
if database_account:
102102
self.location_cache.perform_on_database_account_read(database_account)
103103
self.refresh_needed = False
@@ -107,18 +107,18 @@ def _refresh_endpoint_list_private(self, database_account=None):
107107
and self.location_cache.current_time_millis() - self.last_refresh_time > self.refresh_time_interval_in_ms
108108
):
109109
if not database_account:
110-
database_account = self._GetDatabaseAccount()
110+
database_account = self._GetDatabaseAccount(**kwargs)
111111
self.location_cache.perform_on_database_account_read(database_account)
112112
self.last_refresh_time = self.location_cache.current_time_millis()
113113
self.refresh_needed = False
114114

115-
def _GetDatabaseAccount(self):
115+
def _GetDatabaseAccount(self, **kwargs):
116116
"""Gets the database account first by using the default endpoint, and if that doesn't returns
117117
use the endpoints for the preferred locations in the order they are specified to get
118118
the database account.
119119
"""
120120
try:
121-
database_account = self._GetDatabaseAccountStub(self.DefaultEndpoint)
121+
database_account = self._GetDatabaseAccountStub(self.DefaultEndpoint, **kwargs)
122122
return database_account
123123
# If for any reason(non-globaldb related), we are not able to get the database
124124
# account from the above call to GetDatabaseAccount, we would try to get this
@@ -130,18 +130,18 @@ def _GetDatabaseAccount(self):
130130
for location_name in self.PreferredLocations:
131131
locational_endpoint = _GlobalEndpointManager.GetLocationalEndpoint(self.DefaultEndpoint, location_name)
132132
try:
133-
database_account = self._GetDatabaseAccountStub(locational_endpoint)
133+
database_account = self._GetDatabaseAccountStub(locational_endpoint, **kwargs)
134134
return database_account
135135
except errors.CosmosHttpResponseError:
136136
pass
137137

138138
return None
139139

140-
def _GetDatabaseAccountStub(self, endpoint):
140+
def _GetDatabaseAccountStub(self, endpoint, **kwargs):
141141
"""Stub for getting database account from the client
142142
which can be used for mocking purposes as well.
143143
"""
144-
return self.Client.GetDatabaseAccount(endpoint)
144+
return self.Client.GetDatabaseAccount(endpoint, **kwargs)
145145

146146
@staticmethod
147147
def GetLocationalEndpoint(default_endpoint, location_name):

sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424

2525
import time
2626

27+
from azure.core.exceptions import AzureError, ClientAuthenticationError
28+
from azure.core.pipeline.policies import RetryPolicy
29+
2730
from . import errors
2831
from . import _endpoint_discovery_retry_policy
2932
from . import _resource_throttle_retry_policy
@@ -64,6 +67,8 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs):
6467
)
6568
while True:
6669
try:
70+
client_timeout = kwargs.get('timeout')
71+
start_time = time.time()
6772
if args:
6873
result = ExecuteFunction(function, global_endpoint_manager, *args, **kwargs)
6974
else:
@@ -113,9 +118,92 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs):
113118

114119
# Wait for retry_after_in_milliseconds time before the next retry
115120
time.sleep(retry_policy.retry_after_in_milliseconds / 1000.0)
121+
if client_timeout:
122+
kwargs['timeout'] = client_timeout - (time.time() - start_time)
123+
if kwargs['timeout'] <= 0:
124+
raise errors.CosmosClientTimeoutError()
116125

117126

118127
def ExecuteFunction(function, *args, **kwargs):
119128
""" Stub method so that it can be used for mocking purposes as well.
120129
"""
121130
return function(*args, **kwargs)
131+
132+
133+
def _configure_timeout(request, absolute, per_request):
134+
# type: (azure.core.pipeline.PipelineRequest, Optional[int], int) -> Optional[AzureError]
135+
if absolute is not None:
136+
if absolute <= 0:
137+
raise errors.CosmosClientTimeoutError()
138+
if per_request:
139+
# Both socket timeout and client timeout have been provided - use the shortest value.
140+
request.context.options['connection_timeout'] = min(per_request, absolute)
141+
else:
142+
# Only client timeout provided.
143+
request.context.options['connection_timeout'] = absolute
144+
elif per_request:
145+
# Only socket timeout provided.
146+
request.context.options['connection_timeout'] = per_request
147+
148+
149+
class ConnectionRetryPolicy(RetryPolicy):
150+
151+
def __init__(self, **kwargs):
152+
clean_kwargs = {k: v for k, v in kwargs.items() if v is not None}
153+
super(ConnectionRetryPolicy, self).__init__(**clean_kwargs)
154+
155+
def send(self, request):
156+
"""Sends the PipelineRequest object to the next policy. Uses retry settings if necessary.
157+
Also enforces an absolute client-side timeout that spans multiple retry attempts.
158+
159+
:param request: The PipelineRequest object
160+
:type request: ~azure.core.pipeline.PipelineRequest
161+
:return: Returns the PipelineResponse or raises error if maximum retries exceeded.
162+
:rtype: ~azure.core.pipeline.PipelineResponse
163+
:raises: ~azure.core.exceptions.AzureError if maximum retries exceeded.
164+
:raises: ~azure.cosmos.CosmosClientTimeoutError if specified timeout exceeded.
165+
:raises: ~azure.core.exceptions.ClientAuthenticationError if authentication
166+
"""
167+
absolute_timeout = request.context.options.pop('timeout', None)
168+
per_request_timeout = request.context.options.pop('connection_timeout', 0)
169+
170+
retry_error = None
171+
retry_active = True
172+
response = None
173+
retry_settings = self.configure_retries(request.context.options)
174+
while retry_active:
175+
try:
176+
start_time = time.time()
177+
_configure_timeout(request, absolute_timeout, per_request_timeout)
178+
179+
response = self.next.send(request)
180+
if self.is_retry(retry_settings, response):
181+
retry_active = self.increment(retry_settings, response=response)
182+
if retry_active:
183+
self.sleep(retry_settings, request.context.transport, response=response)
184+
continue
185+
break
186+
except ClientAuthenticationError: # pylint:disable=try-except-raise
187+
# the authentication policy failed such that the client's request can't
188+
# succeed--we'll never have a response to it, so propagate the exception
189+
raise
190+
except errors.CosmosClientTimeoutError as timeout_error:
191+
timeout_error.inner_exception = retry_error
192+
timeout_error.response = response
193+
timeout_error.history = retry_settings['history']
194+
raise
195+
except AzureError as err:
196+
retry_error = err
197+
if self._is_method_retryable(retry_settings, request.http_request):
198+
retry_active = self.increment(retry_settings, response=request, error=err)
199+
if retry_active:
200+
self.sleep(retry_settings, request.context.transport)
201+
continue
202+
raise err
203+
finally:
204+
end_time = time.time()
205+
if absolute_timeout:
206+
absolute_timeout -= (end_time - start_time)
207+
208+
self.update_context(response.context, retry_settings)
209+
return response

sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"""
2424

2525
import json
26+
import time
2627

2728
from six.moves.urllib.parse import urlparse
2829
import six
@@ -96,7 +97,13 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin
9697
connection_timeout = kwargs.pop("connection_timeout", connection_timeout / 1000.0)
9798

9899
# Every request tries to perform a refresh
99-
global_endpoint_manager.refresh_endpoint_list(None)
100+
client_timeout = kwargs.get('timeout')
101+
start_time = time.time()
102+
global_endpoint_manager.refresh_endpoint_list(None, **kwargs)
103+
if client_timeout is not None:
104+
kwargs['timeout'] = client_timeout - (time.time() - start_time)
105+
if kwargs['timeout'] <= 0:
106+
raise errors.CosmosClientTimeoutError()
100107

101108
if request_params.endpoint_override:
102109
base_url = request_params.endpoint_override
@@ -149,7 +156,7 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin
149156
return (response.stream_download(pipeline_client._pipeline), headers)
150157

151158
data = response.body()
152-
if not six.PY2:
159+
if data and not six.PY2:
153160
# python 3 compatible: convert data from byte to unicode string
154161
data = data.decode("utf-8")
155162

sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@
2222
"""Create, read, and delete databases in the Azure Cosmos DB SQL API service.
2323
"""
2424

25-
from typing import Any, Dict, Mapping, Optional, Union, cast, Iterable, List
25+
from typing import Any, Dict, Mapping, Optional, Union, cast, Iterable, List # pylint: disable=unused-import
2626

2727
import six
2828
from azure.core.tracing.decorator import distributed_trace # type: ignore
2929

3030
from ._cosmos_client_connection import CosmosClientConnection
3131
from ._base import build_options
32+
from ._retry_utility import ConnectionRetryPolicy
3233
from .database import DatabaseProxy
3334
from .documents import ConnectionPolicy, DatabaseAccount
3435
from .errors import CosmosResourceNotFoundError
@@ -96,11 +97,25 @@ def _build_connection_policy(kwargs):
9697

9798
# Retry config
9899
retry = kwargs.pop('retry_options', None) or policy.RetryOptions
99-
retry._max_retry_attempt_count = kwargs.pop('retry_total', None) or retry._max_retry_attempt_count
100+
total_retries = kwargs.pop('retry_total', None)
101+
retry._max_retry_attempt_count = total_retries or retry._max_retry_attempt_count
100102
retry._fixed_retry_interval_in_milliseconds = kwargs.pop('retry_fixed_interval', None) or \
101103
retry._fixed_retry_interval_in_milliseconds
102-
retry._max_wait_time_in_seconds = kwargs.pop('retry_backoff_max', None) or retry._max_wait_time_in_seconds
104+
max_backoff = kwargs.pop('retry_backoff_max', None)
105+
retry._max_wait_time_in_seconds = max_backoff or retry._max_wait_time_in_seconds
103106
policy.RetryOptions = retry
107+
connection_retry = kwargs.pop('connection_retry_policy', None) or policy.ConnectionRetryConfiguration
108+
if not connection_retry:
109+
connection_retry = ConnectionRetryPolicy(
110+
retry_total=total_retries,
111+
retry_connect=kwargs.pop('retry_connect', None),
112+
retry_read=kwargs.pop('retry_read', None),
113+
retry_status=kwargs.pop('retry_status', None),
114+
retry_backoff_max=max_backoff,
115+
retry_on_status_codes=kwargs.pop('retry_on_status_codes', []),
116+
retry_backoff_factor=kwargs.pop('retry_backoff_factor', 0.8),
117+
)
118+
policy.ConnectionRetryConfiguration = connection_retry
104119

105120
return policy
106121

@@ -130,6 +145,11 @@ class CosmosClient(object):
130145
*retry_total* - Maximum retry attempts.
131146
*retry_backoff_max* - Maximum retry wait time in seconds.
132147
*retry_fixed_interval* - Fixed retry interval in milliseconds.
148+
*retry_read* - Maximum number of socket read retry attempts.
149+
*retry_connect* - Maximum number of connection error retry attempts.
150+
*retry_status* - Maximum number of retry attempts on error status codes.
151+
*retry_on_status_codes* - A list of specific status codes to retry on.
152+
*retry_backoff_factor* - Factor to calculate wait time between retry attempts.
133153
*enable_endpoint_discovery* - Enable endpoint discovery for geo-replicated database accounts. Default is True.
134154
*preferred_locations* - The preferred locations for geo-replicated database accounts.
135155
When `enable_endpoint_discovery` is true and `preferred_locations` is non-empty,

sdk/cosmos/azure-cosmos/azure/cosmos/documents.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,8 +372,10 @@ class ConnectionPolicy(object): # pylint: disable=too-many-instance-attributes
372372
:ivar boolean UseMultipleWriteLocations:
373373
Flag to enable writes on any locations (regions) for geo-replicated database accounts
374374
in the azure Cosmos service.
375-
:ivar (int or requests.packages.urllib3.util.retry) ConnectionRetryConfiguration:
376-
Retry Configuration to be used for urllib3 connection retries.
375+
:ivar ConnectionRetryConfiguration:
376+
Retry Configuration to be used for connection retries.
377+
:vartype ConnectionRetryConfiguration:
378+
int or azure.cosmos.ConnectionRetryPolicy or requests.packages.urllib3.util.retry
377379
"""
378380

379381
__defaultRequestTimeout = 60000 # milliseconds

sdk/cosmos/azure-cosmos/azure/cosmos/errors.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,13 @@ class CosmosResourceExistsError(ResourceExistsError, CosmosHttpResponseError):
6363

6464
class CosmosAccessConditionFailedError(CosmosHttpResponseError):
6565
"""An error response with status code 412."""
66+
67+
68+
class CosmosClientTimeoutError(AzureError):
69+
"""An operation failed to complete within the specified timeout."""
70+
71+
def __init__(self, **kwargs):
72+
message = "Client operation failed to complete within specified timeout."
73+
self.response = None
74+
self.history = None
75+
super(CosmosClientTimeoutError, self).__init__(message, **kwargs)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
-e ../../../tools/azure-sdk-tools
2+
-e ../../core/azure-core

0 commit comments

Comments
 (0)