diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py index 8f2ab323402c..160d929c0d56 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py @@ -53,6 +53,7 @@ class _Constants: MAX_ITEM_BUFFER_VS_CONFIG_DEFAULT: int = 50000 CIRCUIT_BREAKER_ENABLED_CONFIG: str = "AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER" CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT: str = "False" + AAD_SCOPE_OVERRIDE: str = "AZURE_COSMOS_AAD_SCOPE_OVERRIDE" # Only applicable when circuit breaker is enabled ------------------------- CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ: str = "AZURE_COSMOS_CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ" CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ_DEFAULT: int = 10 diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 0f6f246eb5f5..1cf44026d333 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -201,7 +201,11 @@ def __init__( # pylint: disable=too-many-statements credentials_policy = None if self.aad_credentials: - scope = base.create_scope_from_url(self.url_connection) + scope_override = os.environ.get(Constants.AAD_SCOPE_OVERRIDE, "") + if scope_override: + scope = scope_override + else: + scope = base.create_scope_from_url(self.url_connection) credentials_policy = CosmosBearerTokenCredentialPolicy(self.aad_credentials, scope) policies = [ diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index aafb9d26f4f0..636eae1e642e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -211,7 +211,11 @@ def __init__( # pylint: disable=too-many-statements credentials_policy = None if self.aad_credentials: - scope = base.create_scope_from_url(self.url_connection) + scope_override = os.environ.get(Constants.AAD_SCOPE_OVERRIDE, "") + if scope_override: + scope = scope_override + else: + scope = base.create_scope_from_url(self.url_connection) credentials_policy = AsyncCosmosBearerTokenCredentialPolicy(self.aad_credentials, scope) policies = [ diff --git a/sdk/cosmos/azure-cosmos/tests/test_aad.py b/sdk/cosmos/azure-cosmos/tests/test_aad.py index c13838571a43..39ff7500e30a 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_aad.py +++ b/sdk/cosmos/azure-cosmos/tests/test_aad.py @@ -3,6 +3,7 @@ import base64 import json +import os import time import unittest from io import StringIO @@ -117,6 +118,33 @@ def test_aad_credentials(self): assert e.status_code == 403 print("403 error assertion success") + def test_aad_scope_override(self): + override_scope = "https://my.custom.scope/.default" + os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope + + scopes_captured = [] + original_get_token = CosmosEmulatorCredential.get_token + + def capturing_get_token(self, *scopes, **kwargs): + scopes_captured.extend(scopes) + return original_get_token(self, *scopes, **kwargs) + + CosmosEmulatorCredential.get_token = capturing_get_token + + try: + credential = CosmosEmulatorCredential() + client = cosmos_client.CosmosClient(self.host, credential) + db = client.get_database_client(self.configs.TEST_DATABASE_ID) + container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) + container.create_item(get_test_item(1)) + assert override_scope in scopes_captured + finally: + CosmosEmulatorCredential.get_token = original_get_token + del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] + try: + container.delete_item(item='Item_1', partition_key='pk') + except Exception: + pass if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_aad_async.py b/sdk/cosmos/azure-cosmos/tests/test_aad_async.py index 16b081bd1d88..c39516a06279 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_aad_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_aad_async.py @@ -4,6 +4,7 @@ import base64 import json import time +import os import unittest from io import StringIO @@ -130,6 +131,38 @@ async def test_aad_credentials_async(self): assert e.status_code == 403 print("403 error assertion success") + async def test_aad_scope_override_async(self): + override_scope = "https://my.custom.scope/.default" + os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope + + scopes_captured = [] + original_get_token = CosmosEmulatorCredential.get_token + + async def capturing_get_token(self, *scopes, **kwargs): + scopes_captured.extend(scopes) + # Await the original method! + return await original_get_token(self, *scopes, **kwargs) + + CosmosEmulatorCredential.get_token = capturing_get_token + + try: + credential = CosmosEmulatorCredential() + client = CosmosClient(self.host, credential) + database = client.get_database_client(self.configs.TEST_DATABASE_ID) + container = database.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) + + await container.create_item(get_test_item(1)) + item = await container.read_item(item='Item_1', partition_key='pk') + assert item["id"] == "Item_1" + assert override_scope in scopes_captured + finally: + CosmosEmulatorCredential.get_token = original_get_token + del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] + try: + await container.delete_item(item='Item_1', partition_key='pk') + except Exception: + pass + await client.close() if __name__ == "__main__": unittest.main()