Skip to content
Prev Previous commit
Next Next commit
mypy fixes
  • Loading branch information
SuyogSoti committed Jul 31, 2019
commit 1b688260be06d39ff3593b615cd780297e35f87a
20 changes: 18 additions & 2 deletions sdk/core/azure-core/azure/core/tracing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,27 @@
TYPE_CHECKING = False

if TYPE_CHECKING:
from typing import Any, Optional
from typing import Any, Optional, Union, Callable


def get_function_and_class_name(func, *args):
# type: (Callable, List[Any]) -> str
"""
Given a function and its unamed arguments, returns class_name.function_name. It assumes the first argument
is `self`. If there are no arguments then it only returns the function name.
:param func: the function passed in
:type func: `collections.abc.Callable`
:param args: List of arguments passed into the function
:type args: List[Any]
"""
if args:
return "{}.{}".format(args[0].__class_.__name__, func.__name__) # pylint: disable=protected-access
return func.__name__


def set_span_contexts(wrapped_span, span_instance=None):
# type: (AbstractSpan, Optional[AbstractSpan]) -> None
# type: (Union[AbstractSpan, None], Optional[AbstractSpan]) -> None
"""
Set the sdk context and the implementation context. `span_instance` will be used to set the implementation context
if passed in else will use `wrapped_span.span_instance`.
Expand Down
11 changes: 7 additions & 4 deletions sdk/core/azure-core/azure/core/tracing/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class ContextProtocol(Protocol):
Implements set and get variables in a thread safe way.
"""

def __init__(self, name, default):
# type: (string, Any) -> None
def __init__(self, name, default): # pylint: disable=super-init-not-called
# type: (str, Any) -> None
pass

def clear(self):
Expand All @@ -55,8 +55,9 @@ class _AsyncContext(object):
"""

def __init__(self, name, default):
# type: (str, Any) -> None
self.name = name
self.contextvar = contextvars.ContextVar(name)
self.contextvar = contextvars.ContextVar(name) # type: contextvars.ContextVar
self.default = default if callable(default) else (lambda: default)

def clear(self):
Expand Down Expand Up @@ -84,6 +85,7 @@ class _ThreadLocalContext(object):
"""
Uses thread local storage to set and get variables globally in a thread safe way.
"""

_thread_local = threading.local()

def __init__(self, name, default):
Expand Down Expand Up @@ -115,7 +117,7 @@ def set(self, value):
class TracingContext(object):
def __init__(self):
# type: () -> None
context_class = _AsyncContext if contextvars else _ThreadLocalContext
context_class = _AsyncContext if contextvars else _ThreadLocalContext # type: ContextProtocol
self.current_span = context_class("current_span", None)

def with_current_context(self, func):
Expand All @@ -141,4 +143,5 @@ def call_with_current_context(*args, **kwargs):

return call_with_current_context


tracing_context = TracingContext()
22 changes: 16 additions & 6 deletions sdk/core/azure-core/azure/core/tracing/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,23 @@
from azure.core.settings import settings
from azure.core.tracing.context import tracing_context

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

if TYPE_CHECKING:
from typing import Callable, Any


def distributed_trace(func=None, name_of_span=None):
# type: (Callable[[Any], Any], str) -> Callable[[Any], Any]
# type: (Callable, str) -> Callable[[Any], Any]
if func is None:
return functools.partial(distributed_trace, name_of_span=name_of_span)

@functools.wraps(func)
def wrapper_use_tracer(self, *args, **kwargs):
# type: (Any) -> Any
def wrapper_use_tracer(*args, **kwargs):
# type: (Any, Any) -> Any
passed_in_parent = kwargs.pop("parent_span", None)
orig_wrapped_span = tracing_context.current_span.get()
wrapper_class = settings.tracing_implementation()
Expand All @@ -50,18 +58,20 @@ def wrapper_use_tracer(self, *args, **kwargs):
ans = None
if parent_span is not None and orig_wrapped_span is None:
common.set_span_contexts(parent_span)
name = name_of_span or self.__class__.__name__ + "." + func.__name__
name = (
name_of_span or getattr(func, "__qualname__", None) or common.get_function_and_class_name(func, *args)
)
child = parent_span.span(name=name)
child.start()
common.set_span_contexts(child)
ans = func(self, *args, **kwargs)
ans = func(*args, **kwargs)
child.finish()
common.set_span_contexts(parent_span)
if orig_wrapped_span is None and passed_in_parent is None and original_span_instance is None:
parent_span.finish()
common.set_span_contexts(orig_wrapped_span, span_instance=original_span_instance)
else:
ans = func(self, *args, **kwargs)
ans = func(*args, **kwargs)
return ans

return wrapper_use_tracer
22 changes: 16 additions & 6 deletions sdk/core/azure-core/azure/core/tracing/decorator_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,23 @@
from azure.core.settings import settings
from azure.core.tracing.context import tracing_context

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

if TYPE_CHECKING:
from typing import Callable, Any


def distributed_trace_async(func=None, name_of_span=None):
# type: (Callable[[Any], Any], str) -> Callable[[Any], Any]
# type: (Callable, str) -> Callable[[Any], Any]
if func is None:
return functools.partial(distributed_trace_async, name_of_span=name_of_span)

@functools.wraps(func)
async def wrapper_use_tracer(self, *args, **kwargs):
# type: (Any) -> Any
async def wrapper_use_tracer(*args, **kwargs):
# type: (Any, Any) -> Any
passed_in_parent = kwargs.pop("parent_span", None)
orig_wrapped_span = tracing_context.current_span.get()
wrapper_class = settings.tracing_implementation()
Expand All @@ -50,18 +58,20 @@ async def wrapper_use_tracer(self, *args, **kwargs):
ans = None
if parent_span is not None and orig_wrapped_span is None:
common.set_span_contexts(parent_span)
name = name_of_span or self.__class__.__name__ + "." + func.__name__
name = (
name_of_span or getattr(func, "__qualname__", None) or common.get_function_and_class_name(func, *args)
)
child = parent_span.span(name=name)
child.start()
common.set_span_contexts(child)
ans = await func(self, *args, **kwargs)
ans = await func(*args, **kwargs)
child.finish()
common.set_span_contexts(parent_span)
if orig_wrapped_span is None and passed_in_parent is None and original_span_instance is None:
parent_span.finish()
common.set_span_contexts(orig_wrapped_span, span_instance=original_span_instance)
else:
ans = await func(self, *args, **kwargs)
ans = await func(*args, **kwargs)
return ans

return wrapper_use_tracer
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def to_header(self):
:return: A key value pair dictionary
"""
tracer_from_context = self.get_current_tracer()
temp_headers = {}
temp_headers = {} # type: Dict[str, str]
if tracer_from_context is not None:
ctx = tracer_from_context.span_context
try:
Expand Down