diff --git a/sdk/core/azure-core/azure/core/tracing/decorator.py b/sdk/core/azure-core/azure/core/tracing/decorator.py index 1555b7bd58d0..9f647f3333c4 100644 --- a/sdk/core/azure-core/azure/core/tracing/decorator.py +++ b/sdk/core/azure-core/azure/core/tracing/decorator.py @@ -32,8 +32,11 @@ from azure.core.tracing.context import tracing_context -def distributed_trace(func): - # type: (Callable[[Any], Any]) -> Callable[[Any], Any] +def distributed_trace(func=None, name_of_span=None): + # type: (Callable[[Any], Any], str) -> Callable[[Any], Any] + if func is None: + return functools.partial(distributed_trace, name_of_span=name_of_span) + @functools.wraps(func) def wrapper_use_tracer(self, *args, **kwargs): # type: (Any) -> Any @@ -47,7 +50,7 @@ def wrapper_use_tracer(self, *args, **kwargs): ans = None if common.should_use_trace(parent_span): common.set_span_contexts(parent_span) - name = self.__class__.__name__ + "." + func.__name__ + name = name_of_span or self.__class__.__name__ + "." + func.__name__ child = parent_span.span(name=name) child.start() common.set_span_contexts(child) diff --git a/sdk/core/azure-core/azure/core/tracing/decorator_async.py b/sdk/core/azure-core/azure/core/tracing/decorator_async.py index 56d3f0121497..1440f86525e2 100644 --- a/sdk/core/azure-core/azure/core/tracing/decorator_async.py +++ b/sdk/core/azure-core/azure/core/tracing/decorator_async.py @@ -32,8 +32,11 @@ from azure.core.tracing.context import tracing_context -def distributed_trace_async(func): - # type: (Callable[[Any], Any]) -> Callable[[Any], Any] +def distributed_trace_async(func=None, name_of_span=None): + # type: (Callable[[Any], Any], str) -> Callable[[Any], Any] + if func is None: + return functools.partial(distributed_trace_async, name_of_span=name_of_span) + @functools.wraps(func) async def wrapper_use_tracer(self, *args, **kwargs): # type: (Any) -> Any @@ -47,7 +50,7 @@ async def wrapper_use_tracer(self, *args, **kwargs): ans = None if common.should_use_trace(parent_span): common.set_span_contexts(parent_span) - name = self.__class__.__name__ + "." + func.__name__ + name = name_of_span or self.__class__.__name__ + "." + func.__name__ child = parent_span.span(name=name) child.start() common.set_span_contexts(child) diff --git a/sdk/core/azure-core/tests/azure_core_asynctests/test_tracing_decorator_async.py b/sdk/core/azure-core/tests/azure_core_asynctests/test_tracing_decorator_async.py index 52ee23df2774..7fb903037d80 100644 --- a/sdk/core/azure-core/tests/azure_core_asynctests/test_tracing_decorator_async.py +++ b/sdk/core/azure-core/tests/azure_core_asynctests/test_tracing_decorator_async.py @@ -61,6 +61,26 @@ async def get_foo(self): time.sleep(0.001) return 5 + @distributed_trace_async(name_of_span="different name") + async def check_name_is_different(self): + time.sleep(0.001) + + +@pytest.mark.asyncio +async def test_decorator_has_different_name(): + with ContextHelper(): + exporter = MockExporter() + trace = tracer_module.Tracer(sampler=AlwaysOnSampler(), exporter=exporter) + with trace.span("overall"): + client = MockClient() + await client.check_name_is_different() + trace.finish() + exporter.build_tree() + parent = exporter.root + assert len(parent.children) == 2 + assert parent.children[0].span_data.name == "MockClient.__init__" + assert parent.children[1].span_data.name == "different name" + @pytest.mark.asyncio async def test_with_nothing_imported(): diff --git a/sdk/core/azure-core/tests/test_tracing_decorator.py b/sdk/core/azure-core/tests/test_tracing_decorator.py index 385bda8e4f92..873871f5b222 100644 --- a/sdk/core/azure-core/tests/test_tracing_decorator.py +++ b/sdk/core/azure-core/tests/test_tracing_decorator.py @@ -63,6 +63,10 @@ def get_foo(self): time.sleep(0.001) return 5 + @distributed_trace(name_of_span="different name") + def check_name_is_different(self): + time.sleep(0.001) + class TestCommon(object): def test_set_span_context(self): @@ -112,6 +116,20 @@ def test_should_use_trace(self): class TestDecorator(object): + def test_decorator_has_different_name(self): + with ContextHelper(): + exporter = MockExporter() + trace = tracer_module.Tracer(sampler=AlwaysOnSampler(), exporter=exporter) + with trace.span("overall"): + client = MockClient() + client.check_name_is_different() + trace.finish() + exporter.build_tree() + parent = exporter.root + assert len(parent.children) == 2 + assert parent.children[0].span_data.name == "MockClient.__init__" + assert parent.children[1].span_data.name == "different name" + def test_with_nothing_imported(self): with ContextHelper(): opencensus = sys.modules["opencensus"]