Skip to content
Prev Previous commit
Next Next commit
more mypy fixes
  • Loading branch information
SuyogSoti committed Aug 2, 2019
commit 87e72261174284ccfdb19774c335a9a7ab5f8b24
2 changes: 1 addition & 1 deletion sdk/core/azure-core/azure/core/tracing/abstract_span.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
try:
from typing_extensions import Protocol
except ImportError:
Protocol = object
Protocol = object # type: ignore


class AbstractSpan(Protocol):
Expand Down
4 changes: 2 additions & 2 deletions sdk/core/azure-core/azure/core/tracing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
TYPE_CHECKING = False

if TYPE_CHECKING:
from typing import Any, Optional, Union, Callable
from typing import Any, Optional, Union, Callable, List


def get_function_and_class_name(func, *args):
Expand Down Expand Up @@ -77,7 +77,7 @@ def set_span_contexts(wrapped_span, span_instance=None):


def get_parent_span(parent_span):
# type: (Any) -> Tuple(AbstractSpan, AbstractSpan, Any)
# type: (Any) -> Optional[AbstractSpan]
"""
Returns the current span so that the function's span will be its child. It will create a new span if there is
no current span in any of the context.
Expand Down
10 changes: 6 additions & 4 deletions sdk/core/azure-core/azure/core/tracing/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
TYPE_CHECKING = False

if TYPE_CHECKING:
from typing import Any, Callable
from typing import Any, Callable, Optional, Type, Union
from typing_extensions import Protocol
else:
Protocol = object

try:
import contextvars
except ImportError:
contextvars = None
pass


class ContextProtocol(Protocol):
Expand Down Expand Up @@ -117,8 +117,10 @@ def set(self, value):
class TracingContext(object):
def __init__(self):
# type: () -> None
context_class = _AsyncContext if contextvars else _ThreadLocalContext # type: ContextProtocol
self.current_span = context_class("current_span", None)
try:
self.current_span = _AsyncContext("current_span", None) # type: Union[_AsyncContext, _ThreadLocalContext]
except NameError:
self.current_span = _ThreadLocalContext("current_span", None)

def with_current_context(self, func):
# type: (Callable[[Any], Any]) -> Any
Expand Down
11 changes: 5 additions & 6 deletions sdk/core/azure-core/azure/core/tracing/ext/opencensus_span.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""Implements azure.core.tracing.AbstractSpan to wrap opencensus spans."""

from opencensus.trace import Span, execution_context
from opencensus.trace.tracer import Tracer
from opencensus.trace.span import SpanKind
from opencensus.trace.link import Link
from opencensus.trace.propagation import trace_context_http_header_format
Expand Down Expand Up @@ -34,10 +35,8 @@ def __init__(self, span=None, name="span"):
:param name: The name of the OpenCensus span to create if a new span is needed
:type name: str
"""
if not span:
tracer = self.get_current_tracer()
span = tracer.start_span(name=name) # type: Span
self._span_instance = span
tracer = self.get_current_tracer()
self._span_instance = span or tracer.start_span(name=name)
self._span_component = "component"
self._http_user_agent = "http.user_agent"
self._http_method = "http.method"
Expand Down Expand Up @@ -146,7 +145,7 @@ def get_current_span(cls):

@classmethod
def get_current_tracer(cls):
# type: () -> tracer_module.Tracer
# type: () -> Tracer
"""
Get the current tracer from the execution context. Return None otherwise.
"""
Expand All @@ -165,7 +164,7 @@ def set_current_span(cls, span):

@classmethod
def set_current_tracer(cls, tracer):
# type: (tracer_module.Tracer) -> None
# type: (Tracer) -> None
"""
Set the given tracer as the current tracer in the execution context.
:param tracer: The tracer to set the current tracer as
Expand Down
2 changes: 1 addition & 1 deletion sdk/core/azure-core/tests/test_tracing_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class TestCommon(object):
def test_get_function_and_class_name(self):
with ContextHelper():
client = MockClient()
assert common.get_function_and_class_name(client.get_foo) == "MockClient.get_foo"
assert common.get_function_and_class_name(client.get_foo, client) == "MockClient.get_foo"
assert common.get_function_and_class_name(random_function) == "random_function"

def test_set_span_context(self):
Expand Down