Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 52 additions & 7 deletions src/azure-cli-core/azure/cli/core/_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ def login(self,
return deepcopy(consolidated)

def login_with_managed_identity(self, identity_id=None, allow_no_subscriptions=None):
if _on_azure_arc_windows():
return self.login_with_managed_identity_azure_arc_windows(
identity_id=identity_id, allow_no_subscriptions=allow_no_subscriptions)

import jwt
from azure.mgmt.core.tools import is_valid_resource_id
from azure.cli.core.auth.adal_authentication import MSIAuthenticationWrapper
Expand Down Expand Up @@ -282,6 +286,33 @@ def login_with_managed_identity(self, identity_id=None, allow_no_subscriptions=N
self._set_subscriptions(consolidated)
return deepcopy(consolidated)

def login_with_managed_identity_azure_arc_windows(self, identity_id=None, allow_no_subscriptions=None):
import jwt
identity_type = MsiAccountTypes.system_assigned
from .auth.msal_credentials import ManagedIdentityCredential

cred = ManagedIdentityCredential()
token = cred.get_token(*self._arm_scope).token
logger.info('Managed identity: token was retrieved. Now trying to initialize local accounts...')
decode = jwt.decode(token, algorithms=['RS256'], options={"verify_signature": False})
tenant = decode['tid']

subscription_finder = SubscriptionFinder(self.cli_ctx)
subscriptions = subscription_finder.find_using_specific_tenant(tenant, cred)
base_name = ('{}-{}'.format(identity_type, identity_id) if identity_id else identity_type)
user = _USER_ASSIGNED_IDENTITY if identity_id else _SYSTEM_ASSIGNED_IDENTITY
if not subscriptions:
if allow_no_subscriptions:
subscriptions = self._build_tenant_level_accounts([tenant])
else:
raise CLIError('No access was configured for the managed identity, hence no subscriptions were found. '
"If this is expected, use '--allow-no-subscriptions' to have tenant level access.")

consolidated = self._normalize_properties(user, subscriptions, is_service_principal=True,
user_assigned_identity_id=base_name)
self._set_subscriptions(consolidated)
return deepcopy(consolidated)

def login_in_cloud_shell(self):
import jwt
from .auth.msal_credentials import CloudShellCredential
Expand Down Expand Up @@ -354,13 +385,18 @@ def get_login_credentials(self, resource=None, client_id=None, subscription_id=N
# Cloud Shell
from .auth.msal_credentials import CloudShellCredential
from azure.cli.core.auth.credential_adaptor import CredentialAdaptor
cs_cred = CloudShellCredential()
# The cloud shell credential must be wrapped by CredentialAdaptor so that it can work with Track 1 SDKs.
cred = CredentialAdaptor(cs_cred, resource=resource)
# The credential must be wrapped by CredentialAdaptor so that it can work with Track 1 SDKs.
cred = CredentialAdaptor(CloudShellCredential(), resource=resource)

elif managed_identity_type:
# managed identity
cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id, resource)
if _on_azure_arc_windows():
from .auth.msal_credentials import ManagedIdentityCredential
from azure.cli.core.auth.credential_adaptor import CredentialAdaptor
# The credential must be wrapped by CredentialAdaptor so that it can work with Track 1 SDKs.
cred = CredentialAdaptor(ManagedIdentityCredential(), resource=resource)
else:
cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id, resource)

else:
# user and service principal
Expand Down Expand Up @@ -415,9 +451,13 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No
# managed identity
if tenant:
raise CLIError("Tenant shouldn't be specified for managed identity account")
from .auth.util import scopes_to_resource
cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id,
scopes_to_resource(scopes))
if _on_azure_arc_windows():
from .auth.msal_credentials import ManagedIdentityCredential
cred = ManagedIdentityCredential()
else:
from .auth.util import scopes_to_resource
cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id,
scopes_to_resource(scopes))

else:
cred = self._create_credential(account, tenant)
Expand Down Expand Up @@ -918,3 +958,8 @@ def _create_identity_instance(cli_ctx, *args, **kwargs):
return Identity(*args, encrypt=encrypt, use_msal_http_cache=use_msal_http_cache,
enable_broker_on_windows=enable_broker_on_windows,
instance_discovery=instance_discovery, **kwargs)


def _on_azure_arc_windows():
# This indicates an Azure Arc-enabled Windows server
return "IDENTITY_ENDPOINT" in os.environ and "IMDS_ENDPOINT" in os.environ
21 changes: 20 additions & 1 deletion src/azure-cli-core/azure/cli/core/auth/msal_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

from knack.log import get_logger
from knack.util import CLIError
from msal import PublicClientApplication, ConfidentialClientApplication
from msal import (PublicClientApplication, ConfidentialClientApplication,
ManagedIdentityClient, SystemAssignedManagedIdentity)

from .constants import AZURE_CLI_CLIENT_ID
from .util import check_result, build_sdk_access_token
Expand Down Expand Up @@ -131,3 +132,21 @@ def get_token(self, *scopes, **kwargs):
result = self._msal_app.acquire_token_interactive(list(scopes), prompt="none", **kwargs)
check_result(result, scopes=scopes)
return build_sdk_access_token(result)


class ManagedIdentityCredential: # pylint: disable=too-few-public-methods
"""Managed identity credential implementing get_token interface.
Currently, only Azure Arc's system-assigned managed identity is supported.
"""

def __init__(self):
import requests
self._msal_client = ManagedIdentityClient(SystemAssignedManagedIdentity(), http_client=requests.Session())

def get_token(self, *scopes, **kwargs):
logger.debug("ManagedIdentityCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs)

from .util import scopes_to_resource
result = self._msal_client.acquire_token_for_client(resource=scopes_to_resource(scopes))
check_result(result)
return build_sdk_access_token(result)