Skip to content
Prev Previous commit
Next Next commit
Added trace decorators
  • Loading branch information
annatisch committed Aug 28, 2019
commit abc90303b0572dcec462594dd7ffda6b8ee69bfc
71 changes: 52 additions & 19 deletions sdk/cosmos/azure-cosmos/azure/cosmos/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from typing import Any, Callable, Dict, List, Optional, Union

import six
from azure.core.tracing.decorator import distributed_trace

from ._cosmos_client_connection import CosmosClientConnection
from .errors import HTTPFailure
Expand Down Expand Up @@ -96,6 +97,7 @@ def _get_conflict_link(self, conflict_or_link):
return u"{}/conflicts/{}".format(self.container_link, conflict_or_link)
return conflict_or_link["_self"]

@distributed_trace
def read(
self,
session_token=None,
Expand All @@ -105,6 +107,7 @@ def read(
populate_quota_info=None,
request_options=None,
response_hook=None,
**kwargs
):
# type: (str, Dict[str, str], bool, bool, bool, Dict[str, Any], Optional[Callable]) -> Container
""" Read the container properties
Expand Down Expand Up @@ -136,13 +139,14 @@ def read(
request_options["populateQuotaInfo"] = populate_quota_info

collection_link = self.container_link
self._properties = self.client_connection.ReadContainer(collection_link, options=request_options)
self._properties = self.client_connection.ReadContainer(collection_link, options=request_options, **kwargs)

if response_hook:
response_hook(self.client_connection.last_response_headers, self._properties)

return self._properties

@distributed_trace
def read_item(
self,
item, # type: Union[str, Dict[str, Any]]
Expand All @@ -153,6 +157,7 @@ def read_item(
post_trigger_include=None, # type: str
request_options=None, # type: Dict[str, Any]
response_hook=None, # type: Optional[Callable]
**kwargs
):
# type: (...) -> Dict[str, str]
"""
Expand Down Expand Up @@ -193,11 +198,12 @@ def read_item(
if post_trigger_include:
request_options["postTriggerInclude"] = post_trigger_include

result = self.client_connection.ReadItem(document_link=doc_link, options=request_options)
result = self.client_connection.ReadItem(document_link=doc_link, options=request_options, **kwargs)
if response_hook:
response_hook(self.client_connection.last_response_headers, result)
return result

@distributed_trace
def read_all_items(
self,
max_item_count=None,
Expand All @@ -206,6 +212,7 @@ def read_all_items(
populate_query_metrics=None,
feed_options=None,
response_hook=None,
**kwargs
):
# type: (int, str, Dict[str, str], bool, Dict[str, Any], Optional[Callable]) -> QueryIterable
""" List all items in the container.
Expand Down Expand Up @@ -233,12 +240,13 @@ def read_all_items(
response_hook.clear()

items = self.client_connection.ReadItems(
collection_link=self.container_link, feed_options=feed_options, response_hook=response_hook
collection_link=self.container_link, feed_options=feed_options, response_hook=response_hook, **kwargs
)
if response_hook:
response_hook(self.client_connection.last_response_headers, items)
return items

@distributed_trace
def query_items_change_feed(
self,
partition_key_range_id=None,
Expand All @@ -247,6 +255,7 @@ def query_items_change_feed(
max_item_count=None,
feed_options=None,
response_hook=None,
**kwargs
):
""" Get a sorted list of items that were changed, in the order in which they were modified.

Expand Down Expand Up @@ -277,12 +286,13 @@ def query_items_change_feed(
response_hook.clear()

result = self.client_connection.QueryItemsChangeFeed(
self.container_link, options=feed_options, response_hook=response_hook
self.container_link, options=feed_options, response_hook=response_hook, **kwargs
)
if response_hook:
response_hook(self.client_connection.last_response_headers, result)
return result

@distributed_trace
def query_items(
self,
query, # type: str
Expand All @@ -296,6 +306,7 @@ def query_items(
populate_query_metrics=None, # type: bool
feed_options=None, # type: Dict[str, Any]
response_hook=None, # type: Optional[Callable]
**kwargs
):
# type: (...) -> QueryIterable
"""Return all results matching the given `query`.
Expand Down Expand Up @@ -363,11 +374,13 @@ def query_items(
options=feed_options,
partition_key=partition_key,
response_hook=response_hook,
**kwargs
)
if response_hook:
response_hook(self.client_connection.last_response_headers, items)
return items

@distributed_trace
def replace_item(
self,
item, # type: Union[str, Dict[str, Any]]
Expand All @@ -380,6 +393,7 @@ def replace_item(
post_trigger_include=None, # type: str
request_options=None, # type: Dict[str, Any]
response_hook=None, # type: Optional[Callable]
**kwargs
):
# type: (...) -> Dict[str, str]
""" Replaces the specified item if it exists in the container.
Expand Down Expand Up @@ -415,11 +429,14 @@ def replace_item(
if post_trigger_include:
request_options["postTriggerInclude"] = post_trigger_include

result = self.client_connection.ReplaceItem(document_link=item_link, new_document=body, options=request_options)
result = self.client_connection.ReplaceItem(
document_link=item_link, new_document=body, options=request_options, **kwargs
)
if response_hook:
response_hook(self.client_connection.last_response_headers, result)
return result

@distributed_trace
def upsert_item(
self,
body, # type: Dict[str, Any]
Expand All @@ -431,6 +448,7 @@ def upsert_item(
post_trigger_include=None, # type: str
request_options=None, # type: Dict[str, Any]
response_hook=None, # type: Optional[Callable]
**kwargs
):
# type: (...) -> Dict[str, str]
""" Insert or update the specified item.
Expand Down Expand Up @@ -466,11 +484,13 @@ def upsert_item(
if post_trigger_include:
request_options["postTriggerInclude"] = post_trigger_include

result = self.client_connection.UpsertItem(database_or_Container_link=self.container_link, document=body)
result = self.client_connection.UpsertItem(
database_or_Container_link=self.container_link, document=body, **kwargs)
if response_hook:
response_hook(self.client_connection.last_response_headers, result)
return result

@distributed_trace
def create_item(
self,
body, # type: Dict[str, Any]
Expand All @@ -483,6 +503,7 @@ def create_item(
indexing_directive=None, # type: Any
request_options=None, # type: Dict[str, Any]
response_hook=None, # type: Optional[Callable]
**kwargs
):
# type: (...) -> Dict[str, str]
""" Create an item in the container.
Expand Down Expand Up @@ -523,12 +544,13 @@ def create_item(
request_options["indexingDirective"] = indexing_directive

result = self.client_connection.CreateItem(
database_or_Container_link=self.container_link, document=body, options=request_options
database_or_Container_link=self.container_link, document=body, options=request_options, **kwargs
)
if response_hook:
response_hook(self.client_connection.last_response_headers, result)
return result

@distributed_trace
def delete_item(
self,
item, # type: Union[Dict[str, Any], str]
Expand All @@ -541,6 +563,7 @@ def delete_item(
post_trigger_include=None, # type: str
request_options=None, # type: Dict[str, Any]
response_hook=None, # type: Optional[Callable]
**kwargs
):
# type: (...) -> None
""" Delete the specified item from the container.
Expand Down Expand Up @@ -577,11 +600,12 @@ def delete_item(
request_options["postTriggerInclude"] = post_trigger_include

document_link = self._get_document_link(item)
result = self.client_connection.DeleteItem(document_link=document_link, options=request_options)
result = self.client_connection.DeleteItem(document_link=document_link, options=request_options, **kwargs)
if response_hook:
response_hook(self.client_connection.last_response_headers, result)

def read_offer(self, response_hook=None):
@distributed_trace
def read_offer(self, response_hook=None, **kwargs):
# type: (Optional[Callable]) -> Offer
""" Read the Offer object for this container.

Expand All @@ -596,7 +620,7 @@ def read_offer(self, response_hook=None):
"query": "SELECT * FROM root r WHERE r.resource=@link",
"parameters": [{"name": "@link", "value": link}],
}
offers = list(self.client_connection.QueryOffers(query_spec))
offers = list(self.client_connection.QueryOffers(query_spec, **kwargs))
if not offers:
raise HTTPFailure(StatusCodes.NOT_FOUND, "Could not find Offer for container " + self.container_link)

Expand All @@ -605,7 +629,8 @@ def read_offer(self, response_hook=None):

return Offer(offer_throughput=offers[0]["content"]["offerThroughput"], properties=offers[0])

def replace_throughput(self, throughput, response_hook=None):
@distributed_trace
def replace_throughput(self, throughput, response_hook=None, **kwargs):
# type: (int, Optional[Callable]) -> Offer
""" Replace the container's throughput

Expand All @@ -621,19 +646,20 @@ def replace_throughput(self, throughput, response_hook=None):
"query": "SELECT * FROM root r WHERE r.resource=@link",
"parameters": [{"name": "@link", "value": link}],
}
offers = list(self.client_connection.QueryOffers(query_spec))
offers = list(self.client_connection.QueryOffers(query_spec, **kwargs))
if not offers:
raise HTTPFailure(StatusCodes.NOT_FOUND, "Could not find Offer for container " + self.container_link)
new_offer = offers[0].copy()
new_offer["content"]["offerThroughput"] = throughput
data = self.client_connection.ReplaceOffer(offer_link=offers[0]["_self"], offer=offers[0])
data = self.client_connection.ReplaceOffer(offer_link=offers[0]["_self"], offer=offers[0], **kwargs)

if response_hook:
response_hook(self.client_connection.last_response_headers, data)

return Offer(offer_throughput=data["content"]["offerThroughput"], properties=data)

def read_all_conflicts(self, max_item_count=None, feed_options=None, response_hook=None):
@distributed_trace
def read_all_conflicts(self, max_item_count=None, feed_options=None, response_hook=None, **kwargs):
# type: (int, Dict[str, Any], Optional[Callable]) -> QueryIterable
""" List all conflicts in the container.

Expand All @@ -648,11 +674,14 @@ def read_all_conflicts(self, max_item_count=None, feed_options=None, response_ho
if max_item_count is not None:
feed_options["maxItemCount"] = max_item_count

result = self.client_connection.ReadConflicts(collection_link=self.container_link, feed_options=feed_options)
result = self.client_connection.ReadConflicts(
collection_link=self.container_link, feed_options=feed_options, **kwargs
)
if response_hook:
response_hook(self.client_connection.last_response_headers, result)
return result

@distributed_trace
def query_conflicts(
self,
query,
Expand All @@ -662,6 +691,7 @@ def query_conflicts(
max_item_count=None,
feed_options=None,
response_hook=None,
**kwargs
):
# type: (str, List, bool, Any, int, Dict[str, Any], Optional[Callable]) -> QueryIterable
"""Return all conflicts matching the given `query`.
Expand Down Expand Up @@ -691,12 +721,14 @@ def query_conflicts(
collection_link=self.container_link,
query=query if parameters is None else dict(query=query, parameters=parameters),
options=feed_options,
**kwargs
)
if response_hook:
response_hook(self.client_connection.last_response_headers, result)
return result

def get_conflict(self, conflict, partition_key, request_options=None, response_hook=None):
@distributed_trace
def get_conflict(self, conflict, partition_key, request_options=None, response_hook=None, **kwargs):
# type: (Union[str, Dict[str, Any]], Any, Dict[str, Any], Optional[Callable]) -> Dict[str, str]
""" Get the conflict identified by `id`.

Expand All @@ -714,13 +746,14 @@ def get_conflict(self, conflict, partition_key, request_options=None, response_h
request_options["partitionKey"] = self._set_partition_key(partition_key)

result = self.client_connection.ReadConflict(
conflict_link=self._get_conflict_link(conflict), options=request_options
conflict_link=self._get_conflict_link(conflict), options=request_options, **kwargs
)
if response_hook:
response_hook(self.client_connection.last_response_headers, result)
return result

def delete_conflict(self, conflict, partition_key, request_options=None, response_hook=None):
@distributed_trace
def delete_conflict(self, conflict, partition_key, request_options=None, response_hook=None, **kwargs):
# type: (Union[str, Dict[str, Any]], Any, Dict[str, Any], Optional[Callable]) -> None
""" Delete the specified conflict from the container.

Expand All @@ -738,7 +771,7 @@ def delete_conflict(self, conflict, partition_key, request_options=None, respons
request_options["partitionKey"] = self._set_partition_key(partition_key)

result = self.client_connection.DeleteConflict(
conflict_link=self._get_conflict_link(conflict), options=request_options
conflict_link=self._get_conflict_link(conflict), options=request_options, **kwargs
)
if response_hook:
response_hook(self.client_connection.last_response_headers, result)
Expand Down
Loading