Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor ServicePrincipalAuth
  • Loading branch information
jiasli committed Jul 25, 2024
commit c299ddd65197498924f113f41a21c309471e6ee8
116 changes: 65 additions & 51 deletions src/azure-cli-core/azure/cli/core/auth/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,20 @@
from knack.util import CLIError
from msal import PublicClientApplication, ConfidentialClientApplication

# Service principal entry properties
from .msal_authentication import _CLIENT_ID, _TENANT, _CLIENT_SECRET, _CERTIFICATE, _CLIENT_ASSERTION, \
_USE_CERT_SN_ISSUER
from .msal_authentication import UserCredential, ServicePrincipalCredential
from .msal_credentials import UserCredential, ServicePrincipalCredential
from .persistence import load_persisted_token_cache, file_extensions, load_secret_store
from .util import check_result

AZURE_CLI_CLIENT_ID = '04b07795-8ddb-461a-bbee-02f9e1bf7b46'

# Service principal entry properties. Names are taken from OAuth 2.0 client credentials flow parameters:
# https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-client-creds-grant-flow
_TENANT = 'tenant'
_CLIENT_ID = 'client_id'
_CLIENT_SECRET = 'client_secret'
_CERTIFICATE = 'certificate'
_CLIENT_ASSERTION = 'client_assertion'
_USE_CERT_SN_ISSUER = 'use_cert_sn_issuer'

# For environment credential
AZURE_AUTHORITY_HOST = "AZURE_AUTHORITY_HOST"
Expand Down Expand Up @@ -187,7 +192,7 @@ def login_with_service_principal(self, client_id, credential, scopes):
`credential` is a dict returned by ServicePrincipalAuth.build_credential
"""
sp_auth = ServicePrincipalAuth.build_from_credential(self.tenant_id, client_id, credential)
client_credential = _build_msal_client_credential(sp_auth)
client_credential = sp_auth.get_msal_client_credential()
cca = ConfidentialClientApplication(client_id, client_credential=client_credential, **self._msal_app_kwargs)
result = cca.acquire_token_for_client(scopes)
check_result(result)
Expand Down Expand Up @@ -235,32 +240,45 @@ def get_user_credential(self, username):

def get_service_principal_credential(self, client_id):
entry = self._service_principal_store.load_entry(client_id, self.tenant_id)
sp_auth = ServicePrincipalAuth(entry)
client_credential = _build_msal_client_credential(sp_auth)
client_credential = ServicePrincipalAuth(entry).get_msal_client_credential()
return ServicePrincipalCredential(client_id, client_credential, **self._msal_app_kwargs)

def get_managed_identity_credential(self, client_id=None):
raise NotImplementedError


class ServicePrincipalAuth:

def __init__(self, entry):
# Initialize all attributes first, so that we don't need to call getattr to check their existence
self.client_id = None
self.tenant = None
# secret
self.client_secret = None
# certificate
self.certificate = None
self.use_cert_sn_issuer = None
# federated identity credential
self.client_assertion = None

# Internal attributes for certificate
self._certificate_string = None
self._thumbprint = None
self._public_certificate = None

self.__dict__.update(entry)

if _CERTIFICATE in entry:
if self.certificate:
from OpenSSL.crypto import load_certificate, FILETYPE_PEM, Error
self.public_certificate = None
try:
with open(self.certificate, 'r') as file_reader:
self.certificate_string = file_reader.read()
cert = load_certificate(FILETYPE_PEM, self.certificate_string)
self.thumbprint = cert.digest("sha1").decode().replace(':', '')
self._certificate_string = file_reader.read()
cert = load_certificate(FILETYPE_PEM, self._certificate_string)
self._thumbprint = cert.digest("sha1").decode().replace(':', '')
if entry.get(_USE_CERT_SN_ISSUER):
# low-tech but safe parsing based on
# https://github.com/libressl-portable/openbsd/blob/master/src/lib/libcrypto/pem/pem.h
match = re.search(r'-----BEGIN CERTIFICATE-----(?P<cert_value>[^-]+)-----END CERTIFICATE-----',
self.certificate_string, re.I)
self._certificate_string, re.I)
self.public_certificate = match.group()
except (UnicodeDecodeError, Error) as ex:
raise CLIError('Invalid certificate, please use a valid PEM file. Error detail: {}'.format(ex))
Expand Down Expand Up @@ -298,7 +316,39 @@ def build_credential(cls, secret_or_certificate=None, client_assertion=None, use

def get_entry_to_persist(self):
persisted_keys = [_CLIENT_ID, _TENANT, _CLIENT_SECRET, _CERTIFICATE, _USE_CERT_SN_ISSUER, _CLIENT_ASSERTION]
return {k: v for k, v in self.__dict__.items() if k in persisted_keys}
# Only persist certain attributes whose values are not None
return {k: v for k, v in self.__dict__.items() if k in persisted_keys and v}

def get_msal_client_credential(self):
client_credential = None

# client_secret
# "your client secret"
if self.client_secret:
client_credential = self.client_secret

# certificate
# {
# "private_key": "...-----BEGIN PRIVATE KEY-----... in PEM format",
# "thumbprint": "A1B2C3D4E5F6...",
# "public_certificate": "...-----BEGIN CERTIFICATE-----...",
# }
if self.certificate:
client_credential = {
"private_key": self._certificate_string,
"thumbprint": self._thumbprint
}
if self._public_certificate:
client_credential['public_certificate'] = self._public_certificate

# client_assertion
# {
# "client_assertion": "...a JWT with claims aud, exp, iss, jti, nbf, and sub..."
# }
if self.client_assertion:
client_credential = {'client_assertion': self.client_assertion}

return client_credential


class ServicePrincipalStore:
Expand Down Expand Up @@ -405,39 +455,3 @@ def get_environment_credential():
getenv(AZURE_TENANT_ID))
credentials = ServicePrincipalCredential(sp_auth, authority=authority)
return credentials


def _build_msal_client_credential(service_principal_auth):
client_credential = None

# client_secret
# "your client secret"
client_secret = getattr(service_principal_auth, _CLIENT_SECRET, None)
if client_secret:
client_credential = client_secret

# certificate
# {
# "private_key": "...-----BEGIN PRIVATE KEY-----... in PEM format",
# "thumbprint": "A1B2C3D4E5F6...",
# "public_certificate": "...-----BEGIN CERTIFICATE-----...",
# }
certificate = getattr(service_principal_auth, _CERTIFICATE, None)
if certificate:
client_credential = {
"private_key": getattr(service_principal_auth, 'certificate_string'),
"thumbprint": getattr(service_principal_auth, 'thumbprint')
}
public_certificate = getattr(service_principal_auth, 'public_certificate', None)
if public_certificate:
client_credential['public_certificate'] = public_certificate

# client_assertion
# {
# "client_assertion": "...a JWT with claims aud, exp, iss, jti, nbf, and sub..."
# }
client_assertion = getattr(service_principal_auth, _CLIENT_ASSERTION, None)
if client_assertion:
client_credential = {'client_assertion': client_assertion}

return client_credential
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,6 @@

from .util import check_result, build_sdk_access_token

# OAuth 2.0 client credentials flow parameter
# https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-client-creds-grant-flow
_TENANT = 'tenant'
_CLIENT_ID = 'client_id'
_CLIENT_SECRET = 'client_secret'
_CERTIFICATE = 'certificate'
_CLIENT_ASSERTION = 'client_assertion'
_USE_CERT_SN_ISSUER = 'use_cert_sn_issuer'

logger = get_logger(__name__)


Expand Down