diff --git a/sdk/core/azure-core/azure/core/tracing/abstract_span.py b/sdk/core/azure-core/azure/core/tracing/abstract_span.py index 9b7ca6753c0a..e5597d48df5e 100644 --- a/sdk/core/azure-core/azure/core/tracing/abstract_span.py +++ b/sdk/core/azure-core/azure/core/tracing/abstract_span.py @@ -17,7 +17,7 @@ try: from typing_extensions import Protocol except ImportError: - Protocol = object + Protocol = object # type: ignore class AbstractSpan(Protocol): diff --git a/sdk/core/azure-core/azure/core/tracing/common.py b/sdk/core/azure-core/azure/core/tracing/common.py index b0e425a14d10..22cea6315e73 100644 --- a/sdk/core/azure-core/azure/core/tracing/common.py +++ b/sdk/core/azure-core/azure/core/tracing/common.py @@ -36,11 +36,30 @@ TYPE_CHECKING = False if TYPE_CHECKING: - from typing import Any, Optional + from typing import Any, Optional, Union, Callable, List + + +def get_function_and_class_name(func, *args): + # type: (Callable, List[Any]) -> str + """ + Given a function and its unamed arguments, returns class_name.function_name. It assumes the first argument + is `self`. If there are no arguments then it only returns the function name. + + :param func: the function passed in + :type func: `collections.abc.Callable` + :param args: List of arguments passed into the function + :type args: List[Any] + """ + try: + return func.__qualname__ + except AttributeError: + if args: + return "{}.{}".format(args[0].__class__.__name__, func.__name__) # pylint: disable=protected-access + return func.__name__ def set_span_contexts(wrapped_span, span_instance=None): - # type: (AbstractSpan, Optional[AbstractSpan]) -> None + # type: (Union[AbstractSpan, None], Optional[AbstractSpan]) -> None """ Set the sdk context and the implementation context. `span_instance` will be used to set the implementation context if passed in else will use `wrapped_span.span_instance`. @@ -58,7 +77,7 @@ def set_span_contexts(wrapped_span, span_instance=None): def get_parent_span(parent_span): - # type: (Any) -> Tuple(AbstractSpan, AbstractSpan, Any) + # type: (Any) -> Optional[AbstractSpan] """ Returns the current span so that the function's span will be its child. It will create a new span if there is no current span in any of the context. diff --git a/sdk/core/azure-core/azure/core/tracing/context.py b/sdk/core/azure-core/azure/core/tracing/context.py index f8761cd71502..ecf822bf6fd7 100644 --- a/sdk/core/azure-core/azure/core/tracing/context.py +++ b/sdk/core/azure-core/azure/core/tracing/context.py @@ -13,7 +13,7 @@ TYPE_CHECKING = False if TYPE_CHECKING: - from typing import Any, Callable + from typing import Any, Callable, Optional, Type, Union from typing_extensions import Protocol else: Protocol = object @@ -21,7 +21,7 @@ try: import contextvars except ImportError: - contextvars = None + pass class ContextProtocol(Protocol): @@ -29,8 +29,8 @@ class ContextProtocol(Protocol): Implements set and get variables in a thread safe way. """ - def __init__(self, name, default): - # type: (string, Any) -> None + def __init__(self, name, default): # pylint: disable=super-init-not-called + # type: (str, Any) -> None pass def clear(self): @@ -55,8 +55,9 @@ class _AsyncContext(object): """ def __init__(self, name, default): + # type: (str, Any) -> None self.name = name - self.contextvar = contextvars.ContextVar(name) + self.contextvar = contextvars.ContextVar(name) # type: contextvars.ContextVar self.default = default if callable(default) else (lambda: default) def clear(self): @@ -84,6 +85,7 @@ class _ThreadLocalContext(object): """ Uses thread local storage to set and get variables globally in a thread safe way. """ + _thread_local = threading.local() def __init__(self, name, default): @@ -115,8 +117,10 @@ def set(self, value): class TracingContext(object): def __init__(self): # type: () -> None - context_class = _AsyncContext if contextvars else _ThreadLocalContext - self.current_span = context_class("current_span", None) + try: + self.current_span = _AsyncContext("current_span", None) # type: Union[_AsyncContext, _ThreadLocalContext] + except NameError: + self.current_span = _ThreadLocalContext("current_span", None) def with_current_context(self, func): # type: (Callable[[Any], Any]) -> Any @@ -141,4 +145,5 @@ def call_with_current_context(*args, **kwargs): return call_with_current_context + tracing_context = TracingContext() diff --git a/sdk/core/azure-core/azure/core/tracing/decorator.py b/sdk/core/azure-core/azure/core/tracing/decorator.py index 3986dcaf31ed..7b6ebb26cf45 100644 --- a/sdk/core/azure-core/azure/core/tracing/decorator.py +++ b/sdk/core/azure-core/azure/core/tracing/decorator.py @@ -31,15 +31,23 @@ from azure.core.settings import settings from azure.core.tracing.context import tracing_context +try: + from typing import TYPE_CHECKING +except ImportError: + TYPE_CHECKING = False + +if TYPE_CHECKING: + from typing import Callable, Any + def distributed_trace(func=None, name_of_span=None): - # type: (Callable[[Any], Any], str) -> Callable[[Any], Any] + # type: (Callable, str) -> Callable[[Any], Any] if func is None: return functools.partial(distributed_trace, name_of_span=name_of_span) @functools.wraps(func) - def wrapper_use_tracer(self, *args, **kwargs): - # type: (Any) -> Any + def wrapper_use_tracer(*args, **kwargs): + # type: (Any, Any) -> Any passed_in_parent = kwargs.pop("parent_span", None) orig_wrapped_span = tracing_context.current_span.get() wrapper_class = settings.tracing_implementation() @@ -50,18 +58,18 @@ def wrapper_use_tracer(self, *args, **kwargs): ans = None if parent_span is not None and orig_wrapped_span is None: common.set_span_contexts(parent_span) - name = name_of_span or self.__class__.__name__ + "." + func.__name__ + name = name_of_span or common.get_function_and_class_name(func, *args) child = parent_span.span(name=name) child.start() common.set_span_contexts(child) - ans = func(self, *args, **kwargs) + ans = func(*args, **kwargs) child.finish() common.set_span_contexts(parent_span) if orig_wrapped_span is None and passed_in_parent is None and original_span_instance is None: parent_span.finish() common.set_span_contexts(orig_wrapped_span, span_instance=original_span_instance) else: - ans = func(self, *args, **kwargs) + ans = func(*args, **kwargs) return ans return wrapper_use_tracer diff --git a/sdk/core/azure-core/azure/core/tracing/decorator_async.py b/sdk/core/azure-core/azure/core/tracing/decorator_async.py index 4df65fee77ef..0719a9c016c3 100644 --- a/sdk/core/azure-core/azure/core/tracing/decorator_async.py +++ b/sdk/core/azure-core/azure/core/tracing/decorator_async.py @@ -31,15 +31,23 @@ from azure.core.settings import settings from azure.core.tracing.context import tracing_context +try: + from typing import TYPE_CHECKING +except ImportError: + TYPE_CHECKING = False + +if TYPE_CHECKING: + from typing import Callable, Any + def distributed_trace_async(func=None, name_of_span=None): - # type: (Callable[[Any], Any], str) -> Callable[[Any], Any] + # type: (Callable, str) -> Callable[[Any], Any] if func is None: return functools.partial(distributed_trace_async, name_of_span=name_of_span) @functools.wraps(func) - async def wrapper_use_tracer(self, *args, **kwargs): - # type: (Any) -> Any + async def wrapper_use_tracer(*args, **kwargs): + # type: (Any, Any) -> Any passed_in_parent = kwargs.pop("parent_span", None) orig_wrapped_span = tracing_context.current_span.get() wrapper_class = settings.tracing_implementation() @@ -50,18 +58,18 @@ async def wrapper_use_tracer(self, *args, **kwargs): ans = None if parent_span is not None and orig_wrapped_span is None: common.set_span_contexts(parent_span) - name = name_of_span or self.__class__.__name__ + "." + func.__name__ + name = name_of_span or common.get_function_and_class_name(func, *args) child = parent_span.span(name=name) child.start() common.set_span_contexts(child) - ans = await func(self, *args, **kwargs) + ans = await func(*args, **kwargs) child.finish() common.set_span_contexts(parent_span) if orig_wrapped_span is None and passed_in_parent is None and original_span_instance is None: parent_span.finish() common.set_span_contexts(orig_wrapped_span, span_instance=original_span_instance) else: - ans = await func(self, *args, **kwargs) + ans = await func(*args, **kwargs) return ans return wrapper_use_tracer diff --git a/sdk/core/azure-core/azure/core/tracing/ext/opencensus_span.py b/sdk/core/azure-core/azure/core/tracing/ext/opencensus_span.py index 7aeb79c77e36..c1fd6eb6a4e0 100644 --- a/sdk/core/azure-core/azure/core/tracing/ext/opencensus_span.py +++ b/sdk/core/azure-core/azure/core/tracing/ext/opencensus_span.py @@ -5,6 +5,7 @@ """Implements azure.core.tracing.AbstractSpan to wrap opencensus spans.""" from opencensus.trace import Span, execution_context +from opencensus.trace.tracer import Tracer from opencensus.trace.span import SpanKind from opencensus.trace.link import Link from opencensus.trace.propagation import trace_context_http_header_format @@ -34,10 +35,8 @@ def __init__(self, span=None, name="span"): :param name: The name of the OpenCensus span to create if a new span is needed :type name: str """ - if not span: - tracer = self.get_current_tracer() - span = tracer.start_span(name=name) # type: Span - self._span_instance = span + tracer = self.get_current_tracer() + self._span_instance = span or tracer.start_span(name=name) self._span_component = "component" self._http_user_agent = "http.user_agent" self._http_method = "http.method" @@ -80,7 +79,7 @@ def to_header(self): :return: A key value pair dictionary """ tracer_from_context = self.get_current_tracer() - temp_headers = {} + temp_headers = {} # type: Dict[str, str] if tracer_from_context is not None: ctx = tracer_from_context.span_context try: @@ -146,7 +145,7 @@ def get_current_span(cls): @classmethod def get_current_tracer(cls): - # type: () -> tracer_module.Tracer + # type: () -> Tracer """ Get the current tracer from the execution context. Return None otherwise. """ @@ -165,7 +164,7 @@ def set_current_span(cls, span): @classmethod def set_current_tracer(cls, tracer): - # type: (tracer_module.Tracer) -> None + # type: (Tracer) -> None """ Set the given tracer as the current tracer in the execution context. :param tracer: The tracer to set the current tracer as diff --git a/sdk/core/azure-core/tests/test_tracing_context.py b/sdk/core/azure-core/tests/test_tracing_context.py index 984d2ad480d3..875c98105348 100644 --- a/sdk/core/azure-core/tests/test_tracing_context.py +++ b/sdk/core/azure-core/tests/test_tracing_context.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ import unittest + try: from unittest import mock except ImportError: @@ -11,28 +12,10 @@ from azure.core.tracing.context import tracing_context from azure.core.tracing import AbstractSpan from azure.core.settings import settings +from tracing_common import ContextHelper import os -class ContextHelper(object): - def __init__(self, environ={}, tracer_to_use=None): - self.orig_sdk_context_span = tracing_context.current_span.get() - self.os_env = mock.patch.dict(os.environ, environ) - self.tracer_to_use = tracer_to_use - - def __enter__(self): - self.orig_sdk_context_span = tracing_context.current_span.get() - tracing_context.current_span.clear() - settings.tracing_implementation.set_value(self.tracer_to_use) - self.os_env.start() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - tracing_context.current_span.set(self.orig_sdk_context_span) - settings.tracing_implementation.unset_value() - self.os_env.stop() - - class TestContext(unittest.TestCase): def test_current_span(self): with ContextHelper(): diff --git a/sdk/core/azure-core/tests/test_tracing_decorator.py b/sdk/core/azure-core/tests/test_tracing_decorator.py index 764c09396255..5c5a1df7b9e6 100644 --- a/sdk/core/azure-core/tests/test_tracing_decorator.py +++ b/sdk/core/azure-core/tests/test_tracing_decorator.py @@ -68,7 +68,17 @@ def check_name_is_different(self): time.sleep(0.001) +def random_function(): + pass + + class TestCommon(object): + def test_get_function_and_class_name(self): + with ContextHelper(): + client = MockClient() + assert common.get_function_and_class_name(client.get_foo, client) == "MockClient.get_foo" + assert common.get_function_and_class_name(random_function) == "random_function" + def test_set_span_context(self): with ContextHelper(environ={"AZURE_SDK_TRACING_IMPLEMENTATION": "opencensus"}): wrapper = settings.tracing_implementation()