Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions sdk/core/azure-core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@

### Features Added

- `AccessToken` now has an optional `refresh_on` attribute that can be used to specify when the token should be refreshed. #36183
- `BearerTokenCredentialPolicy` and `AsyncBearerTokenCredentialPolicy` now check the `refresh_on` attribute when determining if a token request should be made.
- Added `azure.core.AzureClouds` enum to represent the different Azure clouds.
- Added azure.core.AzureClouds enum to represent the different Azure clouds.
- Added two new credential protocol classes, `SupportsTokenInfo` and `AsyncSupportsTokenInfo`, to offer more extensibility in supporting various token acquisition scenarios. #36565
- Each new protocol class defines a `get_token_info` method that returns an `AccessTokenInfo` object.
- Added a new `TokenRequestOptions` class, which is a `TypedDict` with optional parameters, that can be used to define options for token requests through the `get_token_info` method. #36565
- Added a new `AccessTokenInfo` class, which is returned by `get_token_info` implementations. This class contains the token, its expiration time, and optional additional information like when a token should be refreshed. #36565
- `BearerTokenCredentialPolicy` and `AsyncBearerTokenCredentialPolicy` now first check if a credential has the `get_token_info` method defined. If so, the `get_token_info` method is used to acquire a token. Otherwise, the `get_token` method is used. #36565
- These policies now also check the `refresh_on` attribute when determining if a new token request should be made.

### Breaking Changes

Expand Down
82 changes: 76 additions & 6 deletions sdk/core/azure-core/azure/core/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,64 @@
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
from typing import Any, NamedTuple, Optional
from typing import Any, NamedTuple, Optional, TypedDict, Union
from typing_extensions import Protocol, runtime_checkable


class AccessToken(NamedTuple):
"""Represents an OAuth access token."""

token: str
"""The token string."""
expires_on: int
refresh_on: Optional[int] = None
"""The token's expiration time in Unix time."""


AccessToken.token.__doc__ = """The token string."""
AccessToken.expires_on.__doc__ = """The token's expiration time in Unix time."""
AccessToken.refresh_on.__doc__ = """When the token should be refreshed in Unix time."""
class AccessTokenInfo:
"""Information about an OAuth access token.

This class is an alternative to `AccessToken` which provides additional information about the token.

:param str token: The token string.
:param int expires_on: The token's expiration time in Unix time.
:keyword str token_type: The type of access token. Defaults to 'Bearer'.
:keyword int refresh_on: Specifies the time, in Unix time, when the cached token should be proactively
refreshed. Optional.
"""

token: str
"""The token string."""
expires_on: int
"""The token's expiration time in Unix time."""
token_type: str
"""The type of access token."""
refresh_on: Optional[int]
"""Specifies the time, in Unix time, when the cached token should be proactively refreshed. Optional."""

def __init__(
self, token: str, expires_on: int, *, token_type: str = "Bearer", refresh_on: Optional[int] = None
) -> None:
self.token = token
self.expires_on = expires_on
self.token_type = token_type
self.refresh_on = refresh_on

def __repr__(self) -> str:
return "AccessTokenInfo(token='{}', expires_on={}, token_type='{}', refresh_on={})".format(
self.token, self.expires_on, self.token_type, self.refresh_on
)


class TokenRequestOptions(TypedDict, total=False):
"""Options to use for access token requests. All parameters are optional."""

claims: str
"""Additional claims required in the token, such as those returned in a resource provider's claims
challenge following an authorization failure."""
tenant_id: str
"""The tenant ID to include in the token request."""
enable_cae: bool
"""Indicates whether to enable Continuous Access Evaluation (CAE) for the requested token."""


@runtime_checkable
Expand All @@ -30,7 +73,7 @@ def get_token(
claims: Optional[str] = None,
tenant_id: Optional[str] = None,
enable_cae: bool = False,
**kwargs: Any
**kwargs: Any,
) -> AccessToken:
"""Request an access token for `scopes`.

Expand All @@ -48,6 +91,29 @@ def get_token(
...


@runtime_checkable
class SupportsTokenInfo(Protocol):
"""Protocol for classes able to provide OAuth access tokens with additional properties."""

def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo:
"""Request an access token for `scopes`.

This is an alternative to `get_token` to enable certain scenarios that require additional properties
on the token.

:param str scopes: The type of access needed.
:keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional.
:paramtype options: TokenRequestOptions

:rtype: AccessTokenInfo
:return: An AccessTokenInfo instance containing information about the token.
"""
...


TokenProvider = Union[TokenCredential, SupportsTokenInfo]


class AzureNamedKey(NamedTuple):
"""Represents a name and key pair."""

Expand All @@ -59,8 +125,12 @@ class AzureNamedKey(NamedTuple):
"AzureKeyCredential",
"AzureSasCredential",
"AccessToken",
"AccessTokenInfo",
"SupportsTokenInfo",
"AzureNamedKeyCredential",
"TokenCredential",
"TokenRequestOptions",
"TokenProvider",
]


Expand Down
42 changes: 40 additions & 2 deletions sdk/core/azure-core/azure/core/credentials_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
# ------------------------------------
from __future__ import annotations
from types import TracebackType
from typing import Any, Optional, AsyncContextManager, Type
from typing import Any, Optional, AsyncContextManager, Type, Union
from typing_extensions import Protocol, runtime_checkable
from .credentials import AccessToken as _AccessToken
from .credentials import (
AccessToken as _AccessToken,
AccessTokenInfo as _AccessTokenInfo,
TokenRequestOptions as _TokenRequestOptions,
)


@runtime_checkable
Expand Down Expand Up @@ -46,3 +50,37 @@ async def __aexit__(
traceback: Optional[TracebackType] = None,
) -> None:
pass


@runtime_checkable
class AsyncSupportsTokenInfo(Protocol, AsyncContextManager["AsyncSupportsTokenInfo"]):
"""Protocol for classes able to provide OAuth access tokens with additional properties."""

async def get_token_info(self, *scopes: str, options: Optional[_TokenRequestOptions] = None) -> _AccessTokenInfo:
"""Request an access token for `scopes`.

This is an alternative to `get_token` to enable certain scenarios that require additional properties
on the token.

:param str scopes: The type of access needed.
:keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional.
:paramtype options: TokenRequestOptions

:rtype: AccessTokenInfo
:return: An AccessTokenInfo instance containing the token string and its expiration time in Unix time.
"""
...

async def close(self) -> None:
pass

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
pass


AsyncTokenProvider = Union[AsyncTokenCredential, AsyncSupportsTokenInfo]
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# license information.
# -------------------------------------------------------------------------
import time
from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any
from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any, Union, cast
from azure.core.credentials import TokenCredential, SupportsTokenInfo, TokenRequestOptions, TokenProvider
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline.transport import HttpResponse as LegacyHttpResponse, HttpRequest as LegacyHttpRequest
from azure.core.rest import HttpResponse, HttpRequest
Expand All @@ -15,7 +16,7 @@
# pylint:disable=unused-import
from azure.core.credentials import (
AccessToken,
TokenCredential,
AccessTokenInfo,
AzureKeyCredential,
AzureSasCredential,
)
Expand All @@ -29,17 +30,17 @@ class _BearerTokenCredentialPolicyBase:
"""Base class for a Bearer Token Credential Policy.

:param credential: The credential.
:type credential: ~azure.core.credentials.TokenCredential
:type credential: ~azure.core.credentials.TokenProvider
:param str scopes: Lets you specify the type of access needed.
:keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested
tokens. Defaults to False.
"""

def __init__(self, credential: "TokenCredential", *scopes: str, **kwargs: Any) -> None:
def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> None:
super(_BearerTokenCredentialPolicyBase, self).__init__()
self._scopes = scopes
self._credential = credential
self._token: Optional["AccessToken"] = None
self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None
self._enable_cae: bool = kwargs.get("enable_cae", False)

@staticmethod
Expand Down Expand Up @@ -70,11 +71,29 @@ def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
@property
def _need_new_token(self) -> bool:
now = time.time()
return (
not self._token
or (self._token.refresh_on is not None and self._token.refresh_on <= now)
or self._token.expires_on - now < 300
)
refresh_on = getattr(self._token, "refresh_on", None)
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300

def _request_token(self, *scopes: str, **kwargs: Any) -> None:
"""Request a new token from the credential.

This will call the credential's appropriate method to get a token and store it in the policy.

:param str scopes: The type of access needed.
"""
if self._enable_cae:
kwargs.setdefault("enable_cae", self._enable_cae)

if hasattr(self._credential, "get_token_info"):
options: TokenRequestOptions = {}
# Loop through all the keyword arguments and check if they are part of the TokenRequestOptions.
for key in list(kwargs.keys()):
if key in TokenRequestOptions.__annotations__: # pylint:disable=no-member
options[key] = kwargs.pop(key) # type: ignore[literal-required]

self._token = cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options)
else:
self._token = cast(TokenCredential, self._credential).get_token(*scopes, **kwargs)


class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]):
Expand All @@ -98,11 +117,9 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
self._enforce_https(request)

if self._token is None or self._need_new_token:
if self._enable_cae:
self._token = self._credential.get_token(*self._scopes, enable_cae=self._enable_cae)
else:
self._token = self._credential.get_token(*self._scopes)
self._update_headers(request.http_request.headers, self._token.token)
self._request_token(*self._scopes)
bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token
self._update_headers(request.http_request.headers, bearer_token)

def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
"""Acquire a token from the credential and authorize the request with it.
Expand All @@ -113,10 +130,9 @@ def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes:
:param ~azure.core.pipeline.PipelineRequest request: the request
:param str scopes: required scopes of authentication
"""
if self._enable_cae:
kwargs.setdefault("enable_cae", self._enable_cae)
self._token = self._credential.get_token(*scopes, **kwargs)
self._update_headers(request.http_request.headers, self._token.token)
self._request_token(*scopes, **kwargs)
bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token
self._update_headers(request.http_request.headers, bearer_token)

def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]:
"""Authorize request with a bearer token and send it to the next policy
Expand Down
Loading