Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Support client-side timeout
  • Loading branch information
annatisch committed Oct 2, 2019
commit e8be3bf146f67fbf67057cbd6b38ed6e40b80399
2 changes: 2 additions & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from ._retry_utility import ConnectionRetryPolicy
from .container import ContainerProxy
from .cosmos_client import CosmosClient
from .database import DatabaseProxy
Expand Down Expand Up @@ -56,5 +57,6 @@
"SSLConfiguration",
"TriggerOperation",
"TriggerType",
"ConnectionRetryPolicy",
)
__version__ = VERSION
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
UserAgentPolicy,
NetworkTraceLoggingPolicy,
CustomHookPolicy,
RetryPolicy,
ProxyPolicy)
from azure.core.pipeline.policies.distributed_tracing import DistributedTracingPolicy # type: ignore

Expand All @@ -51,6 +50,7 @@
from . import _synchronized_request as synchronized_request
from . import _global_endpoint_manager as global_endpoint_manager
from ._routing import routing_map_provider
from ._retry_utility import ConnectionRetryPolicy
from . import _session
from . import _utils
from .partition_key import _Undefined, _Empty
Expand Down Expand Up @@ -155,10 +155,10 @@ def __init__(
if isinstance(self.connection_policy.ConnectionRetryConfiguration, HTTPPolicy):
retry_policy = self.connection_policy.ConnectionRetryConfiguration
elif isinstance(self.connection_policy.ConnectionRetryConfiguration, int):
retry_policy = RetryPolicy(total=self.connection_policy.ConnectionRetryConfiguration)
retry_policy = ConnectionRetryPolicy(total=self.connection_policy.ConnectionRetryConfiguration)
elif isinstance(self.connection_policy.ConnectionRetryConfiguration, Retry):
# Convert a urllib3 retry policy to a Pipeline policy
retry_policy = RetryPolicy(
retry_policy = ConnectionRetryPolicy(
retry_total=self.connection_policy.ConnectionRetryConfiguration.total,
retry_connect=self.connection_policy.ConnectionRetryConfiguration.connect,
retry_read=self.connection_policy.ConnectionRetryConfiguration.read,
Expand All @@ -168,7 +168,7 @@ def __init__(
retry_backoff_factor=self.connection_policy.ConnectionRetryConfiguration.backoff_factor
)
else:
TypeError("Unsupported retry policy. Must be an azure.core.RetryPolicy, integer, or urllib3.Retry")
TypeError("Unsupported retry policy. Must be an azure.cosmos.ConnectionRetryPolicy, int, or urllib3.Retry")

proxies = kwargs.pop('proxies', {})
if self.connection_policy.ProxyConfiguration and self.connection_policy.ProxyConfiguration.Host:
Expand Down Expand Up @@ -199,7 +199,7 @@ def __init__(
# Routing map provider
self._routing_map_provider = routing_map_provider.SmartRoutingMapProvider(self)

database_account = self._global_endpoint_manager._GetDatabaseAccount()
database_account = self._global_endpoint_manager._GetDatabaseAccount(**kwargs)
self._global_endpoint_manager.force_refresh(database_account)

@property
Expand Down
18 changes: 9 additions & 9 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,17 @@ def force_refresh(self, database_account):
self.refresh_needed = True
self.refresh_endpoint_list(database_account)

def refresh_endpoint_list(self, database_account):
def refresh_endpoint_list(self, database_account, **kwargs):
with self.refresh_lock:
# if refresh is not needed or refresh is already taking place, return
if not self.refresh_needed:
return
try:
self._refresh_endpoint_list_private(database_account)
self._refresh_endpoint_list_private(database_account, **kwargs)
except Exception as e:
raise e

def _refresh_endpoint_list_private(self, database_account=None):
def _refresh_endpoint_list_private(self, database_account=None, **kwargs):
if database_account:
self.location_cache.perform_on_database_account_read(database_account)
self.refresh_needed = False
Expand All @@ -107,18 +107,18 @@ def _refresh_endpoint_list_private(self, database_account=None):
and self.location_cache.current_time_millis() - self.last_refresh_time > self.refresh_time_interval_in_ms
):
if not database_account:
database_account = self._GetDatabaseAccount()
database_account = self._GetDatabaseAccount(**kwargs)
self.location_cache.perform_on_database_account_read(database_account)
self.last_refresh_time = self.location_cache.current_time_millis()
self.refresh_needed = False

def _GetDatabaseAccount(self):
def _GetDatabaseAccount(self, **kwargs):
"""Gets the database account first by using the default endpoint, and if that doesn't returns
use the endpoints for the preferred locations in the order they are specified to get
the database account.
"""
try:
database_account = self._GetDatabaseAccountStub(self.DefaultEndpoint)
database_account = self._GetDatabaseAccountStub(self.DefaultEndpoint, **kwargs)
return database_account
# If for any reason(non-globaldb related), we are not able to get the database
# account from the above call to GetDatabaseAccount, we would try to get this
Expand All @@ -130,18 +130,18 @@ def _GetDatabaseAccount(self):
for location_name in self.PreferredLocations:
locational_endpoint = _GlobalEndpointManager.GetLocationalEndpoint(self.DefaultEndpoint, location_name)
try:
database_account = self._GetDatabaseAccountStub(locational_endpoint)
database_account = self._GetDatabaseAccountStub(locational_endpoint, **kwargs)
return database_account
except errors.CosmosHttpResponseError:
pass

return None

def _GetDatabaseAccountStub(self, endpoint):
def _GetDatabaseAccountStub(self, endpoint, **kwargs):
"""Stub for getting database account from the client
which can be used for mocking purposes as well.
"""
return self.Client.GetDatabaseAccount(endpoint)
return self.Client.GetDatabaseAccount(endpoint, **kwargs)

@staticmethod
def GetLocationalEndpoint(default_endpoint, location_name):
Expand Down
78 changes: 78 additions & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@

import time

from azure.core.exceptions import AzureError, ClientAuthenticationError
from azure.core.pipeline.policies import RetryPolicy

from . import errors
from . import _endpoint_discovery_retry_policy
from . import _resource_throttle_retry_policy
Expand Down Expand Up @@ -119,3 +122,78 @@ def ExecuteFunction(function, *args, **kwargs):
""" Stub method so that it can be used for mocking purposes as well.
"""
return function(*args, **kwargs)


def _configure_timeout(request, absolute, per_request):
# type: (azure.core.pipeline.PipelineRequest, Optional[int], int) -> Optional[AzureError]
if absolute is not None and absolute < per_request:
if absolute <= 0:
raise errors.ClientTimeoutError()
request.context.options['connection_timeout'] = absolute
elif per_request:
request.context.options['connection_timeout'] = per_request


class ConnectionRetryPolicy(RetryPolicy):

def __init__(self, **kwargs):
clean_kwargs = {k: v for k, v in kwargs.items() if v is not None}
super(ConnectionRetryPolicy, self).__init__(**clean_kwargs)



def send(self, request):
"""Sends the PipelineRequest object to the next policy. Uses retry settings if necessary.
Also enforces an absolute client-side timeout that spans multiple retry attempts.

:param request: The PipelineRequest object
:type request: ~azure.core.pipeline.PipelineRequest
:return: Returns the PipelineResponse or raises error if maximum retries exceeded.
:rtype: ~azure.core.pipeline.PipelineResponse
:raises: ~azure.core.exceptions.AzureError if maximum retries exceeded.
:raises: ~azure.cosmos.ClientTimeoutError if specified timeout exceeded.
:raises: ~azure.core.exceptions.ClientAuthenticationError if authentication
"""
absolute_timeout = request.context.options.pop('timeout', None)
per_request_timeout = request.context.options.pop('connection_timeout', 0)

retry_error = None
retry_active = True
response = None
retry_settings = self.configure_retries(request.context.options)
while retry_active:
try:
start_time = time.time()
_configure_timeout(request, absolute_timeout, per_request_timeout)

response = self.next.send(request)
if self.is_retry(retry_settings, response):
retry_active = self.increment(retry_settings, response=response)
if retry_active:
self.sleep(retry_settings, request.context.transport, response=response)
continue
break
except ClientAuthenticationError: # pylint:disable=try-except-raise
# the authentication policy failed such that the client's request can't
# succeed--we'll never have a response to it, so propagate the exception
raise
except errors.ClientTimeoutError as timeout_error:
timeout_error.inner_exception = retry_error
timeout_error.response = response
timeout_error.history = retry_settings['history']
raise
except AzureError as err:
retry_error = err
if self._is_method_retryable(retry_settings, request.http_request):
retry_active = self.increment(retry_settings, response=request, error=err)
if retry_active:
self.sleep(retry_settings, request.context.transport)
continue
raise err
finally:
end_time = time.time()
if absolute_timeout:
absolute_timeout -= (end_time - start_time)

self.update_context(response.context, retry_settings)
return response
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"""

import json
import time

from six.moves.urllib.parse import urlparse
import six
Expand Down Expand Up @@ -96,7 +97,13 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin
connection_timeout = kwargs.pop("connection_timeout", connection_timeout / 1000.0)

# Every request tries to perform a refresh
global_endpoint_manager.refresh_endpoint_list(None)
client_timeout = kwargs.get('timeout')
start_time = time.time()
global_endpoint_manager.refresh_endpoint_list(None, **kwargs)
if client_timeout is not None:
kwargs['timeout'] = client_timeout - (time.time() - start_time)
if kwargs['timeout'] <= 0:
raise errors.ClientTimeoutError()

if request_params.endpoint_override:
base_url = request_params.endpoint_override
Expand Down
4 changes: 2 additions & 2 deletions sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@

import six
from azure.core.tracing.decorator import distributed_trace # type: ignore
from azure.core.pipeline.policies import RetryPolicy

from ._cosmos_client_connection import CosmosClientConnection
from ._base import build_options
from ._retry_utility import ConnectionRetryPolicy
from .database import DatabaseProxy
from .documents import ConnectionPolicy, DatabaseAccount
from .errors import CosmosResourceNotFoundError
Expand Down Expand Up @@ -106,7 +106,7 @@ def _build_connection_policy(kwargs):
policy.RetryOptions = retry
connection_retry = kwargs.pop('connection_retry_policy', None) or policy.ConnectionRetryConfiguration
if not connection_retry:
connection_retry = RetryPolicy(
connection_retry = ConnectionRetryPolicy(
retry_total=total_retries,
retry_connect=kwargs.pop('retry_connect', None),
retry_read=kwargs.pop('retry_read', None),
Expand Down
2 changes: 1 addition & 1 deletion sdk/cosmos/azure-cosmos/azure/cosmos/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ class ConnectionPolicy(object): # pylint: disable=too-many-instance-attributes
:ivar ConnectionRetryConfiguration:
Retry Configuration to be used for connection retries.
:vartype ConnectionRetryConfiguration:
int or requests.packages.urllib3.util.retry or azure.core.pipeline.policies.HTTPPolicy
int or azure.cosmos.ConnectionRetryPolicy or requests.packages.urllib3.util.retry
"""

__defaultRequestTimeout = 60000 # milliseconds
Expand Down
10 changes: 10 additions & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,13 @@ class CosmosResourceExistsError(ResourceExistsError, CosmosHttpResponseError):

class CosmosAccessConditionFailedError(CosmosHttpResponseError):
"""An error response with status code 412."""


class ClientTimeoutError(AzureError):
"""An operation failed to complete within the specified timeout."""

def __init__(self, **kwargs):
message = "Client operation failed to complete within specified timeout."
self.response = None
self.history = None
super(ClientTimeoutError, self).__init__(message, **kwargs)
9 changes: 9 additions & 0 deletions sdk/cosmos/azure-cosmos/test/crud_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2043,6 +2043,15 @@ def initialize_client_with_connection_core_retry_config(self, retries):
end_time = time.time()
return end_time - start_time

def test_absolute_client_timeout(self):
with self.assertRaises(errors.ClientTimeoutError):
cosmos_client.CosmosClient(
"https://localhost:9999",
CRUDTests.masterKey,
"Session",
retry_total=3,
timeout=1)

def test_query_iterable_functionality(self):
def __create_resources(client):
"""Creates resources for this test.
Expand Down