Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 92 additions & 19 deletions src/azure-cli-core/azure/cli/core/auth/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,31 @@
from azure.cli.core._environment import get_config_dir
from knack.log import get_logger
from knack.util import CLIError
from msal import PublicClientApplication
from msal import PublicClientApplication, ConfidentialClientApplication

# Service principal entry properties
from .msal_authentication import _CLIENT_ID, _TENANT, _CLIENT_SECRET, _CERTIFICATE, _CLIENT_ASSERTION, \
_USE_CERT_SN_ISSUER
from .msal_authentication import UserCredential, ServicePrincipalCredential
from .msal_credentials import UserCredential, ServicePrincipalCredential
from .persistence import load_persisted_token_cache, file_extensions, load_secret_store
from .util import check_result

AZURE_CLI_CLIENT_ID = '04b07795-8ddb-461a-bbee-02f9e1bf7b46'

# Service principal entry properties. Names are taken from OAuth 2.0 client credentials flow parameters:
# https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-client-creds-grant-flow
_TENANT = 'tenant'
_CLIENT_ID = 'client_id'
_CLIENT_SECRET = 'client_secret'
_CERTIFICATE = 'certificate'
_CLIENT_ASSERTION = 'client_assertion'
_USE_CERT_SN_ISSUER = 'use_cert_sn_issuer'

# For environment credential
AZURE_AUTHORITY_HOST = "AZURE_AUTHORITY_HOST"
AZURE_TENANT_ID = "AZURE_TENANT_ID"
AZURE_CLIENT_ID = "AZURE_CLIENT_ID"
AZURE_CLIENT_SECRET = "AZURE_CLIENT_SECRET"

FEDERATED_IDENTITY = "FEDERATED_IDENTITY"

WAM_PROMPT = (
"Select the account you want to log in with. "
"For more information on login with Azure CLI, see https://go.microsoft.com/fwlink/?linkid=2271136")
Expand Down Expand Up @@ -187,10 +194,9 @@ def login_with_service_principal(self, client_id, credential, scopes):
`credential` is a dict returned by ServicePrincipalAuth.build_credential
"""
sp_auth = ServicePrincipalAuth.build_from_credential(self.tenant_id, client_id, credential)

# This cred means SDK credential object
cred = ServicePrincipalCredential(sp_auth, **self._msal_app_kwargs)
result = cred.acquire_token_for_client(scopes)
client_credential = sp_auth.get_msal_client_credential()
cca = ConfidentialClientApplication(client_id, client_credential=client_credential, **self._msal_app_kwargs)
result = cca.acquire_token_for_client(scopes)
check_result(result)

# Only persist the service principal after a successful login
Expand Down Expand Up @@ -236,31 +242,45 @@ def get_user_credential(self, username):

def get_service_principal_credential(self, client_id):
entry = self._service_principal_store.load_entry(client_id, self.tenant_id)
sp_auth = ServicePrincipalAuth(entry)
return ServicePrincipalCredential(sp_auth, **self._msal_app_kwargs)
client_credential = ServicePrincipalAuth(entry).get_msal_client_credential()
return ServicePrincipalCredential(client_id, client_credential, **self._msal_app_kwargs)

def get_managed_identity_credential(self, client_id=None):
raise NotImplementedError


class ServicePrincipalAuth:

def __init__(self, entry):
# Initialize all attributes first, so that we don't need to call getattr to check their existence
self.client_id = None
self.tenant = None
# secret
self.client_secret = None
# certificate
self.certificate = None
self.use_cert_sn_issuer = None
# federated identity credential
self.client_assertion = None

# Internal attributes for certificate
self._certificate_string = None
self._thumbprint = None
self._public_certificate = None

self.__dict__.update(entry)

if _CERTIFICATE in entry:
if self.certificate:
from OpenSSL.crypto import load_certificate, FILETYPE_PEM, Error
self.public_certificate = None
try:
with open(self.certificate, 'r') as file_reader:
self.certificate_string = file_reader.read()
cert = load_certificate(FILETYPE_PEM, self.certificate_string)
self.thumbprint = cert.digest("sha1").decode().replace(':', '')
self._certificate_string = file_reader.read()
cert = load_certificate(FILETYPE_PEM, self._certificate_string)
self._thumbprint = cert.digest("sha1").decode().replace(':', '')
if entry.get(_USE_CERT_SN_ISSUER):
# low-tech but safe parsing based on
# https://github.com/libressl-portable/openbsd/blob/master/src/lib/libcrypto/pem/pem.h
match = re.search(r'-----BEGIN CERTIFICATE-----(?P<cert_value>[^-]+)-----END CERTIFICATE-----',
self.certificate_string, re.I)
self._certificate_string, re.I)
self.public_certificate = match.group()
except (UnicodeDecodeError, Error) as ex:
raise CLIError('Invalid certificate, please use a valid PEM file. Error detail: {}'.format(ex))
Expand Down Expand Up @@ -298,7 +318,41 @@ def build_credential(cls, secret_or_certificate=None, client_assertion=None, use

def get_entry_to_persist(self):
persisted_keys = [_CLIENT_ID, _TENANT, _CLIENT_SECRET, _CERTIFICATE, _USE_CERT_SN_ISSUER, _CLIENT_ASSERTION]
return {k: v for k, v in self.__dict__.items() if k in persisted_keys}
# Only persist certain attributes whose values are not None
return {k: v for k, v in self.__dict__.items() if k in persisted_keys and v}

def get_msal_client_credential(self):
client_credential = None

# client_secret
# "your client secret"
if self.client_secret:
client_credential = self.client_secret

# certificate
# {
# "private_key": "...-----BEGIN PRIVATE KEY-----... in PEM format",
# "thumbprint": "A1B2C3D4E5F6...",
# "public_certificate": "...-----BEGIN CERTIFICATE-----...",
# }
if self.certificate:
client_credential = {
"private_key": self._certificate_string,
"thumbprint": self._thumbprint
}
if self._public_certificate:
client_credential['public_certificate'] = self._public_certificate

# client_assertion
# {
# "client_assertion": "...a JWT with claims aud, exp, iss, jti, nbf, and sub..."
# }
if self.client_assertion:
client_credential = {
'client_assertion': get_id_token_on_github if self.client_assertion == FEDERATED_IDENTITY
else self.client_assertion}

return client_credential


class ServicePrincipalStore:
Expand Down Expand Up @@ -405,3 +459,22 @@ def get_environment_credential():
getenv(AZURE_TENANT_ID))
credentials = ServicePrincipalCredential(sp_auth, authority=authority)
return credentials


def get_id_token_on_github():
import os
from urllib.parse import quote
import requests
token = os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN']
url = os.environ['ACTIONS_ID_TOKEN_REQUEST_URL']
encodedAudience = quote('api://AzureADTokenExchange')
url = f'{url}&audience={encodedAudience}'
headers = {
'Authorization': f'bearer {token}',
'Accept': 'application/json; api-version=2.0',
'Content-Type': 'application/json'
}
result = requests.get(url, headers=headers)
id_token = result.json()['value']
logger.warning('Got ID token: %s', id_token)
return id_token
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,23 @@

from .util import check_result, build_sdk_access_token

# OAuth 2.0 client credentials flow parameter
# https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-client-creds-grant-flow
_TENANT = 'tenant'
_CLIENT_ID = 'client_id'
_CLIENT_SECRET = 'client_secret'
_CERTIFICATE = 'certificate'
_CLIENT_ASSERTION = 'client_assertion'
_USE_CERT_SN_ISSUER = 'use_cert_sn_issuer'

logger = get_logger(__name__)


class UserCredential(PublicClientApplication):
class UserCredential: # pylint: disable=too-few-public-methods

def __init__(self, client_id, username, **kwargs):
"""User credential implementing get_token interface.

:param client_id: Client ID of the CLI.
:param username: The username for user credential.
"""
super().__init__(client_id, **kwargs)
self._msal_app = PublicClientApplication(client_id, **kwargs)

# Make sure username is specified, otherwise MSAL returns all accounts
assert username, "username must be specified, got {!r}".format(username)

accounts = self.get_accounts(username)
accounts = self._msal_app.get_accounts(username)

# Usernames are usually unique. We are collecting corner cases to better understand its behavior.
if len(accounts) > 1:
Expand All @@ -65,8 +56,9 @@ def get_token(self, *scopes, claims=None, **kwargs):

if claims:
logger.warning('Acquiring new access token silently for tenant %s with claims challenge: %s',
self.authority.tenant, claims)
result = self.acquire_token_silent_with_error(list(scopes), self._account, claims_challenge=claims, **kwargs)
self._msal_app.authority.tenant, claims)
result = self._msal_app.acquire_token_silent_with_error(list(scopes), self._account, claims_challenge=claims,
**kwargs)

from azure.cli.core.azclierror import AuthenticationError
try:
Expand All @@ -82,13 +74,14 @@ def get_token(self, *scopes, claims=None, **kwargs):
logger.warning(ex)
logger.warning("\nThe default web browser has been opened at %s for scope '%s'. "
"Please continue the login in the web browser.",
self.authority.authorization_endpoint, ' '.join(scopes))
self._msal_app.authority.authorization_endpoint, ' '.join(scopes))

from .util import read_response_templates
success_template, error_template = read_response_templates()

result = self.acquire_token_interactive(
list(scopes), login_hint=self._account['username'], port=8400 if self.authority.is_adfs else None,
result = self._msal_app.acquire_token_interactive(
list(scopes), login_hint=self._account['username'],
port=8400 if self._msal_app.authority.is_adfs else None,
success_template=success_template, error_template=error_template, **kwargs)
check_result(result)

Expand All @@ -99,42 +92,18 @@ def get_token(self, *scopes, claims=None, **kwargs):
return build_sdk_access_token(result)


class ServicePrincipalCredential(ConfidentialClientApplication):
class ServicePrincipalCredential: # pylint: disable=too-few-public-methods

def __init__(self, service_principal_auth, **kwargs):
def __init__(self, client_id, client_credential, **kwargs):
"""Service principal credential implementing get_token interface.

:param service_principal_auth: An instance of ServicePrincipalAuth.
"""
client_credential = None

# client_secret
client_secret = getattr(service_principal_auth, _CLIENT_SECRET, None)
if client_secret:
client_credential = client_secret

# certificate
certificate = getattr(service_principal_auth, _CERTIFICATE, None)
if certificate:
client_credential = {
"private_key": getattr(service_principal_auth, 'certificate_string'),
"thumbprint": getattr(service_principal_auth, 'thumbprint')
}
public_certificate = getattr(service_principal_auth, 'public_certificate', None)
if public_certificate:
client_credential['public_certificate'] = public_certificate

# client_assertion
client_assertion = getattr(service_principal_auth, _CLIENT_ASSERTION, None)
if client_assertion:
client_credential = {'client_assertion': client_assertion}

super().__init__(service_principal_auth.client_id, client_credential=client_credential, **kwargs)
self._msal_app = ConfidentialClientApplication(client_id, client_credential, **kwargs)

def get_token(self, *scopes, **kwargs):
logger.debug("ServicePrincipalCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs)

scopes = list(scopes)
result = self.acquire_token_for_client(scopes, **kwargs)
result = self._msal_app.acquire_token_for_client(list(scopes), **kwargs)
check_result(result)
return build_sdk_access_token(result)
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from azure.cli.core import AzCommandsLoader
from azure.cli.core.commands import CliCommandType
from azure.cli.core.commands.parameters import get_enum_type
from azure.cli.core.commands.parameters import get_enum_type, get_three_state_flag

from azure.cli.command_modules.profile._format import transform_account_list
import azure.cli.command_modules.profile._help # pylint: disable=unused-import
Expand Down Expand Up @@ -58,6 +58,8 @@ def load_arguments(self, command):
c.argument('use_cert_sn_issuer', action='store_true', help='used with a service principal configured with Subject Name and Issuer Authentication in order to support automatic certificate rolls')
c.argument('scopes', options_list=['--scope'], nargs='+', help='Used in the /authorize request. It can cover only one static resource.')
c.argument('client_assertion', options_list=['--federated-token'], help='Federated token that can be used for OIDC token exchange.')
c.argument('federated_identity', options_list=['--federated-identity'], arg_type=get_three_state_flag(),
help='Use federated identity credential.')

with self.argument_context('logout') as c:
c.argument('username', help='account user, if missing, logout the current active account')
Expand Down
13 changes: 10 additions & 3 deletions src/azure-cli/azure/cli/command_modules/profile/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def account_clear(cmd):

# pylint: disable=inconsistent-return-statements, too-many-branches
def login(cmd, username=None, password=None, service_principal=None, tenant=None, allow_no_subscriptions=False,
identity=False, use_device_code=False, use_cert_sn_issuer=None, scopes=None, client_assertion=None):
identity=False, use_device_code=False, use_cert_sn_issuer=None, scopes=None, client_assertion=None,
federated_identity=None):
"""Log in to access Azure subscriptions"""

# quick argument usage check
Expand All @@ -128,6 +129,9 @@ def login(cmd, username=None, password=None, service_principal=None, tenant=None
raise CLIError("usage error: '--use-sn-issuer' is only applicable with a service principal")
if service_principal and not username:
raise CLIError('usage error: --service-principal --username NAME --password SECRET --tenant TENANT')
if client_assertion and federated_identity:
raise CLIError('usage error: Only one of --federated-token and --federated-identity can be specified')

if username and not service_principal and not identity:
logger.warning(USERNAME_PASSWORD_DEPRECATION_WARNING)

Expand All @@ -143,7 +147,7 @@ def login(cmd, username=None, password=None, service_principal=None, tenant=None
logger.warning(_CLOUD_CONSOLE_LOGIN_WARNING)

if username:
if not (password or client_assertion):
if not (password or client_assertion or federated_identity):
try:
password = prompt_pass('Password: ')
except NoTTYException:
Expand All @@ -153,7 +157,10 @@ def login(cmd, username=None, password=None, service_principal=None, tenant=None

if service_principal:
from azure.cli.core.auth.identity import ServicePrincipalAuth
password = ServicePrincipalAuth.build_credential(password, client_assertion, use_cert_sn_issuer)
password = ServicePrincipalAuth.build_credential(
secret_or_certificate=password,
client_assertion='FEDERATED_IDENTITY' if federated_identity else client_assertion,
use_cert_sn_issuer=use_cert_sn_issuer)

login_experience_v2 = cmd.cli_ctx.config.getboolean('core', 'login_experience_v2', fallback=True)
# Send login_experience_v2 config to telemetry
Expand Down