diff --git a/src/azure-cli-core/azure/cli/core/auth/identity.py b/src/azure-cli-core/azure/cli/core/auth/identity.py index 89501d31a32..49894903fc0 100644 --- a/src/azure-cli-core/azure/cli/core/auth/identity.py +++ b/src/azure-cli-core/azure/cli/core/auth/identity.py @@ -13,15 +13,20 @@ from knack.util import CLIError from msal import PublicClientApplication, ConfidentialClientApplication -# Service principal entry properties -from .msal_authentication import _CLIENT_ID, _TENANT, _CLIENT_SECRET, _CERTIFICATE, _CLIENT_ASSERTION, \ - _USE_CERT_SN_ISSUER -from .msal_authentication import UserCredential, ServicePrincipalCredential +from .msal_credentials import UserCredential, ServicePrincipalCredential from .persistence import load_persisted_token_cache, file_extensions, load_secret_store from .util import check_result AZURE_CLI_CLIENT_ID = '04b07795-8ddb-461a-bbee-02f9e1bf7b46' +# Service principal entry properties. Names are taken from OAuth 2.0 client credentials flow parameters: +# https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-client-creds-grant-flow +_TENANT = 'tenant' +_CLIENT_ID = 'client_id' +_CLIENT_SECRET = 'client_secret' +_CERTIFICATE = 'certificate' +_USE_CERT_SN_ISSUER = 'use_cert_sn_issuer' +_CLIENT_ASSERTION = 'client_assertion' # For environment credential AZURE_AUTHORITY_HOST = "AZURE_AUTHORITY_HOST" @@ -187,10 +192,9 @@ def login_with_service_principal(self, client_id, credential, scopes): `credential` is a dict returned by ServicePrincipalAuth.build_credential """ sp_auth = ServicePrincipalAuth.build_from_credential(self.tenant_id, client_id, credential) - - # This cred means SDK credential object - cred = ServicePrincipalCredential(sp_auth, **self._msal_app_kwargs) - result = cred.acquire_token_for_client(scopes) + client_credential = sp_auth.get_msal_client_credential() + cca = ConfidentialClientApplication(client_id, client_credential, **self._msal_app_kwargs) + result = cca.acquire_token_for_client(scopes) check_result(result) # Only persist the service principal after a successful login @@ -246,32 +250,47 @@ def get_user_credential(self, username): def get_service_principal_credential(self, client_id): entry = self._service_principal_store.load_entry(client_id, self.tenant_id) - sp_auth = ServicePrincipalAuth(entry) - return ServicePrincipalCredential(sp_auth, **self._msal_app_kwargs) + client_credential = ServicePrincipalAuth(entry).get_msal_client_credential() + return ServicePrincipalCredential(client_id, client_credential, **self._msal_app_kwargs) def get_managed_identity_credential(self, client_id=None): raise NotImplementedError -class ServicePrincipalAuth: - +class ServicePrincipalAuth: # pylint: disable=too-many-instance-attributes def __init__(self, entry): + # Initialize all attributes first, so that we don't need to call getattr to check their existence + self.client_id = None + self.tenant = None + # secret + self.client_secret = None + # certificate + self.certificate = None + self.use_cert_sn_issuer = None + # federated identity credential + self.client_assertion = None + + # Internal attributes for certificate + # They are computed at runtime and not persisted in the service principal entry. + self._certificate_string = None + self._thumbprint = None + self._public_certificate = None + self.__dict__.update(entry) - if _CERTIFICATE in entry: + if self.certificate: from OpenSSL.crypto import load_certificate, FILETYPE_PEM, Error - self.public_certificate = None try: with open(self.certificate, 'r') as file_reader: - self.certificate_string = file_reader.read() - cert = load_certificate(FILETYPE_PEM, self.certificate_string) - self.thumbprint = cert.digest("sha1").decode().replace(':', '') + self._certificate_string = file_reader.read() + cert = load_certificate(FILETYPE_PEM, self._certificate_string) + self._thumbprint = cert.digest("sha1").decode().replace(':', '') if entry.get(_USE_CERT_SN_ISSUER): # low-tech but safe parsing based on # https://github.com/libressl-portable/openbsd/blob/master/src/lib/libcrypto/pem/pem.h match = re.search(r'-----BEGIN CERTIFICATE-----(?P[^-]+)-----END CERTIFICATE-----', - self.certificate_string, re.I) - self.public_certificate = match.group() + self._certificate_string, re.I) + self._public_certificate = match.group() except (UnicodeDecodeError, Error) as ex: raise CLIError('Invalid certificate, please use a valid PEM file. Error detail: {}'.format(ex)) @@ -307,8 +326,42 @@ def build_credential(cls, secret_or_certificate=None, client_assertion=None, use return entry def get_entry_to_persist(self): + """Get a service principal entry that can be persisted by ServicePrincipalStore.""" persisted_keys = [_CLIENT_ID, _TENANT, _CLIENT_SECRET, _CERTIFICATE, _USE_CERT_SN_ISSUER, _CLIENT_ASSERTION] - return {k: v for k, v in self.__dict__.items() if k in persisted_keys} + # Only persist certain attributes whose values are not None + return {k: v for k, v in self.__dict__.items() if k in persisted_keys and v} + + def get_msal_client_credential(self): + """Get a client_credential that can be consumed by msal.ConfidentialClientApplication.""" + client_credential = None + + # client_secret + # "your client secret" + if self.client_secret: + client_credential = self.client_secret + + # certificate + # { + # "private_key": "...-----BEGIN PRIVATE KEY-----... in PEM format", + # "thumbprint": "A1B2C3D4E5F6...", + # "public_certificate": "...-----BEGIN CERTIFICATE-----...", + # } + if self.certificate: + client_credential = { + "private_key": self._certificate_string, + "thumbprint": self._thumbprint + } + if self._public_certificate: + client_credential['public_certificate'] = self._public_certificate + + # client_assertion + # { + # "client_assertion": "...a JWT with claims aud, exp, iss, jti, nbf, and sub..." + # } + if self.client_assertion: + client_credential = {'client_assertion': self.client_assertion} + + return client_credential class ServicePrincipalStore: diff --git a/src/azure-cli-core/azure/cli/core/auth/msal_authentication.py b/src/azure-cli-core/azure/cli/core/auth/msal_credentials.py similarity index 65% rename from src/azure-cli-core/azure/cli/core/auth/msal_authentication.py rename to src/azure-cli-core/azure/cli/core/auth/msal_credentials.py index b7b43ae32ba..8c6dfd0daf3 100644 --- a/src/azure-cli-core/azure/cli/core/auth/msal_authentication.py +++ b/src/azure-cli-core/azure/cli/core/auth/msal_credentials.py @@ -22,19 +22,10 @@ from .util import check_result, build_sdk_access_token -# OAuth 2.0 client credentials flow parameter -# https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-client-creds-grant-flow -_TENANT = 'tenant' -_CLIENT_ID = 'client_id' -_CLIENT_SECRET = 'client_secret' -_CERTIFICATE = 'certificate' -_CLIENT_ASSERTION = 'client_assertion' -_USE_CERT_SN_ISSUER = 'use_cert_sn_issuer' - logger = get_logger(__name__) -class UserCredential(PublicClientApplication): +class UserCredential: # pylint: disable=too-few-public-methods def __init__(self, client_id, username, **kwargs): """User credential implementing get_token interface. @@ -42,12 +33,12 @@ def __init__(self, client_id, username, **kwargs): :param client_id: Client ID of the CLI. :param username: The username for user credential. """ - super().__init__(client_id, **kwargs) + self._msal_app = PublicClientApplication(client_id, **kwargs) # Make sure username is specified, otherwise MSAL returns all accounts assert username, "username must be specified, got {!r}".format(username) - accounts = self.get_accounts(username) + accounts = self._msal_app.get_accounts(username) # Usernames are usually unique. We are collecting corner cases to better understand its behavior. if len(accounts) > 1: @@ -65,8 +56,9 @@ def get_token(self, *scopes, claims=None, **kwargs): if claims: logger.warning('Acquiring new access token silently for tenant %s with claims challenge: %s', - self.authority.tenant, claims) - result = self.acquire_token_silent_with_error(list(scopes), self._account, claims_challenge=claims, **kwargs) + self._msal_app.authority.tenant, claims) + result = self._msal_app.acquire_token_silent_with_error(list(scopes), self._account, claims_challenge=claims, + **kwargs) from azure.cli.core.azclierror import AuthenticationError try: @@ -82,13 +74,14 @@ def get_token(self, *scopes, claims=None, **kwargs): logger.warning(ex) logger.warning("\nThe default web browser has been opened at %s for scope '%s'. " "Please continue the login in the web browser.", - self.authority.authorization_endpoint, ' '.join(scopes)) + self._msal_app.authority.authorization_endpoint, ' '.join(scopes)) from .util import read_response_templates success_template, error_template = read_response_templates() - result = self.acquire_token_interactive( - list(scopes), login_hint=self._account['username'], port=8400 if self.authority.is_adfs else None, + result = self._msal_app.acquire_token_interactive( + list(scopes), login_hint=self._account['username'], + port=8400 if self._msal_app.authority.is_adfs else None, success_template=success_template, error_template=error_template, **kwargs) check_result(result) @@ -99,42 +92,19 @@ def get_token(self, *scopes, claims=None, **kwargs): return build_sdk_access_token(result) -class ServicePrincipalCredential(ConfidentialClientApplication): +class ServicePrincipalCredential: # pylint: disable=too-few-public-methods - def __init__(self, service_principal_auth, **kwargs): + def __init__(self, client_id, client_credential, **kwargs): """Service principal credential implementing get_token interface. - :param service_principal_auth: An instance of ServicePrincipalAuth. + :param client_id: The service principal's client ID. + :param client_credential: client_credential that will be passed to MSAL. """ - client_credential = None - - # client_secret - client_secret = getattr(service_principal_auth, _CLIENT_SECRET, None) - if client_secret: - client_credential = client_secret - - # certificate - certificate = getattr(service_principal_auth, _CERTIFICATE, None) - if certificate: - client_credential = { - "private_key": getattr(service_principal_auth, 'certificate_string'), - "thumbprint": getattr(service_principal_auth, 'thumbprint') - } - public_certificate = getattr(service_principal_auth, 'public_certificate', None) - if public_certificate: - client_credential['public_certificate'] = public_certificate - - # client_assertion - client_assertion = getattr(service_principal_auth, _CLIENT_ASSERTION, None) - if client_assertion: - client_credential = {'client_assertion': client_assertion} - - super().__init__(service_principal_auth.client_id, client_credential=client_credential, **kwargs) + self._msal_app = ConfidentialClientApplication(client_id, client_credential, **kwargs) def get_token(self, *scopes, **kwargs): logger.debug("ServicePrincipalCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs) - scopes = list(scopes) - result = self.acquire_token_for_client(scopes, **kwargs) + result = self._msal_app.acquire_token_for_client(list(scopes), **kwargs) check_result(result) return build_sdk_access_token(result) diff --git a/src/azure-cli-core/azure/cli/core/auth/tests/test_identity.py b/src/azure-cli-core/azure/cli/core/auth/tests/test_identity.py index cc89788f4d0..8f753e264ee 100644 --- a/src/azure-cli-core/azure/cli/core/auth/tests/test_identity.py +++ b/src/azure-cli-core/azure/cli/core/auth/tests/test_identity.py @@ -12,98 +12,108 @@ _get_authority_url) from knack.util import CLIError +# CERTIFICATE section in sp_cert.pem +PUBLIC_CERTIFICATE = """-----BEGIN CERTIFICATE----- +MIIDtTCCAp2gAwIBAgIJAPMNsT0qjg1ZMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwHhcNMTcwMzEwMDQ0NjEyWhcNMTgwMzEwMDQ0NjEyWjBF +MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50 +ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB +CgKCAQEAxec32tnXNiPz2WBTpv7ccZvYqBR2Gr8vimQbiNgT3aHY/dzV26pYv/88 +X5PbkibAr3YXJP64nGI/0MGvFWYi6c6C0Ar6QL/MgRLIGIO8JePTxKu9ZDx+5Crw +beJRQgz7nEtCWsIx5WiIx5/yjUR5AqrNwSxNWo6Ct3E1YWzGyI03gEEr82tEG9Vd +ObIRq05v1hHKTm27xln41JZI1aUMzd/K/pckb6nQLtV6OpOmzZQILMOV95SKJ8+k +1gnxfOX2t9JPgTuiVmwvgYLb1k7Hfqs1/KZt4IyIRkBaXPy2j5Guz09uR1Dg4tOc +oSPwDeN0aQQSucRsk0iaof3DXMfVLQIDAQABo4GnMIGkMB0GA1UdDgQWBBRpCyBM +VgNXHqX5MrBdAQ1Hzf8l7jB1BgNVHSMEbjBsgBRpCyBMVgNXHqX5MrBdAQ1Hzf8l +7qFJpEcwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUxITAfBgNV +BAoTGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZIIJAPMNsT0qjg1ZMAwGA1UdEwQF +MAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAEH/nmErQLSxsMDk3LgTpBY6ibl6xU0k +Lt1wbC+Z3sgpt82oA4BiulcJtTf3IrvBXJNRaB++ChjqRnK8O6uWbBQxvz/V8l+9 +g3s49VSaX3QB74Rh1NIfKhUyYlG3yi8qBJA6tlCNNXGQoYvND9Y3gorj+LzH3Eqf +9g2oBm2jWaiPBHjuuUbd+SBS2hQn/i2huWnz1yewrtfVpRwWrQQHa1Qv3ivKDK2H +2LOdn2Xs3/ZGsi1ySfjzxjTbuPhUaEUy+ZfV2dgmqiS//BAWI5opo7TgeplrGk2P +h5Fwbt0FxaqFCNZdrPI7FRnbKZwvGx0A+Zj8ZpNjft3QjuUg+xqMKMs= +-----END CERTIFICATE-----""" + + +TEST_CERT = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'sp_cert.pem') + +with open(TEST_CERT) as f: + CERTIFICATE_STRING = f.read() + class TestIdentity(unittest.TestCase): @mock.patch("azure.cli.core.auth.identity.ServicePrincipalStore.save_entry") @mock.patch("msal.application.ConfidentialClientApplication.acquire_token_for_client") - @mock.patch("msal.application.ConfidentialClientApplication.__init__") + @mock.patch("msal.application.ConfidentialClientApplication.__init__", return_value=None) def test_login_with_service_principal_secret(self, init_mock, acquire_token_for_client_mock, save_entry_mock): acquire_token_for_client_mock.return_value = {'access_token': "test_token"} - identity = Identity('https://login.microsoftonline.com', tenant_id='my-tenant') - - identity.login_with_service_principal("00000000-0000-0000-0000-000000000000", - {"client_secret": "test_secret"}, "openid") + identity = Identity('https://login.microsoftonline.com', tenant_id='tenant1') + identity.login_with_service_principal("sp_id1", {"client_secret": "test_secret"}, "openid") - assert init_mock.call_args[0][0] == '00000000-0000-0000-0000-000000000000' - assert init_mock.call_args[1]['client_credential'] == 'test_secret' - assert init_mock.call_args[1]['authority'] == 'https://login.microsoftonline.com/my-tenant' + assert init_mock.call_args.args == ('sp_id1',) + assert init_mock.call_args.kwargs['client_credential'] == 'test_secret' + assert init_mock.call_args.kwargs['authority'] == 'https://login.microsoftonline.com/tenant1' - assert save_entry_mock.call_args[0][0] == { - 'tenant': 'my-tenant', - 'client_id': '00000000-0000-0000-0000-000000000000', + assert save_entry_mock.call_args.args[0] == { + 'client_id': 'sp_id1', + 'tenant': 'tenant1', 'client_secret': 'test_secret' } @mock.patch("azure.cli.core.auth.identity.ServicePrincipalStore.save_entry") @mock.patch("msal.application.ConfidentialClientApplication.acquire_token_for_client") - @mock.patch("msal.application.ConfidentialClientApplication.__init__") + @mock.patch("msal.application.ConfidentialClientApplication.__init__", return_value=None) def test_login_with_service_principal_certificate(self, init_mock, acquire_token_for_client_mock, save_entry_mock): acquire_token_for_client_mock.return_value = {'access_token': "test_token"} - identity = Identity('https://login.microsoftonline.com', tenant_id='my-tenant') - - curr_dir = os.path.dirname(os.path.realpath(__file__)) - test_cert_file = os.path.join(curr_dir, 'sp_cert.pem') - - with open(test_cert_file) as cert_file: - cert_file_string = cert_file.read() + identity = Identity('https://login.microsoftonline.com', tenant_id='tenant1') + identity.login_with_service_principal("sp_id1", {'certificate': TEST_CERT}, 'openid') - identity.login_with_service_principal("00000000-0000-0000-0000-000000000000", - {'certificate': test_cert_file}, 'openid') - - assert init_mock.call_args[0][0] == '00000000-0000-0000-0000-000000000000' - assert init_mock.call_args[1]['client_credential'] == { - 'private_key': cert_file_string, + assert init_mock.call_args.args == ('sp_id1',) + assert init_mock.call_args.kwargs['client_credential'] == { + 'private_key': CERTIFICATE_STRING, 'thumbprint': 'F06A53848BBE714A4290D69D335279C1D01073FD' } - assert init_mock.call_args[1]['authority'] == 'https://login.microsoftonline.com/my-tenant' + assert init_mock.call_args.kwargs['authority'] == 'https://login.microsoftonline.com/tenant1' assert save_entry_mock.call_args[0][0] == { - 'tenant': 'my-tenant', - 'client_id': '00000000-0000-0000-0000-000000000000', - 'certificate': test_cert_file + 'client_id': 'sp_id1', + 'tenant': 'tenant1', + 'certificate': TEST_CERT } @mock.patch("azure.cli.core.auth.identity.ServicePrincipalStore.save_entry") @mock.patch("msal.application.ConfidentialClientApplication.acquire_token_for_client") - @mock.patch("msal.application.ConfidentialClientApplication.__init__") + @mock.patch("msal.application.ConfidentialClientApplication.__init__", return_value=None) def test_login_with_service_principal_certificate_sn_issuer(self, init_mock, acquire_token_for_client_mock, save_entry_mock): acquire_token_for_client_mock.return_value = {'access_token': "test_token"} - identity = Identity('https://login.microsoftonline.com', tenant_id='my-tenant') - - curr_dir = os.path.dirname(os.path.realpath(__file__)) - test_cert_file = os.path.join(curr_dir, 'sp_cert.pem') - - with open(test_cert_file) as cert_file: - cert_file_string = cert_file.read() - - match = re.search(r'-+BEGIN CERTIFICATE-+(?P[^-]+)-+END CERTIFICATE-+', cert_file_string, re.I) - public_certificate = match.group().strip() - - identity.login_with_service_principal("00000000-0000-0000-0000-000000000000", + identity = Identity('https://login.microsoftonline.com', tenant_id='tenant1') + identity.login_with_service_principal("sp_id1", { - 'certificate': test_cert_file, + 'certificate': TEST_CERT, 'use_cert_sn_issuer': True, }, "openid") - assert init_mock.call_args[0][0] == '00000000-0000-0000-0000-000000000000' - assert init_mock.call_args[1]['client_credential'] == { - "private_key": cert_file_string, + assert init_mock.call_args.args == ('sp_id1',) + assert init_mock.call_args.kwargs['client_credential'] == { + "private_key": CERTIFICATE_STRING, "thumbprint": 'F06A53848BBE714A4290D69D335279C1D01073FD', - "public_certificate": public_certificate + "public_certificate": PUBLIC_CERTIFICATE } - assert init_mock.call_args[1]['authority'] == 'https://login.microsoftonline.com/my-tenant' + assert init_mock.call_args.kwargs['authority'] == 'https://login.microsoftonline.com/tenant1' - assert save_entry_mock.call_args[0][0] == { - 'tenant': 'my-tenant', - 'client_id': '00000000-0000-0000-0000-000000000000', - 'certificate': test_cert_file, + assert save_entry_mock.call_args.args[0] == { + 'client_id': 'sp_id1', + 'tenant': 'tenant1', + 'certificate': TEST_CERT, 'use_cert_sn_issuer': True } @@ -114,8 +124,27 @@ def test_login_with_service_principal_certificate_cert_err(self): test_cert_file = os.path.join(current_dir, 'err_sp_cert.pem') with self.assertRaisesRegex(CLIError, "Invalid certificate"): - identity.login_with_service_principal("00000000-0000-0000-0000-000000000000", - {"certificate": test_cert_file}, "openid") + identity.login_with_service_principal("sp_id1", {"certificate": test_cert_file}, "openid") + + @mock.patch("azure.cli.core.auth.identity.ServicePrincipalStore.save_entry") + @mock.patch("msal.application.ConfidentialClientApplication.acquire_token_for_client") + @mock.patch("msal.application.ConfidentialClientApplication.__init__", return_value=None) + def test_login_with_service_principal_client_assertion(self, init_mock, acquire_token_for_client_mock, + save_entry_mock): + acquire_token_for_client_mock.return_value = {'access_token': "test_token"} + + identity = Identity('https://login.microsoftonline.com', tenant_id='tenant1') + identity.login_with_service_principal("sp_id1", {'client_assertion': 'test_jwt'}, "openid") + + assert init_mock.call_args.args == ('sp_id1',) + assert init_mock.call_args.kwargs['client_credential'] == {"client_assertion": 'test_jwt'} + assert init_mock.call_args.kwargs['authority'] == 'https://login.microsoftonline.com/tenant1' + + assert save_entry_mock.call_args.args[0] == { + 'client_id': 'sp_id1', + 'tenant': 'tenant1', + 'client_assertion': 'test_jwt', + } @mock.patch("msal.application.PublicClientApplication.remove_account") @mock.patch("msal.application.PublicClientApplication.get_accounts") @@ -134,7 +163,7 @@ def test_logout_user(self, get_accounts_mock, remove_account_mock): get_accounts_mock.return_value = accounts identity = Identity('https://login.microsoftonline.com') - identity.logout_user('00000000-0000-0000-0000-000000000000') + identity.logout_user('test@test.com') remove_account_mock.assert_called_with(accounts[0]) @mock.patch("azure.cli.core.auth.identity.ServicePrincipalStore.remove_entry") @@ -142,7 +171,7 @@ def test_logout_user(self, get_accounts_mock, remove_account_mock): @mock.patch("msal.application.ConfidentialClientApplication.__init__", return_value=None) def test_logout_service_principal(self, init_mock, remove_tokens_for_client_mock, remove_entry_mock): identity = Identity('https://login.microsoftonline.com') - client_id = '00000000-0000-0000-0000-000000000000' + client_id = 'sp_id1' identity.logout_service_principal(client_id) assert init_mock.call_args.args[0] == client_id remove_tokens_for_client_mock.assert_called_once() @@ -153,57 +182,87 @@ class TestServicePrincipalAuth(unittest.TestCase): def test_service_principal_auth_client_secret(self): sp_auth = ServicePrincipalAuth.build_from_credential('tenant1', 'sp_id1', {'client_secret': "test_secret"}) - result = sp_auth.get_entry_to_persist() - assert result == { + # Verify persist entry + entry = sp_auth.get_entry_to_persist() + assert entry == { 'client_id': 'sp_id1', 'tenant': 'tenant1', 'client_secret': 'test_secret' } + # Verify msal client_credential + client_credential = sp_auth.get_msal_client_credential() + assert client_credential == 'test_secret' + def test_service_principal_auth_certificate(self): - curr_dir = os.path.dirname(os.path.realpath(__file__)) - test_cert_file = os.path.join(curr_dir, 'sp_cert.pem') - sp_auth = ServicePrincipalAuth.build_from_credential('tenant1', 'sp_id1', {'certificate': test_cert_file}) + sp_auth = ServicePrincipalAuth.build_from_credential('tenant1', 'sp_id1', {'certificate': TEST_CERT}) - result = sp_auth.get_entry_to_persist() - # To compute the thumb print: + # To compute the thumbprint: # openssl x509 -in sp_cert.pem -noout -fingerprint - assert sp_auth.thumbprint == 'F06A53848BBE714A4290D69D335279C1D01073FD' - assert result == { + assert sp_auth._thumbprint == 'F06A53848BBE714A4290D69D335279C1D01073FD' + + # Verify persist entry + entry = sp_auth.get_entry_to_persist() + assert entry == { 'client_id': 'sp_id1', 'tenant': 'tenant1', - 'certificate': test_cert_file + 'certificate': TEST_CERT } - def test_service_principal_auth_certificate_sn_issuer(self): - curr_dir = os.path.dirname(os.path.realpath(__file__)) - test_cert_file = os.path.join(curr_dir, 'sp_cert.pem') - - with open(test_cert_file) as cert_file: - cert_file_string = cert_file.read() - match = re.search(r'-+BEGIN CERTIFICATE-+(?P[^-]+)-+END CERTIFICATE-+', cert_file_string, re.I) - public_certificate = match.group().strip() + # Verify msal client_credential + client_credential = sp_auth.get_msal_client_credential() + assert client_credential == { + 'private_key': CERTIFICATE_STRING, + 'thumbprint': 'F06A53848BBE714A4290D69D335279C1D01073FD' + } + def test_service_principal_auth_certificate_sn_issuer(self): sp_auth = ServicePrincipalAuth.build_from_credential('tenant1', 'sp_id1', { - 'certificate': test_cert_file, + 'certificate': TEST_CERT, 'use_cert_sn_issuer': True, }) - result = sp_auth.get_entry_to_persist() - # To compute the thumb print: + # To compute the thumbprint: # openssl x509 -in sp_cert.pem -noout -fingerprint - assert sp_auth.thumbprint == 'F06A53848BBE714A4290D69D335279C1D01073FD' - assert sp_auth.public_certificate == public_certificate + assert sp_auth._thumbprint == 'F06A53848BBE714A4290D69D335279C1D01073FD' + assert sp_auth._public_certificate == PUBLIC_CERTIFICATE - assert result == { + # Verify persist entry + entry = sp_auth.get_entry_to_persist() + assert entry == { 'client_id': 'sp_id1', 'tenant': 'tenant1', - 'certificate': test_cert_file, + 'certificate': TEST_CERT, 'use_cert_sn_issuer': True, } + # Verify msal client_credential + client_credential = sp_auth.get_msal_client_credential() + assert client_credential == { + 'private_key': CERTIFICATE_STRING, + 'thumbprint': 'F06A53848BBE714A4290D69D335279C1D01073FD', + 'public_certificate': PUBLIC_CERTIFICATE + } + + def test_service_principal_auth_client_assertion(self): + sp_auth = ServicePrincipalAuth.build_from_credential('tenant1', 'sp_id1', + {'client_assertion': 'test_jwt'}) + assert sp_auth.client_assertion == 'test_jwt' + + # Verify persist entry + entry = sp_auth.get_entry_to_persist() + assert entry == { + 'client_id': 'sp_id1', + 'tenant': 'tenant1', + 'client_assertion': 'test_jwt' + } + + # Verify msal client_credential + client_credential = sp_auth.get_msal_client_credential() + assert client_credential == {'client_assertion': 'test_jwt'} + def test_build_credential(self): # secret cred = ServicePrincipalAuth.build_credential("test_secret")