diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index f683405407cb..ea418bf1c97f 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -15,6 +15,7 @@ #### Other Changes * Added session token false progress merge logic. See [42393](https://github.com/Azure/azure-sdk-for-python/pull/42393) +* Added a fallback mechanism to AAD scope override. See [PR 42731](https://github.com/Azure/azure-sdk-for-python/pull/42731). ### 4.14.0b2 (2025-08-12) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py index 8f3316870607..83418e1f375d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py @@ -3,20 +3,32 @@ # Licensed under the MIT License. See LICENSE.txt in the project root for # license information. # ------------------------------------------------------------------------- -from typing import TypeVar, Any, MutableMapping, cast +from typing import TypeVar, Any, MutableMapping, cast, Optional from azure.core.pipeline import PipelineRequest from azure.core.pipeline.policies import BearerTokenCredentialPolicy from azure.core.pipeline.transport import HttpRequest as LegacyHttpRequest from azure.core.rest import HttpRequest from azure.core.credentials import AccessToken +from azure.core.exceptions import HttpResponseError from .http_constants import HttpHeaders +from ._constants import _Constants as Constants HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) - +# NOTE: This class accesses protected members (_scopes, _token) of the parent class +# to implement fallback and scope-switching logic not exposed by the public API. +# Composition was considered, but still required accessing protected members, so inheritance is retained +# for seamless Azure SDK pipeline integration. class CosmosBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): + AadDefaultScope = Constants.AAD_DEFAULT_SCOPE + + def __init__(self, credential, account_scope: str, override_scope: Optional[str] = None): + self._account_scope = account_scope + self._override_scope = override_scope + self._current_scope = override_scope or account_scope + super().__init__(credential, self._current_scope) @staticmethod def _update_headers(headers: MutableMapping[str, str], token: str) -> None: @@ -34,9 +46,26 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: :param ~azure.core.pipeline.PipelineRequest request: the request """ - super().on_request(request) - # The None-check for self._token is done in the parent on_request - self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token) + tried_fallback = False + while True: + try: + super().on_request(request) + # The None-check for self._token is done in the parent on_request + self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token) + break + except HttpResponseError as ex: + # Only fallback if not using override, not already tried, and error is AADSTS500011 + if ( + not self._override_scope and + not tried_fallback and + self._current_scope != self.AadDefaultScope and + "AADSTS500011" in str(ex) + ): + self._scopes = (self.AadDefaultScope,) + self._current_scope = self.AadDefaultScope + tried_fallback = True + continue + raise def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None: """Acquire a token from the credential and authorize the request with it. @@ -47,6 +76,7 @@ def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: :param ~azure.core.pipeline.PipelineRequest request: the request :param str scopes: required scopes of authentication """ + super().authorize_request(request, *scopes, **kwargs) # The None-check for self._token is done in the parent authorize_request self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py index c7e1b89730ed..7a1675546934 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py @@ -56,6 +56,7 @@ class _Constants: CIRCUIT_BREAKER_ENABLED_CONFIG: str = "AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER" CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT: str = "False" AAD_SCOPE_OVERRIDE: str = "AZURE_COSMOS_AAD_SCOPE_OVERRIDE" + AAD_DEFAULT_SCOPE: str = "https://cosmos.azure.com/.default" # Database Account Retry Policy constants AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES: str = "AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES" diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 23406adfb212..180aacaf5611 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -132,7 +132,6 @@ def __init__( # pylint: disable=too-many-statements The connection policy for the client. :param documents.ConsistencyLevel consistency_level: The default consistency policy for client operations. - """ self.client_id = str(uuid.uuid4()) self.url_connection = url_connection @@ -205,11 +204,12 @@ def __init__( # pylint: disable=too-many-statements credentials_policy = None if self.aad_credentials: scope_override = os.environ.get(Constants.AAD_SCOPE_OVERRIDE, "") - if scope_override: - scope = scope_override - else: - scope = base.create_scope_from_url(self.url_connection) - credentials_policy = CosmosBearerTokenCredentialPolicy(self.aad_credentials, scope) + account_scope = base.create_scope_from_url(self.url_connection) + credentials_policy = CosmosBearerTokenCredentialPolicy( + self.aad_credentials, + account_scope=account_scope, + override_scope=scope_override if scope_override else None + ) self._enable_diagnostics_logging = kwargs.pop("enable_diagnostics_logging", False) policies = [ HeadersPolicy(**kwargs), diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py index 6e018d88cdfd..ea1a86b120a1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py @@ -4,20 +4,32 @@ # license information. # ------------------------------------------------------------------------- -from typing import Any, MutableMapping, TypeVar, cast +from typing import Any, MutableMapping, TypeVar, cast, Optional from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy from azure.core.pipeline import PipelineRequest from azure.core.pipeline.transport import HttpRequest as LegacyHttpRequest from azure.core.rest import HttpRequest from azure.core.credentials import AccessToken +from azure.core.exceptions import HttpResponseError from ..http_constants import HttpHeaders +from .._constants import _Constants as Constants HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) - +# NOTE: This class accesses protected members (_scopes, _token) of the parent class +# to implement fallback and scope-switching logic not exposed by the public API. +# Composition was considered, but still required accessing protected members, so inheritance is retained +# for seamless Azure SDK pipeline integration. class AsyncCosmosBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): + AadDefaultScope = Constants.AAD_DEFAULT_SCOPE + + def __init__(self, credential, account_scope: str, override_scope: Optional[str] = None): + self._account_scope = account_scope + self._override_scope = override_scope + self._current_scope = override_scope or account_scope + super().__init__(credential, self._current_scope) @staticmethod def _update_headers(headers: MutableMapping[str, str], token: str) -> None: @@ -35,9 +47,26 @@ async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: :type request: ~azure.core.pipeline.PipelineRequest :raises: :class:`~azure.core.exceptions.ServiceRequestError` """ - await super().on_request(request) - # The None-check for self._token is done in the parent on_request - self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token) + tried_fallback = False + while True: + try: + await super().on_request(request) + # The None-check for self._token is done in the parent on_request + self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token) + break + except HttpResponseError as ex: + # Only fallback if not using override, not already tried, and error is AADSTS500011 + if ( + not self._override_scope and + not tried_fallback and + self._current_scope != self.AadDefaultScope and + "AADSTS500011" in str(ex) + ): + self._scopes = (self.AadDefaultScope,) + self._current_scope = self.AadDefaultScope + tried_fallback = True + continue + raise async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None: """Acquire a token from the credential and authorize the request with it. @@ -48,6 +77,7 @@ async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *sc :param ~azure.core.pipeline.PipelineRequest request: the request :param str scopes: required scopes of authentication """ + await super().authorize_request(request, *scopes, **kwargs) # The None-check for self._token is done in the parent authorize_request self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index a7fbdf218bb2..b8af1b081d70 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -138,7 +138,6 @@ def __init__( # pylint: disable=too-many-statements The connection policy for the client. :param documents.ConsistencyLevel consistency_level: The default consistency policy for client operations. - """ self.client_id = str(uuid.uuid4()) self.url_connection = url_connection @@ -213,11 +212,12 @@ def __init__( # pylint: disable=too-many-statements credentials_policy = None if self.aad_credentials: scope_override = os.environ.get(Constants.AAD_SCOPE_OVERRIDE, "") - if scope_override: - scope = scope_override - else: - scope = base.create_scope_from_url(self.url_connection) - credentials_policy = AsyncCosmosBearerTokenCredentialPolicy(self.aad_credentials, scope) + account_scope = base.create_scope_from_url(self.url_connection) + credentials_policy = AsyncCosmosBearerTokenCredentialPolicy( + self.aad_credentials, + account_scope, + scope_override + ) self._enable_diagnostics_logging = kwargs.pop("enable_diagnostics_logging", False) policies = [ HeadersPolicy(**kwargs), diff --git a/sdk/cosmos/azure-cosmos/tests/test_aad.py b/sdk/cosmos/azure-cosmos/tests/test_aad.py index 39ff7500e30a..b1d593bd96a6 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_aad.py +++ b/sdk/cosmos/azure-cosmos/tests/test_aad.py @@ -14,7 +14,7 @@ import azure.cosmos.cosmos_client as cosmos_client import test_config from azure.cosmos import DatabaseProxy, ContainerProxy, exceptions - +from azure.core.exceptions import HttpResponseError def _remove_padding(encoded_string): while encoded_string.endswith("="): @@ -34,7 +34,6 @@ def get_test_item(num): class CosmosEmulatorCredential(object): - def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken """Request an access token for the emulator. Based on Azure Core's Access Token Credential. @@ -118,33 +117,126 @@ def test_aad_credentials(self): assert e.status_code == 403 print("403 error assertion success") - def test_aad_scope_override(self): - override_scope = "https://my.custom.scope/.default" - os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope + def _run_with_scope_capture(self, credential_cls, action, *args, **kwargs): scopes_captured = [] - original_get_token = CosmosEmulatorCredential.get_token + original_get_token = credential_cls.get_token def capturing_get_token(self, *scopes, **kwargs): scopes_captured.extend(scopes) return original_get_token(self, *scopes, **kwargs) - CosmosEmulatorCredential.get_token = capturing_get_token - + credential_cls.get_token = capturing_get_token try: + result = action(scopes_captured, *args, **kwargs) + finally: + credential_cls.get_token = original_get_token + return scopes_captured, result + + def test_override_scope_no_fallback(self): + """When override scope is provided, only that scope is used and no fallback occurs.""" + override_scope = "https://my.custom.scope/.default" + os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope + + def action(scopes_captured): credential = CosmosEmulatorCredential() client = cosmos_client.CosmosClient(self.host, credential) db = client.get_database_client(self.configs.TEST_DATABASE_ID) container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) - container.create_item(get_test_item(1)) - assert override_scope in scopes_captured + container.create_item(get_test_item(10)) + return container + + scopes, container = self._run_with_scope_capture(CosmosEmulatorCredential, action) + try: + assert all(scope == override_scope for scope in scopes), f"Expected only override scope(s), got: {scopes}" finally: - CosmosEmulatorCredential.get_token = original_get_token del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] try: - container.delete_item(item='Item_1', partition_key='pk') + container.delete_item(item='Item_10', partition_key='pk') except Exception: pass + def test_override_scope_auth_error_no_fallback(self): + """When override scope is provided and auth fails, no fallback to other scopes occurs.""" + override_scope = "https://my.custom.scope/.default" + os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope + + class FailingCredential(CosmosEmulatorCredential): + def get_token(self, *scopes, **kwargs): + raise Exception("Simulated auth error for override scope") + + def action(scopes_captured): + with pytest.raises(Exception) as excinfo: + client = cosmos_client.CosmosClient(self.host, FailingCredential()) + db = client.get_database_client(self.configs.TEST_DATABASE_ID) + container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) + container.create_item(get_test_item(11)) + assert "Simulated auth error" in str(excinfo.value) + return None + + scopes, _ = self._run_with_scope_capture(FailingCredential, action) + try: + assert scopes == [override_scope], f"Expected only override scope, got: {scopes}" + finally: + del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] + + def test_account_scope_only(self): + """When account scope is provided, only that scope is used.""" + account_scope = "https://localhost/.default" + os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = "" + + def action(scopes_captured): + credential = CosmosEmulatorCredential() + client = cosmos_client.CosmosClient(self.host, credential) + db = client.get_database_client(self.configs.TEST_DATABASE_ID) + container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) + container.create_item(get_test_item(12)) + return container + + scopes, container = self._run_with_scope_capture(CosmosEmulatorCredential, action) + try: + # Accept multiple calls, but only the account_scope should be used + assert all(scope == account_scope for scope in scopes), f"Expected only account scope, got: {scopes}" + finally: + try: + container.delete_item(item='Item_12', partition_key='pk') + except Exception: + pass + + def test_account_scope_fallback_on_error(self): + """When account scope is provided and auth fails, fallback to default scope occurs.""" + account_scope = "https://localhost/.default" + fallback_scope = "https://cosmos.azure.com/.default" + os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = "" + + class FallbackCredential(CosmosEmulatorCredential): + def __init__(self): + self.call_count = 0 + + def get_token(self, *scopes, **kwargs): + self.call_count += 1 + if self.call_count == 1: + raise HttpResponseError(message="AADSTS500011: Simulated error for fallback") + return super().get_token(*scopes, **kwargs) + + def action(scopes_captured): + credential = FallbackCredential() + client = cosmos_client.CosmosClient(self.host, credential) + db = client.get_database_client(self.configs.TEST_DATABASE_ID) + container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) + container.create_item(get_test_item(13)) + return container + + scopes, container = self._run_with_scope_capture(FallbackCredential, action) + try: + # Accept multiple calls, but the first should be account_scope, and fallback_scope should appear after error + assert account_scope in scopes and fallback_scope in scopes, f"Expected fallback to default scope, got: {scopes}" + finally: + try: + container.delete_item(item='Item_13', partition_key='pk') + except Exception: + pass + + if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_aad_async.py b/sdk/cosmos/azure-cosmos/tests/test_aad_async.py index c39516a06279..6ce3cb4d1124 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_aad_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_aad_async.py @@ -14,7 +14,7 @@ import test_config from azure.cosmos import exceptions from azure.cosmos.aio import CosmosClient, DatabaseProxy, ContainerProxy - +from azure.core.exceptions import HttpResponseError def _remove_padding(encoded_string): while encoded_string.endswith("="): @@ -34,7 +34,6 @@ def get_test_item(num): class CosmosEmulatorCredential(object): - async def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken """Request an access token for the emulator. Based on Azure Core's Access Token Credential. @@ -131,38 +130,141 @@ async def test_aad_credentials_async(self): assert e.status_code == 403 print("403 error assertion success") - async def test_aad_scope_override_async(self): - override_scope = "https://my.custom.scope/.default" - os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope - + async def _run_with_scope_capture_async(self, credential_cls, action): scopes_captured = [] - original_get_token = CosmosEmulatorCredential.get_token + + orig_get_token = credential_cls.get_token async def capturing_get_token(self, *scopes, **kwargs): scopes_captured.extend(scopes) - # Await the original method! - return await original_get_token(self, *scopes, **kwargs) - - CosmosEmulatorCredential.get_token = capturing_get_token + return await orig_get_token(self, *scopes, **kwargs) + credential_cls.get_token = capturing_get_token try: + result = await action(scopes_captured) + return scopes_captured, result + finally: + credential_cls.get_token = orig_get_token + + async def test_override_scope_no_fallback_async(self): + """When override scope is provided, only that scope is used and no fallback occurs.""" + override_scope = "https://my.custom.scope/.default" + os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope + + async def action(scopes_captured): credential = CosmosEmulatorCredential() client = CosmosClient(self.host, credential) - database = client.get_database_client(self.configs.TEST_DATABASE_ID) - container = database.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) + try: + db = client.get_database_client(self.configs.TEST_DATABASE_ID) + container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) + await container.create_item(get_test_item(20)) + return container + finally: + await client.close() + + scopes, container = await self._run_with_scope_capture_async(CosmosEmulatorCredential, action) + try: + assert all(scope == override_scope for scope in scopes), f"Expected only override scope, got: {scopes}" + finally: + del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] + try: + await container.delete_item(item='Item_20', partition_key='pk') + except Exception: + pass + + async def test_override_scope_no_fallback_on_error_async(self): + """When override scope is provided and auth fails, no fallback occurs.""" + override_scope = "https://my.custom.scope/.default" + os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope + + class FailingCredential(CosmosEmulatorCredential): + async def get_token(self, *scopes, **kwargs): + raise Exception("AADSTS500011: Simulated error for override scope") - await container.create_item(get_test_item(1)) - item = await container.read_item(item='Item_1', partition_key='pk') - assert item["id"] == "Item_1" - assert override_scope in scopes_captured + async def action(scopes_captured): + credential = FailingCredential() + client = CosmosClient(self.host, credential) + try: + db = client.get_database_client(self.configs.TEST_DATABASE_ID) + container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) + try: + await container.create_item(get_test_item(21)) + except Exception: + pass + return container + finally: + await client.close() + + scopes, container = await self._run_with_scope_capture_async(FailingCredential, action) + try: + assert all(scope == override_scope for scope in scopes), f"Expected only override scope, got: {scopes}" finally: - CosmosEmulatorCredential.get_token = original_get_token del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] try: - await container.delete_item(item='Item_1', partition_key='pk') + await container.delete_item(item='Item_21', partition_key='pk') + except Exception: + pass + + async def test_account_scope_only_async(self): + """When account scope is provided, only that scope is used.""" + account_scope = "https://localhost/.default" + os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = "" + + async def action(scopes_captured): + credential = CosmosEmulatorCredential() + client = CosmosClient(self.host, credential) + try: + db = client.get_database_client(self.configs.TEST_DATABASE_ID) + container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) + await container.create_item(get_test_item(22)) + return container + finally: + await client.close() + + scopes, container = await self._run_with_scope_capture_async(CosmosEmulatorCredential, action) + try: + assert all(scope == account_scope for scope in scopes), f"Expected only account scope, got: {scopes}" + finally: + try: + await container.delete_item(item='Item_22', partition_key='pk') + except Exception: + pass + + async def test_account_scope_fallback_on_error_async(self): + """When account scope is provided and auth fails, fallback to default scope occurs.""" + account_scope = "https://localhost/.default" + fallback_scope = "https://cosmos.azure.com/.default" + os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = "" + + class FallbackCredential(CosmosEmulatorCredential): + def __init__(self): + self.call_count = 0 + + async def get_token(self, *scopes, **kwargs): + self.call_count += 1 + if self.call_count == 1: + raise HttpResponseError(message="AADSTS500011: Simulated error for fallback") + return await super().get_token(*scopes, **kwargs) + + async def action(scopes_captured): + credential = FallbackCredential() + client = CosmosClient(self.host, credential) + try: + db = client.get_database_client(self.configs.TEST_DATABASE_ID) + container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) + await container.create_item(get_test_item(23)) + return container + finally: + await client.close() + + scopes, container = await self._run_with_scope_capture_async(FallbackCredential, action) + try: + assert account_scope in scopes and fallback_scope in scopes, f"Expected fallback to default scope, got: {scopes}" + finally: + try: + await container.delete_item(item='Item_23', partition_key='pk') except Exception: pass - await client.close() if __name__ == "__main__": unittest.main()