Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion sdk/identity/azure-identity/azure/identity/_authn_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from azure.core.exceptions import ClientAuthenticationError
from azure.core.pipeline import Pipeline
from azure.core.pipeline.policies import ContentDecodePolicy, NetworkTraceLoggingPolicy, RetryPolicy
from azure.core.pipeline.policies.distributed_tracing import DistributedTracingPolicy
from azure.core.pipeline.transport import HttpTransport, RequestsTransport
from msal import TokenCache

Expand Down Expand Up @@ -73,6 +74,10 @@ def _deserialize_and_cache_token(self, response, scopes, request_time):

# now we have an int expires_on, ensure the cache entry gets it
payload["expires_on"] = expires_on
if "expires_in" in payload:
payload["expires_in"] = int(payload["expires_in"])
if "ext_expires_in" in payload:
payload["ext_expires_in"] = int(payload["ext_expires_in"])

self._cache.add({"response": payload, "scope": scopes})

Expand Down Expand Up @@ -119,7 +124,7 @@ class AuthnClient(AuthnClientBase):
def __init__(self, auth_url, config=None, policies=None, transport=None, **kwargs):
# type: (str, Optional[Configuration], Optional[Iterable[HTTPPolicy]], Optional[HttpTransport], Mapping[str, Any]) -> None
config = config or self.create_config(**kwargs)
policies = policies or [ContentDecodePolicy(), config.logging_policy, config.retry_policy]
policies = policies or [ContentDecodePolicy(), config.logging_policy, config.retry_policy, DistributedTracingPolicy()]
if not transport:
transport = RequestsTransport(**kwargs)
self._pipeline = Pipeline(transport=transport, policies=policies)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from azure.core.credentials import AccessToken

from azure.core import Configuration
from azure.core.tracing.decorator import distributed_trace
from azure.core.exceptions import ClientAuthenticationError, HttpResponseError
from azure.core.pipeline.policies import ContentDecodePolicy, HeadersPolicy, NetworkTraceLoggingPolicy, RetryPolicy

Expand Down Expand Up @@ -79,6 +80,7 @@ def __init__(self, config=None, **kwargs):
super(ImdsCredential, self).__init__(endpoint=Endpoints.IMDS, client_cls=AuthnClient, config=config, **kwargs)
self._endpoint_available = None # type: Optional[bool]

@distributed_trace
def get_token(self, *scopes):
# type: (*str) -> AccessToken
"""
Expand Down Expand Up @@ -137,6 +139,7 @@ def __init__(self, config=None, **kwargs):
endpoint=endpoint, client_cls=AuthnClient, config=config, **kwargs
)

@distributed_trace
def get_token(self, *scopes):
# type: (*str) -> AccessToken
"""
Expand Down
34 changes: 20 additions & 14 deletions sdk/identity/azure-identity/azure/identity/aio/_authn_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from azure.core import Configuration
from azure.core.credentials import AccessToken
from azure.core.pipeline import AsyncPipeline
from azure.core.pipeline.policies.distributed_tracing import DistributedTracingPolicy
from azure.core.pipeline.policies import AsyncRetryPolicy, ContentDecodePolicy, HTTPPolicy, NetworkTraceLoggingPolicy
from azure.core.pipeline.transport import AsyncHttpTransport
from azure.core.pipeline.transport.requests_asyncio import AsyncioRequestsTransport
Expand All @@ -19,28 +20,33 @@ class AsyncAuthnClient(AuthnClientBase):
"""Async authentication client"""

def __init__(
self,
auth_url: str,
config: Optional[Configuration] = None,
policies: Optional[Iterable[HTTPPolicy]] = None,
transport: Optional[AsyncHttpTransport] = None,
**kwargs: Mapping[str, Any]
self,
auth_url: str,
config: Optional[Configuration] = None,
policies: Optional[Iterable[HTTPPolicy]] = None,
transport: Optional[AsyncHttpTransport] = None,
**kwargs: Mapping[str, Any]
) -> None:
config = config or self.create_config(**kwargs)
policies = policies or [ContentDecodePolicy(), config.logging_policy, config.retry_policy]
policies = policies or [
ContentDecodePolicy(),
config.logging_policy,
config.retry_policy,
DistributedTracingPolicy(),
]
if not transport:
transport = AsyncioRequestsTransport(**kwargs)
self._pipeline = AsyncPipeline(transport=transport, policies=policies)
super(AsyncAuthnClient, self).__init__(auth_url, **kwargs)

async def request_token(
self,
scopes: Iterable[str],
method: Optional[str] = "POST",
headers: Optional[Mapping[str, str]] = None,
form_data: Optional[Mapping[str, str]] = None,
params: Optional[Dict[str, str]] = None,
**kwargs: Any
self,
scopes: Iterable[str],
method: Optional[str] = "POST",
headers: Optional[Mapping[str, str]] = None,
form_data: Optional[Mapping[str, str]] = None,
params: Optional[Dict[str, str]] = None,
**kwargs: Any
) -> AccessToken:
request = self._prepare_request(method, headers=headers, form_data=form_data, params=params)
request_time = int(time.time())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from azure.core import Configuration
from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError, HttpResponseError
from azure.core.tracing.decorator_async import distributed_trace_async
from azure.core.pipeline.policies import ContentDecodePolicy, HeadersPolicy, NetworkTraceLoggingPolicy, AsyncRetryPolicy

from ._authn_client import AsyncAuthnClient
Expand Down Expand Up @@ -41,6 +42,7 @@ def __init__(self, config: Optional[Configuration] = None, **kwargs: Any) -> Non
super().__init__(endpoint=Endpoints.IMDS, config=config, **kwargs)
self._endpoint_available = None # type: Optional[bool]

@distributed_trace_async
async def get_token(self, *scopes: str) -> AccessToken:
"""
Asynchronously request an access token for `scopes`.
Expand Down Expand Up @@ -96,6 +98,7 @@ def __init__(self, config: Optional[Configuration] = None, **kwargs: Any) -> Non
if self._endpoint_available:
super().__init__(endpoint=endpoint, config=config, **kwargs) # type: ignore

@distributed_trace_async
async def get_token(self, *scopes: str) -> AccessToken:
"""
Asynchronously request an access token for `scopes`.
Expand Down
6 changes: 6 additions & 0 deletions sdk/identity/azure-identity/azure/identity/aio/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from azure.core import Configuration
from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
from azure.core.tracing.decorator_async import distributed_trace_async
from azure.core.pipeline.policies import ContentDecodePolicy, HeadersPolicy, NetworkTraceLoggingPolicy, AsyncRetryPolicy

from ._authn_client import AsyncAuthnClient
Expand Down Expand Up @@ -44,6 +45,7 @@ def __init__(
super(ClientSecretCredential, self).__init__(client_id, secret, tenant_id, **kwargs)
self._client = AsyncAuthnClient(Endpoints.AAD_OAUTH2_V2_FORMAT.format(tenant_id), config, **kwargs)

@distributed_trace_async
async def get_token(self, *scopes: str) -> AccessToken:
"""
Asynchronously request an access token for `scopes`.
Expand Down Expand Up @@ -81,6 +83,7 @@ def __init__(
super(CertificateCredential, self).__init__(client_id, tenant_id, certificate_path, **kwargs)
self._client = AsyncAuthnClient(Endpoints.AAD_OAUTH2_V2_FORMAT.format(tenant_id), config, **kwargs)

@distributed_trace_async
async def get_token(self, *scopes: str) -> AccessToken:
"""
Asynchronously request an access token for `scopes`.
Expand Down Expand Up @@ -130,6 +133,7 @@ def __init__(self, **kwargs: Mapping[str, Any]) -> None:
**kwargs
)

@distributed_trace_async
async def get_token(self, *scopes: str) -> AccessToken:
"""
Asynchronously request an access token for `scopes`.
Expand Down Expand Up @@ -174,6 +178,7 @@ def create_config(**kwargs: Dict[str, Any]) -> Configuration:
"""
return Configuration(**kwargs)

@distributed_trace_async
async def get_token(self, *scopes: str) -> AccessToken:
"""
Asynchronously request an access token for `scopes`.
Expand All @@ -194,6 +199,7 @@ class ChainedTokenCredential(ChainedTokenCredential):
:type credentials: :class:`azure.core.credentials.TokenCredential`
"""

@distributed_trace_async
async def get_token(self, *scopes: str) -> AccessToken: # type: ignore
"""
Asynchronously request a token from each credential, in order, returning the first token
Expand Down
6 changes: 6 additions & 0 deletions sdk/identity/azure-identity/azure/identity/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from azure.core import Configuration
from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
from azure.core.tracing.decorator import distributed_trace
from azure.core.pipeline.policies import ContentDecodePolicy, HeadersPolicy, NetworkTraceLoggingPolicy, RetryPolicy

from ._authn_client import AuthnClient
Expand Down Expand Up @@ -46,6 +47,7 @@ def __init__(self, client_id, secret, tenant_id, config=None, **kwargs):
super(ClientSecretCredential, self).__init__(client_id, secret, tenant_id, **kwargs)
self._client = AuthnClient(Endpoints.AAD_OAUTH2_V2_FORMAT.format(tenant_id), config, **kwargs)

@distributed_trace
def get_token(self, *scopes):
# type (*str) -> AccessToken
"""
Expand Down Expand Up @@ -78,6 +80,7 @@ def __init__(self, client_id, tenant_id, certificate_path, config=None, **kwargs
self._client = AuthnClient(Endpoints.AAD_OAUTH2_V2_FORMAT.format(tenant_id), config, **kwargs)
super(CertificateCredential, self).__init__(client_id, tenant_id, certificate_path, **kwargs)

@distributed_trace
def get_token(self, *scopes):
# type (*str) -> AccessToken
"""
Expand Down Expand Up @@ -129,6 +132,7 @@ def __init__(self, **kwargs):
**kwargs
)

@distributed_trace
def get_token(self, *scopes):
# type (*str) -> AccessToken
"""
Expand Down Expand Up @@ -176,6 +180,7 @@ def create_config(**kwargs):
"""
return Configuration(**kwargs)

@distributed_trace
def get_token(self, *scopes):
# type (*str) -> AccessToken
"""
Expand Down Expand Up @@ -203,6 +208,7 @@ def __init__(self, *credentials):
raise ValueError("at least one credential is required")
self._credentials = credentials

@distributed_trace
def get_token(self, *scopes):
# type (*str) -> AccessToken
"""
Expand Down