Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
tests
  • Loading branch information
chlowell committed Jun 4, 2020
commit dab341706af719ac3259d2ccbe8d11b3ab01ac5a
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
# ------------------------------------
from azure.core.exceptions import ClientAuthenticationError
from azure.core.pipeline.policies import SansIOHTTPPolicy
from azure.identity import CredentialUnavailableError, KnownAuthorities, SharedTokenCacheCredential
from azure.identity import (
AuthenticationRecord,
CredentialUnavailableError,
SharedTokenCacheCredential,
)
from azure.identity._constants import EnvironmentVariables
from azure.identity._internal.shared_token_cache import (
KNOWN_ALIASES,
Expand Down Expand Up @@ -502,6 +506,98 @@ def test_authority_environment_variable():
assert token.token == expected_access_token


def test_authentication_record_empty_cache():
record = AuthenticationRecord("tenant_id", "client_id", "authority", "home_account_id", "username")
transport = Mock(side_effect=Exception("the credential shouldn't send a request"))
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=TokenCache())

with pytest.raises(CredentialUnavailableError):
credential.get_token("scope")


def test_authentication_record_no_match():
tenant_id = "tenant-id"
client_id = "client-id"
authority = "localhost"
object_id = "object-id"
home_account_id = object_id + "." + tenant_id
username = "me"
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)

transport = Mock(side_effect=Exception("the credential shouldn't send a request"))
cache = populated_cache(
get_account_event(
"not-" + username, "not-" + object_id, "different-" + tenant_id, client_id="not-" + client_id,
),
)
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)

with pytest.raises(CredentialUnavailableError):
credential.get_token("scope")


def test_authentication_record():
tenant_id = "tenant-id"
client_id = "client-id"
authority = "localhost"
object_id = "object-id"
home_account_id = object_id + "." + tenant_id
username = "me"
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)

expected_access_token = "****"
expected_refresh_token = "**"
account = get_account_event(
username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token
)
cache = populated_cache(account)

transport = validating_transport(
requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})],
responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))],
)
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)

token = credential.get_token("scope")
assert token.token == expected_access_token


def test_auth_record_multiple_accounts_for_username():
tenant_id = "tenant-id"
client_id = "client-id"
authority = "localhost"
object_id = "object-id"
home_account_id = object_id + "." + tenant_id
username = "me"
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)

expected_access_token = "****"
expected_refresh_token = "**"
expected_account = get_account_event(
username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token
)
cache = populated_cache(
expected_account,
get_account_event( # this account matches all but the record's tenant
username,
object_id,
"different-" + tenant_id,
authority=authority,
client_id=client_id,
refresh_token="not-" + expected_refresh_token,
),
)

transport = validating_transport(
requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})],
responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))],
)
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)

token = credential.get_token("scope")
assert token.token == expected_access_token


def get_account_event(
username, uid, utid, authority=None, client_id="client-id", refresh_token="refresh-token", scopes=None
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from azure.core.exceptions import ClientAuthenticationError
from azure.core.pipeline.policies import SansIOHTTPPolicy
from azure.identity import CredentialUnavailableError, KnownAuthorities
from azure.identity import AuthenticationRecord, CredentialUnavailableError
from azure.identity.aio import SharedTokenCacheCredential
from azure.identity._constants import EnvironmentVariables
from azure.identity._internal.shared_token_cache import (
Expand Down Expand Up @@ -566,3 +566,99 @@ async def test_authority_environment_variable():
credential = SharedTokenCacheCredential(transport=transport, _cache=cache)
token = await credential.get_token("scope")
assert token.token == expected_access_token


@pytest.mark.asyncio
async def test_authentication_record_empty_cache():
record = AuthenticationRecord("tenant_id", "client_id", "authority", "home_account_id", "username")
transport = Mock(side_effect=Exception("the credential shouldn't send a request"))
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=TokenCache())

with pytest.raises(CredentialUnavailableError):
await credential.get_token("scope")


@pytest.mark.asyncio
async def test_authentication_record_no_match():
tenant_id = "tenant-id"
client_id = "client-id"
authority = "localhost"
object_id = "object-id"
home_account_id = object_id + "." + tenant_id
username = "me"
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)

transport = Mock(side_effect=Exception("the credential shouldn't send a request"))
cache = populated_cache(
get_account_event(
"not-" + username, "not-" + object_id, "different-" + tenant_id, client_id="not-" + client_id,
),
)
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)

with pytest.raises(CredentialUnavailableError):
await credential.get_token("scope")


@pytest.mark.asyncio
async def test_authentication_record():
tenant_id = "tenant-id"
client_id = "client-id"
authority = "localhost"
object_id = "object-id"
home_account_id = object_id + "." + tenant_id
username = "me"
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)

expected_access_token = "****"
expected_refresh_token = "**"
account = get_account_event(
username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token
)
cache = populated_cache(account)

transport = async_validating_transport(
requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})],
responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))],
)
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)

token = await credential.get_token("scope")
assert token.token == expected_access_token


@pytest.mark.asyncio
async def test_auth_record_multiple_accounts_for_username():
tenant_id = "tenant-id"
client_id = "client-id"
authority = "localhost"
object_id = "object-id"
home_account_id = object_id + "." + tenant_id
username = "me"
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)

expected_access_token = "****"
expected_refresh_token = "**"
expected_account = get_account_event(
username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token
)
cache = populated_cache(
expected_account,
get_account_event( # this account matches all but the record's tenant
username,
object_id,
"different-" + tenant_id,
authority=authority,
client_id=client_id,
refresh_token="not-" + expected_refresh_token,
),
)

transport = async_validating_transport(
requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})],
responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))],
)
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)

token = await credential.get_token("scope")
assert token.token == expected_access_token