Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class _Constants:
MAX_ITEM_BUFFER_VS_CONFIG_DEFAULT: int = 50000
CIRCUIT_BREAKER_ENABLED_CONFIG: str = "AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"
CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT: str = "False"
AAD_SCOPE_OVERRIDE: str = "AZURE_COSMOS_AAD_SCOPE_OVERRIDE"
# Only applicable when circuit breaker is enabled -------------------------
CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ: str = "AZURE_COSMOS_CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ"
CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ_DEFAULT: int = 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,11 @@ def __init__( # pylint: disable=too-many-statements

credentials_policy = None
if self.aad_credentials:
scope = base.create_scope_from_url(self.url_connection)
scope_override = os.environ.get(Constants.AAD_SCOPE_OVERRIDE, "")
if scope_override:
scope = scope_override
else:
scope = base.create_scope_from_url(self.url_connection)
credentials_policy = CosmosBearerTokenCredentialPolicy(self.aad_credentials, scope)

policies = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,11 @@ def __init__( # pylint: disable=too-many-statements

credentials_policy = None
if self.aad_credentials:
scope = base.create_scope_from_url(self.url_connection)
scope_override = os.environ.get(Constants.AAD_SCOPE_OVERRIDE, "")
if scope_override:
scope = scope_override
else:
scope = base.create_scope_from_url(self.url_connection)
credentials_policy = AsyncCosmosBearerTokenCredentialPolicy(self.aad_credentials, scope)

policies = [
Expand Down
28 changes: 28 additions & 0 deletions sdk/cosmos/azure-cosmos/tests/test_aad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import base64
import json
import os
import time
import unittest
from io import StringIO
Expand Down Expand Up @@ -117,6 +118,33 @@ def test_aad_credentials(self):
assert e.status_code == 403
print("403 error assertion success")

def test_aad_scope_override(self):
override_scope = "https://my.custom.scope/.default"
os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope

scopes_captured = []
original_get_token = CosmosEmulatorCredential.get_token

def capturing_get_token(self, *scopes, **kwargs):
scopes_captured.extend(scopes)
return original_get_token(self, *scopes, **kwargs)

CosmosEmulatorCredential.get_token = capturing_get_token

try:
credential = CosmosEmulatorCredential()
client = cosmos_client.CosmosClient(self.host, credential)
db = client.get_database_client(self.configs.TEST_DATABASE_ID)
container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)
container.create_item(get_test_item(1))
assert override_scope in scopes_captured
finally:
CosmosEmulatorCredential.get_token = original_get_token
del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"]
try:
container.delete_item(item='Item_1', partition_key='pk')
except Exception:
pass

if __name__ == "__main__":
unittest.main()
33 changes: 33 additions & 0 deletions sdk/cosmos/azure-cosmos/tests/test_aad_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import base64
import json
import time
import os
import unittest
from io import StringIO

Expand Down Expand Up @@ -130,6 +131,38 @@ async def test_aad_credentials_async(self):
assert e.status_code == 403
print("403 error assertion success")

async def test_aad_scope_override_async(self):
override_scope = "https://my.custom.scope/.default"
os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope

scopes_captured = []
original_get_token = CosmosEmulatorCredential.get_token

async def capturing_get_token(self, *scopes, **kwargs):
scopes_captured.extend(scopes)
# Await the original method!
return await original_get_token(self, *scopes, **kwargs)

CosmosEmulatorCredential.get_token = capturing_get_token

try:
credential = CosmosEmulatorCredential()
client = CosmosClient(self.host, credential)
database = client.get_database_client(self.configs.TEST_DATABASE_ID)
container = database.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)

await container.create_item(get_test_item(1))
item = await container.read_item(item='Item_1', partition_key='pk')
assert item["id"] == "Item_1"
assert override_scope in scopes_captured
finally:
CosmosEmulatorCredential.get_token = original_get_token
del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"]
try:
await container.delete_item(item='Item_1', partition_key='pk')
except Exception:
pass
await client.close()

if __name__ == "__main__":
unittest.main()
Loading