Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#### Other Changes
* Added session token false progress merge logic. See [42393](https://github.com/Azure/azure-sdk-for-python/pull/42393)
* Added a fallback mechanism to AAD scope override. See [PR 42731](https://github.com/Azure/azure-sdk-for-python/pull/42731).

### 4.14.0b2 (2025-08-12)

Expand Down
40 changes: 35 additions & 5 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,32 @@
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
from typing import TypeVar, Any, MutableMapping, cast
from typing import TypeVar, Any, MutableMapping, cast, Optional

from azure.core.pipeline import PipelineRequest
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
from azure.core.pipeline.transport import HttpRequest as LegacyHttpRequest
from azure.core.rest import HttpRequest
from azure.core.credentials import AccessToken
from azure.core.exceptions import HttpResponseError

from .http_constants import HttpHeaders
from ._constants import _Constants as Constants

HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)


# NOTE: This class accesses protected members (_scopes, _token) of the parent class
# to implement fallback and scope-switching logic not exposed by the public API.
# Composition was considered, but still required accessing protected members, so inheritance is retained
# for seamless Azure SDK pipeline integration.
class CosmosBearerTokenCredentialPolicy(BearerTokenCredentialPolicy):
AadDefaultScope = Constants.AAD_DEFAULT_SCOPE

def __init__(self, credential, account_scope: str, override_scope: Optional[str] = None):
self._account_scope = account_scope
self._override_scope = override_scope
self._current_scope = override_scope or account_scope
super().__init__(credential, self._current_scope)

@staticmethod
def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
Expand All @@ -34,9 +46,26 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:

:param ~azure.core.pipeline.PipelineRequest request: the request
"""
super().on_request(request)
# The None-check for self._token is done in the parent on_request
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
tried_fallback = False
while True:
try:
super().on_request(request)
# The None-check for self._token is done in the parent on_request
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
break
except HttpResponseError as ex:
# Only fallback if not using override, not already tried, and error is AADSTS500011
if (
not self._override_scope and
not tried_fallback and
self._current_scope != self.AadDefaultScope and
"AADSTS500011" in str(ex)
):
self._scopes = (self.AadDefaultScope,)
self._current_scope = self.AadDefaultScope
tried_fallback = True
continue
raise

def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
"""Acquire a token from the credential and authorize the request with it.
Expand All @@ -47,6 +76,7 @@ def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes:
:param ~azure.core.pipeline.PipelineRequest request: the request
:param str scopes: required scopes of authentication
"""

super().authorize_request(request, *scopes, **kwargs)
# The None-check for self._token is done in the parent authorize_request
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class _Constants:
CIRCUIT_BREAKER_ENABLED_CONFIG: str = "AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"
CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT: str = "False"
AAD_SCOPE_OVERRIDE: str = "AZURE_COSMOS_AAD_SCOPE_OVERRIDE"
AAD_DEFAULT_SCOPE: str = "https://cosmos.azure.com/.default"

# Database Account Retry Policy constants
AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES: str = "AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def __init__( # pylint: disable=too-many-statements
The connection policy for the client.
:param documents.ConsistencyLevel consistency_level:
The default consistency policy for client operations.

"""
self.client_id = str(uuid.uuid4())
self.url_connection = url_connection
Expand Down Expand Up @@ -205,11 +204,12 @@ def __init__( # pylint: disable=too-many-statements
credentials_policy = None
if self.aad_credentials:
scope_override = os.environ.get(Constants.AAD_SCOPE_OVERRIDE, "")
if scope_override:
scope = scope_override
else:
scope = base.create_scope_from_url(self.url_connection)
credentials_policy = CosmosBearerTokenCredentialPolicy(self.aad_credentials, scope)
account_scope = base.create_scope_from_url(self.url_connection)
credentials_policy = CosmosBearerTokenCredentialPolicy(
self.aad_credentials,
account_scope=account_scope,
override_scope=scope_override if scope_override else None
)
self._enable_diagnostics_logging = kwargs.pop("enable_diagnostics_logging", False)
policies = [
HeadersPolicy(**kwargs),
Expand Down
40 changes: 35 additions & 5 deletions sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,32 @@
# license information.
# -------------------------------------------------------------------------

from typing import Any, MutableMapping, TypeVar, cast
from typing import Any, MutableMapping, TypeVar, cast, Optional

from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy
from azure.core.pipeline import PipelineRequest
from azure.core.pipeline.transport import HttpRequest as LegacyHttpRequest
from azure.core.rest import HttpRequest
from azure.core.credentials import AccessToken
from azure.core.exceptions import HttpResponseError

from ..http_constants import HttpHeaders
from .._constants import _Constants as Constants

HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)


# NOTE: This class accesses protected members (_scopes, _token) of the parent class
# to implement fallback and scope-switching logic not exposed by the public API.
# Composition was considered, but still required accessing protected members, so inheritance is retained
# for seamless Azure SDK pipeline integration.
class AsyncCosmosBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy):
AadDefaultScope = Constants.AAD_DEFAULT_SCOPE

def __init__(self, credential, account_scope: str, override_scope: Optional[str] = None):
self._account_scope = account_scope
self._override_scope = override_scope
self._current_scope = override_scope or account_scope
super().__init__(credential, self._current_scope)

@staticmethod
def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
Expand All @@ -35,9 +47,26 @@ async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
:type request: ~azure.core.pipeline.PipelineRequest
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""
await super().on_request(request)
# The None-check for self._token is done in the parent on_request
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
tried_fallback = False
while True:
try:
await super().on_request(request)
# The None-check for self._token is done in the parent on_request
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
break
except HttpResponseError as ex:
# Only fallback if not using override, not already tried, and error is AADSTS500011
if (
not self._override_scope and
not tried_fallback and
self._current_scope != self.AadDefaultScope and
"AADSTS500011" in str(ex)
):
self._scopes = (self.AadDefaultScope,)
self._current_scope = self.AadDefaultScope
tried_fallback = True
continue
raise

async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
"""Acquire a token from the credential and authorize the request with it.
Expand All @@ -48,6 +77,7 @@ async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *sc
:param ~azure.core.pipeline.PipelineRequest request: the request
:param str scopes: required scopes of authentication
"""

await super().authorize_request(request, *scopes, **kwargs)
# The None-check for self._token is done in the parent authorize_request
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def __init__( # pylint: disable=too-many-statements
The connection policy for the client.
:param documents.ConsistencyLevel consistency_level:
The default consistency policy for client operations.

"""
self.client_id = str(uuid.uuid4())
self.url_connection = url_connection
Expand Down Expand Up @@ -213,11 +212,12 @@ def __init__( # pylint: disable=too-many-statements
credentials_policy = None
if self.aad_credentials:
scope_override = os.environ.get(Constants.AAD_SCOPE_OVERRIDE, "")
if scope_override:
scope = scope_override
else:
scope = base.create_scope_from_url(self.url_connection)
credentials_policy = AsyncCosmosBearerTokenCredentialPolicy(self.aad_credentials, scope)
account_scope = base.create_scope_from_url(self.url_connection)
credentials_policy = AsyncCosmosBearerTokenCredentialPolicy(
self.aad_credentials,
account_scope,
scope_override
)
self._enable_diagnostics_logging = kwargs.pop("enable_diagnostics_logging", False)
policies = [
HeadersPolicy(**kwargs),
Expand Down
116 changes: 104 additions & 12 deletions sdk/cosmos/azure-cosmos/tests/test_aad.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import azure.cosmos.cosmos_client as cosmos_client
import test_config
from azure.cosmos import DatabaseProxy, ContainerProxy, exceptions

from azure.core.exceptions import HttpResponseError

def _remove_padding(encoded_string):
while encoded_string.endswith("="):
Expand All @@ -34,7 +34,6 @@ def get_test_item(num):


class CosmosEmulatorCredential(object):

def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
"""Request an access token for the emulator. Based on Azure Core's Access Token Credential.
Expand Down Expand Up @@ -118,33 +117,126 @@ def test_aad_credentials(self):
assert e.status_code == 403
print("403 error assertion success")

def test_aad_scope_override(self):
override_scope = "https://my.custom.scope/.default"
os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope

def _run_with_scope_capture(self, credential_cls, action, *args, **kwargs):
scopes_captured = []
original_get_token = CosmosEmulatorCredential.get_token
original_get_token = credential_cls.get_token

def capturing_get_token(self, *scopes, **kwargs):
scopes_captured.extend(scopes)
return original_get_token(self, *scopes, **kwargs)

CosmosEmulatorCredential.get_token = capturing_get_token

credential_cls.get_token = capturing_get_token
try:
result = action(scopes_captured, *args, **kwargs)
finally:
credential_cls.get_token = original_get_token
return scopes_captured, result

def test_override_scope_no_fallback(self):
"""When override scope is provided, only that scope is used and no fallback occurs."""
override_scope = "https://my.custom.scope/.default"
os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope

def action(scopes_captured):
credential = CosmosEmulatorCredential()
client = cosmos_client.CosmosClient(self.host, credential)
db = client.get_database_client(self.configs.TEST_DATABASE_ID)
container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)
container.create_item(get_test_item(1))
assert override_scope in scopes_captured
container.create_item(get_test_item(10))
return container

scopes, container = self._run_with_scope_capture(CosmosEmulatorCredential, action)
try:
assert all(scope == override_scope for scope in scopes), f"Expected only override scope(s), got: {scopes}"
finally:
CosmosEmulatorCredential.get_token = original_get_token
del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"]
try:
container.delete_item(item='Item_1', partition_key='pk')
container.delete_item(item='Item_10', partition_key='pk')
except Exception:
pass

def test_override_scope_auth_error_no_fallback(self):
"""When override scope is provided and auth fails, no fallback to other scopes occurs."""
override_scope = "https://my.custom.scope/.default"
os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope

class FailingCredential(CosmosEmulatorCredential):
def get_token(self, *scopes, **kwargs):
raise Exception("Simulated auth error for override scope")

def action(scopes_captured):
with pytest.raises(Exception) as excinfo:
client = cosmos_client.CosmosClient(self.host, FailingCredential())
db = client.get_database_client(self.configs.TEST_DATABASE_ID)
container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)
container.create_item(get_test_item(11))
assert "Simulated auth error" in str(excinfo.value)
return None

scopes, _ = self._run_with_scope_capture(FailingCredential, action)
try:
assert scopes == [override_scope], f"Expected only override scope, got: {scopes}"
finally:
del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"]

def test_account_scope_only(self):
"""When account scope is provided, only that scope is used."""
account_scope = "https://localhost/.default"
os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = ""

def action(scopes_captured):
credential = CosmosEmulatorCredential()
client = cosmos_client.CosmosClient(self.host, credential)
db = client.get_database_client(self.configs.TEST_DATABASE_ID)
container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)
container.create_item(get_test_item(12))
return container

scopes, container = self._run_with_scope_capture(CosmosEmulatorCredential, action)
try:
# Accept multiple calls, but only the account_scope should be used
assert all(scope == account_scope for scope in scopes), f"Expected only account scope, got: {scopes}"
finally:
try:
container.delete_item(item='Item_12', partition_key='pk')
except Exception:
pass

def test_account_scope_fallback_on_error(self):
"""When account scope is provided and auth fails, fallback to default scope occurs."""
account_scope = "https://localhost/.default"
fallback_scope = "https://cosmos.azure.com/.default"
os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = ""

class FallbackCredential(CosmosEmulatorCredential):
def __init__(self):
self.call_count = 0

def get_token(self, *scopes, **kwargs):
self.call_count += 1
if self.call_count == 1:
raise HttpResponseError(message="AADSTS500011: Simulated error for fallback")
return super().get_token(*scopes, **kwargs)

def action(scopes_captured):
credential = FallbackCredential()
client = cosmos_client.CosmosClient(self.host, credential)
db = client.get_database_client(self.configs.TEST_DATABASE_ID)
container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)
container.create_item(get_test_item(13))
return container

scopes, container = self._run_with_scope_capture(FallbackCredential, action)
try:
# Accept multiple calls, but the first should be account_scope, and fallback_scope should appear after error
assert account_scope in scopes and fallback_scope in scopes, f"Expected fallback to default scope, got: {scopes}"
finally:
try:
container.delete_item(item='Item_13', partition_key='pk')
except Exception:
pass


if __name__ == "__main__":
unittest.main()
Loading