Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,8 @@ await ValidateJWEAsync(jsonWebToken, validationParameters, currentConfiguration)
{
_telemetryClient.IncrementConfigurationRefreshRequestCounter(
validationParameters.ConfigurationManager.MetadataAddress,
TelemetryConstants.Protocols.Lkg);
TelemetryConstants.Protocols.Lkg,
TelemetryConstants.Protocols.ConfigurationSourceUnknown);

validationParameters.ConfigurationManager.RequestRefresh();
validationParameters.RefreshBeforeValidation = true;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;

namespace Microsoft.IdentityModel.Protocols.Configuration
{
/// <summary>
/// Represents a configuration retrieval result.
/// </summary>
/// <typeparam name="T">The type of configuration.</typeparam>
public class ConfigurationEventHandlerResult<T> where T : class
{
/// <summary>
/// Represents a result indicating that configuration retrieval should proceed normally, with no result provided from the event handler.
/// </summary>
public static readonly ConfigurationEventHandlerResult<T> NoResult = new();

/// <summary>
/// Initializes a new instance of the <see cref="ConfigurationEventHandlerResult{T}"/> class with no result.
/// </summary>
private ConfigurationEventHandlerResult()
{
Configuration = null;
RetrievalTime = DateTimeOffset.MinValue;
}

/// <summary>
/// Initializes a new instance of the <see cref="ConfigurationEventHandlerResult{T}"/> class.
/// </summary>
/// <param name="configuration">The configuration retrieved.</param>
/// <param name="retrievalTime"> The time when the configuration was originally retrieved (UTC).</param>
/// <remarks>
/// Setting a <paramref name="configuration"/> on the <see cref="ConfigurationEventHandlerResult{T}"/> skips the existing
/// configuration retrieval process and sets <paramref name="configuration"/> as a current valid configuration.
/// </remarks>
public ConfigurationEventHandlerResult(T configuration, DateTimeOffset retrievalTime)
{
Configuration = configuration ?? throw new ArgumentNullException(nameof(configuration));
RetrievalTime = retrievalTime;
}

/// <summary>
/// Gets or sets the configuration.
/// </summary>
public T Configuration { get; }

/// <summary>
/// Gets or sets the time when the configuration was originally retrieved in UTC.
/// </summary>
/// <remarks>
/// This property will be set to <see cref="DateTimeOffset.MinValue"/> for <see cref="NoResult"/>.
/// </remarks>
public DateTimeOffset RetrievalTime { get; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
using System.Threading.Tasks;
using Microsoft.IdentityModel.Logging;
using Microsoft.IdentityModel.Protocols.Configuration;
using Microsoft.IdentityModel.Tokens;
using Microsoft.IdentityModel.Telemetry;
using Microsoft.IdentityModel.Tokens;

namespace Microsoft.IdentityModel.Protocols
{
Expand Down Expand Up @@ -41,6 +41,11 @@ public partial class ConfigurationManager<T> : BaseConfigurationManager, IConfig
internal TimeProvider TimeProvider = TimeProvider.System;
internal ITelemetryClient TelemetryClient = new TelemetryClient();

/// <summary>
/// Gets or sets the optional configuration event handler.
/// </summary>
public IConfigurationEventHandler<T> ConfigurationEventHandler { get; set; }

/// <summary>
/// Instantiates a new <see cref="ConfigurationManager{T}"/> that manages automatic and controls refreshing on configuration data.
/// </summary>
Expand Down Expand Up @@ -135,6 +140,25 @@ public ConfigurationManager(string metadataAddress, IConfigurationRetriever<T> c
_configValidator = configValidator;
}

/// <summary>
/// Instantiates a new <see cref="ConfigurationManager{T}"/> with configuration validator that manages automatic and controls refreshing on configuration data.
/// </summary>
/// <param name="metadataAddress">The address to obtain configuration.</param>
/// <param name="configRetriever">The <see cref="IConfigurationRetriever{T}"/>.</param>
/// <param name="docRetriever">The <see cref="IDocumentRetriever"/> that reaches out to obtain the configuration.</param>
/// <param name="configValidator">The <see cref="IConfigurationValidator{T}"/>.</param>
/// <param name="lkgCacheOptions">The <see cref="LastKnownGoodConfigurationCacheOptions"/>.</param>
/// <param name="configurationEventHandler">The <see cref="IConfigurationEventHandler{T}"/> that handles configuration events.</param>
/// <exception cref="ArgumentNullException">If 'configValidator' is null.</exception>
public ConfigurationManager(string metadataAddress, IConfigurationRetriever<T> configRetriever, IDocumentRetriever docRetriever, IConfigurationValidator<T> configValidator, LastKnownGoodConfigurationCacheOptions lkgCacheOptions, IConfigurationEventHandler<T> configurationEventHandler)
: this(metadataAddress, configRetriever, docRetriever, configValidator, lkgCacheOptions)
{
if (configurationEventHandler == null)
throw LogHelper.LogArgumentNullException(nameof(configurationEventHandler));

ConfigurationEventHandler = configurationEventHandler;
}

/// <summary>
/// Obtains an updated version of Configuration.
/// </summary>
Expand Down Expand Up @@ -203,6 +227,25 @@ private async Task<T> GetConfigurationNonBlockingAsync(CancellationToken cancel)

try
{
// Check if event handler can provide configuration.
// If provided configuration is valid, skip regular retriaval process and update current configuration.
if (ConfigurationEventHandler != null)
{
var configurationRetrieved = await HandleBeforeRetrieveAsync(cancel).ConfigureAwait(false);

// replicate the behavior of successful retrieval from endpoint
if (configurationRetrieved != null && configurationRetrieved.Configuration != null)
{
TelemetryClient.IncrementConfigurationRefreshRequestCounter(
MetadataAddress,
TelemetryConstants.Protocols.FirstRefresh,
TelemetryConstants.Protocols.ConfigurationSourceHandler);

UpdateConfiguration(configurationRetrieved.Configuration, configurationRetrieved.RetrievalTime);
return _currentConfiguration;
}
}

// Don't use the individual CT here, this is a shared operation that shouldn't be affected by an individual's cancellation.
// The transport should have its own timeouts, etc.
T configuration = await _configRetriever.GetConfigurationAsync(
Expand All @@ -227,9 +270,10 @@ private async Task<T> GetConfigurationNonBlockingAsync(CancellationToken cancel)

TelemetryClient.IncrementConfigurationRefreshRequestCounter(
MetadataAddress,
TelemetryConstants.Protocols.FirstRefresh);
TelemetryConstants.Protocols.FirstRefresh,
TelemetryConstants.Protocols.ConfigurationSourceRetriever);

UpdateConfiguration(configuration);
UpdateConfiguration(configuration, TimeProvider.GetUtcNow());
}
#pragma warning disable CA1031 // Do not catch general exception types
catch (Exception ex)
Expand All @@ -238,6 +282,7 @@ private async Task<T> GetConfigurationNonBlockingAsync(CancellationToken cancel)
TelemetryClient.IncrementConfigurationRefreshRequestCounter(
MetadataAddress,
TelemetryConstants.Protocols.FirstRefresh,
TelemetryConstants.Protocols.ConfigurationSourceRetriever,
ex);

LogHelper.LogExceptionMessage(
Expand All @@ -260,9 +305,10 @@ private async Task<T> GetConfigurationNonBlockingAsync(CancellationToken cancel)
{
TelemetryClient.IncrementConfigurationRefreshRequestCounter(
MetadataAddress,
TelemetryConstants.Protocols.Automatic);
TelemetryConstants.Protocols.Automatic,
TelemetryConstants.Protocols.ConfigurationSourceUnknown);

_ = Task.Run(UpdateCurrentConfiguration, CancellationToken.None);
_ = Task.Run(UpdateCurrentConfigurationAsync, CancellationToken.None);
}
}

Expand All @@ -285,25 +331,40 @@ private async Task<T> GetConfigurationNonBlockingAsync(CancellationToken cancel)
/// The Caller should first check the state checking state using:
/// if (Interlocked.CompareExchange(ref _configurationRetrieverState, ConfigurationRetrieverRunning, ConfigurationRetrieverIdle) == ConfigurationRetrieverIdle).
/// </summary>
private void UpdateCurrentConfiguration()
private async Task UpdateCurrentConfigurationAsync()
{
long startTimestamp = TimeProvider.GetTimestamp();

try
{
T configuration = _configRetriever.GetConfigurationAsync(
// Check if event handler can provide configuration
// If provided configuration is valid, skip regular retriaval process and update current configuration.
if (ConfigurationEventHandler != null)
{
var configurationRetrieved = await HandleBeforeRetrieveAsync().ConfigureAwait(false);
if (configurationRetrieved != null && configurationRetrieved.Configuration != null)
{
UpdateConfiguration(configurationRetrieved.Configuration, configurationRetrieved.RetrievalTime);

_onBackgroundTaskFinish?.Invoke();
return;
}
}

T configuration = await _configRetriever.GetConfigurationAsync(
MetadataAddress,
_docRetriever,
CancellationToken.None).ConfigureAwait(false).GetAwaiter().GetResult();
CancellationToken.None).ConfigureAwait(false);

var elapsedTime = TimeProvider.GetElapsedTime(startTimestamp);
TelemetryClient.LogConfigurationRetrievalDuration(
MetadataAddress,
TelemetryConstants.Protocols.ConfigurationSourceRetriever,
elapsedTime);

if (_configValidator == null)
{
UpdateConfiguration(configuration);
UpdateConfiguration(configuration, TimeProvider.GetUtcNow());
}
else
{
Expand All @@ -316,7 +377,7 @@ private void UpdateCurrentConfiguration()
LogMessages.IDX20810,
result.ErrorMessage)));
else
UpdateConfiguration(configuration);
UpdateConfiguration(configuration, TimeProvider.GetUtcNow());
}
}
#pragma warning disable CA1031 // Do not catch general exception types
Expand All @@ -325,6 +386,7 @@ private void UpdateCurrentConfiguration()
var elapsedTime = TimeProvider.GetElapsedTime(startTimestamp);
TelemetryClient.LogConfigurationRetrievalDuration(
MetadataAddress,
TelemetryConstants.Protocols.ConfigurationSourceRetriever,
elapsedTime,
ex);

Expand All @@ -345,11 +407,33 @@ private void UpdateCurrentConfiguration()
_onBackgroundTaskFinish?.Invoke();
}

private void UpdateConfiguration(T configuration)
private void UpdateConfiguration(T configuration, DateTimeOffset retrievalTime)
{
_currentConfiguration = configuration;
_syncAfter = DateTimeUtil.Add(TimeProvider.GetUtcNow().UtcDateTime, AutomaticRefreshInterval +
_syncAfter = DateTimeUtil.Add(retrievalTime.UtcDateTime, AutomaticRefreshInterval +
TimeSpan.FromSeconds(new Random().Next((int)AutomaticRefreshInterval.TotalSeconds / 20)));

if (ConfigurationEventHandler != null)
{
// fire-and-forget an after update task
_ = Task.Run(async () =>
{
try
{
await ConfigurationEventHandler.AfterUpdateAsync(MetadataAddress, configuration).ConfigureAwait(false);
}
catch (Exception ex)
{
LogHelper.LogExceptionMessage(
new InvalidOperationException(
LogHelper.FormatInvariant(
LogMessages.IDX20813,
LogHelper.MarkAsNonPII(MetadataAddress ?? "null"),
ex),
ex));
}
});
}
}

/// <summary>
Expand Down Expand Up @@ -391,17 +475,75 @@ private void RequestRefreshBackgroundThread()
{
TelemetryClient.IncrementConfigurationRefreshRequestCounter(
MetadataAddress,
TelemetryConstants.Protocols.Manual);
TelemetryConstants.Protocols.Manual,
TelemetryConstants.Protocols.ConfigurationSourceUnknown);

_isFirstRefreshRequest = false;
if (Interlocked.CompareExchange(ref _configurationRetrieverState, ConfigurationRetrieverRunning, ConfigurationRetrieverIdle) == ConfigurationRetrieverIdle)
{
_ = Task.Run(UpdateCurrentConfiguration, CancellationToken.None);
_ = Task.Run(UpdateCurrentConfigurationAsync, CancellationToken.None);
_lastRequestRefresh = now;
}
}
}

private async Task<ConfigurationEventHandlerResult<T>> HandleBeforeRetrieveAsync(CancellationToken cancellationToken = default)
{
long beforeHandlerTimestamp = TimeProvider.GetTimestamp();

try
{
var handlerResult = await ConfigurationEventHandler.BeforeRetrieveAsync(MetadataAddress, cancellationToken).ConfigureAwait(false);
if (handlerResult != null && handlerResult.Configuration != null)
{
var handlerElapsedTime = TimeProvider.GetElapsedTime(beforeHandlerTimestamp);
TelemetryClient.LogConfigurationRetrievalDuration(
MetadataAddress,
TelemetryConstants.Protocols.ConfigurationSourceHandler,
handlerElapsedTime);

// Validate configuration from handler
if (_configValidator != null)
{
ConfigurationValidationResult result = _configValidator.Validate(handlerResult.Configuration);
if (!result.Succeeded)
{
// Just log the error and proceed to fetch from endpoint
LogHelper.LogExceptionMessage(
new InvalidConfigurationException(
LogHelper.FormatInvariant(
LogMessages.IDX20812,
result.ErrorMessage)));

return ConfigurationEventHandlerResult<T>.NoResult;
}
}

// No validator configured, return configuration
return handlerResult;
}
}
catch (Exception ex)
{
var handlerErrorElapsedTime = TimeProvider.GetElapsedTime(beforeHandlerTimestamp);
TelemetryClient.LogConfigurationRetrievalDuration(
MetadataAddress,
TelemetryConstants.Protocols.ConfigurationSourceHandler,
handlerErrorElapsedTime,
ex);

LogHelper.LogExceptionMessage(
new InvalidOperationException(
LogHelper.FormatInvariant(
LogMessages.IDX20811,
LogHelper.MarkAsNonPII(MetadataAddress ?? "null"),
ex),
ex));
}

return ConfigurationEventHandlerResult<T>.NoResult;
}

/// <summary>
/// 12 hours is the default time interval that afterwards, <see cref="GetBaseConfigurationAsync(CancellationToken)"/> will obtain new configuration.
/// </summary>
Expand Down
Loading
Loading