Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor get_msal_token function
  • Loading branch information
arrownj committed Apr 24, 2020
commit 6f12c34308f1bcc6e832175c1ed134898108c2af
2 changes: 1 addition & 1 deletion src/azure-cli-core/azure/cli/core/_msal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from msal import ClientApplication


class SSHCertificateClientApplication(ClientApplication):
class AdalRefreshTokenBasedClientApplication(ClientApplication):
"""
This class is added only for vmssh feature.
This is a temporary solution and will deprecate after adoption to MSAL completely.
Expand Down
50 changes: 12 additions & 38 deletions src/azure-cli-core/azure/cli/core/_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,16 +525,6 @@ def get_access_token_for_resource(self, username, tenant, resource):
username, tenant, resource)
return access_token

def _get_ssh_certificate_for_resource(self, tenant, modulus, exponent):
"""
This is added only for vmssh feature.
It is a temporary solution and will deprecate after adoption to MSAL completely.
"""
tenant = tenant or 'common'
_, refresh_token, _, _ = self.get_refresh_token()
_, cert, _ = self._creds_cache.retrieve_ssh_certificate_for_user(tenant, modulus, exponent, refresh_token)
return cert

@staticmethod
def _try_parse_msi_account_name(account):
msi_info, user = account[_USER_ENTITY].get(_ASSIGNED_IDENTITY_INFO), account[_USER_ENTITY].get(_USER_NAME)
Expand Down Expand Up @@ -603,16 +593,17 @@ def _retrieve_tokens_from_external_tenants():
str(account[_SUBSCRIPTION_ID]),
str(account[_TENANT_ID]))

def get_ssh_credentials(self, modulus, exponent):
def get_msal_token(self, scopes, data):
"""
This is added only for vmssh feature.
It is a temporary solution and will deprecate after adoption to MSAL completely.
It is a temporary solution and will deprecate after MSAL adopted completely.
"""
account = self.get_subscription()
username = account[_USER_ENTITY][_USER_NAME]
return username, self._get_ssh_certificate_for_resource(
account[_TENANT_ID], modulus, exponent
)
tenant = account[_TENANT_ID] or 'common'
_, refresh_token, _, _ = self.get_refresh_token()
certificate = self._creds_cache.retrieve_msal_token(tenant, scopes, data, refresh_token)
return username, certificate

def get_refresh_token(self, resource=None,
subscription=None):
Expand Down Expand Up @@ -1029,35 +1020,18 @@ def retrieve_token_for_user(self, username, tenant, resource):
self.persist_cached_creds()
return (token_entry[_TOKEN_ENTRY_TOKEN_TYPE], token_entry[_ACCESS_TOKEN], token_entry)

def retrieve_ssh_certificate_for_user(self, tenant, modulus, exponent, refresh_token):
def retrieve_msal_token(self, tenant, scopes, data, refresh_token):
"""
This is added only for vmssh feature.
It is a temporary solution and will deprecate after adoption to MSAL completely.
It is a temporary solution and will deprecate after MSAL adopted completely.
"""
from azure.cli.core._msal import SSHCertificateClientApplication
import hashlib
scopes = ["https://pas.windows.net/CheckMyAccess/Linux/user_impersonation"]
from azure.cli.core._msal import AdalRefreshTokenBasedClientApplication
tenant = tenant or 'organizations'
authority = self._ctx.cloud.endpoints.active_directory + '/' + tenant
app = SSHCertificateClientApplication(_CLIENT_ID, authority=authority)

key_hash = hashlib.sha256()
key_hash.update(modulus.encode('utf-8'))
key_hash.update(exponent.encode('utf-8'))
key_id = key_hash.hexdigest()

jwk = {
"kty": "RSA",
"n": modulus,
"e": exponent,
"kid": key_id
}
json_jwk = json.dumps(jwk)
result = app.acquire_token_silent(scopes, None,
data={"token_type": "ssh-cert", "req_cnf": json_jwk, "key_id": jwk["kid"]},
refresh_token=refresh_token)
app = AdalRefreshTokenBasedClientApplication(_CLIENT_ID, authority=authority)
result = app.acquire_token_silent(scopes, None, data=data, refresh_token=refresh_token)

return result["token_type"], result["access_token"], result
return result["access_token"]

def retrieve_token_for_service_principal(self, sp_id, resource, tenant, use_cert_sn_issuer=False):
self.load_adal_token_cache()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
logger = log.get_logger(__name__)


def get_ssh_credentials(cli_ctx, modulus, exponent):
def get_ssh_credentials(cli_ctx, scopes, data):
from azure.cli.core._profile import Profile
logger.debug("Getting SSH credentials")
profile = Profile(cli_ctx=cli_ctx)

user, cert = profile.get_ssh_credentials(modulus, exponent)
user, cert = profile.get_ssh_credentials(scopes, data)
return SSHCredentials(user, cert)


Expand Down