Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 73 additions & 20 deletions src/azure-cli-core/azure/cli/core/auth/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,20 @@
from knack.util import CLIError
from msal import PublicClientApplication, ConfidentialClientApplication

# Service principal entry properties
from .msal_authentication import _CLIENT_ID, _TENANT, _CLIENT_SECRET, _CERTIFICATE, _CLIENT_ASSERTION, \
_USE_CERT_SN_ISSUER
from .msal_authentication import UserCredential, ServicePrincipalCredential
from .msal_credentials import UserCredential, ServicePrincipalCredential
from .persistence import load_persisted_token_cache, file_extensions, load_secret_store
from .util import check_result

AZURE_CLI_CLIENT_ID = '04b07795-8ddb-461a-bbee-02f9e1bf7b46'

# Service principal entry properties. Names are taken from OAuth 2.0 client credentials flow parameters:
# https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-client-creds-grant-flow
_TENANT = 'tenant'
_CLIENT_ID = 'client_id'
_CLIENT_SECRET = 'client_secret'
_CERTIFICATE = 'certificate'
_USE_CERT_SN_ISSUER = 'use_cert_sn_issuer'
_CLIENT_ASSERTION = 'client_assertion'

# For environment credential
AZURE_AUTHORITY_HOST = "AZURE_AUTHORITY_HOST"
Expand Down Expand Up @@ -187,10 +192,9 @@ def login_with_service_principal(self, client_id, credential, scopes):
`credential` is a dict returned by ServicePrincipalAuth.build_credential
"""
sp_auth = ServicePrincipalAuth.build_from_credential(self.tenant_id, client_id, credential)

# This cred means SDK credential object
cred = ServicePrincipalCredential(sp_auth, **self._msal_app_kwargs)
result = cred.acquire_token_for_client(scopes)
client_credential = sp_auth.get_msal_client_credential()
cca = ConfidentialClientApplication(client_id, client_credential, **self._msal_app_kwargs)
result = cca.acquire_token_for_client(scopes)
Comment on lines +195 to +197
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because ServicePrincipalCredential stops inheriting from msal.ConfidentialClientApplication, it is no longer possible to call acquire_token_for_client on the cred. We prepare client_credential and directly create a ConfidentialClientApplication instance.

check_result(result)

# Only persist the service principal after a successful login
Expand Down Expand Up @@ -246,32 +250,47 @@ def get_user_credential(self, username):

def get_service_principal_credential(self, client_id):
entry = self._service_principal_store.load_entry(client_id, self.tenant_id)
sp_auth = ServicePrincipalAuth(entry)
return ServicePrincipalCredential(sp_auth, **self._msal_app_kwargs)
client_credential = ServicePrincipalAuth(entry).get_msal_client_credential()
return ServicePrincipalCredential(client_id, client_credential, **self._msal_app_kwargs)

def get_managed_identity_credential(self, client_id=None):
raise NotImplementedError


class ServicePrincipalAuth:

class ServicePrincipalAuth: # pylint: disable=too-many-instance-attributes
def __init__(self, entry):
# Initialize all attributes first, so that we don't need to call getattr to check their existence
self.client_id = None
self.tenant = None
# secret
self.client_secret = None
# certificate
self.certificate = None
self.use_cert_sn_issuer = None
# federated identity credential
self.client_assertion = None

# Internal attributes for certificate
# They are computed at runtime and not persisted in the service principal entry.
self._certificate_string = None
self._thumbprint = None
self._public_certificate = None

self.__dict__.update(entry)

if _CERTIFICATE in entry:
if self.certificate:
from OpenSSL.crypto import load_certificate, FILETYPE_PEM, Error
self.public_certificate = None
try:
with open(self.certificate, 'r') as file_reader:
self.certificate_string = file_reader.read()
cert = load_certificate(FILETYPE_PEM, self.certificate_string)
self.thumbprint = cert.digest("sha1").decode().replace(':', '')
self._certificate_string = file_reader.read()
cert = load_certificate(FILETYPE_PEM, self._certificate_string)
self._thumbprint = cert.digest("sha1").decode().replace(':', '')
if entry.get(_USE_CERT_SN_ISSUER):
# low-tech but safe parsing based on
# https://github.com/libressl-portable/openbsd/blob/master/src/lib/libcrypto/pem/pem.h
match = re.search(r'-----BEGIN CERTIFICATE-----(?P<cert_value>[^-]+)-----END CERTIFICATE-----',
self.certificate_string, re.I)
self.public_certificate = match.group()
self._certificate_string, re.I)
self._public_certificate = match.group()
except (UnicodeDecodeError, Error) as ex:
raise CLIError('Invalid certificate, please use a valid PEM file. Error detail: {}'.format(ex))

Expand Down Expand Up @@ -307,8 +326,42 @@ def build_credential(cls, secret_or_certificate=None, client_assertion=None, use
return entry

def get_entry_to_persist(self):
"""Get a service principal entry that can be persisted by ServicePrincipalStore."""
persisted_keys = [_CLIENT_ID, _TENANT, _CLIENT_SECRET, _CERTIFICATE, _USE_CERT_SN_ISSUER, _CLIENT_ASSERTION]
return {k: v for k, v in self.__dict__.items() if k in persisted_keys}
# Only persist certain attributes whose values are not None
return {k: v for k, v in self.__dict__.items() if k in persisted_keys and v}

def get_msal_client_credential(self):
"""Get a client_credential that can be consumed by msal.ConfidentialClientApplication."""
client_credential = None

# client_secret
# "your client secret"
if self.client_secret:
client_credential = self.client_secret

# certificate
# {
# "private_key": "...-----BEGIN PRIVATE KEY-----... in PEM format",
# "thumbprint": "A1B2C3D4E5F6...",
# "public_certificate": "...-----BEGIN CERTIFICATE-----...",
# }
if self.certificate:
client_credential = {
"private_key": self._certificate_string,
"thumbprint": self._thumbprint
}
if self._public_certificate:
client_credential['public_certificate'] = self._public_certificate

# client_assertion
# {
# "client_assertion": "...a JWT with claims aud, exp, iss, jti, nbf, and sub..."
# }
if self.client_assertion:
client_credential = {'client_assertion': self.client_assertion}

return client_credential


class ServicePrincipalStore:
Expand Down
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is renamed to more accurately reflect its content.

Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,23 @@

from .util import check_result, build_sdk_access_token

# OAuth 2.0 client credentials flow parameter
# https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-client-creds-grant-flow
_TENANT = 'tenant'
_CLIENT_ID = 'client_id'
_CLIENT_SECRET = 'client_secret'
_CERTIFICATE = 'certificate'
_CLIENT_ASSERTION = 'client_assertion'
_USE_CERT_SN_ISSUER = 'use_cert_sn_issuer'

logger = get_logger(__name__)


class UserCredential(PublicClientApplication):
class UserCredential: # pylint: disable=too-few-public-methods
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interestingly, pylint doesn't think this is worth a class. Haha

Used when class has too few public methods, so be sure it's really worth it.

Ref: https://pylint.readthedocs.io/en/latest/user_guide/messages/refactor/too-few-public-methods.html

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to define such class so that it can be passed to SDK as a credential.


def __init__(self, client_id, username, **kwargs):
"""User credential implementing get_token interface.

:param client_id: Client ID of the CLI.
:param username: The username for user credential.
"""
super().__init__(client_id, **kwargs)
self._msal_app = PublicClientApplication(client_id, **kwargs)

# Make sure username is specified, otherwise MSAL returns all accounts
assert username, "username must be specified, got {!r}".format(username)

accounts = self.get_accounts(username)
accounts = self._msal_app.get_accounts(username)

# Usernames are usually unique. We are collecting corner cases to better understand its behavior.
if len(accounts) > 1:
Expand All @@ -65,8 +56,9 @@ def get_token(self, *scopes, claims=None, **kwargs):

if claims:
logger.warning('Acquiring new access token silently for tenant %s with claims challenge: %s',
self.authority.tenant, claims)
result = self.acquire_token_silent_with_error(list(scopes), self._account, claims_challenge=claims, **kwargs)
self._msal_app.authority.tenant, claims)
result = self._msal_app.acquire_token_silent_with_error(list(scopes), self._account, claims_challenge=claims,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By saving msal.PublicClientApplication instance to self._msal_app, self._account is no longer saved as an attribute of msal.PublicClientApplication. This avoids the possible conflict of msal.PublicClientApplication introducing an _account attribute of its own.

**kwargs)

from azure.cli.core.azclierror import AuthenticationError
try:
Expand All @@ -82,13 +74,14 @@ def get_token(self, *scopes, claims=None, **kwargs):
logger.warning(ex)
logger.warning("\nThe default web browser has been opened at %s for scope '%s'. "
"Please continue the login in the web browser.",
self.authority.authorization_endpoint, ' '.join(scopes))
self._msal_app.authority.authorization_endpoint, ' '.join(scopes))

from .util import read_response_templates
success_template, error_template = read_response_templates()

result = self.acquire_token_interactive(
list(scopes), login_hint=self._account['username'], port=8400 if self.authority.is_adfs else None,
result = self._msal_app.acquire_token_interactive(
list(scopes), login_hint=self._account['username'],
port=8400 if self._msal_app.authority.is_adfs else None,
success_template=success_template, error_template=error_template, **kwargs)
check_result(result)

Expand All @@ -99,42 +92,19 @@ def get_token(self, *scopes, claims=None, **kwargs):
return build_sdk_access_token(result)


class ServicePrincipalCredential(ConfidentialClientApplication):
class ServicePrincipalCredential: # pylint: disable=too-few-public-methods

def __init__(self, service_principal_auth, **kwargs):
def __init__(self, client_id, client_credential, **kwargs):
"""Service principal credential implementing get_token interface.

:param service_principal_auth: An instance of ServicePrincipalAuth.
:param client_id: The service principal's client ID.
:param client_credential: client_credential that will be passed to MSAL.
"""
client_credential = None

# client_secret
client_secret = getattr(service_principal_auth, _CLIENT_SECRET, None)
if client_secret:
client_credential = client_secret

# certificate
certificate = getattr(service_principal_auth, _CERTIFICATE, None)
if certificate:
client_credential = {
"private_key": getattr(service_principal_auth, 'certificate_string'),
"thumbprint": getattr(service_principal_auth, 'thumbprint')
}
public_certificate = getattr(service_principal_auth, 'public_certificate', None)
if public_certificate:
client_credential['public_certificate'] = public_certificate

# client_assertion
client_assertion = getattr(service_principal_auth, _CLIENT_ASSERTION, None)
if client_assertion:
client_credential = {'client_assertion': client_assertion}

super().__init__(service_principal_auth.client_id, client_credential=client_credential, **kwargs)
self._msal_app = ConfidentialClientApplication(client_id, client_credential, **kwargs)

def get_token(self, *scopes, **kwargs):
logger.debug("ServicePrincipalCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs)

scopes = list(scopes)
result = self.acquire_token_for_client(scopes, **kwargs)
result = self._msal_app.acquire_token_for_client(list(scopes), **kwargs)
check_result(result)
return build_sdk_access_token(result)
Loading