diff --git a/.gitignore b/.gitignore index ce9cb766b2c1..75f5dc5dab54 100644 --- a/.gitignore +++ b/.gitignore @@ -90,3 +90,4 @@ sdk/storage/azure-storage-blob/tests/settings_real.py sdk/storage/azure-storage-queue/tests/settings_real.py sdk/storage/azure-storage-file/tests/settings_real.py *.code-workspace +sdk/cosmos/azure-cosmos/test/test_config.py diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/__init__.py b/sdk/cosmos/azure-cosmos/azure/cosmos/__init__.py index 19e73780aa50..07f3ca79fb93 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/__init__.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/__init__.py @@ -37,6 +37,7 @@ from .permission import Permission from .scripts import Scripts from .user import User +from .version import VERSION __all__ = ( "Container", @@ -56,3 +57,4 @@ "TriggerOperation", "TriggerType", ) +__version__ = VERSION \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 903e8882b2f8..23c5c4c432d7 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -24,9 +24,20 @@ """Document client class for the Azure Cosmos database service. """ -import requests +import platform +import requests import six +from azure.core import PipelineClient +from azure.core.pipeline.policies import ( + ContentDecodePolicy, + HeadersPolicy, + UserAgentPolicy, + NetworkTraceLoggingPolicy, + CustomHookPolicy, + ProxyPolicy) +from azure.core.pipeline.policies.distributed_tracing import DistributedTracingPolicy + from . import _base as base from . import documents from . import _constants as constants @@ -40,6 +51,7 @@ from . import _session from . import _utils from .partition_key import _Undefined, _Empty +from .version import VERSION # pylint: disable=protected-access @@ -132,15 +144,28 @@ def __init__( self._useMultipleWriteLocations = False self._global_endpoint_manager = global_endpoint_manager._GlobalEndpointManager(self) - # creating a requests session used for connection pooling and re-used by all requests - self._requests_session = requests.Session() - + proxies = {} if self.connection_policy.ProxyConfiguration and self.connection_policy.ProxyConfiguration.Host: host = connection_policy.ProxyConfiguration.Host url = six.moves.urllib.parse.urlparse(host) proxy = host if url.port else host + ":" + str(connection_policy.ProxyConfiguration.Port) - proxyDict = {url.scheme: proxy} - self._requests_session.proxies.update(proxyDict) + proxies = {url.scheme : proxy} + user_agent = "azsdk-python-cosmos/{} Python/{} ({})".format( + VERSION, + platform.python_version(), + platform.platform()) + + policies = [ + HeadersPolicy(), + ProxyPolicy(proxies=proxies), + UserAgentPolicy(base_user_agent=user_agent), + ContentDecodePolicy(), + CustomHookPolicy(), + DistributedTracingPolicy(), + NetworkTraceLoggingPolicy(), + ] + + self.pipeline_client = PipelineClient(url_connection, "empty-config", policies=policies) # Query compatibility mode. # Allows to specify compatibility mode used by client when making query requests. Should be removed when @@ -1782,7 +1807,7 @@ def fetch_fn(options): return query_iterable.QueryIterable(self, query, options, fetch_fn) - def ReadMedia(self, media_link): + def ReadMedia(self, media_link, **kwargs): """Reads a media. When self.connection_policy.MediaReadMode == @@ -1806,11 +1831,11 @@ def ReadMedia(self, media_link): headers = base.GetHeaders(self, default_headers, "get", path, attachment_id, "media", {}) # ReadMedia will always use WriteEndpoint since it's not replicated in readable Geo regions - request = _request_object.RequestObject("media", documents._OperationType.Read) - result, self.last_response_headers = self.__Get(path, request, headers) + request_params = _request_object.RequestObject("media", documents._OperationType.Read) + result, self.last_response_headers = self.__Get(path, request_params, headers, **kwargs) return result - def UpdateMedia(self, media_link, readable_stream, options=None): + def UpdateMedia(self, media_link, readable_stream, options=None, **kwargs): """Updates a media and returns it. :param str media_link: @@ -1845,8 +1870,8 @@ def UpdateMedia(self, media_link, readable_stream, options=None): headers = base.GetHeaders(self, initial_headers, "put", path, attachment_id, "media", options) # UpdateMedia will use WriteEndpoint since it uses PUT operation - request = _request_object.RequestObject("media", documents._OperationType.Update) - result, self.last_response_headers = self.__Put(path, request, readable_stream, headers) + request_params = _request_object.RequestObject("media", documents._OperationType.Update) + result, self.last_response_headers = self.__Put(path, request_params, readable_stream, headers, **kwargs) self._UpdateSessionIfRequired(headers, result, self.last_response_headers) return result @@ -1995,7 +2020,7 @@ def DeleteUserDefinedFunction(self, udf_link, options=None): udf_id = base.GetResourceIdOrFullNameFromLink(udf_link) return self.DeleteResource(path, "udfs", udf_id, None, options) - def ExecuteStoredProcedure(self, sproc_link, params, options=None): + def ExecuteStoredProcedure(self, sproc_link, params, options=None, **kwargs): """Executes a store procedure. :param str sproc_link: @@ -2025,8 +2050,8 @@ def ExecuteStoredProcedure(self, sproc_link, params, options=None): headers = base.GetHeaders(self, initial_headers, "post", path, sproc_id, "sprocs", options) # ExecuteStoredProcedure will use WriteEndpoint since it uses POST operation - request = _request_object.RequestObject("sprocs", documents._OperationType.ExecuteJavaScript) - result, self.last_response_headers = self.__Post(path, request, params, headers) + request_params = _request_object.RequestObject("sprocs", documents._OperationType.ExecuteJavaScript) + result, self.last_response_headers = self.__Post(path, request_params, params, headers, **kwargs) return result def ReplaceStoredProcedure(self, sproc_link, sproc, options=None): @@ -2175,7 +2200,7 @@ def fetch_fn(options): return query_iterable.QueryIterable(self, query, options, fetch_fn) - def GetDatabaseAccount(self, url_connection=None): + def GetDatabaseAccount(self, url_connection=None, **kwargs): """Gets database account info. :return: @@ -2190,8 +2215,8 @@ def GetDatabaseAccount(self, url_connection=None): initial_headers = dict(self.default_headers) headers = base.GetHeaders(self, initial_headers, "get", "", "", "", {}) # path # id # type - request = _request_object.RequestObject("databaseaccount", documents._OperationType.Read, url_connection) - result, self.last_response_headers = self.__Get("", request, headers) + request_params = _request_object.RequestObject("databaseaccount", documents._OperationType.Read, url_connection) + result, self.last_response_headers = self.__Get("", request_params, headers, **kwargs) database_account = documents.DatabaseAccount() database_account.DatabasesLink = "/dbs/" database_account.MediaLink = "/media/" @@ -2220,7 +2245,7 @@ def GetDatabaseAccount(self, url_connection=None): ) return database_account - def Create(self, body, path, typ, id, initial_headers, options=None): # pylint: disable=redefined-builtin + def Create(self, body, path, typ, id, initial_headers, options=None, **kwargs): # pylint: disable=redefined-builtin """Creates a Azure Cosmos resource and returns it. :param dict body: @@ -2244,14 +2269,14 @@ def Create(self, body, path, typ, id, initial_headers, options=None): # pylint: headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, options) # Create will use WriteEndpoint since it uses POST operation - request = _request_object.RequestObject(typ, documents._OperationType.Create) - result, self.last_response_headers = self.__Post(path, request, body, headers) + request_params = _request_object.RequestObject(typ, documents._OperationType.Create) + result, self.last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) # update session for write request self._UpdateSessionIfRequired(headers, result, self.last_response_headers) return result - def Upsert(self, body, path, typ, id, initial_headers, options=None): # pylint: disable=redefined-builtin + def Upsert(self, body, path, typ, id, initial_headers, options=None, **kwargs): # pylint: disable=redefined-builtin """Upserts a Azure Cosmos resource and returns it. :param dict body: @@ -2277,13 +2302,13 @@ def Upsert(self, body, path, typ, id, initial_headers, options=None): # pylint: headers[http_constants.HttpHeaders.IsUpsert] = True # Upsert will use WriteEndpoint since it uses POST operation - request = _request_object.RequestObject(typ, documents._OperationType.Upsert) - result, self.last_response_headers = self.__Post(path, request, body, headers) + request_params = _request_object.RequestObject(typ, documents._OperationType.Upsert) + result, self.last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) # update session for write request self._UpdateSessionIfRequired(headers, result, self.last_response_headers) return result - def Replace(self, resource, path, typ, id, initial_headers, options=None): # pylint: disable=redefined-builtin + def Replace(self, resource, path, typ, id, initial_headers, options=None, **kwargs): # pylint: disable=redefined-builtin """Replaces a Azure Cosmos resource and returns it. :param dict resource: @@ -2306,14 +2331,14 @@ def Replace(self, resource, path, typ, id, initial_headers, options=None): # py initial_headers = initial_headers or self.default_headers headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, options) # Replace will use WriteEndpoint since it uses PUT operation - request = _request_object.RequestObject(typ, documents._OperationType.Replace) - result, self.last_response_headers = self.__Put(path, request, resource, headers) + request_params = _request_object.RequestObject(typ, documents._OperationType.Replace) + result, self.last_response_headers = self.__Put(path, request_params, resource, headers, **kwargs) # update session for request mutates data on server side self._UpdateSessionIfRequired(headers, result, self.last_response_headers) return result - def Read(self, path, typ, id, initial_headers, options=None): # pylint: disable=redefined-builtin + def Read(self, path, typ, id, initial_headers, options=None, **kwargs): # pylint: disable=redefined-builtin """Reads a Azure Cosmos resource and returns it. :param str path: @@ -2335,11 +2360,11 @@ def Read(self, path, typ, id, initial_headers, options=None): # pylint: disable initial_headers = initial_headers or self.default_headers headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, options) # Read will use ReadEndpoint since it uses GET operation - request = _request_object.RequestObject(typ, documents._OperationType.Read) - result, self.last_response_headers = self.__Get(path, request, headers) + request_params = _request_object.RequestObject(typ, documents._OperationType.Read) + result, self.last_response_headers = self.__Get(path, request_params, headers, **kwargs) return result - def DeleteResource(self, path, typ, id, initial_headers, options=None): # pylint: disable=redefined-builtin + def DeleteResource(self, path, typ, id, initial_headers, options=None, **kwargs): # pylint: disable=redefined-builtin """Deletes a Azure Cosmos resource and returns it. :param str path: @@ -2361,15 +2386,15 @@ def DeleteResource(self, path, typ, id, initial_headers, options=None): # pylin initial_headers = initial_headers or self.default_headers headers = base.GetHeaders(self, initial_headers, "delete", path, id, typ, options) # Delete will use WriteEndpoint since it uses DELETE operation - request = _request_object.RequestObject(typ, documents._OperationType.Delete) - result, self.last_response_headers = self.__Delete(path, request, headers) + request_params = _request_object.RequestObject(typ, documents._OperationType.Delete) + result, self.last_response_headers = self.__Delete(path, request_params, headers, **kwargs) # update session for request mutates data on server side self._UpdateSessionIfRequired(headers, result, self.last_response_headers) return result - def __Get(self, path, request, headers): + def __Get(self, path, request_params, headers, **kwargs): """Azure Cosmos 'GET' http request. :params str url: @@ -2382,20 +2407,19 @@ def __Get(self, path, request, headers): tuple of (dict, dict) """ + request = self.pipeline_client.get(url=path, headers=headers) return synchronized_request.SynchronizedRequest( - self, - request, - self._global_endpoint_manager, - self.connection_policy, - self._requests_session, - "GET", - path, - None, - None, - headers, + client=self, + request_params=request_params, + global_endpoint_manager=self._global_endpoint_manager, + connection_policy=self.connection_policy, + pipeline_client=self.pipeline_client, + request=request, + request_data=None, + **kwargs ) - def __Post(self, path, request, body, headers): + def __Post(self, path, request_params, body, headers, **kwargs): """Azure Cosmos 'POST' http request. :params str url: @@ -2409,20 +2433,19 @@ def __Post(self, path, request, body, headers): tuple of (dict, dict) """ + request = self.pipeline_client.post(url=path, headers=headers) return synchronized_request.SynchronizedRequest( - self, - request, - self._global_endpoint_manager, - self.connection_policy, - self._requests_session, - "POST", - path, - body, - query_params=None, - headers=headers, + client=self, + request_params=request_params, + global_endpoint_manager=self._global_endpoint_manager, + connection_policy=self.connection_policy, + pipeline_client=self.pipeline_client, + request=request, + request_data=body, + **kwargs ) - def __Put(self, path, request, body, headers): + def __Put(self, path, request_params, body, headers, **kwargs): """Azure Cosmos 'PUT' http request. :params str url: @@ -2436,20 +2459,19 @@ def __Put(self, path, request, body, headers): tuple of (dict, dict) """ + request = self.pipeline_client.put(url=path, headers=headers) return synchronized_request.SynchronizedRequest( - self, - request, - self._global_endpoint_manager, - self.connection_policy, - self._requests_session, - "PUT", - path, - body, - query_params=None, - headers=headers, + client=self, + request_params=request_params, + global_endpoint_manager=self._global_endpoint_manager, + connection_policy=self.connection_policy, + pipeline_client=self.pipeline_client, + request=request, + request_data=body, + **kwargs ) - def __Delete(self, path, request, headers): + def __Delete(self, path, request_params, headers, **kwargs): """Azure Cosmos 'DELETE' http request. :params str url: @@ -2462,17 +2484,16 @@ def __Delete(self, path, request, headers): tuple of (dict, dict) """ + request = self.pipeline_client.delete(url=path, headers=headers) return synchronized_request.SynchronizedRequest( - self, - request, - self._global_endpoint_manager, - self.connection_policy, - self._requests_session, - "DELETE", - path, + client=self, + request_params=request_params, + global_endpoint_manager=self._global_endpoint_manager, + connection_policy=self.connection_policy, + pipeline_client=self.pipeline_client, + request=request, request_data=None, - query_params=None, - headers=headers, + **kwargs ) def QueryFeed(self, path, collection_id, query, options, partition_key_range_id=None): @@ -2506,7 +2527,17 @@ def QueryFeed(self, path, collection_id, query, options, partition_key_range_id= ) def __QueryFeed( - self, path, typ, id_, result_fn, create_fn, query, options=None, partition_key_range_id=None, response_hook=None + self, + path, + typ, + id_, + result_fn, + create_fn, + query, + options=None, + partition_key_range_id=None, + response_hook=None, + **kwargs ): """Query for more than one Azure Cosmos resources. @@ -2545,9 +2576,9 @@ def __GetBodiesFromQueryResult(result): # Copy to make sure that default_headers won't be changed. if query is None: # Query operations will use ReadEndpoint even though it uses GET(for feed requests) - request = _request_object.RequestObject(typ, documents._OperationType.ReadFeed) + request_params = _request_object.RequestObject(typ, documents._OperationType.ReadFeed) headers = base.GetHeaders(self, initial_headers, "get", path, id_, typ, options, partition_key_range_id) - result, self.last_response_headers = self.__Get(path, request, headers) + result, self.last_response_headers = self.__Get(path, request_params, headers, **kwargs) if response_hook: response_hook(self.last_response_headers, result) return __GetBodiesFromQueryResult(result) @@ -2566,9 +2597,9 @@ def __GetBodiesFromQueryResult(result): raise SystemError("Unexpected query compatibility mode.") # Query operations will use ReadEndpoint even though it uses POST(for regular query operations) - request = _request_object.RequestObject(typ, documents._OperationType.SqlQuery) + request_params = _request_object.RequestObject(typ, documents._OperationType.SqlQuery) headers = base.GetHeaders(self, initial_headers, "post", path, id_, typ, options, partition_key_range_id) - result, self.last_response_headers = self.__Post(path, request, query, headers) + result, self.last_response_headers = self.__Post(path, request_params, query, headers, **kwargs) if response_hook: response_hook(self.last_response_headers, result) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_default_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_default_retry_policy.py index a1eb0f51eab0..2e07955ab0ea 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_default_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_default_retry_policy.py @@ -57,7 +57,7 @@ def __init__(self, *args): def needsRetry(self, error_code): if error_code in DefaultRetryPolicy.CONNECTION_ERROR_CODES: if self.args: - if (self.args[4]["method"] == "GET") or (http_constants.HttpHeaders.IsQuery in self.args[4]["headers"]): + if (self.args[3].method == "GET") or (http_constants.HttpHeaders.IsQuery in self.args[3].headers): return True return False return True diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py index 09970fe93bc5..618541feb6b0 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py @@ -33,19 +33,17 @@ from . import _retry_utility -def _IsReadableStream(obj): +def _is_readable_stream(obj): """Checks whether obj is a file-like readable stream. - :rtype: - boolean - + :rtype: boolean """ if hasattr(obj, "read") and callable(getattr(obj, "read")): return True return False -def _RequestBodyFromData(data): +def _request_body_from_data(data): """Gets request body from data. When `data` is dict and list into unicode string; otherwise return `data` @@ -57,7 +55,7 @@ def _RequestBodyFromData(data): str, unicode, file-like stream object, or None """ - if isinstance(data, six.string_types) or _IsReadableStream(data): + if data is None or isinstance(data, six.string_types) or _is_readable_stream(data): return data if isinstance(data, (dict, list, tuple)): @@ -66,27 +64,21 @@ def _RequestBodyFromData(data): if six.PY2: return json_dumped.decode("utf-8") return json_dumped - return None -def _Request( - global_endpoint_manager, request, connection_policy, requests_session, path, request_options, request_body -): +def _Request(global_endpoint_manager, request_params, connection_policy, pipeline_client, request, **kwargs): """Makes one http request using the requests module. :param _GlobalEndpointManager global_endpoint_manager: - :param dict request: + :param dict request_params: contains the resourceType, operationType, endpointOverride, useWriteEndpoint, useAlternateWriteEndpoint information :param documents.ConnectionPolicy connection_policy: - :param requests.Session requests_session: - Session object in requests module - :param str resource_url: - The url for the resource - :param dict request_options: - :param str request_body: - Unicode or None + :param azure.core.PipelineClient pipeline_client: + Pipeline client to process the resquest + :param azure.core.HttpRequest request: + The request object to send through the pipeline :return: tuple of (result, headers) @@ -94,29 +86,27 @@ def _Request( tuple of (dict, dict) """ - is_media = request_options["path"].find("media") > -1 + is_media = request.url.find("media") > -1 is_media_stream = is_media and connection_policy.MediaReadMode == documents.MediaReadMode.Streamed connection_timeout = connection_policy.MediaRequestTimeout if is_media else connection_policy.RequestTimeout + connection_timeout = kwargs.pop("connection_timeout", connection_timeout / 1000.0) # Every request tries to perform a refresh global_endpoint_manager.refresh_endpoint_list(None) - if request.endpoint_override: - base_url = request.endpoint_override + if request_params.endpoint_override: + base_url = request_params.endpoint_override else: - base_url = global_endpoint_manager.resolve_service_endpoint(request) + base_url = global_endpoint_manager.resolve_service_endpoint(request_params) + if base_url != pipeline_client._base_url: + request.url = request.url.replace(pipeline_client._base_url, base_url) - if path: - resource_url = base_url + path - else: - resource_url = base_url - - parse_result = urlparse(resource_url) + parse_result = urlparse(request.url) # The requests library now expects header values to be strings only starting 2.11, # and will raise an error on validation if they are not, so casting all header values to strings. - request_options["headers"] = {header: str(value) for header, value in request_options["headers"].items()} + request.headers.update({header: str(value) for header, value in request.headers.items()}) # We are disabling the SSL verification for local emulator(localhost/127.0.0.1) or if the user # has explicitly specified to disable SSL verification. @@ -126,40 +116,35 @@ def _Request( and not connection_policy.DisableSSLVerification ) - if connection_policy.SSLConfiguration: + if connection_policy.SSLConfiguration or "connection_cert" in kwargs: ca_certs = connection_policy.SSLConfiguration.SSLCaCerts cert_files = (connection_policy.SSLConfiguration.SSLCertFile, connection_policy.SSLConfiguration.SSLKeyFile) - - response = requests_session.request( - request_options["method"], - resource_url, - data=request_body, - headers=request_options["headers"], - timeout=connection_timeout / 1000.0, + response = pipeline_client._pipeline.run( + request, stream=is_media_stream, - verify=ca_certs, - cert=cert_files, + connection_timeout=connection_timeout, + connection_verify=kwargs.pop("connection_verify", ca_certs), + connection_cert=kwargs.pop("connection_cert", cert_files), + ) else: - response = requests_session.request( - request_options["method"], - resource_url, - data=request_body, - headers=request_options["headers"], - timeout=connection_timeout / 1000.0, + response = pipeline_client._pipeline.run( + request, stream=is_media_stream, + connection_timeout=connection_timeout, # If SSL is disabled, verify = false - verify=is_ssl_enabled, + connection_verify=kwargs.pop("connection_verify", is_ssl_enabled) ) + response = response.http_response headers = dict(response.headers) # In case of media stream response, return the response to the user and the user # will need to handle reading the response. if is_media_stream: - return (response.raw, headers) + return (response.stream_download(pipeline_client._pipeline), headers) - data = response.content + data = response.body() if not six.PY2: # python 3 compatible: convert data from byte to unicode string data = data.decode("utf-8") @@ -182,25 +167,23 @@ def _Request( def SynchronizedRequest( client, - request, + request_params, global_endpoint_manager, connection_policy, - requests_session, - method, - path, + pipeline_client, + request, request_data, - query_params, - headers, + **kwargs ): """Performs one synchronized http request according to the parameters. :param object client: Document client instance - :param dict request: - :param _GlobalEndpointManager global_endpoint_manager: + :param dict request_params: + :param _GlobalEndpointManager global_endpoint_manager: :param documents.ConnectionPolicy connection_policy: - :param requests.Session requests_session: - Session object in requests module + :param azure.core.PipelineClient pipeline_client: + PipelineClient to process the request. :param str method: :param str path: :param (str, unicode, file-like stream object, dict, list or None) request_data: @@ -213,33 +196,20 @@ def SynchronizedRequest( tuple of (dict dict) """ - request_body = None - if request_data: - request_body = _RequestBodyFromData(request_data) - if not request_body: - raise errors.UnexpectedDataType("parameter data must be a JSON object, string or" + " readable stream.") - - request_options = {} - request_options["path"] = path - request_options["method"] = method - if query_params: - request_options["path"] += "?" + urlencode(query_params) - - request_options["headers"] = headers - if request_body and isinstance(request_body, (str, six.text_type)): - request_options["headers"][http_constants.HttpHeaders.ContentLength] = len(request_body) - elif request_body is None: - request_options["headers"][http_constants.HttpHeaders.ContentLength] = 0 + request.data = _request_body_from_data(request_data) + if request.data and isinstance(request.data, six.string_types): + request.headers[http_constants.HttpHeaders.ContentLength] = len(request.data) + elif request.data is None: + request.headers[http_constants.HttpHeaders.ContentLength] = 0 # Pass _Request function with it's parameters to retry_utility's Execute method that wraps the call with retries return _retry_utility.Execute( client, global_endpoint_manager, _Request, - request, + request_params, connection_policy, - requests_session, - path, - request_options, - request_body, + pipeline_client, + request, + **kwargs ) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py index 3abe9b3cea93..db929854e529 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py @@ -25,6 +25,7 @@ from typing import Any, Callable, Dict, List, Optional, Union import six +from azure.core.tracing.decorator import distributed_trace from ._cosmos_client_connection import CosmosClientConnection from .errors import HTTPFailure @@ -96,6 +97,7 @@ def _get_conflict_link(self, conflict_or_link): return u"{}/conflicts/{}".format(self.container_link, conflict_or_link) return conflict_or_link["_self"] + @distributed_trace def read( self, session_token=None, @@ -105,6 +107,7 @@ def read( populate_quota_info=None, request_options=None, response_hook=None, + **kwargs ): # type: (str, Dict[str, str], bool, bool, bool, Dict[str, Any], Optional[Callable]) -> Container """ Read the container properties @@ -136,13 +139,14 @@ def read( request_options["populateQuotaInfo"] = populate_quota_info collection_link = self.container_link - self._properties = self.client_connection.ReadContainer(collection_link, options=request_options) + self._properties = self.client_connection.ReadContainer(collection_link, options=request_options, **kwargs) if response_hook: response_hook(self.client_connection.last_response_headers, self._properties) return self._properties + @distributed_trace def read_item( self, item, # type: Union[str, Dict[str, Any]] @@ -153,6 +157,7 @@ def read_item( post_trigger_include=None, # type: str request_options=None, # type: Dict[str, Any] response_hook=None, # type: Optional[Callable] + **kwargs ): # type: (...) -> Dict[str, str] """ @@ -193,11 +198,12 @@ def read_item( if post_trigger_include: request_options["postTriggerInclude"] = post_trigger_include - result = self.client_connection.ReadItem(document_link=doc_link, options=request_options) + result = self.client_connection.ReadItem(document_link=doc_link, options=request_options, **kwargs) if response_hook: response_hook(self.client_connection.last_response_headers, result) return result + @distributed_trace def read_all_items( self, max_item_count=None, @@ -206,6 +212,7 @@ def read_all_items( populate_query_metrics=None, feed_options=None, response_hook=None, + **kwargs ): # type: (int, str, Dict[str, str], bool, Dict[str, Any], Optional[Callable]) -> QueryIterable """ List all items in the container. @@ -233,12 +240,13 @@ def read_all_items( response_hook.clear() items = self.client_connection.ReadItems( - collection_link=self.container_link, feed_options=feed_options, response_hook=response_hook + collection_link=self.container_link, feed_options=feed_options, response_hook=response_hook, **kwargs ) if response_hook: response_hook(self.client_connection.last_response_headers, items) return items + @distributed_trace def query_items_change_feed( self, partition_key_range_id=None, @@ -247,6 +255,7 @@ def query_items_change_feed( max_item_count=None, feed_options=None, response_hook=None, + **kwargs ): """ Get a sorted list of items that were changed, in the order in which they were modified. @@ -277,12 +286,13 @@ def query_items_change_feed( response_hook.clear() result = self.client_connection.QueryItemsChangeFeed( - self.container_link, options=feed_options, response_hook=response_hook + self.container_link, options=feed_options, response_hook=response_hook, **kwargs ) if response_hook: response_hook(self.client_connection.last_response_headers, result) return result + @distributed_trace def query_items( self, query, # type: str @@ -296,6 +306,7 @@ def query_items( populate_query_metrics=None, # type: bool feed_options=None, # type: Dict[str, Any] response_hook=None, # type: Optional[Callable] + **kwargs ): # type: (...) -> QueryIterable """Return all results matching the given `query`. @@ -363,11 +374,13 @@ def query_items( options=feed_options, partition_key=partition_key, response_hook=response_hook, + **kwargs ) if response_hook: response_hook(self.client_connection.last_response_headers, items) return items + @distributed_trace def replace_item( self, item, # type: Union[str, Dict[str, Any]] @@ -380,6 +393,7 @@ def replace_item( post_trigger_include=None, # type: str request_options=None, # type: Dict[str, Any] response_hook=None, # type: Optional[Callable] + **kwargs ): # type: (...) -> Dict[str, str] """ Replaces the specified item if it exists in the container. @@ -415,11 +429,14 @@ def replace_item( if post_trigger_include: request_options["postTriggerInclude"] = post_trigger_include - result = self.client_connection.ReplaceItem(document_link=item_link, new_document=body, options=request_options) + result = self.client_connection.ReplaceItem( + document_link=item_link, new_document=body, options=request_options, **kwargs + ) if response_hook: response_hook(self.client_connection.last_response_headers, result) return result + @distributed_trace def upsert_item( self, body, # type: Dict[str, Any] @@ -431,6 +448,7 @@ def upsert_item( post_trigger_include=None, # type: str request_options=None, # type: Dict[str, Any] response_hook=None, # type: Optional[Callable] + **kwargs ): # type: (...) -> Dict[str, str] """ Insert or update the specified item. @@ -466,11 +484,13 @@ def upsert_item( if post_trigger_include: request_options["postTriggerInclude"] = post_trigger_include - result = self.client_connection.UpsertItem(database_or_Container_link=self.container_link, document=body) + result = self.client_connection.UpsertItem( + database_or_Container_link=self.container_link, document=body, **kwargs) if response_hook: response_hook(self.client_connection.last_response_headers, result) return result + @distributed_trace def create_item( self, body, # type: Dict[str, Any] @@ -483,6 +503,7 @@ def create_item( indexing_directive=None, # type: Any request_options=None, # type: Dict[str, Any] response_hook=None, # type: Optional[Callable] + **kwargs ): # type: (...) -> Dict[str, str] """ Create an item in the container. @@ -523,12 +544,13 @@ def create_item( request_options["indexingDirective"] = indexing_directive result = self.client_connection.CreateItem( - database_or_Container_link=self.container_link, document=body, options=request_options + database_or_Container_link=self.container_link, document=body, options=request_options, **kwargs ) if response_hook: response_hook(self.client_connection.last_response_headers, result) return result + @distributed_trace def delete_item( self, item, # type: Union[Dict[str, Any], str] @@ -541,6 +563,7 @@ def delete_item( post_trigger_include=None, # type: str request_options=None, # type: Dict[str, Any] response_hook=None, # type: Optional[Callable] + **kwargs ): # type: (...) -> None """ Delete the specified item from the container. @@ -577,11 +600,12 @@ def delete_item( request_options["postTriggerInclude"] = post_trigger_include document_link = self._get_document_link(item) - result = self.client_connection.DeleteItem(document_link=document_link, options=request_options) + result = self.client_connection.DeleteItem(document_link=document_link, options=request_options, **kwargs) if response_hook: response_hook(self.client_connection.last_response_headers, result) - def read_offer(self, response_hook=None): + @distributed_trace + def read_offer(self, response_hook=None, **kwargs): # type: (Optional[Callable]) -> Offer """ Read the Offer object for this container. @@ -596,7 +620,7 @@ def read_offer(self, response_hook=None): "query": "SELECT * FROM root r WHERE r.resource=@link", "parameters": [{"name": "@link", "value": link}], } - offers = list(self.client_connection.QueryOffers(query_spec)) + offers = list(self.client_connection.QueryOffers(query_spec, **kwargs)) if not offers: raise HTTPFailure(StatusCodes.NOT_FOUND, "Could not find Offer for container " + self.container_link) @@ -605,7 +629,8 @@ def read_offer(self, response_hook=None): return Offer(offer_throughput=offers[0]["content"]["offerThroughput"], properties=offers[0]) - def replace_throughput(self, throughput, response_hook=None): + @distributed_trace + def replace_throughput(self, throughput, response_hook=None, **kwargs): # type: (int, Optional[Callable]) -> Offer """ Replace the container's throughput @@ -621,19 +646,20 @@ def replace_throughput(self, throughput, response_hook=None): "query": "SELECT * FROM root r WHERE r.resource=@link", "parameters": [{"name": "@link", "value": link}], } - offers = list(self.client_connection.QueryOffers(query_spec)) + offers = list(self.client_connection.QueryOffers(query_spec, **kwargs)) if not offers: raise HTTPFailure(StatusCodes.NOT_FOUND, "Could not find Offer for container " + self.container_link) new_offer = offers[0].copy() new_offer["content"]["offerThroughput"] = throughput - data = self.client_connection.ReplaceOffer(offer_link=offers[0]["_self"], offer=offers[0]) + data = self.client_connection.ReplaceOffer(offer_link=offers[0]["_self"], offer=offers[0], **kwargs) if response_hook: response_hook(self.client_connection.last_response_headers, data) return Offer(offer_throughput=data["content"]["offerThroughput"], properties=data) - def read_all_conflicts(self, max_item_count=None, feed_options=None, response_hook=None): + @distributed_trace + def read_all_conflicts(self, max_item_count=None, feed_options=None, response_hook=None, **kwargs): # type: (int, Dict[str, Any], Optional[Callable]) -> QueryIterable """ List all conflicts in the container. @@ -648,11 +674,14 @@ def read_all_conflicts(self, max_item_count=None, feed_options=None, response_ho if max_item_count is not None: feed_options["maxItemCount"] = max_item_count - result = self.client_connection.ReadConflicts(collection_link=self.container_link, feed_options=feed_options) + result = self.client_connection.ReadConflicts( + collection_link=self.container_link, feed_options=feed_options, **kwargs + ) if response_hook: response_hook(self.client_connection.last_response_headers, result) return result + @distributed_trace def query_conflicts( self, query, @@ -662,6 +691,7 @@ def query_conflicts( max_item_count=None, feed_options=None, response_hook=None, + **kwargs ): # type: (str, List, bool, Any, int, Dict[str, Any], Optional[Callable]) -> QueryIterable """Return all conflicts matching the given `query`. @@ -691,12 +721,14 @@ def query_conflicts( collection_link=self.container_link, query=query if parameters is None else dict(query=query, parameters=parameters), options=feed_options, + **kwargs ) if response_hook: response_hook(self.client_connection.last_response_headers, result) return result - def get_conflict(self, conflict, partition_key, request_options=None, response_hook=None): + @distributed_trace + def get_conflict(self, conflict, partition_key, request_options=None, response_hook=None, **kwargs): # type: (Union[str, Dict[str, Any]], Any, Dict[str, Any], Optional[Callable]) -> Dict[str, str] """ Get the conflict identified by `id`. @@ -714,13 +746,14 @@ def get_conflict(self, conflict, partition_key, request_options=None, response_h request_options["partitionKey"] = self._set_partition_key(partition_key) result = self.client_connection.ReadConflict( - conflict_link=self._get_conflict_link(conflict), options=request_options + conflict_link=self._get_conflict_link(conflict), options=request_options, **kwargs ) if response_hook: response_hook(self.client_connection.last_response_headers, result) return result - def delete_conflict(self, conflict, partition_key, request_options=None, response_hook=None): + @distributed_trace + def delete_conflict(self, conflict, partition_key, request_options=None, response_hook=None, **kwargs): # type: (Union[str, Dict[str, Any]], Any, Dict[str, Any], Optional[Callable]) -> None """ Delete the specified conflict from the container. @@ -738,7 +771,7 @@ def delete_conflict(self, conflict, partition_key, request_options=None, respons request_options["partitionKey"] = self._set_partition_key(partition_key) result = self.client_connection.DeleteConflict( - conflict_link=self._get_conflict_link(conflict), options=request_options + conflict_link=self._get_conflict_link(conflict), options=request_options, **kwargs ) if response_hook: response_hook(self.client_connection.last_response_headers, result) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py b/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py index 6d62de891fb8..9b0aa1836831 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py @@ -25,6 +25,7 @@ from typing import Any, Callable, Dict, Mapping, Optional, Union, cast import six +from azure.core.tracing.decorator import distributed_trace from ._cosmos_client_connection import CosmosClientConnection from .database import Database @@ -80,6 +81,7 @@ def _get_database_link(database_or_id): database_id = cast("Dict[str, str]", database_or_id)["id"] return "dbs/{}".format(database_id) + @distributed_trace def create_database( self, id, # pylint: disable=redefined-builtin @@ -90,6 +92,7 @@ def create_database( offer_throughput=None, request_options=None, response_hook=None, + **kwargs ): # type: (str, str, Dict[str, str], Dict[str, str], bool, int, Dict[str, Any], Optional[Callable]) -> Database """Create a new database with the given ID (name). @@ -128,7 +131,7 @@ def create_database( if offer_throughput is not None: request_options["offerThroughput"] = offer_throughput - result = self.client_connection.CreateDatabase(database=dict(id=id), options=request_options) + result = self.client_connection.CreateDatabase(database=dict(id=id), options=request_options, **kwargs) if response_hook: response_hook(self.client_connection.last_response_headers) return Database(self.client_connection, id=result["id"], properties=result) @@ -152,6 +155,7 @@ def get_database_client(self, database): return Database(self.client_connection, id_value) + @distributed_trace def read_all_databases( self, max_item_count=None, @@ -160,6 +164,7 @@ def read_all_databases( populate_query_metrics=None, feed_options=None, response_hook=None, + **kwargs ): # type: (int, str, Dict[str, str], bool, Dict[str, Any], Optional[Callable]) -> QueryIterable """ @@ -185,11 +190,12 @@ def read_all_databases( if populate_query_metrics is not None: feed_options["populateQueryMetrics"] = populate_query_metrics - result = self.client_connection.ReadDatabases(options=feed_options) + result = self.client_connection.ReadDatabases(options=feed_options, **kwargs) if response_hook: response_hook(self.client_connection.last_response_headers) return result + @distributed_trace def query_databases( self, query=None, # type: str @@ -201,6 +207,7 @@ def query_databases( populate_query_metrics=None, # type: bool feed_options=None, # type: Dict[str, Any] response_hook=None, # type: Optional[Callable] + **kwargs ): # type: (...) -> QueryIterable @@ -239,15 +246,15 @@ def query_databases( # (just returning a generator did not initiate the first network call, so # the headers were misleading) # This needs to change for "real" implementation - result = self.client_connection.QueryDatabases( - query=query if parameters is None else dict(query=query, parameters=parameters), options=feed_options - ) + query = query if parameters is None else dict(query=query, parameters=parameters) + result = self.client_connection.QueryDatabases(query=query, options=feed_options, **kwargs) else: - result = self.client_connection.ReadDatabases(options=feed_options) + result = self.client_connection.ReadDatabases(options=feed_options, **kwargs) if response_hook: response_hook(self.client_connection.last_response_headers) return result + @distributed_trace def delete_database( self, database, # type: Union[str, Database, Dict[str, Any]] @@ -257,6 +264,7 @@ def delete_database( populate_query_metrics=None, # type: bool request_options=None, # type: Dict[str, Any] response_hook=None, # type: Optional[Callable] + **kwargs ): # type: (...) -> None """ @@ -285,11 +293,12 @@ def delete_database( request_options["populateQueryMetrics"] = populate_query_metrics database_link = self._get_database_link(database) - self.client_connection.DeleteDatabase(database_link, options=request_options) + self.client_connection.DeleteDatabase(database_link, options=request_options, **kwargs) if response_hook: response_hook(self.client_connection.last_response_headers) - def get_database_account(self, response_hook=None): + @distributed_trace + def get_database_account(self, response_hook=None, **kwargs): # type: (Optional[Callable]) -> DatabaseAccount """ Retrieve the database account information. @@ -298,7 +307,7 @@ def get_database_account(self, response_hook=None): :returns: A :class:`DatabaseAccount` instance representing the Cosmos DB Database Account. """ - result = self.client_connection.GetDatabaseAccount() + result = self.client_connection.GetDatabaseAccount(**kwargs) if response_hook: response_hook(self.client_connection.last_response_headers) return result diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/database.py b/sdk/cosmos/azure-cosmos/azure/cosmos/database.py index 6d2335986ec0..0df8589b70c2 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/database.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/database.py @@ -25,6 +25,7 @@ from typing import Any, List, Dict, Mapping, Union, cast import six +from azure.core.tracing.decorator import distributed_trace from ._cosmos_client_connection import CosmosClientConnection from .container import Container @@ -103,6 +104,7 @@ def _get_properties(self): self.read() return self._properties + @distributed_trace def read( self, session_token=None, @@ -110,6 +112,7 @@ def read( populate_query_metrics=None, request_options=None, response_hook=None, + **kwargs ): # type: (str, Dict[str, str], bool, Dict[str, Any], Optional[Callable]) -> Dict[str, Any] """ @@ -139,13 +142,14 @@ def read( if populate_query_metrics is not None: request_options["populateQueryMetrics"] = populate_query_metrics - self._properties = self.client_connection.ReadDatabase(database_link, options=request_options) + self._properties = self.client_connection.ReadDatabase(database_link, options=request_options, **kwargs) if response_hook: response_hook(self.client_connection.last_response_headers, self._properties) return self._properties + @distributed_trace def create_container( self, id, # type: str # pylint: disable=redefined-builtin @@ -161,6 +165,7 @@ def create_container( conflict_resolution_policy=None, # type: Dict[str, Any] request_options=None, # type: Dict[str, Any] response_hook=None, # type: Optional[Callable] + **kwargs ): # type: (...) -> Container """ @@ -228,7 +233,7 @@ def create_container( request_options["offerThroughput"] = offer_throughput data = self.client_connection.CreateContainer( - database_link=self.database_link, collection=definition, options=request_options + database_link=self.database_link, collection=definition, options=request_options, **kwargs ) if response_hook: @@ -236,6 +241,7 @@ def create_container( return Container(self.client_connection, self.database_link, data["id"], properties=data) + @distributed_trace def delete_container( self, container, # type: Union[str, Container, Dict[str, Any]] @@ -245,6 +251,7 @@ def delete_container( populate_query_metrics=None, # type: bool request_options=None, # type: Dict[str, Any] response_hook=None, # type: Optional[Callable] + **kwargs ): # type: (...) -> None """ Delete the container @@ -273,7 +280,7 @@ def delete_container( request_options["populateQueryMetrics"] = populate_query_metrics collection_link = self._get_container_link(container) - result = self.client_connection.DeleteContainer(collection_link, options=request_options) + result = self.client_connection.DeleteContainer(collection_link, options=request_options, **kwargs) if response_hook: response_hook(self.client_connection.last_response_headers, result) @@ -302,6 +309,7 @@ def get_container_client(self, container): return Container(self.client_connection, self.database_link, id_value) + @distributed_trace def read_all_containers( self, max_item_count=None, @@ -310,6 +318,7 @@ def read_all_containers( populate_query_metrics=None, feed_options=None, response_hook=None, + **kwargs ): # type: (int, str, Dict[str, str], bool, Dict[str, Any], Optional[Callable]) -> QueryIterable """ List the containers in the database. @@ -342,11 +351,14 @@ def read_all_containers( if populate_query_metrics is not None: feed_options["populateQueryMetrics"] = populate_query_metrics - result = self.client_connection.ReadContainers(database_link=self.database_link, options=feed_options) + result = self.client_connection.ReadContainers( + database_link=self.database_link, options=feed_options, **kwargs + ) if response_hook: response_hook(self.client_connection.last_response_headers, result) return result + @distributed_trace def query_containers( self, query=None, @@ -357,6 +369,7 @@ def query_containers( populate_query_metrics=None, feed_options=None, response_hook=None, + **kwargs ): # type: (str, List, int, str, Dict[str, str], bool, Dict[str, Any], Optional[Callable]) -> QueryIterable """List properties for containers in the current database @@ -387,11 +400,13 @@ def query_containers( database_link=self.database_link, query=query if parameters is None else dict(query=query, parameters=parameters), options=feed_options, + **kwargs ) if response_hook: response_hook(self.client_connection.last_response_headers, result) return result + @distributed_trace def replace_container( self, container, # type: Union[str, Container, Dict[str, Any]] @@ -405,6 +420,7 @@ def replace_container( populate_query_metrics=None, # type: bool request_options=None, # type: Dict[str, Any] response_hook=None, # type: Optional[Callable] + **kwargs ): # type: (...) -> Container """ Reset the properties of the container. Property changes are persisted immediately. @@ -462,7 +478,7 @@ def replace_container( } container_properties = self.client_connection.ReplaceContainer( - container_link, collection=parameters, options=request_options + container_link, collection=parameters, options=request_options, **kwargs ) if response_hook: @@ -472,7 +488,8 @@ def replace_container( self.client_connection, self.database_link, container_properties["id"], properties=container_properties ) - def read_all_users(self, max_item_count=None, feed_options=None, response_hook=None): + @distributed_trace + def read_all_users(self, max_item_count=None, feed_options=None, response_hook=None, **kwargs): # type: (int, Dict[str, Any], Optional[Callable]) -> QueryIterable """ List all users in the container. @@ -487,12 +504,15 @@ def read_all_users(self, max_item_count=None, feed_options=None, response_hook=N if max_item_count is not None: feed_options["maxItemCount"] = max_item_count - result = self.client_connection.ReadUsers(database_link=self.database_link, options=feed_options) + result = self.client_connection.ReadUsers( + database_link=self.database_link, options=feed_options, **kwargs + ) if response_hook: response_hook(self.client_connection.last_response_headers, result) return result - def query_users(self, query, parameters=None, max_item_count=None, feed_options=None, response_hook=None): + @distributed_trace + def query_users(self, query, parameters=None, max_item_count=None, feed_options=None, response_hook=None, **kwargs): # type: (str, List, int, Dict[str, Any], Optional[Callable]) -> QueryIterable """Return all users matching the given `query`. @@ -513,6 +533,7 @@ def query_users(self, query, parameters=None, max_item_count=None, feed_options= database_link=self.database_link, query=query if parameters is None else dict(query=query, parameters=parameters), options=feed_options, + **kwargs ) if response_hook: response_hook(self.client_connection.last_response_headers, result) @@ -538,7 +559,8 @@ def get_user_client(self, user): return User(client_connection=self.client_connection, id=id_value, database_link=self.database_link) - def create_user(self, body, request_options=None, response_hook=None): + @distributed_trace + def create_user(self, body, request_options=None, response_hook=None, **kwargs): # type: (Dict[str, Any], Dict[str, Any], Optional[Callable]) -> User """ Create a user in the container. @@ -563,7 +585,8 @@ def create_user(self, body, request_options=None, response_hook=None): if not request_options: request_options = {} # type: Dict[str, Any] - user = self.client_connection.CreateUser(database_link=self.database_link, user=body, options=request_options) + user = self.client_connection.CreateUser( + database_link=self.database_link, user=body, options=request_options, **kwargs) if response_hook: response_hook(self.client_connection.last_response_headers, user) @@ -572,7 +595,8 @@ def create_user(self, body, request_options=None, response_hook=None): client_connection=self.client_connection, id=user["id"], database_link=self.database_link, properties=user ) - def upsert_user(self, body, request_options=None, response_hook=None): + @distributed_trace + def upsert_user(self, body, request_options=None, response_hook=None, **kwargs): # type: (Dict[str, Any], Dict[str, Any], Optional[Callable]) -> User """ Insert or update the specified user. @@ -588,7 +612,9 @@ def upsert_user(self, body, request_options=None, response_hook=None): if not request_options: request_options = {} # type: Dict[str, Any] - user = self.client_connection.UpsertUser(database_link=self.database_link, user=body, options=request_options) + user = self.client_connection.UpsertUser( + database_link=self.database_link, user=body, options=request_options, **kwargs + ) if response_hook: response_hook(self.client_connection.last_response_headers, user) @@ -597,7 +623,8 @@ def upsert_user(self, body, request_options=None, response_hook=None): client_connection=self.client_connection, id=user["id"], database_link=self.database_link, properties=user ) - def replace_user(self, user, body, request_options=None, response_hook=None): + @distributed_trace + def replace_user(self, user, body, request_options=None, response_hook=None, **kwargs): # type: (Union[str, User, Dict[str, Any]], Dict[str, Any], Dict[str, Any], Optional[Callable]) -> User """ Replaces the specified user if it exists in the container. @@ -614,7 +641,7 @@ def replace_user(self, user, body, request_options=None, response_hook=None): request_options = {} # type: Dict[str, Any] user = self.client_connection.ReplaceUser( - user_link=self._get_user_link(user), user=body, options=request_options + user_link=self._get_user_link(user), user=body, options=request_options, **kwargs ) if response_hook: @@ -624,7 +651,8 @@ def replace_user(self, user, body, request_options=None, response_hook=None): client_connection=self.client_connection, id=user["id"], database_link=self.database_link, properties=user ) - def delete_user(self, user, request_options=None, response_hook=None): + @distributed_trace + def delete_user(self, user, request_options=None, response_hook=None, **kwargs): # type: (Union[str, User, Dict[str, Any]], Dict[str, Any], Optional[Callable]) -> None """ Delete the specified user from the container. @@ -639,11 +667,14 @@ def delete_user(self, user, request_options=None, response_hook=None): if not request_options: request_options = {} # type: Dict[str, Any] - result = self.client_connection.DeleteUser(user_link=self._get_user_link(user), options=request_options) + result = self.client_connection.DeleteUser( + user_link=self._get_user_link(user), options=request_options, **kwargs + ) if response_hook: response_hook(self.client_connection.last_response_headers, result) - def read_offer(self, response_hook=None): + @distributed_trace + def read_offer(self, response_hook=None, **kwargs): # type: (Optional[Callable]) -> Offer """ Read the Offer object for this database. @@ -658,7 +689,7 @@ def read_offer(self, response_hook=None): "query": "SELECT * FROM root r WHERE r.resource=@link", "parameters": [{"name": "@link", "value": link}], } - offers = list(self.client_connection.QueryOffers(query_spec)) + offers = list(self.client_connection.QueryOffers(query_spec, **kwargs)) if not offers: raise HTTPFailure(StatusCodes.NOT_FOUND, "Could not find Offer for database " + self.database_link) @@ -667,7 +698,8 @@ def read_offer(self, response_hook=None): return Offer(offer_throughput=offers[0]["content"]["offerThroughput"], properties=offers[0]) - def replace_throughput(self, throughput, response_hook=None): + @distributed_trace + def replace_throughput(self, throughput, response_hook=None, **kwargs): # type: (int, Optional[Callable]) -> Offer """ Replace the database level throughput. @@ -688,7 +720,7 @@ def replace_throughput(self, throughput, response_hook=None): raise HTTPFailure(StatusCodes.NOT_FOUND, "Could not find Offer for collection " + self.database_link) new_offer = offers[0].copy() new_offer["content"]["offerThroughput"] = throughput - data = self.client_connection.ReplaceOffer(offer_link=offers[0]["_self"], offer=offers[0]) + data = self.client_connection.ReplaceOffer(offer_link=offers[0]["_self"], offer=offers[0], **kwargs) if response_hook: response_hook(self.client_connection.last_response_headers, data) return Offer(offer_throughput=data["content"]["offerThroughput"], properties=data) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/user.py b/sdk/cosmos/azure-cosmos/azure/cosmos/user.py index ab6a59a2277c..b33a36fc81eb 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/user.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/user.py @@ -25,6 +25,7 @@ from typing import Any, List, Dict, Union, cast import six +from azure.core.tracing.decorator import distributed_trace from ._cosmos_client_connection import CosmosClientConnection from .permission import Permission @@ -54,7 +55,8 @@ def _get_properties(self): self.read() return self._properties - def read(self, request_options=None, response_hook=None): + @distributed_trace + def read(self, request_options=None, response_hook=None, **kwargs): # type: (Dict[str, Any], Optional[Callable]) -> User """ Read user propertes. @@ -68,14 +70,15 @@ def read(self, request_options=None, response_hook=None): if not request_options: request_options = {} # type: Dict[str, Any] - self._properties = self.client_connection.ReadUser(user_link=self.user_link, options=request_options) + self._properties = self.client_connection.ReadUser(user_link=self.user_link, options=request_options, **kwargs) if response_hook: response_hook(self.client_connection.last_response_headers, self._properties) return self._properties - def read_all_permissions(self, max_item_count=None, feed_options=None, response_hook=None): + @distributed_trace + def read_all_permissions(self, max_item_count=None, feed_options=None, response_hook=None, **kwargs): # type: (int, Dict[str, Any], Optional[Callable]) -> QueryIterable """ List all permission for the user. @@ -90,14 +93,23 @@ def read_all_permissions(self, max_item_count=None, feed_options=None, response_ if max_item_count is not None: feed_options["maxItemCount"] = max_item_count - result = self.client_connection.ReadPermissions(user_link=self.user_link, options=feed_options) + result = self.client_connection.ReadPermissions(user_link=self.user_link, options=feed_options, **kwargs) if response_hook: response_hook(self.client_connection.last_response_headers, result) return result - def query_permissions(self, query, parameters=None, max_item_count=None, feed_options=None, response_hook=None): + @distributed_trace + def query_permissions( + self, + query, + parameters=None, + max_item_count=None, + feed_options=None, + response_hook=None, + **kwargs + ): # type: (str, List, int, Dict[str, Any], Optional[Callable]) -> QueryIterable """Return all permissions matching the given `query`. @@ -118,6 +130,7 @@ def query_permissions(self, query, parameters=None, max_item_count=None, feed_op user_link=self.user_link, query=query if parameters is None else dict(query=query, parameters=parameters), options=feed_options, + **kwargs ) if response_hook: @@ -125,7 +138,8 @@ def query_permissions(self, query, parameters=None, max_item_count=None, feed_op return result - def get_permission(self, permission, request_options=None, response_hook=None): + @distributed_trace + def get_permission(self, permission, request_options=None, response_hook=None, **kwargs): # type: (str, Dict[str, Any], Optional[Callable]) -> Permission """ Get the permission identified by `id`. @@ -142,7 +156,7 @@ def get_permission(self, permission, request_options=None, response_hook=None): request_options = {} # type: Dict[str, Any] permission = self.client_connection.ReadPermission( - permission_link=self._get_permission_link(permission), options=request_options + permission_link=self._get_permission_link(permission), options=request_options, **kwargs ) if response_hook: @@ -156,7 +170,8 @@ def get_permission(self, permission, request_options=None, response_hook=None): properties=permission, ) - def create_permission(self, body, request_options=None, response_hook=None): + @distributed_trace + def create_permission(self, body, request_options=None, response_hook=None, **kwargs): # type: (Dict[str, Any], Dict[str, Any], Optional[Callable]) -> Permission """ Create a permission for the user. @@ -173,7 +188,7 @@ def create_permission(self, body, request_options=None, response_hook=None): request_options = {} # type: Dict[str, Any] permission = self.client_connection.CreatePermission( - user_link=self.user_link, permission=body, options=request_options + user_link=self.user_link, permission=body, options=request_options, **kwargs ) if response_hook: @@ -187,7 +202,8 @@ def create_permission(self, body, request_options=None, response_hook=None): properties=permission, ) - def upsert_permission(self, body, request_options=None, response_hook=None): + @distributed_trace + def upsert_permission(self, body, request_options=None, response_hook=None, **kwargs): # type: (Dict[str, Any], Dict[str, Any], Optional[Callable]) -> Permission """ Insert or update the specified permission. @@ -204,7 +220,7 @@ def upsert_permission(self, body, request_options=None, response_hook=None): request_options = {} # type: Dict[str, Any] permission = self.client_connection.UpsertPermission( - user_link=self.user_link, permission=body, options=request_options + user_link=self.user_link, permission=body, options=request_options, **kwargs ) if response_hook: @@ -218,7 +234,8 @@ def upsert_permission(self, body, request_options=None, response_hook=None): properties=permission, ) - def replace_permission(self, permission, body, request_options=None, response_hook=None): + @distributed_trace + def replace_permission(self, permission, body, request_options=None, response_hook=None, **kwargs): # type: (str, Dict[str, Any], Dict[str, Any], Optional[Callable]) -> Permission """ Replaces the specified permission if it exists for the user. @@ -235,7 +252,7 @@ def replace_permission(self, permission, body, request_options=None, response_ho request_options = {} # type: Dict[str, Any] permission = self.client_connection.ReplacePermission( - permission_link=self._get_permission_link(permission), permission=body, options=request_options + permission_link=self._get_permission_link(permission), permission=body, options=request_options, **kwargs ) if response_hook: @@ -249,7 +266,8 @@ def replace_permission(self, permission, body, request_options=None, response_ho properties=permission, ) - def delete_permission(self, permission, request_options=None, response_hook=None): + @distributed_trace + def delete_permission(self, permission, request_options=None, response_hook=None, **kwargs): # type: (str, Dict[str, Any], Optional[Callable]) -> None """ Delete the specified permission from the user. @@ -266,7 +284,7 @@ def delete_permission(self, permission, request_options=None, response_hook=None request_options = {} # type: Dict[str, Any] result = self.client_connection.DeletePermission( - permission_link=self._get_permission_link(permission), options=request_options + permission_link=self._get_permission_link(permission), options=request_options, **kwargs ) if response_hook: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/version.py b/sdk/cosmos/azure-cosmos/azure/cosmos/version.py new file mode 100644 index 000000000000..ea688752d60b --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/version.py @@ -0,0 +1,22 @@ +# The MIT License (MIT) +# Copyright (c) 2014 Microsoft Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +VERSION = "4.0.0b2" diff --git a/sdk/cosmos/azure-cosmos/setup.py b/sdk/cosmos/azure-cosmos/setup.py index 1f4f8c94d778..a2bc049eaf9d 100644 --- a/sdk/cosmos/azure-cosmos/setup.py +++ b/sdk/cosmos/azure-cosmos/setup.py @@ -7,7 +7,7 @@ # pylint:disable=missing-docstring import re -import os.path +import os from io import open from setuptools import find_packages, setup @@ -20,6 +20,10 @@ # a-b-c => a.b.c NAMESPACE_NAME = PACKAGE_NAME.replace("-", ".") +# Version extraction inspired from 'requests' +with open(os.path.join(PACKAGE_FOLDER_PATH, 'version.py'), 'r') as fd: + version = re.search(r'^VERSION\s*=\s*[\'"]([^\'"]*)[\'"]', + fd.read(), re.MULTILINE).group(1) with open("README.md", encoding="utf-8") as f: README = f.read() @@ -28,7 +32,7 @@ setup( name=PACKAGE_NAME, - version='4.0.0b1', + version=version, description="Microsoft Azure {} Client Library for Python".format(PACKAGE_PPRINT_NAME), long_description=README + "\n\n" + HISTORY, long_description_content_type="text/markdown", @@ -66,7 +70,7 @@ ), install_requires=[ 'six >=1.6', - 'requests>=2.18.4' + 'azure-core<2.0.0,>=1.0.0b2' ], extras_require={ ":python_version<'3.0'": ["azure-nspkg"], diff --git a/sdk/cosmos/azure-cosmos/test/crud_tests.py b/sdk/cosmos/azure-cosmos/test/crud_tests.py index 43a1d2096d24..0f093271d7c3 100644 --- a/sdk/cosmos/azure-cosmos/test/crud_tests.py +++ b/sdk/cosmos/azure-cosmos/test/crud_tests.py @@ -1946,7 +1946,7 @@ def __get_first(array): root_included_path = __get_first([included_path for included_path in indexing_policy['includedPaths'] if included_path['path'] == '/*']) - self.assertFalse('indexes' in root_included_path) + self.assertFalse(root_included_path.get('indexes')) def test_client_request_timeout(self): connection_policy = documents.ConnectionPolicy() @@ -2565,8 +2565,8 @@ def test_get_resource_with_dictionary_and_object(self): self.assertEquals(read_permission.id, created_permission.id) def _MockExecuteFunction(self, function, *args, **kwargs): - self.last_headers.append(args[5]['headers'][HttpHeaders.PartitionKey] - if HttpHeaders.PartitionKey in args[5]['headers'] else '') + self.last_headers.append(args[4].headers[HttpHeaders.PartitionKey] + if HttpHeaders.PartitionKey in args[4].headers else '') return self.OriginalExecuteFunction(function, *args, **kwargs) if __name__ == '__main__': diff --git a/sdk/cosmos/azure-cosmos/test/multimaster_tests.py b/sdk/cosmos/azure-cosmos/test/multimaster_tests.py index 11c740244dcc..e58dff33a4ae 100644 --- a/sdk/cosmos/azure-cosmos/test/multimaster_tests.py +++ b/sdk/cosmos/azure-cosmos/test/multimaster_tests.py @@ -123,8 +123,8 @@ def _MockExecuteFunction(self, function, *args, **kwargs): return {constants._Constants.EnableMultipleWritableLocations: self.EnableMultipleWritableLocations}, {} else: if len(args) > 0: - self.last_headers.append(HttpHeaders.AllowTentativeWrites in args[5]['headers'] - and args[5]['headers'][HttpHeaders.AllowTentativeWrites] == 'true') + self.last_headers.append(HttpHeaders.AllowTentativeWrites in args[4].headers + and args[4].headers[HttpHeaders.AllowTentativeWrites] == 'true') return self.OriginalExecuteFunction(function, *args, **kwargs) diff --git a/sdk/cosmos/azure-cosmos/test/proxy_tests.py b/sdk/cosmos/azure-cosmos/test/proxy_tests.py index 552f76ee301c..09857ebf84f5 100644 --- a/sdk/cosmos/azure-cosmos/test/proxy_tests.py +++ b/sdk/cosmos/azure-cosmos/test/proxy_tests.py @@ -30,7 +30,7 @@ else: from http.server import BaseHTTPRequestHandler, HTTPServer from threading import Thread -from requests.exceptions import ProxyError +from azure.core.exceptions import ServiceRequestError pytestmark = pytest.mark.cosmosEmulator @@ -104,7 +104,7 @@ def test_failure_with_wrong_proxy(self): client = cosmos_client_connection.CosmosClientConnection(self.host, {'masterKey': self.masterKey}, connection_policy) self.fail("Client instantiation is not expected") except Exception as e: - self.assertTrue(type(e) is ProxyError, msg="Error is not a ProxyError") + self.assertTrue(type(e) is ServiceRequestError, msg="Error is not a ServiceRequestError") if __name__ == "__main__": #import sys;sys.argv = ['', 'Test.testName'] diff --git a/sdk/cosmos/azure-cosmos/test/session_tests.py b/sdk/cosmos/azure-cosmos/test/session_tests.py index cdd0799dafd5..f3e7e3e9fe93 100644 --- a/sdk/cosmos/azure-cosmos/test/session_tests.py +++ b/sdk/cosmos/azure-cosmos/test/session_tests.py @@ -37,12 +37,12 @@ def setUpClass(cls): cls.created_db = test_config._test_config.create_database_if_not_exist(cls.client) cls.created_collection = test_config._test_config.create_multi_partition_collection_with_custom_pk_if_not_exist(cls.client) - def _MockRequest(self, global_endpoint_manager, request, connection_policy, requests_session, path, request_options, request_body): - if HttpHeaders.SessionToken in request_options['headers']: - self.last_session_token_sent = request_options['headers'][HttpHeaders.SessionToken] + def _MockRequest(self, global_endpoint_manager, request_params, connection_policy, pipeline_client, request): + if HttpHeaders.SessionToken in request.headers: + self.last_session_token_sent = request.headers[HttpHeaders.SessionToken] else: self.last_session_token_sent = None - return self._OriginalRequest(global_endpoint_manager, request, connection_policy, requests_session, path, request_options, request_body) + return self._OriginalRequest(global_endpoint_manager, request_params, connection_policy, pipeline_client, request) def test_session_token_not_sent_for_master_resource_ops (self): self._OriginalRequest = synchronized_request._Request