Skip to content
Merged

FOCI #957

Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Token cache changes for saving AppMetadada and FRT
  • Loading branch information
bgavrilMS committed Mar 16, 2019
commit 36fd99b544d958b18346e65cf265bbb5323cf57e
2 changes: 1 addition & 1 deletion src/Microsoft.Identity.Client/Cache/CacheSessionManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public Task<MsalRefreshTokenCacheItem> FindRefreshTokenAsync()

public Task<bool?> IsAppFociMemberAsync(string familyId)
{
return TokenCacheInternal.CheckAppIsFamilyMemberAsync(_requestParams, familyId);
return TokenCacheInternal.IsFociMemberAsync(_requestParams, familyId);
}
}
}
14 changes: 13 additions & 1 deletion src/Microsoft.Identity.Client/ITokenCacheInternal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@ Tuple<MsalAccessTokenCacheItem, MsalIdTokenCacheItem> SaveTokenResponse(

Task<MsalAccessTokenCacheItem> FindAccessTokenAsync(AuthenticationRequestParameters authenticationRequestParameters);
MsalIdTokenCacheItem GetIdTokenCacheItem(MsalIdTokenCacheKey getIdTokenItemKey, RequestContext requestContext);
Task<MsalRefreshTokenCacheItem> FindRefreshTokenAsync(AuthenticationRequestParameters authenticationRequestParameters);

/// <summary>
/// Returns a RT for the request. If familyId is specified, it tries to return the FRT.
/// </summary>
Task<MsalRefreshTokenCacheItem> FindRefreshTokenAsync(
AuthenticationRequestParameters authenticationRequestParameters,
string familyId = null);

void SetIosKeychainSecurityGroup(string securityGroup);

Expand All @@ -68,6 +74,12 @@ Tuple<MsalAccessTokenCacheItem, MsalIdTokenCacheItem> SaveTokenResponse(
IEnumerable<MsalIdTokenCacheItem> GetAllIdTokens(bool filterByClientId);
IEnumerable<MsalAccountCacheItem> GetAllAccounts();

/// <summary>
/// FOCI - check in the app metadata to see if the app is part of the family
/// </summary>
/// <returns>null if unkown, true or false if app metadata has details</returns>
Task<bool?> IsFociMemberAsync(AuthenticationRequestParameters authenticationRequestParameters, string familyId);

void ClearAdalCache();
void ClearMsalCache();
void Clear();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,19 +185,13 @@ public void Clear()
// app metadata isn't removable
}

public MsalAppMetadataCacheItem ReadAppMetadata(MsalAppMetadataCacheKey appMetadataKey)
public MsalAppMetadataCacheItem GetAppMetadata(MsalAppMetadataCacheKey appMetadataKey)
{
if (_appMetadataDictionary.TryGetValue(appMetadataKey.ToString(), out var cacheItem))
{
return cacheItem;
}
return null;
}

public void WriteAppMetadata(MsalAppMetadataCacheItem appMetadata)
{
_appMetadataDictionary[appMetadata.GetKey().ToString()] = appMetadata;
}

}
}
152 changes: 110 additions & 42 deletions src/Microsoft.Identity.Client/TokenCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,11 @@ internal void SetServiceBundle(IServiceBundle serviceBundle)
/// <param name="serviceBundle"></param>
/// <param name="legacyCachePersistenceForTest"></param>
internal TokenCache(IServiceBundle serviceBundle, ILegacyCachePersistence legacyCachePersistenceForTest)
: this(serviceBundle.PlatformProxy)
{
ServiceBundle = serviceBundle;
_accessor = ServiceBundle.PlatformProxy.CreateTokenCacheAccessor();
SetServiceBundle(serviceBundle);

LegacyCachePersistence = legacyCachePersistenceForTest;
_defaultTokenCacheBlobStorage = ServiceBundle.PlatformProxy.CreateTokenCacheBlobStorage();
}

/// <summary>
Expand All @@ -140,6 +140,7 @@ internal TokenCache(IServiceBundle serviceBundle, ILegacyCachePersistence legacy

internal string ClientId => ServiceBundle.Config.ClientId;

#region Notifications
/// <summary>
/// Notification method called before any library method accesses the cache.
/// </summary>
Expand Down Expand Up @@ -208,6 +209,8 @@ internal void OnBeforeWrite(TokenCacheNotificationArgs args)
BeforeWrite?.Invoke(args);
}

#endregion

Tuple<MsalAccessTokenCacheItem, MsalIdTokenCacheItem> ITokenCacheInternal.SaveTokenResponse(
AuthenticationRequestParameters requestParams,
MsalTokenResponse response)
Expand Down Expand Up @@ -274,23 +277,26 @@ Tuple<MsalAccessTokenCacheItem, MsalIdTokenCacheItem> ITokenCacheInternal.SaveTo

_accessor.SaveAccessToken(msalAccessTokenCacheItem);

if (idToken != null)
{
_accessor.SaveIdToken(msalIdTokenCacheItem);

var msalAccountCacheItem = new MsalAccountCacheItem(preferredEnvironmentHost, response, preferredUsername, tenantId);

_accessor.SaveAccount(msalAccountCacheItem);
}
if (idToken != null)
{
_accessor.SaveIdToken(msalIdTokenCacheItem);
var msalAccountCacheItem = new MsalAccountCacheItem(preferredEnvironmentHost, response, preferredUsername, tenantId);
_accessor.SaveAccount(msalAccountCacheItem);
}

// if server returns the refresh token back, save it in the cache.
if (response.RefreshToken != null)
// if server returns the refresh token back, save it in the cache.
if (response.RefreshToken != null)
{
msalRefreshTokenCacheItem = new MsalRefreshTokenCacheItem(preferredEnvironmentHost, requestParams.ClientId, response);
if (!_featureFlags.IsFociEnabled)
{
msalRefreshTokenCacheItem = new MsalRefreshTokenCacheItem(preferredEnvironmentHost, requestParams.ClientId, response);
requestParams.RequestContext.Logger.Info("Saving RT in cache...");
_accessor.SaveRefreshToken(msalRefreshTokenCacheItem);
msalRefreshTokenCacheItem.FamilyId = null;
}

requestParams.RequestContext.Logger.Info("Saving RT in cache...");
_accessor.SaveRefreshToken(msalRefreshTokenCacheItem);
}

UpdateAppMetadata(requestParams.ClientId, preferredEnvironmentHost, response.FamilyId);

// save RT in ADAL cache for public clients
Expand Down Expand Up @@ -560,7 +566,9 @@ private string GetAccessTokenExpireLogMessageContent(MsalAccessTokenCacheItem ms
msalAccessTokenCacheItem.ExtendedExpiresOn);
}

async Task<MsalRefreshTokenCacheItem> ITokenCacheInternal.FindRefreshTokenAsync(AuthenticationRequestParameters requestParams)
async Task<MsalRefreshTokenCacheItem> ITokenCacheInternal.FindRefreshTokenAsync(
AuthenticationRequestParameters requestParams,
string familyId)
{
using (ServiceBundle.TelemetryManager.CreateTelemetryHelper(requestParams.RequestContext.TelemetryRequestId, requestParams.RequestContext.ClientId,
new CacheEvent(CacheEvent.TokenCacheLookup) { TokenType = CacheEvent.TokenTypes.RT }))
Expand Down Expand Up @@ -590,30 +598,24 @@ async Task<MsalRefreshTokenCacheItem> ITokenCacheInternal.FindRefreshTokenAsync(
Account = requestParams.Account
};

MsalRefreshTokenCacheKey key = new MsalRefreshTokenCacheKey(
preferredEnvironmentHost, requestParams.ClientId, requestParams.Account?.HomeAccountId?.Identifier);
// make sure to check preferredEnvironmentHost first
var allEnvAliases = new List<string>() { preferredEnvironmentHost };
allEnvAliases.AddRange(environmentAliases);

var keysAcrossEnvs = allEnvAliases.Select(ea => new MsalRefreshTokenCacheKey(
ea,
requestParams.ClientId,
requestParams.Account?.HomeAccountId?.Identifier,
familyId));


OnBeforeAccess(args);
try
{
MsalRefreshTokenCacheItem msalRefreshTokenCacheItem = _accessor.GetRefreshToken(key);

// trying to find rt by authority aliases
if (msalRefreshTokenCacheItem == null)
{
var refreshTokens = _accessor.GetAllRefreshTokens();

foreach (var refreshToken in refreshTokens)
{
if (refreshToken.ClientId.Equals(requestParams.ClientId, StringComparison.OrdinalIgnoreCase) &&
environmentAliases.Contains(refreshToken.Environment) &&
requestParams.Account?.HomeAccountId?.Identifier == refreshToken.HomeAccountId)
{
msalRefreshTokenCacheItem = refreshToken;
continue;
}
}
}
// Try to load from all env aliases, but stop at the first valid one
MsalRefreshTokenCacheItem msalRefreshTokenCacheItem = keysAcrossEnvs
.Select(key => _accessor.GetRefreshToken(key))
.FirstOrDefault(item => item != null);

requestParams.RequestContext.Logger.Info("Refresh token found in the cache? - " + (msalRefreshTokenCacheItem != null));

Expand Down Expand Up @@ -646,6 +648,55 @@ async Task<MsalRefreshTokenCacheItem> ITokenCacheInternal.FindRefreshTokenAsync(
}
}

async Task<bool?> ITokenCacheInternal.IsFociMemberAsync(AuthenticationRequestParameters requestParams, string familyId)
{
var logger = requestParams.RequestContext.Logger;
if (requestParams?.AuthorityInfo?.CanonicalAuthority == null)
{
logger.Warning("No authority details, can't check app metadta. Returning unkown");
return null;
}

var instanceDiscoveryMetadataEntry = await GetCachedOrDiscoverAuthorityMetaDataAsync(
requestParams.AuthorityInfo.CanonicalAuthority,
requestParams.RequestContext).ConfigureAwait(false);

var environmentAliases = GetEnvironmentAliases(
requestParams.AuthorityInfo.CanonicalAuthority,
instanceDiscoveryMetadataEntry);

TokenCacheNotificationArgs args = new TokenCacheNotificationArgs
{
TokenCache = this,
ClientId = ClientId,
Account = requestParams?.Account,
HasStateChanged = false
};

//TODO: bogavril - is the env ok here? Can I cache it or pass it in?
MsalAppMetadataCacheItem appMetadata;
lock (LockObject)
{
OnBeforeAccess(args);

appMetadata =
environmentAliases
.Select(env => _accessor.GetAppMetadata(new MsalAppMetadataCacheKey(ClientId, env)))
.FirstOrDefault(item => item != null);

OnAfterAccess(args);
}

if (appMetadata == null)
{
logger.Warning("No app metadata found. Returning unkown");
return null;
}

return appMetadata.FamilyId == familyId;
}


/// <inheritdoc />
public void SetIosKeychainSecurityGroup(string securityGroup)
{
Expand Down Expand Up @@ -856,6 +907,8 @@ IEnumerable<MsalAccountCacheItem> ITokenCacheInternal.GetAllAccounts()
}
}


#region Removal methods
void ITokenCacheInternal.RemoveAccount(IAccount account, RequestContext requestContext)
{
lock (LockObject)
Expand Down Expand Up @@ -901,7 +954,9 @@ void ITokenCacheInternal.RemoveMsalAccount(IAccount account, RequestContext requ
// adalv3 account
return;
}
var allRefreshTokens = ((ITokenCacheInternal)this).GetAllRefreshTokens(true)

// Delete ALL refresh tokens associated with this account
var allRefreshTokens = ((ITokenCacheInternal)this).GetAllRefreshTokens(false)
.Where(item => item.HomeAccountId.Equals(account.HomeAccountId.Identifier, StringComparison.OrdinalIgnoreCase))
.ToList();

Expand All @@ -911,7 +966,7 @@ void ITokenCacheInternal.RemoveMsalAccount(IAccount account, RequestContext requ
}

requestContext.Logger.Info("Deleted refresh token count - " + allRefreshTokens.Count);
IList<MsalAccessTokenCacheItem> allAccessTokens = ((ITokenCacheInternal)this).GetAllAccessTokens(true)
IList<MsalAccessTokenCacheItem> allAccessTokens = ((ITokenCacheInternal)this).GetAllAccessTokens(false)
.Where(item => item.HomeAccountId.Equals(account.HomeAccountId.Identifier, StringComparison.OrdinalIgnoreCase))
.ToList();
foreach (MsalAccessTokenCacheItem accessTokenCacheItem in allAccessTokens)
Expand All @@ -921,7 +976,7 @@ void ITokenCacheInternal.RemoveMsalAccount(IAccount account, RequestContext requ

requestContext.Logger.Info("Deleted access token count - " + allAccessTokens.Count);

var allIdTokens = ((ITokenCacheInternal)this).GetAllIdTokens(true)
var allIdTokens = ((ITokenCacheInternal)this).GetAllIdTokens(false)
.Where(item => item.HomeAccountId.Equals(account.HomeAccountId.Identifier, StringComparison.OrdinalIgnoreCase))
.ToList();
foreach (MsalIdTokenCacheItem idTokenCacheItem in allIdTokens)
Expand All @@ -930,6 +985,13 @@ void ITokenCacheInternal.RemoveMsalAccount(IAccount account, RequestContext requ
}

requestContext.Logger.Info("Deleted Id token count - " + allIdTokens.Count);

((ITokenCacheInternal)this).GetAllAccounts()
.Where(item => item.HomeAccountId.Equals(account.HomeAccountId.Identifier, StringComparison.OrdinalIgnoreCase) &&
item.PreferredUsername.Equals(account.Username, StringComparison.OrdinalIgnoreCase))
.ToList()
.ForEach(accItem => _accessor.DeleteAccount(accItem.GetKey()));

}

internal void RemoveAdalUser(IAccount account)
Expand Down Expand Up @@ -984,14 +1046,18 @@ void ITokenCacheInternal.ClearMsalCache()
_accessor.Clear();
}

#endregion

#region Serialization

private bool UserHasConfiguredBlobSerialization()
{
return _userConfiguredBeforeAccess != null ||
_userConfiguredBeforeAccess != null ||
_userConfiguredBeforeWrite != null;
}

#if !ANDROID_BUILDTIME && !iOS_BUILDTIME
#if !ANDROID_BUILDTIME && !iOS_BUILDTIME

/// <summary>
/// Sets a delegate to be notified before any library method accesses the cache. This gives an option to the
Expand Down Expand Up @@ -1259,7 +1325,9 @@ private static void GuardOnMobilePlatforms()
"For more details about custom token cache serialization, visit https://aka.ms/msal-net-serialization");
#endif
}
#endif // !ANDROID_BUILDTIME && !iOS_BUILDTIME


#endif // !ANDROID_BUILDTIME && !iOS_BUILDTIME
#endregion
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,12 @@ internal void PopulateCacheForClientCredential(ITokenCacheAccessor accessor)
internal void PopulateCache(
ITokenCacheAccessor accessor,
string uid = MsalTestConstants.Uid,
string utid = MsalTestConstants.Utid)
string utid = MsalTestConstants.Utid,
string clientId = MsalTestConstants.ClientId)
{
MsalAccessTokenCacheItem atItem = new MsalAccessTokenCacheItem(
MsalTestConstants.ProductionPrefCacheEnvironment,
MsalTestConstants.ClientId,
clientId,
MsalTestConstants.Scope.AsSingleString(),
utid,
"",
Expand All @@ -72,8 +73,8 @@ internal void PopulateCache(
accessor.SaveAccessToken(atItem);

var idTokenCacheItem = new MsalIdTokenCacheItem(
MsalTestConstants.ProductionPrefCacheEnvironment,
MsalTestConstants.ClientId,
MsalTestConstants.ProductionPrefCacheEnvironment,
clientId,
MockHelpers.CreateIdToken(MsalTestConstants.UniqueId + "more", MsalTestConstants.DisplayableId),
MockHelpers.CreateClientInfo(uid, utid),
utid);
Expand All @@ -94,7 +95,7 @@ internal void PopulateCache(

atItem = new MsalAccessTokenCacheItem(
MsalTestConstants.ProductionPrefCacheEnvironment,
MsalTestConstants.ClientId,
clientId,
MsalTestConstants.ScopeForAnotherResource.AsSingleString(),
utid,
"",
Expand All @@ -105,7 +106,14 @@ internal void PopulateCache(
// add another access token
accessor.SaveAccessToken(atItem);

AddRefreshTokenToCache(accessor,uid, utid, MsalTestConstants.Name);
AddRefreshTokenToCache(accessor,uid, utid, clientId);

var appMetadataItem = new MsalAppMetadataCacheItem(
clientId,
MsalTestConstants.ProductionPrefCacheEnvironment,
null);

accessor.SaveAppMetadata(appMetadataItem);
}

internal void PopulateCacheWithOneAccessToken(ITokenCacheAccessor accessor)
Expand Down Expand Up @@ -149,10 +157,10 @@ internal void PopulateCacheWithOneAccessToken(ITokenCacheAccessor accessor)
AddRefreshTokenToCache(accessor, MsalTestConstants.Uid, MsalTestConstants.Utid, MsalTestConstants.Name);
}

public static void AddRefreshTokenToCache(ITokenCacheAccessor accessor, string uid, string utid, string name)
public static void AddRefreshTokenToCache(ITokenCacheAccessor accessor, string uid, string utid, string clientId = MsalTestConstants.ClientId)
{
var rtItem = new MsalRefreshTokenCacheItem
(MsalTestConstants.ProductionPrefCacheEnvironment, MsalTestConstants.ClientId, "someRT", MockHelpers.CreateClientInfo(uid, utid));
(MsalTestConstants.ProductionPrefCacheEnvironment, clientId, "someRT", MockHelpers.CreateClientInfo(uid, utid));

accessor.SaveRefreshToken(rtItem);
}
Expand All @@ -177,4 +185,4 @@ public static void AddAccountToCache(ITokenCacheAccessor accessor, string uid, s
accessor.SaveAccount(accountCacheItem);
}
}
}
}
Loading