diff --git a/src/azure-cli-core/azure/cli/core/_profile.py b/src/azure-cli-core/azure/cli/core/_profile.py index a4e71416808..ee2cd6d9de6 100644 --- a/src/azure-cli-core/azure/cli/core/_profile.py +++ b/src/azure-cli-core/azure/cli/core/_profile.py @@ -220,6 +220,10 @@ def login(self, return deepcopy(consolidated) def login_with_managed_identity(self, identity_id=None, allow_no_subscriptions=None): + if _on_azure_arc_windows(): + return self.login_with_managed_identity_azure_arc_windows( + identity_id=identity_id, allow_no_subscriptions=allow_no_subscriptions) + import jwt from azure.mgmt.core.tools import is_valid_resource_id from azure.cli.core.auth.adal_authentication import MSIAuthenticationWrapper @@ -282,6 +286,33 @@ def login_with_managed_identity(self, identity_id=None, allow_no_subscriptions=N self._set_subscriptions(consolidated) return deepcopy(consolidated) + def login_with_managed_identity_azure_arc_windows(self, identity_id=None, allow_no_subscriptions=None): + import jwt + identity_type = MsiAccountTypes.system_assigned + from .auth.msal_credentials import ManagedIdentityCredential + + cred = ManagedIdentityCredential() + token = cred.get_token(*self._arm_scope).token + logger.info('Managed identity: token was retrieved. Now trying to initialize local accounts...') + decode = jwt.decode(token, algorithms=['RS256'], options={"verify_signature": False}) + tenant = decode['tid'] + + subscription_finder = SubscriptionFinder(self.cli_ctx) + subscriptions = subscription_finder.find_using_specific_tenant(tenant, cred) + base_name = ('{}-{}'.format(identity_type, identity_id) if identity_id else identity_type) + user = _USER_ASSIGNED_IDENTITY if identity_id else _SYSTEM_ASSIGNED_IDENTITY + if not subscriptions: + if allow_no_subscriptions: + subscriptions = self._build_tenant_level_accounts([tenant]) + else: + raise CLIError('No access was configured for the managed identity, hence no subscriptions were found. ' + "If this is expected, use '--allow-no-subscriptions' to have tenant level access.") + + consolidated = self._normalize_properties(user, subscriptions, is_service_principal=True, + user_assigned_identity_id=base_name) + self._set_subscriptions(consolidated) + return deepcopy(consolidated) + def login_in_cloud_shell(self): import jwt from .auth.msal_credentials import CloudShellCredential @@ -354,13 +385,18 @@ def get_login_credentials(self, resource=None, client_id=None, subscription_id=N # Cloud Shell from .auth.msal_credentials import CloudShellCredential from azure.cli.core.auth.credential_adaptor import CredentialAdaptor - cs_cred = CloudShellCredential() - # The cloud shell credential must be wrapped by CredentialAdaptor so that it can work with Track 1 SDKs. - cred = CredentialAdaptor(cs_cred, resource=resource) + # The credential must be wrapped by CredentialAdaptor so that it can work with Track 1 SDKs. + cred = CredentialAdaptor(CloudShellCredential(), resource=resource) elif managed_identity_type: # managed identity - cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id, resource) + if _on_azure_arc_windows(): + from .auth.msal_credentials import ManagedIdentityCredential + from azure.cli.core.auth.credential_adaptor import CredentialAdaptor + # The credential must be wrapped by CredentialAdaptor so that it can work with Track 1 SDKs. + cred = CredentialAdaptor(ManagedIdentityCredential(), resource=resource) + else: + cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id, resource) else: # user and service principal @@ -415,9 +451,13 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No # managed identity if tenant: raise CLIError("Tenant shouldn't be specified for managed identity account") - from .auth.util import scopes_to_resource - cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id, - scopes_to_resource(scopes)) + if _on_azure_arc_windows(): + from .auth.msal_credentials import ManagedIdentityCredential + cred = ManagedIdentityCredential() + else: + from .auth.util import scopes_to_resource + cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id, + scopes_to_resource(scopes)) else: cred = self._create_credential(account, tenant) @@ -918,3 +958,8 @@ def _create_identity_instance(cli_ctx, *args, **kwargs): return Identity(*args, encrypt=encrypt, use_msal_http_cache=use_msal_http_cache, enable_broker_on_windows=enable_broker_on_windows, instance_discovery=instance_discovery, **kwargs) + + +def _on_azure_arc_windows(): + # This indicates an Azure Arc-enabled Windows server + return "IDENTITY_ENDPOINT" in os.environ and "IMDS_ENDPOINT" in os.environ diff --git a/src/azure-cli-core/azure/cli/core/auth/msal_credentials.py b/src/azure-cli-core/azure/cli/core/auth/msal_credentials.py index 3b58ecdaa48..c15d7ea5f7b 100644 --- a/src/azure-cli-core/azure/cli/core/auth/msal_credentials.py +++ b/src/azure-cli-core/azure/cli/core/auth/msal_credentials.py @@ -18,7 +18,8 @@ from knack.log import get_logger from knack.util import CLIError -from msal import PublicClientApplication, ConfidentialClientApplication +from msal import (PublicClientApplication, ConfidentialClientApplication, + ManagedIdentityClient, SystemAssignedManagedIdentity) from .constants import AZURE_CLI_CLIENT_ID from .util import check_result, build_sdk_access_token @@ -131,3 +132,21 @@ def get_token(self, *scopes, **kwargs): result = self._msal_app.acquire_token_interactive(list(scopes), prompt="none", **kwargs) check_result(result, scopes=scopes) return build_sdk_access_token(result) + + +class ManagedIdentityCredential: # pylint: disable=too-few-public-methods + """Managed identity credential implementing get_token interface. + Currently, only Azure Arc's system-assigned managed identity is supported. + """ + + def __init__(self): + import requests + self._msal_client = ManagedIdentityClient(SystemAssignedManagedIdentity(), http_client=requests.Session()) + + def get_token(self, *scopes, **kwargs): + logger.debug("ManagedIdentityCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs) + + from .util import scopes_to_resource + result = self._msal_client.acquire_token_for_client(resource=scopes_to_resource(scopes)) + check_result(result) + return build_sdk_access_token(result)