Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Remove all AT, RT, FRT belongs to current account
  • Loading branch information
rayluo committed May 10, 2019
commit 167e954c50c11e84b00d633e42ddd94443081fcc
41 changes: 14 additions & 27 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,36 +280,23 @@ def _get_authority_aliases(self, instance):
return [alias for alias in group if alias != instance]
return []

def sign_out(self, account):
def remove_account(self, home_account):
"""Remove all relevant RTs and ATs from token cache"""
owned_by_account = {
"environment": account["environment"],
"home_account_id": (account or {}).get("home_account_id"),}

owned_by_account_and_app = dict(owned_by_account, client=self.client_id)
for rt in self.token_cache.find( # Remove RTs
TokenCache.CredentialType.REFRESH_TOKEN,
query=owned_by_account_and_app):
"environment": home_account["environment"],
"home_account_id": home_account["home_account_id"],} # realm-independent
for rt in self.token_cache.find( # Remove RTs, and RTs are realm-independent
TokenCache.CredentialType.REFRESH_TOKEN, query=owned_by_account):
self.token_cache.remove_rt(rt)
for at in self.token_cache.find( # Remove ATs
TokenCache.CredentialType.ACCESS_TOKEN,
query=owned_by_account_and_app): # regardless of realm
self.token_cache.remove_at(at) # TODO

app_metadata = self._get_app_metadata(account["environment"])
if app_metadata.get("family_id"): # Now let's settle family business
for rt in self.token_cache.find( # Remove FRTs
TokenCache.CredentialType.REFRESH_TOKEN, query=dict(
owned_by_account,
family_id=app_metadata["family_id"])):
self.token_cache.remove_rt(rt)
for sibling_app in self.token_cache.find( # Remove siblings' ATs
TokenCache.CredentialType.APP_METADATA,
query={"family_id": app_metadata.get["family_id"]}):
for at in self.token_cache.find( # Remove ATs, regardless of realm
TokenCache.CredentialType.ACCESS_TOKEN, query=dict(
owned_by_account, client_id=sibling_app["client_id"])):
self.token_cache.remove_at(at) # TODO
for at in self.token_cache.find( # Remove ATs, regardless of realm
TokenCache.CredentialType.ACCESS_TOKEN, query=owned_by_account):
self.token_cache.remove_at(at)
for idt in self.token_cache.find( # Remove IDTs, regardless of realm
TokenCache.CredentialType.ID_TOKEN, query=owned_by_account):
self.token_cache.remove_idt(idt)
for a in self.token_cache.find( # Remove Accounts, regardless of realm
TokenCache.CredentialType.ACCOUNT, query=owned_by_account):
self.token_cache.remove_account(a)

def acquire_token_silent(
self,
Expand Down
85 changes: 63 additions & 22 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,9 @@ def add(self, event, now=None):
with self._lock:

if access_token:
key = "-".join([
home_account_id or "",
environment or "",
self.CredentialType.ACCESS_TOKEN,
event.get("client_id", ""),
realm or "",
target,
]).lower()
key = self._build_at_key(
home_account_id, environment, event.get("client_id", ""),
realm, target)
now = time.time() if now is None else now
expires_in = response.get("expires_in", 3599)
self._cache.setdefault(self.CredentialType.ACCESS_TOKEN, {})[key] = {
Expand All @@ -110,11 +105,7 @@ def add(self, event, now=None):
if client_info:
decoded_id_token = json.loads(
base64decode(id_token.split('.')[1])) if id_token else {}
key = "-".join([
home_account_id or "",
environment or "",
realm or "",
]).lower()
key = self._build_account_key(home_account_id, environment, realm)
self._cache.setdefault(self.CredentialType.ACCOUNT, {})[key] = {
"home_account_id": home_account_id,
"environment": environment,
Expand All @@ -129,14 +120,8 @@ def add(self, event, now=None):
}

if id_token:
key = "-".join([
home_account_id or "",
environment or "",
self.CredentialType.ID_TOKEN,
event.get("client_id", ""),
realm or "",
"" # Albeit irrelevant, schema requires an empty scope here
]).lower()
key = self._build_idt_key(
home_account_id, environment, event.get("client_id", ""), realm)
self._cache.setdefault(self.CredentialType.ID_TOKEN, {})[key] = {
"credential_type": self.CredentialType.ID_TOKEN,
"secret": id_token,
Expand Down Expand Up @@ -178,7 +163,7 @@ def _build_appmetadata_key(environment, client_id):
def _build_rt_key(
cls,
home_account_id=None, environment=None, client_id=None, target=None,
**ignored):
**ignored_payload_from_a_real_token):
return "-".join([
home_account_id or "",
environment or "",
Expand All @@ -189,17 +174,73 @@ def _build_rt_key(
]).lower()

def remove_rt(self, rt_item):
assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN
key = self._build_rt_key(**rt_item)
with self._lock:
self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {}).pop(key, None)

def update_rt(self, rt_item, new_rt):
assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN
key = self._build_rt_key(**rt_item)
with self._lock:
RTs = self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {})
rt = RTs.get(key, {}) # key usually exists, but we'll survive its absence
rt["secret"] = new_rt

@classmethod
def _build_at_key(cls,
home_account_id=None, environment=None, client_id=None,
realm=None, target=None, **ignored_payload_from_a_real_token):
return "-".join([
home_account_id or "",
environment or "",
cls.CredentialType.ACCESS_TOKEN,
client_id,
realm or "",
target or "",
]).lower()

def remove_at(self, at_item):
assert at_item.get("credential_type") == self.CredentialType.ACCESS_TOKEN
key = self._build_at_key(**at_item)
with self._lock:
self._cache.setdefault(self.CredentialType.ACCESS_TOKEN, {}).pop(key, None)

@classmethod
def _build_idt_key(cls,
home_account_id=None, environment=None, client_id=None, realm=None,
**ignored_payload_from_a_real_token):
return "-".join([
home_account_id or "",
environment or "",
cls.CredentialType.ID_TOKEN,
client_id or "",
realm or "",
"" # Albeit irrelevant, schema requires an empty scope here
]).lower()

def remove_idt(self, idt_item):
assert idt_item.get("credential_type") == self.CredentialType.ID_TOKEN
key = self._build_idt_key(**idt_item)
with self._lock:
self._cache.setdefault(self.CredentialType.ID_TOKEN, {}).pop(key, None)

@classmethod
def _build_account_key(cls,
home_account_id=None, environment=None, realm=None,
**ignored_payload_from_a_real_entry):
return "-".join([
home_account_id or "",
environment or "",
realm or "",
]).lower()

def remove_account(self, account_item):
assert "authority_type" in account_item
key = self._build_account_key(**account_item)
with self._lock:
self._cache.setdefault(self.CredentialType.ACCOUNT, {}).pop(key, None)


class SerializableTokenCache(TokenCache):
"""This serialization can be a starting point to implement your own persistence.
Expand Down
30 changes: 30 additions & 0 deletions tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ def setUp(self):
"scope": self.scopes,
"token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url),
"response": TokenCacheTestCase.build_response(
access_token="Siblings won't share AT. test_remove_account() will.",
id_token=TokenCacheTestCase.build_id_token(),
uid=self.uid, utid=self.utid, refresh_token=self.frt, foci="1"),
}) # The add(...) helper populates correct home_account_id for future searching

Expand Down Expand Up @@ -239,6 +241,34 @@ def tester(url, data=None, **kwargs):

# Will not test scenario of app leaving family. Per specs, it won't happen.

def test_get_remove_account(self):
logger.debug("%s.cache = %s", self.id(), self.cache.serialize())
app = ClientApplication(
"family_app_2", authority=self.authority_url, token_cache=self.cache)
account = app.get_accounts()[0]
mine = {"home_account_id": account["home_account_id"]}

self.assertNotEqual([], self.cache.find(
self.cache.CredentialType.ACCESS_TOKEN, query=mine))
self.assertNotEqual([], self.cache.find(
self.cache.CredentialType.REFRESH_TOKEN, query=mine))
self.assertNotEqual([], self.cache.find(
self.cache.CredentialType.ID_TOKEN, query=mine))
self.assertNotEqual([], self.cache.find(
self.cache.CredentialType.ACCOUNT, query=mine))

app.remove_account(account)

self.assertEqual([], self.cache.find(
self.cache.CredentialType.ACCESS_TOKEN, query=mine))
self.assertEqual([], self.cache.find(
self.cache.CredentialType.REFRESH_TOKEN, query=mine))
self.assertEqual([], self.cache.find(
self.cache.CredentialType.ID_TOKEN, query=mine))
self.assertEqual([], self.cache.find(
self.cache.CredentialType.ACCOUNT, query=mine))


class TestClientApplicationForAuthorityMigration(unittest.TestCase):

@classmethod
Expand Down