diff --git a/doc/sphinx/_static/js/get_options.js b/doc/sphinx/_static/js/get_options.js index 4e79921b4066..46ac0f0df866 100644 --- a/doc/sphinx/_static/js/get_options.js +++ b/doc/sphinx/_static/js/get_options.js @@ -23,11 +23,11 @@ function currentPackage(){ function httpGetAsync(targetUrl, callback) { var xmlHttp = new XMLHttpRequest(); - xmlHttp.onreadystatechange = function() { + xmlHttp.onreadystatechange = function() { if (xmlHttp.readyState == 4 && xmlHttp.status == 200) callback(xmlHttp.responseText); } - xmlHttp.open("GET", targetUrl, true); // true for asynchronous + xmlHttp.open("GET", targetUrl, true); // true for asynchronous xmlHttp.send(null); } @@ -45,8 +45,8 @@ function hideSelectors(selectors){ function populateOptions(optionSelector, otherSelectors){ if(currentPackage()){ - var versionRequestUrl = "https://azuresdkdocs.blob.core.windows.net/$web/" + SELECTED_LANGUAGE + "/" + currentPackage() + "/versioning/versions" - + var versionRequestUrl = "https://azuresdkdocs.z19.web.core.windows.net/" + SELECTED_LANGUAGE + "/" + currentPackage() + "/versioning/versions" + httpGetAsync(versionRequestUrl, function(responseText){ if(responseText){ options = responseText.match(/[^\r\n]+/g) @@ -68,7 +68,7 @@ function populateOptions(optionSelector, otherSelectors){ function populateVersionDropDown(selector, values){ var select = $(selector) - + $('option', select).remove() $.each(values, function(index, text) { @@ -80,17 +80,17 @@ function populateVersionDropDown(selector, values){ select.selectedIndex = 0 } else { - select.val(version) + select.val(version) } } function getPackageUrl(language, package, version){ - return "https://azuresdkdocs.blob.core.windows.net/$web/" + language + "/" + package + "/"+ version + "/index.html" + return "https://azuresdkdocs.z19.web.core.windows.net/" + language + "/" + package + "/"+ version + "/index.html" } function populateIndexList(selector, packageName) { - url = "https://azuresdkdocs.blob.core.windows.net/$web/" + SELECTED_LANGUAGE + "/" + packageName + "/versioning/versions" + url = "https://azuresdkdocs.z19.web.windows.net/" + SELECTED_LANGUAGE + "/" + packageName + "/versioning/versions" httpGetAsync(url, function (responseText){ if(responseText){ diff --git a/doc/sphinx/conf.py b/doc/sphinx/conf.py index f0d6ce555521..f2feed4d3c96 100644 --- a/doc/sphinx/conf.py +++ b/doc/sphinx/conf.py @@ -66,8 +66,8 @@ 'trio': ('https://trio.readthedocs.io/en/stable/', None), 'msal': ('https://msal-python.readthedocs.io/en/latest/', None), # Azure packages - 'azure-core': ('https://azuresdkdocs.blob.core.windows.net/$web/python/azure-core/latest/', None), - 'azure-identity': ('https://azuresdkdocs.blob.core.windows.net/$web/python/azure-identity/latest/', None), + 'azure-core': ('https://azuresdkdocs.z19.web.core.windows.net/python/azure-core/latest/', None), + 'azure-identity': ('https://azuresdkdocs.z19.web.core.windows.net/python/azure-identity/latest/', None), } autodoc_member_order = 'groupwise' diff --git a/doc/sphinx/individual_build_conf.py b/doc/sphinx/individual_build_conf.py index a289c668e3ed..74df67599c51 100644 --- a/doc/sphinx/individual_build_conf.py +++ b/doc/sphinx/individual_build_conf.py @@ -64,8 +64,8 @@ 'trio': ('https://trio.readthedocs.io/en/stable/', None), 'msal': ('https://msal-python.readthedocs.io/en/latest/', None), # Azure packages - 'azure-core': ('https://azuresdkdocs.blob.core.windows.net/$web/python/azure-core/1.0.0/', None), - 'azure-identity': ('https://azuresdkdocs.blob.core.windows.net/$web/python/azure-identity/1.0.0/', None), + 'azure-core': ('https://azuresdkdocs.z19.web.core.windows.net/python/azure-core/1.0.0/', None), + 'azure-identity': ('https://azuresdkdocs.z19.web.core.windows.net/python/azure-identity/1.0.0/', None), } autodoc_member_order = 'groupwise' diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 82dac7bf1c26..0e52da484037 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -1,5 +1,11 @@ ## Release History +### 4.9.1 (2025-10-03) + +#### Other Changes +* Added an option to override AAD audience scope through environment variable. See [PR 42228](https://github.com/Azure/azure-sdk-for-python/pull/42228). +* Added a fallback mechanism to AAD scope override. See [PR 42731](https://github.com/Azure/azure-sdk-for-python/pull/42731). + ### 4.9.0 (2024-11-18) #### Features Added diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py index 8f3316870607..83418e1f375d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py @@ -3,20 +3,32 @@ # Licensed under the MIT License. See LICENSE.txt in the project root for # license information. # ------------------------------------------------------------------------- -from typing import TypeVar, Any, MutableMapping, cast +from typing import TypeVar, Any, MutableMapping, cast, Optional from azure.core.pipeline import PipelineRequest from azure.core.pipeline.policies import BearerTokenCredentialPolicy from azure.core.pipeline.transport import HttpRequest as LegacyHttpRequest from azure.core.rest import HttpRequest from azure.core.credentials import AccessToken +from azure.core.exceptions import HttpResponseError from .http_constants import HttpHeaders +from ._constants import _Constants as Constants HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) - +# NOTE: This class accesses protected members (_scopes, _token) of the parent class +# to implement fallback and scope-switching logic not exposed by the public API. +# Composition was considered, but still required accessing protected members, so inheritance is retained +# for seamless Azure SDK pipeline integration. class CosmosBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): + AadDefaultScope = Constants.AAD_DEFAULT_SCOPE + + def __init__(self, credential, account_scope: str, override_scope: Optional[str] = None): + self._account_scope = account_scope + self._override_scope = override_scope + self._current_scope = override_scope or account_scope + super().__init__(credential, self._current_scope) @staticmethod def _update_headers(headers: MutableMapping[str, str], token: str) -> None: @@ -34,9 +46,26 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: :param ~azure.core.pipeline.PipelineRequest request: the request """ - super().on_request(request) - # The None-check for self._token is done in the parent on_request - self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token) + tried_fallback = False + while True: + try: + super().on_request(request) + # The None-check for self._token is done in the parent on_request + self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token) + break + except HttpResponseError as ex: + # Only fallback if not using override, not already tried, and error is AADSTS500011 + if ( + not self._override_scope and + not tried_fallback and + self._current_scope != self.AadDefaultScope and + "AADSTS500011" in str(ex) + ): + self._scopes = (self.AadDefaultScope,) + self._current_scope = self.AadDefaultScope + tried_fallback = True + continue + raise def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None: """Acquire a token from the credential and authorize the request with it. @@ -47,6 +76,7 @@ def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: :param ~azure.core.pipeline.PipelineRequest request: the request :param str scopes: required scopes of authentication """ + super().authorize_request(request, *scopes, **kwargs) # The None-check for self._token is done in the parent authorize_request self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py index 4aef29a8b3dc..4f26e2e3365e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py @@ -43,6 +43,9 @@ class _Constants: # ServiceDocument Resource EnableMultipleWritableLocations: Literal["enableMultipleWriteLocations"] = "enableMultipleWriteLocations" + AAD_DEFAULT_SCOPE: str = "https://cosmos.azure.com/.default" + AAD_SCOPE_OVERRIDE: str = "AZURE_COSMOS_AAD_SCOPE_OVERRIDE" + # Error code translations ERROR_TRANSLATIONS: Dict[int, str] = { 400: "BAD_REQUEST - Request being sent is invalid.", diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 03cdcbb8e214..97d97a8e1a63 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -85,7 +85,7 @@ class CredentialDict(TypedDict, total=False): clientSecretCredential: TokenCredential -class CosmosClientConnection: # pylint: disable=too-many-public-methods,too-many-instance-attributes +class CosmosClientConnection: # pylint: disable=too-many-public-methods,too-many-instance-attributes,too-many-statements """Represents a document client. Provides a client-side logical representation of the Azure Cosmos @@ -129,7 +129,6 @@ def __init__( The connection policy for the client. :param documents.ConsistencyLevel consistency_level: The default consistency policy for client operations. - """ self.url_connection = url_connection self.master_key: Optional[str] = None @@ -195,9 +194,13 @@ def __init__( credentials_policy = None if self.aad_credentials: - scope = base.create_scope_from_url(self.url_connection) - credentials_policy = CosmosBearerTokenCredentialPolicy(self.aad_credentials, scope) - + scope_override = os.environ.get(Constants.AAD_SCOPE_OVERRIDE, "") + account_scope = base.create_scope_from_url(self.url_connection) + credentials_policy = CosmosBearerTokenCredentialPolicy( + self.aad_credentials, + account_scope=account_scope, + override_scope=scope_override if scope_override else None + ) policies = [ HeadersPolicy(**kwargs), ProxyPolicy(proxies=proxies), diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py index 5b6088546c61..6ffc17562066 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py @@ -19,4 +19,4 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -VERSION = "4.9.0" +VERSION = "4.9.1" diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py index 6e018d88cdfd..ea1a86b120a1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py @@ -4,20 +4,32 @@ # license information. # ------------------------------------------------------------------------- -from typing import Any, MutableMapping, TypeVar, cast +from typing import Any, MutableMapping, TypeVar, cast, Optional from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy from azure.core.pipeline import PipelineRequest from azure.core.pipeline.transport import HttpRequest as LegacyHttpRequest from azure.core.rest import HttpRequest from azure.core.credentials import AccessToken +from azure.core.exceptions import HttpResponseError from ..http_constants import HttpHeaders +from .._constants import _Constants as Constants HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) - +# NOTE: This class accesses protected members (_scopes, _token) of the parent class +# to implement fallback and scope-switching logic not exposed by the public API. +# Composition was considered, but still required accessing protected members, so inheritance is retained +# for seamless Azure SDK pipeline integration. class AsyncCosmosBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): + AadDefaultScope = Constants.AAD_DEFAULT_SCOPE + + def __init__(self, credential, account_scope: str, override_scope: Optional[str] = None): + self._account_scope = account_scope + self._override_scope = override_scope + self._current_scope = override_scope or account_scope + super().__init__(credential, self._current_scope) @staticmethod def _update_headers(headers: MutableMapping[str, str], token: str) -> None: @@ -35,9 +47,26 @@ async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: :type request: ~azure.core.pipeline.PipelineRequest :raises: :class:`~azure.core.exceptions.ServiceRequestError` """ - await super().on_request(request) - # The None-check for self._token is done in the parent on_request - self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token) + tried_fallback = False + while True: + try: + await super().on_request(request) + # The None-check for self._token is done in the parent on_request + self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token) + break + except HttpResponseError as ex: + # Only fallback if not using override, not already tried, and error is AADSTS500011 + if ( + not self._override_scope and + not tried_fallback and + self._current_scope != self.AadDefaultScope and + "AADSTS500011" in str(ex) + ): + self._scopes = (self.AadDefaultScope,) + self._current_scope = self.AadDefaultScope + tried_fallback = True + continue + raise async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None: """Acquire a token from the credential and authorize the request with it. @@ -48,6 +77,7 @@ async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *sc :param ~azure.core.pipeline.PipelineRequest request: the request :param str scopes: required scopes of authentication """ + await super().authorize_request(request, *scopes, **kwargs) # The None-check for self._token is done in the parent authorize_request self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index b91f53394f53..6c63d5363cc4 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -88,7 +88,7 @@ class CredentialDict(TypedDict, total=False): clientSecretCredential: AsyncTokenCredential -class CosmosClientConnection: # pylint: disable=too-many-public-methods,too-many-instance-attributes +class CosmosClientConnection: # pylint: disable=too-many-public-methods,too-many-instance-attributes,too-many-statements """Represents a document client. Provides a client-side logical representation of the Azure Cosmos @@ -132,7 +132,6 @@ def __init__( The connection policy for the client. :param documents.ConsistencyLevel consistency_level: The default consistency policy for client operations. - """ self.url_connection = url_connection self.master_key: Optional[str] = None @@ -200,9 +199,13 @@ def __init__( credentials_policy = None if self.aad_credentials: - scope = base.create_scope_from_url(self.url_connection) - credentials_policy = AsyncCosmosBearerTokenCredentialPolicy(self.aad_credentials, scope) - + scope_override = os.environ.get(Constants.AAD_SCOPE_OVERRIDE, "") + account_scope = base.create_scope_from_url(self.url_connection) + credentials_policy = AsyncCosmosBearerTokenCredentialPolicy( + self.aad_credentials, + account_scope, + scope_override + ) policies = [ HeadersPolicy(**kwargs), ProxyPolicy(proxies=proxies), diff --git a/sdk/cosmos/azure-cosmos/test/test_aad.py b/sdk/cosmos/azure-cosmos/test/test_aad.py index c13838571a43..cb3db3417d70 100644 --- a/sdk/cosmos/azure-cosmos/test/test_aad.py +++ b/sdk/cosmos/azure-cosmos/test/test_aad.py @@ -5,6 +5,7 @@ import json import time import unittest +import os from io import StringIO import pytest @@ -13,7 +14,7 @@ import azure.cosmos.cosmos_client as cosmos_client import test_config from azure.cosmos import DatabaseProxy, ContainerProxy, exceptions - +from azure.core.exceptions import HttpResponseError def _remove_padding(encoded_string): while encoded_string.endswith("="): @@ -33,7 +34,6 @@ def get_test_item(num): class CosmosEmulatorCredential(object): - def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken """Request an access token for the emulator. Based on Azure Core's Access Token Credential. @@ -118,5 +118,125 @@ def test_aad_credentials(self): print("403 error assertion success") + def _run_with_scope_capture(self, credential_cls, action, *args, **kwargs): + scopes_captured = [] + original_get_token = credential_cls.get_token + + def capturing_get_token(self, *scopes, **kwargs): + scopes_captured.extend(scopes) + return original_get_token(self, *scopes, **kwargs) + + credential_cls.get_token = capturing_get_token + try: + result = action(scopes_captured, *args, **kwargs) + finally: + credential_cls.get_token = original_get_token + return scopes_captured, result + + def test_override_scope_no_fallback(self): + """When override scope is provided, only that scope is used and no fallback occurs.""" + override_scope = "https://my.custom.scope/.default" + os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope + + def action(scopes_captured): + credential = CosmosEmulatorCredential() + client = cosmos_client.CosmosClient(self.host, credential) + db = client.get_database_client(self.configs.TEST_DATABASE_ID) + container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) + container.create_item(get_test_item(10)) + return container + + scopes, container = self._run_with_scope_capture(CosmosEmulatorCredential, action) + try: + assert all(scope == override_scope for scope in scopes), f"Expected only override scope(s), got: {scopes}" + finally: + del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] + try: + container.delete_item(item='Item_10', partition_key='pk') + except Exception: + pass + + def test_override_scope_auth_error_no_fallback(self): + """When override scope is provided and auth fails, no fallback to other scopes occurs.""" + override_scope = "https://my.custom.scope/.default" + os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope + + class FailingCredential(CosmosEmulatorCredential): + def get_token(self, *scopes, **kwargs): + raise Exception("Simulated auth error for override scope") + + def action(scopes_captured): + with pytest.raises(Exception) as excinfo: + client = cosmos_client.CosmosClient(self.host, FailingCredential()) + db = client.get_database_client(self.configs.TEST_DATABASE_ID) + container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) + container.create_item(get_test_item(11)) + assert "Simulated auth error" in str(excinfo.value) + return None + + scopes, _ = self._run_with_scope_capture(FailingCredential, action) + try: + assert scopes == [override_scope], f"Expected only override scope, got: {scopes}" + finally: + del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] + + def test_account_scope_only(self): + """When account scope is provided, only that scope is used.""" + account_scope = "https://localhost/.default" + os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = "" + + def action(scopes_captured): + credential = CosmosEmulatorCredential() + client = cosmos_client.CosmosClient(self.host, credential) + db = client.get_database_client(self.configs.TEST_DATABASE_ID) + container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) + container.create_item(get_test_item(12)) + return container + + scopes, container = self._run_with_scope_capture(CosmosEmulatorCredential, action) + try: + # Accept multiple calls, but only the account_scope should be used + assert all(scope == account_scope for scope in scopes), f"Expected only account scope, got: {scopes}" + finally: + try: + container.delete_item(item='Item_12', partition_key='pk') + except Exception: + pass + + def test_account_scope_fallback_on_error(self): + """When account scope is provided and auth fails, fallback to default scope occurs.""" + account_scope = "https://localhost/.default" + fallback_scope = "https://cosmos.azure.com/.default" + os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = "" + + class FallbackCredential(CosmosEmulatorCredential): + def __init__(self): + self.call_count = 0 + + def get_token(self, *scopes, **kwargs): + self.call_count += 1 + if self.call_count == 1: + raise HttpResponseError(message="AADSTS500011: Simulated error for fallback") + return super().get_token(*scopes, **kwargs) + + def action(scopes_captured): + credential = FallbackCredential() + client = cosmos_client.CosmosClient(self.host, credential) + db = client.get_database_client(self.configs.TEST_DATABASE_ID) + container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) + container.create_item(get_test_item(13)) + return container + + scopes, container = self._run_with_scope_capture(FallbackCredential, action) + try: + # Accept multiple calls, but the first should be account_scope, and fallback_scope should appear after error + assert account_scope in scopes and fallback_scope in scopes, f"Expected fallback to default scope, got: {scopes}" + finally: + try: + container.delete_item(item='Item_13', partition_key='pk') + except Exception: + pass + + if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/test/test_aad_async.py b/sdk/cosmos/azure-cosmos/test/test_aad_async.py index d96375ea53d7..4b200ec3928a 100644 --- a/sdk/cosmos/azure-cosmos/test/test_aad_async.py +++ b/sdk/cosmos/azure-cosmos/test/test_aad_async.py @@ -5,6 +5,7 @@ import json import time import unittest +import os from io import StringIO import pytest @@ -13,7 +14,7 @@ import test_config from azure.cosmos import exceptions from azure.cosmos.aio import CosmosClient, DatabaseProxy, ContainerProxy - +from azure.core.exceptions import HttpResponseError def _remove_padding(encoded_string): while encoded_string.endswith("="): @@ -33,7 +34,6 @@ def get_test_item(num): class CosmosEmulatorCredential(object): - async def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken """Request an access token for the emulator. Based on Azure Core's Access Token Credential. @@ -119,6 +119,141 @@ async def test_aad_credentials_async(self): assert e.status_code == 403 print("403 error assertion success") + async def _run_with_scope_capture_async(self, credential_cls, action): + scopes_captured = [] + + orig_get_token = credential_cls.get_token + + async def capturing_get_token(self, *scopes, **kwargs): + scopes_captured.extend(scopes) + return await orig_get_token(self, *scopes, **kwargs) + + credential_cls.get_token = capturing_get_token + try: + result = await action(scopes_captured) + return scopes_captured, result + finally: + credential_cls.get_token = orig_get_token + + async def test_override_scope_no_fallback_async(self): + """When override scope is provided, only that scope is used and no fallback occurs.""" + override_scope = "https://my.custom.scope/.default" + os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope + + async def action(scopes_captured): + credential = CosmosEmulatorCredential() + client = CosmosClient(self.host, credential) + try: + db = client.get_database_client(self.configs.TEST_DATABASE_ID) + container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) + await container.create_item(get_test_item(20)) + return container + finally: + await client.close() + + scopes, container = await self._run_with_scope_capture_async(CosmosEmulatorCredential, action) + try: + assert all(scope == override_scope for scope in scopes), f"Expected only override scope, got: {scopes}" + finally: + del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] + try: + await container.delete_item(item='Item_20', partition_key='pk') + except Exception: + pass + + async def test_override_scope_no_fallback_on_error_async(self): + """When override scope is provided and auth fails, no fallback occurs.""" + override_scope = "https://my.custom.scope/.default" + os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope + + class FailingCredential(CosmosEmulatorCredential): + async def get_token(self, *scopes, **kwargs): + raise Exception("AADSTS500011: Simulated error for override scope") + + async def action(scopes_captured): + credential = FailingCredential() + client = CosmosClient(self.host, credential) + try: + db = client.get_database_client(self.configs.TEST_DATABASE_ID) + container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) + try: + await container.create_item(get_test_item(21)) + except Exception: + pass + return container + finally: + await client.close() + + scopes, container = await self._run_with_scope_capture_async(FailingCredential, action) + try: + assert all(scope == override_scope for scope in scopes), f"Expected only override scope, got: {scopes}" + finally: + del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] + try: + await container.delete_item(item='Item_21', partition_key='pk') + except Exception: + pass + + async def test_account_scope_only_async(self): + """When account scope is provided, only that scope is used.""" + account_scope = "https://localhost/.default" + os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = "" + + async def action(scopes_captured): + credential = CosmosEmulatorCredential() + client = CosmosClient(self.host, credential) + try: + db = client.get_database_client(self.configs.TEST_DATABASE_ID) + container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) + await container.create_item(get_test_item(22)) + return container + finally: + await client.close() + + scopes, container = await self._run_with_scope_capture_async(CosmosEmulatorCredential, action) + try: + assert all(scope == account_scope for scope in scopes), f"Expected only account scope, got: {scopes}" + finally: + try: + await container.delete_item(item='Item_22', partition_key='pk') + except Exception: + pass + + async def test_account_scope_fallback_on_error_async(self): + """When account scope is provided and auth fails, fallback to default scope occurs.""" + account_scope = "https://localhost/.default" + fallback_scope = "https://cosmos.azure.com/.default" + os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = "" + + class FallbackCredential(CosmosEmulatorCredential): + def __init__(self): + self.call_count = 0 + + async def get_token(self, *scopes, **kwargs): + self.call_count += 1 + if self.call_count == 1: + raise HttpResponseError(message="AADSTS500011: Simulated error for fallback") + return await super().get_token(*scopes, **kwargs) + + async def action(scopes_captured): + credential = FallbackCredential() + client = CosmosClient(self.host, credential) + try: + db = client.get_database_client(self.configs.TEST_DATABASE_ID) + container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) + await container.create_item(get_test_item(23)) + return container + finally: + await client.close() + + scopes, container = await self._run_with_scope_capture_async(FallbackCredential, action) + try: + assert account_scope in scopes and fallback_scope in scopes, f"Expected fallback to default scope, got: {scopes}" + finally: + try: + await container.delete_item(item='Item_23', partition_key='pk') + except Exception: + pass if __name__ == "__main__": unittest.main()