diff --git a/sdk/identity/azure-identity/azure/identity/_internal/__init__.py b/sdk/identity/azure-identity/azure/identity/_internal/__init__.py new file mode 100644 index 000000000000..9ea29a25784d --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/_internal/__init__.py @@ -0,0 +1,6 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from .msal_credentials import ConfidentialClientCredential +from .msal_transport_adapter import MsalTransportResponse, MsalTransportAdapter diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py new file mode 100644 index 000000000000..9bf44cbb3219 --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py @@ -0,0 +1,88 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +"""Credentials wrapping MSAL applications and delegating token acquisition and caching to them. +This entails monkeypatching MSAL's OAuth client with an adapter substituting an azure-core pipeline for Requests. +""" + +import time + +try: + from typing import TYPE_CHECKING +except ImportError: + TYPE_CHECKING = False + +try: + from unittest import mock +except ImportError: # python < 3.3 + import mock # type: ignore + +if TYPE_CHECKING: + # pylint:disable=unused-import + from typing import Any, Mapping, Optional, Union + +from azure.core.credentials import AccessToken +from azure.core.exceptions import ClientAuthenticationError +import msal + +from .msal_transport_adapter import MsalTransportAdapter + + +class MsalCredential(object): + """Base class for credentials wrapping MSAL applications""" + + def __init__(self, client_id, authority, app_class, client_credential=None, **kwargs): + # type: (str, str, msal.ClientApplication, Optional[Union[str, Mapping[str, str]]], Any) -> None + self._authority = authority + self._client_credential = client_credential + self._client_id = client_id + + self._adapter = kwargs.pop("msal_adapter", None) or MsalTransportAdapter(**kwargs) + + # postpone creating the wrapped application because its initializer uses the network + self._app_class = app_class + self._msal_app = None # type: Optional[msal.ClientApplication] + + @property + def _app(self): + # type: () -> msal.ClientApplication + """The wrapped MSAL application""" + + if not self._msal_app: + # MSAL application initializers use msal.authority to send AAD tenant discovery requests + with mock.patch("msal.authority.requests", self._adapter): + app = self._app_class( + client_id=self._client_id, client_credential=self._client_credential, authority=self._authority + ) + + # monkeypatch the app to replace requests.Session with MsalTransportAdapter + app.client.session = self._adapter + self._msal_app = app + + return self._msal_app + + +class ConfidentialClientCredential(MsalCredential): + """Wraps an MSAL ConfidentialClientApplication with the TokenCredential API""" + + def __init__(self, **kwargs): + # type: (Any) -> None + super(ConfidentialClientCredential, self).__init__(app_class=msal.ConfidentialClientApplication, **kwargs) + + def get_token(self, *scopes): + # type: (str) -> AccessToken + + # MSAL requires scopes be a list + scopes = list(scopes) # type: ignore + now = int(time.time()) + + # First try to get a cached access token or if a refresh token is cached, redeem it for an access token. + # Failing that, acquire a new token. + app = self._app # type: msal.ConfidentialClientApplication + result = app.acquire_token_silent(scopes, account=None) or app.acquire_token_for_client(scopes) + + if "access_token" not in result: + raise ClientAuthenticationError(message="authentication failed: {}".format(result.get("error_description"))) + + return AccessToken(result["access_token"], now + int(result["expires_in"])) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_transport_adapter.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_transport_adapter.py new file mode 100644 index 000000000000..1a19beaf1c3d --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_transport_adapter.py @@ -0,0 +1,88 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +"""Adapter to substitute an azure-core pipeline for Requests in MSAL application token acquisition methods.""" + +import json + +try: + from typing import TYPE_CHECKING +except ImportError: + TYPE_CHECKING = False + +if TYPE_CHECKING: + # pylint:disable=unused-import + from typing import Any, Dict, Mapping, Optional + from azure.core.pipeline import PipelineResponse + +from azure.core.configuration import Configuration +from azure.core.exceptions import ClientAuthenticationError +from azure.core.pipeline import Pipeline +from azure.core.pipeline.policies import ContentDecodePolicy, NetworkTraceLoggingPolicy, RetryPolicy +from azure.core.pipeline.transport import HttpRequest, RequestsTransport + + +class MsalTransportResponse: + """Wraps an azure-core PipelineResponse with the shape of requests.Response""" + + def __init__(self, pipeline_response): + # type: (PipelineResponse) -> None + self._response = pipeline_response.http_response + self.status_code = self._response.status_code + self.text = self._response.text() + + def json(self, **kwargs): + # type: (Any) -> Mapping[str, Any] + return json.loads(self.text, **kwargs) + + def raise_for_status(self): + # type: () -> None + raise ClientAuthenticationError("authentication failed", self._response) + + +class MsalTransportAdapter(object): + """Wraps an azure-core pipeline with the shape of requests.Session""" + + def __init__(self, **kwargs): + # type: (Any) -> None + super(MsalTransportAdapter, self).__init__() + self._pipeline = self._build_pipeline(**kwargs) + + @staticmethod + def create_config(**kwargs): + # type: (Any) -> Configuration + config = Configuration(**kwargs) + config.logging_policy = NetworkTraceLoggingPolicy(**kwargs) + config.retry_policy = RetryPolicy(**kwargs) + return config + + def _build_pipeline(self, config=None, policies=None, transport=None, **kwargs): + config = config or self.create_config(**kwargs) + policies = policies or [ContentDecodePolicy(), config.retry_policy, config.logging_policy] + if not transport: + transport = RequestsTransport(configuration=config) + return Pipeline(transport=transport, policies=policies) + + def get(self, url, headers=None, params=None, timeout=None, verify=None, **kwargs): + # type: (str, Optional[Mapping[str, str]], Optional[Dict[str, str]], float, bool, Any) -> MsalTransportResponse + request = HttpRequest("GET", url, headers=headers) + if params: + request.format_parameters(params) + response = self._pipeline.run( + request, stream=False, connection_timeout=timeout, connection_verify=verify, **kwargs + ) + return MsalTransportResponse(response) + + def post(self, url, data=None, headers=None, params=None, timeout=None, verify=None, **kwargs): + # type: (str, Optional[Mapping[str, str]], Optional[Mapping[str, str]], Optional[Dict[str, str]], float, bool, Any) -> MsalTransportResponse + request = HttpRequest("POST", url, headers=headers) + if params: + request.format_parameters(params) + if data: + request.headers["Content-Type"] = "application/x-www-form-urlencoded" + request.set_formdata_body(data) + response = self._pipeline.run( + request, stream=False, connection_timeout=timeout, connection_verify=verify, **kwargs + ) + return MsalTransportResponse(response) diff --git a/sdk/identity/azure-identity/azure/identity/_internal.py b/sdk/identity/azure-identity/azure/identity/_managed_identity.py similarity index 100% rename from sdk/identity/azure-identity/azure/identity/_internal.py rename to sdk/identity/azure-identity/azure/identity/_managed_identity.py diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal.py b/sdk/identity/azure-identity/azure/identity/aio/_managed_identity.py similarity index 99% rename from sdk/identity/azure-identity/azure/identity/aio/_internal.py rename to sdk/identity/azure-identity/azure/identity/aio/_managed_identity.py index 4f502e95ed17..ec698b52c98b 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_managed_identity.py @@ -12,7 +12,7 @@ from ._authn_client import AsyncAuthnClient from ..constants import Endpoints, EnvironmentVariables -from .._internal import _ManagedIdentityBase +from .._managed_identity import _ManagedIdentityBase class _AsyncManagedIdentityBase(_ManagedIdentityBase): diff --git a/sdk/identity/azure-identity/azure/identity/aio/credentials.py b/sdk/identity/azure-identity/azure/identity/aio/credentials.py index 7f9f846f9ea1..c43dc917e518 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/credentials.py +++ b/sdk/identity/azure-identity/azure/identity/aio/credentials.py @@ -14,7 +14,7 @@ from azure.core.pipeline.policies import ContentDecodePolicy, HeadersPolicy, NetworkTraceLoggingPolicy, AsyncRetryPolicy from ._authn_client import AsyncAuthnClient -from ._internal import ImdsCredential, MsiCredential +from ._managed_identity import ImdsCredential, MsiCredential from .._base import ClientSecretCredentialBase, CertificateCredentialBase from ..constants import Endpoints, EnvironmentVariables from ..credentials import ChainedTokenCredential diff --git a/sdk/identity/azure-identity/azure/identity/credentials.py b/sdk/identity/azure-identity/azure/identity/credentials.py index 6a172995d563..d53edf8e2c62 100644 --- a/sdk/identity/azure-identity/azure/identity/credentials.py +++ b/sdk/identity/azure-identity/azure/identity/credentials.py @@ -14,7 +14,7 @@ from ._authn_client import AuthnClient from ._base import ClientSecretCredentialBase, CertificateCredentialBase -from ._internal import ImdsCredential, MsiCredential +from ._managed_identity import ImdsCredential, MsiCredential from .constants import Endpoints, EnvironmentVariables try: diff --git a/sdk/identity/azure-identity/tests/test_identity.py b/sdk/identity/azure-identity/tests/test_identity.py index 72d9a31dbf0a..c9756ad3b344 100644 --- a/sdk/identity/azure-identity/tests/test_identity.py +++ b/sdk/identity/azure-identity/tests/test_identity.py @@ -22,7 +22,7 @@ ManagedIdentityCredential, ChainedTokenCredential, ) -from azure.identity._internal import ImdsCredential +from azure.identity._managed_identity import ImdsCredential from azure.identity.constants import EnvironmentVariables from helpers import mock_response, Request, validating_transport diff --git a/sdk/identity/azure-identity/tests/test_identity_async.py b/sdk/identity/azure-identity/tests/test_identity_async.py index 78230c94bb09..ba203cd2eb59 100644 --- a/sdk/identity/azure-identity/tests/test_identity_async.py +++ b/sdk/identity/azure-identity/tests/test_identity_async.py @@ -19,7 +19,7 @@ EnvironmentCredential, ManagedIdentityCredential, ) -from azure.identity.aio._internal import ImdsCredential +from azure.identity.aio._managed_identity import ImdsCredential from azure.identity.constants import EnvironmentVariables from helpers import mock_response, Request, async_validating_transport diff --git a/sdk/identity/azure-identity/tests/test_live.py b/sdk/identity/azure-identity/tests/test_live.py index 891524a48929..ddff3d83fa3a 100644 --- a/sdk/identity/azure-identity/tests/test_live.py +++ b/sdk/identity/azure-identity/tests/test_live.py @@ -2,16 +2,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -import os - -try: - from unittest import mock -except ImportError: # python < 3.3 - import mock # type: ignore - from azure.identity import DefaultAzureCredential, CertificateCredential, ClientSecretCredential -from azure.identity.constants import EnvironmentVariables -import pytest +from azure.identity._internal import ConfidentialClientCredential ARM_SCOPE = "https://management.azure.com/.default" @@ -46,3 +38,15 @@ def test_default_credential(live_identity_settings): assert token assert token.token assert token.expires_on + + +def test_confidential_client_credential(live_identity_settings): + credential = ConfidentialClientCredential( + client_id=live_identity_settings["client_id"], + client_credential=live_identity_settings["client_secret"], + authority="https://login.microsoftonline.com/" + live_identity_settings["tenant_id"], + ) + token = credential.get_token(ARM_SCOPE) + assert token + assert token.token + assert token.expires_on